C. Helga Hufflepuff's Cup 树形dp 难

时间:2023-03-09 13:31:26
C. Helga Hufflepuff's Cup 树形dp 难

C. Helga Hufflepuff's Cup

这个题目我感觉挺难的,想了好久也写了很久,还是没有写出来。

dp[i][j][k] 代表以 i 为根的子树*选择了 j 个特殊颜色,且当前节点 i 的状态为 k 的染色方案数。

  1. k=0 ,代表当前节点 i 的颜色值小于 K 。
  2. k=1,代表当前节点 i 的颜色值等于 K 。
  3. k=2,代表当前节点 i 的颜色值大于 K 。

但是这个dfs过程的处理我觉得很复杂。

我们需要一个数组来进行临时的存储。

tmp[i][k] 表示选了 i 个  状态为 j 的方案数。

先枚举这个点已经用了 i 个,然后枚举这个子节点可以加上 j 个

tmp[j + h][0] += dp[u][j][0] * (dp[v][h][0] + dp[v][h][1] + dp[v][h][2]) % mod;
tmp[j + h][0] %= mod;
tmp[j + h][1] += dp[u][j][1] * dp[v][h][0] % mod;
tmp[j + h][1] %= mod;
tmp[j + h][2] += dp[u][j][2] * (dp[v][h][0] + dp[v][h][2]) % mod;
tmp[j + h][2] %= mod;

然后每次再赋值回去,注意这个tmp的定义,tmp[i,j]是k有 i 个,状态为 j 的方案。

这个地方我觉得挺难想的,我是没有想到。

这个是对于每一个节点u,我们枚举它的子树,一开始这个dp[u] 的初始化要注意,

我们先考虑如果子树没有k,这个状态是怎么样的,因为只有这个状态我们是可以定下来的,如果有k,这种状态必须通过转移得来。

然后再枚举子树k的个数,再去更新这个节点u。

这里用了一点点的乘法原理,挺厉害的,我一开始想的是,枚举这个u的所有可能的k的数量,然后再去用背包,这个好像有点不对。

因为我们不是取最大的那个,而是取求组合数,应该枚举每一种可能,然后相乘,也有可能是我想的不太对。

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <vector>
#include <iostream>
#include <string>
#define inf 0x3f3f3f3f
#define inf64 0x3f3f3f3f3f3f3f3f
using namespace std;
const int maxn = 2e5 + ;
const int mod = 1e9 + ;
typedef long long ll;
struct node {
int v, nxt;
node(int v = , int nxt = ) :v(v), nxt(nxt) {}
}ex[maxn];
int head[maxn], cnt = ;
int n, m, k, x;
void init()
{
memset(head, -, sizeof(head));
cnt = ;
}
void add(int u,int v)
{
ex[cnt] = node(v, head[u]);
head[u] = cnt++;
ex[cnt] = node(u, head[v]);
head[v] = cnt++;
}
int num[maxn];
ll tmp[][];
ll dp[maxn][][];//注意这个dp的定义,dp[i,j,x]表示到i这个节点,子树用k的数量为j,x==0 表示这个节点小于k,1代表等于k,2代表大于k
void dfs(int u,int pre)
{
dp[u][][] = k - ;
dp[u][][] = ;
dp[u][][] = m - k;
num[u] = ;
for (int i = head[u]; i != -; i = ex[i].nxt) {
int v = ex[i].v;
if (v == pre) continue;
dfs(v, u);
memset(tmp, , sizeof(tmp));
for (int j = ; j <= num[u]; j++) {
for (int h = ; h <= num[v]; h++) {
if (j + h > x) continue;
tmp[j + h][] += dp[u][j][] * (dp[v][h][] + dp[v][h][] + dp[v][h][]) % mod;
tmp[j + h][] %= mod;
tmp[j + h][] += dp[u][j][] * dp[v][h][] % mod;
tmp[j + h][] %= mod;
tmp[j + h][] += dp[u][j][] * (dp[v][h][] + dp[v][h][]) % mod;
tmp[j + h][] %= mod;
}
}
num[u] = min(x, num[u] + num[v]);
for (int j = ; j <= num[u]; j++) {
for (int h = ; h < ; h++) {
dp[u][j][h] = tmp[j][h];
}
}
}
} int main()
{
init();
scanf("%d%d", &n, &m);
for(int i=;i<=n-;i++)
{
int u, v;
scanf("%d%d", &u, &v);
add(u, v);
}
ll ans = ;
scanf("%d%d", &k, &x);
dfs(, -);
for (int i = ; i <= x; i++) {
for (int j = ; j < ; j++) {
ans = (ans + dp[][i][j]) % mod;
}
}
printf("%lld\n", ans);
return ;
}