📑 문제
문제링크 : 친구 네트워크
🤔 생각의 흐름
Disjoint Set 자료구조가 필요한 문제입니다. 귀찮은 부분은 이름이 숫자로 주어진 것이 아니라 문자열로 이루어진 점과 사람이 총 몇 명인지 모르는 점입니다.
문자열 이름은 딕셔너리를 이용해서 처리해주면 될 것 같네요. 처리해주는 방법은 디테일하게 2가지가 있습니다.
딕셔너리 key 값에 사람 이름을 담고, value 값으로 루트노드가 아니라면 부모의 이름을, 루트 노드라면 해당 tree 의 노드 갯수를 저장하는 방법
딕셔너리 key 값에 사람 이름을 담고 value 값으로 리스트의 인덱스 번호를 저장하여, 리스트를 이용해 disjoint set 을 구현하는 방법
설명은 1번 방법으로 설명하겠습니다만 코드는 1번 풀이와 2번 풀이 둘 다 작성하였습니다.
🎯 풀이방법
Disjoint Set 문제입니다.
disjoint_set : 딕셔너리로 이루어진 Disjoint Set 자료구조. key 에는 사람의 이름이, value 에는 자식노드라면 부모 이름, 부모노드라면 해당 트리의 원소 갯수가 담깁니다.
union(a, b) : disjoint_set 에서 a 를 b 의 자식으로 합치는 함수입니다. b 는 노드 갯수가 증가하며 a 는 value 값이 b로 바뀝니다.
find(a) : disjoint_set 에서 a 의 루트노드를 찾아주는 함수입니다.
disjoint set 을 위한 딕셔너리를 만들어줍니다.
2명씩 사람 이름을 입력받습니다. 만약 처음 입력받는 이름이라면 disjoint set 딕셔너리에 1 값을 value 로 갖도록 넣어주어 루트노드로 만들어줍니다.
- 두 사람의 루트노드를 find 함수를 이용해 구해줍니다.
- 동일한 루트 노드를 가지는 경우, 아무것도 해주지 않습니다.
- 다른 루트노드를 가진 경우, 노드의 갯수가 적은 트리를 더 많은 트리로 합쳐줍니다.
- 3번을 수행한 후 루트노드의 value 값을 출력해줍니다.
🔎 유의할 점
동일한 루트를 갖는 사람들을 입력받은 경우 union 을 해주면 안됩니다. (해줬을 때 메모리초과가 나오더라구요.)
union 이나 find 함수를 재귀함수로 구현하는 경우, 재귀함수의 깊이가 1000이 넘어갈 수 있으니 setrecursionlimit 으로 재귀함수 깊이를 높여주어야 합니다.
💻 코드
딕셔너리만 이용
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from sys import stdin
input = stdin.readline
def solve():
TC = int(input())
# a를 b의 자식으로 연결
def union(a, b):
disjoint_set[b] += disjoint_set[a]
disjoint_set[a] = b
return disjoint_set[b]
def find(a):
while isinstance(disjoint_set[a], str):
a = disjoint_set[a]
return a
for _ in range(TC):
F = int(input())
# 자식이름: 부모이름 or 집합 갯수
disjoint_set = {}
for _ in range(F):
a, b = input().split()
if a not in disjoint_set:
disjoint_set[a] = 1
if b not in disjoint_set:
disjoint_set[b] = 1
a, b = find(a), find(b)
if a == b:
print(disjoint_set[a])
else:
a_val, b_val = disjoint_set[a], disjoint_set[b]
if a_val > b_val:
print(union(b, a))
else:
print(union(a, b))
if __name__ == '__main__':
solve()
리스트도 이용
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from sys import stdin
input = stdin.readline
def solve():
TC = int(input())
# a를 b의 자식으로 연결
def union(a, b):
disjoint_set[b] += disjoint_set[a]
disjoint_set[a] = b
return -disjoint_set[b]
def find(a):
while disjoint_set[a] >= 0:
a = disjoint_set[a]
return a
for _ in range(TC):
F = int(input())
disjoint_set = []
name = {}
for _ in range(F):
a, b = input().split()
if a not in name:
name[a] = len(name.keys())
disjoint_set.append(-1)
if b not in name:
name[b] = len(name.keys())
disjoint_set.append(-1)
a, b = find(name[a]), find(name[b])
if a == b:
print(-disjoint_set[a])
elif a < b:
print(union(b, a))
else:
print(union(a, b))
if __name__ == '__main__':
solve()
Comments powered by Disqus.