본문 바로가기
algorithm

[JAVA] 세그먼트 트리구현법과 lazy propagation (feat. 백준)

by onejunu 2021. 3. 12.

www.acmicpc.net/blog/view/26

 

세그먼트 트리 나중에 업데이트 해야지!

배열 A가 있고, 여기서 다음과 같은 두 연산을 수행해야하는 문제가 있습니다. 10999번 문제: 구간 합 구하기 2 구간 l, r (l ≤ r)이 주어졌을 때, A[l] + A[l+1] + ... + A[r-1] + A[r]을 구해서 출력하기 i번째

www.acmicpc.net

 

여기서 공부한 내용을 바탕으로 간단하게 정리해보고자 한다.

 

세그먼트 트리를 공부하면서 분할정복에 대한 감을 좀더 잘 잡을 수 있는거 같다.

 

세그먼트 트리 구현을 잊어버려도 되지만 감은 계속 붙잡고 있었으면 하는 마음에 본인이 이해한 내용을 그림과 코드로 설명한다.

 

설명은 위에 백준님이 설명해준 그림을 바탕으로 설명하겠다. 먼저 백준님글을 읽고 이해가 완벽하게 된다면 이글은 읽을 필요가 없다.

 

다만, 자바코드를 참고하고 싶다면 맨 아래 첨부해 두었다.

 

 


1. 세그먼트 트리란?

 

한번 더 세그먼트 트리에 대해 정리한다.

 

배열 = [3,6,2,5,3,1,8,9,7,3] 를 세그먼트 트리로 표현한 모습이 아래의 모습이다.

 

 

노드 한개 안에 위에 적힌 숫자는 인덱스의 범위이다. 아래에 적힌 숫자는 인덱스 범위 안의 모든 숫자의 합이다.

꼭 합을 tree의 노드값으로 정할 이유는 없고 문제에 따라 최대값,최소값등 정하면 나름이다. 

 

5~7 을 보면 인덱스 5번 부터 7번 원소까지의 합은 1 + 8 + 9 = 18 이다. 

루트노드의 값은 3 + 6 + 2 + ... + 3 = 47 이다.

 


2. 세그먼트 트리 기본적인 구현법

a =[3,6,2,5,3,1,8,9,7,3] 이 세그먼트 트리를 만들 배열이라고 하자. 

 

tree =[] 를 세그먼트 트리로 사용할 배열이라고 하자. 여기서 1차원 배열을 사용한다.

tree[i] 의 자식은 tree[2*i] 와 tree[2*i+1] 이다.

a 의 배열을 tree로 만들어 버리면 아래와 같다. 

2번째 원소인 19의 자식은 (2*2=4)번째원소인 11과 (2*2+1=5)번째 원소인 8의 합이다.

우리는 tree를 1번 인덱스부터 사용할 것이다. 왜냐하면 tree를 0번 인덱스부터 사용하면 2*i 와 2*i+1 인 자식으로 올바르게 갈수 없기 때문이다.

1번째 원소는 : 47
2번째 원소는 : 19
3번째 원소는 : 28
4번째 원소는 : 11
5번째 원소는 : 8
6번째 원소는 : 18
7번째 원소는 : 10
8번째 원소는 : 9
9번째 원소는 : 2
10번째 원소는 : 5
11번째 원소는 : 3
12번째 원소는 : 9
13번째 원소는 : 9
14번째 원소는 : 7
15번째 원소는 : 3
16번째 원소는 : 3
17번째 원소는 : 6
18번째 원소는 : 0
19번째 원소는 : 0
20번째 원소는 : 0
21번째 원소는 : 0
22번째 원소는 : 0
23번째 원소는 : 0
24번째 원소는 : 1
25번째 원소는 : 8
26번째 원소는 : 0
27번째 원소는 : 0
28번째 원소는 : 0
29번째 원소는 : 0
30번째 원소는 : 0
31번째 원소는 : 0

 

