1. 문제 설명
난 하다가 Pypy로 제출해버렸는데, 지금 생각해보면 몇 가지 수정을 해주면 Python으로 충분히 통과할 수 있을 것 같다.
2. 아이디어
전 단계 방법을 그대로 사용해서 다익스트라 알고리즘을 곁들였는데 Python으로는 시간초과가 나서
Pypy로 제출했다.
포스팅을 하다가 갑자기 Pypy로 제출해서 통과한 게 자존심이 상해서 관련 알고리즘을 뒤져보니
HLD(Heavy Light Demoposion) 알고리즘이라는 걸 발견했다.
가장 무거운 간선을 따라 내려간 경로를 찾아 트리를 묶으면,
트리 위의 어떠한 경로도 길이가 O(logn)을 넘지 않는다.
트리를 무식하게 순회하면 O(N)라는 시간 복잡도가 발생하므로 LCA라는 알고리즘을 사용하는 건데
이 알고리즘은 공간 복잡도가 O(NlogN)이기 때문에 현재 문제가 발생한다.
대충 원리만 설명하면 비선형 구조 트리를 선형 구조 트리로 만들어 버리면 segment tree 기법을 사용할 수 있게 된다.
균형 잡히지 않은 트리에서 균형 잡힌 트리의 구역(chain)을 나누어 다루는 것이다.
자세한 설명은 나중에 알고리즘 탭에서 할 거고 간단히 코드에 대한 설명을 하면
처음 dfs는 그래프의 기본적인 정보를 파악하여 현재(current)노드의 깊이, 부모 정점, 부트리의 정점의 수(무게)를
재귀적으로 계산한다.
그 다음 dfs는 root 노드를 시작으로 체인을 형성한다.
각 정점에서 가장 큰 자식만 정점의 체인을 이어나가고 나머지는 새로운 체인의 시작이 된다. (분가 형성)
hld 알고리즘 수행을 통해 얻어낸 chain들을 가지고 다시 LCA를 수행하면 해결되는 문제였다.
3. 코드
import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
def dfs(cur, pre):
parent[cur] = pre # 현재 노드의 부모 노드 정보 업데이트
size[cur] = 1 # 부트리 정점의 수 계산
for child in tree[cur]:
if child == pre: # 부모 노드 제외
continue
size[cur] += dfs(child, cur) # 재귀 빠져나오면서
return size[cur]
def hld(cur, prev, chain_idx, depth):
chain_depth[cur] = depth # 체인 구분
belong_chain[cur] = chain_idx #
rel_pos_in_chain[cur] = len(chains[chain_idx]) # 현재 만들고 있는 체인 길이
chains[chain_idx].append(cur) # 체인에 포함된 노드 업데이트
max_idx = 0
for child in tree[cur]:
if child == prev:
continue
if size[child] > size[max_idx]:
max_idx = child
if max_idx != 0:
hld(max_idx, cur, chain_idx, depth) # 현재 체인 확장
for child in tree[cur]:
if child == prev or child == max_idx:
continue
hld(child, cur, child, depth+1) # 새로운 체인 생성
def lca(a, b):
while belong_chain[a] != belong_chain[b]: # 깊이 맞추기
if chain_depth[a] > chain_depth[b]:
a = parent[belong_chain[a]]
else:
b = parent[belong_chain[b]]
if rel_pos_in_chain[a] > rel_pos_in_chain[b]:
return b
else:
return a
def solution():
m = int(input())
dfs(1, 0) # parent info update
hld(1, 0, 1, 0)
for _ in range(m):
a, b = map(int, input().split())
print(lca(a, b))
if __name__ == "__main__":
n = int(input())
parent = [0] * (n+1) # 각 노드의 부모 노드 info
d = [0] * (n+1) # 각 노드까지 깊이
size = [0] * (n+1)
tree = [[] for _ in range(n+1)]
belong_chain = [-1] * (n+1)
rel_pos_in_chain = [-1] * (n+1)
chain_depth = [-1] * (n+1)
chains = [[] for _ in range(n+1)]
for _ in range(n-1):
u, v = map(int, input().split())
tree[u].append(v)
tree[v].append(u)
solution()