树形DP 学习总结

Posted 更强一点才行

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了树形DP 学习总结相关的知识,希望对你有一定的参考价值。

DP毕竟是算法中最精妙的部分,理解并玩得花哨还是需要一定的时间积累

之前对普通的DP也不敢说掌握,只能说略懂皮毛

在学习树形DP 的同时也算是对DP有了更深的理解吧

DP的关键就在于状态的定义以及找转移

首先要考虑清楚状态,状态要能够很好地并且完整地描述子问题

其次考虑最底层的状态,这些状态一般是最简单的情况或者是边界情况

再就是考虑某一个状态能从哪些子状态转移过来,同时还要考虑转移的顺序,确保子问题已经解决

树形DP很多时候就是通过子节点推父亲节点的状态

还是通过题目来加强理解吧

1.HDU 1520

题意:给一棵树,选最多的结点使得选择的结点不存在直接的父子关系

很容易想到一个结点有两个状态:选或者不选

所以自然地想到状态dp[u][0/1]表示u子树内的最佳答案,u的状态为选或者不选

初始化自然是叶子结点dp[u][0]=0,dp[u][1]=w[u]

转移则可以考虑依次考虑

u不选的时候:u的儿子可以任意选或者不选,所以dp[u][0]+=max(dp[v][0],dp[v][1])

u选的时候:u的儿子必定不能选,所以dp[u][1]+=dp[v][0]   然后dp[u][1]+=w[u]表示加上u这个点

答案自然就是max(dp[rt][0],dp[rt][1])了

#include"cstdio"
#include"queue"
#include"cmath"
#include"stack"
#include"iostream"
#include"algorithm"
#include"cstring"
#include"queue"
#include"map"
#include"set"
#include"vector"
#include"bitset"
#define LL long long
#define ull unsigned long long
#define mems(a,b) memset(a,b,sizeof(a))
#define mdzz int mid=(L+R)>>1
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;

const int N = 6005;
const int M = 1e5+5;
const int MOD = 998244353;
const int INF = 0x3f3f3f3f;

int tot;
int first[N],w[N],deg[N];
int dp[N][2];

struct node{
    int e,next;
    node(){}
    node(int a,int b):e(a),next(b){}
}edge[M];

void init(){
    tot=0;
    mems(first,-1);
    mems(deg,0);
    mems(dp,0);
}

void addedge(int u,int v){
    edge[tot]=node(v,first[u]);
    first[u]=tot++;
}

void dfs(int u){
    dp[u][0]=0;
    dp[u][1]=w[u];
    for(int i=first[u];i!=-1;i=edge[i].next){
        int v=edge[i].e;
        dfs(v);
        dp[u][0]+=max(dp[v][1],dp[v][0]);
        dp[u][1]+=dp[v][0];
    }
}
int n;
int main(){
    //freopen("in.txt","r",stdin);
    while(~scanf("%d",&n)){
        init();
        for(int i=1;i<=n;i++) scanf("%d",&w[i]);
        int u,v;
        while(1){
            scanf("%d%d",&v,&u);
            if(!v&&!u) break;
            addedge(u,v);deg[v]++;
        }
        int rt;
        for(int i=1;i<=n;i++) if(!deg[i]){
            dfs(rt=i);
            break;
        }
        printf("%d\\n",max(dp[rt][0],dp[rt][1]));
    }
    return 0;
}
View Code

2.POJ 1436

题意:选中一个点则与其相连的边将被覆盖,问最少选几个点使得树中所有边均被覆盖

和上一个题很类似

同样状态设为dp[u][0/1]

初始的底层状态自然是dp[u][0]=0,dp[u][1]=1;

考虑一个非叶子结点和它儿子的所有连边

如果当前结点不选,那这些边只能由其儿子保护,所以dp[u][0]+=dp[v][1]

如果当前结点选,那这些边已被保护,其儿子选和不选都行,所以dp[u][1]+=min(dp[v][0],dp[v][1])

答案自然是min(dp[rt][0],dp[rt][1])

#include"cstdio"
#include"queue"
#include"cmath"
#include"stack"
#include"iostream"
#include"algorithm"
#include"cstring"
#include"queue"
#include"map"
#include"set"
#include"vector"
#include"bitset"
#define LL long long
#define ull unsigned long long
#define mems(a,b) memset(a,b,sizeof(a))
#define mdzz int mid=(L+R)>>1
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;

