Tree with Maximum Cost (树上DP)
题意:
给定一颗无根树,每点都有权值a[i],定义两点之间的距离dis(x,y)是两点之间简单路径上的边数,求树的最大成本。其中树的成本定义为:假设树的根节点为v,成本为
Σdis(i,v)∗a[i](1<=i<=n)
思路:
首先假设根节点为1,预先dfs处理出每个点的子树的权值和a[i]以及子树的成本dp[i]。
再考虑换根,假设本来u是树的根节点,设u的子节点为v,现在根从u转移到v,那么对距离的影响就是所有v的子节点距离根节点的距离都减少了1,所有v的非子节点距离根节点的距离都增加了一。又因为对于成本来说,是距离*权值,所以对答案的影响就相当于是减少了子节点的权值和,增加了非子节点的权值和。我们记录a[i]表示点i的子树的权值和,那么a[1]就表示整棵树的权值和,所以子节点的权值和就是a[i],非子节点的权值和就是a[1]-a[i];所以每次换根的时候,对答案的贡献就是:
dp[u]=dp[fa]−a[u]+a[1]−a[u];
代码:
///#pragma GCC optimize(3) ///#pragma GCC optimize("Ofast","unroll-loops","omit-frame-pointer","inline") ///#pragma GCC optimize(2) #include<bits/stdc++.h> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef pair<ll,ll>PLL; typedef pair<int,int>PII; typedef pair<double,double>PDD; #define I_int ll inline ll read() { ll x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } char F[200]; inline void out(I_int x) { if (x == 0) return (void) (putchar('0')); I_int tmp = x > 0 ? x : -x; if (x < 0) putchar('-'); int cnt = 0; while (tmp > 0) { F[cnt++] = tmp % 10 + '0'; tmp /= 10; } while (cnt > 0) putchar(F[--cnt]); //cout<<" "; } ll ksm(ll a,ll b,ll p){ll res=1;while(b){if(b&1)res=res*a%p;a=a*a%p;b>>=1;}return res;} const int inf=0x3f3f3f3f,mod=998244353; const ll INF = 0x3f3f3f3f3f3f3f3f; const int maxn=2e5+100,maxm=3e5+7,N=1e6+7; const double PI = atan(1.0)*4; int h[maxn]; struct node{ int e,ne; }edge[maxn*2]; ll n; ll w[maxn],a[maxn],dp[maxn]; int idx=0; void add(int u,int v){ edge[idx].e=v,edge[idx].ne=h[u],h[u]=idx++; } ll res=0; void dfs1(int u,int fa){ for(int i=h[u];~i;i=edge[i].ne){ int j=edge[i].e; if(j==fa) continue; dfs1(j,u); a[u]+=a[j]; dp[u]+=dp[j]+a[j]; } } void dfs2(int u,int fa){ if(u!=1) dp[u]=dp[fa]+a[1]-2*a[u]; for(int i=h[u];~i;i=edge[i].ne){ int j=edge[i].e; if(j==fa) continue; dfs2(j,u); } res=max(res,dp[u]); } int main(){ memset(h,-1,sizeof h); memset(dp,0,sizeof dp); n=read(); for(int i=1;i<=n;i++) a[i]=read(); for(int i=1;i<n;i++){ int u=read(),v=read(); add(u,v);add(v,u); } dfs1(1,-1); dfs2(1,-1); out(res); return 0; }
参考博客