C++八股 —— 手撕shared_ptr
文章目录
- 一、列出需要实现的接口
- 二、实现细节
- 三、接口细节
- 四、完整代码:
来自:【面试精选】大佬带你一周刷完一线互联网大厂C++面试八股文,比啃书效果好多了!_哔哩哔哩_bilibili
字节C++二面:
手撕shared_ptr
,要求:
- 不考虑删除器和空间配置器
- 不考虑弱引用
- 考虑引用计数的线程安全
- 提出测试案例
相关概念参考:
- C++八股——智能指针-CSDN博客
- C++八股 —— 原子操作-CSDN博客
- C++八股——关键字_c++ volatile cpu缓存-CSDN博客
一、列出需要实现的接口
- 构造函数
- 析构函数
- 拷贝构造函数
- 拷贝赋值运算符
- 移动构造函数
- 移动赋值运算符
- 解引用、箭头运算符
- 引用计数、原始指针、重置指针
二、实现细节
-
空的
shared_ptr
大小为16字节不考虑删除器、空间配置器、弱引用,只有引用计数和指针,所以空的
shared_ptr
大小16字节 -
std::atomic<std::size_t>*
引用计数原因参考C++八股——智能指针-CSDN博客中的
shared_ptr
部分
三、接口细节
- 有参构造函数需要
explicit
修饰 - 拷贝构造函数和拷贝赋值运算符需要
const T &
常引用 - 移动构造函数和移动赋值运算符需要
noexcept
修饰 - 只读接口用
const
修饰
四、完整代码:
shared_ptr.h
:
#pragma once#include <atomic>template <typename T>
class shared_ptr {
private:T* ptr; // 指向管理的对象std::atomic<std::size_t>* ref_count; // 原子引用计数// 释放资源void release() {// 使用 std::memory_order_acq_rel 内存序,保证释放资源时的原子性if (ref_count && ref_count->fetch_sub(1, std::memory_order_acq_rel) == 1) {delete ptr; // 删除对象delete ref_count; // 删除引用计数}ptr = nullptr; // 清空指针ref_count = nullptr; // 清空引用计数}public:// 默认构造函数shared_ptr() : ptr(nullptr), ref_count(nullptr) {}// 构造函数// 使用explicit关键字防止隐式转换// shared_ptr<int> ptr1 = new int(10); 不允许出现explicit shared_ptr(T* p) : ptr(p), ref_count(p ? new std::atomic<std::size_t>(1) : nullptr) {}// 析构函数~shared_ptr() {release(); // 释放资源}// 拷贝构造函数shared_ptr(const shared_ptr<T>& other) : ptr(other.ptr), ref_count(other.ref_count) {if (ref_count) {ref_count->fetch_add(1, std::memory_order_relaxed); // 增加引用计数,不需要增强内存序,可以加快代码的执行速度}}// 赋值运算符shared_ptr<T>& operator=(const shared_ptr<T>& other) {if (this != &other) { // 防止自赋值release(); // 释放当前资源ptr = other.ptr; // 复制指针ref_count = other.ref_count; // 复制引用计数if (ref_count) {ref_count->fetch_add(1, std::memory_order_relaxed); // 增加引用计数}}return *this;}// 移动构造函数// 使用noexcept关键字,表示该函数不会抛出异常,帮助编译器优化代码,不需要为异常处理生成额外的代码// 标准库中的某些操作(如:std::swap)要求移动操作是noexcept的,以确保异常安全shared_ptr(shared_ptr<T>&& other) noexcept : ptr(other.ptr), ref_count(other.ref_count) {other.ptr = nullptr; // 清空原对象的指针other.ref_count = nullptr; // 清空原对象的引用计数}// 移动赋值运算符shared_ptr<T>& operator=(shared_ptr<T>&& other) noexcept {if (this != &other) { // 防止自赋值release(); // 释放当前资源ptr = other.ptr; // 复制指针ref_count = other.ref_count; // 复制引用计数other.ptr = nullptr; // 清空原对象的指针other.ref_count = nullptr; // 清空原对象的引用计数}return *this;}// 重载解引用运算符T& operator*() const {return *ptr;}// 重载箭头运算符T* operator->() const {return ptr;}// 获取引用计数std::size_t use_count() const {return ref_count ? ref_count->load(std::memory_order_acquire) : 0; // 使用 relaxed 内存序,获取引用计数}// 获取原始指针T* get() const {return ptr;}// 重置指针void reset(T* p = nullptr) {release(); // 释放当前资源ptr = p; // 复制指针ref_count = p ? new std::atomic<std::size_t>(1) : nullptr; // 创建新的引用计数}
};
测试代码:
#include <iostream>
#include <thread>
#include <vector>
#include <chrono>#include "shared_ptr.h"void test_shared_ptr_thread_safety() {shared_ptr<int> ptr(new int(10)); // 创建一个shared_ptr对象,管理一个int类型的对象std::cout << "Initial value: " << *ptr << std::endl; // 输出初始值// 创建多个线程,测试线程安全性const int num_threads = 5;std::vector<std::thread> threads;for (int i = 0; i < num_threads; ++i) {threads.emplace_back([&ptr]() {for (int j = 0; j < 5; ++j) {shared_ptr<int> local_ptr(ptr); // 创建一个新的shared_ptr对象,引用计数加1std::cout << "use_count: " << ptr.use_count() << std::endl; // 输出引用计数std::this_thread::sleep_for(std::chrono::milliseconds(1000)); // 模拟一些工作}});}for (auto& t : threads) {t.join(); // 等待所有线程完成}// 检查引用计数是否正确std::cout << "use_count: " << ptr.use_count() << std::endl; // 输出引用计数if (ptr.use_count() == 1) {std::cout << "Thread safety test passed!" << std::endl; // 如果引用计数等于线程数,测试通过} else {std::cout << "Thread safety test failed!" << std::endl; // 否则测试失败}
}int main() {test_shared_ptr_thread_safety(); // 测试shared_ptr的线程安全性return 0;
}