bzoj 4559 [JLoi2016]成绩比较 —— DP+拉格朗日插值

时间:2023-03-09 17:00:10
bzoj 4559 [JLoi2016]成绩比较 —— DP+拉格朗日插值

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=4559

看了看拉格朗日插值:http://www.cnblogs.com/ECJTUACM-873284962/p/6833391.html

https://blog.****.net/lvzelong2014/article/details/79159346

https://blog.****.net/qq_35649707/article/details/78018944

还只会最简单的那种,正好在这道题里可以用到;

计算方案数,可以考虑DP,利用那个所有成绩都小于 B 的性质,枚举超过 B 的一门课;

设计 f[i][j] 表示当前到了第 i 门课,还剩 j 个人被碾压(一开始是所有人都被碾压,然后渐渐突破...);

则 f[i][j] = ∑(j<=t<=n-1) f[i-1][t] * C(n-1-t,rk[i]-1-(t-j)) * C(t,j) * g[i]

其中第一个组合数表示在 n-1-t 个上一次已经不被碾压的人中选出  rk[i]-1-(t-j) 个作为这次成绩高于 B 的人,第二个组合数表示从 t 个上次被碾压的人中选出 j 个这次仍然被碾压(也等同与选出 t-j 个人这次成绩高于 B );

g[i] 则表示在 i 这门课上的成绩分布情况,则选出的人的成绩可以对号入座;

而 g[i] = ∑(1<=j<=lim[i]) j^(n-rk[i]) * (lim[i]-j)^(rk[i]-1),表示若 B 的成绩是 j,则有 n-rk[i] 个人的成绩在 1~j 中选择,有 rk[i]-1 个人的成绩在 lim[i]-j~lim[i] 中选择;

可以发现这是个大约 n+1 次的多项式,所以设出几个点,求出当 x=lim 时的取值即可,这个过程的复杂度是 n^2 的。

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
int const xn=,mod=1e9+;
int n,m,K,lm[xn],rk[xn],g[xn],c[xn][xn],f[xn][xn],xx[xn],yy[xn];
int rd()
{
int ret=,f=; char ch=getchar();
while(ch<''||ch>''){if(ch=='-')f=; ch=getchar();}
while(ch>=''&&ch<='')ret=(ret<<)+(ret<<)+ch-'',ch=getchar();
return f?ret:-ret;
}
int pw(ll a,int b)
{
ll ret=;
for(;b;b>>=,a=(a*a)%mod)
if(b&)ret=(ret*a)%mod;
return ret;
}
int upt(int x){while(x>=mod)x-=mod; while(x<)x+=mod; return x;}
void init()
{
for(int i=;i<=n;i++)c[i][]=;
for(int i=;i<=n;i++)
for(int j=;j<=i;j++)c[i][j]=upt(c[i-][j]+c[i-][j-]);
}
int solve(int lim,int n,int m)
{
int num=n+m+,sum=;
for(int i=;i<=num;i++)
xx[i]=i,yy[i]=upt(yy[i-]+(ll)pw(i,n)*pw(lim-i,m)%mod);
for(int i=;i<=num;i++)
{
ll s1=,s2=;
for(int j=;j<=num;j++)
if(i!=j)//!!!
s1=s1*(lim-xx[j])%mod,s2=s2*(xx[i]-xx[j])%mod;
sum=upt(sum+s1*pw(s2,mod-)%mod*yy[i]%mod);
}
return sum;
}
int main()
{
n=rd()-; m=rd(); K=rd(); init();//n-1
for(int i=;i<=m;i++)lm[i]=rd();
for(int i=;i<=m;i++)rk[i]=rd(),g[i]=solve(lm[i],n-rk[i]+,rk[i]-);//+1
f[][n]=;//n
for(int i=;i<=m;i++)
for(int j=K;j<=n;j++)//k
for(int t=j;t<=n;t++)
{
if(t-j>rk[i]-||j>n-rk[i]+)continue;//+1!
f[i][j]=upt(f[i][j]+(ll)f[i-][t]*c[t][j]%mod*c[n-t][rk[i]--t+j]%mod*g[i]%mod);
}
printf("%d\n",f[m][K]);
return ;
}