NOIP2015提高组.运输计划
题目
521. 运输计划
算法标签: 树上倍增, l c a lca lca, 前缀和, 树上差分, 二分
思路
注意到答案是具有二分性质的, 对于某个时间 m i d mid mid假设是最优答案, 小于该时间是不可以的, 但是大于该时间是可行的, 因此可以二分答案
这样就将问题转化为, 对于给定的时间
m
i
d
mid
mid, 将树中的一条边权变为
0
0
0, 所有的运输路线耗时是否
≤
m
i
d
\le mid
≤mid
可以将所有运输的路线分为两类, 一种是运输时间
≤
m
i
d
\le mid
≤mid的, 这种路线不要需要删除边
但是还有一种路线是
>
m
i
d
> mid
>mid, 对于这些路线需要找个这些路线的公共边, 将这个公共边的权值变为
0
0
0, 但是直接枚举所有的边和路线会超时, 因此需要进行优化
可以在所有路线上的边 + 1 + 1 +1, 最终结果就是公共边被加了 t t t次, t t t是大于 m i d mid mid的路线的数量, 这样就找到了这个边, 利用树上差分, 实现对每个边 + 1 +1 +1的操作
代码
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
const int N = 300010, M = N << 1, K = 19;
int n, m;
int head[N], ed[M], ne[M], w[M], idx;
int fa[N][K], depth[N], d[N];
struct Path {
int u, v, p, d;
} path[N];
int s[N];
void add(int u, int v, int val) {
ed[idx] = v, ne[idx] = head[u], w[idx] = val, head[u] = idx++;
}
void dfs(int u, int pre, int dep) {
depth[u] = dep;
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (v == pre) continue;
fa[v][0] = u;
for (int k = 1; k < K; ++k) fa[v][k] = fa[fa[v][k - 1]][k - 1];
d[v] = d[u] + w[i];
dfs(v, u, dep + 1);
}
}
int lca(int u, int v) {
if (depth[u] < depth[v]) swap(u, v);
for (int k = K - 1; k >= 0; --k) {
if (depth[fa[u][k]] >= depth[v]) {
u = fa[u][k];
}
}
if (u == v) return v;
for (int k = K - 1; k >= 0; --k) {
if (fa[u][k] != fa[v][k]) {
u = fa[u][k];
v = fa[v][k];
}
}
return fa[u][0];
}
void dfs_sum(int u, int pre) {
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (v == pre) continue;
dfs_sum(v, u);
s[u] += s[v];
}
}
bool check(int mid) {
memset(s, 0, sizeof s);
int c = 0, max_d = 0;
for (int i = 0; i < m; ++i) {
auto [u, v, p, val] = path[i];
if (val > mid) {
c++;
max_d = max(max_d, val);
s[u]++;
s[v]++;
s[p] -= 2;
}
}
if (c == 0) return true;
dfs_sum(1, -1);
for (int u = 2; u <= n; ++u) {
if (s[u] == c && max_d - (d[u] - d[fa[u][0]]) <= mid) {
return true;
}
}
return false;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
memset(head, -1, sizeof head);
cin >> n >> m;
for (int i = 0; i < n - 1; ++i) {
int u, v, w;
cin >> u >> v >> w;
add(u, v, w), add(v, u, w);
}
dfs(1, -1, 1);
for (int i = 0; i < m; ++i) {
int u, v;
cin >> u >> v;
int p = lca(u, v);
int dis = d[u] + d[v] - 2 * d[p];
path[i] = {u, v, p, dis};
}
int l = 0, r = 3e8;
while (l < r) {
int mid = l + r >> 1;
if (check(mid)) r = mid;
else l = mid + 1;
}
cout << l << "\n";
return 0;
}
* v e c t o r vector vector存邻接表会超时
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
typedef pair<int, int> PII;
const int N = 300010, M = N << 1, K = 19;
int n, m;
vector<PII> head[N];
int fa[N][K], depth[N], d[N];
struct Path {
int u, v, p, d;
};
vector<Path> path;
int s[M];
void init() {
path.resize(m + 1);
}
void add(int u, int v, int w) {
head[u].push_back({v, w});
}
void dfs(int u, int pre, int dep) {
depth[u] = dep;
for (auto [v, w] : head[u]) {
if (v == pre) continue;
fa[v][0] = u;
for (int k = 1; k < K; ++k) fa[v][k] = fa[fa[v][k - 1]][k - 1];
d[v] = d[u] + w;
dfs(v, u, dep + 1);
}
}
int lca(int u, int v) {
if (depth[u] < depth[v]) swap(u, v);
for (int k = K - 1; k >= 0; --k) {
if (depth[fa[u][k]] >= depth[v]) {
u = fa[u][k];
}
}
if (u == v) return u;
for (int k = K - 1; k >= 0; --k) {
if (fa[u][k] != fa[v][k]) {
u = fa[u][k];
v = fa[v][k];
}
}
return fa[u][0];
}
void dfs_sum(int u, int fa) {
for (auto [v, w] : head[u]) {
if (v == fa) continue;
dfs_sum(v, u);
s[u] += s[v];
}
}
bool check(int mid) {
memset(s, 0, sizeof s);
int cnt = 0, max_d = 0;
for (auto [u, v, p, dis] : path) {
if (dis > mid) {
cnt++;
s[u]++;
s[v]++;
s[p] -= 2;
max_d = max(max_d, dis);
}
}
if (cnt == 0) return true;
dfs_sum(1, -1);
for (int u = 2; u <= n; ++u) {
if (s[u] == cnt && max_d - (d[u] - d[fa[u][0]]) <= mid) return true;
}
return false;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
cin >> n >> m;
init();
for (int i = 0; i < n - 1; ++i) {
int u, v, w;
cin >> u >> v >> w;
add(u, v, w), add(v, u, w);
}
dfs(1, -1, 1);
for (int i = 0; i < m; ++i) {
int u, v;
cin >> u >> v;
int p = lca(u, v);
path[i] = {u, v, p, d[u] + d[v] - 2 * d[p]};
}
int l = 0, r = 3e8;
while (l < r) {
int mid = l + r >> 1;
if (check(mid)) r = mid;
else l = mid + 1;
}
cout << l << "\n";
return 0;
}