1240 노드 사이의 거리
아이디어
트리를 아래와 같은 자료구조(왼쪽 그림)를 사용하여 표현했다.
# q: [(start0, end0), (start1, end1), (start2, end2), ...]
for q in question:
bfs(q[0])
for q in question:
print(d[q[0]][q[1]])
# start 노드에서부터 나머지 노드까지의 거리를 계산한다.
def bfs(start):
start 노드 방문 처리
q = deque()
q에 start 노드 삽입
while q:
v = q.popleft()
for v와 연결된 모든 노드 x에 대해:
if x를 방문하지 않은 경우:
# (start에서 x까지의 거리) += (start에서 v까지으 거리) + (v와 x 사이 거리)
d[start][x] += d[start][v] tree[v][x]
오답
import sys
input = sys.stdin.readline
from collections import deque
n, m = map(int, input().strip().split())
tree = [[] for _ in range(n+1)]
d = [[0]*(n+1) for _ in range(n+1)]
for _ in range(n-1):
s,e,w = map(int, input().strip().split())
tree[s].append((e,w))
tree[e].append((s,w))
question = []
for i in range(m):
s,e = map(int, input().strip().split())
question.append((s,e))
def bfs(start):
visited = [False] * (n+1)
q = deque()
q.append(start)
visited[start] = True
while q:
v = q.popleft()
for i in range(len(tree[v])):
if not visited[tree[v][i][0]]:
visited[tree[v][i][0]] = True
d[start][tree[v][i][0]] += d[start][v] + tree[v][i][1]
q.append(tree[v][i][0])
for i in range(m):
bfs(question[i][0])
for i in range(m):
print(d[question[i][0]][question[i][1]])
원인
각 노드 사이의 거리를 기록하는 d 를 전역변수로 선언해버렸기 때문이다.
반례
5 2
1 2 2
1 5 3
3 5 1
4 5 5
4 3
4 1
위와 같이 m개의 질문에 출발 노드가 같은 질문이 존재할 경우 d의 값이 덮어써지게 되어 실제보다 큰 값이 출력된다.
그리고 어차피 d의 첫 번째 인덱스가 [start]
로 고정이기 때문에 굳이 2차원 배열을 사용할 필요가 없다.
정답 코드
import sys
input = sys.stdin.readline
from collections import deque
def bfs(start, end):
visited = [False] * (n+1)
q = deque()
q.append(start)
visited[start] = True
d = [0] * (n+1)
while q:
v = q.popleft()
for i in range(len(tree[v])):
if not visited[tree[v][i][0]]:
visited[tree[v][i][0]] = True
d[tree[v][i][0]] += d[v] + tree[v][i][1]
q.append(tree[v][i][0])
if tree[v][i][0] == end:
return d[tree[v][i][0]]
n, m = map(int, input().strip().split())
tree = [[] for _ in range(n+1)]
for _ in range(n-1):
s,e,w = map(int, input().strip().split())
tree[s].append((e,w))
tree[e].append((s,w))
for _ in range(m):
s,e = map(int, input().strip().split())
print(bfs(s,e))
Comments