const int N = 1505;
const int M = 1e5+5;
const int MOD = 998244353;
const int INF = 0x3f3f3f3f;

int tot;
int first[N],deg[N];
int dp[N][2];

struct node{
    int e,next;
    node(){}
    node(int a,int b):e(a),next(b){}
}edge[M];

void init(){
    tot=0;
    mems(first,-1);
    mems(deg,0);
}

void addedge(int u,int v){
    edge[tot]=node(v,first[u]);
    first[u]=tot++;
}

void dfs(int u){
    dp[u][0]=0;
    dp[u][1]=1;
    for(int i=first[u];i!=-1;i=edge[i].next){
        int v=edge[i].e;
        dfs(v);
        dp[u][0]+=dp[v][1];
        dp[u][1]+=min(dp[v][0],dp[v][1]);
    }
}

int n,k,u,v;
int main(){
    //freopen("in.txt","r",stdin);
    while(~scanf("%d",&n)){
        init();
        for(int i=1;i<=n;i++){
            scanf("%d:(%d)",&u,&k);
            for(int j=0;j<k;j++){
                scanf("%d",&v);
                addedge(u,v);deg[v]++;
            }
        }
        int rt;
        for(int i=0;i<n;i++) if(!deg[i]){
            dfs(rt=i);
            break;
        }
        printf("%d\\n",min(dp[rt][0],dp[rt][1]));
    }
    return 0;
}
View Code

3.URAL 1018

题意:树中每个点有权值,问只留下k个点剩下的最大权值和是多少?留下的点还是构成一棵树

树形背包

状态定义成dp[u][i]表示u子树内剩i个点的最大权值

考虑叶子结点:dp[u][0]=0,dp[u][1]=w[u]

考虑非叶子结点的一个状态dp[u][i],对于当前的一个儿子v,枚举一个k表示从这个儿子里取几个结点,维护一个最大值

其实我们这里的状态是三维的,表示u子树的前j个子树取了i个结点的答案

我们使用滚动数组把j这一维滚掉

这里简化了题目,每一个结点固定只有两个儿子,用一般做法做肯定也是没问题的

还有要注意的地方就是这里根是一定要保留的

处理方法就是对于状态dp[u][1]直接赋值,枚举时候i从2开始,这样就可以默认根已取

#include"cstdio"
#include"queue"
#include"cmath"
#include"stack"
#include"iostream"
#include"algorithm"
#include"cstring"
#include"queue"
#include"map"
#include"set"
#include"vector"
#include"bitset"
#define LL long long
#define ull unsigned long long
#define mems(a,b) memset(a,b,sizeof(a))
#define mdzz int mid=(L+R)>>1
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;

const int N = 105;
const int M = 1e5+5;
const int MOD = 998244353;
const int INF = 0x3f3f3f3f;

int tot;
int first[N],w[N];
int dp[N][N],sz[N];
int ls[N],rs[N];

struct node{
    int e,next;
    node(){}
    node(int a,int b):e(a),next(b){}
}edge[M];

void init(){
    tot=0;
    mems(first,-1);
    mems(w,-1);
    //mems(deg,0);
    mems(dp,0);
    mems(ls,-1);
    mems(rs,-1);
}

void addedge(int u,int v){
    edge[tot]=node(v,first[u]);
    first[u]=tot++;
    edge[tot]=node(u,first[v]);
    first[v]=tot++;
}

void dfs1(int u,int fa){
    sz[u]=1;
    for(int i=first[u];i!=-1;i=edge[i].next){
        int v=edge[i].e;
        if(v==fa) continue;
        dfs1(v,u);
        sz[u]+=sz[v];
        if(ls[u]==-1) ls[u]=v;
        else rs[u]=v;
    }
}

void dfs(int u){
    int f=0;
    dp[u][0]=0;dp[u][1]=w[u];
    if(ls[u]!=-1) dfs(ls[u]),f=1;
    if(rs[u]!=-1) dfs(rs[u]),f=1;
    if(!f) return;
    for(int i=2;i<=sz[u];i++)
    for(int j=0;j<=sz[ls[u]];j++) if(i-1>=j) dp[u][i]=max(dp[u][i],dp[ls[u]][j]+dp[rs[u]][i-1-j]+w[u]);
}

int n,k;

