본문 바로가기
algorithm

[JAVA] 세그먼트 트리 ( SegmentTree )

by onejunu 2021. 3. 7.

www.acmicpc.net/blog/view/9

 

세그먼트 트리 (Segment Tree)

문제 배열 A가 있고, 여기서 다음과 같은 두 연산을 수행해야하는 문제를 생각해봅시다. 구간 l, r (l ≤ r)이 주어졌을 때, A[l] + A[l+1] + ... + A[r-1] + A[r]을 구해서 출력하기 i번째 수를 v로 바꾸기. A[i

www.acmicpc.net

백준 사이트의 설명을 보며 나름 자바 코드로 정리해봤다.

 

<자바 코드> 

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

class Main{

    static long[] segTree;
    static long[] a;

    /**
     *  node 번호는 1번부터 1<<(h+1) 번까지 있음.
     * */
    public static long init(int node,int st,int ed){
        if(st == ed) { // leaf node
            return segTree[node] = a[st];
        }
        else{ // not leaf node
            return segTree[node] = init(node*2,st,(st+ed)/2) + init(node*2 + 1,(st+ed)/2+1,ed);
        }
    }

    /**
     *
     * update
     *
     * */

    public static void update(int node,int st,int ed,int index,long diff){
        if(index < st || index > ed){
            return;
        }
        else{
            segTree[node] += diff;
            if(st != ed) {
                update(node * 2, st, (st + ed) / 2, index, diff);
                update(node * 2 + 1, (st + ed)/2 + 1, ed, index, diff);
            }
        }
    }

    /**
     *
     *  sum
     * */

    public static long sum(int node,int st,int ed,int left,long right){
        if(left > ed || right < st){ // 구하고자 하는 범위에 포함되지 않을때
            return 0;
        }
        if(left <= st && ed <= right){ // 완전히 포함되는 경우
            return segTree[node];
        }
        return sum(node*2,st,(st+ed)/2,left,right) + sum(node*2+1,(st+ed)/2+1,ed,left,right);
    }


    public static void main(String[] args) throws Exception{
        int n,m,k;
        int tt;
        Fs fs = new Fs();
        n = fs.nInt(); // 수의 개수

        m = fs.nInt(); // 변경횟수
        k = fs.nInt(); // 구해야할 구간합 수

        tt = m+k; // 총 반복해야하는 경우의 수

        a = new long[n];
        int h = (int)Math.ceil(Math.log(n) / Math.log(2));
        int treeSize = (1<<(h+1));
        segTree = new long[treeSize];

        for(int i=0;i<n;i++){
            a[i] = fs.nInt();
        }

        init(1,0,n-1);


        while(tt-- != 0){
            int t1,t2;
            long t3;
            t1 = fs.nInt(); t2 = fs.nInt(); t3 = fs.nLong();
            if(t1==1){ // update
                long diff = t3 - a[t2-1];
                a[t2-1] = t3;
                update(1,0,n-1,t2-1,diff);
            }
            else{
                System.out.println(sum(1,0,n-1,t2-1,t3-1));
            }

        }
    }


    static class Fs{
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer("");

        public int nInt() throws Exception{
            if(!st.hasMoreElements()) st= new StringTokenizer(br.readLine());
            return Integer.parseInt(st.nextToken());
        }

        public long nLong() throws Exception{
            if(!st.hasMoreElements()) st= new StringTokenizer(br.readLine());
            return Long.parseLong(st.nextToken());
        }
    }
}

 

댓글