python调用底层c++算子示例
test.cpp
#include <torch/extension.h>// 定义一个简单的加法函数
at::Tensor add(at::Tensor a, at::Tensor b) {return a + b;
}PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {m.def("add", &add, "Add two tensors");
}
test.py
import torch
from torch.utils.cpp_extension import loadtest_load = load(name='test_load', sources=['test.cpp'],extra_cflags=['-O2'],verbose=True,
)
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = test_load.add(a, b)
print(result) # Should print tensor([5, 7, 9])
注意:load中name实际上传给test.cpp中TORCH_EXTENSION_NAME, 会编译生成一个test_load.so 动态库
test2.py:调用test.py中test_load模块
from test import test_load
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = test_load.add(a, b)
print(result) # Should print tensor([5, 7, 9])