【BZOJ5287】[HNOI2018]毒瘤(动态规划,容斥)

时间:2023-03-08 22:26:41

【BZOJ5287】[HNOI2018]毒瘤(动态规划,容斥)

题面

BZOJ

洛谷

题解

考场上想到的暴力做法是容斥:

因为\(m-n\le 10\),所以最多会多出来\(11\)条非树边。

如果就是一棵树的话,显然答案就是独立集的个数。

非树边\(2^{11}\)枚举,强制非树边的两端同时备选导致不合法,容斥计算答案即可。

这样子的复杂度是\(O(2^{11}n)\),估算出来是\(2s\),然而在\(HNOI\)考场跑要\(20s\)(大雾

考虑如何优化这个东西。

我们\(2^{11}\)枚举出来之后,显然是强制令枚举的非树边的两端都被选入进集合。但是我们并不需要每次重新做一遍\(dp\),显然会出现大量的重复计算内容。

把枚举的点的虚树给构建出来,显然会影响到的部分只有虚树上的点和链。

对于每个虚树上的点,考虑修改后对于其虚树上父亲的影响。

因为\(dp\)状态是\(f[i][0/1]\),所以可以把关键点的状态设为\(x,y\),到虚树上父亲的链的转移全部用\(x,y\)的形式转移,这样子到其父亲时就可以合并一堆\(x,y\)的状态,当确定所有\(x,y\)后就能确定所有虚树上的关键点的\(dp\)值。

这样子单次容斥的复杂度就变成了虚树点数,这个东西很小。

这是一个很类似于动态\(dp\)的思路。

代码有点丑

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
using namespace std;
#define ll long long
#define MAX 100100
#define MOD 998244353
#define pb push_back
inline int read()
{
int x=0;bool t=false;char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')t=true,ch=getchar();
while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
return t?-x:x;
}
int fpow(int a,int b)
{
int s=1;
while(b){if(b&1)s=1ll*s*a%MOD;a=1ll*a*a%MOD;b>>=1;}
return s;
}
int dsu[MAX];
int getf(int x){return x==dsu[x]?x:dsu[x]=getf(dsu[x]);}
struct Line{int v,next;}e[MAX<<1];
int h[MAX],cnt=1;
inline void Add(int u,int v){e[cnt]=(Line){v,h[u]};h[u]=cnt++;}
int n,m,ans,f[MAX][2],zr[MAX][2],g[MAX][2];
int fa[MAX],dfn[MAX],tim,size[MAX],hson[MAX],top[MAX],dep[MAX],low[MAX];
void dfs1(int u,int ff)
{
f[u][0]=f[u][1]=1;fa[u]=ff;dep[u]=dep[ff]+1;size[u]=1;
for(int i=h[u];i;i=e[i].next)
{
int v=e[i].v;if(v==ff)continue;
dfs1(v,u);size[u]+=size[v];
if(size[v]>size[hson[u]])hson[u]=v;
if((f[v][0]+f[v][1])%MOD)f[u][0]=1ll*f[u][0]*(f[v][0]+f[v][1])%MOD;else zr[u][0]+=1;
if(f[v][0])f[u][1]=1ll*f[u][1]*f[v][0]%MOD;else zr[u][1]+=1;
}
}
void dfs2(int u,int tp)
{
top[u]=tp;dfn[u]=++tim;
if(hson[u])dfs2(hson[u],tp);
for(int i=h[u];i;i=e[i].next)
if(e[i].v!=fa[u]&&e[i].v!=hson[u])
dfs2(e[i].v,e[i].v);
low[u]=tim;
}
int LCA(int u,int v)
{
while(top[u]^top[v])dep[top[u]]<dep[top[v]]?v=fa[top[v]]:u=fa[top[u]];
return dep[u]<dep[v]?u:v;
}
bool cmp(int a,int b){return dfn[a]<dfn[b];}
int S[MAX],Top,snt;bool spn[MAX];
struct data{int x,y;}nt[50];
data operator*(data a,int b){return (data){1ll*a.x*b%MOD,1ll*a.y*b%MOD};}
data operator+(data a,data b){return (data){(a.x+b.x)%MOD,(a.y+b.y)%MOD};}
vector<int> fr[MAX];
vector<data> F0[MAX],F1[MAX];
int Q[MAX],tot;
int Div(int i,int p,int j)
{
if(j)return zr[i][p]?0:1ll*f[i][p]*fpow(j,MOD-2)%MOD;
else return zr[i][p]==1?f[i][p]:0;
}
void Calc(int x,int y)
{
data f0=(data){1,0},f1=(data){0,1},ff0,ff1;
int p=x;
for(int i=fa[x],j=x;i!=y;p=j=i,i=fa[i])
{
int F0=Div(i,0,(f[j][0]+f[j][1])%MOD),F1=Div(i,1,f[j][0]);
ff0=(f0+f1)*F0;ff1=f0*F1;
f0=ff0;f1=ff1;
}
fr[y].pb(x);F0[y].pb(f0);F1[y].pb(f1);
int a=(f[p][0]+f[p][1])%MOD,b=f[p][0];
if(a)f[y][0]=1ll*f[y][0]*fpow(a,MOD-2)%MOD;else zr[y][0]-=1;
if(b)f[y][1]=1ll*f[y][1]*fpow(b,MOD-2)%MOD;else zr[y][1]-=1; }
bool Vis[MAX];
int DP()
{
for(int i=Top;i;--i)g[S[i]][0]=zr[S[i]][0]?0:f[S[i]][0],g[S[i]][1]=zr[S[i]][1]?0:f[S[i]][1];
for(int i=Top;i;--i)
if(Vis[S[i]])g[S[i]][0]=0;
for(int i=Top;i;--i)
for(int j=0,l=fr[S[i]].size();j<l;++j)
{
int u=S[i],v=fr[u][j];
data f0=F0[u][j],f1=F1[u][j];
int ff0=(1ll*f0.x*g[v][0]+1ll*f0.y*g[v][1])%MOD;
int ff1=(1ll*f1.x*g[v][0]+1ll*f1.y*g[v][1])%MOD;
g[u][0]=1ll*g[u][0]*(ff0+ff1)%MOD;
g[u][1]=1ll*g[u][1]*ff0%MOD;
}
return (g[1][0]+g[1][1])%MOD;
}
int main()
{
n=read();m=read();
for(int i=1;i<=n;++i)dsu[i]=i;
for(int i=1;i<=m;++i)
{
int u=read(),v=read();
if(getf(u)==getf(v))S[++Top]=u,S[++Top]=v,nt[snt++]=(data){u,v};
else Add(u,v),Add(v,u),dsu[getf(u)]=getf(v);
}
dfs1(1,0);dfs2(1,1);S[++Top]=1;
sort(&S[1],&S[Top+1],cmp);
for(int i=Top;i>1;--i)S[++Top]=LCA(S[i],S[i-1]);
sort(&S[1],&S[Top+1],cmp);Top=unique(&S[1],&S[Top+1])-S-1;
for(int i=1;i<=Top;++i)spn[S[i]]=true;
Q[tot=1]=S[1];
for(int i=2;i<=Top;++i)
{
while(!(dfn[Q[tot]]<=dfn[S[i]]&&dfn[S[i]]<=low[Q[tot]]))--tot;
Calc(S[i],Q[tot]);Q[++tot]=S[i];
}
for(int i=0;i<1<<snt;++i)
{
int d=1;
for(int j=0;j<snt;++j)
if(i&(1<<j))
Vis[nt[j].x]=Vis[nt[j].y]=true,d^=1;
int ret=DP();
if(d)ans=(ans+ret)%MOD;
else ans=(ans+MOD-ret)%MOD;
for(int j=0;j<snt;++j)Vis[nt[j].x]=Vis[nt[j].y]=false;
}
printf("%d\n",ans);
return 0;
}