비트마스킹으로 상태를 정의하면서 bfs를 돌면 된다.
도는 중간에 켜져있는 발전소 개수가 P이상일 때만 저장된 정답과 비교하면서 돌면 끝이다.
struct Data
{
int bit, cost, cnt;
bool operator<(const Data r)const{
if(cnt==r.cnt)
{
return cost>r.cost;
}
return cnt>r.cnt;
}
};
구조체를 이렇게 정의해서 어떤 발전소들이 켜져 있는지 저장하면 bit, 현재까지의 비용 cost, 켜져 있는 발전소의 개수 cnt
이렇게 3가지 정보를 이용한다.
priority_queue<Data>pw;
void solve(int start,int stcnt)
{
pw.push({start,0,stcnt});
while(!pw.empty())
{
Data t = pw.top();
pw.pop();
int nbit = t.bit;
int ncost = t.cost;
int ncnt = t.cnt;
if(visit[nbit]) continue;
visit[nbit]=1;
if(ncnt>=P)
{
if(ans>ncost) ans = ncost;
continue;
}
int tar = nbit;
vector<int>on,off;
for(int x=1;x<=N;x++)
{
if(tar%2==1) on.push_back(x);
else off.push_back(x);
tar/=2;
}
for(int i=0;i<on.size();i++)
{
for(int j=0;j<off.size();j++)
{
int nextbit = nbit+(1<<(off[j]-1));
pw.push({nextbit,ncost+cost[on[i]][off[j]],ncnt+1});
}
}
}
}
이렇게 해서 풀면 된다.
비트를 받아서 on과 off 벡터에 켜져있는 발전소와 꺼져있는 발전소들을 분류하고,
이중 for문으로 켜져있는 발전소 -> 꺼져있는 발전소로 돌면서 다음 상태와 비용을 정의해주면 된다.
#include<stdio.h>
#include<string>
#include<iostream>
#include<queue>
#include<vector>
using namespace std;
int N,P;
int cost[20][20];
int states[100004];
int visit[100004];
int ans = 100000000;
struct Data
{
int bit, cost, cnt;
bool operator<(const Data r)const{
if(cnt==r.cnt)
{
return cost>r.cost;
}
return cnt>r.cnt;
}
};
priority_queue<Data>pw;
void solve(int start,int stcnt)
{
pw.push({start,0,stcnt});
while(!pw.empty())
{
Data t = pw.top();
pw.pop();
int nbit = t.bit;
int ncost = t.cost;
int ncnt = t.cnt;
if(visit[nbit]) continue;
visit[nbit]=1;
if(ncnt>=P)
{
if(ans>ncost) ans = ncost;
continue;
}
int tar = nbit;
vector<int>on,off;
for(int x=1;x<=N;x++)
{
if(tar%2==1) on.push_back(x);
else off.push_back(x);
tar/=2;
}
for(int i=0;i<on.size();i++)
{
for(int j=0;j<off.size();j++)
{
int nextbit = nbit+(1<<(off[j]-1));
pw.push({nextbit,ncost+cost[on[i]][off[j]],ncnt+1});
}
}
}
}
int main()
{
scanf("%d",&N);
for(int i=1;i<=N;i++)
{
for(int j=1;j<=N;j++)
{
scanf("%d",&cost[i][j]);
}
}
string start;
cin >> start;
int stbit=0,stcnt=0;
for(int i=0;i<start.size();i++)
{
if(start[i]=='Y')
{
stcnt++;
stbit += 1<<i;
}
}
scanf("%d",&P);
solve(stbit,stcnt);
if(ans==100000000)
{
printf("-1");
return 0;
}
printf("%d",ans);
}
전체 코드이다.