题解:[POI2008] STA-Station

洛谷P3478 | 换根 DP

Posted by TH911 on January 22, 2025

题目传送门

树形 DP

依然是树形 DP。

首先,在根节点确定的情况下,深度之和显然可以 $\mathcal O(n)$ 求。

但是现在根节点不确定,因此我们需要 $\mathcal O(n)$ 来枚举根节点。

然而这样时间复杂度就来到了 $\mathcal O\left(n^2\right)$,考虑到 $1\leq n\leq 10^6$,无法通过此题。


如图:

假设原来的根节点是 $x$,对于 $x$ 的每一个子节点 $y$,其子树内深度之和 $dp_x,dp_y$ 应当满足:

\[\begin{aligned} dp_y&=dp_x+size_z-size_y\\ &=dp_x+(n-size_y)-size_y\\ &=dp_x+n-2\times size_y \end{aligned}\]

其中,$size_z$ 表示 $x$ 的子树内除 $y$ 的子树外的大小。

这是因为根节点从 $x$ 换到 $y$ 之后,$y$ 的子树“向上抬升”了一级,剩余部分“下降”了一级。

那么先 $\mathcal O(n)$ 求出 $1$ 为根节点时的答案,然后转移求最大值并输出对应根节点即可。

AC 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
//#include<bits/stdc++.h>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<iomanip>
#include<cstdio>
#include<string>
#include<vector>
#include<cmath>
#include<ctime>
#include<deque>
#include<queue>
#include<stack>
#include<list>
using namespace std;
typedef long long ll;
constexpr const int N=1e6;
struct graph{
	struct node{
		int v,r;
	}a[2*(N-1)+1];
	
	int h[N+1];
	void create(int u,int v){
		static int top;
		a[++top]={v,h[u]};
		h[u]=top;
	}
}g;
int n;
ll size[N+1],dep[N+1];
void dfs1(int x,int fx){
	size[x]=1;
	dep[x]=dep[fx]+1;
	for(int i=g.h[x];i>0;i=g.a[i].r){
		if(g.a[i].v==fx)continue; 
		dfs1(g.a[i].v,x);
		size[x]+=size[g.a[i].v]; 
	}
}
ll dp[N+1],dp_root[N+1];
void dfs2(int x,int fx){
	dp[x]=dep[x];
	for(int i=g.h[x];i>0;i=g.a[i].r){
		if(g.a[i].v==fx)continue;
		dfs2(g.a[i].v,x);
		dp[x]+=dp[g.a[i].v];
	}
}
void dfs3(int x,int fx){
	if(x==1)dp_root[x]=dp[x];
	else dp_root[x]=dp_root[fx] + n -2*size[x];
	for(int i=g.h[x];i>0;i=g.a[i].r){
		if(g.a[i].v==fx)continue;
		dfs3(g.a[i].v,x);
	}
}
int main(){
	/*freopen("test.in","r",stdin);
	freopen("test.out","w",stdout);*/
	
	scanf("%d",&n);
	for(int i=1;i<n;i++){
		int u,v;
		scanf("%d %d",&u,&v);
		g.create(u,v);
		g.create(v,u);
	}
	dfs1(1,0);
	dfs2(1,0);
	dp_root[1]=dp[1];
	dfs3(1,0);
	ll Max=0;
	for(int i=1;i<=n;i++){
		Max=max(Max,dp_root[i]);
	}
	for(int i=1;i<=n;i++){
		if(dp_root[i]==Max){
			printf("%d\n",i);
			return 0;
		}
	}
	
	/*fclose(stdin);
	fclose(stdout);*/
	return 0;
}