Codeforces 855-C 树形DP
September 26, 2017
题意 #
给一颗树, 每个结点可以标记一个type。 一共有m种type,分别是1-m。 其中有一种特殊的type,编号为k,要求:
- 在这棵树中,最多出现x个type为k的点。
- 每个type为k的点,其相邻的点type必须小于k。
问,给定以上数据,求树的标记种类数有多少。
数据范围: n <= 1e5, x <=10, m <=1e9
分析 #
这是一个典型的树形DP,DP的维度应该至少是3维。
Dp[i][j][z]
表示: 以i为根的子树,用掉z个k时的种类数。
其中对于根结点的type选取,分三种情况,
0
:type < k (type >=1 )1
:type = k2
:type > k (type <= m)
Dp表示内容确定以后,就要来找转移方程了。 这种组合种类的问题,只要分清楚哪些量之间是相乘的关系,哪些量是相加的关系就好办了。
思考一种普遍的情况,对于一个根节点和其儿子节点之间的关系,以及儿子之间的关系。
可以确定的是,各个儿子之间的关系是相乘的关系。因为儿子之间是没有约束的,不直接相连就没有type < k
的约束。所以每个子树内的排列种类是直接乘到贡献中的。
然后就是儿子与父亲节点的关系。
讨论一下,如果父亲的取值是0类型,也就是type < k
。 对于任何一个儿子都可以任意选择type,并且这些选择之间是相加的关系。选择之和与父亲的选择种类之间是相乘关系。
总上,父亲与儿子、儿子与儿子之间都是相乘的关系,儿子内部是相加关系。
这个时候我们可以选择将父亲先放在一边,先合并各个儿子。最后把合并后的儿子与父亲合并。
为什么不能直接把儿子合并到父亲上呢?
因为对于一个固定的z来说,其对应的情况是z个k在所有子树中的全排列种情况。这些排列又是相加的关系,很显然我们不能去求这个排列,复杂度太大。
我们要在处理每个儿子的时候,将其对不同的z的贡献都存起来,这个我们可以做到。
以下的代码以根为0时举例,其他两个情况一样。
for(int j=0; j<=x ;j++){
for(int z = 0; z <= x; z++){
if(l+z > x) continue;
Dp[u][0][j+z] += ta[0][j] * (Dp[v][0][z] + Dp[v][1][z] + Dp[v][2][z]);
}
}
ta[0][j]
记录的是在当前儿子之前的所有儿子中,使用了j个k时,所有的种类数。
其中,Dp[u][0][j+z]
记录的是对于当前儿子结点v,使用了j+z 个k时,所有的种类数。
在每次处理完一个儿子以后,就将ta数组更新一下。并将Dp[u]…清空。
最后再将结果和根结点合并。这个合并就比较简单了。
附上代码: #
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 100005;
const long long MOD = 1000000007;
int cnt = 0, head[MAXN];
int n, m, k, x;
typedef long long ll;
ll Dp[MAXN][3][15], ta[3][15];
struct Edge{
int u, v, nxt;
}edge[MAXN<<2];
void init(){
cnt = 0;
memset(head, -1, sizeof(head));
}
void add(int u, int v){
edge[cnt].u = u;
edge[cnt].v = v;
edge[cnt].nxt = head[u];
head[u] = cnt++;
}
void Dfs(int u, int pre){
int hasSon = 0;
for(int i = head[u]; i!=-1; i = edge[i].nxt ){
int v = edge[i].v;
if(v == pre) continue;
hasSon = 1;
Dfs(v, u);
}
if( !hasSon ) {
Dp[u][0][0] = 1ll*k-1;
Dp[u][1][1] = 1ll;
Dp[u][2][0] = 1ll*m-k;
return ;
}
memset(ta, 0, sizeof(ta));
ta[0][0] = ta[1][0] = ta[2][0] = 1ll;
for(int j = 0; j<=x ; j++){
Dp[u][0][j] = Dp[u][1][j] = Dp[u][2][j] = 0;
}
for(int i = head[u]; i != -1 ; i = edge[i].nxt){
int v = edge[i].v;
if(v == pre) continue;
for(int j = 0; j <= x; j++ ){
for(int z = 0; z <= x; z++){
if(j+z>x) continue;
Dp[u][0][j+z] += ta[0][j] * ( Dp[v][0][z] + Dp[v][1][z] + Dp[v][2][z] ) %MOD;
Dp[u][1][j+z] += ta[1][j] * Dp[v][0][z] %MOD;
Dp[u][2][j+z] += ta[2][j] * ( Dp[v][0][z] + Dp[v][2][z] ) %MOD;
}
}
for(int j = 0; j <=x; j++)
{
ta[0][j] = Dp[u][0][j]%MOD;
ta[1][j] = Dp[u][1][j]%MOD;
ta[2][j] = Dp[u][2][j]%MOD;
Dp[u][0][j] = 0;
Dp[u][1][j] = 0;
Dp[u][2][j] = 0;
}
}
for(int i=0;i<=x; i++){
Dp[u][0][i] = 1ll*(k-1)*ta[0][i]%MOD;
if(i>0)
Dp[u][1][i] = 1ll*ta[1][i-1]%MOD;
Dp[u][2][i] = 1ll*(m-k)*ta[2][i]%MOD;
}
}
int main(){
init();
cin >> n >> m;
int u, v;
for(int i=1;i<n;i++){
cin >> u >> v;
add(u, v);
add(v, u);
}
cin >> k >> x;
memset(Dp, 0, sizeof(Dp));
Dfs(1, 0);
ll ans = 0;
for(int i=0;i<3;i++){
for(int j = 0; j <=x; j++)
ans = (ans + Dp[1][i][j])%MOD;
}
printf("%lld\n", ans%MOD);
return 0;
}