P2515 [HAOI2010]软件安装

时间:2023-03-09 03:01:43
P2515 [HAOI2010]软件安装

树形背包

 #include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<vector>
#define MAXN 110
#define MAXM 510
#define pii pair<int,int>
#define rint register int
#define mp make_pair
#define pb push_back
using namespace std;
int read(){
int x=,f=;char ch=getchar();
while(ch<''||ch>''){if('-'==ch)f=-;ch=getchar();}
while(ch>=''&&ch<=''){x=x*+ch-'';ch=getchar();}
return x*f;
}
int tot,m;
int C[MAXN],V[MAXN];
vector<int> G[MAXN];
int dfn[MAXN],low[MAXN],sta[MAXN],b[MAXN],top,idx;
int c[MAXN],v[MAXN],cmp[MAXN],n;
int fst[MAXN],nxt[MAXN],from[MAXN],to[MAXN],cnt,pin[MAXN];
void add(int x,int y){
nxt[++cnt]=fst[x],fst[x]=cnt,from[cnt]=x,to[cnt]=y;
pin[y]++;
}
void tarjan(int x){
dfn[x]=low[x]=(++idx);
sta[++top]=x,b[x]=;
int y;
for(rint i=;i<G[x].size();i++){
y=G[x][i];
if(!dfn[y]){
tarjan(y);
low[x]=min(low[x],low[y]);
}
else if(b[y]){
low[x]=min(low[x],dfn[y]);
}
}
if(dfn[x]==low[x]){
n++;
while(sta[top+]!=x){
c[n]+=C[sta[top]];
v[n]+=V[sta[top]];
cmp[sta[top]]=n;
b[sta[top]]=;
top--;
}
}
}
int f[MAXN][MAXM],dep[MAXN];
int H[MAXM],H1[MAXM],H2[MAXM];
void work(int x,int vl){
if(vl<c[x]){f[x][vl]=;return;}
vl-=c[x];
memset(H,,sizeof(H));
memset(H1,,sizeof(H1));
int y,r=;
for(rint e=fst[x];e;e=nxt[e]){
y=to[e];
for(rint i=;i<=vl;i++){
H[i]=max(H[i],H1[i]+f[y][vl-i]);
r=max(r,H[i]);
}
memcpy(H1,H,sizeof(H));
}
f[x][vl+c[x]]=r+v[x];
}
void dp(int x){
b[x]=;
int y;
for(rint e=fst[x];e;e=nxt[e]){
y=to[e];
if(!b[y]){
dp(y);
for(rint j=m-c[x];j>=;j--){
for(rint k=;k<=j;k++){
f[x][j]=max(f[x][j],f[x][k]+f[y][j-k]);
}
}
}
}
for(rint j=m;j>=;j--){
if(j>=c[x]){
f[x][j]=f[x][j-c[x]]+v[x];
}
else{
f[x][j]=;
}
}
}
int main()
{
// freopen("data.in","r",stdin);
tot=read();m=read();
for(rint i=;i<=tot;i++)C[i]=read();
for(rint i=;i<=tot;i++)V[i]=read();
int tmp;
for(rint i=;i<=tot;i++){
tmp=read();
if(tmp)G[tmp].pb(i);
}
for(rint i=;i<=tot;i++){
if(!dfn[i])tarjan(i);
}
for(rint i=;i<=tot;i++){
for(rint j=;j<G[i].size();j++){
int x=i,y=G[i][j];
if(cmp[x]!=cmp[y]){
add(cmp[x],cmp[y]);
}
}
}
for(rint i=;i<=n;i++){
if(!pin[i]){
add(,i);
}
}
// for(rint i=1;i<=cnt;i++){
// printf("%d %d\n",from[i],to[i]);
// }
memset(b,,sizeof(b));
dp();
printf("%d\n",f[][m]);
return ;
}