Codeforces 855-C 树形DP

题意

给一颗树, 每个结点可以标记一个type。
一共有m种type,分别是1-m。
其中有一种特殊的type,编号为k,要求:

  1. 在这棵树中,最多出现x个type为k的点。
  2. 每个type为k的点,其相邻的点type必须小于k。

问,给定以上数据,求树的标记种类数有多少。

数据范围: n <= 1e5, x <=10, m <=1e9

分析

这是一个典型的树形DP,DP的维度应该至少是3维。

Dp[i][j][z] 表示: 以i为根的子树,用掉z个k时的种类数。
其中对于根结点的type选取,分三种情况,

  • 0:type < k (type >=1 )
  • 1:type = k
  • 2:type > k (type <= m)

Dp表示内容确定以后,就要来找转移方程了。
这种组合种类的问题,只要分清楚哪些量之间是相乘的关系,哪些量是相加的关系就好办了。

思考一种普遍的情况,对于一个根节点和其儿子节点之间的关系,以及儿子之间的关系。
可以确定的是,各个儿子之间的关系是相乘的关系。因为儿子之间是没有约束的,不直接相连就没有type < k的约束。所以每个子树内的排列种类是直接乘到贡献中的。

然后就是儿子与父亲节点的关系。
讨论一下,如果父亲的取值是0类型,也就是type < k。 对于任何一个儿子都可以任意选择type,并且这些选择之间是相加的关系。选择之和与父亲的选择种类之间是相乘关系。

总上,父亲与儿子、儿子与儿子之间都是相乘的关系,儿子内部是相加关系。

这个时候我们可以选择将父亲先放在一边,先合并各个儿子。最后把合并后的儿子与父亲合并。

为什么不能直接把儿子合并到父亲上呢?

因为对于一个固定的z来说,其对应的情况是z个k在所有子树中的全排列种情况。这些排列又是相加的关系,很显然我们不能去求这个排列,复杂度太大。

我们要在处理每个儿子的时候,将其对不同的z的贡献都存起来,这个我们可以做到。

以下的代码以根为0时举例,其他两个情况一样。

1
2
3
4
5
6
7
8
for(int j=0; j<=x ;j++){
for(int z = 0; z <= x; z++){
if(l+z > x) continue;

Dp[u][0][j+z] += ta[0][j] * (Dp[v][0][z] + Dp[v][1][z] + Dp[v][2][z]);

}
}

ta[0][j] 记录的是在当前儿子之前的所有儿子中,使用了j个k时,所有的种类数。
其中,Dp[u][0][j+z] 记录的是对于当前儿子结点v,使用了j+z 个k时,所有的种类数。

在每次处理完一个儿子以后,就将ta数组更新一下。并将Dp[u]…清空。

最后再将结果和根结点合并。这个合并就比较简单了。

附上代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 100005;
const long long MOD = 1000000007;
int cnt = 0, head[MAXN];
int n, m, k, x;
typedef long long ll;
ll Dp[MAXN][3][15], ta[3][15];

struct Edge{
int u, v, nxt;
}edge[MAXN<<2];

void init(){
cnt = 0;
memset(head, -1, sizeof(head));
}

void add(int u, int v){
edge[cnt].u = u;
edge[cnt].v = v;
edge[cnt].nxt = head[u];
head[u] = cnt++;
}
void Dfs(int u, int pre){
int hasSon = 0;
for(int i = head[u]; i!=-1; i = edge[i].nxt ){
int v = edge[i].v;
if(v == pre) continue;
hasSon = 1;

Dfs(v, u);
}

if( !hasSon ) {
Dp[u][0][0] = 1ll*k-1;
Dp[u][1][1] = 1ll;
Dp[u][2][0] = 1ll*m-k;

return ;
}

memset(ta, 0, sizeof(ta));
ta[0][0] = ta[1][0] = ta[2][0] = 1ll;
for(int j = 0; j<=x ; j++){
Dp[u][0][j] = Dp[u][1][j] = Dp[u][2][j] = 0;
}
for(int i = head[u]; i != -1 ; i = edge[i].nxt){
int v = edge[i].v;
if(v == pre) continue;

for(int j = 0; j <= x; j++ ){
for(int z = 0; z <= x; z++){

if(j+z>x) continue;
Dp[u][0][j+z] += ta[0][j] * ( Dp[v][0][z] + Dp[v][1][z] + Dp[v][2][z] ) %MOD;

Dp[u][1][j+z] += ta[1][j] * Dp[v][0][z] %MOD;

Dp[u][2][j+z] += ta[2][j] * ( Dp[v][0][z] + Dp[v][2][z] ) %MOD;
}
}

for(int j = 0; j <=x; j++)
{
ta[0][j] = Dp[u][0][j]%MOD;
ta[1][j] = Dp[u][1][j]%MOD;
ta[2][j] = Dp[u][2][j]%MOD;

Dp[u][0][j] = 0;
Dp[u][1][j] = 0;
Dp[u][2][j] = 0;
}
}

for(int i=0;i<=x; i++){
Dp[u][0][i] = 1ll*(k-1)*ta[0][i]%MOD;
if(i>0)
Dp[u][1][i] = 1ll*ta[1][i-1]%MOD;
Dp[u][2][i] = 1ll*(m-k)*ta[2][i]%MOD;
}
}

int main(){

init();
cin >> n >> m;
int u, v;
for(int i=1;i<n;i++){
cin >> u >> v;
add(u, v);
add(v, u);
}
cin >> k >> x;

memset(Dp, 0, sizeof(Dp));

Dfs(1, 0);

ll ans = 0;
for(int i=0;i<3;i++){
for(int j = 0; j <=x; j++)
ans = (ans + Dp[1][i][j])%MOD;
}
printf("%lld\n", ans%MOD);

return 0;
}
Talk is not cheap.