📑 문제
문제링크 : 피아노 연주
🤔 생각의 흐름
예전에 학교 수업시간에 배웠던 Assembly Line Scheduling 이 떠올랐습니다. dp[i][j] = i 번째에 j 손(0: 왼손, 1: 오른손) 으로 쳤을 때 0번째부터의 거리 합으로 정의하고 풀었더니 틀렸습니다가 나왔습니다.
테스트 케이스 생성하는 코드를 짜서 돌려봤더니 아래와 같은 반례들이 있었습니다.
1
2
3
4
5
6
7
G2 D1
10
A1# E1# G2# A2 D2 D0 A2 C2 E0 C0#
G2# E1#
5
E0 B1 G2 G1 A0#
Assembly Line Scheduling 알고리즘은 i 번째에서 (i+1) 번째로 넘어가는데 더해지는 값이 고정적이지만, 이 문제는 i 번째에서의 손 위치가 이후에 더해지는 값들에도 영향을 주기 때문에 이 알고리즘을 이용할 수 없던 것으로 보입니다.
그래서 다른 점화식을 세웠습니다. arr[i] 를 시간 i 에서 음, l 과 r 을 왼손과 오른손의 위치라고 할 때,
\[dp[i][l][r] = min(|arr[i] - l| + dp[i+1][arr[i]][r], |arr[i] - r| + dp[i+1][l][arr[i]])\]이며, dp[i][l][r] 은 i 번째에서 왼손과 오른손이 l 과 r 에 있을 때, i 번째부터 마지막까지 연주했을 때의 최소 거리입니다. 참고로 위 식에서 min 함수 안의 왼쪽은 왼손을 움직인 경우, 오른쪽은 오른손을 움직인 경우입니다.
🎯 풀이방법
DP 문제입니다.
convert(a) : 음을 나타내는 문자열 a 를 입력받아 C0 = 0 을 기준으로 하여 정수로 변환해주는 함수
arr : 쳐야하는 음들을 저장한 리스트
dp[i][l][r] : 왼손과 오른손이 l 과 r 에 있을 때, i 번째부터 마지막까지 연주했을 때의 거리
calc(depth, left, right) : dfs 방식으로 dp[depth][left][right] 를 계산하고 dp 리스트에 저장하는 함수
초기 왼손의 위치(= l)과 오른손의 위치(= r)를 입력받아 정수로 변환합니다.
쳐야하는 음들을 입력받아 정수로 변환하여 리스트에 저장합니다.
calc(depth=0, left=l, right=r) 을 호출하여 dp 리스트에 올바른 값을 할당해줍니다.
calc 함수에서 위 점화식에 맞게, depth 번째에 왼손과 오른손을 각각 움직였을 때 중 거리의 최소를 구하여 dp 리스트에 저장해줍니다.
i = 0 부터 i = N - 1 까지 왼손과 오른손을 움직였을 때의 거리를 비교하여, 더 적은 손을 움직여가며 움직여야하는 손들을 출력합니다.
🔎 유의할 점
i 번째의 손의 위치에 따라 이후에 더해져야하는 값이 바뀔 수 있으므로, 이 문제에서는 Assembly Line Scheduling 을 쓸 수 없습니다.
점화식을 이해하고 푼다면 크게 유의해야할 부분은 없는 것 같습니다.
💻 코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from sys import stdin, setrecursionlimit as SRL
SRL(10 ** 6)
input = stdin.readline
def solve():
conv = {'C': 0, 'D': 2, 'E': 4, 'F': 5, 'G': 7, 'A': 9, 'B': 11}
def convert(a):
return conv[a[0]] + int(a[1]) * 12 + (len(a) == 3)
l, r = map(convert, input().split())
N = int(input())
arr = [*map(convert, input().split())]
# dp[i][l][r]: i 번째에 왼손과 오른손이 l과 r에 있을 때, 마지막까지 연주하는 거리
dp = [[[-1] * 121 for _ in range(121)] for _ in range(N+1)]
seq = [0] * N
def calc(depth, left, right):
if depth == N:
return 0
elif dp[depth][left][right] != -1:
return dp[depth][left][right]
_l = abs(arr[depth] - left) + calc(depth+1, arr[depth], right)
_r = abs(arr[depth] - right) + calc(depth+1, left, arr[depth])
dp[depth][left][right] = _l if _l < _r else _r
return dp[depth][left][right]
calc(0, l, r)
print(dp[0][l][r])
_l, _r = l, r
for i in range(N):
l_val = abs(_l - arr[i]) + dp[i+1][arr[i]][_r]
r_val = abs(_r - arr[i]) + dp[i+1][_l][arr[i]]
if l_val < r_val:
_l = arr[i]
print(1, end=' ')
else:
_r = arr[i]
print(2, end=' ')
if __name__ == '__main__':
solve()
Comments powered by Disqus.