본문 바로가기
PS/BOJ

[자바, C++] 백준 2042 - 구간 합 구하기 (java cpp)

by Nahwasa 2022. 8. 12.

 문제 : boj2042


 

필요 알고리즘 개념

  •  펜윅 트리, 세그먼트 트리 등 
    • 펜윅 트리, 세그먼트 트리 등 구간 쿼리 알고리즘(자료구조)를 알고 있어야 풀 수 있다.

※ 제 코드에서 왜 main 함수에 로직을 직접 작성하지 않았는지, 왜 Scanner를 쓰지 않고 BufferedReader를 사용했는지 등에 대해서는 '자바로 백준 풀 때의 팁 및 주의점' 글을 참고해주세요. 백준을 자바로 풀어보려고 시작하시는 분이나, 백준에서 자바로 풀 때의 팁을 원하시는 분들도 보시는걸 추천드립니다.

 


 

풀이

 

  펜윅 트리, 세그먼트 트리 등으로 풀 수 있는 기본형태의 문제이다. 펜윅 트리로 풀려면 작성해둔 펜윅 트리 글에서 '기본 : 펜윅 트리' 부분을 읽어보면 이 문제를 풀 수 있다.

 

 


 

코드(C++ 펜윅 트리) : github

#include <iostream>
using namespace std;
#define ll long long

int n;
ll arr[1000001] = {0,};
ll bit[1000001] = {0,};

ll getPrefixSum(int ith) {
    ll answer = 0l;
    while (ith > 0) {
        answer += bit[ith];
        ith -= ith&-ith;
    }
    return answer;
}

ll query(int b, int c) {
    return getPrefixSum(c) - getPrefixSum(b-1);
}

void update(int ith, long val) {
    ll diff = val - arr[ith];
    arr[ith] = val;

    while (ith <= n) {
        bit[ith] += diff;
        ith += ith&-ith;
    }
}

int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);

    int m, k;
    cin >> n >> m >> k;
    m += k;
    for (int i = 1; i <= n; i++) {
        ll cur;
        cin >> cur;
        update(i, cur);
    }
    while (m--) {
        int a, b, c;
        ll v;
        cin >> a;
        switch (a) {
            case 1:
                cin >> b >> v;
                update(b, v);
                break;
            case 2:
                cin >> b >> c;
                cout << query(b, c) << '\n';
                break;
        }
    }
}

 

 

 

--- 이하 자바 코드 ---

 

코드(세그먼트 트리) : github

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.StringTokenizer;

public class Main {
    static int N;
    static long[] arr, seg;

    private static void init(int n, int s, int e) {
        if (s==e) {
            seg[n]=arr[s];
            return;
        }
        int m = (s+e)/2;
        init(n*2,s,m);
        init(n*2+1,m+1,e);
        seg[n]=seg[n*2]+seg[n*2+1];
    }

    private static void update(int n, int s, int e, int t, long diff) {
        if (t<s || t>e)
            return;
        seg[n]+=diff;
        if (s==e)
            return;
        int m = (s+e)/2;
        update(n*2,s,m,t,diff);
        update(n*2+1,m+1,e,t,diff);
    }

    private static long query(int n, int s, int e, int l, int r) {
        if (s>r || l>e)
            return 0;
        if (l<=s && e<=r)
            return seg[n];
        int m = (s+e)/2;
        long q1 = query(n*2, s, m, l, r);
        long q2 = query(n*2+1, m+1, e, l, r);
        return q1 + q2;
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        StringTokenizer st = new StringTokenizer(br.readLine());
        // input
        N = Integer.parseInt(st.nextToken());
        int M = Integer.parseInt(st.nextToken()) + Integer.parseInt(st.nextToken());
        int h = (int) Math.ceil(Math.log(N) / Math.log(2));
        arr = new long[N+1];
        seg = new long[1<<(h+1)];
        for (int i = 1; i <= N; i++)
            arr[i] = Long.parseLong(br.readLine());
        // seg init
        init(1, 1, N);
        // proc
        while (M-->0) {
            st = new StringTokenizer(br.readLine());
            switch (st.nextToken()) {
                case "1" :
                    int t = Integer.parseInt(st.nextToken());
                    long v = Long.parseLong(st.nextToken());
                    long diff = v-arr[t];
                    arr[t] = v;
                    update(1, 1, N, t, diff);
                    break;
                case "2" :
                    int l = Integer.parseInt(st.nextToken());
                    int r = Integer.parseInt(st.nextToken());
                    bw.write(query(1, 1, N, l, r) + "\n");
                    break;
            }
        }
        bw.flush();
        bw.close();
        br.close();
    }
}

 

 

