当前位置: 首页 > news >正文

python学习打卡day33

DAY 33 简单的神经网络

知识点回顾:

  1. PyTorch和cuda的安装
  2. 查看显卡信息的命令行命令(cmd中使用)
  3. cuda的检查
  4. 简单神经网络的流程
    1. 数据预处理(归一化、转换成张量)
    2. 模型的定义
      1. 继承nn.Module类
      2. 定义每一个层
      3. 定义前向传播流程
    3. 定义损失函数和优化器
    4. 定义训练流程
    5. 可视化loss过程

预处理补充:

注意事项:

1. 分类任务中,若标签是整数(如 0/1/2 类别),需转为long类型(对应 PyTorch 的torch.long),否则交叉熵损失函数会报错。

2. 回归任务中,标签需转为float类型(如torch.float32)。

作业:今日的代码,要做到能够手敲。这已经是最简单最基础的版本了。

import pandas as pd 
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
iris=load_iris()
X=iris.data
y=iris.target
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=42)
scaler=MinMaxScaler()
X_train=scaler.fit_transform(X_train)
X_test=scaler.transform(X_test)
X_train=torch.FloatTensor(X_train)
X_test=torch.FloatTensor(X_test)
y_train=torch.LongTensor(y_train)
y_test=torch.LongTensor(y_test)class MLP(nn.Module):def __init__(self):super(MLP,self).__init__()self.fc1=nn.Linear(4,10)self.relu=nn.ReLU()self.fc2=nn.Linear(10,3)def forward(self,x):out=self.fc1(x)out=self.relu(out)out=self.fc2(out)return out
model=MLP()criterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(),lr=0.01)num_epochs=20000
losses=[]
for epoch in range(num_epochs):outputs=model.forward(X_train)loss=criterion(outputs,y_train)optimizer.zero_grad()loss.backward()optimizer.step()losses.append(loss.item())if(epoch+1)%100==0:print(f'Epoch[{epoch+1}/{num_epochs}],loss:{loss.item():.4f}')plt.plot(range(num_epochs), losses)#()内对应的是X轴和Y轴的数据
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()

 

使用Adam优化器不到2000轮就收敛了,且损失要比SGD小

@浙大疏锦行

相关文章:

  • Mysql刷题之正则表达式专题
  • MA网络笔记
  • leetcode2261. 含最多 K 个可整除元素的子数组-medium
  • 关于Python编程语言的详细介绍,结合其核心特性、应用领域和发展现状,以结构化方式呈现:
  • 网络编程 之 从BIO到 NIO加多线程高性能网络编程实战
  • JMeter 教程:响应断言
  • 融合蛋白质语言模型和图像修复模型,麻省理工与哈佛联手提出PUPS ,实现单细胞级蛋白质定位
  • recurrent neural network(rnn)
  • 记录Pycharm断点调试的一个BUG
  • Java的列表、集合、数组的添加一个元素各自用的什么方法?
  • 蜂鸣器模块
  • 7.2.顺序查找
  • 【KWDB 2025 创作者计划】_KWDB时序数据库特性及跨模查询
  • 把银河装进镜头里!动态星轨素材使用实录
  • iisARR负均衡
  • indicator-sysmonitor 在Ubuntu 右上角实时显示CPU/MEM/NET的利用率
  • 实现一个前端动态模块组件(Vite+原生JS)
  • anaconda的c++环境与ros2需要的系统变量c++环境冲突
  • 冲刺卷软考总结-案例分析
  • MySQL索引事务
  • 泰安做网站公司/旅游营销推广方案
  • 本地网站建设流程/关键词你们懂的
  • 怎么自己做一个公众号/厦门百度seo点击软件
  • 手机浏览器下载app/个人如何做seo推广
  • 网站结构是什么 怎么做/360手机优化大师安卓版
  • 创建网站大约多少钱/seo推广软件品牌