树的直径 (dp或贪心)
B4016 树的直径 - 洛谷
题目大意:
给定一棵 n 个结点的树,树没有边权。请求出树的直径是多少,即树上的最长路径长度是多少。
思路:
- 树形 d p dp dp 求树的直径
定义 d [ x ] d[x] d[x] 表示以 x x x 节点出发走向 x x x 的子树,能到达的最远距离,
接下来只需要考虑对每个 x x x 节点求出 经过 x x x 节点的最长链即可,
定义 f [ x ] f[x] f[x] 表示经过 x x x 节点的最长链,
考虑转移, d [ x ] d[x] d[x] 只需先 d f s dfs dfs 到叶子节点,再由下向上更新即可
void dfs(int u,int fa){
for(auto x:g[u]){
if(x==fa) continue;
dfs(x,u);
d[u]=max(d[u],d[x]+1);
}
}
f [ x ] f[x] f[x] 只需考虑 x x x 节点能够到达的 两个最远节点 y i y_i yi 的 d [ y i ] d[y_i] d[yi] 即可,
而在更新 d [ x ] d[x] d[x] 的时候,每次都会保存一个最大的节点,因此更新 f [ x ] f[x] f[x] 时,只需考虑 d [ x ] d[x] d[x] 中保存最大的与要更新的路径求 m a x max max 即可
void dfs(int u,int fa){
for(auto x:g[u]){
if(x==fa) continue;
dfs(x,u);
f[u]=max(f[u],d[u]+d[x]+1);
d[u]=max(d[u],d[x]+1);
}
}
- 两边 d f s dfs dfs 贪心求直径
从任意一个节点出发, d f s dfs dfs 遍历到这个点最远能够到达的点,这个点一定是直径上的点,如果不是直径上的点,那么一定存在一个点比这个点更优
第二次以直径上的点 d f s dfs dfs 求一遍最大值即可
void dfs(int u,int fa){
for(auto x:g[u]){
if(x==fa) continue;
d[x]=d[u]+1;
if(d[x]>ans){
ans=d[x];
pos=x;
}
dfs(x,u);
}
}
代码1(树形 d p dp dp ):
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define fi first
#define se second
#define PII pair<int,int>
#define lowbit(x) x&-x
#define ALL(x) x.begin(),x.end()
const int mod = 1e9 + 7;
const int N = 1e5+10;
int d[N],f[N];
int n;
vector<int> g[N];
void dfs(int u,int fa) {
for(auto x:g[u]){
if(x==fa) continue;
dfs(x,u);
f[u]=max(f[u],d[u]+d[x]+1);
d[u]=max(d[u],d[x]+1);
}
}
void solve() {
int n;cin>>n;
for(int i=2;i<=n;i++){
int u,v;cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1,0);
int ans=0;
for(int i=1;i<=n;i++){
ans=max(ans,f[i]);
}
cout<<ans;
}
signed main() {
std::ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int T = 1;
// cin >> T;
while (T--) {
solve();
}
return 0;
}
代码2(贪心 d f s dfs dfs ):
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define fi first
#define se second
#define PII pair<int,int>
#define lowbit(x) x&-x
#define ALL(x) x.begin(),x.end()
const int mod = 1e9 + 7;
const int N = 1e5+10;
int d[N],f[N];
int n;
vector<int> g[N];
int ans,pos;
void dfs(int u,int fa) {
for(auto x:g[u]){
if(x==fa) continue;
d[x]=d[u]+1;
if(d[x]>ans){
ans=d[x];
pos=x;
}
dfs(x,u);
}
}
void solve() {
int n;cin>>n;
for(int i=2;i<=n;i++){
int u,v;cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
d[1]=0;
dfs(1,0);
d[pos]=0;
ans=0;
dfs(pos,0);
cout<<ans;
}
signed main() {
std::ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int T = 1;
// cin >> T;
while (T--) {
solve();
}
return 0;
}