OpenCV机器学习(1)人工神经网络 - 多层感知器类cv::ml::ANN_MLP
- 操作系统:ubuntu22.04
- OpenCV版本:OpenCV4.9
- IDE:Visual Studio Code
- 编程语言:C++11
算法描述
cv::ml::ANN_MLP 是 OpenCV 库中的一部分,用于实现人工神经网络 - 多层感知器(Artificial Neural Network - Multi-Layer Perceptron, ANN-MLP)。它提供了一种方式来创建和训练多层感知器模型,以解决分类、回归等问题。
主要特点
- 多层架构:支持一个输入层、多个隐藏层和一个输出层。
- 激活函数:可以选择不同的激活函数,如Sigmoid、Identity、ReLU等。
- 训练算法:包括误差反向传播算法,用户可以指定参数如迭代次数、终止条件等。
- 正则化参数:可以设置权重衰减项,帮助防止过拟合。
常用成员函数
- create(): 创建一个指定层数和每层神经元数目的网络。
- setLayerSizes(): 设置每一层的大小(神经元数量)。
- setActivationFunction(): 设置使用的激活函数。
- train(): 使用提供的数据集进行模型训练。
- predict(): 对新的输入数据进行预测。
- save()/load(): 保存和加载训练好的模型。
使用步骤
- 初始化网络:使用 create() 函数初始化网络,并通过 setLayerSizes() 定义网络结构。
- 配置训练参数:选择激活函数、设置训练方法及相应参数。
- 准备数据:准备好训练数据集和标签。
- 训练模型:调用 train() 方法对模型进行训练。
- 评估与预测:利用 predict() 方法对新数据进行预测,并根据需要评估模型性能。
代码示例
include <iostream>
#include <opencv2/ml.hpp>
#include <opencv2/opencv.hpp>
using namespace cv;
using namespace cv::ml;
using namespace std;
int main()
{
// 训练数据及对应标签
float trainingData[ 8 ][ 2 ] = { { 480, 500 }, { 50, 130 }, { 110, 32 }, { 490, 60 }, { 60, 190 }, { 200, 189 }, { 78, 256 }, { 45, 315 } };
float labels[ 8 ] = { 0, 1, 0, 0, 1, 0, 1, 1 };
Mat trainingDataMat( 8, 2, CV_32FC1, trainingData );
Mat labelsMat( 8, 1, CV_32FC1, labels );
// 创建ANN_MLP模型
Ptr< ANN_MLP > model = ANN_MLP::create();
// 设置网络结构:输入层大小为2,隐藏层大小为2,输出层大小为1
Mat layerSizes = ( Mat_< int >( 1, 3 ) << 2, 2, 1 );
model->setLayerSizes( layerSizes );
// 设置激活函数
model->setActivationFunction( ANN_MLP::SIGMOID_SYM );
// 设置训练方法
model->setTrainMethod( ANN_MLP::BACKPROP );
model->setBackpropWeightScale( 0.1 );
model->setBackpropMomentumScale( 0.1 );
// 设置迭代终止准则
TermCriteria termCrit = TermCriteria( TermCriteria::MAX_ITER + TermCriteria::EPS, 1000, 0.01 );
model->setTermCriteria( termCrit );
// 准备训练数据
Ptr< TrainData > tData = TrainData::create( trainingDataMat, ROW_SAMPLE, labelsMat );
// 训练模型
model->train( tData );
// 预测新数据点
Mat sampleMat = ( Mat_< float >( 1, 2 ) << 500, 500 );
Mat responseMat;
float predictedClass = model->predict( sampleMat, responseMat );
cout << "Predicted class: " << predictedClass << endl;
return 0;
}
运行结果
Predicted class: 0