当前位置: 首页 > news >正文

深度学习 模型和代码

提供一个简单的深度学习模型(类似 DeepSeek 工作原理的简单示例,比如一个简单的神经网络实现手写数字识别,使用 PyTorch 框架)示例代码

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 加载训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True,
                               download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False,
                              download=True, transform=transform)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


# 定义神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool(x)
        x = x.view(-1, 32 * 7 * 7)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        return x


# 初始化模型、损失函数和优化器
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))

# 在测试集上评估模型
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        output = model(data)
        test_loss += criterion(output, target).item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

以上代码构建了一个简单的卷积神经网络用于 MNIST 手写数字识别,包含了数据加载、模型定义、训练和测试等流程。

相关文章:

  • 【经验】Ubuntu|VMware 新建虚拟机后打开 SSH 服务、在主机上安装vscode并连接、配置 git 的 ssh
  • Spring Security的作用
  • 为你的python程序上锁:软件序列号生成器
  • 来看两篇RAG相关的优化工作:多跳查询的优化L-RAG以及利用记忆增强的查询重构MemQ框架
  • 大语言模型打卡学习DAY1
  • 【数据结构C语言】一、基本概念
  • java-正则表达式
  • 【Hadoop】Hadoop是什么?
  • 简单易懂Modbus Tcp和Rtu的异同点
  • AI重构私域增长:从流量收割到终身价值运营的三阶跃迁
  • UI自动化:poium测试库
  • 网络安全系统集成
  • 从Swish到SwiGLU:激活函数的进化与革命,qwen2.5应用的激活函数
  • pythonSTL---os
  • Spring Boot 启动失败:Failed to start bean ‘documentationPluginsBootstrapper’ 解决方案
  • 在 Linux 中,lsblk 命令输出内容解释
  • Linux网络编程——TCP网络通信多线程处理
  • 大白话JavaScript详细描述基于原型链实现对象继承的步骤,分析其在共享属性、内存占用等方面的优缺点
  • OpenHands:OpenDevin的升级版,由人工智能驱动的软件开发代理平台
  • stm32第四天控制蜂鸣器
  • 融创中国:今年前4个月销售额约112亿元
  • 奥园集团将召开债券持有人会议,拟调整“H20奥园2”本息兑付方案
  • 公元1057年:千年龙虎榜到底有多厉害?
  • 紧盯大V、网红带货肉制品,整治制售假劣肉制品专项行动开展
  • 中邮保险斥资8.69亿元举牌东航物流,持股比例达5%
  • 专访|李沁云:精神分析不会告诉你“应该怎么做”,但是……