DeepSeek辅助实现的DuckDB copy to自定义函数
copy to自定义函数指将DuckDB 数据库中的数据导出成各种自己需要的格式,除了官方已经提供的csv、parquet、xlsx,只要知道某个格式用什么库来写,就将读出的数据填充到那个库的写函数中即可。
本来以为这是最简单的一个,因为有DuckDB自己实现copy to csv的源码在前,又有读写网上的Google Gsheet插件在后,随便抄哪个都能抄对,但DeepSeek折腾了一整天也没有输出一个正确的程序,最后还是人工比对他的实现和人家正确实现的不同点,才解决。
关键在于:
1.InitializeGlobal
只能做到写标题,不能写数据本身,后者要用InitializeLocal
完成,否则就连Sink
函数都够不着执行,就报Segmentation fault (core dumped)
错误退出,至少今天的试验是这样的。
2.读数据要从Datachunk类型的参数中读,而不是从LocalFunctionData
参数读,否则,虽然local.size()
和local.ColumnCount()
的值都正确,GetValue
读出来的全是空白。
这就是一整天的教训,也怪我让他做一个精简的实现,而不是原封不动地照抄。
源代码如下,它实现了一个mycsv
文件后缀名,并在数据中插入myduck
前缀,以便与系统的内置csv区分。
#include "duckdb.hpp"
#include "duckdb/common/file_system.hpp"
#include "duckdb/common/serializer/buffered_file_writer.hpp"
#include "duckdb/catalog/catalog_entry/copy_function_catalog_entry.hpp"
#include "duckdb/function/copy_function.hpp"
#include "duckdb/parser/parsed_data/create_copy_function_info.hpp"
#include "duckdb/main/extension_util.hpp"
#include <iostream>#define DEBUG_LOG(msg) //std::cerr << "[DEBUG] " << msg << std::endl
namespace duckdb {// 1. 全局状态// 修改全局状态管理,确保文件正确关闭
// 1. 修改全局状态类,添加写入状态跟踪
struct MyCSVCopyGlobalState : public GlobalFunctionData {explicit MyCSVCopyGlobalState(unique_ptr<BufferedFileWriter> writer, string file_path): writer(std::move(writer)), file_path(std::move(file_path)), initialized(true) {DEBUG_LOG("Writer initialized for: " << this->file_path << ", writer valid: " << (this->writer != nullptr));}~MyCSVCopyGlobalState() {if (writer) {DEBUG_LOG("Final flush for: " << file_path);writer->Flush();}}unique_ptr<BufferedFileWriter> writer;string file_path;bool initialized = false;};
// 2. 配置选项
struct MyCSVWriteOptions {vector<string> name_list;string delimiter = "|";string prefix = "myduck";bool header = true;
};struct MyCSVLocalState : public LocalFunctionData {explicit MyCSVLocalState(ClientContext &context, const vector<LogicalType> &sql_types): executor(context) {// 初始化转换用的DataChunkcast_chunk.Initialize(Allocator::Get(context), GetVarcharTypes(sql_types));}// 类型转换执行器ExpressionExecutor executor;// 用于存储转换后的字符串数据DataChunk cast_chunk;private:static vector<LogicalType> GetVarcharTypes(const vector<LogicalType> &sql_types) {vector<LogicalType> varchar_types;for (auto &type : sql_types) {varchar_types.push_back(LogicalType::VARCHAR);}return varchar_types;}
};// 3. 绑定数据(实现Equals方法)
struct MyCSVWriteBindData : public TableFunctionData {vector<string> files;MyCSVWriteOptions options;vector<LogicalType> sql_types;MyCSVWriteBindData(string file_path, vector<LogicalType> sql_types, vector<string> names,string delimiter = "|",string prefix = "myduck",bool header = true): sql_types(std::move(sql_types)) {files.push_back(std::move(file_path));options.name_list = std::move(names);options.delimiter = std::move(delimiter);options.prefix = std::move(prefix);options.header = std::move(header);}unique_ptr<FunctionData> Copy() const override {return make_uniq<MyCSVWriteBindData>(files[0], sql_types, options.name_list, options.delimiter, options.prefix, options.header);}bool Equals(const FunctionData &other) const override {auto &other_bind = other.Cast<MyCSVWriteBindData>();return files == other_bind.files && options.delimiter == other_bind.options.delimiter &&options.prefix == other_bind.options.prefix &&options.header == other_bind.options.header;}
};// 4. 主功能类
class MyCSVCopyFunction : public CopyFunction {
public:MyCSVCopyFunction() : CopyFunction("mycsv") {copy_to_bind = Bind;copy_to_initialize_global = InitializeGlobal;copy_to_initialize_local = InitializeLocal;copy_to_sink = Sink;DEBUG_LOG("Pointers registered: " << (void*)copy_to_bind << ", "<< (void*)copy_to_initialize_global << ", "<< (void*)copy_to_initialize_local << ", "<< (void*)copy_to_sink);}static unique_ptr<FunctionData> Bind(ClientContext &context, CopyFunctionBindInput &input,const vector<string> &names,const vector<LogicalType> &sql_types);static unique_ptr<GlobalFunctionData> InitializeGlobal(ClientContext &context, FunctionData &bind_data,const string &file_path);static unique_ptr<LocalFunctionData> InitializeLocal(duckdb::ExecutionContext&, duckdb::FunctionData&);static void Sink(ExecutionContext &context,FunctionData &bind_data,GlobalFunctionData &gstate,LocalFunctionData &lstate,DataChunk &input);
};// 辅助写入函数
static void WriteCSVString(BufferedFileWriter &writer, const string &str) {writer.WriteData(reinterpret_cast<const_data_ptr_t>(str.c_str()),str.size());
}// 5. 绑定函数实现
unique_ptr<FunctionData> MyCSVCopyFunction::Bind(ClientContext &context, CopyFunctionBindInput &input,const vector<string> &names,const vector<LogicalType> &sql_types) {auto bind_data = make_uniq<MyCSVWriteBindData>(input.info.file_path, sql_types, names);// 处理选项参数for (auto &option : input.info.options) {if (option.first == "delimiter" && !option.second.empty()) {bind_data->options.delimiter = option.second[0].ToString();} else if (option.first == "prefix" && !option.second.empty()) {bind_data->options.prefix = option.second[0].ToString();} else if (option.first == "header" && !option.second.empty()) {bind_data->options.header = option.second[0].CastAs(context, LogicalType::BOOLEAN).GetValue<bool>();}}return std::move(bind_data);
}unique_ptr<LocalFunctionData> MyCSVCopyFunction::InitializeLocal(ExecutionContext &context, FunctionData &bind_data) {DEBUG_LOG("Initializing thread-local state for worker ");auto &data = bind_data.Cast<MyCSVWriteBindData>();// 创建线程本地状态auto local_state = make_uniq<MyCSVLocalState>(context.client, data.sql_types);// 如果需要表达式转换(例如日期格式化),在此初始化executor/*if (!data.options.force_quote.empty()) {vector<unique_ptr<Expression>> expressions;// 构建转换表达式...local_state->executor.Initialize(expressions);}*/DEBUG_LOG("Thread-local state initialized with "<< data.sql_types.size() << " columns");return std::move(local_state);
}
// 修改InitializeGlobal函数
unique_ptr<GlobalFunctionData> MyCSVCopyFunction::InitializeGlobal(ClientContext &context, FunctionData &bind_data,const string &file_path) {DEBUG_LOG("Initializing global state for file: " << file_path);auto &data = bind_data.Cast<MyCSVWriteBindData>();auto &fs = FileSystem::GetFileSystem(context);try {// 检查文件是否可写DEBUG_LOG("Checking file access: " << file_path);auto handle = fs.OpenFile(file_path, FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW);handle->Close();// 创建文件写入器DEBUG_LOG("Creating BufferedFileWriter");auto writer = make_uniq<BufferedFileWriter>(fs, file_path);// 写入表头if (data.options.header) {DEBUG_LOG("Writing header");for (size_t i = 0; i < data.options.name_list.size(); ++i) {if (i != 0) {WriteCSVString(*writer, data.options.delimiter);}WriteCSVString(*writer, data.options.prefix);WriteCSVString(*writer, data.options.name_list[i]);}WriteCSVString(*writer, "\n");writer->Flush();DEBUG_LOG("Header written successfully");}return make_uniq<MyCSVCopyGlobalState>(std::move(writer), file_path);} catch (const std::exception &e) {DEBUG_LOG("InitializeGlobal failed: " << e.what());throw;}
}void MyCSVCopyFunction::Sink(ExecutionContext &context,FunctionData &bind_data,GlobalFunctionData &gstate,LocalFunctionData &lstate,DataChunk &input) {auto &state = gstate.Cast<MyCSVCopyGlobalState>();auto &local = lstate.Cast<MyCSVLocalState>();// 1. 类型转换local.cast_chunk.Reset();local.cast_chunk.SetCardinality(input);local.executor.Execute(input, local.cast_chunk);// 2. 写入数据for (idx_t row = 0; row < local.cast_chunk.size(); row++) {for (idx_t col = 0; col < local.cast_chunk.ColumnCount(); col++) {auto val = input.GetValue(col, row);//std::cout<<val.ToString()<<std::endl;WriteCSVString(*state.writer, val.IsNull() ? "NULL" : val.ToString());}WriteCSVString(*state.writer, "\n");}
}// 8. 注册函数(修正注册方式)
void RegisterMyCSVFunction(DatabaseInstance &db) {// Register COPY TO (FORMAT 'mycsv') functionMyCSVCopyFunction MyCSV_copy_function;ExtensionUtil::RegisterFunction(db, MyCSV_copy_function);DEBUG_LOG("RegisterMyCSVFunction started");return;}} // namespace duckdb
using namespace duckdb;
int main() {try {DuckDB db(nullptr);Connection con(db);// 注册自定义格式RegisterMyCSVFunction(*db.instance);// 创建测试数据DEBUG_LOG("Creating test table");auto create_result = con.Query("CREATE TABLE test AS SELECT i, 'val_'||i AS text FROM range(5) t(i)");if (create_result->HasError()) {DEBUG_LOG("Create table error: " << create_result->GetError());return 1;}auto result1 = con.Query("from test");result1->Print();auto explain = con.Query("EXPLAIN COPY test TO 'output.mycsv' WITH (FORMAT mycsv,DELIMITER '|',PREFIX 'myprefix')");if (explain) {std::cout <<explain->GetValue(1,0);}DEBUG_LOG("Executing COPY command");auto result = con.Query(R"(COPY test TO 'output.mycsv' WITH (FORMAT mycsv))");if (result->HasError()) {std::cerr << "Error: " << result->GetError() << std::endl;return 1;}DEBUG_LOG("Execution completed successfully");return 0;} catch (const std::exception &e) {DEBUG_LOG("Fatal error in main: " << e.what());return 1;} return 0;
}