当前位置: 首页 > news >正文

预训练CNN网络的迁移学习(MATLAB例)

从基于大型数据集训练的神经网络中提取层,并基于新数据集进行微调。本例使用ImageNet中的子集进行微调。

This example retrains a SqueezeNet neural network using transfer learning. This network has been trained on over a million images, and can classify images into 1000 object categories (such as keyboard, coffee mug, pencil, and many animals). The network has learned rich feature representations for a wide range of images. The network takes an image as input and outputs a prediction score for each of these classes.

Performing transfer learning and fine-tuning of a pretrained neural network typically requires less data, is much faster, and is easier than training a neural network from scratch.

To adapt a pretrained neural network for a new task, replace the last few layers (the network head) so that it outputs prediction scores for each of the classes for the new task. This diagram outlines the architecture of a neural network that makes predictions for classes, and illustrates how to edit the network so that it outputs predictions for classes.

在这里插入图片描述

ImageNet 使用 WordNet 的层级分类体系,每个类别有唯一的 ID。

  • 老虎(tiger)
    • WordNet ID: n02129604
    • 子类别: 包括孟加拉虎、西伯利亚虎(Indochinese tiger)等。
  • 兔子(rabbit)
    • WordNet ID: n02325366
    • 子类别: 如家兔(European rabbit)、野兔(hare)等。
  • 鸡(chicken)
    • WordNet ID: n01514668
    • 子类别: 如母鸡(hen)、公鸡(rooster)、小鸡(chick)等。

  • 老虎:1,300 张图片(不同虎亚种)。
  • 兔子:1,300 张图片(含家兔、野兔)。
  • :1,300 张图片(含不同品种、年龄)。

在这里插入图片描述


Load Training Data
Create an image datastore. An image datastore enables you to store large collections of image data, including data that does not fit in memory, and efficiently read batches of images when training a neural network. Specify the folder with the extracted images, and indicate that the subfolder names correspond to the image labels.

imds = imageDatastore(digitDatasetPath, ...IncludeSubfolders=true,LabelSource="foldernames");imds.Labels = renamecats(imds.Labels, {'n01514668', 'n02129604','n02325366'}, {'chicken', 'tiger','rabbit'});
numObsPerClass = countEachLabel(imds)
numObsPerClass = Label     Count_______    _____chicken    1300 tiger      1300 rabbit     1300 

Load Pretrained Network

To adapt a pretrained neural network for a new task, replace the last few layers (the network head) so that it outputs prediction scores for each of the classes for the new task. This diagram outlines the architecture of a neural network that makes predictions for classes, and illustrates how to edit the network so that it outputs predictions for classes.

Load a pretrained SqueezeNet neural network into the workspace by using the imagePretrainedNetwork function. To return a neural network ready for retraining for the new data, specify the number of classes. When you specify the number of classes, the imagePretrainedNetwork function adapts the neural network so that it outputs prediction scores for each of the specified number of classes.

You can try other pretrained networks. Deep Learning Toolbox™ provides various pretrained networks that have different sizes, speeds, and accuracies. These additional networks usually require a support package. If the support package for a selected network is not installed, then the function provides a download link. For more information, see Pretrained Deep Neural Networks.

net = imagePretrainedNetwork("squeezenet",NumClasses=numClasses);
inputSize = networkInputSize(net)

The learnable layer in the network head (the last layer with learnable parameters) requires retraining. The layer is usually a fully connected layer, or a convolutional layer, with an output size that matches the number of classes.

The networkHead function, attached to this example as a supporting file, returns the layer and learnable parameter names of the learnable layer in the network head.

[layerName,learnableNames] = networkHead(net)

For transfer learning, you can freeze the weights of earlier layers in the network by setting the learning rates in those layers to 0. During training, the trainnet function does not update the parameters of these frozen layers. Because the function does not compute the gradients of the frozen layers, freezing the weights can significantly speed up network training. For small datasets, freezing the network layers prevents those layers from overfitting to the new dataset.
Freeze the weights of the network, keeping the last learnable layer unfrozen.

net = freezeNetwork(net,LayerNamesToIgnore=layerName);

Prepare Data for Training
The images in the datastore can have different sizes. To automatically resize the training images, use an augmented image datastore.

augImds = augmentedImageDatastore(inputSize(1:2),imds,ColorPreprocessing='gray2rgb');

Specify Training Options
Specify the training options. Choosing among the options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager app.
For this example, use these options:
Train using the Adam optimizer.
Validate the network using the validation data every five iterations. For larger datasets, to prevent validation from slowing down training, increase this value.
Display the training progress in a plot, and monitor the accuracy metric.
Disable the verbose output.

opts = trainingOptions("adam", ...InitialLearnRate=1e-4, ...MaxEpochs=50, ...ValidationData=augImdsVal, ...Verbose=false,...Plots="training-progress", ...MiniBatchSize=128,...Metrics="accuracy");

