트리의 지름
문제
트리의 지름이란, 트리에서 임의의 두 점 사이의 거리 중 가장 긴 것을 말한다.
트리의 지름을 구하는 프로그램을 작성하시오.
입력
트리가 입력으로 주어진다.
먼저 첫 번째 줄에서는 트리의 정점의 개수 V가 주어지고 (2 ≤ V ≤ 100,000)
둘째 줄부터 V개의 줄에 걸쳐 간선의 정보가 다음과 같이 주어진다.
정점 번호는 1부터 V까지 매겨져 있다.
먼저 정점 번호가 주어지고, 이어서 연결된 간선의 정보를 의미하는 정수가 두 개씩 주어지는데,
하나는 정점번호, 다른 하나는 그 정점까지의 거리이다.
예를 들어 네 번째 줄의 경우 정점 3은 정점 1과 거리가 2인 간선으로 연결되어 있고,
정점 4와는 거리가 3인 간선으로 연결되어 있는 것을 보여준다.
각 줄의 마지막에는 -1이 입력으로 주어진다.
주어지는 거리는 모두 10,000 이하의 자연수이다.
출력
첫째 줄에 트리의 지름을 출력한다.
풀이
트리의 지름을 구하는 아이디어를 모르면 풀기 어려운 문제였다.
참고 : https://blog.myungwoo.kr/112
트리의 지름을 구하는 과정은 임의의 점 x를 잡고, 가장 먼 점인 y를 찾는다.
그러고서 y에서 가장 떨어진 정점 z를 찾고, y에서 z까지의 거리가 트리의 지름이다.
머릿속으로 그려보면 직관적으로 와닿기는 하는데 진짜 맞는지 검증이 필요하다.
트리의 길이는 정점 u와 정점 v를 잇는 점이라고 하자. 우선 case를 생각해보면
1
2
3
1. x가 u or v인 경우
2. y가 u or v인 경우
3. x, y가 둘 다 u나 v가 아닌 경우
가 있을 수 있다.
당연하게도 1, 2번 같은 경우에는 y에서 z가 트리의 지름의 길이와 같아진다.
1번은 y가 반대편 지름 정점이 될 것이고, 그러면 z는 다시 x가 된다.
2번은 y가 지름의 정점 중 하나이므로 y서 가장 먼 정점 z를 이으면 지름이 된다.
이제 3번만 생각을 해주면 된다.
우선 u와 v를 잇는 선이 트리의 길이라는 것은 u와 v가 트리에서 가장 먼 점이라는 것이다.
그런데 임의의 점 x에 대해 가장 먼 점이 u나 v가 아닌 y라고 가정하면 모순이 발생한다.
x와 y를 잇는 선이 u와 v를 잇는 선과 독립적이라고 가정하면, 단순하게 x와 y를 잇는 선이 트리에서
가장 길어지게 되므로, 트리의 지름이 u와 v를 잇는 선이라는 가정에 모순이 생긴다.x와 y를 잇는 선이 u와 v를 잇는 선과 겹친다고 가정해보자.
겹치는 정점을 t라고 하면 t에서 가장 먼 정점은 u나 v가 되어야 한다.
x에서 y를 가는 경로는 x -> t -> y인데 t -> y보다 t -> u or t -> v가 더 길다.
그러므로 x -> t -> y보다 x -> t -> u나 x -> t -> y가 반드시 길어지니 조건에 모순이 생긴다.
그렇기 때문에 3번 조건은 모순이므로, 1, 2번 case에 모든 경우가 들어간다.
따라서
1
2
1. 트리의 지름을 구하는 과정은 임의의 점 x를 잡고, 가장 먼 점인 y를 찾는다.
2. 그러고서 y에서 가장 떨어진 정점 z를 찾고, y에서 z까지의 거리가 트리의 지름이다.
를 통해서 트리의 지름을 구할 수 있다.
그러므로 임의의 점 1에서 dfs를 통해 가장 먼 점을 구한 다음
그 점에서 dfs를 통해 가장 먼점과의 거리를 구하면 그것이 트리의 지름이다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <iostream>
#include <vector>
#include <string>
#include <queue>
#include <stack>
#include <set>
#include <functional>
#include <algorithm>
#include <math.h>
#include <cstring>
using namespace std;
int main()
{
ios_base ::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
int a, b, c, d;
int n, m, t;
char ch;
cin >> n;
vector<vector<pair<int, int>>> graph(n+1);
for(int i=0;i<n;i++){
cin >> a;
while(1){
cin >> b;
if(b==-1)break;
cin >> c;
graph[a].push_back({b,c});
}
}
int len=0;
int ans=0;
int point=-1;
vector<bool> visited(n+1, 0);
function <void(int)> dfs = [&](int k){
visited[k]=1;
for(int i=0;i<graph[k].size();i++){
if(visited[graph[k][i].first])continue;
len+=graph[k][i].second;
dfs(graph[k][i].first);
len-=graph[k][i].second;
}
if(ans < len){
ans = len;
point = k;
}
};
dfs(1);
fill(visited.begin(), visited.end(), 0);
dfs(point);
cout << ans;
return 0;
}