2016ACM/ICPC亚洲区沈阳站H - Guessing the Dice Roll HDU - 5955 ac自动机+概率dp+高斯消元

时间:2024-01-02 16:08:38

http://acm.hdu.edu.cn/showproblem.php?pid=5955

题意:给你长度为l的n组数,每个数1-6,每次扔色子,问你每个串第一次被匹配的概率是多少

题解:先建成ac自动机构造fail数组,然后因为fail指针可能向前转移所以不能不能直接递推dp,需要高斯消元解方程,对于节点i,假设不是结束点而且能转移到它的点有a1,a2...an,那么dp[i]=1/6*dp[a1]+1/6*dp[a2]+...+1/6*a[n],然后我们可以列出n个方程,高斯消元然后找到每个串结尾点的概率就是答案了

//#pragma comment(linker, "/stack:200000000")
//#pragma GCC optimize("Ofast,no-stack-protector")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
//#pragma GCC optimize("unroll-loops")
#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define pi acos(-1.0)
#define ll long long
#define vi vector<int>
#define mod 1000000007
#define C 0.5772156649
#define ls l,m,rt<<1
#define rs m+1,r,rt<<1|1
#define pil pair<int,ll>
#define pli pair<ll,int>
#define pii pair<int,int>
#define cd complex<double>
#define ull unsigned long long
#define base 1000000000000000000
#define fio ios::sync_with_stdio(false);cin.tie(0) using namespace std; const double g=10.0,eps=1e-;
const int N=+,maxn=+,inf=0x3f3f3f3f,INF=0x3f3f3f3f3f3f3f3f; int l;
double a[N][N],ans[N];
void gauss(int n)
{
for(int i=;i<n;i++)
{
if(a[i][i]==)
{
int id=;
for(int j=i+;j<=n;j++)
if(a[j][i]!=)
id=j;
for(int j=i;j<=n+;j++)
swap(a[i][j],a[id][j]);
}
for(int j=i+;j<=n;j++)
{
double t=a[j][i]/a[i][i];
for(int k=i;k<=n+;k++)
a[j][k]-=(a[i][k]*t);
}
}
// for(int i=1;i<=n;i++)
// {
// for(int j=1;j<=n+1;j++)
// printf("%.12f ",a[i][j]);
// puts("");
// }
for(int i=n;i>=;i--)
{
for(int j=i+;j<=n;j++)
a[i][n+]-=ans[j]*a[i][j];
ans[i]=a[i][n+]/a[i][i];
}
}
char s[N];
struct ACM{
int root,tot;
int Next[N][],fail[N],End[N];
int newnode()
{
memset(Next[tot],-,sizeof Next[tot]);
End[tot]=;
return tot++;
}
void init()
{
tot=;
root=newnode();
}
void ins(int i)
{
int now=root;
for(int i=,x;i<l;i++)
{
scanf("%d",&x);x--;
if(Next[now][x]==-)
Next[now][x]=newnode();
now=Next[now][x];
}
End[now]=i;
}
void build()
{
queue<int>q;
fail[root]=root;
for(int i=;i<;i++)
{
if(Next[root][i]==-)Next[root][i]=root;
else
{
fail[Next[root][i]]=root;
q.push(Next[root][i]);
}
}
while(!q.empty())
{
int now=q.front();
q.pop();
if(End[fail[now]])End[now]=End[fail[now]];
for(int i=;i<;i++)
{
if(Next[now][i]==-)Next[now][i]=Next[fail[now]][i];
else
{
fail[Next[now][i]]=Next[fail[now]][i];
q.push(Next[now][i]);
}
}
}
}
void solve()
{
memset(a,,sizeof a);
a[][tot+]=-1.0;
for(int i=;i<tot;i++)
{
a[i+][i+]=-1.0;
if(End[i])continue;
for(int j=;j<;j++)a[Next[i][j]+][i+]+=1.0/;
}
// for(int i=1;i<=tot;i++)
// {
// for(int j=1;j<=tot+1;j++)printf("%.5f ",a[i][j]);
// puts("");
// }
gauss(tot);
bool ok=;
for(int i=;i<tot;i++)
{
if(End[i])
{
if(!ok)printf("%.6f",ans[i+]);
else printf(" %.6f",ans[i+]);
ok=;
}
}
puts("");
}
}ac;
int main()
{
int T;scanf("%d",&T);
while(T--)
{
ac.init();
int n;
scanf("%d%d",&n,&l);
for(int i=;i<n;i++)ac.ins(i+);
ac.build();
ac.solve();
}
return ;
}
/***********************
1
2 2
1 1
2 1
***********************/