当前位置: 首页 > wzjs >正文

很多卖假药冒产品用二级域名做网站微信分享接口网站开发

很多卖假药冒产品用二级域名做网站,微信分享接口网站开发,有哪些做电子商务的网站,wordpress分类目录小工具知识点回顾: 彩色和灰度图片测试和训练的规范写法:封装在函数中展平操作:除第一个维度batchsize外全部展平dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout 作业:仔细学习下测试和训练代码…

知识点回顾:

  1. 彩色和灰度图片测试和训练的规范写法:封装在函数中
  2. 展平操作:除第一个维度batchsize外全部展平
  3. dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout

作业:仔细学习下测试和训练代码的逻辑,这是基础,这个代码框架后续会一直沿用,后续的重点慢慢就是转向模型定义阶段了。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from typing import Tuple, Callable, Optional# 1. 数据预处理与加载函数
def get_data_loaders(dataset_name: str = 'MNIST',  # 可选: 'MNIST'或'CIFAR10'batch_size: int = 64,data_dir: str = './data',num_workers: int = 2
) -> Tuple[DataLoader, DataLoader]:"""获取训练和测试数据加载器,支持灰度(MNIST)和彩色(CIFAR10)数据集"""# 根据数据集类型设置不同的转换if dataset_name == 'MNIST':# 灰度图像转换transform = transforms.Compose([transforms.ToTensor(),  # 转换为Tensor并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差])train_dataset = datasets.MNIST(data_dir, train=True, download=True, transform=transform)test_dataset = datasets.MNIST(data_dir, train=False, transform=transform)elif dataset_name == 'CIFAR10':# 彩色图像转换transform = transforms.Compose([transforms.ToTensor(),  # 转换为Tensor并归一化到[0,1]transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化到[-1,1]])train_dataset = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform)test_dataset = datasets.CIFAR10(data_dir, train=False, transform=transform)else:raise ValueError(f"Unsupported dataset: {dataset_name}")# 创建数据加载器train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers)test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers)return train_loader, test_loader# 2. 通用模型定义(支持灰度和彩色图像)
class Flatten(nn.Module):"""自定义展平层,保留batch维度"""def forward(self, x):return x.view(x.size(0), -1)  # 保留batch维度,展平其余维度class ImageClassifier(nn.Module):"""通用图像分类器,支持灰度和彩色图像"""def __init__(self,input_channels: int = 1,  # MNIST:1, CIFAR10:3input_size: int = 28,     # MNIST:28, CIFAR10:32hidden_size: int = 128,num_classes: int = 10,dropout_rate: float = 0.5):super().__init__()self.model = nn.Sequential(Flatten(),  # 展平除batch外的所有维度nn.Linear(input_channels * input_size * input_size, hidden_size),nn.ReLU(),nn.Dropout(dropout_rate),  # 训练时随机丢弃神经元nn.Linear(hidden_size, num_classes))def forward(self, x):return self.model(x)# 3. 训练函数
def train(model: nn.Module,train_loader: DataLoader,criterion: nn.Module,optimizer: optim.Optimizer,device: torch.device,epoch: int,log_interval: int = 100
) -> None:"""训练模型一个epoch"""model.train()  # 启用训练模式(激活dropout等)running_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)# 前向传播optimizer.zero_grad()output = model(data)loss = criterion(output, target)# 反向传播loss.backward()optimizer.step()running_loss += loss.item()# 打印训练进度if batch_idx % log_interval == 0:print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} 'f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')# 打印平均损失avg_loss = running_loss / len(train_loader)print(f'Epoch {epoch} average loss: {avg_loss:.4f}')# 4. 测试函数
def test(model: nn.Module,test_loader: DataLoader,criterion: nn.Module,device: torch.device
) -> Tuple[float, float]:"""评估模型在测试集上的性能"""model.eval()  # 启用评估模式(关闭dropout等)test_loss = 0correct = 0with torch.no_grad():  # 不计算梯度,节省内存和计算资源for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()  # 累加批次损失pred = output.argmax(dim=1, keepdim=True)  # 获取最大概率的类别correct += pred.eq(target.view_as(pred)).sum().item()  # 统计正确预测数# 计算平均损失和准确率test_loss /= len(test_loader)accuracy = 100. * correct / len(test_loader.dataset)print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} 'f'({accuracy:.2f}%)\n')return test_loss, accuracy# 5. 主函数:训练和测试流程
def main(dataset_name: str = 'MNIST',batch_size: int = 64,epochs: int = 5,lr: float = 0.001,dropout_rate: float = 0.5,use_cuda: bool = True
) -> None:"""主函数:整合数据加载、模型训练和测试流程"""# 设置设备device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")print(f"Using device: {device}")# 获取数据加载器train_loader, test_loader = get_data_loaders(dataset_name=dataset_name,batch_size=batch_size)# 确定输入参数if dataset_name == 'MNIST':input_channels = 1input_size = 28num_classes = 10elif dataset_name == 'CIFAR10':input_channels = 3input_size = 32num_classes = 10else:raise ValueError(f"Unsupported dataset: {dataset_name}")# 初始化模型model = ImageClassifier(input_channels=input_channels,input_size=input_size,hidden_size=128,num_classes=num_classes,dropout_rate=dropout_rate).to(device)# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=lr)# 训练和测试循环for epoch in range(1, epochs + 1):train(model, train_loader, criterion, optimizer, device, epoch)test(model, test_loader, criterion, device)# 保存模型torch.save(model.state_dict(), f"{dataset_name}_mlp_model.pth")print(f"Model saved as: {dataset_name}_mlp_model.pth")if __name__ == "__main__":# 训练MNIST模型main(dataset_name='MNIST', batch_size=64, epochs=5)# 训练CIFAR10模型(取消注释下面一行)# main(dataset_name='CIFAR10', batch_size=64, epochs=10)    

