본문 바로가기
Computer Science/Algorithm

[알고리즘/파이썬] 세그먼트 트리 segment tree

by ggyongi 2022. 3. 27.
반응형

세그먼트 트리는 언제 쓰나?

배열과 같은 자료형에서 특정 구간에 속한 원소들의 연산(합, 최솟값, 최댓값 등)을 알아볼 때 선형 탐색보다 좀 더 효율적인 탐색이 가능하다. 그럼 누적합이랑 비슷한거 아닌가? 합만을 알려주는 누적합보다 좀 더 폭넓게 활용이 가능하다. 특히 특정 데이터가 변경되었을 때 누적합은 O(N)으로 누적합 데이터를 업데이트해줘야 하지만, 세그먼트는 O(logN)에 특정 데이터 변경을 반영할 수 있다. 

 

사실 알고리즘보다는 자료구조에 가깝다.

대략 아래의 그림처럼, 어떤 배열이 존재할 때 세그먼트 트리는 그 배열의 특정 구간에 대한 정보를 추가로 담고 있게 된다. 이해가 안되면 그림을 보면 된다. 아래의 그림에서 트리의 루트 노드는 0~8번 데이터를 포함하는 정보를 담고 있다. 예를 들어, 만약 세그먼트 트리가 구간합 정보를 담고 있다면(세그먼트 트리는 구간합 말고도 최솟값, 최댓값 구할 때도 사용됨), 루트 노드에서는 배열의 0번부터 8번까지의 합을 담고 있다고 생각하면 된다. 그 루트 노드의 왼쪽 자식은 0번부터 4번까지의 합을, 루트 노드의 오른쪽 자식은 5번부터 8번까지의 합을 담고 있는 것이다. 

 

그렇기 때문에 구간합을 트리의 노드 정보를 이용하여 매우 빠르게 구해낼 수 있다. 아래 그림을 보자.

배열의 2번부터 8번까지의 원소의 합을 빠르게 구하고 싶으면?

=> 세그먼트 트리에서 주황색으로 색칠된 노드 3개가 담고 있는 합만 더해주면 된다.

 

 

세그먼트 트리 구현

세그먼트 트리는 재귀 구조를 이용하여 구현한다. 세그먼트 트리를 구현할 때 노드의 인덱스를 주의깊게 살펴볼 필요가 있다. 특이한 점은, 노드의 인덱스가 1부터 시작한다. 즉 루트 노드의 인덱스는 1번이다. 

 

파이썬 구현 코드 - 재귀 구조 사용 

num_lst = [1, 2, 5, 3, 9, 6, 5, 3, 2]
print('num_lst : ', num_lst)

# create segment tree
stree = [0 for x in range(4*len(num_lst))]  # 길이는 넉넉하게 4N으로 설정

def merge(left, right):
    return left + right

def build(stree, node, left, right):
    # leaf node
    if left == right:
        stree[node] = num_lst[left]
        return stree[node]

    mid = left + (right - left)//2
    left_val = build(stree, 2*node, left, mid)
    right_val = build(stree, 2*node + 1, mid + 1, right)
    stree[node] = merge(left_val, right_val)
    return stree[node]

## stree의 1번 노드는 num_lst의 0번부터 len(num_lst)-1번까지의 정보를 담고 있다.
## 이 루트 노드를 시작으로, 재귀 구조를 통해 자식 노드들로 쭉쭉 뻗어 나간다.
build(stree, 1, 0, len(num_lst)-1)

## 결과 출력
print("==segment tree==")
print(stree[1])
print(stree[2], stree[3])
print(stree[4], stree[5], stree[6], stree[7])
print(stree[8], stree[9], stree[10], stree[11], stree[12], stree[13], stree[14], stree[15])
print(stree[16], stree[17])

 출력 결과

num_lst :  [1, 2, 5, 3, 9, 6, 5, 3, 2]
==segment tree==
36
20 16
8 12 11 5
3 5 3 9 6 5 3 2
1 2

 

특징 1.

위의 코드에서 merge 함수를 보면 left와 right를 더한다고 되어있으므로 이 세그먼트는 구간합 정보를 담고 있는 세그먼트 트리가 된다. merge 함수를 적절히 바꾸면 최댓값, 최솟값 등을 담고 있는 형태로 변경이 가능하다. 

 

특징 2.

세그먼트 트리는 트리 구조를 갖지만 실제로는 배열로 구현하는 것이 편하다. 이때 배열의 길이는 넉넉하게 잡아준다. 트리의 루트 노드 인덱스를 1이라고 하기로 했으므로, 이 배열의 0번째에는 그냥 값이 비어있게 된다. 

 

특징 3. 

함수 build의 파라미터로 stree, node, left, right를 넘겨 준다. 

stree는 세그먼트 트리 정보를 담고 있는 배열,

node는 현재 탐색하고 있는 노드 번호,

left와 right는 node의 적용 범위다. 즉 node에는 원래 배열(num_list)의 left번부터 right번까지 숫자의 합이 들어있다. 

 

 

