c++手撕协程库,实现生成器与自定义可等待对象
今天我们来从零实现一个非对称协程库,这个库是使用汇编实现上下文切换,实现了生成器与自定义可等待对象
协程是用户态的线程,它需要由程序来进行调度,如上下文切换与调度设计都需要程序来设计,并且协程运行在单个线程中,这就成就了线程的低成本,简单讲协程就是一种可以被挂起与恢复的特殊函数
看之前建议先看看详解c++20的协程,自定义可等待对象,生成器详解-CSDN博客
上下文切换实际上就是保存现在的寄存器状态,恢复之前的寄存器状态,通过这一系列操作我们就可以实现切换执行的函数,下面是汇编的实现:
.section .text
.globl swap_context
swap_context:
leaq (%rsp),%rax
movq %rax, 104(%rdi)
movq %rbx, 96(%rdi)
movq %rcx, 88(%rdi)
movq %rdx, 80(%rdi)
movq 0(%rax), %rax
movq %rax, 72(%rdi)
movq %rsi, 64(%rdi)
movq %rdi, 56(%rdi)
movq %rbp, 48(%rdi)
movq %r8, 40(%rdi)
movq %r9, 32(%rdi)
movq %r12, 24(%rdi)
movq %r13, 16(%rdi)
movq %r14, 8(%rdi)
movq %r15, (%rdi)
xorq %rax, %rax
movq 48(%rsi), %rbp
movq 104(%rsi), %rsp
movq (%rsi), %r15
movq 8(%rsi), %r14
movq 16(%rsi), %r13
movq 24(%rsi), %r12
movq 32(%rsi), %r9
movq 40(%rsi), %r8
movq 56(%rsi), %rdi
movq 80(%rsi), %rdx
movq 88(%rsi), %rcx
movq 96(%rsi), %rbx
leaq 8(%rsp), %rsp
pushq 72(%rsi)
movq 64(%rsi), %rsi
ret
.section .text
此语句表明后续代码会被放在 .text
节,而 .text
节通常用来存放可执行代码,详细可见简述计算机系统中的抽象,进程、线程、虚拟内存与文件抽象_线程,进程,内存-CSDN博客中关于虚拟内存的内容。.globl swap_context
:这行代码把 swap_context
声明为全局符号,这样其他文件也能调用这个函数。
swap_context:
你可以简单的理解为就是一个函数名,冒号后面是函数的运行逻辑,我们分为两部分保存于恢复之间用换行分隔,leaq (%rsp), %rax
将当前栈指针 rsp
的值存入 rax
。此时 rax
指向栈顶(即调用 swap_context
后的返回地址)。
movq %rax, 104(%rdi)
将 %rax
(也就是当前栈指针)的值存到由 %rdi
指向的内存地址偏移 104 字节处。这里的 %rdi
一般是第一个参数,代表要保存上下文的结构体地址。
在 x86 - 64 架构中,栈指针 %rsp
是一个非常重要的寄存器,它指向当前栈的栈顶。栈是一种后进先出(LIFO)的数据结构,常用于存储函数调用的上下文信息,如局部变量、返回地址等。
当函数被调用时,系统会将一些信息压入栈中,包括返回地址、调用者的寄存器状态等,同时 %rsp
会相应地向下移动(地址值减小)。当函数返回时,系统会从栈中弹出这些信息,%rsp
会向上移动(地址值增大)。
movq %rbx, 96(%rdi)
到 movq %r15, (%rdi)
依次将寄存器 rbx, rcx, rdx, rsi, rdi, rbp, r8, r9, r12, r13, r14, r15
的值保存到旧上下文结构体的对应偏移位置。
%r15 - %r8
这些是 64 位通用寄存器,用于存储临时数据和计算结果。在上下文切换时,需要保存和恢复它们的值,以确保协程的执行状态可以正确恢复。
%rbp
基址指针寄存器,通常用于指向当前栈帧的底部。在函数调用时,%rbp
可以帮助访问局部变量和参数。
%rdi
和 %rsi
:通常用于传递函数的第一个和第二个参数。在汇编代码中,%rdi
指向要保存上下文的结构体地址,%rsi
指向要恢复上下文的结构体地址。
%rbx
、%rcx
、%rdx
:这些寄存器也用于存储临时数据和计算结果,在上下文切换时需要保存和恢复。
movq 0(%rax), %rax
: 获取栈顶的返回地址(即调用 swap_context
后的下一条指令地址),保存到偏移 72 处。movq %rdi, 56(%rdi)
: 保存旧上下文结构体指针自身(rdi
)到其偏移 56 处。xorq %rax, %rax
将 rax
清零
后续的 movq
指令从由 %rsi
指向的内存区域读取数据,并将其恢复到对应的寄存器中。例如,movq 48(%rsi), %rbp
从 %rsi
偏移 48 字节的内存地址读取数据,并将其恢复到 %rbp
寄存器。
接下来我们来看看c++的实现,先看与汇编代码最相关的:
struct Context {
void *regs[14];
};
extern "C" void swap_context(Context* old_ctx, Context* new_ctx, void* arg);
Context定义了我们用到的14个寄存器,而swap_context函数就是我们代码层面用来调用汇编的接口,这个函数只有声明没有定义,在连接阶段会将汇编代码与当前的声明合并,详细的c++编译过程可以看详解c++的编译过程,如何从源文件到可执行文件到-CSDN博客
static ThreadPool thread_pool;
template<typename T>
class Coroutine {
public:
template<typename Func, typename... Args>
explicit Coroutine(Func&& func, Args&&... args) {
setCoroutine(std::forward<Func>(func), std::forward<Args>(args)...);
}
~Coroutine() { destroy(); }
template <typename Func, typename... Args>
void setCoroutine(Func &&func, Args &&...args) {
func_ = std::bind(std::forward<Func>(func), this, std::forward<Args>(args)...);
stack_mem = malloc(STACK_SIZE);
void *stack_top = static_cast<char *>(stack_mem) + STACK_SIZE;
stack_top = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(stack_top) & ~0xF);
ctx_coroutine = Context();
ctx_coroutine.regs[13] = stack_top;
ctx_coroutine.regs[9] = reinterpret_cast<void *>(&Coroutine::entry);
ctx_coroutine.regs[7] = this;
ctx_coroutine.regs[6] = stack_top;
resume();
}
void await(){ swap_context(&ctx_coroutine, &sched_ctx, nullptr); }
void await_time(int ms){
thread_pool.addTask([ms, this]{
std::this_thread::sleep_for(std::chrono::milliseconds(ms));
resume();
});
await();
}
template<typename Func, typename... Args>
auto await_future_return(Func&& func, Args&&... args) -> decltype(func(std::forward<Args>(args)...)) {
using ReturnType = decltype(func(std::forward<Args>(args)...));
ReturnType resp;
auto fun = std::bind(func, std::forward<Args>(args)...);
thread_pool.addTask([this, fun = std::move(fun), &resp]() mutable {
resp = fun();
resume();
});
await();
return resp;
}
void resume() { if (stack_mem) swap_context(&sched_ctx, &ctx_coroutine, this); }
void destroy() {
if (stack_mem) free(stack_mem);
stack_mem = nullptr;
}
void yield_value(T val);
void return_value(T val);
T get_yield_value();
T get_return_value() { return value_; }
private:
Context sched_ctx;
Context ctx_coroutine;
T value_;
void* stack_mem = nullptr;
std::function<void()> func_;
static constexpr int STACK_SIZE = 256 * 1024;
static void entry(void* arg) {
Coroutine* self = static_cast<Coroutine*>(arg);
self->func_();
swap_context(&self->ctx_coroutine, &self->sched_ctx, nullptr);
}
};
template<typename T>
inline void Coroutine<T>::yield_value(T val) {
value_ = val;
swap_context(&ctx_coroutine, &sched_ctx, nullptr);
}
template <typename T>
inline void Coroutine<T>::return_value(T val) {
value_ = val;
swap_context(&ctx_coroutine, &sched_ctx, nullptr);
destroy();
}
template <typename T>
inline T Coroutine<T>::get_yield_value() {
if (stack_mem) swap_context(&sched_ctx, &ctx_coroutine, this);
return value_;
}
我们使用了一个线程池来实现可等待对象,这个线程池的具体实现可以看c++ 手撕线程池这个协程库的实现基本仿照c++20的协程,不过大大简化了使用难度,不再需要直接事件各种类,下面可以看下使用演示:
int add(int a, int b){
std::this_thread::sleep_for(std::chrono::milliseconds(2));
return a + b;
}
std::string str(){
return "Hello, world!";
}
void fibonacci(Coroutine<int>* co) {
int a = 0, b = 1, c;
while (true) {
co->yield_value(a);
c = a + b;
a = b;
b = c;
}
}
void example_function(Coroutine<int>* co) {
std::cout << "Example function started" << std::endl;
co->await_time(1);
std::cout << "Example function ended" << std::endl;
}
void func1(Coroutine<std::string>* co, int a, const std::string& b) {
std::string s;
for (int i = 0; i < a; i++) {
s += b + " ";
}
co->return_value(s);
}
void func2(Coroutine<int>* co, int a) {
int i = co->await_future_return(add, 242, a);
std::cout << "Co1 Received int: " << i << std::endl;
auto ss = co->await_future_return(str);
std::cout << "Co1 Received string: " << ss << std::endl;
}
int main() {
Coroutine<int> co(example_function);
Coroutine<int> co2(fibonacci);
for (int i = 0; i < 10; i++) {
int fib = co2.get_yield_value();
std::cout << fib << " ";
}
std::cout << std::endl;
Coroutine<std::string> co3(func1, 3, "hello");
std::string str = co3.get_return_value();
std::cout << "Co3 Received string: " << str << std::endl;
Coroutine<int> co4(func2, 25);
std::cout << "Co4 Received int: " << std::endl;
std::this_thread::sleep_for(std::chrono::milliseconds(10));
return 0;
}
// 输出
Example function started
1 1 2 3 5 8 13 21 34 55
Co3 Received string: hello hello hello
Co4 Received int:
Example function ended
Co1 Received int: 267
Co1 Received string: Hello, world!