[백준 2213] 트리의 독립집합 (파이썬 / Python)
포스트
취소

[백준 2213] 트리의 독립집합 (파이썬 / Python)


📑 문제

문제링크 : 트리의 독립집합


🤔 생각의 흐름

처음에는 각 노드마다 선택했을 때와 선택하지 않았을 때, 두 경우의 해당 노드부터 leaf 노드까지의 가중치 합을 저장하는 dp 리스트에 저장하여 해결했습니다.

하지만 다른 분의 풀이를 보니 굳이 dp 리스트를 사용하지 않고, 노드에 방문할 때마다 선택했을 때와 선택하지 않았을 때를 함께 계산해 나가며 풀었습니다. 이 때의 로직과 코드가 더 간단했습니다.


🎯 풀이방법

트리 문제이며, 시간복잡도는 $O(N)$ 입니다.

arr : 가중치 리스트. 1-based 로 저장하기 위해 맨 앞에 0을 추가했습니다.

tree : 각 노드와 연결되어있는 다른 노드가 저장되는 2차원 리스트. 입력받으면서 어떤 노드가 부모노드인지 알 수 없으므로, 입력받은 두 노드 모두에 서로를 추가해줍니다.

visited : 어떤 노드가 부모노드인지 알기 위해 선언한 boolean 리스트입니다. 재귀적으로 get_ans 함수를 호출할 때, 이미 도달했던 노드가 부모노드이므로 tree[idx] 중 방문하지 않은 노드들이 자식노드입니다.

get_ans(idx) : 부모노드(idx)를 인자로 받아, idx 노드가 선택되었을 때와 안되었을 때, idx 노드부터 leaf 노드까지의 가중치의 합과 포함되는 노드들을 담은 리스트를 반환해주는 함수입니다.

  1. 가중치를 입력받아 1-based 로 저장해줍니다.

  2. 간선들을 입력받습니다. 모든 간선들을 입력받기 전까지는 트리가 어떻게 구성되어있는지 알 수 없으므로, 그래프를 입력받듯이 두 노드에 연결된 노드로 서로를 추가합니다.

  3. 부모노드에서 자식노드 방향으로 가중치와 독립집합을 계산해 나가야 합니다.

  4. 현재 노드가 선택되었을 경우와 안되었을 경우에 따라 행동이 달라집니다.
    1. 현재 노드가 선택되었다면 모든 자식노드들은 선택되어서는 안됩니다. 따라서 현재 노드의 가중치와 (선택하지 않은 상태의 자식노드들 ~ leaf 노드의 가중치 합) 을 더한 값이 현재 노드를 선택했을 때 현재노드 ~ leaf 노드까지의 가중치의 합이 됩니다.

    2. 현재 노드가 선택되지 않았다면 자식 노드들을 선택해도 되고 선택하지 않아도 됩니다. 이 때는 각 자식 노드를 선택했을 때와 선택하지 않았을 때, 두 경우 중 가중치의 합이 더 높은 것을 골라서 더해주면 현재 노드를 선택하지 않았을 때 현재노드 ~ leaf 노드까지의 가중치의 합이 됩니다.

  5. 독립집합의 경우 현재 노드가 선택되었는지 안되었는지 나누어서 생각해야 합니다.
    1. 현재 노드가 선택된 경우 독립집합에 현재 노드를 추가해야합니다. 그리고 선택하지 않은 자식 노드 ~ leaf 노드까지의 독립집합을 현재 독립집합에 추가해줍니다.

    2. 현재 노드가 선택되지 않았을 경우, 자식을 선택했을 때와 선택하지 않았을 때의 가중치 합이 더 높은 쪽의 자식 ~ leaf 노드까지의 독립집합을 현재 독립집합에 추가해줍니다.

  6. 4번과 5번에서 구한 가중치와 독립집합들을 return 해줍니다. (4번과 5번 동작을 해주는 함수가 get_ans 함수입니다.)

  7. 1번 노드를 루트노드로 설정하여 가중치와 독립집합을 get_ans 함수를 이용하여 구하고 출력해줍니다.

🔎 유의할 점

  • 간선들을 다 입력받기 전까지는 어떤 노드가 부모 노드고 자식 노드인지 온전히 알기는 어렵습니다. 그래서 입력을 다 받은 후 루트노드부터 자식노드까지 tree 를 파악하고 탐색을 시작하거나, 방문여부를 저장한 boolean 리스트를 이용해 탐색해가며 어떤 노드가 부모노드이고 자식노드인지 파악해야 합니다.

  • 마지막 독립집합의 노드들을 출력할 때 정렬을 한 후 출력해주어야 합니다.


💻 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from sys import stdin

input = stdin.readline


def solve():
    N = int(input())
    arr = [0, *map(int, input().split())]

    tree = [[] for _ in range(N + 1)]
    for _ in range(N-1):
        a, b = map(int, input().split())
        tree[a].append(b)
        tree[b].append(a)

    visited = [False] * (N + 1)

    def get_ans(idx):
        visited[idx] = True
        sel_w, nsel_w = arr[idx], 0
        sel_n, nsel_n = [idx], []

        for nxt in tree[idx]:
            if not visited[nxt]:
                w_sel, w_nsel, n_sel, n_nsel = get_ans(nxt)
                sel_w += w_nsel
                sel_n += n_nsel
                if w_sel > w_nsel:
                    nsel_w += w_sel
                    nsel_n += n_sel
                else:
                    nsel_w += w_nsel
                    nsel_n += n_nsel

        return sel_w, nsel_w, sel_n, nsel_n

    wa, wb, na, nb = get_ans(1)

    if wa > wb:
        print(wa)
        print(*sorted(na))
    else:
        print(wb)
        print(*sorted(nb))


if __name__ == '__main__':
    solve()
이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.

[백준 1007] 벡터 매칭 (파이썬 / Python)

[알고리즘] 희소 배열(Sparse Table) 알고리즘

Comments powered by Disqus.