当前位置: 首页 > news >正文

李沐动手深度学习(pycharm中运行笔记)——11.模型选择+过拟合欠拟合

11.模型选择+过拟合欠拟合(与课程对应)


一、模型选择

  • 训练误差与泛化误差
    • 误差定义:训练误差是模型在训练数据集上的误差,泛化误差是模型在新数据上的误差。
    • 举例说明:高考模考成绩好是训练误差,不能代表真实考试成绩好;背题的学生模拟考成绩好但真实考试可能不如理解解题思路的学生
    • 关注重点:应关注泛化误差,而非训练误差。
  • 验证数据集与测试数据集
    • 数据集作用:验证数据集用于评估模型好坏、选择超参数,测试数据集理论上只能用一次,不能用于调参。
    • 常见错误:将验证数据集与训练数据集混用,如用 Google 搜来的图片作为验证集,其与训练集 image net 的图片有重复,导致模型上线效果差;在测试数据集上调参,使测试结果虚高。
    • 实际情况:代码里常将验证数据集称为 test data,但其精度可能虚高。
  • k 折交叉验证算法
    • 算法原理:将训练数据集随机打散后分割成 k 块,做 k 次计算,每次将第 k 块作为验证数据集,其余作为训练数据集,取 k 次验证精度的平均值。
    • 举例说明:以三折交叉验证为例,进行三次不同验证数据集的计算。
    • k 值选择:常用 k 等于 5 或 10,若数据大 k 可取 2 或 3,数据小可大于 10;n 折交叉验证能最大程度使用训练数据,但计算代价高。
  • 超参数选择方法
    • 选择流程:有多种超参数时,采用 k 折交叉验证,为每个超参数得到交叉验证的平均精度,选择精度最好的超参数。

二、过拟合欠拟合

  • 过拟合与欠拟合概念
    • 过拟合:简单数据使用复杂模型,模型记住所有样本,对新样本无泛化性。如在简单数据集上用特别深的神经网络。
    • 欠拟合:复杂数据使用简单模型,无法很好训练模型。如用线性模型拟合复杂的 fashion MNIST 数据集或异或函数。
  • 模型容量定义与选择
    • 模型容量定义:指拟合各种函数的能力,低容量难以拟合训练数据,高容量可记住所有训练数据。
    • 模型容量选择:简单数据选低模型容量,复杂数据选高模型容量。
  • 模型容量与误差关系
    • 低模型容量情况:训练误差和泛化误差都高,因模型简单无法拟合数据。
    • 模型容量增加:训练误差下降,可降为 0,但泛化误差先降后升,因模型过于关注细节。
    • 最优情况:泛化误差上升时的点为最优,要将该点误差往下拉,减小训练误差和泛化误差的差距。
  • 模型容量估计因素
    • 参数个数:参数个数多,模型容量高。如线性模型参数个数为 d + 1,单层隐含层感知机参数个数为 (d + 1)×m + (m + 1)×k 。
    • 参数值选择范围:参数选择范围大,模型复杂度高;范围小,模型容量低。
  • VC 维相关内容
    • VC 维定义:对分类模型,等价于最大数据集大小,不管其标号如何,都存在模型可完美分类。
    • 示例:二维输入感知机 VC 维为 3,可任意分类三个点,但不能分类异或问题;支持 n 维输入的感知器 VC 维是 n 。
  • 数据复杂度衡量因素
    • 样本个数:样本个数不同,数据复杂度不同,如 100 个样本和 100 万个样本。
    • 样本元素个数:如二维向量和不同尺寸图片,图片尺寸大则复杂度高。
    • 时空结构:数据有空间、时间或时空结构,如图片有空间结构,股票预测有时间结构,视频有时空结构。
    • 数据多样性:分类类别数不同,多样性不同,如 10 类、100 类、1000 类分类。

三、代码

