当前位置: 首页 > news >正文

03_全连接神经网络

描述

dlib的神经网络算法虽然没法与tensorflow、pytorch这些框架相提并论的(dlib中可供选择的的损失函数、优化器十分有限,缺失自动微分等等),但不代表dlib不是一个强大的机器学习库。

dlib广泛应用于人脸识别、关键点检测、图像处理等领域,支持Linux/Windows/macOS多平台运行(跨平台编译很方便)。

这记录一下全连接网络下二分类、多分类、回归模型的设计与训练

二分类

处理二分类问题时,优先使用loss_binary_log(逻辑损失)、loss_binary_hinge(最大化分类间隔,SVM中最常用)损失函数,专门处理二分类问题的。这里需要注意:1表示正类,-1表示反类。

这里以breast_cancer.csv数据分类为例,设计一个30x60x60x1的神经网络(这里要注意:用loss_binary_xxx处理二分类问题,网络输出层只能有一个参数),当然数据归一化处理也是少不了的(为了提高模型的精度)。

void dnnBinaryClassifierTest(const string &pfile)
{vector<BreastCancer> vec = LoadCancer(pfile);vector<dlib::matrix<double>> train_datas;vector<float> train_labels; // 注意label的类型,与loss的选择有关for (int i = 0; i < vec.size(); i++){BreastCancer canser = vec.at(i);train_datas.push_back({canser.ft[0], canser.ft[1], canser.ft[2], canser.ft[3], canser.ft[4], canser.ft[5], canser.ft[6],canser.ft[7], canser.ft[8], canser.ft[9], canser.ft[10], canser.ft[11], canser.ft[12], canser.ft[13],canser.ft[14], canser.ft[15], canser.ft[16], canser.ft[17], canser.ft[18], canser.ft[19], canser.ft[20],canser.ft[21], canser.ft[22], canser.ft[23], canser.ft[24], canser.ft[25], canser.ft[26], canser.ft[27],canser.ft[28], canser.ft[29]});train_labels.push_back(canser.tag > 0 ? 1 : -1);}dlib::vector_normalizer<dlib::matrix<double>> normalizer; // 数据归一化normalizer.train(train_datas);for (int i = 0; i < train_datas.size(); i++){train_datas[i] = normalizer(train_datas.at(i));}dlib::randomize_samples(train_datas, train_labels); // 打乱数据vector<dlib::matrix<double>> test_datas;vector<float> test_labels;test_datas.assign(train_datas.begin(), train_datas.begin() + 150);test_labels.assign(train_labels.begin(), train_labels.begin() + 150);train_datas.assign(train_datas.begin() + 150, train_datas.end());train_labels.assign(train_labels.begin() + 150, train_labels.end());// 定义网络结构:输出层<-中间层隐层<-...<-输入层using net_type = dlib::loss_binary_log<dlib::fc<1, dlib::relu<dlib::fc<60, dlib::relu<dlib::fc<60, dlib::relu<dlib::fc<30, dlib::input<dlib::matrix<double>>>>>>>>>>;net_type net;dlib::dnn_trainer<net_type> trainer(net);trainer.set_max_num_epochs(50);   // 设置训练最大轮数trainer.set_learning_rate(0.0001);// 设置学习率trainer.set_mini_batch_size(4);   // 设置数据最批次trainer.be_verbose();			 // 输出训练过程日志trainer.train(train_datas, train_labels); // 训练模型net.clean(); // 很有必要执行。官方解释:模型在训练结束后会保留最后一批训练数据的相关状态,当然这些数据对模型预测结果不会产生任何影响,但会影响模型文件的大小int ok_count = 0;for (int i = 0; i < test_datas.size(); i++){double ret_f = net(test_datas.at(i));cout << "predicted : " << ret_f << ",real val:" << test_labels.at(i) << endl;int pred_val = ret_f > 0 ? 1 : -1; // 大于0的是正类,小于0的是反类(会不会等于0,理论上有可能,但概率极低)if (test_labels.at(i) == pred_val)ok_count += 1;}cout << "accurary:" << (ok_count * 1.0) / test_datas.size() << endl;// 保存模型dlib::serialize("dnn_cancer_model.dat")<<net;dlib::serialize("dnn_cancer_normalizer.dat")<<normalizer;
}

多分类

多分类问题可以选择loss_multiclass_log损失函数。网络输出层参数与类别数保持一致,类别从0开始,网络输出结果为类别编号。

以鸢尾花分类为例,设计一个4x30x30x3的神经网络。

void dnnMultiClassifierTest(const string &pfile)
{vector<Iris> vec = LoadIris(pfile);vector<dlib::matrix<double>> train_datas;vector<unsigned long> train_labels; // 注意类型int type_no = 0;map<string, int> temp_map;for (int i = 0; i < vec.size(); i++){Iris iris = vec.at(i);train_datas.push_back({iris.ft[0], iris.ft[1], iris.ft[2], iris.ft[3]});if (temp_map.find(iris.specie) == temp_map.end()){temp_map[iris.specie] = type_no;type_no += 1;}train_labels.push_back(temp_map[iris.specie]);}// iris的样本数据有三种分类,故输出层是3个参数;特征有4个,故输入层参数是4using net_type = dlib::loss_multiclass_log<dlib::fc<3, dlib::relu<dlib::fc<40, dlib::relu<dlib::fc<40, dlib::relu<dlib::fc<4, dlib::input<dlib::matrix<double>>>>>>>>>>;net_type net;dlib::dnn_trainer<net_type> trainer(net);trainer.set_max_num_epochs(20);trainer.set_learning_rate(0.001);trainer.set_mini_batch_size(4);trainer.be_verbose();trainer.train(train_datas,train_labels);net.clean();int ok_count=0;for (int i = 0; i < vec.size(); i++){Iris iris = vec.at(i);uint32_t ret_d = net({iris.ft[0], iris.ft[1], iris.ft[2], iris.ft[3]});cout<<"predict:"<<ret_d<<",real:"<<temp_map[iris.specie]<<endl;if(ret_d == temp_map[iris.specie]) ok_count+=1;}cout << "accurary:" << (ok_count * 1.0) / vec.size() << endl;dlib::serialize("dnn_iris_model.dat")<<net;
}

