01_svm_二分类
环境
- MinGW:7.3.0
- Dlib:20.0
dlib 库SVM 二分类
模型训练
dlib库二分类问题中1表示正类,-1表示反类(这一点与其他库不太一样)
训练模型前需要将样本数据转换为std::vector<dlib::matrix<double,T,1>>形式(matrix 是一个多维数组的模板类,可以用来表示不同类型的矩阵)
typedef dlib::matrix<double,30,1> cancer_type; // 定义一个矩阵类型cancer_type,用于存储样本数据
// 加载数据
vector<BreastCancer> vec = LoadCancer(pfile); // 通过文件加载数据(sklearn中的breast_cancer.csv)
vector<cancer_type> train_datas;
vector<double> train_labels;
for (int i = 0; i < vec.size(); i++)
{BreastCancer cancer = vec.at(i);cancer_type cc_type;for (int j = 0; j < 30; j++){cc_type(j) = cancer.ft[j];}train_datas.push_back(cc_type);train_labels.push_back(cancer.tag > 0 ? 1 : -1); // 注意正、反类的取值
}
svm算法对数据缩放敏感(可以参考蜥蜴书《Python机器学习基础教程》),这里也对数据进行缩放处理
dlib::vector_normalizer<cancer_type> normalizer; // 数据预处理,使不同特征的量纲统一,帮助模型训练时更高效收敛。// sklearn中也称这种预处理算法为无监督算法
normalizer.train(train_datas);
for (int i = 0; i < train_datas.size(); i++)
{train_datas[i] = normalizer(train_datas[i]);
}
创建模型并训练
typedef dlib::radial_basis_kernel<cancer_type> kernel_type; // rbf 径向基函数核,主要用于支持向量机(SVM)、高斯过程回归等机器学习算法,能够实现输入空间的非线性映射。
typedef dlib::decision_function<kernel_type> dec_func_type; // dlib::decision_function是dlib库中用于支持向量机(SVM)分类器决策的核心接口,其主要功能是通过训练后的模型对新样本进行分类或回归预测
typedef dlib::normalized_function<dec_func_type> func_type; // 用于对决策函数进行归一化处理的工具类,主要用于支持向量机(SVM)等分类器的输出标准化。typedef dlib::probabilistic_decision_function<kernel_type> pro_func_type; // 与decision_function类似,但输出的是概率值
typedef dlib::normalized_function<pro_func_type> p_func_type;// 打乱数据
dlib::randomize_samples(train_datas, train_labels);
// 创建训练器
dlib::svm_c_trainer<kernel_type> trainer;
trainer.set_c(10); // 设置C参数
trainer.set_kernel(kernel_type(0.01)); // 设置gamma参数
// 训练模型
p_func_type learned_func;
learned_func.normalizer = normalizer;
// learned_func.function = trainer.train(train_datas, train_labels); // 训练模型
learned_func.function = dlib::train_probabilistic_decision_function(trainer, train_datas, train_labels, 3);// 训练模型,数值3表示对样本数据的折叠数,可以参考sklearn中的模型评估
模型验证
int ok_count = 0;
for (int i = 0; i < vec.size(); i++)
{BreastCancer cancer = vec.at(i);cancer_type cc_type;for (int j = 0; j < 30; j++){cc_type(j) = cancer.ft[j];}double ret_d = learned_func(cc_type);cout << "probabilistic : " << ret_d << endl; // 概率值,大于0.5的是正类,否则视为反类int ret = ret_d > 0.5 ? 1 : 0;if (ret == cancer.tag)ok_count += 1;
}
cout << "accurary:" << (ok_count * 1.0) / vec.size() << endl;
模型保存与加载
通过dlib::serialize可以保存模型
dlib::serialize("svm_cancer_model.dat") << learned_func; // 保存模型到svm_cancer_model.dat
通过dlib::deserialize加载模型
p_func_type learned_func; //
dlib::deserialize(file_name) >> learned_func;
其他
加载breast_cancer.csv数据
typedef struct {float ft[30];int tag;
} BreastCancer;vector<BreastCancer> LoadCancer(const string &fpath)
{vector<BreastCancer> vec;fstream input_file(fpath);string line;if (input_file.is_open()){getline(input_file, line); // 跳过头while (getline(input_file, line)){vector<string> sp_vec = SplitString(line, ',');if (sp_vec.size() == 0)continue;BreastCancer cancer;for (int i = 0; i < 30; i++){cancer.ft[i] = atof(sp_vec.at(i).c_str());}cancer.tag = atoi(sp_vec.at(30).c_str());vec.push_back(cancer);}}return vec;
}vector<string> SplitString(const string &str, char delim)
{vector<string> vec;istringstream iss(str);string token;while (getline(iss, token, delim)){vec.push_back(token);}return vec;
}