Codeforces 1085G(1086E) Beautiful Matrix $dp$+树状数组

时间:2022-05-06 05:59:46

题意

定义一个\(n*n\)的矩阵是\(beautiful\)的,需要满足以下三个条件:

1.每一行是一个排列。

2.上下相邻的两个元素的值不同。

再定义两个矩阵的字典序大的矩阵大(从左往右从上到下一个一个比较)。

给出一个\(beautiful\)的\(n*n\)的矩阵,求有多少个矩阵小于这个矩阵。

Solution

逐行计算。

\(ans=\)每行字典序比这行小的排列且与上一行相邻的两个元素值不同的排列个数*\(n\)个元素错排的方案数\(^{n-i}\)

第一行的方案数随便算,我就不说了。

另外的行大概就是逐位算。

从后往前枚举前\(i\)个数相同,树状数组维护当前位置可以填的数有几个有限制(即上一行后\(n-i+1\)中有这个数)和当前能填哪些数(即比\(a_{i,j}\)小且在当前行后\(n-i+1\)个数中出现了),不难发现有限制的数或者没限制的数都是同质的,那么就答案就是方案数乘上数的个数,问题就是有几个数有限制的错排怎么算方案数?\(dp\)一下就好了。

设\(dp_{i,j}\)表示\(i\)个数中有\(j\)个数有限制的排列的方案数。

考虑从\(dp_{i,j-1}\)转移,减去多了一个限制的数会少的方案数。

多了一个限制的数不合法的方案数?那我们就强制多的那个数不符合限制,另外数符合限制,也就是\(dp_{i-1,j-1}\)。

\(dp_{i,j}=dp_{i,j-1}-dp_{i-1,j-1}\)

如果不会推,也可以打表

\(dp_{n,n}\)的值就是\(n\)个数错排的方案数。

#include<bits/stdc++.h>
#define For(i,x,y) for (register int i=(x);i<=(y);i++)
#define Dow(i,x,y) for (register int i=(x);i>=(y);i--)
#define cross(i,k) for (register int i=first[k];i;i=last[i])
#define Debug(x) cerr<<#x<<"="<<(x)<<endl
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pa;
inline ll read(){
ll x=0;int ch=getchar(),f=1;
while (!isdigit(ch)&&(ch!='-')&&(ch!=EOF)) ch=getchar();
if (ch=='-'){f=-1;ch=getchar();}
while (isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return x*f;
}
const int N = 2010;
int n,a[N][N];
const int mod = 998244353;
int fac[N],dp[N][N],p[N];
inline void init(){
fac[0]=1;For(i,1,n) fac[i]=1ll*fac[i-1]*i%mod;
dp[1][0]=1;
For(i,2,n){
dp[i][0]=fac[i];
For(j,1,i) dp[i][j]=(dp[i][j-1]-dp[i-1][j-1]+mod)%mod;
}
p[0]=1;For(i,1,n) p[i]=1ll*p[i-1]*dp[n][n]%mod;
}
struct BIT{
int c[N],sum;
inline void clear(){sum=0,memset(c,0,sizeof c);}
inline void Add(int x){sum++;for (;x<=n;x+=x&-x) c[x]++;}
inline int Query(int x){int ans=0;for (;x;x-=x&-x) ans+=c[x];return ans;}
}t,T;
int b[N],ans;
inline void Add(int x){if (++b[x]==2) T.Add(x);}
inline void upd(int &x,int y){x+=y,(x>=mod)?x-=mod:0;}
int main(){
n=read();
For(i,1,n) For(j,1,n) a[i][j]=read();
init();int sum=0;
For(i,1,n) upd(sum,1ll*fac[n-i]*(a[1][i]-1-t.Query(a[1][i]-1))%mod),t.Add(a[1][i]);
ans=1ll*sum*p[n-1]%mod;//printf("%d\n",ans);
For(i,2,n){
t.clear(),T.clear(),sum=0,memset(b,0,sizeof b);
Dow(j,n,1){
Add(a[i][j]),Add(a[i-1][j]),t.Add(a[i][j]);
int x=T.Query(a[i][j]-1),y=t.Query(a[i][j]-1)-x,z=T.sum;
if (b[a[i-1][j]]==2&&a[i-1][j]<a[i][j]) x--;
if (b[a[i-1][j]]==2) z--;
upd(sum,1ll*x*dp[n-j][z-1]%mod),upd(sum,1ll*y*dp[n-j][z]%mod);
//printf("%d %d ",z,x*dp[n-j][z-1]);
}//puts("");
upd(ans,1ll*sum*p[n-i]%mod);
}
printf("%d\n",ans);
}