当前位置: 首页 > news >正文

生成了一个AI算法

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# 1. 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) # MNIST单通道归一化
])
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, transform=transform)

# 2. 模型定义
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.layers = nn.Sequential(
            nn.Linear(28*28, 128), # 输入层
            nn.ReLU(),             # 激活函数
            nn.Dropout(0.2),       # 防过拟合
            nn.Linear(128, 10)     # 输出层(10分类)
        )
    def forward(self, x):
        x = self.flatten(x)
        return self.layers(x)

# 3. 训练配置
model = NeuralNetwork()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

# 4. 训练循环
for epoch in range(10):
    for images, labels in train_loader:
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# 5. 评估
test_loader = torch.utils.data.DataLoader(test_data, batch_size=256)
correct = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
print(f'准确率: {100 * correct / len(test_data):.2f}%')

相关文章:

  • 华为设备端口隔离
  • 【Azure Redis】Redis导入备份文件(RDB)失败的原因
  • NVIDIA Halos:智能汽车革命中的全栈式安全系统
  • Selenium模拟人类,操作网页的行为(全)
  • 强啊!Oracle Database 23aiOracle Database 23ai:使用列别名进行分组排序!
  • Oracle04-基本使用
  • 26届秋招收割offer指南
  • JavaScript性能优化实战:深入探讨性能瓶颈与优化技巧
  • 嵌入式面试八股文(十四)·内存管理机制、优先级继承机制以及优先级翻转
  • 多行文本省略
  • 精益数据分析(43/126):媒体网站商业模式的盈利与指标解析
  • Gitee的介绍
  • QtGUI模块功能详细说明,图像处理(三)
  • MUX-vlan
  • 多模态大语言模型arxiv论文略读(六十)
  • 山东大学软件学院项目实训-基于大模型的模拟面试系统-个人主页头像上传
  • 面试常问系列(一)-神经网络参数初始化-之自注意力机制为什么除以根号d而不是2*根号d或者3*根号d
  • 双ISP(双互联网服务提供商)
  • 为什么Transformer推理需要做KV缓存
  • Kotlin-访问权限控制
  • 联合国秘书长吁印巴“最大程度克制”,特朗普:遗憾,希望尽快结束冲突
  • 上海市政府常务会议部署提升入境旅游公共服务水平,让国际友人“无障碍”畅游上海
  • 躺着玩手机真有意思,我“瞎”之前最喜欢了
  • “五一”假期预计全社会跨区域人员流动累计14.67亿人次
  • 贵州黔西市游船倾覆事故最后一名失联人员被找到,但已无生命体征
  • 击败老对手韩国队夺冠!国羽第14次问鼎苏迪曼杯创历史