CF613D Kingdom and its Cities 虚树 + 树形DP

时间:2024-04-30 00:57:52

Code:

#include<bits/stdc++.h>
#define ll long long
#define maxn 300003
#define RG register
using namespace std;
inline int read()
{
RG int x=0,t=1;RG char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')t=-1,ch=getchar();
while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
return x*t;
}
inline void setIO(string s)
{
string in=s+".in", out=s+".out";
freopen(in.c_str(),"r",stdin);
}
int edges,tim,n;
int hd[maxn], to[maxn<<1], nex[maxn<<1];
inline void add(int u,int v)
{
nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;
}
int fa[maxn], top[maxn], dfn[maxn], hson[maxn], siz[maxn], dep[maxn];
void dfs1(int u,int ff)
{
siz[u]=1,fa[u]=ff,dfn[u]=++tim,dep[u]=dep[ff]+1;
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(v==ff) continue;
dfs1(v, u);
siz[u]+=siz[v];
if(siz[v]>siz[hson[u]]) hson[u]=v;
}
}
void dfs2(int u,int tp)
{
top[u]=tp;
if(hson[u]) dfs2(hson[u], tp);
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(v==fa[u]||v==hson[u]) continue;
dfs2(v,v);
}
}
inline int LCA(int x,int y)
{
while(top[x]!=top[y])
{
dep[top[x]] > dep[top[y]] ? x = fa[top[x]] : y = fa[top[y]];
}
return dep[x] < dep[y] ? x : y;
}
int tp;
vector<int>G[maxn];
int arr[maxn],mk[maxn],S[maxn],g[maxn],f[maxn];
int cmp(int i,int j)
{
return dfn[i]<dfn[j];
}
inline void addvir(int u,int v)
{
G[u].push_back(v);
}
inline void insert(int x)
{
if(tp<=1) { S[++tp]=x; return; }
int lca=LCA(x, S[tp]);
if(lca==S[tp]) { S[++tp]=x; return; }
while(tp > 1 && dep[S[tp - 1]] >= dep[lca]) addvir(S[tp-1], S[tp]), --tp;
if(S[tp]!=lca) addvir(lca,S[tp]), S[tp]=lca;
S[++tp]=x;
}
void DP(int x)
{
g[x]=f[x]=0;
for(int i=0;i<G[x].size();++i)
{
int v = G[x][i];
DP(v);
f[x]+=f[v];
g[x]+=g[v];
}
if(mk[x]) f[x]+=g[x], g[x]=1;
else f[x]+=(g[x]>1), g[x]=(g[x]==1);
G[x].clear();
}
inline void work()
{
int k=read();
for(int i=1;i<=k;++i) arr[i]=read(), mk[arr[i]]=1;
sort(arr+1,arr+1+k,cmp);
for(int i=1;i<=k;++i)
if(mk[arr[i]]&&mk[fa[arr[i]]])
{
for(int j=1;j<=k;++j) mk[arr[j]]=0;
printf("-1\n");
return;
}
tp=0;
if(arr[1]!=1) S[tp=1]=1;
for(int i=1;i<=k;++i) insert(arr[i]);
while(tp > 1) addvir(S[tp - 1], S[tp]), --tp;
DP(1);
printf("%d\n",f[1]);
for(int j=1;j<=k;++j) mk[arr[j]]=0;
}
int main()
{
// setIO("input");
n=read();
for(int i=1,a,b;i<n;++i)
{
a=read(),b=read();
add(a,b), add(b,a);
}
dfs1(1,0), dfs2(1,1);
int Q;
Q=read();
for(int i=1;i<=Q;++i) work();
return 0;
}