[题解] [CF451E] Devu and Flowers

时间:2022-12-19 00:16:03

题面

题解

就是一个求\(\sum_{i= 1}^{n}x _ i = m\)的不重复多重集的个数, 我们可以由容斥原理得到:
\[ ans = C_{n + m - 1}^{n - 1} - \sum_{i = 1}^{n}C_{n + m - f_i - 2}^{n - 1} + \sum_{1 \leq i < j \leq n}C_{n + m - f_i - f_j - 3}^{n - 1} - \cdots + (-1)^n C_{n + m - \sum_{k = 1}^{n}f_k -(n + 1)}^{n - 1} \]
数据范围中\(1\leq N\leq 20\)告诉了我们什么?

我们考虑枚举\(x = 0 \sim 2 ^ n - 1\), 设\(x\)在二进制表示下共有\(p\)位为1, 分别是\(i_1, i_2, i_3, \cdots, i_p\), 则这个\(x\)对答案的贡献就是
\[ (-1)^pC_{n+m-\sum_{k=1}^{p}f_{i_k}-(p+1)}^{n-1} \]
注意到\(x\)为0时它对答案的贡献为\(C_{n + m - 1}^{n - 1}\)

还是因为\(N\)比较小, 我们可以将\(C_{n+m-1}^{n-1}\)转化为\(P_{n+m-1}^{n-1}/(n-1)!\)

由于\(P_{n+m-1}^{n-1}=(n+m-1)*(n+m-2)*\cdots*((n+m-1)-(n-1)+1)\), 再乘上一个\((n-1)!\)的逆元就可以算出来了, 最后对于每个\(x\)求个和即可

Code

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <vector>
#define itn int
#define reaD read
#define mod 1000000007
#define int long long
using namespace std;

int n, m, f[22], inv[22], ans; 

inline int read()
{
    int x = 0, w = 1; char c = getchar();
    while(c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
    while(c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); }
    return x * w;
}

int fpow(int x, int y)
{
    int res = 1;
    while(y)
    {
        if(y & 1) res = res * x % mod;
        x = x * x % mod;
        y >>= 1; 
    }
    return res; 
}

int lucas(int n, int m)
{
    if(n < 0 || m < 0 || n < m) return 0;
    n %= mod;
    if(!n || !m) return 1;
    int res = 1;
    for(int i = n; i >= n - m + 1; i--) res = res * i % mod;
    res = res * inv[m] % mod;
    return res; 
}

signed main()
{
    n = read(); m = read(); inv[0] = 1;
    for(int i = 1; i <= n; i++) f[i] = reaD();
    for(int i = 1; i <= 20; i++) inv[i] = fpow(i, mod - 2);
    for(int i = 1; i <= 20; i++) inv[i] = inv[i] * inv[i - 1] % mod;
    for(int x = 0; x < (1 << n); x++)
    {
        if(!x) { ans = (ans + lucas(n + m - 1, n - 1)) % mod; continue; }
        int num = n + m - 1, cnt = 0;
        for(int i = 0; i < n; i++) if((x >> i) & 1) cnt++, num -= f[i + 1] + 1;
        if(cnt % 2) ans = ((ans - lucas(num, n - 1)) % mod + mod) % mod;
        else ans = (ans + lucas(num, n - 1)) % mod; 
    }
    printf("%lld\n", ans); 
    return 0;
}