본문 바로가기
PS/BOJ

[자바] 백준 1311 - 할 일 정하기 1 (java)

by Nahwasa 2023. 8. 3.

목차

    문제 : boj1311

     

     

    필요 알고리즘

    • bit DP (비트 필드를 이용한 다이나믹 프로그래밍)
      • 비트 단위 DP로 푼 문제이다.

    ※ 제 코드에서 왜 main 함수에 로직을 직접 작성하지 않았는지, 왜 Scanner를 쓰지 않고 BufferedReader를 사용했는지 등에 대해서는 '자바로 백준 풀 때의 팁 및 주의점' 글을 참고해주세요. 백준을 자바로 풀어보려고 시작하시는 분이나, 백준에서 자바로 풀 때의 팁을 원하시는 분들도 보시는걸 추천드립니다.

     

     

    풀이

      문제를 좀 더 간소화시켜서, N이 항상 3이라고 해보자. 이 때 dp[A][B] 를 A번 사람을 B라는 일의 조합으로 진행한 경우 비용의 최솟값이라고 정의하자.

     

      이 경우 B는 3개의 일을 가지고 할 수 있는 모든 경우를 나타낼 수 있으면 된다.

    0. 아무일도 안한 경우

    1. 1번일을 한 경우

    2. 2번일을 한 경우

    3. 3번일을 한 경우

    4. 1번과 2번 일을 한 경우

    5. 2번과 3번 일을 한 경우

    6. 1번과 3번 일을 한 경우 

    7. 1,2,3번 일을 모두 한 경우

     

      그럼 우선 dp[][] 를 모두 무한대 값으로 초기화 시켜두고, dp[0][0] = 0 으로 시작할꺼다.

     

    3
    2 3 3
    3 2 3
    3 3 2

      예제 1 을 가지고 살펴보자. A=1 까지 dp[A][B]를 그려보면 이하와 같다. INF는 무한대 값임을 뜻한다. 여기까지는 해석해보면 A=1일 때, 1번 사람이 1번 일을 하면 2, 2번 일을 하면 3, 3번 일을 하면 3이 비용의 최솟값이란 의미이다. 4,5,6,7번 조합은 A=1일 경우 INF이므로 불가한 셈이다. 

     

      그 다음 A=2, A=3도 보면 다음과 같을꺼다. A=2는 즉 2 사람이 1번 사람이 이미 진행 한 것 기준으로, 1번과 2번 일을 하는 경우 4가 최솟값 (= min( dp[1][1] + 2번사람이 2번일 하는 비용, dp[1][2] + 2번 사람이 1번일 하는 비용) 이 되는 식이다. 마지막으로 A=3인 경우는 마찬가지로 2번 사람이 2개의 일을 한 경우들에 3번 사람이 남는 일 하나를 한 경우들 중의 최솟값이 된다.

     

      일단 위의 내용을 이해했다면 다 이해한거다!

    위의 설명은 결국 간소화해서 확인한건데, 실제론 저 모든 조합을 N이 변경됨에 따라 직접 구하긴 상당히 어렵다. N=2일 땐 0. 아무일도 안함 / 1. 1번일함 / 2. 2번일함 / 3.1,2번일 함 이런식으로 될꺼다. 이걸 N=20까지 전부 미리 생각해두긴 비효율적이다.

     

      근데 잘 생각해보면 N의 최대치는 20이고, int값은 총 32개의 bit가 있다. 따라서 각 비트를 가지고 모든 조합을 나타낼 수 있다. 예를들어 N=3일 때 아래처럼 나타내볼 수 있다.

    00000...000 = 아무일도 안함

    00000...001 = 1번일 함

    00000...101 = 1,3번일 함

    00000...111 = 1,2,3번일 함

     

      이런식으로 bit를 이용해 조합을 나타내는 것과 위에서 0부터 7번까지 모든 조합을 나타낸 것에 차이점은 없다! 결국 모든 경우의 수를 나타낼 수만 있으면 되는거니깐. 그러니 위의 표를 가지고 설명한게 이해됬다면, 그대로 모든 조합을 미리 생각해두는 것 대신 단순히 int 하나를 가지고 bit로 표현해주면 그걸로 풀이가 된다. 다만 bit를 가지고 놀다보니 코드가 이해하기 힘들 수 있다. 코드 풀이는 이하 주석으로 달아두었다.

     

    private void solution() throws Exception {
        int n = Integer.parseInt(br.readLine());
        int[][] arr = new int[n+1][n+1];
        for (int i = 1; i <= n; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            for (int j = 1; j <= n; j++) {
                arr[i][j] = Integer.parseInt(st.nextToken());
            }
        }
    
        int[][] dp = new int[n+1][1<<(n+1)];	// dp[A][B]
        for (int[] row : dp) Arrays.fill(row, INF);	// INF로 전체 초기화
        dp[0][0] = 0;	// 얘기한대로 dp[0][0]만 0으로 초기화
    
        for (int i = 1; i <= n; i++) {	// A=1부터 A=N까지 증가시키면서
    
            for (int j = 0; j < 1<<(n+1); j++) {	// 직전 사람이 진행한 조합을 확인하며
                if (dp[i-1][j] == INF) continue;	// 직전 값이 무한이면 무시
    
                for (int k = 1; k <= n; k++) {	// A=i 일 때 k번 일을 할꺼임.
                    if ((j&(1<<k)) != 0) continue;	// 이미 bit에 k번 일을 한 경우 무시
                    dp[i][j|(1<<k)] = Math.min(dp[i][j|(1<<k)], dp[i-1][j] + arr[i][k]);
                    // j라는 조합에 k라는 일을 더한 조합의 최소값 구하기.
                }
            }
    
        }
    
        System.out.println(dp[n][(1<<(n+1))-2]);	
        // 최종적으로 답은 dp[n][모든 일을 한 경우(즉 n개의 비트가 모두 켜진 경우)]
    }

      

     

    코드 : github

    import java.io.BufferedReader;
    import java.io.InputStreamReader;
    import java.util.Arrays;
    import java.util.StringTokenizer;
    
    public class Main {
    
        static BufferedReader br = new BufferedReader(new InputStreamReader(System.in), 1<<16);
        static final int INF = 20*10000+1;
    
        public static void main(String[] args) throws Exception {
            new Main().solution();
        }
    
        private void solution() throws Exception {
            int n = Integer.parseInt(br.readLine());
            int[][] arr = new int[n+1][n+1];
            for (int i = 1; i <= n; i++) {
                StringTokenizer st = new StringTokenizer(br.readLine());
                for (int j = 1; j <= n; j++) {
                    arr[i][j] = Integer.parseInt(st.nextToken());
                }
            }
    
            int[][] dp = new int[n+1][1<<(n+1)];
            for (int[] row : dp) Arrays.fill(row, INF);
            dp[0][0] = 0;
    
            for (int i = 1; i <= n; i++) {
    
                for (int j = 0; j < 1<<(n+1); j++) {
                    if (dp[i-1][j] == INF) continue;
    
                    for (int k = 1; k <= n; k++) {
                        if ((j&(1<<k)) != 0) continue;
                        dp[i][j|(1<<k)] = Math.min(dp[i][j|(1<<k)], dp[i-1][j] + arr[i][k]);
                    }
                }
    
            }
    
            System.out.println(dp[n][(1<<(n+1))-2]);
        }
    }

     

    댓글