当前位置: 首页 > news >正文

DAY 39 图像数据与显存

  1. 图像数据的格式:灰度和彩色数据
  2. 模型的定义
  3. 显存占用的4种地方
    1. 模型参数+梯度参数
    2. 优化器参数
    3. 数据批量所占显存
    4. 神经元输出中间状态
  4. batchisize和训练的关系
  5. import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import DataLoader , Dataset
    from torchvision import datasets, transforms
    import matplotlib.pyplot as plttorch.manual_seed(42)
    transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
    ])
    train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
    )test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
    )
    sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() 
    image, label = train_dataset[sample_idx] 
    def imshow(img):img = img* 0.3081 + 0.1307npimg= img.numpy()plt.imshow(npimg[0], cmap='gray')plt.show()
    print(f'Label:{label}')
    imshow(image)
    import torch
    import torchvision
    import torchvision.transforms as transforms
    import matplotlib.pyplot as plt
    import numpy as nptorch.manual_seed(42)
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  
    ])trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform
    )trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True
    )classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')sample_idx = torch.randint(0, len(trainset), size=(1,)).item()
    image, label = trainset[sample_idx]print(f"图像形状: {image.shape}") 
    print(f"图像类别: {classes[label]}")def imshow(img):img = img / 2 + 0.5 npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))  plt.axis('off') plt.show()imshow(image)
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) 
    ])
    import matplotlib.pyplot as plttrain_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
    )test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
    )
    class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.flatten = nn.Flatten()  self.layer1 = nn.Linear(784, 128) self.relu = nn.ReLU()  self.layer2 = nn.Linear(128, 10)  def forward(self, x):x = self.flatten(x)  x = self.layer1(x)   x = self.relu(x)    x = self.layer2(x) return xmodel = MLP()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)  from torchsummary import summary  
    print("\n模型结构信息:")
    summary(model, input_size=(1, 28, 28)) 
    class MLP(nn.Module):def __init__(self, input_size=3072, hidden_size=128, num_classes=10):super(MLP, self).__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_size, num_classes) def forward(self, x):x = self.flatten(x)  x = self.fc1(x)    x = self.relu(x)     x = self.fc2(x)     return xmodel = MLP()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device) 
    from torchsummary import summary 
    print("\n模型结构信息:")
    summary(model, input_size=(3, 32, 32))  class MLP(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten() self.layer1 = nn.Linear(784, 128)self.relu = nn.ReLU()self.layer2 = nn.Linear(128, 10)def forward(self, x):x = self.flatten(x) x = self.layer1(x) x = self.relu(x)x = self.layer2(x)   return x
    from torch.utils.data import DataLoadertrain_loader = DataLoader(dataset=train_dataset,  batch_size=64,          shuffle=True           
    )test_loader = DataLoader(dataset=test_dataset,batch_size=1000,shuffle=False
    )

相关文章:

  • unity版本控制PlasticSCM转git
  • RADIUS认证服务器全面解析:核心功能、应用场景
  • FLTK从源码编译到使用
  • SQL Server基础语句4:数据定义
  • ROS 2 中 Astra Pro 相机与 YOLOv5 检测功能编译启动全记录
  • 深入解析域名解析API:从gethostbyname到getaddrinfo的演进之路
  • cherryStudio连接MCP服务器
  • 微服务网关/nacos/feign总结
  • Spring AI 项目实战(十一):Spring Boot +AI + DeepSeek 开发智能教育作业批改系统(附完整源码)
  • 华为OD-2024年E卷-字符串化繁为简[200分] -- python
  • Qt应用中处理Linux信号:实现安全退出的技术指南
  • MySQL 主从同步完整配置示例
  • 虚拟与现实交融视角下定制开发开源AI智能名片S2B2C商城小程序赋能新零售商业形态研究
  • 华为OD机考-网上商城优惠活动-模拟(JAVA 2025B卷)
  • 华为公布《鸿蒙编程语言白皮书》V1.0 版:解读适用场景
  • Ragflow 源码:task_executor.py
  • 数据库(1)-SQL
  • 超详细YOLOv8/11图像菜品分类全程概述:环境、数据准备、训练、验证/预测、onnx部署(c++/python)详解
  • 46- 赎金信
  • VB.NET,C#在线程中修改UI的安全操作
  • 网站建设目的分析/虎扑体育网体育
  • 万网注册域名做简单网站/seo服务 文库
  • 建设网站自学/昆明排名优化
  • 企业网站一般用什么域名/微信软文是什么
  • 菜鸟如何做网站/东莞网络公司排行榜
  • 做网站哪些方面会侵权/成都网站设计公司