Codeforces 855-C 树形DP

Codeforces 855-C 树形DP

September 26, 2017
ACM, DP

题意 #

给一颗树, 每个结点可以标记一个type。 一共有m种type,分别是1-m。 其中有一种特殊的type,编号为k,要求:

  1. 在这棵树中,最多出现x个type为k的点。
  2. 每个type为k的点,其相邻的点type必须小于k。

问,给定以上数据,求树的标记种类数有多少。

数据范围: n <= 1e5, x <=10, m <=1e9

分析 #

这是一个典型的树形DP,DP的维度应该至少是3维。

Dp[i][j][z] 表示: 以i为根的子树,用掉z个k时的种类数。 其中对于根结点的type选取,分三种情况,

  • 0:type < k (type >=1 )
  • 1:type = k
  • 2:type > k (type <= m)

Dp表示内容确定以后,就要来找转移方程了。 这种组合种类的问题,只要分清楚哪些量之间是相乘的关系,哪些量是相加的关系就好办了。

思考一种普遍的情况,对于一个根节点和其儿子节点之间的关系,以及儿子之间的关系。 可以确定的是,各个儿子之间的关系是相乘的关系。因为儿子之间是没有约束的,不直接相连就没有type < k的约束。所以每个子树内的排列种类是直接乘到贡献中的。

然后就是儿子与父亲节点的关系。 讨论一下,如果父亲的取值是0类型,也就是type < k。 对于任何一个儿子都可以任意选择type,并且这些选择之间是相加的关系。选择之和与父亲的选择种类之间是相乘关系。

总上,父亲与儿子、儿子与儿子之间都是相乘的关系,儿子内部是相加关系。

这个时候我们可以选择将父亲先放在一边,先合并各个儿子。最后把合并后的儿子与父亲合并。

为什么不能直接把儿子合并到父亲上呢?

因为对于一个固定的z来说,其对应的情况是z个k在所有子树中的全排列种情况。这些排列又是相加的关系,很显然我们不能去求这个排列,复杂度太大。

我们要在处理每个儿子的时候,将其对不同的z的贡献都存起来,这个我们可以做到。

以下的代码以根为0时举例,其他两个情况一样。

for(int j=0; j<=x ;j++){
    for(int z = 0; z <= x; z++){
        if(l+z > x) continue;

        Dp[u][0][j+z] += ta[0][j] * (Dp[v][0][z] + Dp[v][1][z] + Dp[v][2][z]);

    }
}

ta[0][j] 记录的是在当前儿子之前的所有儿子中,使用了j个k时,所有的种类数。 其中,Dp[u][0][j+z] 记录的是对于当前儿子结点v,使用了j+z 个k时,所有的种类数。

在每次处理完一个儿子以后,就将ta数组更新一下。并将Dp[u]…清空。

最后再将结果和根结点合并。这个合并就比较简单了。

附上代码: #

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 100005;
const long long MOD = 1000000007;
int cnt = 0, head[MAXN];
int n, m, k, x;
typedef long long ll;
ll Dp[MAXN][3][15], ta[3][15];

struct Edge{
    int u, v, nxt;
}edge[MAXN<<2];

void init(){
    cnt = 0;
    memset(head, -1, sizeof(head));
}

void add(int u, int v){
    edge[cnt].u = u;
    edge[cnt].v = v;
    edge[cnt].nxt = head[u];
    head[u] = cnt++;
}
void Dfs(int u, int pre){
    int hasSon = 0;
    for(int i = head[u]; i!=-1; i = edge[i].nxt ){
        int v = edge[i].v;
        if(v == pre) continue;
        hasSon = 1;

        Dfs(v, u);
    }

    if( !hasSon ) {
        Dp[u][0][0] = 1ll*k-1;
        Dp[u][1][1] = 1ll;
        Dp[u][2][0] = 1ll*m-k;

        return ;
    }

    memset(ta, 0, sizeof(ta));
    ta[0][0] = ta[1][0] = ta[2][0] = 1ll;
    for(int j = 0; j<=x ; j++){
        Dp[u][0][j] = Dp[u][1][j] = Dp[u][2][j] = 0;
    }
    for(int i = head[u]; i != -1 ; i = edge[i].nxt){
        int v = edge[i].v;
        if(v == pre) continue;

            for(int j = 0; j <= x; j++ ){
                for(int z = 0; z <= x; z++){

                    if(j+z>x) continue;
                    Dp[u][0][j+z] += ta[0][j] * ( Dp[v][0][z] + Dp[v][1][z] + Dp[v][2][z] ) %MOD;

                    Dp[u][1][j+z] += ta[1][j] * Dp[v][0][z] %MOD;

                    Dp[u][2][j+z] += ta[2][j] * ( Dp[v][0][z] + Dp[v][2][z] ) %MOD;
                }
            }

        for(int j = 0; j <=x; j++)
        {
            ta[0][j] = Dp[u][0][j]%MOD;
            ta[1][j] = Dp[u][1][j]%MOD;
            ta[2][j] = Dp[u][2][j]%MOD;

            Dp[u][0][j] = 0;
            Dp[u][1][j] = 0;
            Dp[u][2][j] = 0;
        }
    }

    for(int i=0;i<=x; i++){
        Dp[u][0][i] = 1ll*(k-1)*ta[0][i]%MOD;
        if(i>0)
            Dp[u][1][i] = 1ll*ta[1][i-1]%MOD;
        Dp[u][2][i] = 1ll*(m-k)*ta[2][i]%MOD;
    }
}

int main(){

    init();
    cin >> n >> m;
    int u, v;
    for(int i=1;i<n;i++){
        cin >> u >> v;
        add(u, v);
        add(v, u);
    }
    cin  >> k >> x;

    memset(Dp, 0, sizeof(Dp));

    Dfs(1, 0);

    ll ans = 0;
    for(int i=0;i<3;i++){
        for(int j = 0; j <=x; j++)
            ans = (ans + Dp[1][i][j])%MOD;
    }
    printf("%lld\n", ans%MOD);

    return 0;
}