본문 바로가기
PS/BOJ

[자바] 백준 23324 - 어려운 모든 정점 쌍 최단 거리 (java)

by Nahwasa 2023. 1. 1.

 문제 : boj23324


 

필요 알고리즘 개념

  • 분리 집합 (disjoint set)
    • 그룹의 갯수를 알아야 풀 수 있으므로 내 경우엔 분리 집합 알고리즘을 사용했다.

※ 제 코드에서 왜 main 함수에 로직을 직접 작성하지 않았는지, 왜 Scanner를 쓰지 않고 BufferedReader를 사용했는지 등에 대해서는 '자바로 백준 풀 때의 팁 및 주의점' 글을 참고해주세요. 백준을 자바로 풀어보려고 시작하시는 분이나, 백준에서 자바로 풀 때의 팁을 원하시는 분들도 보시는걸 추천드립니다.

 


 

풀이

  사실 풀이는 금방 생각나긴했지만 가중치가 1과 0이라는 점과, 실제로 탐색을 통해 모든 정점에서 최단 거리를 판단할 경우 무조건 시간초과가 발생하도록 문제가 세팅되었다는 점에서 출제 아이디어가 좋다고 생각했다. 그래프쪽 공부하고 있다면 추천문제라 생각된다.

 

  예제 입력 1까지만 보고 풀어서 몰랐는데 아래쪽에 "어떤 정점 u에서 다른 정점 v까지 한 개 이상의 간선을 이용하여 항상 도달할 수 있다." 라는 힌트가 있었다. 근데 사실 저 조건이 없었어도 상관없다(도달 불가할 경우 0으로 치기만 하면 상관없다.). 결국 중요한건 가중치가 1로 구분 가능한 지점들의 그룹이 어디냐만 알면 된다.

 

 

A. 모두 가중치가 0인 간선만 존재한다면 어떨까?

  이 경우 당연히 모든 정점에서 모든 정점으로의 최단 거리는 각 0일 것이다. 모든 정점 쌍의 최단 거리의 합은 0이다.

 

 

B. 가중치가 1인 간선이 생겼다!

  이 경우에도 당연히 '최단' 거리이므로 가중치 1의 간선은 지나지 않을 것이니, 거리의 합은 0이다.

 

 

C. 가중치 1인 간선을 제거할 경우 그룹이 나누어진다면?

  아래와 같은 경우 가중치 1인 간선이 없다면 {1,2,3}과 {4,5,6} 정점들로 그룹이 나뉜다. 즉, 다르게 말하면 서로 다른 그룹끼리 도달하려면 가중치 1인 간선을 지날수밖에 없다.

 


 

  즉 문제의 핵심은 가중치가 1인 간선을 제거했을 때 2개의 그룹으로 나뉘는지 보면 된다. 이에 적합한 알고리즘은 분리 집합 알고리즘이다. 내 경우엔 union-find 알고리즘을 사용했다.

 

1. K번째 간선을 제외한 나머지 간선을 입력받으면서 u와 v를 union 시켜준다.

for (int i = 1; i <= m; i++) {
    st = new StringTokenizer(br.readLine());
    int u = Integer.parseInt(st.nextToken());
    int v = Integer.parseInt(st.nextToken());

    if (i == k) {
        a = u;
        b = v;
        continue;
    }

    union(u, v);
}

 

2. 두 그룹이 동일한 그룹이라면 'B'의 경우이므로 0을 출력하면 끝이다.

a = find(a);
b = find(b);

if (a == b) {
    System.out.println(0);
    return;
}

 

3. '2' 이외의 경우는 'C'의 경우이다. 이 경우 두 그룹에 포함된 정점끼리는 전부 최단거리가 '1'이다. 따라서 두 그룹의 모든 쌍의 개수를 세주면 되며, 이건 두 그룹에 포함된 정점의 수를 곱해주면 된다. 주의점은 N이 최대 10만이므로, 각 그룹에 포함된 정점의 곱은 최대 50000 x 50000이다. 25억으로 int로 표현 가능한 약 21억을 넘는다. 그러므로 long으로 계산해야 한다.

int aCnt = 0, bCnt = 0;
for (int i = 1; i <= n; i++) {
    int cur = find(i);
    if (cur == a)
        aCnt++;
    else if (cur == b)
        bCnt++;
}

System.out.println(1l*aCnt*bCnt);

 

  풀고보니 그냥 각 그룹의 수만 세면 되므로 DFS 또는 BFS를 써도 될 것 같다. 내가 분리 집합을 먼저 생각한 이유는 풀 때 맨아래 힌트 부분을 못봤기 때문이다 ㅋㅋ DFS, BFS로 그룹들끼리 탐색하는게 귀찮을 것 같아 분리 집합을 생각했었다.

 


 

코드 : github

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.*;

public class Main {
    private static final BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    int[] parents;

    public static void main(String[] args) throws Exception {
        new Main().solution();
    }

    private int find(int a) {
        if (parents[a] < 0)
            return a;
        return parents[a] = find(parents[a]);
    }

    private void union(int a, int b) {
        a = find(a);
        b = find(b);
        if (a == b) return;

        int hi = parents[a]<parents[b]?a:b;
        int lo = parents[a]<parents[b]?b:a;
        parents[hi] += parents[lo];
        parents[lo] = hi;
    }

    public void solution() throws Exception {
        StringTokenizer st = new StringTokenizer(br.readLine());
        int n = Integer.parseInt(st.nextToken());
        int m = Integer.parseInt(st.nextToken());
        int k = Integer.parseInt(st.nextToken());

        parents = new int[n+1];
        Arrays.fill(parents, -1);
        int a = 0, b = 0;
        for (int i = 1; i <= m; i++) {
            st = new StringTokenizer(br.readLine());
            int u = Integer.parseInt(st.nextToken());
            int v = Integer.parseInt(st.nextToken());

            if (i == k) {
                a = u;
                b = v;
                continue;
            }

            union(u, v);
        }

        a = find(a);
        b = find(b);

        if (a == b) {
            System.out.println(0);
            return;
        }

        int aCnt = 0, bCnt = 0;
        for (int i = 1; i <= n; i++) {
            int cur = find(i);
            if (cur == a)
                aCnt++;
            else if (cur == b)
                bCnt++;
        }

        System.out.println(1l*aCnt*bCnt);
    }
}

댓글