단순한 문제인데 꽤 어려웠다.
N개를 배열하면 N!가지가 있어서 일일히 모든 경우를 따져보는 것은 불가능하다.
그런데 정확한 경우의 수를 세려면 모든 경우를 따져야 한다.
모든 경우를 나눠보지 않고 어떻게 나눠지는 경우의 수를 셀 수 있을지 생각해 내기가 어려운 문제였다.
결론적으로 방법은 dp를 쓰면 된다.
N만 놓고 본다면 15! 이어서 모든 경우를 셀 수 없는 것이 맞는데
이 문제는 K가 100이하여서 풀리는 문제이다.
N개의 숫자들 중에 어떤 숫자를 사용했는지를 비트마스킹으로 표현을 하면
dp[i][j]는 비트가 i이고, K로 나눈 나머지가 j인 경우의 수이다.
이렇게 하면 2^15 * 100 크기의 dp 배열을 채우면 되는 문제가 돼서 시간 안에 풀 수 있게 된다.
void solv()
{
dp[0][0] = 1;
for(int cur = 0;cur<(1<<N);cur++)
{
for(int i=0;i<N;i++)
{
int next = cur | (1<<i);
if(next!=cur)
{
for(int j=0;j<K;j++)
{
int nextR1 = (j*digits[nums[i].size()])%K;
int nextR2 = numremains[i];
int nextR = (nextR1 + nextR2)%K;
dp[next][nextR] += dp[cur][j];
}
}
}
}
}
핵심 풀이 코드이다.
집합 내에서 사용한 숫자들을 1로 표시해 만든 비트를 cur로 생각하고
모든 경우 2^N을 전부 돌면서 i번 숫자를 추가해 나가는 방식이다.
이때 i번 숫자가 이미 쓴 숫자면 안되니까 next != cur 일때만 처리한다.
그리고 j를 0부터 K-1까지 돌면서 nextR을 계산해준다.
여기서 digits[x]는 10^x을 K로 나눈 나머지로 미리 계산해서 넣어놓는다.
numremains는 i번째 숫자를 K로 나눈 나머지를 미리 계산해서 넣어놓은 결과이다.
그러면 나머지 j에 10^(i의 자릿수) 를 곱한 값에 i번째 수를 K로 나눈 나머지를 더해주면 된다.
예를 들어 (~~~)가 있을 때 (~~~)(~~~~~) 를 K로 나눈 나머지는 (~~~)를 K로 나눈 나머지에 10^5을 K로 나눈 나머지를 곱하고,
거기에 (~~~~~)를 K로 나눈 나머지를 더해서 계산하는 방식이다.
그리고 숫자가 최대 50자리나 되기 때문에 숫자로는 계산을 할 수가 없어서 나머지 계산을 문자열로 해줘야 한다.
int div(string x)
{
int res = 0;
for(int i=0;i<x.size();i++)
{
res *= 10;
res += x[i]-'0';
res %= K;
}
return res;
}
여기서 좀 해맸던 부분이 res += (x[i]-'0')%K 이렇게 하면 이상한 값이 나와서 두 줄에 걸쳐 작성해야 한다.
디버깅하면서 찾아냈던 부분인데 찾느라 꽤 오래걸렸다.
#include<stdio.h>
#include<vector>
#include<string>
#include<iostream>
using namespace std;
int N,K;
vector<string>nums;
long long dp[40004][104];
int digits[104];
int numremains[20];
int div(string x)
{
int res = 0;
for(int i=0;i<x.size();i++)
{
res *= 10;
res += x[i]-'0';
res %= K;
}
return res;
}
void solv()
{
dp[0][0] = 1;
for(int cur = 0;cur<(1<<N);cur++)
{
for(int i=0;i<N;i++)
{
int next = cur | (1<<i);
if(next!=cur)
{
for(int j=0;j<K;j++)
{
int nextR1 = (j*digits[nums[i].size()])%K;
int nextR2 = numremains[i];
int nextR = (nextR1 + nextR2)%K;
dp[next][nextR] += dp[cur][j];
}
}
}
}
}
long long fac(int n)
{
if(n==1) return 1;
return n*fac(n-1);
}
long long gcd(long long a, long long b)
{
if(a<b)
{
long long tmp = a;
a = b;
b = tmp;
}
if(b==0) return a;
return gcd(a%b,b);
}
int main()
{
scanf("%d",&N);
for(int i=0;i<N;i++)
{
string x;
cin >> x;
nums.push_back(x);
}
scanf("%d",&K);
digits[0]=1;
digits[1]=10%K;
for(int i=2;i<=50;i++)
{
digits[i] = (digits[i-1]*10)%K;
}
for(int i=0;i<N;i++)
{
numremains[i] = div(nums[i]);
}
solv();
long long ans = dp[(1<<N)-1][0];
long long g = gcd(ans,fac(N));
printf("%lld",ans/g);
printf("/");
printf("%lld",fac(N)/g);
}
전체 코드이다.