tree의 사이즈는 어떻게 정할까?? ( 이부분은 백준님글에 설명이 안되어 있다)

 

1개의 원소가 있다면 루트노드 한개만 있으면 될것이다.

 

2개의 원소가 있다면 3개의 노드가 필요하다.

왜냐하면 루트노드에는 2개의 원소의 합이 저장되어야 하기 때문이다.

 

3개의 원소가 있다면 6개가 필요하다. (그림으로 생각해보기 바란다.)

 

4개의 원소가 있다면 7개가 필요하다.

5개의 원소가 있다면 9개가 필요하다.

 

..

이진트리에서 리프노드의 숫자가 바로 기존 배열의 숫자이며 이진트리의 모든 노드들의 숫자가 tree의 크기다.

"5개의 원소가 있다면" =>  "5개의 리프노드가 있다면" 으로 바꿀수 있다.

 

다시 질문해서 5개의 리프노드가 있다면 이진트리는 몇층으로 쌓아야 5개의 리프노드를 안정적으로 놓을 수 있는가? 적어도 4층은 쌓아야 5개의 리프노드가 올 수 있다. 4개의 층을 쌓으려면 tree는 총 2^4 개의 노드가 필요하며 최종적으로 tree 사이즈는 16개면 충분하다는 것이다.

 

10개의 원소가 있다면 (=> 10개의 리프노드가 있다면) 5층을 쌓아야하기 때문에 2^5개의 노드는 총 32개가 필요하다.

반드시 32개가 필요하다는 것이 아니라 32개면 충분하다는 의미로 받아들이면 좋겠다.

이러한 계산과정을 백준님은 아래처럼 구현하였다.

    int h = (int)ceil(log2(n));
    int tree_size = (1 << (h+1)) - 1;

 

 

init 에 대해 알아보자. init의 자바코드는 아래와 같다.

    public static long init(int node,int st,int ed){
        if(st == ed ) { // 리프노드인가?
            return tree[node] = a[st]; //그렇다면 tree에 넣고 반환하기
        }
        else{
        // 리프노드가 아니라면 왼쪽,오른쪽 자식에게 갔다와서 그 합을 현재 노드에게 반환한다.
            return tree[node] = init(node*2,st,(st+ed)/2) + init(node*2+1,(st+ed)/2+1,ed);
        }
    }

세그먼트 트리를 구현하면서 node, st, ed는 항상 붙어다닌다.

 

node는 tree에서 쓰는 인덱스이며 st 와 ed는 a 배열의 구간 범위를 나타내는 것으로 a배열에서 사용한다.

st == ed 도 구간의 범위가 처음과 끝이 같다는 말은 리프노드라는 의미다.

 

왼쪽 자식의 node 번호는 현재 node 번호의 2배이며 구간은 전체 구간을 반으로 나눈것이므로 

node*2 , st , (st+ed)/2 로 다음 Init에게 전달해준다. (사실 외우면 편함)

 

왼쪽으로 가기 => node*2 , st , (st+ed)/2

오른쪽으로 가기 => node*2 , (st+ed)/2+1, ed

리프노드 =>  st == ed

 

update에 대해 알아보자. update의 자바코드는 아래와 같다.

특정 인덱스 index에 대해서 변경하는 경우 세그먼트 트리를 업데이트하는 코드다.

    public static void update(int node,int st,int ed,int index,long diff){
        if(index < st || index > ed){ // 바꾸려는 인덱스가 현재 범위(st~ed)에 포함되지 않는다면 패스!
            return;
        }
        else{
            segTree[node] += diff; // 현재 범위에 포함되니까 일단 차이값만큼 수정!
            if(st != ed) { // 리프노드가 아니라면 왼쪽 자식 오른쪽 자식에게 넘기기!
                update(node * 2, st, (st + ed) / 2, index, diff);
                update(node * 2 + 1, (st + ed)/2 + 1, ed, index, diff);
            }
        }
    }