int main(){
    //freopen("in.txt","r",stdin);
    while(~scanf("%d%d",&n,&k)){
        init();
        int u,v,x,rt=1;w[rt]=0;
        for(int i=1;i<n;i++){
            scanf("%d%d%d",&u,&v,&x);
            addedge(u,v);
            if(w[v]==-1) w[v]=x;
            else w[u]=x;
        }
        dfs1(rt,-1);
        dfs(rt);
        printf("%d\\n",dp[rt][k+1]);
    }
    return 0;
}
View Code

4.HDU 2196

题意:对于树中的每一个结点,输出树中与其距离最远的结点的距离值

经典的树形DP问题

状态定义为dp[u][0/1]为u到其子树内结点的最远距离、次远距离

这样一轮dp下来,可以想到对于根来说,dp[rt][0]就是它的答案

但是对于其它结点来说只得到了其子树内的答案,而我们需要的是其对于整棵树的信息

这里需要再一次dfs,相当于反过来从根往叶子再dp一次,通过根的答案推其余结点的答案

这里之所以要维护一个次大值,就是对于一个结点u的儿子v,

若u的最远距离是经过u的,那v的答案应该是v子树内的答案和u的次大值比较,否则v的答案和u的最大值比较

画个图就明白了

#include"cstdio"
#include"queue"
#include"cmath"
#include"stack"
#include"iostream"
#include"algorithm"
#include"cstring"
#include"queue"
#include"map"
#include"set"
#include"vector"
#include"bitset"
#define LL long long
#define ull unsigned long long
#define mems(a,b) memset(a,b,sizeof(a))
#define mdzz int mid=(L+R)>>1
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;

const int N = 1e4+5;
const int M = 2e4+5;
const int MOD = 998244353;
const int INF = 0x3f3f3f3f;

int tot;
int first[N];
int mx[N][2],id[N][2];

struct node{
    int e,next,w;
    node(){}
    node(int a,int b,int c):e(a),next(b),w(c){}
}edge[M];

void init(){
    tot=0;
    mems(first,-1);
    mems(mx,0);
    mems(id,-1);
}

void addedge(int u,int v,int w){
    edge[tot]=node(v,first[u],w);
    first[u]=tot++;
    edge[tot]=node(u,first[v],w);
    first[v]=tot++;
}

void dfs1(int u,int fa){
    for(int i=first[u];i!=-1;i=edge[i].next){
        int v=edge[i].e;
        if(v==fa) continue;
        dfs1(v,u);
        if(mx[v][0]+edge[i].w>=mx[u][0]){
            mx[u][1]=mx[u][0];
            id[u][1]=id[u][0];id[u][0]=v;
            mx[u][0]=mx[v][0]+edge[i].w;
        }
        else if(mx[v][0]+edge[i].w>mx[u][1]) mx[u][1]=mx[v][0]+edge[i].w,id[u][1]=v;
    }
}

void dfs2(int u,int fa){
    for(int i=first[u];i!=-1;i=edge[i].next){
        int v=edge[i].e;
        if(v==fa) continue;
        if(id[u][0]==v){
            if(mx[v][1]<mx[u][1]+edge[i].w){
                mx[v][1]=mx[u][1]+edge[i].w;
                id[v][1]=u;
            }
        }
        else{
            if(mx[v][1]<mx[u][0]+edge[i].w){
                mx[v][1]=edge[i].w+mx[u][0];
                id[v][1]=u;
            }
        }
        if(mx[v][0]<mx[v][1]){
            swap(mx[v][0],mx[v][1]);
            swap(id[v][0],id[v][1]);
        }
        dfs2(v,u);
    }
}

int n,u,w;

int main(){
    //freopen("in.txt","r",stdin);
    while(~scanf("%d",&n)){
        init();
        for(int i=2;i<=n;i++){
            scanf("%d%d",&u,&w);
            addedge(i,u,w);
        }
        dfs1(1,-1);
        dfs2(1,-1);
        for(int i=1;i<=n;i++) printf("%d\\n",mx[i][0]);
    }
    return 0;
}
View Code

5.POJ 2152

题意:树中每个结点有两个值:w[i]表示在i建设防火设施的价格,d[i]表示与i最近的防火设施的距离上限,求满足所有d[i]的最小花费

算是一道比较难的树形dp,状态和普通的树形DP略有不同

