C++数据结构 —— 平衡树Treap
文章目录
- 一、简介
- 二、具体实现
- 0. 小Tips
- 1. 右旋-左旋
- 2. 关键函数
- 三、完整代码及测试
- 1. 头文件
- 2. 测试主函数
一、简介
Treap是由二叉搜索树(BST—“Binary Search Tree”)和堆(Heap)结合而成的数据结构。
(补充BST:当前节点左子树中的任何一个节点的权值均严格小于当前节点的权值,右子树中的任何一个节点的权值均严格大于当前节点的权值。其中序遍历的结果是从小到大的有序序列。)
本质:动态维护一个有序序列
当一组有序数据顺序插入BST时,整棵树会退化成一条链,复杂度会从 O ( l o g n ) O(log\ n) O(log n)退化到 O ( n ) O(n) O(n)。为了避免这种情况的发生,尽可能使不管插入什么样的数据,树的高度均维持在 l o g n log\ n log n,可以考虑给每个节点引入一个随机的key
值,然后根据key
值重构树为一个堆(同时val
仍满足BST性质)。
这样每次在插入删除数据时,动态调整数的结构,使其维持“平衡”。
Treap的相关操作
- 插入一个数
- 删除一个数
- 找前驱/后继
- 找最值
可拓展的操作
(算法竞赛中可能会出现)
- 求某个值的排名
- 求排名为k的数
- 求比某数小的最大数
- 求比某数大的最小数
二、具体实现
定义Treap类:
#pragma once#include <functional>
#include <random>
#include <limits>template <typename T, typename Compare = std::less<T>>
class Treap {
private:struct Node {T key; // T类型的值int priority; // 用于堆调整的优先级size_t size; // 记录以该节点为根的子树大小(快速求解排序时可用)Node* left; // 左儿子节点Node* right; // 右儿子节点Node(const T& k, int p) : key(k), priority(p), left(nullptr), right(nullptr), size(1) {}};Node* root;Compare comp;std::mt19937 rng;std::uniform_int_distribution<int> dist;// 由子节点计算当前节点的大小void updateSize(Node* node);// 右旋操作Node* rotateRight(Node* node);// 左旋操作Node* rotateLeft(Node* node);// 插入操作Node* insert(Node* node, const T& key, int priority);// 查找操作Node* find(Node* node, const T& key);// 获取前驱节点Node* get_prev(Node* node, const T& key);// 获取后继节点Node* get_next(Node* node, const T& key);// 获取最小值节点Node* get_min(Node* node);// 获取最大值节点Node* get_max(Node* node);// 删除操作Node* erase(Node* node, const T& key);// 清空操作void clear(Node* node); public:Treap() : root(nullptr), comp(), rng(std::random_device{}()),dist(0, std::numeric_limits<int>::max()) {}~Treap() {clear(root);}/*外部接口实现*/void insert(const T& key);void erase(const T& key);bool find(const T& key);T get_prev(const T& key);T get_next(const T& key);T get_min();T get_max();size_t size() const;bool empty() const;void clear();
};
0. 小Tips
-
模板加入
Compare
,默认从小到大排序,可以通过修改该类型自定义排序规则 -
std::mt19937
这是随机数生成器:
- 全名是 Mersenne Twister 19937(梅森旋转算法)
- 用来产生高质量的伪随机数
- 这里用来生成每个节点的随机优先级(Treap需要这个机制来保持平衡)
- 初始化时用随机设备种子:rng(std::random_device{}())
相当于更高级的 rand() 函数,但随机性更好、周期更长
-
std::uniform_int_distribution<int>
这是随机数分布器:
- 用来把随机数生成器的结果映射到指定范围
- 初始化时设置的范围是 0 到
int
最大值(dist(0, std::numeric_limits<int>::max())
) - 和
rng
配合使用:dist(rng)
就会生成一个随机优先级
1. 右旋-左旋
右旋和左旋是Treap维持平衡的关键操作
-
右旋
将当前节点旋转到其左儿子的右儿子位置上
// 右旋操作 Node* rotateRight(Node* node) {Node* tmpNode = node->left;node->left = tmpNode->right;tmpNode->right = node;updateSize(node);updateSize(tmpNode);return tmpNode; }
-
左旋
将当前节点旋转到其右儿子的左儿子位置上
// 左旋操作 Node* rotateLeft(Node* node) {Node* tmpNode = node->right;node->right = tmpNode->left;tmpNode->left = node;updateSize(node);updateSize(tmpNode);return tmpNode; }
2. 关键函数
以下解释均以Compare = std::less<T>
为例
-
更新大小
通过子节点信息计算当前节点信息
void updateSize(Node* node) {if (node) {node->size = 1 + (node->left ? node->left->size : 0) + (node->right ? node->right->size : 0);} }
-
插入
递归实现:将值为
key
、优先级为priority
的节点插入到节点node
的下面插入分为四种情况:
- 如果当前节点不存在,根据值和优先级新建节点
- 如果当前插入的值与
node
的值相等,直接返回 - 如果当前插入的值小于
node
的值:- 将其插入
node
的左子树 - 插入完毕后,如果左子节点优先级大于当前
node
,需要右旋node
(大根堆性质,父节点优先级必须大于子节点优先级)
- 将其插入
- 如果当前插入的值大于
node
的值:- 将其插入
node
的右子树 - 插入完毕后,如果右子节点优先级大于当前
node
,需要左旋node
- 将其插入
插入后记得跟新当前
node
的大小,然后返回当前node
// 插入操作 Node* insert(Node* node, const T& key, int priority) {// 如果当前节点为空,则创建一个新节点if (!node) {return new Node(key, priority);}// 如果当前节点的优先级小于新节点的优先级,则需要进行旋转if (key == node->key) {return node;} else if (comp(key, node->key)) {node->left = insert(node->left, key, priority);if (node->left->priority > node->priority) {node = rotateRight(node);}} else if (comp(node->key, key)) {node->right = insert(node->right, key, priority);if (node->right->priority > node->priority) {node = rotateLeft(node);}}// 更新当前节点的大小updateSize(node);return node; }
-
查找
和BST的查找方式一样,根据待查找Key与当前节点值比较的结果决定是递归查找右子树还是左子树
// 查找操作 Node* find(Node* node, const T& key) {while (node) {if (comp(key, node->key)) {node = node->left;} else if (comp(node->key, key)) {node = node->right;} else {return node; // 找到}}return nullptr; // 没找到 }
-
找前驱
此处key的前驱节点指中序遍历中小于key的最大值。
分两种情况:
- 当key小于等于当前节点值时,递归到左子树查找key的前驱
- 否则,递归到右子树查找key的前驱
// 获取前驱节点 Node* get_prev(Node* node, const T& key) {if (!node) {return nullptr;}if (comp(key, node->key) || key == node->key) {return get_prev(node->left, key);} else {Node* right = get_prev(node->right, key);return right ? right : node;} }
-
找后继
此处key的前驱节点指中序遍历中大于key的最小值。
分两种情况:
- 当key大于等于当前节点值时,递归到右子树查找key的后继
- 否则,递归到左子树查找key的后继
// 获取后继节点 Node* get_next(Node* node, const T& key) {if (!node) {return nullptr;}if (comp(node->key, key) || key == node->key) {return get_next(node->right, key);} else {Node* left = get_next(node->left, key);return left ? left : node;} }
-
获取最值
最值分别在Treap的最左边和最右边的节点
// 获取最左边节点 Node* get_left(Node* node) {while (node && node->left) {node = node->left;}return node; } // 获取最右边节点 Node* get_right(Node* node) {while (node && node->right) {node = node->right;}return node; }
-
删除元素
根据节点的值,分三种情况:
-
当要删除的元素小于当前节点时,递归进入左子树
-
当要删除的元素大于当前节点时,递归进入右子树
-
当找到要删除的节点时:
-
如果是叶子节点,则直接删除
-
如果是非叶子节点:
通过旋转操作将要删除的结点旋转到叶节点位置
- 如果没有右子树,或者同时存在左右子树且左子树优先级大于右子树优先级:先右旋当前节点,然后递归删除右子树
- 否则,先左旋当前节点,然后递归删除左子树
-
// 删除操作 Node* erase(Node* node, const T& key) {if (!node) {return nullptr;}// 如果当前节点的key小于要删除的key,则在右子树中删除if (comp(key, node->key)) {node->left = erase(node->left, key);} else if (comp(node->key, key)) { // 如果当前节点的key大于要删除的key,则在左子树中删除node->right = erase(node->right, key);} else { // 找到要删除的节点// 如果是非叶子结点,通过旋转将其转化为叶子结点if (node->left || node->right) { // 如果没有右子树,或者左子树的优先级大于右子树的优先级,则进行右旋转// 否则进行左旋转if (!node->right || (node->left && node->left->priority > node->right->priority)) {node = rotateRight(node);node->right = erase(node->right, key);} else {node = rotateLeft(node);node->left = erase(node->left, key);}} else { // 如果是叶子节点,直接删除delete node;node = nullptr;}}updateSize(node);return node; }
给出一个样例:
-
-
清空
递归删除左右子树即可
// 清空操作 void clear(Node* node) {if (node) {clear(node->left);clear(node->right);delete node;} }
三、完整代码及测试
1. 头文件
查询资料推荐模板类编程将函数实现放在类内
#pragma once#include <functional>
#include <random>
#include <limits>template <typename T, typename Compare = std::less<T>>
class Treap {
private:struct Node {T key;int priority;size_t size;Node* left;Node* right;Node(const T& k, int p) : key(k), priority(p), left(nullptr), right(nullptr), size(1) {}};Node* root;Compare comp;std::mt19937 rng;std::uniform_int_distribution<int> dist;// 由子节点计算当前节点的大小void updateSize(Node* node) {if (node) {node->size = 1 + (node->left ? node->left->size : 0) + (node->right ? node->right->size : 0);}} // 右旋操作Node* rotateRight(Node* node) {Node* tmpNode = node->left;node->left = tmpNode->right;tmpNode->right = node;updateSize(node);updateSize(tmpNode);return tmpNode;} // 左旋操作Node* rotateLeft(Node* node) {Node* tmpNode = node->right;node->right = tmpNode->left;tmpNode->left = node;updateSize(node);updateSize(tmpNode);return tmpNode;} // 插入操作Node* insert(Node* node, const T& key, int priority) {// 如果当前节点为空,则创建一个新节点if (!node) {return new Node(key, priority);}// 如果当前节点的优先级小于新节点的优先级,则需要进行旋转if (key == node->key) {return node;} else if (comp(key, node->key)) {node->left = insert(node->left, key, priority);if (node->left->priority > node->priority) {node = rotateRight(node);}} else if (comp(node->key, key)) {node->right = insert(node->right, key, priority);if (node->right->priority > node->priority) {node = rotateLeft(node);}}// 更新当前节点的大小updateSize(node);return node;}// 查找操作Node* find(Node* node, const T& key) {while (node) {if (comp(key, node->key)) {node = node->left;} else if (comp(node->key, key)) {node = node->right;} else {return node; // 找到}}return nullptr; // 没找到} // 获取前驱节点Node* get_prev(Node* node, const T& key) {if (!node) {return nullptr;}if (comp(key, node->key) || key == node->key) {return get_prev(node->left, key);} else {Node* right = get_prev(node->right, key);return right ? right : node;}} // 获取后继节点Node* get_next(Node* node, const T& key) {if (!node) {return nullptr;}if (comp(node->key, key) || key == node->key) {return get_next(node->right, key);} else {Node* left = get_next(node->left, key);return left ? left : node;}} // 获取最左边节点Node* get_left(Node* node) {while (node && node->left) {node = node->left;}return node;} // 获取最右边节点Node* get_right(Node* node) {while (node && node->right) {node = node->right;}return node;} // 删除操作Node* erase(Node* node, const T& key) {if (!node) {return nullptr;}// 如果当前节点的key小于要删除的key,则在右子树中删除if (comp(key, node->key)) {node->left = erase(node->left, key);} else if (comp(node->key, key)) { // 如果当前节点的key大于要删除的key,则在左子树中删除node->right = erase(node->right, key);} else { // 找到要删除的节点// 如果是非叶子结点,通过旋转将其转化为叶子结点if (node->left || node->right) { // 如果没有右子树,或者左子树的优先级大于右子树的优先级,则进行右旋转// 否则进行左旋转if (!node->right || (node->left && node->left->priority > node->right->priority)) {node = rotateRight(node);node->right = erase(node->right, key);} else {node = rotateLeft(node);node->left = erase(node->left, key);}} else { // 如果是叶子节点,直接删除delete node;node = nullptr;}}updateSize(node);return node;} // 清空操作void clear(Node* node) {if (node) {clear(node->left);clear(node->right);delete node;}} public:Treap() : root(nullptr), comp(), rng(std::random_device{}()),dist(0, std::numeric_limits<int>::max()) {}~Treap() {clear(root);}/*外部接口实现*/void insert(const T& key) {int priority = dist(rng); // 生成随机优先级root = insert(root, key, priority);}void erase(const T& key) {root = erase(root, key);}bool find(const T& key) {return find(root, key) != nullptr;}T get_prev(const T& key) {Node* node = get_prev(root, key);return node ? node->key : T(); // 返回前驱节点的key}T get_next(const T& key) {Node* node = get_next(root, key);return node ? node->key : T(); // 返回后继节点的key}T get_min() {Node* node = nullptr;if constexpr (std::is_same_v<Compare, std::less<T>>) {node = get_left(root);} else {node = get_right(root);}return node ? node->key : T(); // 返回最小值节点的key}T get_max() {Node* node = nullptr;if constexpr (std::is_same_v<Compare, std::less<T>>) {node = get_right(root);} else {node = get_left(root);}return node ? node->key : T(); // 返回最大值节点的key}size_t size() const {return root ? root->size : 0;}bool empty() const {return size() == 0;}void clear() {clear(root);root = nullptr;}
};
2. 测试主函数
#include <iostream>
#include "Treap.h"int main()
{Treap<int> treap;treap.insert(5);treap.insert(3);treap.insert(3);treap.insert(7);treap.insert(2);treap.insert(4);std::cout << "Size: " << treap.size() << std::endl; // 5std::cout << "Find 3: " << treap.find(3) << std::endl; // 1std::cout << "Precursor of 3: " << treap.get_prev(3) << std::endl; // 2std::cout << "Next of 3: " << treap.get_next(3) << std::endl; // 4std::cout << "Min: " << treap.get_min() << std::endl; // 2std::cout << "Max: " << treap.get_max() << std::endl; // 7// 特殊情况测试std::cout << "Find 9: " << treap.find(9) << std::endl; // 0std::cout << "Precursor of 9: " << treap.get_prev(9) << std::endl; // 7std::cout << "Next of 9: " << treap.get_next(9) << std::endl; // 0std::cout << "Precursor of 0: " << treap.get_prev(0) << std::endl; // 0std::cout << "Next of 0: " << treap.get_next(0) << std::endl; // 2treap.erase(3);std::cout << "Size after erase: " << treap.size() << std::endl; // 4std::cout << "find 3: " << treap.find(3) << std::endl; // 0return 0;
}
参考:
- DeepSeek
- ACWing算法提高课Treap实现