다음 코드는 특정 index에 대해 변경하는 코드가 아니라 어떤 범위에 대해 diff라는 일정한 값을 더해주는 코드다.

left 와 right 에 각각 모두 diff를 더해주는 연산을 하는 코드다.

static void update(int node,int st,int ed,int left,int right,long diff){
        if(st > right || ed < left) return; // left 와 right사이에 포함되지 않는다면 패스!
        // left~right 와 st~ed 의 공통구간을 구한다.
        int t1 = Math.max(st,left); // 왼쪽 시작
        int t2 = Math.min(ed,right); // 오른쪽 끝
        long d = diff * (t2-t1 + 1); // 겹치는 구간만큼 diff를 곱한다.
        tree[node] = tree[node] + d; // 현재 노드에 d 만큼 더해준다.
        if(st != ed){ // 리프 노드가 아니라면
            update(node*2,st,(st+ed)/2,left,right,diff); // 왼쪽 ㄱ
            update(node*2+1,(st+ed)/2+1,ed,left,right,diff); // 오른쪽 ㄱ
        }
    }

 

위 코드는 치명적인 단점이 있는데 만약 모든 인덱스의 값이 변경된다면 모든 세그먼트트리를 수정해야되기 때문에 시간이 상당히 걸릴 수 있다. 어떤 범위에 대해서 똑같은 값으로 수정을 한다면 lazy propagation이라는 기법을 이용하여 업데이트를 좀더 효율적으로 할 수 있다.

실제 아래 문제를 위 update문으로 하면 시간초과로 실패한다.

www.acmicpc.net/problem/10999

 

10999번: 구간 합 구하기 2

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

아래코드는 lazy propagation을 적용한 코드이다.

 

lazy propagation 개념은 간단하다. "자식노드는 나중에 업데이트한다" 이다.

- 업데이트하고자 하는 범위를 left , right 일때 노드를 돌면서 현재 범위(st ~ ed)가 left 와 right에 완전히 포함된다면 더이상 자식노드를 업데이트 하지 않고 lazy 값만 자식노드에게 전파한뒤 return 해버리고 자식노드들에 대해서는 방문할때만 업데이트를 한다는 것이다.

 

백준님의 설명을 보고 왔다는 가정하에 코드로 설명하겠다.

 public static void updateLazy(int node,int st,int ed){
        if(lazy[node]!=0){ // 만약 현재 노드에 lazy값이 있다면
            tree[node] += (lazy[node]*(ed-st+1)); // 현재노드 업데이트

            if(st != ed) { // 만약 자식이 있다면 자식들에게 lazy값을 전파한다.
                lazy[node * 2] += lazy[node];
                lazy[node * 2 + 1] += lazy[node];
            }

            lazy[node] = 0; // lazy값을 업데이트했으므로 0으로 만들어준다.
        }
    }

    public static void updateRange(int node,int st,int ed,int left,int right,long diff){
        // 현재 노드의 lazy를 업데이트한다.
        updateLazy(node,st,ed);

        if(st > right || ed < left) { //현재 범위와 겹치치 않는다면
            return;
        }

        if(left <= st && ed <= right){ //현재 범위와 완전히 겹친다면 더이상 내려가서 업데이트할 필요 없음
            tree[node] += (diff * (ed-st+1));

            if(st != ed){
                lazy[node*2] += diff;
                lazy[node*2+1] += diff;
            }
            return;
        }

        // 걸친다면
        updateRange(node*2,st,(st+ed)/2,left,right,diff); //왼쪽 자식 업데이트
        updateRange(node*2+1,(st+ed)/2+1,ed,left,right,diff); // 오른쪽 자식 업데이트
        tree[node] = tree[node*2] + tree[node*2+1]; // 현재 노드 업데이트

    }

 

