문제 보기

아이디어

트리를 아래와 같은 자료구조(왼쪽 그림)를 사용하여 표현했다. 자료구조

# q: [(start0, end0), (start1, end1), (start2, end2), ...]
for q in question:
	bfs(q[0])
for q in question:
	print(d[q[0]][q[1]])

# start 노드에서부터 나머지 노드까지의 거리를 계산한다.
def bfs(start):
	start 노드 방문 처리
	q = deque()
	q에 start 노드 삽입
	while q:
		v = q.popleft()
		for v와 연결된 모든 노드 x에 대해:
			if x를 방문하지 않은 경우:
				# (start에서 x까지의 거리) += (start에서 v까지으 거리) + (v와 x 사이 거리)
				d[start][x] += d[start][v] tree[v][x]

오답

import sys
input = sys.stdin.readline
from collections import deque

n, m = map(int, input().strip().split())
tree = [[] for _ in range(n+1)]
d = [[0]*(n+1) for _ in range(n+1)]
for _ in range(n-1):
    s,e,w = map(int, input().strip().split())
    tree[s].append((e,w))
    tree[e].append((s,w))
question = []
for i in range(m):
    s,e = map(int, input().strip().split())
    question.append((s,e))

def bfs(start):
    visited = [False] * (n+1)
    q = deque()
    q.append(start)
    visited[start] = True
    while q:
        v = q.popleft()
        for i in range(len(tree[v])):
            if not visited[tree[v][i][0]]:
                visited[tree[v][i][0]] = True
                d[start][tree[v][i][0]] += d[start][v] + tree[v][i][1]
                q.append(tree[v][i][0])

for i in range(m):
    bfs(question[i][0])

for i in range(m):
    print(d[question[i][0]][question[i][1]])

원인

각 노드 사이의 거리를 기록하는 d 를 전역변수로 선언해버렸기 때문이다.

반례

5 2
1 2 2 
1 5 3
3 5 1
4 5 5
4 3
4 1

위와 같이 m개의 질문에 출발 노드가 같은 질문이 존재할 경우 d의 값이 덮어써지게 되어 실제보다 큰 값이 출력된다.
그리고 어차피 d의 첫 번째 인덱스가 [start]로 고정이기 때문에 굳이 2차원 배열을 사용할 필요가 없다.

정답 코드

import sys
input = sys.stdin.readline
from collections import deque

def bfs(start, end):
    visited = [False] * (n+1)
    q = deque()
    q.append(start)
    visited[start] = True
    d = [0] * (n+1)
    while q:
        v = q.popleft()
        for i in range(len(tree[v])):
            if not visited[tree[v][i][0]]:
                visited[tree[v][i][0]] = True
                d[tree[v][i][0]] += d[v] + tree[v][i][1]
                q.append(tree[v][i][0])
            if tree[v][i][0] == end:
                return d[tree[v][i][0]]

n, m = map(int, input().strip().split())
tree = [[] for _ in range(n+1)]

for _ in range(n-1):
    s,e,w = map(int, input().strip().split())
    tree[s].append((e,w))
    tree[e].append((s,w))

for _ in range(m):
    s,e = map(int, input().strip().split())
    print(bfs(s,e))

Categories:

Updated:

Comments