본문 바로가기
PS/BOJ

[자바] 백준 23040 - 누텔라 트리 (Easy) (java)

by Nahwasa 2023. 11. 22.

목차

    문제 : boj23040

     

     

    필요 알고리즘

    • 분리 집합 (union-find), 그래프 탐색 (bfs, dfs 등), 트리 (tree)
      • 그래프 탐색과 분리 집합에 대한 개념이 필요한 문제이다.

    ※ 제 코드에서 왜 main 함수에 로직을 직접 작성하지 않았는지, 왜 Scanner를 쓰지 않고 BufferedReader를 사용했는지 등에 대해서는 '자바로 백준 풀 때의 팁 및 주의점' 글을 참고해주세요. 백준을 자바로 풀어보려고 시작하시는 분이나, 백준에서 자바로 풀 때의 팁을 원하시는 분들도 보시는걸 추천드립니다.

     

     

    풀이

      만약 서로 붙어 있는 빨간 정점 그룹들의 갯수를 알고 있다고 해보자.

    그럼 간단하게 모든 검정 정점들에서 간선 1개로 갈 수 있는 모든 정점 중 빨간 정점 그룹의 수를 모두 세주면 끝나는 문제이다. 언제나 말로는 간단한데 구현은 녹록치 않다 ㅋㅋ 

     

    1. 입력 받으면서 간선들을 초기화해준다.

    List<Integer>[] edges = new List[n+1];
    for (int i = 1; i <= n; i++) edges[i] = new ArrayList<>();
    
    for (int i = 1; i < n; i++) {
        StringTokenizer st = new StringTokenizer(br.readLine());
        int a = Integer.parseInt(st.nextToken());
        int b = Integer.parseInt(st.nextToken());
        edges[a].add(b);
        edges[b].add(a);
    }

     

     

    2. 아무 정점이나 하나 골라서 거기서부터 트리 전체를 탐색하며 서로 붙어있는 빨간 정점들을 그룹화 해준다 (분리 집합). 이 때 union-find 알고리즘에서 union을 할 때 음수로 해당 그룹에 포함된 정점의 수를 기록해두었다.

    Queue<Integer> q = new ArrayDeque<>();
    q.add(1);
    Set<Integer> v = new HashSet<>();
    v.add(1);
    
    while (!q.isEmpty()) {
        int cur = q.poll();
        for (int next : edges[cur]) {
            if (v.contains(next)) continue;
            v.add(next);
            if (colors[cur] && colors[next])
                union(cur, next);
    
            q.add(next);
        }
    }

     

     

    3. 이제 검정 정점에서 간선 1개로 갈 수 있는 모든 빨간 그룹의 수를 결과에 더해 출력하면 된다.

    long answer = 0;
    for (int i = 1; i <= n; i++) {
        if (colors[i]) continue;
    
        v = new HashSet<>();
        for (int next : edges[i]) {
            if (!colors[next] || v.contains(find(next))) continue;
            v.add(find(next));
            answer += -parents[find(next)];
        }
    }
    System.out.println(answer);

     

     

    코드 : github

    import java.io.BufferedReader;
    import java.io.InputStreamReader;
    import java.util.*;
    
    public class Main {
    
        static BufferedReader br = new BufferedReader(new InputStreamReader(System.in), 1<<16);
    
        public static void main(String[] args) throws Exception {
            new Main().solution();
        }
    
        private int[] parents;
    
        private int find(int a) {
            if (parents[a] < 0) return a;
            return parents[a] = find(parents[a]);
        }
    
        private void union(int a, int b) {
            a = find(a);
            b = find(b);
            if (a == b) return;
    
            int hi = parents[a] < parents[b] ? a:b;
            int lo = parents[a] < parents[b] ? b:a;
            parents[hi] += parents[lo];
            parents[lo] = hi;
        }
    
        private void solution() throws Exception {
            int n = Integer.parseInt(br.readLine());
            parents = new int[n+1];
            Arrays.fill(parents, -1);
    
            List<Integer>[] edges = new List[n+1];
            for (int i = 1; i <= n; i++) edges[i] = new ArrayList<>();
    
            for (int i = 1; i < n; i++) {
                StringTokenizer st = new StringTokenizer(br.readLine());
                int a = Integer.parseInt(st.nextToken());
                int b = Integer.parseInt(st.nextToken());
                edges[a].add(b);
                edges[b].add(a);
            }
    
            boolean[] colors = new boolean[n+1];
            String tmp = br.readLine();
            for (int i = 1; i <= n; i++) {
                colors[i] = tmp.charAt(i-1)=='R';
            }
    
            Queue<Integer> q = new ArrayDeque<>();
            q.add(1);
            Set<Integer> v = new HashSet<>();
            v.add(1);
    
            while (!q.isEmpty()) {
                int cur = q.poll();
                for (int next : edges[cur]) {
                    if (v.contains(next)) continue;
                    v.add(next);
                    if (colors[cur] && colors[next])
                        union(cur, next);
    
                    q.add(next);
                }
            }
    
            long answer = 0;
            for (int i = 1; i <= n; i++) {
                if (colors[i]) continue;
    
                v = new HashSet<>();
                for (int next : edges[i]) {
                    if (!colors[next] || v.contains(find(next))) continue;
                    v.add(find(next));
                    answer += -parents[find(next)];
                }
            }
            System.out.println(answer);
        }
    }

     

    댓글