当前位置: 首页 > news >正文

Day.34

优化耗时:

import torch

import torch.nn as nn

import torch.optim as optim

from sklearn.datasets import load_iris

from sklearn.model_selection import train_test_split

from sklearn.preprocessing import MinMaxScaler

import time

import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(f"使用设备: {device}")

iris = load_iris()

X = iris.data  

y = iris.target  

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

scaler = MinMaxScaler()

X_train = scaler.fit_transform(X_train)

X_test = scaler.transform(X_test)

X_train = torch.FloatTensor(X_train).to(device)

y_train = torch.LongTensor(y_train).to(device)

X_test = torch.FloatTensor(X_test).to(device)

y_test = torch.LongTensor(y_test).to(device)

class MLP(nn.Module):

    def __init__(self):

        super(MLP, self).__init__()

        self.fc1 = nn.Linear(4, 10)  

        self.relu = nn.ReLU()

        self.fc2 = nn.Linear(10, 3)  

    def forward(self, x):

        out = self.fc1(x)

        out = self.relu(out)

        out = self.fc2(out)

        return out

model = MLP().to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=0.01)

num_epochs = 20000  

losses = []

start_time = time.time()  

for epoch in range(num_epochs):

    outputs = model(X_train) 

    loss = criterion(outputs, y_train)

    optimizer.zero_grad()

    loss.backward()

    optimizer.step()

    if (epoch + 1) % 200 == 0:

        losses.append(loss.item()) 

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

    if (epoch + 1) % 100 == 0: 

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

time_all = time.time() - start_time  

print(f'Training time: {time_all:.2f} seconds')

plt.plot(range(len(losses)), losses)

plt.xlabel('Epoch')

plt.ylabel('Loss')

plt.title('Training Loss over Epochs')@浙大疏锦行

plt.show()

相关文章:

  • JVM: 内存、类与垃圾
  • API 管理系统实践指南:监控、安全、性能全覆盖
  • MCP基本概念
  • synchronized 做了哪些优化?
  • 【Algorithm】图论入门
  • 软件体系结构-论述、设计、问答
  • 每天一个前端小知识 Day 4 - TypeScript 核心类型系统与实践
  • 跨境卖家警报。抽绳背包版权案立案,TRO在即速排查
  • 二维数组 结构体01 day15,16
  • 【大模型:知识库管理】--MinerU本地部署
  • SpringBoot Starter设计:依赖管理的革命
  • 什么是数据清洗?数据清洗有哪些步骤?
  • 选择与方法专栏(9) 职场内篇: 是否要跳出舒适圈?如何处理犯错?
  • ffmpeg python rgba图片合成 4444格式mov视频,保留透明通道
  • 有趣的git
  • 【git】错误
  • 《深度学习基础与概念》task2/3
  • 使用 Java + WebSocket 实现简单实时双人协同 pk 答题
  • 设计模式精讲 Day 4:建造者模式(Builder Pattern)
  • Datawhale YOLO Master 第1次笔记
  • 如何做cpa单页网站/什么是seo优化?
  • 网站建设硬件设置/下载手机百度最新版
  • 路由器屏蔽网站怎么做/小红书推广
  • h5case什么网站/怎么让百度搜索靠前
  • 电子商务 网站模板/下载百度地图2022最新版官方
  • wordpress整站搬家/网络推广怎么找客户资源