封装红黑树实现map和set(部分常用接口)(C++)
1. 框架分析
SGI-STL30版本源代码,map和set的源代码在map/set/stl_map.h/stl_set.h/stl_tree.h等几个头文件 中。
template <class Key, class Compare = less<Key>, class Alloc = alloc>
class set {
public:
// typedefs:
typedef Key key_type;
typedef Key value_type;
private:
typedef rb_tree<key_type, value_type,
identity<value_type>, key_compare, Alloc> rep_type;
rep_type t; // red-black tree representing set
}
template <class Key, class T, class Compare = less<Key>, class Alloc = alloc>
class map {
public:
// typedefs:
typedef Key key_type;
typedef T mapped_type; typedef pair<const Key, T> value_type;
private:
typedef rb_tree<key_type, value_type,
select1st<value_type>, key_compare, Alloc> rep_type;
rep_type t; // red-black tree representing map
};
这里着重关注,之前实现的红黑树中,用的是pair<K,T>来存储数据,这里set里面是两个key,而map里面却是key和pair<key,T>,这里其实set这样的实现是方便了map接口的一致性,这样实现一个红黑树,两个容器都能使用。
对于map和set,find/erase时的函数参数都是Key,所以第⼀个模板参数是传给find/erase等函数做形参的类型的。对于set而言两个参数是⼀样的,但是对于map而言就完全不⼀样了,map的insert是pair对象,但是find和ease的是Key对象。
2. 模拟实现map和set
2.1 利用红黑树的框架,实现insert
为RBTree实现了泛型不知道T参数导致是K,还是pair,那么insert内部进行插入逻辑比较时,就没办法进行比较,因为pair的默认支持的是key和value⼀起参与比较,我们需要只比较key,所以我们在map和set层分别实现⼀个MapKeyOfT和SetKeyOfT的仿函数传给 RBTree的KeyOfT,然后RBTree中通过KeyOfT仿函数取出T类型对象中的key,再进行比较。
map的比较:
template<class K,class V>
class map
{
struct MapKeyOFT
{
const K& operator()(const pair<K, V>& kv)
{
return kv.first;
}
};
public:
private:
RBTree<K, pair<const K, V>, MapKeyOFT> _t;
};
set的比较:
template<class T>
class set
{
struct SetKeyOFT
{
const T& operator()(const T& key)
{
return key;
}
};
public:
private:
RBTree<T,const T, SetKeyOFT> _t;
};
2.2 支持iterator的实现
template< class T, class REF, class PTR>
struct Treeiterator
{
typedef RBTreeNode<T> Node;
typedef Treeiterator<T, REF, PTR> Self;
Node* _node;
Node* _root;
Treeiterator(Node* node,Node* root)
:_node(node)
,_root(root)
{}
REF operator*()
{
return _node->_data;
}
PTR operator->()
{
return &_node->_data;
}
bool operator!=(const Self& s)
{
return _node != s._node;
}
bool operator==(const Self& s)
{
return _node == s._node;
}
};
这里的解引用,->,==,!=实现的逻辑和之前实现的迭代器是一样的,这里主要考虑的是++,--操作时,结点的变化。
2.2.1 operator++()
map和set的迭代器是中序遍历,左⼦树->根结点->右⼦树。
迭代器++时,如果it指向的结点的右子树不为空,代表当前结点已经访问完了,要访问下⼀个结点 是右子树的中序第⼀个,⼀棵树中序第⼀个是最左结点,所以直接找右子树的最左结点即可。
迭代器++时,如果it指向的结点的右子树空,代表当前结点已经访问完了且当前结点所在的子树也 访问完了,要访问的下⼀个结点在当前结点的祖先里面,所以要沿着当前结点到根的祖先路径向上 找。(也就是找到cur为parent->_left)
还有就是end()如何表示呢?我们这里通过直接将其置为nullptr,这样++,如果是最后一个结点了,就会向上找祖先结点,刚好根节点的_parent结点为空,相当于是找到了最后一个结点的下一个位置,且为空。
Self& operator++()
{
if (_node->_right)
{
_node = _node->_right;
while (_node->_left)
{
_node = _node->_left;
}
}
else
{
Node* cur = _node;
Node* parent = cur->_parent;
while (parent && cur == parent->_right)
{
cur = parent;
parent = parent -> _parent;
}
_node = parent;
}
return *this;
}
2.2.2 operator--()
这里和++逻辑一样,只是相反而已,找左子树的最左结点,以及当左子树为空时,向上找祖先结点。
Self& operator--()
{
if (_node == nullptr)
{
Node* rightMost = _root;
while (rightMost && rightMost->_right)
{
rightMost = rightMost->_right;
}
_node = rightMost;
}
else if(_node->_left)
{
Node* rightMost = _node->_left;
while (rightMost->_right)
{
rightMost = rightMost->_right;
}
_node = rightMost;
}
else
{
Node* cur = _node;
Node* parent = cur->_parent;
while (parent && cur == parent->_left)
{
cur = parent;
parent = parent->_parent;
}
_node = parent;
}
return *this;
}
2.3 map支持[ ]
map要支持[ ],需要修改insert返回值,修改RBtree中的insert返回值为pai<iterator,bool>r Insert(const T& data)
V& operator[](const K& key)
{
pair<iterator, bool> ret = _t.insert({ key,V()});
return ret.first->second;
}
2.4 整体实现
set的iterator也不能修改,这样会破坏红黑树中搜索的结构,把set的第⼆个模板参数改成const K即可。
2.4.1 set实现
#include"RBTree.h"
template<class T>
class set
{
struct SetKeyOFT
{
const T& operator()(const T& key)
{
return key;
}
};
public:
typedef typename RBTree<T, const T, SetKeyOFT>::iterator iterator;
typedef typename RBTree<T, const T, SetKeyOFT>::const_iterator const_iterator;
pair<iterator, bool> insert(const T& key)
{
return _t.insert(key);
}
iterator begin()
{
return _t.Begin();
}
iterator end()
{
return _t.End();
}
const_iterator begin() const
{
return _t.Begin();
}
const_iterator end() const
{
return _t.End();
}
private:
RBTree<T,const T, SetKeyOFT> _t;
};
2.4.2 map实现
map的iterator不能修改key但是可以修改value,我们把map的第⼆个模板参数pair的第⼀个参数改成const K即可,RBTree<K,pair<const K,V>,MapKeyOfT> _t;
template<class K,class V>
class map
{
struct MapKeyOFT
{
const K& operator()(const pair<K, V>& kv)
{
return kv.first;
}
};
public:
typedef typename RBTree<K, pair<const K, V>, MapKeyOFT>::iterator iterator;
typedef typename RBTree<K, pair<const K, V>, MapKeyOFT>::const_iterator const_iterator;
pair<iterator,bool> insert(const pair<K, V>& kv)
{
return _t.insert(kv);
}
iterator begin()
{
return _t.Begin();
}
iterator end()
{
return _t.End();
}
const_iterator begin() const
{
return _t.Begin();
}
const_iterator end() const
{
return _t.End();
}
V& operator[](const K& key)
{
pair<iterator, bool> ret = _t.insert({ key,V()});
return ret.first->second;
}
private:
RBTree<K, pair<const K, V>, MapKeyOFT> _t;
};
2.4.3 RBTree代码:
#include<iostream>
using namespace std;
enum Color
{
RED,
BLACK
};
template<class T>
struct RBTreeNode
{
T _data;
RBTreeNode<T>* _parent;
RBTreeNode<T>* _left;
RBTreeNode<T>* _right;
Color _col;
RBTreeNode(const T& data)
:_data(data)
, _parent(nullptr)
, _left(nullptr)
, _right(nullptr)
{}
};
template< class T, class REF, class PTR>
struct Treeiterator
{
typedef RBTreeNode<T> Node;
typedef Treeiterator<T, REF, PTR> Self;
Node* _node;
Node* _root;
Treeiterator(Node* node,Node* root)
:_node(node)
,_root(root)
{}
REF operator*()
{
return _node->_data;
}
PTR operator->()
{
return &_node->_data;
}
Self& operator++()
{
if (_node->_right)
{
_node = _node->_right;
while (_node->_left)
{
_node = _node->_left;
}
}
else
{
Node* cur = _node;
Node* parent = cur->_parent;
while (parent && cur == parent->_right)
{
cur = parent;
parent = parent -> _parent;
}
_node = parent;
}
return *this;
}
Self& operator--()
{
if (_node == nullptr)
{
Node* rightMost = _root;
while (rightMost && rightMost->_right)
{
rightMost = rightMost->_right;
}
_node = rightMost;
}
else if(_node->_left)
{
Node* rightMost = _node->_left;
while (rightMost->_right)
{
rightMost = rightMost->_right;
}
_node = rightMost;
}
else
{
Node* cur = _node;
Node* parent = cur->_parent;
while (parent && cur == parent->_left)
{
cur = parent;
parent = parent->_parent;
}
_node = parent;
}
return *this;
}
bool operator!=(const Self& s)
{
return _node != s._node;
}
bool operator==(const Self& s)
{
return _node == s._node;
}
};
template<class K, class T,class KeyOfT>
class RBTree
{
typedef RBTreeNode<T> Node;
public:
typedef Treeiterator<T, T&, T*> iterator;
typedef Treeiterator<T,const T&,const T*> const_iterator;
iterator Begin()
{
Node* cur = _root;
while (cur && cur->_left)
{
cur = cur->_left;
}
return iterator(cur, _root);
}
iterator End()
{
return iterator(nullptr, _root);
}
const_iterator Begin() const
{
Node* cur = _root;
while (cur && cur->_left)
{
cur = cur->_left;
}
return const_iterator(cur, _root);
}
const_iterator End() const
{
return const_iterator(nullptr, _root);
}
pair<iterator,bool> insert(const T&data)
{
if (_root == nullptr)
{
_root = new Node(data);
_root->_col = BLACK;
return {iterator(_root,_root),true};
}
KeyOfT oft;
Node* parent = nullptr;
Node* cur = _root;
while (cur)
{
if (oft(data) < oft(cur->_data))
{
parent = cur;
cur = cur->_left;
}
else if (oft(data) > oft(cur->_data))
{
parent = cur;
cur = cur->_right;
}
else
{
return {iterator(cur,_root),false};
}
}
cur = new Node(data);
Node* newnode = cur;
cur->_col = RED;
if (oft(parent->_data) > oft(data))
{
parent->_left = cur;
}
else
{
parent->_right = cur;
}
cur->_parent = parent;
while (parent && parent->_col == RED)
{
Node* grandfather = parent->_parent;
if (parent == grandfather->_left)
{
Node* uncle = grandfather->_right;
if (uncle && uncle->_col == RED)
{
grandfather->_col = RED;
parent->_col = BLACK;
uncle->_col = BLACK;
cur = grandfather;
parent = cur->_parent;
}
else
{
if (cur == parent->_left)
{
RotateR(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else
{
RotateL(parent);
RotateR(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
}
else
{
Node* uncle = grandfather->_left;
if (uncle && uncle->_col == RED)
{
grandfather->_col = RED;
parent->_col = BLACK;
uncle->_col = BLACK;
cur = grandfather;
parent = cur->_parent;
}
else
{
if (cur == parent->_right)
{
RotateL(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else
{
RotateR(parent);
RotateL(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
}
}
_root->_col = BLACK;
return {iterator(newnode,_root),true};
}
void RotateR(Node* parent)
{
Node* subL = parent->_left;
Node* subLR = subL->_right;
parent->_left = subLR;
if (subLR)
subLR->_parent = parent;
Node* parentParent = parent->_parent;
subL->_right = parent;
parent->_parent = subL;
if (parentParent == nullptr)
{
_root = subL;
subL->_parent = nullptr;
}
else
{
if (parent == parentParent->_left)
{
parentParent->_left = subL;
}
else
{
parentParent->_right = subL;
}
subL->_parent = parentParent;
}
}
void RotateL(Node* parent)
{
Node* subR = parent->_right;
Node* subRL = subR->_left;
parent->_right = subRL;
if (subRL)
subRL->_parent = parent;
Node* parentParent = parent->_parent;
subR->_left = parent;
parent->_parent = subR;
if (parentParent == nullptr)
{
_root = subR;
subR->_parent = nullptr;
}
else
{
if (parent == parentParent->_left)
{
parentParent->_left = subR;
}
else
{
parentParent->_right = subR;
}
subR->_parent = parentParent;
}
}
iterator find(const K& key)
{
KeyOfT kot;
Node* cur = _root;
while (cur)
{
if (kot(cur->_data) < key)
{
cur = cur->_right;
}
else if (kot(cur->_data) > key)
{
cur = cur->_left;
}
else
{
return iterator(cur, _root);
}
}
return End();
}
void Inoder()
{
_Inoder(_root);
cout << endl;
}
void _Inoder(Node* root)
{
if (root == nullptr)
return;
_Inoder(root->_left);
cout << root->_kv.first << " ";
_Inoder(root->_right);
}
bool Check(Node* root, int blacknum, const int refnum)
{
if (root == nullptr)
{
if (blacknum == refnum)
return true;
else
{
cout << "存在黑色节点的数量不相等的路径" << endl;
return false;
}
}
if (root->_col == RED && root->_parent && root->_parent->_col == RED)
{
cout << "存在连续的红色结点" << endl;
return false;
}
if (root->_col == BLACK)
{
blacknum++;
}
return Check(root->_left, blacknum, refnum) && Check(root->_right, blacknum, refnum);
}
bool IsBalanceTree()
{
if (_root == nullptr)
return true;
if (_root->_col == RED)
return false;
Node* cur = _root;
int refnum = 0;
while (cur)
{
if (cur->_col == BLACK)
{
++refnum;
}
cur = cur->_left;
}
return Check(_root, 0, refnum);
}
private:
Node* _root = nullptr;
};