【BZOJ4042】【CERC2014】parades 状压DP

时间:2023-03-09 04:07:29
【BZOJ4042】【CERC2014】parades 状压DP

题目大意

  给你一棵\(n\)个点的树和\(m\)条路径要求你找出最多的路径,使得这些路径不共边。特别的,每个点的度数\(\leq 10\)。

  \(n\leq 1000,m\leq \frac{n(n-1)}{2}\)

题解

  先对于每个点把相邻的边编号。

  考虑状压DP。

  设\(f_{i,j}\)为以\(i\)个点的子树内,状态为\(j\)的边的子树内的边也没有选(这些边也没选),所选的最多路径数。

  然后对于每个点有很多种选法,分为两类:1.某条边不选,选对应的子树;2.选\(1\)~\(2\)条边和对应的路径,那么这些路径经过的边都不能选。

  然后直接状压DP。

  对于每个点来说,总共有最多\(O(d^2)\)种转移。考虑每个儿子的贡献,就是\(O(d)\)。

  时间复杂度:\(O(n^2+nd2^d)\)

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
struct list
{
int t[1000010];
pii v[1000010];
int h[1010];
int n;
void clear()
{
memset(h,0,sizeof h);
n=0;
}
void add(int x,pii y)
{
n++;
v[n]=y;
t[n]=h[x];
h[x]=n;
}
};
list l;
int f[1010][1<<10];
int g[1010];
int c[1010][20];
int d[1010];
int ns[12][12];
int e[1010];
void dfs2(int x,int fa,int t,int s)
{
int fc;
int i;
for(i=1;i<=d[x];i++)
if(c[x][i]==fa)
fc=i;
g[x]=s+f[x][((1<<d[x])-1)^(1<<(fc-1))];
e[x]=t;
for(i=1;i<=d[x];i++)
if(c[x][i]!=fa)
dfs2(c[x][i],x,t,s+f[x][((1<<d[x])-1)^(1<<(fc-1))^(1<<(i-1))]);
}
int dd[1010];
int ff[1010];
int lca[1010][1010];
void dfs(int x,int fa,int dep)
{
ff[x]=fa;
dd[x]=dep;
int i;
for(i=1;i<=d[x];i++)
if(c[x][i]!=fa)
dfs(c[x][i],x,dep+1);
}
int getlca(int x,int y)
{
if(x==y)
return x;
if(lca[x][y])
return lca[x][y];
if(dd[x]>dd[y])
return lca[x][y]=getlca(ff[x],y);
return lca[x][y]=getlca(x,ff[y]);
}
void dp(int x,int fa)
{
int i;
for(i=1;i<=d[x];i++)
if(c[x][i]!=fa)
dp(c[x][i],x);
for(i=1;i<=d[x];i++)
if(c[x][i]!=fa)
dfs2(c[x][i],x,i,0);
memset(ns,0,sizeof ns);
int cx,cy,cs;
for(i=l.h[x];i;i=l.t[i])
{
if(l.v[i].first==x)
{
cx=0;
cy=e[l.v[i].second];
cs=g[l.v[i].second];
}
else if(l.v[i].second==x)
{
cx=e[l.v[i].first];
cy=0;
cs=g[l.v[i].first];
}
else
{
cx=e[l.v[i].first];
cy=e[l.v[i].second];
cs=g[l.v[i].first]+g[l.v[i].second];
}
cs++;
if(cx>cy)
swap(cx,cy);
ns[cx][cy]=max(ns[cx][cy],cs);
}
for(i=1;i<=d[x];i++)
if(c[x][i]!=fa)
{
cx=0;
cy=i;
cs=f[c[x][i]][(1<<d[c[x][i]])-1];
ns[cx][cy]=max(ns[cx][cy],cs);
}
int j,k;
for(i=0;i<=d[x];i++)
for(j=0;j<=d[x];j++)
if(ns[i][j])
{
int s=0;
if(i)
s|=1<<(i-1);
if(j)
s|=1<<(j-1);
for(k=0;k<(1<<d[x]);k++)
if(!(k&s))
f[x][k|s]=max(f[x][k|s],f[x][k]+ns[i][j]);
}
}
void solve()
{
memset(d,0,sizeof d);
int n;
scanf("%d",&n);
int i,j;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
lca[i][j]=0;
for(i=1;i<=n;i++)
for(j=0;j<(1<<10);j++)
f[i][j]=0;
l.clear();
int x,y;
for(i=1;i<=n-1;i++)
{
scanf("%d%d",&x,&y);
c[x][++d[x]]=y;
c[y][++d[y]]=x;
}
dfs(1,0,1);
int m;
scanf("%d",&m);
for(i=1;i<=m;i++)
{
scanf("%d%d",&x,&y);
l.add(getlca(x,y),pii(x,y));
}
dp(1,0);
int ans=0;
for(i=1;i<=n;i++)
ans=max(ans,f[i][(1<<d[i])-1]);
printf("%d\n",ans);
}
int main()
{
#ifdef DEBUG
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
#endif
int t;
scanf("%d",&t);
while(t--)
solve();
return 0;
}