Codeforces Round #554 (Div. 2) F2. Neko Rules the Catniverse (Large Version) (矩阵快速幂 状压DP)

时间:2024-01-06 12:03:02

题意

有nnn个点,每个点只能走到编号在[1,min(n+m,1)][1,min(n+m,1)][1,min(n+m,1)]范围内的点。求路径长度恰好为kkk的简单路径(一个点最多走一次)数。

1≤n≤109,1≤m≤4,1≤k≤min(n,12)1\le n\le 10^9,1\le m\le 4,1\le k\le min(n,12)1≤n≤109,1≤m≤4,1≤k≤min(n,12)

分析

直接考虑走路径的话不能判有没有走过,然后就把路径转化为一个序列,每次往里面插入新的点(神了)。因为一个点可以走到比他小的所有点,那么我们把点从大到小插入。

假设现在已有序列为p1,p2,p3,...,pkp_1,p_2,p_3,...,p_kp1​,p2​,p3​,...,pk​。那么当前插入一个点iii。

假设插在pjp_jpj​和pj+1p_{j+1}pj+1​之间,必须满足pjp_jpj​能走到iii并且iii能走到pj+1p_{j+1}pj+1​。由于iii是最小的,那么所有pjp_jpj​都能走到iii,所以只用考虑iii能走到哪些点。

  • 一种情况是直接放在最后。
  • 另一种情况是i+m≥pj+1i+m\ge p_{j+1}i+m≥pj+1​。那么满足这个式子的pj+1p_{j+1}pj+1​最多有m(≤4)m(\le 4)m(≤4)个。那么就把[i+1,i+m][i+1,i+m][i+1,i+m]这mmm个数有没有出现在序列中过状压成mmm位222进制数,记为SSS。当前方案就是bitcount(S)bitcount(S)bitcount(S)(SSS在222进制中有多少个111)。

那么一共就有bitcount(S)+1bitcount(S)+1bitcount(S)+1种方案。

另外还可以不插入。

所以DPDPDP状态设为f[i][j][S]f[i][j][S]f[i][j][S]表示到iii这个点,序列长度为jjj,上述状态为SSS的方案数,转移方程为:

f[i+1][j+1][(S<<1∣1)&(2m−1)]+=f[i][j][S]∗(bitcount(S)+1)f[i+1][j][(S<<1∣0)&(2m−1)]+=f[i][j][S]\begin{aligned}
f[i+1][j+1][(S<<1|1)\&(2^m-1)]&+=f[i][j][S]*(bitcount(S)+1)\\
f[i+1][j][(S&lt;&lt;1|0)\&amp;(2^m-1)]&amp;+=f[i][j][S]\end{aligned}f[i+1][j+1][(S<<1∣1)&(2m−1)]f[i+1][j][(S<<1∣0)&(2m−1)]​+=f[i][j][S]∗(bitcount(S)+1)+=f[i][j][S]​

由于nnn比较大,就矩阵加速就行了。这个矩阵快速幂还是挺好写的。。

CODE

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 100005;
const int mod = 1e9 + 7;
int n, k, m, all;
inline void add(int &x, int y) { x += y; if(x >= mod) x -= mod; }
struct mat {
int a[210][210]; //最大状态数为(k+1)*(1<<m)<=(12+1)*(2^4)=208
mat() { memset(a, 0, sizeof a); }
inline mat operator *(const mat &o)const {
mat re;
for(int k = 0; k < all; ++k)
for(int i = 0; i < all; ++i) if(a[i][k])
for(int j = 0; j < all; ++j) if(o.a[k][j])
add(re.a[i][j], 1ll * a[i][k] * o.a[k][j] % mod);
return re;
}
inline mat operator ^(int b)const {
mat re, A = *this;
for(int i = 0; i < all; ++i) re.a[i][i] = 1;
while(b) {
if(b & 1) re = re * A;
A = A * A; b >>= 1;
}
return re;
}
};
inline int enc(int K, int S) { return K*(1<<m) + S; }
inline int nxt(int S, bool x) { return ((S<<1)|x) & ((1<<m)-1); }
int cnt[16];
int main () {
scanf("%d%d%d", &n, &k, &m);
all = (k+1)*(1<<m); //所有状态数
mat trans, ans;
ans.a[0][enc(0, 0)] = 1;
for(int s = 1; s < (1<<m); ++s) cnt[s] = cnt[s>>1] + (s&1); //预处理2进制下有多少个1
for(int i = 0; i <= k; ++i)
for(int s = 0; s < (1<<m); ++s) {
if(i < k) trans.a[enc(i, s)][enc(i+1, nxt(s, 1))] = cnt[s]+1;
trans.a[enc(i, s)][enc(i, nxt(s, 0))] = 1;
}
ans = ans * (trans ^ n);
int Ans = 0;
for(int s = 0; s < (1<<m); ++s)
add(Ans, ans.a[0][enc(k, s)]);
printf("%d\n", (Ans + mod) % mod);
}