본문 바로가기

algorithm/백준알고리즘

[백준알고리즘] 1167번: 트리의 지름 -Python

728x90

[백준알고리즘] 1167번: 트리의 지름 -Python

https://www.acmicpc.net/problem/1167

 

1167번: 트리의 지름

트리가 입력으로 주어진다. 먼저 첫 번째 줄에서는 트리의 정점의 개수 V가 주어지고 (2≤V≤100,000)둘째 줄부터 V개의 줄에 걸쳐 간선의 정보가 다음과 같이 주어진다. (정점 번호는 1부터 V까지 매겨져 있다고 생각한다) 먼저 정점 번호가 주어지고, 이어서 연결된 간선의 정보를 의미하는 정수가 두 개씩 주어지는데, 하나는 정점번호, 다른 하나는 그 정점까지의 거리이다. 예를 들어 네 번째 줄의 경우 정점 3은 정점 1과 거리가 2인 간선으로 연결되

www.acmicpc.net

와 이 문제도 되게 오래 걸려서 풀었다... 오래 걸린 이유는 두 가지로 뽑을 수 있을 것 같다.

1. 트리의 지름을 구하는 방법을 몰랐다.

2. 멘탈이 나간 상태에서 이상한 부분을 놓쳤다...

 

1번 이유인 트리의 지름을 구하는 방법을 몰랐기 때문에 정말 엄청 헤매었다. 범위를 생각 안 하고 인접 행렬을 통해 각 점에서 가장 거리가 먼 노드까지의 거리를 구했었다.

그 결과는... 메모리초과가 당연히 발생했고 인접 리스트를 통해 문제를 해결하게 되었다.

하지만 각 점을 모두 반복하면서 가장 거리가 먼 노드를 구하는 코드는 시간제한에 걸리게 되어 시간 초과가 발생한다...

 

여기까지는 뭐 그러려니 했다. 계속 왜 안되지.. 어떻게 해야 하지.. 하다가 질문 게시판을 통해서 구하는 방법을 알게 되었다.

 

 

"트리의 지름을 구하는 공식은 임의의 하나의 노드 A에서 가장 거리가 먼 노드 B를 구하고, 이 노드 B에서 가장 거리가 먼 노드 C를 구하게 되었을 때, B와 C 사이의 거리가 트리의 지름이 된다."

 

A에서 가장 거리가 먼 노드 B를 구하게 되었을 때, 이 B가 반드시 트리의 지름을 구성하는 양 끝 노드 중의 하나가 된다. 설명은 참고할 사람만..

더보기

이해한 대로 설명을 조금 하자면 트리의 임의의 두 점 x와 x에서 가장 거리가 먼 y가 있다고 하자. 그리고 트리의 지름을 이루는 두 노드를 u, v라고 하자.

y가 u 또는 v라면 당연히 y에서 가장 거리가 먼 노드까지의 거리가 트리의 지름이 된다.

하지만 y가 u 또는 v가 아닌 다른 노드가 될 수 있는 가를 보겠다. x를 루트로 하는 트리가 존재한다고 했을 때, x가 아닌 u, v, y의 공통 조상인 t가 있다고 하자. 그렇다면 d(x, y) = d(x, t) + d(t, y)이고, d(t, y)가 d(t, u)와 d(t, v) 보다 크다는 것이 된다. 하지만 트리의 지름은 d(t, u) + d(t, v)인 것인데, d(t, y)가 d(t, u)와 d(t, v) 보다 크다면 지름은 반드시 d(t, u) + d(t, v)는 최댓값이 될 수 없기 때문에 트리의 지름이 될 수 없다.

 

아무튼 이 공식을 이용해서 풀자고 생각을 했는데 계속해서 틀리는 것이다... 정말 다른 사람의 코드를 봐도 똑같은 것 같고 마음을 비우고 코드를 다 지우고 다시 풀어도 똑같고..

 

그러다가 print()를 잔뜩 넣어서 어디가 틀린건지 확인하려 여러 예제를 넣고 알게 되었다... 아래 해결한 코드의 get_farthest() 메서드에서 stack = link[i][:]link[i]의 리스트를 복사하는 부분에서 처음에는 link[i]로 가리키기만 했었는데 여기서 발생한 문제였다.

 

처음 1에서 가장 먼 노드를 찾게 되는데, 이때 stack = link[1] 이 되면서, stack.pop()stack.append()를 통해 나중에는 link[1] = []로 빈 리스트가 되어버리는 것이다.... 그래서 [:]로 복사만 해주었더니 통과했다... 울고 싶다... 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import sys
 
def get_farthest(i):
   farthest, dist = 00
    visit = [False] * (v+1)
    visit[i] = True
    stack = link[i][:]
 
    for s in stack:
        visit[s[0]] = True
        if s[1> dist:
           farthest = s[0]
            dist = s[1]
    
    while stack:
        bridge, now = stack.pop()
        for b in link[bridge]:
            if not visit[b[0]]:
                visit[b[0]] = True
                new = now + b[1]
                stack.append((b[0], new))
                if new > dist:
                   farthest = b[0]
                    dist = new
 
    return farthest, dist
 
= int(sys.stdin.readline())
link = {}
for _ in range(v):
    data = list(map(int, sys.stdin.readline().split()))
    link[data[0]] = []
    for i in range(1len(data)-12):
        link[data[0]].append((data[i], data[i+1]))
 
farthest, dist = get_farthest(1)
sys.stdout.write(str(get_farthest(farthest)[1]))

 

잘못된 점이나 부족한 점 지적해주시면 감사하겠습니다

728x90