@浙大疏锦行


文章转载自:

http://LS6R7mVL.kkLwz.cn
http://btC9FKoF.kkLwz.cn
http://GuGMJRiq.kkLwz.cn
http://isG7s9RR.kkLwz.cn
http://Baxkdpw9.kkLwz.cn
http://r4nb1zcS.kkLwz.cn
http://tg1QcCLB.kkLwz.cn
http://V83KOAbP.kkLwz.cn
http://lN3hexrV.kkLwz.cn
http://PztcUHhY.kkLwz.cn
http://g7fKBXKL.kkLwz.cn
http://nJqgNN1M.kkLwz.cn
http://M9i72eCF.kkLwz.cn
http://2oTVDm5S.kkLwz.cn
http://CZmgraPQ.kkLwz.cn
http://OsexekLA.kkLwz.cn
http://ImTVfT7i.kkLwz.cn
http://w4u1MJTE.kkLwz.cn
http://tOQENkEg.kkLwz.cn
http://EoeGofpB.kkLwz.cn
http://AqFAvywM.kkLwz.cn
http://RTMuThVh.kkLwz.cn
http://Yv7lUSdF.kkLwz.cn
http://1Olk9cWE.kkLwz.cn
http://8ttFr2Fm.kkLwz.cn
http://2nSJkgpY.kkLwz.cn
http://dBKfcbaj.kkLwz.cn
http://Ng6s06Lp.kkLwz.cn
http://2F6nzUYa.kkLwz.cn
http://9Lprc8Ug.kkLwz.cn
http://www.dtcms.com/wzjs/677046.html

相关文章:

  • 西安中交建设集团网站建设部监理资质申报网站
  • 湛江免费建站进入公众号即弹出图文
  • 国外流行的内容网站wordpress登陆后评论
  • 网站备案是否收费网站建设步骤 高清教 程
  • 网址查询网站上海亿网站建设
  • wordpress网站运行时间代码做家教网站赚钱么
  • 海东商城网站建设碗网站
  • 网站源码分享丹徒网站建设哪家好
  • 网站不需要什么备案凯盛建设公司网站
  • 品牌网站建设怎么做wordpress 设置伪静态后
  • 做网站怎样连数据库室内设计工作室简介
  • 网站建设设计简介品牌建设的好处
  • 怎么做网站的内链wordpress更新之后字体发生变化
  • wordpress 教垜东莞网站排名优化seo
  • 做网站彩票代理犯法吗iis wordpress rewrite
  • 男女做爰全过程的视频网站wordpress设置权限777
  • 南京百度做网站电话网站的维护怎么做
  • 淘宝官网首页版本湖南seo推广软件
  • 西安 做网站铁汉生态建设有限公司网站
  • 做网站需要记哪些代码wordpress内网使用
  • 宠物网站设计说明书东莞外贸网站的推广
  • 南软科技网站开发开源php建站系统
  • 昆山建设局网站首页wordpress减少请求次数
  • 网站数据库结构被删了怎么办网站搭建技术要求
  • wordpress网站生成app应用wordpress多设备网页生成
  • 网站开发人员绩效如何计算北京网站开发建设 58同城
  • 苏州外贸网站建设公司排名百度收录方法
  • 网站建设准备取消wordpress邮箱认证
  • 网站区域名是什么随州网站seo
  • 美食网站建设的内容分析建设网站的建筑公司