当前位置: 首页 > news >正文

基于Fashion-MNIST的softmax回归-直接运行

引用 Fashion-MNIST数据集,进行分类问题训练,代码如下,可直接运行

import torch
import torchvision
from torchvision import transforms
from torch.utils import data
import timeclass Timer:"""记录多次运行时间"""def __init__(self):"""Defined in :numref:`subsec_linear_model`"""self.times = []self.start()def start(self):"""启动计时器"""self.tik = time.time()def stop(self):"""停止计时器并将时间记录在列表中"""self.times.append(time.time() - self.tik)return self.times[-1]def avg(self):"""返回平均时间"""return sum(self.times) / len(self.times)def sum(self):"""返回时间总和"""return sum(self.times)def cumsum(self):"""返回累计时间"""return np.array(self.times).cumsum().tolist()size = lambda x, *args, **kwargs: x.numel(*args, **kwargs)
reduce_sum = lambda x, *args, **kwargs: x.sum(*args, **kwargs)
argmax = lambda x, *args, **kwargs: x.argmax(*args, **kwargs)
astype = lambda x, *args, **kwargs: x.type(*args, **kwargs)def accuracy(y_hat, y):"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = argmax(y_hat, axis=1)cmp = astype(y_hat, y.dtype) == yreturn float(reduce_sum(astype(cmp, y.dtype)))def evaluate_accuracy(net, data_iter):"""计算在指定数据集上模型的精度"""metric = Accumulator(2)  # 正确预测数、预测总数with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), size(y))return metric[0] / metric[1]class Accumulator:"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]def get_dataloader_workers():return 4def load_data_fashion_mnist(batch_size, resize=None):"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partitiondef net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)def cross_entropy(y_hat, y):#get the prediction probability b y,从表面上看,损失与类别数量无关probability = y_hat[range(len(y_hat)), y]return -torch.log(probability)def train_epoch_ch3(net, train_iter, loss, updater):# 训练损失总和、训练准确度总和、样本数metric = Accumulator(3)for X, y in train_iter:# 计算梯度并更新参数y_hat = net(X)l = loss(y_hat, y)# 使用定制的优化器和损失函数l.sum().backward()sgd([W,b], lr, X.shape[0])#batch_sizemetric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):for epoch in range(num_epochs):train_metrics = train_epoch_ch3(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)print(f'epoch {epoch + 1}, loss={train_metrics[0]:.5f}, train_acc={train_metrics[1]:.5f}, test_acc={test_acc:.5f}')return train_metricsdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()lr = 0.1
num_epochs = 10
batch_size = 256num_inputs = 784
num_outputs = 10W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)train_iter, test_iter = load_data_fashion_mnist(batch_size)
timer = Timer()
train_loss, train_acc = train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, sgd)
print(f'train takes {timer.stop():.2f} sec, loss={train_loss:.5f}, train_acc={train_acc:.5f}')

运行结果:

epoch 1, loss=0.78627, train_acc=0.74832, test_acc=0.79040
epoch 2, loss=0.57115, train_acc=0.81300, test_acc=0.81000
epoch 3, loss=0.52513, train_acc=0.82650, test_acc=0.81280
epoch 4, loss=0.50200, train_acc=0.83142, test_acc=0.82520
epoch 5, loss=0.48449, train_acc=0.83667, test_acc=0.82270
epoch 6, loss=0.47343, train_acc=0.83960, test_acc=0.82230
epoch 7, loss=0.46612, train_acc=0.84238, test_acc=0.82810
epoch 8, loss=0.45860, train_acc=0.84573, test_acc=0.82260
epoch 9, loss=0.45168, train_acc=0.84740, test_acc=0.83220
epoch 10, loss=0.44680, train_acc=0.84780, test_acc=0.83150
train takes 84.08 sec, loss=0.44680, train_acc=0.84780

训练10次,可以看到,总耗时84秒(CPU),后面改为使用GPU时间会短一些,后面再结合另外一篇文章一块整合;随着训练的进行,损失函数逐渐减小,准确率逐渐增大。

相关文章:

  • 第3章 自动化测试:从单元测试到硬件在环(HIL)
  • 电子电路:到底该怎么理解电容器的“通交流阻直流”?
  • ElasticSearch 8.x新特性面试题
  • 使用Maven部署WebLogic应用
  • Ubuntu 添加系统调用
  • React中useDeferredValue与useTransition终极对比。
  • Spring-boot初次使用
  • redis的pipline使用结合线程池优化实战
  • 精益数据分析(63/126):移情阶段的深度潜入——从用户生活到产品渗透的全链路解析
  • linux——mysql高可用
  • 用 CodeBuddy 打造我的「TextBeautifier」文本美化引擎
  • SEO 优化实战:ZKmall模板商城的 B2C商城的 URL 重构与结构化数据
  • Webpack DefinePlugin插件介绍(允许在编译时创建JS全局常量,常量可以在源代码中直接使用)JS环境变量
  • TCP/UDP协议原理和区别 笔记
  • RAGFlow Arbitrary Account Takeover Vulnerability
  • python的漫画网站管理系统
  • 目标检测工作原理:从滑动窗口到Haar特征检测的完整实现
  • 现代健康养生新风尚
  • 【前端基础】10、CSS的伪元素(::first-line、::first-letter、::before、::after)【注:极简描述】
  • upload-labs通关笔记-第10关 文件上传之点多重过滤(空格点绕过)
  • 体坛联播|热刺追平单赛季输球纪录,世俱杯或创收20亿美元
  • 中国社联成立95周年,《中国社联期刊汇编》等研究丛书出版
  • 赡养纠纷个案推动类案监督,检察机关保障特殊群体胜诉权
  • 辽宁盘山县一乡镇幼儿园四名老师被指多次殴打一女童,均被行拘
  • 汕头违建豪宅“英之园”将强拆,当地:将根据公告期内具体情况采取下一步措施
  • 第1现场 | 美国称将取消制裁,对叙利亚意味着什么