Train Neural Network
Train the neural network using the trainnet function. For classification, use cross-entropy loss. By default, the trainnet function uses a GPU if one is available. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements. Otherwise, the trainnet function uses the CPU. To specify the execution environment, use the ExecutionEnvironment training option.

rng default
net = trainnet(augImds,net,"crossentropy",opts);

没有划分数据集,因为这个例子本身的目的是为了观察CNN的特征变换。
在这里插入图片描述

>> summary(net)已初始化: true可学习参数的数量: 724k输入:1   'data'   227×227×3 图像

观察在训练集上的性能。

将预训练的神经网络直接应用于分类问题。要对新图像进行分类,请使用 minibatchpredict。要将预测分类分数转换为标签,请使用scores2label 函数。有关如何使用预训练神经网络进行分类的示例,请参阅使用 GoogLeNet 对图像进行分类。

在这里插入图片描述

在这里插入图片描述


Ambiguity of Classifications
You can use the softmax activations to calculate the image classifications that are most likely to be incorrect. Define the ambiguity of a classification as the ratio of the second-largest probability to the largest probability. The ambiguity of a classification is between zero (nearly certain classification) and 1 (nearly as likely to be classified to the most likely class as the second class). An ambiguity of near 1 means the network is unsure of the class in which a particular image belongs. This uncertainty might be caused by two classes whose observations appear so similar to the network that it cannot learn the differences between them. Or, a high ambiguity can occur because a particular observation contains elements of more than one class, so the network cannot decide which classification is correct. Note that low ambiguity does not necessarily imply correct classification; even if the network has a high probability for a class, the classification can still be incorrect.

[R,RI] = maxk(softmaxActivations,2,2);
ambiguity = R(:,2)./R(:,1);
Find the most ambiguous images.
[ambiguity,ambiguityIdx] = sort(ambiguity,"descend");
View the most probable classes of the ambiguous images and the true classes.
classList = unique(imds.Labels);
top10Idx = ambiguityIdx(1:10);
top10Ambiguity = ambiguity(1:10);
mostLikely = classList(RI(ambiguityIdx,1));
secondLikely = classList(RI(ambiguityIdx,2));
table(top10Idx,top10Ambiguity,mostLikely(1:10),secondLikely(1:10),imds.Labels(ambiguityIdx(1:10)),...VariableNames=["Image #","Ambiguity","Likeliest","Second","True Class"])
  10×5 tableImage #    Ambiguity    Likeliest    Second     True Class_______    _________    _________    _______    __________2268       0.99602      chicken     tiger       tiger    3330       0.99584      tiger       rabbit      rabbit   104       0.99187      chicken     tiger       chicken  304       0.98644      rabbit      chicken     chicken  1163       0.98466      tiger       chicken     chicken  3071       0.95684      chicken     rabbit      rabbit   1925       0.95373      rabbit      tiger       tiger    3006       0.95209      rabbit      chicken     rabbit   2772       0.93734      chicken     rabbit      rabbit   3461        0.9258      tiger       rabbit      rabbit  

容易错分的地方就这三坨。原因是这些样本都比较复杂,前景不突出,或者背景复杂,造成特征不明确。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

相关文章:

  • webgl工程发布问题解决记录
  • 【在线五子棋对战】五、前端扫盲:html css javascript ajax jquery websocket
  • LeetCode 3423.循环数组中相邻元素的最大差值:遍历(模拟)
  • 基于GeoTools的道路相交多个点容差冗余计算实战
  • Unity性能优化-C#编码模块
  • 项目名称:基于计算机视觉的夜间目标检测系统
  • 本地内网搭建网址需要外部网络连接怎么办?无公网ip实现https/http站点外网访问
  • 公网 IP 地址SSL证书实现 HTTPS 访问完全指南
  • Ubuntu下使用PyTurboJPEG加速图像编解码
  • 新能源知识库(46)EMS与协控装置
  • Peiiieee的Linux笔记(1)
  • [OS_20] 设备和驱动程序 | GPIO | IPP | PCIe总线 | ioctl
  • Android S - 恢复部分应用安装
  • 使用Gitlab CI/CD结合docker容器实现自动化部署
  • javascript入门
  • RT-Thread Studio 配置使用详细教程
  • Spring Cloud Gateway 介绍
  • 金蝶K3 ERP 跨网段访问服务器卡顿问题排查和解决方法
  • 用户态与内核态是什么?有什么作用?两者在什么时候切换?为什么要切换?
  • word用endnote插入国标参考文献
  • 网站开发后台数据库怎么搞/长春网站公司哪家好
  • 粉丝网站制作/百度搜索引擎怎么做
  • 网站建设合同 简单/seo工具在线访问
  • 百度公司做网站可靠吗/seo行业网
  • 网站建设优化服务精英/百度竞价员
  • 选片 网站 建设/qq推广引流怎么做