当前位置: 首页 > news >正文

day35 python模型可视化与推理

目录

python模型可视化与推理

一、实验环境与数据准备

二、模型构建与训练

三、模型推理与评估

四、模型结构可视化与参数分析

五、学习总结


一、实验环境与数据准备

首先,我搭建了实验环境,导入了必要的库,包括PyTorch及其相关模块、sklearn用于数据处理、matplotlib用于可视化等。然后,我加载了鸢尾花数据集,这是一个经典的多分类数据集,包含150个样本,每个样本有4个特征,对应3种鸢尾花类别。

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import time
import matplotlib.pyplot as plt

为了更好地训练模型,我对数据进行了预处理。使用train_test_split将数据集划分为训练集和测试集,测试集占20%。接着,我使用MinMaxScaler对特征数据进行了归一化处理,将特征值缩放到0到1之间,这有助于加快模型的收敛速度。

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 归一化数据
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

二、模型构建与训练

我构建了一个简单的MLP模型,包含一个输入层、一个隐藏层和一个输出层。输入层的神经元数量与特征数量相同,隐藏层有10个神经元,输出层有3个神经元,分别对应3个类别。模型使用ReLU作为激活函数,交叉熵损失函数作为优化目标,随机梯度下降(SGD)作为优化器。

class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(4, 10)  # 输入层到隐藏层self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 3)  # 隐藏层到输出层def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out

在训练模型时,我设置了20000个训练轮数,并使用GPU加速训练过程。通过optimizer.zero_grad()清零梯度,loss.backward()计算梯度,optimizer.step()更新参数,完成了模型的训练。同时,我使用tqdm库创建了进度条,实时显示训练进度和损失值,让训练过程更加直观。

# 训练模型
num_epochs = 20000  # 训练的轮数# 创建tqdm进度条
with tqdm(total=num_epochs, desc="训练进度", unit="epoch") as pbar:for epoch in range(num_epochs):# 前向传播outputs = model(X_train)  # 隐式调用forward函数loss = criterion(outputs, y_train)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()# 记录损失值并更新进度条if (epoch + 1) % 200 == 0:losses.append(loss.item())epochs.append(epoch + 1)# 更新进度条的描述信息pbar.set_postfix({'Loss': f'{loss.item():.4f}'})# 每1000个epoch更新一次进度条if (epoch + 1) % 1000 == 0:pbar.update(1000)  # 更新进度条

三、模型推理与评估

训练完成后,我使用测试集对模型进行了推理和评估。在推理阶段,我将模型设置为评估模式,并禁用了梯度计算,以提高推理速度。通过计算预测值与真实值的匹配程度,我得到了模型在测试集上的准确率,达到了96.67%。

# 评估模型
model.eval() # 设置模型为评估模式
with torch.no_grad(): # 禁用梯度计算outputs = model(X_test)  # 对测试数据进行前向传播,获得预测结果_, predicted = torch.max(outputs, 1) # 获取预测的类别correct = (predicted == y_test).sum().item() # 计算预测正确的样本数accuracy = correct / y_test.size(0)print(f'测试集准确率: {accuracy * 100:.2f}%')

四、模型结构可视化与参数分析

为了更好地理解模型,我还对模型结构进行了可视化,并分析了模型参数。通过打印模型对象,我清晰地看到了模型的每一层及其参数信息。此外,我使用torchsummarytorchinfo库生成了模型摘要,详细展示了每一层的输出形状、参数数量等信息。

# 打印模型结构
print(model)# 使用torchsummary生成模型摘要
from torchsummary import summary
summary(model, input_size=(4,))

我还提取了模型的权重参数,并对其进行了可视化和统计分析。通过绘制权重分布直方图,我直观地看到了不同层权重的分布情况。同时,我计算了每层权重的均值、标准差、最小值和最大值等统计信息,这些信息为我后续调整超参数提供了参考。

# 提取权重数据并可视化
weight_data = {}
for name, param in model.named_parameters():if 'weight' in name:weight_data[name] = param.detach().cpu().numpy()# 可视化权重分布
fig, axes = plt.subplots(1, len(weight_data), figsize=(15, 5))
fig.suptitle('Weight Distribution of Layers')for i, (name, weights) in enumerate(weight_data.items()):weights_flat = weights.flatten()axes[i].hist(weights_flat, bins=50, alpha=0.7)axes[i].set_title(name)axes[i].set_xlabel('Weight Value')axes[i].set_ylabel('Frequency')axes[i].grid(True, linestyle='--', alpha=0.7)plt.tight_layout()
plt.subplots_adjust(top=0.85)
plt.show()

五、学习总结

通过这次实验,我不仅掌握了PyTorch框架的基本使用方法,还对深度学习模型的训练、推理、评估以及可视化有了更深入的理解。我学会了如何构建和训练一个简单的MLP模型,如何使用GPU加速训练过程,如何使用tqdm库创建进度条,以及如何对模型结构和参数进行可视化和分析。在实验过程中,我也遇到了一些问题,例如训练时间较长、模型准确率有待提高等。这些问题让我意识到,深度学习模型的优化是一个复杂的过程,需要不断地调整超参数、改进模型结构、尝试不同的优化算法等。未来,我将继续深入学习深度学习知识,探索更复杂的模型和优化方法,提高模型的性能和泛化能力。

@浙大疏锦行

相关文章:

  • 使用防火墙禁止程序联网(这里禁止vscode)
  • 天猫平台实时商品数据 API 接入方案与开发实践
  • C++——volatile
  • Python打卡第35天
  • Ollama01-安装教程
  • C#学习第25天:GUI编程
  • 关于支付组织
  • 黑马k8s(十五)
  • Mac的显卡架构种类
  • 数据透视表和公式法在Excel中实现去除重复计数的方法
  • 攻防世界RE-666
  • exti line2 interrupt 如何写中断回调
  • 关于使用QT时写客户端连接时因使用代理出现的问题
  • GeoTools 将 Shp 导入PostGIS 空间数据库
  • 路径规划算法BFS/Astar/HybridAstar简单实现
  • 如何实现Aurora MySQL 零停机升级
  • linux线程同步
  • ES6 扩展运算符与 Rest 参数
  • yum命令常用选项
  • nginx 基于IP和用户的访问
  • 浙江做网站找谁/网络营销大赛策划书
  • 打开上海发布/在线网站seo诊断
  • 用php做网站需要什么/有产品怎么找销售渠道
  • 深圳银行网站建设/淘宝店铺买卖交易平台
  • 梅州建站联系方式/优化疫情防控
  • 把自己做的网站发布/ip域名查询网