관련 링크 : https://www.acmicpc.net/problem/1167

백준 1103 : 트리의 지름

트리의 지름이란, 트리에서 임의의 두 점 사이의 거리 중 가장 긴 것을 말한다. 트리의 지름을 구하는 프로그램을 작성하시오.

입력

트리가 입력으로 주어진다. 먼저 첫 번째 줄에서는 트리의 정점의 개수 V가 주어지고 (2 ≤ V ≤ 100,000)둘째 줄부터 V개의 줄에 걸쳐 간선의 정보가 다음과 같이 주어진다. 정점 번호는 1부터 V까지 매겨져 있다.

먼저 정점 번호가 주어지고, 이어서 연결된 간선의 정보를 의미하는 정수가 두 개씩 주어지는데, 하나는 정점번호, 다른 하나는 그 정점까지의 거리이다. 예를 들어 네 번째 줄의 경우 정점 3은 정점 1과 거리가 2인 간선으로 연결되어 있고, 정점 4와는 거리가 3인 간선으로 연결되어 있는 것을 보여준다. 각 줄의 마지막에는 -1이 입력으로 주어진다. 주어지는 거리는 모두 10,000 이하의 자연수이다.

1
2
3
4
5
6
5
1 3 2 -1
2 4 4 -1
3 1 2 4 3 -1
4 2 4 3 3 5 6 -1
5 4 6 -1

출력

첫째 줄에 트리의 지름을 출력한다.

1
11

풀이

DFS를 이용해서 완전 탐색으로 풀려고 시도했으나, 시간 초과를 벗어날 수 없었다. 구글링을 통해 그래프 이론 관련 문제임을 알았고, 트리의 지름을 구하는 방법을 알아낸 후 다시 시도했다.

다른 블로그에서 가져온 트리의 지름 구하는 방법. 증명도 블로그에 있다.

  1. 트리에서 임의의 정점 $x$를 잡는다,
  2. 정점 $x$에서 가장 먼 정점 $y$를 찾는다.
  3. 정점 $y$에서 가장 먼 정점 $z$를 찾는다.
  4. 트리의 지름은 정점 $y$와 정점 $z$를 연결하는 통로다
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import sys
#입력 처리
input = sys.stdin.readline
v = int(input().strip())
tree = [[] for _ in range(v)]
for i in range(v):
    node = list(map(int, input().strip().split()))
    for j in range(1, len(node)//2):
        tree[i-1].append([node[2*j-1]-1, node[2*j]])
        
result = [0 for _ in range(v)] #기준 노드에서 다른 노드까지의 거리를 저장
history = [] #1->3, 3->1 같은 경우를 막기위해 추가
def dfs(start, result):
    for dest, dist in tree[start]:
        history.append(start)
        if result[dest] == 0 and dest not in history:
            result[dest] = result[start] + dist
            dfs(dest, result)
        history.pop()

dfs(0, result) #첫 번째 노드에서 DFS

#가장 먼 거리에 있는 노드 찾기
max_dist = 0
for idx in range(len(result)):
    if result[idx] > max_dist:
        max_idx = idx
        max_dist = result[idx]

result = [0 for i in range(v)]
dfs(max_idx, result) #가장 먼 노드에서 한번 더 DFS
print(max(result))

근데 여기서 계속 틀려서 원인을 찾지 못했는데, 알고 보니 입력이 노드 숫자에 따라 순서대로 들어오는 게 아니어서 그랬다. 그래서 입력 처리를 다음과 같이 수정했다.

1
2
3
4
5
6
v = int(input().strip())
tree = [[] for _ in range(v)]
for _ in range(v):
    i, *node = list(map(int, input().strip().split()))
    for j in range(len(node)//2):
        tree[i-1].append([node[2*j]-1, node[2*j+1]])

그리고 다른 사람들의 코드를 보니 history를 사용하지 않고, 원래 위치로 돌아오는 경우만 마지막에 0으로 만들어 주어 문제를 풀더라. 생각해보니 result[dest]==0에서 방문한 곳이 자동으로 걸러지더라. 그래서 최종적으로 다음과 같이 수정

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import sys
 
input = sys.stdin.readline

v = int(input().strip())
tree = [[] for _ in range(v)]
for _ in range(v):
    i, *node = list(map(int, input().strip().split()))
    for j in range(len(node)//2):
        tree[i-1].append([node[2*j]-1, node[2*j+1]])
        
result = [0 for _ in range(v)]
def dfs2(start, result):
    for dest, dist in tree[start]:
        if result[dest] == 0:
            result[dest] = result[start] + dist
            dfs2(dest, result)

dfs2(0, result)
result[0] = 0 #자기 자신으로 오는 곳 예외처리
max_dist = 0
for idx in range(len(result)):
    if result[idx] > max_dist:
        max_idx = idx
        max_dist = result[idx]

result = [0 for _ in range(v)]
dfs2(max_idx, result)
result[max_idx] = 0 #자기 자신으로 오는 곳 예외처리
print(max(result))

Leave a comment