Codeforces 446D - DZY Loves Games(高斯消元+期望 DP+矩阵快速幂)

时间:2023-03-09 00:51:25
Codeforces 446D - DZY Loves Games(高斯消元+期望 DP+矩阵快速幂)

Codeforces 题目传送门 & 洛谷题目传送门

神仙题,%%%

首先考虑所有格子都是陷阱格的情况,那显然就是一个矩阵快速幂,具体来说,设 \(f_{i,j}\) 表示走了 \(i\) 步到达 \(j\) 点的概率,那显然有 \(dp_{i+1,k}\leftarrow dp_{i,j}\times\dfrac{1}{\delta^+(j)}\)(\(j,k\) 之间有边相连),矩阵快速幂优化一下即可,最终答案即为 \(f_{k-1,n}\),时间复杂度 \(n^3\log k\)。

接下来考虑原题,注意到本题 \(n\) 的数据范围为 \(500\),\(500^3\log k\) 过不了,但是陷阱格的数据范围只有 \(101\),\(101^3\log k\) 刚好能过,这启示我们可以以陷阱格为状态采用类似的方式进行矩阵快速幂。下设陷阱格个数为 \(m\),我们将所有陷阱格编个号,\(x_1,x_2,\cdots,x_m\),我们记 \(p_{i,j}\) 表示从第 \(i\) 个陷阱格出发,在不经过其他陷阱格的情况到达第 \(j\) 个陷阱格的概率,再设 \(f_{i,j}\) 表示已经经过了 \(i\) 个陷阱格,现在在第 \(j\) 个陷阱格的概率。倘若我们已经知道了 \(p_{i,j}\),那么显然有 \(f_{i+1,k}\leftarrow f_{i,j}\times p_{j,k}\),刚好矩阵快速幂。

考虑怎样求 \(p_{i,j}\),我们考虑再令 \(g_{i,j}\) 表示当前在 \(i\) 号点,有多大概率能过从 \(i\to j\),满足 \(i=x_j\) 或者路径上除第 \(j\) 个陷阱格外不存在其他陷阱格,那么就有 \(p_{i,j}=\sum\limits_{(x_i,t)\in E}g_{t,j}\times\dfrac{1}{\delta^+(x_i)}\)。最后考虑求 \(g_{i,j}\),显然若 \(i\) 本身就是陷阱格,那只有 \(i=x_j\) 时 \(g_{i,j}=1\),其他情况下 \(g_{i,j}=0\) 了。否则枚举下一个点进行转移,即 \(g_{i,j}=\sum\limits_{(i,k)\in E}g_{k,j}\times\dfrac{1}{\delta^+(i)}\),显然这个东西不能当作 DP 方程进行转移,因为它存在后效性,但注意到这题 \(n\) 只有 \(500\),刚好放 \(n^3\) 过,因此考虑高斯消元。不过对于所有的 \(j\) 都消一遍复杂度高达 \(mn^3\),无法通过,但稍加观察即可注意到这 \(nm\) 个方程中,所有 \(i\) 相同的方程除了常数项,前面的系数没有区别,因此考虑按照 CF832E 的套路,前面的系数列不是一列而是 \(m\) 列,相当于将增广矩阵由 \(n\times (n+1)\) 变为 \(n\times (n+m)\),这样复杂度即可讲到三方级别。

总复杂度 \(n^3+m^3\log k\)。