树形dp很多时候是把一个结点及其子树看成一个整体,然后考虑这个结点的状态进行转移

考虑到数据范围N<=1000,可以定义状态dp[u][i]表示u依靠i时,保证子树内安全的最小花费

为了转移方便,再定义all[u]表示保证u的安全的最小花费

其实可以理解为all[u]是在dp[u][i]中取了个最优值

要确定一个点是否能被u依靠就需要知道u与该点的距离

所以先n^2处理树中任意两点的距离

考虑叶子结点:dp[u][i]=w[i]

考虑一个非叶子结点u,先枚举它依靠的结点i

再考虑u的儿子v,v可以选择依靠或者不依靠i,则dp[u][i]+=min(dp[v][i]-c[i],all[v])

对于u的每一个i更新u的最优解all[u]

对于u的孙子k是不需要考虑的,因为k依靠i的情况在解决v的时候已经考虑到了,所以不会有重复计算的情况

#include"cstdio"
#include"queue"
#include"cmath"
#include"stack"
#include"iostream"
#include"algorithm"
#include"cstring"
#include"queue"
#include"map"
#include"set"
#include"vector"
#include"bitset"
#define LL long long
#define ull unsigned long long
#define mems(a,b) memset(a,b,sizeof(a))
#define mdzz int mid=(L+R)>>1
#define ls pos<<1
#define rs pos<<1|1
#define lson L,mid,pos<<1
#define rson mid+1,R,pos<<1|1
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;

const int N = 1e3+5;
const int M = 2e3+5;
const int MOD = 998244353;
const int INF = 0x3f3f3f3f;

int tot;
int first[N];
int dp[N][N],all[N];
int n,u,v,x;
int cost[N],d[N],dis[N][N];

struct node{
    int e,next,w;
    node(){}
    node(int a,int b,int c):e(a),next(b),w(c){}
}edge[M];

void init(){
    tot=0;
    mems(first,-1);
    mems(all,INF);
    mems(dp,INF);
}

void addedge(int u,int v,int w){
    edge[tot]=node(v,first[u],w);
    first[u]=tot++;
    edge[tot]=node(u,first[v],w);
    first[v]=tot++;
}

void dfs1(int rt,int u,int fa){
    for(int i=first[u];i!=-1;i=edge[i].next){
        int v=edge[i].e;
        if(v==fa) continue;
        dis[rt][v]=dis[rt][u]+edge[i].w;
        dfs1(rt,v,u);
    }
}

void dfs2(int u,int fa){
    for(int i=first[u];i!=-1;i=edge[i].next){
        int v=edge[i].e;
        if(v==fa) continue;
        dfs2(v,u);
    }
    for(int k=1;k<=n;k++) if(dis[u][k]<=d[u]){
        dp[u][k]=cost[k];
        for(int i=first[u];i!=-1;i=edge[i].next){
            int v=edge[i].e;
            if(v==fa) continue;
            dp[u][k]+=min(dp[v][k]-cost[k],all[v]);
        }
        all[u]=min(all[u],dp[u][k]);
    }
}

int T;
int main(){
    //freopen("in.txt","r",stdin);
    for(int i=0;i<N;i++) dis[i][i]=0;
    scanf("%d",&T);
    while(T--){
        init();
        scanf("%d",&n);
        for(int i=1;i<=n;i++) scanf("%d",&cost[i]);
        for(int i=1;i<=n;i++) scanf("%d",&d[i]);
        for(int i=1;i<n;i++){
            scanf("%d%d%d",&u,&v,&x);
            addedge(u,v,x);
        }
        for(int i=1;i<=n;i++) dfs1(i,i,-1);
        dfs2(1,-1);
        printf("%d\\n",all[1]);
    }
    return 0;
}
View Code

6.POJ 3162

题意:对于树中每一个结点i找到另一个结点使得两者距离dp[i]最远,最后输出一段最长区间的长度,区间maxv-minv<=M

只是在树形dp上加了点东西而已,用线段树+two pointer维护就好了

#include"cstdio"
#include"queue"
#include"cmath"
#include"stack"
#include"iostream"
#include"algorithm"
#include"cstring"
#include"queue"
#include"map"
#include"set"树形DP

树形dp总结

树形DP初探?总结

hdu 1561 The more, The Better 树形dp

树形 DP 总结

树形 DP 总结