AtCoder AGC019E Shuffle and Swap (DP、FFT、多项式求逆、多项式快速幂)

时间:2022-06-11 09:16:43

题目链接

https://atcoder.jp/contests/agc019/tasks/agc019_e

题解

tourist的神仙E题啊做不来做不来……这题我好像想歪了啊= =……

首先我们可以考虑,什么样的操作序列才是合法的?

有用的位置只有两种,一种是两个序列在这个位置上都是1, 称作11型,另一种是一个0一个1, 称作01型。设两种位置分别有\(A\)个和\(2B\)个。

考虑一个操作序列,交换两个11型相当于没交换,每个11型只会被交换两次,每个01型只会被交换一次。这也就是说,如果我们从shuffle之后的\(a_i\)向\(b_i\)连边,那么形成的图一定是若干个环加上\(m\)条链,链的开头结尾都是01型,中间是11型。对于链来说,上面操作的顺序必须固定;对于环来说,上面操作的顺序可以任意。

下面有两种处理方式。

做法一

设\(dp[i][j]\)表示把\(j\)个无标号的11型放到\(i\)条链中,可得DP式: \(dp[i][j]=\sum_{k\ge 0}\frac{dp[i-1][j-k]}{(k+1)!}\), 其中分母的含义是链上\((k+1)\)个点顺序固定,最后的答案是\(A!B!(A+B)!\sum^A_{i=0}dp[B][i]\). \((A+B)!\)表示将边随意排序,\(A!B!\)表示11型和01型点之间是有标号的。

时间复杂度\(O(n^3)\).

但是我们发现这个DP就相当于在给多项式\(\sum_{n\ge 0}\frac{1}{(n+1)!}x^n\)进行幂运算,于是用多项式快速幂加速即可,时间复杂度\(O(n\log^2n)\)或\(O(n\log n)\).

做法二

有没有聪明一点的做法?有!

设\(dp[i][j]\)表示目前一共放了\(i\)个11型和\(j\)个01型链(考虑已经放了的元素的标号,但是每次仅仅是往右添加),我们强行转移!

\(dp[i][j]=j^2\times dp[i][j-1]+ij\times dp[i-1][j]\)

前一个式子是要加一个新的01型链,选两个01型;后一个式子是要选一个链,再把这条链的结尾端点任意扩展一个位置。

答案就是\(\sum_{k}{A\choose k}\times f[k][B]\times ((A-k)!)^2\times {A+B\choose A-k}\), 其中\(A+B\choose A-k\)是选出位置,\(A\choose k\)是选出编号,\(((A-k)!)^2\)是求出组成环的方案数。

时间复杂度\(O(n^2)\), 可以通过。

但是我们发现这个DP还可以用多项式优化!

令\(g[i][j]=\frac{dp[i][j]}{(j!)^2i!}\), 显然有\(g[i][j]=g[i][j-1]+j\times g[i-1][j]\)

然后这个使用NE Lattice Path的方式来理解,就是从\((0,0)\)走到\((i,j)\),每往上走一次路径权值乘上横坐标,求所有路径权值和。

考虑另一种DP,枚举在第\(i\)列走几步,那么发现第\(i\)列的生成函数就是\(\frac{1}{1-ix}\), 然后答案就是所有列生成函数之积

于是可以分治NTT+多项式求逆计算,时间复杂度\(O(n\log^2n)\).

代码

做法二\(O(n^2)\)

#include<cstdio>
#include<cstdlib>
#include<iostream>
#include<cassert>
#include<cstring>
#define llong long long
using namespace std; const int N = 2e4;
const int P = 998244353;
const llong INV2 = 499122177;
llong fact[N+3],finv[N+3]; llong quickpow(llong x,llong y)
{
llong cur = x,ret = 1ll;
for(int i=0; y; i++)
{
if(y&(1ll<<i)) {y-=(1ll<<i); ret = ret*cur%P;}
cur = cur*cur%P;
}
return ret;
}
llong mulinv(llong x) {return quickpow(x,P-2);}
llong comb(llong x,llong y) {return x<0||y<0||x<y ? 0ll : fact[x]*finv[y]%P*finv[x-y]%P;} int n,a,b;
char s[N+3],t[N+3];
llong dp[2][N+3]; int main()
{
fact[0] = 1ll; for(int i=1; i<=N; i++) fact[i] = fact[i-1]*i%P;
finv[N] = quickpow(fact[N],P-2); for(int i=N-1; i>=0; i--) finv[i] = finv[i+1]*(i+1)%P;
scanf("%s%s",s+1,t+1); n = strlen(s+1);
for(int i=1; i<=n; i++)
{
if(s[i]=='1' && t[i]=='1') {a++;}
else if(s[i]^t[i]) {b++;}
}
b>>=1;
int cur = 0,prv = 1;
dp[0][0] = 1ll;
for(int j=1; j<=b; j++)
{
cur^=1; prv^=1;
dp[cur][0] = dp[prv][0]*j*j%P;
for(int i=1; i<=a; i++)
{
dp[cur][i] = (dp[prv][i]*j*j+dp[cur][i-1]*i*j)%P;
// printf("dp[%d][%d]=%lld\n",i,j,dp[cur][i]);
}
}
llong ans = 0ll;
for(int k=0; k<=a; k++)
{
ans = (ans+dp[cur][k]*comb(a,k)%P*fact[a-k]%P*fact[a-k]%P*comb(a+b,a-k))%P;
// printf("k%d ans%lld\n",k,ans);
}
printf("%lld\n",ans);
return 0;
}