#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define fill0(a) memset(a,0,sizeof(a))
#define fill1(a) memset(a,-1,sizeof(a))
#define fillbig(a) memset(a,63,sizeof(a))
#define pb push_back
#define ppb pop_back
#define mp make_pair
template<typename T1,typename T2> void chkmin(T1 &x,T2 y){if(x>y) x=y;}
template<typename T1,typename T2> void chkmax(T1 &x,T2 y){if(x<y) x=y;}
typedef pair<int,int> pii;
typedef long long ll;
typedef unsigned int u32;
typedef unsigned long long u64;
namespace fastio{
#define FILE_SIZE 1<<23
char rbuf[FILE_SIZE],*p1=rbuf,*p2=rbuf,wbuf[FILE_SIZE],*p3=wbuf;
inline char getc(){return p1==p2&&(p2=(p1=rbuf)+fread(rbuf,1,FILE_SIZE,stdin),p1==p2)?-1:*p1++;}
inline void putc(char x){(*p3++=x);}
template<typename T> void read(T &x){
x=0;char c=getchar();T neg=0;
while(!isdigit(c)) neg|=!(c^'-'),c=getchar();
while(isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=getchar();
if(neg) x=(~x)+1;
}
template<typename T> void recursive_print(T x){if(!x) return;recursive_print(x/10);putc(x%10^48);}
template<typename T> void print(T x){if(!x) putc('0');if(x<0) putc('-'),x=~x+1;recursive_print(x);}
void print_final(){fwrite(wbuf,1,p3-wbuf,stdout);}
}
const int MAXN=500;
const int MAXM=101;
const int TOT=601;
int n,m,k,deg[MAXN+5],num[MAXN+5][MAXN+5];
bool is[MAXN+5];int rm[MAXN+5],id[MAXN+5],cnt=0;
double a[MAXN+5][TOT+5],f[MAXN+5][MAXM+5];
struct mat{
double a[MAXM+5][MAXM+5];
mat(){for(int i=1;i<=MAXM;i++) for(int j=1;j<=MAXM;j++) a[i][j]=0;}
mat operator *(const mat &rhs){
mat ret;
for(int i=1;i<=cnt;i++) for(int j=1;j<=cnt;j++)
for(int l=1;l<=cnt;l++) ret.a[i][j]+=a[i][l]*rhs.a[l][j];
return ret;
}
};
int main(){
scanf("%d%d%d",&n,&m,&k);
for(int i=1;i<=n;i++){
scanf("%d",&is[i]);
if(is[i]) rm[++cnt]=i,id[i]=cnt;
}
for(int i=1,u,v;i<=m;i++){
scanf("%d%d",&u,&v);deg[u]++;deg[v]++;
num[u][v]++;num[v][u]++;
}
// for(int i=1;i<=n;i++) printf("%d\n",deg[i]);
for(int i=1;i<=n;i++){
if(!is[i]){
a[i][i]=-1;
for(int j=1;j<=n;j++) if(i^j)
a[i][j]=1.0*num[i][j]/deg[i];
} else {
a[i][i]=a[i][id[i]+n]=1;
}
}
// for(int i=1;i<=n;i++) for(int j=1;j<=n+cnt;j++)
// printf("%.10lf%c",a[i][j],(j==n+cnt)?'\n':' ');
for(int i=1;i<=n;i++){
int pos=i;
for(int j=i+1;j<=n;j++) if(fabs(a[j][i])>fabs(a[pos][i])) pos=j;
for(int j=i;j<=n+cnt;j++) swap(a[pos][j],a[i][j]);
for(int j=i+1;j<=n+cnt;j++) a[i][j]/=a[i][i];a[i][i]=1;
for(int j=i+1;j<=n;j++){
for(int k=i+1;k<=n+cnt;k++) a[j][k]-=a[i][k]*a[j][i];
a[j][i]=0;
}
}
for(int i=1;i<=cnt;i++) for(int j=n;j;j--){
f[j][i]=a[j][i+n];
for(int k=j+1;k<=n;k++) f[j][i]-=a[j][k]*f[k][i];
// printf("%d %d %.10lf\n",j,i,f[j][i]);
} mat trs,ini,mul;
for(int i=1;i<=cnt;i++) for(int j=1;j<=cnt;j++){
for(int l=1;l<=n;l++) trs.a[j][i]+=1.0*num[rm[i]][l]/deg[rm[i]]*f[l][j];
// printf("%d %d %.10lf\n",i,j,trs.a[i][j]);
}
for(int i=1;i<=cnt;i++) ini.a[i][1]=f[1][i];
k-=2;for(int i=1;i<=cnt;i++) mul.a[i][i]=1;
for(;k;k>>=1,trs=trs*trs) if(k&1) mul=mul*trs;
ini=mul*ini;printf("%.10lf\n",ini.a[cnt][1]);
return 0;
}