回归

对于回归问题,dlib中目前好像只有一个选择:loss_mean_squared。

以boston_house_prices.csv为例,设计一个13x50x50x1的回归模型

void dnnRegressionTest(const string &pfile)
{vector<BostonHoursPrice> vec = LoadHoursPrice(pfile);vector<float> train_labels;vector<dlib::matrix<double>> train_datas;for (int i = 0; i < vec.size(); i++){BostonHoursPrice price = vec.at(i);train_datas.push_back({price.ft[0], price.ft[1], price.ft[2], price.ft[3], price.ft[4], price.ft[5],price.ft[6], price.ft[7], price.ft[8], price.ft[9], price.ft[10], price.ft[11],price.ft[12]});train_labels.push_back(price.tag);}using net_type = dlib::loss_mean_squared<dlib::fc<1, dlib::relu<dlib::fc<50, dlib::relu<dlib::fc<50, dlib::relu<dlib::fc<13, dlib::input<dlib::matrix<double>>>>>>>>>>;net_type net;dlib::dnn_trainer<net_type> trainer(net);trainer.set_learning_rate(0.000001);trainer.set_min_learning_rate(0.00000001);trainer.set_mini_batch_size(8);trainer.set_max_num_epochs(200);trainer.be_verbose();trainer.train(train_datas, train_labels);net.clean();for (int i = 0; i < train_datas.size(); i++){float ret_f = net(train_datas.at(i));cout << "predicted : " << ret_f << ",real val:" << train_labels.at(i) << endl;}
}

补充一个SVM回归的例子,svm的回归模型更简单些:

typedef dlib::matrix<double,13,1> price_type;
typedef dlib::radial_basis_kernel< price_type> rbf_kernel;void RvmRegression(const string &pfile)
{vector<price_type> train_datas;vector<double> train_labels;LoadPriceData(pfile, train_datas, train_labels);const double gamma = 0.00001;dlib::rvm_regression_trainer<rbf_kernel> trainer; // rvm_regression_trainer trainer.set_kernel(rbf_kernel(gamma));dlib::decision_function<rbf_kernel> learned_func = trainer.train(train_datas, train_labels);for (int i = 0; i < train_datas.size(); i++){price_type pt = train_datas.at(i);double ret = learned_func(pt);cout << "predicted val:" << ret << ", real val:" << train_labels.at(i) << endl;}dlib::serialize("svm_reg_price.dat") << learned_func;
}
http://www.dtcms.com/a/540181.html

相关文章:

  • 生成式AI重塑教学生态:理论基础、核心特征与伦理边界
  • html5手机网站调用微信分享wordpress缩略图加载慢
  • 动环监控:数据中心机房的“智慧守护者”
  • 5.6对象
  • 生命线与黑箱:LIME和Anchor作为两种事后可解释性分析
  • VMware安装配置CentOS 7
  • 链表算法题
  • 织梦制作wap网站高端网站开发建设
  • 网站建设公司销售经理职责全网最大的精品网站
  • 怎么做公司网站推广cms网站开发教程
  • 解决 OpenSSL 3.6.0 在 macOS 上 Conan 构建失败的链接错误
  • metaRTC7 mac/ios编程指南
  • Go语言-->Goroutine 详细解释
  • 船舶终端数据采集与监管平台一体化方案
  • 2025年10月28日Github流行趋势
  • 《红色脉络:一部PLMN在中国的演进史诗 (1G-6G)》 第14篇 | 6G畅想:通感一体、AI内生——下一代网络的愿景与挑战
  • 「Java EE开发指南」如何用MyEclipse设置Java项目依赖项属性?
  • 输电线路防外破在线监测装置是什么
  • MTK5G旗舰系列——天玑9500/9400/9300/9200/9000在AI和处理器性能、DDR频率及UFS的深度对比分析
  • 平板做网站服务器wordpress在线直播插件
  • 前端Jquery,后端Java实现预览Word、Excel、PPT,pdf等文档
  • 华为910B服务器(搭载昇腾Ascend 910B AI 芯片的AI服务器查看服务器终端信息
  • Spring JDBC实战:参数处理与嵌入式数据库
  • 图片转PPT:用Java高效处理PowerPoint的秘籍
  • Custom Animations for PPT (PowerPoint)
  • 沈阳网站哪家做的好做视频网站设备需求
  • 【数据工程】16. Notions of Time in Stream Processing
  • AOI在传统汽车制造领域中的应用
  • 搭建网站复杂吗微信公众号怎么做链接网站
  • 网站优化推广招聘wordpress后台打开超慢