BZOJ 3876 支线剧情 | 有下界费用流
题意
这题题面搞得我看了半天没看懂……是这样的,原题中的“剧情”指的是边,“剧情点”指的才是点。
题面翻译过来大概是这样:
有一个DAG,每次从1号点出发,走过一条路径,再瞬移回1号点。问:想要遍历所有的边,至少要走多少路程(瞬移回1号点不算路程)。
题解
我们用有上下界费用流的模型,建个图:
- 原图中的每条边,流量范围是\([1, +\infty]\),表示至少走一次,可以走无限次,这条边的费用就是边权。
- 原图中的每个点(除1号点外)向1号点连一条边,流量范围是\([0, +\infty]\),费用为0,表示任意节点随时可以回到1号节点。
在这个图上求一个最小费用最小流即可。
那么我们再用上下界网络流的套路给这个图改成正常的有源汇网络流:
- 对于原图中的每条边\(u \to v\)(边权为\(w\)),建边\((u, v, +\infty , w), (S, v, 1, w)\);
- 对于每个出度为\(t\)的点\(u\),建边\((u, T, t, 0)\);
- 对于每个非1的点\(u\),建边\((u, 1, +\infty, 0)\)。
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
typedef long long ll;
#define enter putchar('\n')
#define space putchar(' ')
template <class T>
void read(T &x){
char c;
bool op = 0;
while(c = getchar(), c > '9' || c < '0')
if(c == '-') op = 1;
x = c - '0';
while(c = getchar(), c >= '0' && c <= '9')
x = x * 10 + c - '0';
if(op) x = -x;
}
template <class T>
void write(T x){
if(x < 0) putchar('-'), x = -x;
if(x >= 10) write(x / 10);
putchar('0' + x % 10);
}
const int N = 305, M = 2000005, INF = 0x3f3f3f3f;
int n, src, des;
int ecnt = 1, adj[N], pre[N], dis[N], go[M], nxt[M], cap[M], cost[M];
queue <int> que;
bool inq[N];
void ADD(int u, int v, int _cap, int _cost){
go[++ecnt] = v;
nxt[ecnt] = adj[u];
adj[u] = ecnt;
cap[ecnt] = _cap;
cost[ecnt] = _cost;
}
void add(int u, int v, int _cap, int _cost){
ADD(u, v, _cap, _cost);
ADD(v, u, 0, -_cost);
}
bool spfa(){
for(int i = 1; i <= des; i++)
dis[i] = INF, pre[i] = 0;
dis[src] = 0, que.push(src), inq[src] = 1;
while(!que.empty()){
int u = que.front();
que.pop(), inq[u] = 0;
for(int e = adj[u], v; e; e = nxt[e])
if(cap[e] && dis[v = go[e]] > dis[u] + cost[e]){
dis[v] = dis[u] + cost[e], pre[v] = e;
if(!inq[v]) que.push(v), inq[v] = 1;
}
}
return pre[des] != 0;
}
int mcmf(){
int ret = 0;
while(spfa()){
int flow = INF;
for(int e = pre[des]; e; e = pre[go[e ^ 1]])
flow = min(flow, cap[e]);
for(int e = pre[des]; e; e = pre[go[e ^ 1]])
cap[e] -= flow, cap[e ^ 1] += flow;
ret += flow * dis[des];
}
return ret;
}
int main(){
read(n), src = n + 1, des = src + 1;
for(int u = 1, t; u <= n; u++){
read(t);
for(int i = 1, v, w; i <= t; i++)
read(v), read(w), add(src, v, 1, w), add(u, v, INF, w);
add(u, des, t, 0);
if(u != 1) add(u, 1, INF, 0);
}
write(mcmf()), enter;
return 0;
}