코드(펜윅트리 - 0인덱스 베이스) : github

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main {
    private long[] bit, arr;
    private int n;

    private long query(int b, int c) {
        return getPrefixSum(c) - getPrefixSum(b-1);
    }

    private long getPrefixSum(int idx) {
        long answer = 0l;
        while (idx >= 0) {
            answer += bit[idx];
            idx = (idx&(idx+1))-1;
        }
        return answer;
    }

    private void update(int idx, long val) {
        long diff = val - arr[idx];
        arr[idx] = val;

        while (idx < n) {
            bit[idx] += diff;
            idx |= idx+1;
        }
    }

    private void solution() throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        n = Integer.parseInt(st.nextToken());
        int mk = Integer.parseInt(st.nextToken()) + Integer.parseInt(st.nextToken());
        arr = new long[n];
        bit = new long[n];
        for (int i = 0; i < n; i++) {
            update(i, Long.parseLong(br.readLine()));
        }
        StringBuilder sb = new StringBuilder();
        while (mk-->0) {
            st = new StringTokenizer(br.readLine());
            switch (Integer.parseInt(st.nextToken())) {
                case 1: update(Integer.parseInt(st.nextToken())-1, Long.parseLong(st.nextToken())); break;
                case 2: sb.append(query(Integer.parseInt(st.nextToken())-1, Integer.parseInt(st.nextToken())-1)).append('\n'); break;
            }
        }
        System.out.print(sb);
    }

    public static void main(String[] args) throws Exception {
        new Main().solution();
    }
}

 

 

코드(펜윅트리 - 1인덱스 베이스) : github

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main {
    private long[] bit, arr;
    private int n;

    private long query(int b, int c) {
        return getPrefixSum(c) - getPrefixSum(b-1);
    }

    private long getPrefixSum(int ith) {
        long answer = 0l;
        while (ith > 0) {
            answer += bit[ith];
            ith -= ith&-ith;
        }
        return answer;
    }

    private void update(int ith, long val) {
        long diff = val - arr[ith];
        arr[ith] = val;

        while (ith <= n) {
            bit[ith] += diff;
            ith += ith&-ith;
        }
    }

    private void solution() throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        n = Integer.parseInt(st.nextToken());
        int mk = Integer.parseInt(st.nextToken()) + Integer.parseInt(st.nextToken());
        arr = new long[n+1];
        bit = new long[n+1];
        for (int i = 1; i <= n; i++) {
            update(i, Long.parseLong(br.readLine()));
        }
        StringBuilder sb = new StringBuilder();
        while (mk-->0) {
            st = new StringTokenizer(br.readLine());
            switch (Integer.parseInt(st.nextToken())) {
                case 1: update(Integer.parseInt(st.nextToken()), Long.parseLong(st.nextToken())); break;
                case 2: sb.append(query(Integer.parseInt(st.nextToken()), Integer.parseInt(st.nextToken()))).append('\n'); break;
            }
        }
        System.out.print(sb);
    }

    public static void main(String[] args) throws Exception {
        new Main().solution();
    }
}

 

 

코드(펜윅트리 - 원본 논문 그대로 구현) : github

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main {
    private long[] bit, arr;
    private int n;

    private long query(int b, int c) {
        return getPrefixSum(c) - getPrefixSum(b-1);
    }

    private long getPrefixSum(int ith) {
        long answer = 0l;
        while (ith > 0) {
            answer += bit[ith];
            ith = ith&(ith-1);
        }
        return answer;
    }

    private void update(int ith, long val) {
        long diff = val - arr[ith];
        arr[ith] = val;

        while (ith <= n) {
            bit[ith] += diff;
            ith += ith&-ith;
        }
    }

    private void solution() throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        n = Integer.parseInt(st.nextToken());
        int mk = Integer.parseInt(st.nextToken()) + Integer.parseInt(st.nextToken());
        arr = new long[n+1];
        bit = new long[n+1];
        for (int i = 1; i <= n; i++) {
            update(i, Long.parseLong(br.readLine()));
        }
        StringBuilder sb = new StringBuilder();
        while (mk-->0) {
            st = new StringTokenizer(br.readLine());
            switch (Integer.parseInt(st.nextToken())) {
                case 1: update(Integer.parseInt(st.nextToken()), Long.parseLong(st.nextToken())); break;
                case 2: sb.append(query(Integer.parseInt(st.nextToken()), Integer.parseInt(st.nextToken()))).append('\n'); break;
            }
        }
        System.out.print(sb);
    }

    public static void main(String[] args) throws Exception {
        new Main().solution();
    }
}

댓글