hdu1024 dp

时间:2022-01-02 22:47:21

题意:求一个序列中的最大 m 段和,m 段不能交叉。

dp[i][0/1][j] 表示已经取完第 i 个物品,第 i 个物品取或不取,取到第 j 个子段。

用vis[i][0/1][j] 表示该 dp 值是否存在。

然后当 vis[i][0][j] 存在,即第 i 个物品不取,之前已经取了 j 个子段,可推得:

  第 i+1 个不取: dp[i+1][0][j]=max(dp[i+1][0][j],dp[i][0][j]);

  第 i+1 个取: dp[i+1][1][j+1]=max(dp[i+1][1][j+1],dp[i][0][j]+a[i]);

当 vis[i][1][j] 存在,即第 i 个物品取,之前已经取了 j 个子段(第 j 段可能还没有取完),可推得:

  第 i+1 个不取: dp[i+1][0][j]=max(dp[i+1][0][j],dp[i][1][j]);

  第 i+1 个取且放在第 j 个子段中: dp[i+1][1][j]=max(dp[i+1][1][j],dp[i][1][j]+a[i]);

  第 i+1 个取且放在第 j+1 个子段中: dp[i+1][1][j+1]=max(dp[i+1][1][j+1],dp[i][1][j]+a[i]);

然后初始化 dp[1][1][1]=a[1],dp[1][0][0]=0;

由于直接开 n*2*m 会MLE,所以将第一维滚动,2*2*m 就完全没有问题,复杂度 O(n*m);

 #include<stdio.h>
#include<string.h>
#include<algorithm>
using namespace std; typedef long long ll;
const int maxn=1e6+; int a[maxn];
int dp[][][];
bool vis[][][]; int main(){
int m,n;
while(scanf("%d%d",&m,&n)!=EOF){
for(int i=;i<=n;++i)scanf("%d",&a[i]);
memset(dp,,sizeof(dp));
memset(vis,,sizeof(vis));
dp[][][]=a[];
vis[][][]=;
dp[][][]=;
vis[][][]=;
for(int k=;k<n;++k){
int i=k&;
memset(vis[i^],,sizeof(vis[i^]));
for(int j=;j<=m;++j){
if(vis[i][][j]){
if(!vis[i^][][j]){
vis[i^][][j]=;
dp[i^][][j]=dp[i][][j];
}
else if(dp[i][][j]>dp[i^][][j])dp[i^][][j]=dp[i][][j];
if(!vis[i^][][j+]){
vis[i^][][j+]=;
dp[i^][][j+]=dp[i][][j]+a[k+];
}
else if(dp[i][][j]+a[k+]>dp[i^][][j+])dp[i^][][j+]=dp[i][][j]+a[k+];
}
if(vis[i][][j]){
if(!vis[i^][][j]){
vis[i^][][j]=;
dp[i^][][j]=dp[i][][j];
}
else if(dp[i][][j]>dp[i^][][j])dp[i^][][j]=dp[i][][j];
if(!vis[i^][][j]){
vis[i^][][j]=;
dp[i^][][j]=dp[i][][j]+a[k+];
}
else if(dp[i][][j]+a[k+]>dp[i^][][j])dp[i^][][j]=dp[i][][j]+a[k+];
if(!vis[i^][][j+]){
vis[i^][][j+]=;
dp[i^][][j+]=dp[i][][j]+a[k+];
}
else if(vis[i^][][j+]&&dp[i][][j]+a[k+]>dp[i^][][j+])dp[i^][][j+]=dp[i][][j]+a[k+];
}
}
}
int ans=-0x3f3f3f3f;
if(vis[n&][][m])ans=max(ans,dp[n&][][m]);
if(vis[n&][][m])ans=max(ans,dp[n&][][m]);
printf("%d\n",ans);
}
return ;
}