当前位置: 首页 > news >正文

python学习打卡day35

知识点回顾:

  1. 三种不同的模型可视化方法:推荐torchinfo打印summary+权重分布可视化
  2. 进度条功能:手动和自动写法,让打印结果更加美观
  3. 推理的写法:评估模式

作业:调整模型定义时的超参数,对比下效果。

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import time
import matplotlib.pyplot as plt
from tqdm import tqdm  # 导入tqdm库用于进度条显示
import warnings
warnings.filterwarnings("ignore")  # 忽略警告信息# 设置GPU设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 加载鸢尾花数据集
iris = load_iris()
X = iris.data  # 特征数据
y = iris.target  # 标签数据# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 归一化数据
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 将数据转换为PyTorch张量并移至GPU
X_train = torch.FloatTensor(X_train).to(device)
y_train = torch.LongTensor(y_train).to(device)
X_test = torch.FloatTensor(X_test).to(device)
y_test = torch.LongTensor(y_test).to(device)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(4, 10)  # 输入层到隐藏层self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 3)  # 隐藏层到输出层def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# 实例化模型并移至GPU
model = MLP().to(device)# 分类问题使用交叉熵损失函数
criterion = nn.CrossEntropyLoss()# 使用随机梯度下降优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
num_epochs = 20000  # 训练的轮数# 用于存储每100个epoch的损失值和对应的epoch数
losses = []
epochs = []start_time = time.time()  # 记录开始时间# 创建tqdm进度条
with tqdm(total=num_epochs, desc="训练进度", unit="epoch") as pbar:# 训练模型for epoch in range(num_epochs):# 前向传播outputs = model(X_train)  # 隐式调用forward函数loss = criterion(outputs, y_train)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()# 记录损失值并更新进度条if (epoch + 1) % 200 == 0:losses.append(loss.item())epochs.append(epoch + 1)# 更新进度条的描述信息pbar.set_postfix({'Loss': f'{loss.item():.4f}'})# 每1000个epoch更新一次进度条if (epoch + 1) % 1000 == 0:pbar.update(1000)  # 更新进度条# 确保进度条达到100%if pbar.n < num_epochs:pbar.update(num_epochs - pbar.n)  # 计算剩余的进度并更新time_all = time.time() - start_time  # 计算训练时间
print(f'Training time: {time_all:.2f} seconds')# 可视化损失曲线
plt.figure(figsize=(10, 6))
plt.plot(epochs, losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.grid(True)
plt.show()# 在测试集上评估模型,此时model内部已经是训练好的参数了
# 评估模型
model.eval() # 设置模型为评估模式
with torch.no_grad(): # torch.no_grad()的作用是禁用梯度计算,可以提高模型推理速度outputs = model(X_test)  # 对测试数据进行前向传播,获得预测结果_, predicted = torch.max(outputs, 1) # torch.max(outputs, 1)返回每行的最大值和对应的索引#这个函数返回2个值,分别是最大值和对应索引,参数1是在第1维度(行)上找最大值,_ 是Python的约定,表示忽略这个返回值,所以这个写法是找到每一行最大值的下标# 此时outputs是一个tensor,p每一行是一个样本,每一行有3个值,分别是属于3个类别的概率,取最大值的下标就是预测的类别# predicted == y_test判断预测值和真实值是否相等,返回一个tensor,1表示相等,0表示不等,然后求和,再除以y_test.size(0)得到准确率# 因为这个时候数据是tensor,所以需要用item()方法将tensor转化为Python的标量# 之所以不用sklearn的accuracy_score函数,是因为这个函数是在CPU上运行的,需要将数据转移到CPU上,这样会慢一些# size(0)获取第0维的长度,即样本数量correct = (predicted == y_test).sum().item() # 计算预测正确的样本数accuracy = correct / y_test.size(0)print(f'测试集准确率: {accuracy * 100:.2f}%')

@浙大疏锦行

http://www.dtcms.com/a/333531.html

相关文章:

  • 分库分表和sql的进阶用法总结
  • AI客户维护高效解决方案
  • element-plus 如何通过js验证页面的表单
  • 开发避坑指南(27):Vue3中高效安全修改列表元素属性的方法
  • IP地址代理服务避坑指南:如何选择优质的IP地址代理服务公司?
  • 前端设置不同环境高德地图 key 和秘钥(秘钥通过运维统一配置)
  • 六大主流负载均衡算法
  • w484扶贫助农系统设计与实现
  • 【postgresql】一文详解postgresql中的统计模块
  • [Pyro概率编程] 概率分布 | 共轭计算 | 参数存储库
  • Qt开发:实现跨组件的条件触发
  • android 悬浮窗权限申请
  • 正点原子STM32H743配置 LTDC + DMA2D
  • 零基础学会制作 基于STM32单片机智能加湿系统/加湿监测/蓝牙系统/监测水量
  • Docker部署MySQL命令解读
  • redis-保姆级配置详解
  • 嵌入式软件开发--回调函数
  • 大肠杆菌重组蛋白表达致命痛点:包涵体 / 低表达 / 可溶性差?高效解决方案全解析!
  • JVM核心原理与实战优化指南
  • c++程序示例:多线程下的实例计数器
  • Nginx反向代理与缓存实现
  • 企业级Java项目和大模型结合场景(智能客服系统:电商、金融、政务、企业)
  • 正确维护邵氏硬度计的使用寿命至关重要
  • 【办公类110-01】20250813 园园通新生分班(python+uibot)
  • 量化线性层(42)
  • JavaScript 逻辑运算符与实战案例:从原理到落地
  • JavaScript 中 call、apply 和 bind 方法的区别与使用
  • 技术解读 | 搭建NL2SQL系统需要大模型么?
  • 【Git】Git-fork开发模式
  • 从0开始学习Java+AI知识点总结-15.后端web基础(Maven基础)