当前位置: 首页 > news >正文

da y54

1.对inception网络在cifar10上观察精度

 

Inception网络是一种经典的卷积神经网络架构,其核心特点是通过“ inception模块”组合不同尺寸的卷积核(如1x1、3x3、5x5)和池化操作,在提升特征提取能力的同时控制计算量。

在CIFAR-10数据集(含10类小尺寸图像)上,Inception网络的表现受具体版本和训练配置影响,测试精度通常在87% - 96%之间。例如,Inception-v3在该数据集上可达到约96.5%的精度,展现了其对小图像分类的有效性。

 

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
 
# 检查设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
 
# 数据预处理
transform = transforms.Compose([
    transforms.Resize((299, 299)),  # Inception 网络需要输入大小为 299x299
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化到 [-1, 1]
])
 
# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
 
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
 
# 加载预训练的 Inception 网络
model = models.inception_v3(pretrained=True)
model.aux_logits = False  # 禁用辅助分类器
model.fc = nn.Linear(model.fc.in_features, 10)  # 修改最后一层为 CIFAR-10 的 10 类
model = model.to(device)
 
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
 
# 训练函数
def train(model, train_loader, criterion, optimizer, device, epochs):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100. * correct / total
        print(f"Epoch [{epoch+1}/{epochs}] | Loss: {epoch_loss:.4f} | Accuracy: {epoch_acc:.2f}%")
 
# 测试函数
def test(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            test_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    
    avg_loss = test_loss / len(test_loader)
    accuracy = 100. * correct / total
    print(f"Test Loss: {avg_loss:.4f} | Test Accuracy: {accuracy:.2f}%")
    return accuracy
 
# 执行训练和测试
epochs = 10
print("开始训练 Inception 网络...")
train(model, train_loader, criterion, optimizer, device, epochs)
print("训练完成!开始测试...")
test_accuracy = test(model, test_loader, criterion, device)
print(f"最终测试准确率: {test_accuracy:.2f}%")

 

http://www.dtcms.com/a/269068.html

相关文章:

  • LED 闪烁 LED 流水灯 蜂鸣器
  • IROS 2025|RL vs MPC性能对比:加州理工无人机实测,谁在「变形控制」中更胜一筹?
  • pg_class 系统表信息
  • React + Express 传输加密以及不可逆加密
  • OpenCV人脸分析------绘制面部关键点函数drawFacemarks()
  • day08-Elasticsearch
  • MinIO与SpringBoot集成完整指南
  • maven 发布到中央仓库常用脚本-02
  • 视频序列和射频信号多模态融合算法Fusion-Vital解读
  • 力扣 hot100 Day37
  • C++笔记之和的区别
  • Isaac Lab:让机器人学习更简单的开源框架
  • Go defer(二):从汇编的角度理解延迟调用的实现
  • RAG实战指南 Day 8:PDF、Word和HTML文档解析实战
  • Stirling-PDF 本地化部署,建立自己的专属PDF工具箱
  • 力扣_链表(前后指针)_python版本
  • 虚幻引擎UE5 GAS开发RPG游戏-02 设置英雄角色-18 改成网络多人游戏
  • C++:string类(3)(string类的模拟实现)
  • 批量OCR的GitHub项目
  • Linux 进程控制:全面深入剖析进程创建、终止、替换与等待
  • UI自动化常见面试题
  • qt-C++笔记之QSplitter
  • PyTorch笔记3----------统计学相关函数
  • AI PPT探秘
  • ARMv7单核CPU上SWI(软件中断)验证
  • 策略与工厂的演进:打造工业级Spring路由框架
  • window显示驱动开发—X 通道解释
  • 如何远程管理Linux服务器
  • Rust 内存结构:深入解析
  • DPDK 网络驱动 之 UIO