codeforces 439 E. Devu and Birthday Celebration 组合数学 容斥定理

时间:2022-12-19 00:17:09

题意:

q个询问,每一个询问给出2个数sum,n

1 <= q <= 10^5, 1 <= n <= sum <= 10^5

对于每一个询问,求满足下列条件的数组的方案数

1.数组有n个元素,ai >= 1

2.sigma(ai) = sum

3.gcd(ai) = 1

 

solution:

这道题的做法类似bzoj2005能量采集

f(d) 表示gcd(ai) = d 的方案数

h(d) 表示d|gcd(ai)的方案数

令ai = bi * d

则有sigma(bi) = sum / n

  d | gcd(ai)

还要满足bi >= 1

则显然有h(d) = C(sum / d - 1,n - 1)

    h(d) = f(d) + f(2d) + ... + f(d_max)

 

这里的d满足:

1.d是sum 的约数

2.sum / d >= n

则f(d) = h(d) - sigma(f(j)) ,2d <=j<=sum/n

倒序遍历d

ans = f(1)

 

由于询问的次数太多,每次询问后,可以把(sum,n)放入map中,记录下来

 

                                            
//File Name: cf439E.cpp
//Author: long
//Mail: 736726758@qq.com
//Created Time: 2016年02月17日 星期三 14时58分16秒


#include
<cstdio>
#include
<cstring>
#include
<iostream>
#include
<algorithm>
#include
<map>
#include
<cmath>
#include
<cstdlib>
#include
<vector>

#define LL long long
#define pb push_back

using namespace std;

const int MAXN = 1e5+5;
const int MOD = 1e9+7;

LL f[MAXN];
LL jie[MAXN];
bool is[MAXN];
vector
<int> dive;
map
< pair<int,int>,int > rem;

void init()
{
jie[
0] = 1;
for(int i=1;i<MAXN;i++){
jie[i]
= jie[i-1] * i % MOD;
}
rem.clear();
}

void get_dive(int sum,int n)
{
int e = (int)sqrt(sum + 0.0);
dive.clear();
int j;
for(int i=1;i<=e;i++){
if(sum % i == 0){
if(sum / i >= n)
dive.pb(i);
j
= sum / i;
if(j != i && sum / j >= n)
dive.pb(j);
}
}
sort(dive.begin(),dive.end());
for(int i=0;i<dive.size();i++){
is[dive[i]] = true;
}
}

LL qp(LL x,LL y)
{
LL res
= 1LL;
while(y){
if(y & 1)
res
= res * x % MOD;
x
= x * x % MOD;
y
>>= 1;
}
return res;
}

LL comb(
int x ,int y)
{
if(y < 0 || y > x)
return 0;
if(y == 0 || y == x)
return 1;
return jie[x] * qp(jie[y] * jie[x-y] % MOD,MOD - 2) % MOD;
}

void solve(int sum,int n)
{
map
< pair<int,int>,int >::iterator it;
it
= rem.find(make_pair(sum,n));
if(it != rem.end()){
printf(
"%d\n",(int)(it->second));
return ;
}
memset(f,
0,sizeof f);
memset(
is,false,sizeof is);
get_dive(sum,n);
int ma = dive.size();
for(int i=ma-1;i>=0;i--){
int d = dive[i];
f[d]
= comb(sum / d - 1,n - 1);
for(int j=2*d;j<=dive[ma-1];j+=d){
if(is[j]){
f[d]
= ((f[d] - f[j] + MOD) % MOD + MOD) % MOD;
}
}
}
printf(
"%d\n",(int)f[1]);
rem[make_pair(sum,n)]
= f[1];
return ;
}

int main()
{
init();
int test;
scanf(
"%d",&test);
while(test--){
int sum,n;
scanf(
"%d %d",&sum,&n);
solve(sum,n);
}
return 0;
}