import math
from cProfile import label
from symbol import trailer
import numpy as np
import torch
from pyexpat import features
from torch import nn
from d2l import torch as d2l# 模型选择:欠拟合 和 过拟合;通过多项式拟合来交互地探索这些概念
max_degree = 20  # 特征为20
n_train, n_test = 100, 100  # 100个训练样本、100个测试样本(验证)
true_w = np.zeros(max_degree)  # 长为20的w
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])  # 前4个有数(多项式前4项系数非0),其他16个都是0,给它一些噪声features = np.random.normal(size=(n_train + n_test, 1))
np.random.shuffle(features)
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
for i in range(max_degree):poly_features[:, i] /= math.gamma(i + 1)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale=0.1, size=labels.shape)true_w, features, poly_features, labels = [torch.tensor(x, dtype=torch.float32)for x in [true_w, features, poly_features, labels]
]print(features[:2], '\n', poly_features[:2, :], '\n', labels[:2])  # 打印一下部分数据长什么样子# 实现一个函数来评估模型在给定数据集上的损失
def evaluate_loss(net, data_iter, loss):metric = d2l.Accumulator(2)for X, y in data_iter:out = net(X)y = y.reshape(out.shape)l = loss(out, y)metric.add(l.sum(), l.numel())return metric[0] / metric[1]# 训练函数
def train(train_features, test_features, train_labels, test_labels, num_epochs=400):loss = nn.MSELoss()input_shape = train_features.shape[-1]net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))batch_size = min(10, train_labels.shape[0])train_iter = d2l.load_array((train_features, train_labels.reshape(-1, 1)), batch_size)test_iter = d2l.load_array((test_features, test_labels.reshape(-1, 1)), batch_size, is_train=False)trainer = torch.optim.SGD(net.parameters(), lr=0.01)animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log', xlim=[1, num_epochs], ylim=[1e-3, 1e2], legend=['train', 'test'])for epoch in range(num_epochs):d2l.train_epoch_ch3(net, train_iter, loss, trainer)if epoch == 0 or (epoch + 1) % 20 == 0:animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),evaluate_loss(net, test_iter, loss)))print('weight:', net[0].weight.data.numpy())# 三阶多项式拟合(正常);从多项式特征中选择前4个维度
train(poly_features[:n_train, :4], poly_features[n_train:, :4],labels[:n_train], labels[n_train:])
d2l.plt.show()# 线性函数拟合(欠拟合);数据给的不全
train(poly_features[:n_train, :2], poly_features[n_train:, :2],labels[:n_train], labels[n_train:])
d2l.plt.show()# 高阶多项式函数拟合(过拟合)
train(poly_features[:n_train, :], poly_features[n_train:, :],labels[:n_train], labels[n_train:])
d2l.plt.show()

如果此文章对您有所帮助,那就请点个赞吧,收藏+关注 那就更棒啦,十分感谢!!! 

相关文章:

  • SQL关键字三分钟入门:UNION 与 UNION ALL —— 数据合并全攻略
  • RKNN开发环境搭建3-RKNN Model Zoo 板载部署以Whisper为例
  • pyqt 简单条码系统
  • OpenStack入门
  • 搭建简易采购系统:从需求分析到供应商数据库设计
  • 【第二章:机器学习与神经网络概述】01.聚类算法理论与实践-(2)层次聚类算法(Hierarchical Clustering)
  • 【对比】DeepAR 和 N-Beats
  • 【unitrix】 3.0 基本结构体(types.rs)
  • python 解码 jwt
  • javaweb -Ajax
  • LVS—DR模式
  • 最新FVCOM 潮流、波浪、泥沙、水质、温盐、染色剂、粒子示踪、嵌套、背景流、自动化全流程
  • 在线教育平台敏捷开发项目
  • CppCon 2017 学习:C++ in Academia
  • ModbusTcp使用
  • Qt事件处理机制
  • Transformer推理拓扑关系
  • 2025年06月18日Github流行趋势
  • Jenkins审核插件实战:实现流水线审批控制的最佳实践
  • 经典风格的免费wordpress模板
  • 部队网站设计/全网整合营销平台
  • 济宁做网站公司/佐力药业股票
  • 溧水区住房和城乡建设厅网站/seo服务公司
  • dw网页制作下载/网站优化搜索排名
  • 外贸网站推广/南京seo排名扣费
  • 凡科做数据查询网站/网络推广外包公司哪家好