可撤销并查集,原理分析,题目练习
零、写在前面
可撤销并查集代码相对简单,但是使用场景往往比较复杂,经常用于处理离线查询,比较经典的应用是结合线段树分治维护动态连通性问题。在一些较为综合的图论问题中也经常出现。
前置知识:并查集,扩展域并查集,带边权并查集详解,OJ练习,详细代码
一、可撤销并查集
1.1 可撤销并查集
可撤销并查集(Revertible Disjoint Union)是一种支持撤销上一次合并操作的并查集。
1.2 启发式合并
可撤销并查集无法使用路径压缩优化,只能使用启发式合并,因而可撤销并查集的查询复杂度是O(logn)的
1.3 数据定义
constexpr int N = 2E5; // 节点数
int f[N]; // 父节点
int siz[N]; // 集合size
std::stack<std::pair<int, int>> his; // 操作记录栈
1.4 查询祖先
暴力上跳即可,树高为O(logn),因而时间复杂度为O(logn)
int find(int x) {while (f[x] != x) {x = f[x]; }return x;
}
1.5 合并操作
- 如果 x, y 在同一集合,合并失败 返回false
- 否则 取 x 为大树根节点(启发式合并)
- 记录 merge(x, y) 操作到his
- siz[x] += siz[y], f[y] = x
- 合并成功返回 true
bool merge(int x, int y) {x = find(x);y = find(y);if (x == y) {return false;}if (siz[x] < siz[y]) {std::swap(x, y);}his.emplace_back(x, y);siz[x] += siz[y];f[y] = x;return true;
}
1.6 撤销操作
- 因为栈顶始终保存的是我们上一次合并操作
- 所以回滚到上一次的状态只需弹出上次 merge(x, y) 的 x, y,并恢复 f[y] 和 siz[x] 即可
- 进一步,如果 想要回滚到 tm时刻,我们只需不断的弹栈恢复,直至 his 栈的size = tm
void revert(int tm) {while (his.size() > tm) {auto [x, y] = his.back();his.pop_back();f[y] = y;siz[x] -= siz[y];}
}
1.7 基本模板
因为 需要可撤销并查集的题目往往需要对维护的信息做修改,所以我一般选择手写,不过这里还是放一个泛用性强一点的模板。
struct DSU {std::vector<int> siz;std::vector<int> f;std::vector<std::pair<int, int>> his;DSU(int n) : siz(n, 1), f(n) {std::iota(f.begin(), f.end(), 0);}int find(int x) {while (f[x] != x) {x = f[x];}return x;}bool merge(int x, int y) {x = find(x);y = find(y);if (x == y) {return false;}if (siz[x] < siz[y]) {std::swap(x, y);}his.emplace_back(x, y);siz[x] += siz[y];f[y] = x;return true;}int time() {return his.size();}void revert(int tm) {while (his.size() > tm) {auto [x, y] = his.back();his.pop_back();f[y] = y;siz[x] -= siz[y];}}int size(int x) {return siz[find(x)];}
};
二、题目练习
2.1 Ex - Ball Collector
原题链接
Ex - Ball Collector
思路分析
题意就是说一颗以1为根的树,每个节点有两个数字,求出 2~n 每个节点的 最小代价
最小代价的定义为:1 ~ v 路径上每个节点拿一个数,能拿的最小种类数。
我们单看对于 1~v 这条路径的代价怎么求?
我们把每个数字看作新图节点,每个树节点的 A[v], B[v] 相当于 对 A[v] 和 B[v] 连边,那么我们对1~v路径建图得到的新图会有若干个连通块。
我们发现每条边其实就是选数操作,所以答案不会超过边数
又因为节点数目一共就那么多,所以每个连通块的贡献就是 min(V, E),其中 V 为节点数,E 为边数
每条路径暴力计算是O(N^2),我们发现我们dfs 暴力计算的问题就在于:回溯的时候如何删边删贡献?
可撤销并查集可以帮我们轻松解决,dfs 进入一个节点,我们merge(u, v),记录当前时刻 t
回溯时,我们revert 到 t 时刻即可。
现在考虑如何维护 min(V, E)
我们发现 一个连通块如果是树,那么答案就是 E(即 V - 1),否则就是 V
除了 可撤销DSU 本身维护的 siz 和 f 数组外,我们额外维护一个 loop[] 数组,代表该连通块环的数目
对于merge(x, y),我们除了维护f[] 和 siz[] 还要合并两个连通块的 loop
对于x 和 y 在同一连通块的情况,如果 当前连通块无环,说明我们原先是树,我们 e[x] = 1后返回true,否则返回false
对于x 和 y 不在同一连通块的情况,如果x 和 y 所在连通块都有环,那么我们不合并,返回false。因为合并无用
否则 正常合并,并且 loop[x] += loop[y]
撤销的时候相应撤销即可。
AC代码
时间复杂度:O(nlogn),log 是可撤销DSU的代价
#include <bits/stdc++.h>using i64 = long long;constexpr int N = 2E5;int f[N], siz[N], e[N];
std::stack<std::pair<int &, int>> his;int find(int x) {while (f[x] != x) {x = f[x]; }return x;
}void change(int &x, int y) {his.emplace(x, x);x = y;
}bool merge(int x, int y) {x = find(x);y = find(y);if (x == y) {if (!e[x]) {change(e[x], 1);return true;} else {return false;}}if (e[x] && e[y]) {return false;}if (siz[x] < siz[y]) {std::swap(x, y);}change(siz[x], siz[x] + siz[y]);change(e[x], e[x] + e[y]);change(f[y], x);return true;
}int time() {return his.size();
}void revert(int tim) {while(his.size() > tim) {auto [x, y] = his.top();his.pop();x = y;}
}int main() {std::ios::sync_with_stdio(false);std::cin.tie(nullptr);int N;std::cin >> N;std::iota(f, f + N, 0);std::fill(siz, siz + N, 1);std::vector<int> A(N), B(N);for (int i = 0; i < N; ++ i) {std::cin >> A[i] >> B[i];-- A[i], -- B[i];}std::vector<std::vector<int>> adj(N);for (int i = 0; i + 1 < N; ++ i) {int u, v;std::cin >> u >> v;-- u, -- v;adj[u].push_back(v);adj[v].push_back(u);}std::vector<int> ans(N);auto dfs = [&](auto &&self, int x, int p, int res) -> void{int t = his.size();res += merge(A[x], B[x]);ans[x] = res;for (int y : adj[x]) {if (y != p) {self(self, y, x, res);}}revert(t);};dfs(dfs, 0, -1, 0);for (int i = 1; i < N; ++ i) {std::cout << ans[i] << " \n"[i + 1 == N];}return 0;
}
2.2 C. Envy
原题链接
C. Envy
思路分析
首先,对于Kruscal算法,我们无论按怎样的顺序 合并一组边,我们得到的图的连通性是一样的(只可能”形状“不同,但连通性一定相同)
对于所有MST,某种权值的边的数目是一定的(如果可以被替换为更小或更大的,那么该树不是MST)
对于一组查询,如果存在某个MST 包含这些边,我们可以这样check:
- 模拟Kruscal算法
- 当前合并完了权值小于 x 的边权,检查完了查询中权值为x的边权
- 如果查询中所有边权为x 的边都合并成功,那么说明这些边是必须的,一定存在某个MST包含这些边,我们合并完剩余边权为x的边后继续往后check
- 否则查询非法
对于单个查询的处理复杂度是 O(mlogn) 即Kruscal的复杂度
对于多组查询,我们考虑离线,仍然按边权模拟Kruscal
不过不同的是,我们把普通并查集换成可撤销并查集,check完一组边后,我们回退到check前的时间戳
check完所有查询中该边权的边后,我们再去正常的合并原图中边权为x的边
时间复杂度:O(qlogk + V + mlogn)
AC代码
#include <bits/stdc++.h>using i64 = long long;struct DSU {std::vector<int> siz;std::vector<int> f;std::vector<std::pair<int, int>> his;DSU(int n) : siz(n, 1), f(n) {std::iota(f.begin(), f.end(), 0);}int find(int x) {while (f[x] != x) {x = f[x];}return x;}bool merge(int x, int y) {x = find(x);y = find(y);if (x == y) {return false;}if (siz[x] < siz[y]) {std::swap(x, y);}his.emplace_back(x, y);siz[x] += siz[y];f[y] = x;return true;}int time() {return his.size();}void revert(int tm) {while (his.size() > tm) {auto [x, y] = his.back();his.pop_back();f[y] = y;siz[x] -= siz[y];}}int size(int x) {return siz[find(x)];}
};int main() {std::ios::sync_with_stdio(false);std::cin.tie(nullptr);int n, m;std::cin >> n >> m;std::vector<int> u(m), v(m), w(m);for (int i = 0; i < m; ++ i) {std::cin >> u[i] >> v[i] >> w[i];-- u[i], -- v[i];}const int V = std::ranges::max(w) + 1;std::vector<std::vector<int>> e(V);for (int i = 0; i < m; ++ i) {e[w[i]].push_back(i);}std::vector<std::vector<std::pair<int, std::vector<int>>>> g(V);int q;std::cin >> q;for (int i = 0; i < q; ++ i) {int k;std::cin >> k;std::vector<int> a(k);for (int j = 0; j < k; ++ j) {std::cin >> a[j];-- a[j];}std::ranges::sort(a, {}, [&](int i){return w[i];});for (int l = 0, r = 0; l < k; l = r) {while (r < k && w[a[l]] == w[a[r]]) {++ r;}g[w[a[l]]].emplace_back(i, std::vector<int>(a.begin() + l, a.begin() + r));}}DSU dsu(n);std::vector<bool> ans(q, true);for (int x = 1; x < V; ++ x) {for (auto &[i, a] : g[x]) {int tim = dsu.time();std::vector<std::pair<int, int>> p;p.reserve(a.size());for (int j : a) {p.emplace_back(u[j], v[j]);}for (auto &[X, Y] : p) {if (!dsu.merge(X, Y)) {ans[i] = false;}}dsu.revert(tim);}for (int i : e[x]) {dsu.merge(u[i], v[i]);}}for (int i = 0; i < q; ++ i) {std::cout << (ans[i] ? "YES\n" : "NO\n");}return 0;
}
2.3 C. Team-Building
原题链接
C. Team-Building
思路分析
显然可以用扩展域并查集来判断二分图
考虑先去掉不满足二分图的组
对于剩下的组数记为cnt,那么有 cnt * (cnt - 1) / 2 对
我们只需计算不满足二分图的配对
直觉上似乎对数很多,没法算,事实上,我们从边的角度来分析的话,其实需要计算的对数很少。
对于一条边,他连接的两个组别已经确定了,所以我们按照边连接的 组的编号对进行排序,只需处理按照连接的组的编号对分段处理即可。
时间复杂度:O(m(logm + logn))
AC代码
#include <bits/stdc++.h>using i64 = long long;struct DSU {std::vector<int> siz;std::vector<int> f;std::vector<std::pair<int, int>> his;DSU(int n) : siz(n, 1), f(n) {std::iota(f.begin(), f.end(), 0);}int find(int x) {while (f[x] != x) {x = f[x];}return x;}bool merge(int x, int y) {x = find(x);y = find(y);if (x == y) {return false;}if (siz[x] < siz[y]) {std::swap(x, y);}his.emplace_back(x, y);siz[x] += siz[y];f[y] = x;return true;}int time() {return his.size();}void revert(int tm) {while (his.size() > tm) {auto [x, y] = his.back();his.pop_back();f[y] = y;siz[x] -= siz[y];}}int size(int x) {return siz[find(x)];}bool same(int x, int y) {return find(x) == find(y);}
};int main() {std::ios::sync_with_stdio(false);std::cin.tie(nullptr);int n, m, k;std::cin >> n >> m >> k;std::vector<int> c(n);for (int i = 0; i < n; ++ i) {std::cin >> c[i];-- c[i];}DSU dsu(2 * n);std::vector<int> a(m), b(m);std::vector<bool> good(k, true);for (int i = 0; i < m; ++ i) {std::cin >> a[i] >> b[i];-- a[i], -- b[i];if (c[a[i]] > c[b[i]]) {std::swap(a[i], b[i]);}if (c[a[i]] == c[b[i]]) {dsu.merge(a[i], b[i] + n);dsu.merge(a[i] + n, b[i]);if (dsu.same(a[i], b[i]) || dsu.same(a[i] + n, b[i] + n)) {good[c[a[i]]] = false;}}}int cnt = std::reduce(good.begin(), good.end(), 0);i64 ans = cnt * (cnt - 1LL) / 2;std::vector<int> p(m);std::ranges::iota(p, 0);std::ranges::sort(p, {}, [&](int i){return std::pair(c[a[i]], c[b[i]]);});for (int i = 0, j = 0; i < m; i = j) {while (j < m && c[a[p[i]]] == c[a[p[j]]] && c[b[p[i]]] == c[b[p[j]]]) ++ j;if (c[a[p[i]]] != c[b[p[i]]] && good[c[a[p[i]]]] && good[c[b[p[i]]]]) {int tim = dsu.time();for (int x = i; x < j; ++ x) {dsu.merge(a[p[x]], b[p[x]] + n);dsu.merge(a[p[x]] + n, b[p[x]]);if (dsu.same(a[p[x]], b[p[x]]) || dsu.same(a[p[x]] + n, b[p[x]] + n)) {-- ans;break;}}dsu.revert(tim);}}std::cout << ans << "\n";return 0;
}