세그먼트 트리 적용

def query(start, end, node, left, right):
    # 나의 담당구역 left~right 사이에 start~end 가 아예 포함되지 않는 경우 => 가지 치기
    if end < left or start > right:
        return 0

    # 나의 담당구역 left~right 전체가 start~end 에 포함되는 경우 => 내 정보를 바로 리턴
    if start <= left and right <= end:
        return stree[node]

    mid = left + (right - left) // 2
    left_val = query(start, end, 2 * node, left, mid)
    right_val = query(start, end, 2 * node + 1, mid + 1, right)
    return merge(left_val, right_val)

print("==query==")
print('sum 0 to 5 : ', query(0, 5, 1, 0, len(num_lst)-1))  # 0~5번까지의 합
print('sum 4 to 7 : ', query(4, 7, 1, 0, len(num_lst)-1))  # 4~7번까지의 합

출력 결과

==query==
sum 0 to 5 :  26
sum 4 to 7 :  23

 

세그먼트 트리를 구현했으면, 이 트리의 장점을 써먹을 쿼리 처리 함수를 만들어준다.

위의 코드에서 함수 query는 파라미터로 start, end, node, left, right를 갖는다.

start, end => 우리가 관심 있어하는 구간이다. ex) 쿼리 예시 : start번부터 end번까지의 숫자 합을 구하여라

node => 현재 탐색 중인 노드의 번호

left, right => 현재 노드의 포용 범위. 즉 left~right의 정보를 담고 있음

 

 

직접 실행을 따라가보자.

0번부터 5번까지의 구간합을 구하라는 문제를 만났을 때, 1번 노드(루트 노드)는 0~8번까지의 정보를 담고 있다. 이 구간은 우리의 관심 구간(0~5번)을 포함하면서 더 큰 범위다. 이때는 0~4, 5~8로 쪼개져서 재귀 함수를 타게 된다. 

 

2번 노드를 보자. 2번 노드는 0~4번의 정보를 담고 있다. 이 구간은 우리가 관심있는 0~5번에 완전 포함된다. 이때는 더 이상 아래로 탐색할 필요 없이 본인이 가지고 있는 정보를 넘기면 된다. 

한편 3번 노드는 5~8번의 정보를 담고 있고, 이 구간은 우리의 남은 관심 구간(5번)을 포함한다. 따라서 1번 노드와 마찬가지로 쪼개진다. 5~6, 7~8로 쪼개져서 다시 재귀함수를 타게 된다.

이후 6번 노드는 또 다시 우리의 관심 구간(5번)을 포함하므로 쪼개진다. 5, 6으로 쪼개짐.

이후 12번 노드에 이르게 되면 그제서야 관심 구간(5번)에 노드의 정보(5번)이 포함되므로 탐색을 멈추고 정보를 넘겨준다. 

 

 

 

세그먼트 트리 업데이트

def update(idx, val, node, left, right):
    # 나의 담당구역 left~right 사이에 target이 없는 경우 => 기존 값을 그대로 반환(변경 필요 x)
    if idx < left or idx > right:
        return stree[node]

    # leaf node: 바꿀 idx 값 발견 => 값 변경
    if left == right:
        stree[node] = val
        return stree[node]

    mid = left + (right - left) // 2
    left_val = update(idx, val, 2*node, left, mid)
    right_val = update(idx, val, 2*node + 1, mid + 1, right)
    stree[node] = merge(left_val, right_val)
    return stree[node]

update(3, 100, 1, 0, len(num_lst)-1)  # 3번째 숫자를 100으로 변경
print("==update==")
print(stree[1])
print(stree[2], stree[3])
print(stree[4], stree[5], stree[6], stree[7])
print(stree[8], stree[9], stree[10], stree[11], stree[12], stree[13], stree[14], stree[15])
print(stree[16], stree[17])

출력 결과: 업데이트된 세그먼트 트리.

잘 보면 루트 노드의 값이 기존엔 36이었는데 지금은 133이 되었다. 이는 우리가 3번째 배열 값을 3에서 100으로 바꿔주었기 때문에 그 변화가 반영된 것이다.

==update==
133
117 16
8 109 11 5
3 5 100 9 6 5 3 2
1 2

배열의 특정 값을 변경하면, 세그먼트 트리는 그 결과를 O(logN)시간에 반영한다. 이것이 세그먼트 트리의 장점.

업데이트 함수는 기존의 함수들과 유사하다.

탐색을 하면서 내가 찾는 idx를 찾으면, 내가 원하는 값으로 바꿔주고, 

노드가 담고 있는 구간(left~right)에 idx가 포함되어 있지 않으면 어차피 그 노드를 포함하여 자식 노드들은 값의 변경이 없기 때문에 탐색을 멈추면 된다.

 

비전공자 네카라 신입 취업 노하우

시행착오 끝에 얻어낸 취업 노하우가 모두 담긴 전자책!

kmong.com

댓글