이제 lazy propagation에서 구간합을 구하는 코드는 아래와 같다. 주석부분만 빼면 기본적인 sum과 동일하다.

    public static long sum(int node,int st,int ed,int left,int right){
        updateLazy(node,st,ed); // 현재노드가 lazy값이 있는지 반드시 확인하고 업데이트!
        if(st > right || ed < left) return 0;
        if(left <= st && ed <= right) {
            return tree[node];
        }
        return sum(node*2,st,(st+ed)/2,left,right)
                + sum(node*2+1,(st+ed)/2+1,ed,left,right);
    }

 

위 코드는 계속 참고하면서 세그먼트 트리를 풀면 도움될것이다.

 

구간 합 구하기 2

<전체코드>

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

class Main{
    static Fs fs = new Fs();
    static int n,m,k;
    static long[] a;
    static long[] tree;

    static long[] lazy;

    public static long init(int node,int st,int ed){
        if(st == ed ) {
            return tree[node] = a[st];
        }
        else{
            return tree[node] = init(node*2,st,(st+ed)/2) + init(node*2+1,(st+ed)/2+1,ed);
        }
    }

    public static void updateLazy(int node,int st,int ed){
        if(lazy[node]!=0){ // 만약 현재 노드에 lazy값이 있다면
            tree[node] += (lazy[node]*(ed-st+1)); // 현재노드 업데이트

            if(st != ed) { // 만약 자식이 있다면 자식들에게 lazy값을 전파한다.
                lazy[node * 2] += lazy[node];
                lazy[node * 2 + 1] += lazy[node];
            }

            lazy[node] = 0; // lazy값을 업데이트했으므로 0으로 만들어준다.
        }
    }

    public static void updateRange(int node,int st,int ed,int left,int right,long diff){
        // 현재 노드의 lazy를 업데이트한다.
        updateLazy(node,st,ed);

        if(st > right || ed < left) { //현재 범위와 겹치치 않는다면
            return;
        }

        if(left <= st && ed <= right){ //현재 범위와 완전히 겹친다면 더이상 내려가서 업데이트할 필요 없음
            tree[node] += (diff * (ed-st+1));

            if(st != ed){
                lazy[node*2] += diff;
                lazy[node*2+1] += diff;
            }
            return;
        }

        // 걸친다면
        updateRange(node*2,st,(st+ed)/2,left,right,diff); //왼쪽 자식 업데이트
        updateRange(node*2+1,(st+ed)/2+1,ed,left,right,diff); // 오른쪽 자식 업데이트
        tree[node] = tree[node*2] + tree[node*2+1]; // 현재 노드 업데이트

    }

    public static long sum(int node,int st,int ed,int left,int right){
        updateLazy(node,st,ed);
        if(st > right || ed < left) return 0;
        if(left <= st && ed <= right) {
            return tree[node];
        }
        return sum(node*2,st,(st+ed)/2,left,right)
                + sum(node*2+1,(st+ed)/2+1,ed,left,right);
    }

    public static void main(String[] args) throws IOException{
        n = fs.nInt(); m = fs.nInt(); k = fs.nInt();
        a = new long[n];
        int size = 1;
        while(size < n) {size <<=1;} size<<=1;
        tree = new long[size];
        lazy = new long[size];
        for(int i=0;i<n;i++) a[i] = fs.nLong();
        m+=k;

        init(1,0,n-1);


        while(m-- != 0) {
            int cond;
            cond = fs.nInt();
            if(cond == 1){
                int t1 = fs.nInt();
                int t2 = fs.nInt();
                long val = fs.nLong();

                updateRange(1,0,n-1,t1-1,t2-1,val);

            }
            else{
                int t1 = fs.nInt();
                int t2 = fs.nInt();

                System.out.println(sum(1,0,n-1,t1-1,t2-1));
            }

        }


    }



    static class Fs{
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer("");

        public int nInt() throws IOException {
            if(!st.hasMoreElements()) st = new StringTokenizer(br.readLine());
            return Integer.parseInt(st.nextToken());
        }

        public long nLong() throws IOException {
            if(!st.hasMoreElements()) st=  new StringTokenizer(br.readLine());
            return Long.parseLong(st.nextToken());
        }
    }
}

댓글