The 15th Chinese Northeast Collegiate Programming Contest C. Vertex Deletion (树形dp)

Posted Accelerator

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了The 15th Chinese Northeast Collegiate Programming Contest C. Vertex Deletion (树形dp)相关的知识,希望对你有一定的参考价值。

  • 题意:一棵\\(n\\)个顶点的树,定义一次删点是”美丽的“,如果删去某个点后,树的每个点都有边,问有多少种”美丽的“删点方式。

  • 题解:对于某个点父亲结点\\(u\\),考虑它的儿子和子树。

    定义\\(f[u][0/1/2]\\)分别表示\\(u\\)的三种形式:

    0:删去\\(u\\)这个结点,并且保证删去后所有子树都是合法的。

    1:不删\\(u\\),但是它和所有儿子之间都没有边

    2:不删\\(u\\),且至少有一个儿子和它有边。

    那么不难发现,对于\\(u\\),合法的方案数为\\(f[u][0]+f[u][2]\\).

    下面来看状态怎么转移:

    对于\\(f[u][0]\\),因为要求子树全部合法,所以\\(f[u][0]=\\prod_{v\\in son_u}(f[v][0]+f[v][2])\\).

    对于\\(f[u][1]\\),因为和所有儿子都没有边,也就是说它的所有儿子都被删了,所以\\(f[u][1]=\\prod_{v\\in son_u}f[v][0]\\).

    对于\\(f[u][2]\\),因为是至少有一条边,根据容斥,转换一下也就是说所有情况减去都没有边的情况,就是至少有一条边的情况,那么\\(f[u][2]=\\prod_{v\\in son_u}(f[v][0]+f[v][1]+f[v][2])-\\prod_{v\\in son_u}f[v][1]\\).
    有可能出现负数,取模时要注意!

  • 代码:

#include <bits/stdc++.h>
#define ll long long
#define fi first
#define se second
#define pb push_back
#define me memset
#define rep(a,b,c) for(int a=b;a<=c;++a)
#define per(a,b,c) for(int a=b;a>=c;--a)
const int N = 1e6 + 10;
const int mod = 998244353;
const int INF = 0x3f3f3f3f;
using namespace std;
typedef pair<int,int> PII;
typedef pair<ll,ll> PLL;
ll gcd(ll a,ll b) {return b?gcd(b,a%b):a;}
ll lcm(ll a,ll b) {return a/gcd(a,b)*b;}

vector<int> edge[N];
ll dp[N][3];

void dfs(int u,int fa){
    dp[u][0]=dp[u][1]=dp[u][2]=1;
    for(auto to:edge[u]){
        if(to==fa) continue;
        dfs(to,u);
        dp[u][0]=(dp[to][0]+dp[to][2])%mod*dp[u][0]%mod;
        dp[u][1]=dp[u][1]*dp[to][0]%mod;
        dp[u][2]=(dp[to][0]+dp[to][1]+dp[to][2])%mod*dp[u][2]%mod;
    }
    dp[u][2]=(dp[u][2]-dp[u][1])%mod;
}

int main() {
    int _;
    scanf("%d",&_);
    while(_--){
        int n;
        scanf("%d",&n);
        for(int i=1;i<=n;++i) edge[i].clear(),dp[i][0]=dp[i][1]=dp[i][2]=0;
        for(int i=1;i<n;++i){
            int u,v;
            scanf("%d %d",&u,&v);
            edge[u].pb(v),edge[v].pb(u);
        }
        dfs(1,0);
        printf("%lld\\n",((dp[1][0]+dp[1][2])%mod+mod)%mod);
    }
    return 0;
}