当前位置: 首页 > news >正文

pytorch 演示 “变分状态空间模型(Variational State-Space Model, VSSM)“ 基于 MINIST数据集

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os# 设置随机种子,保证结果可复现
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 创建保存图像的目录
os.makedirs('visualizations', exist_ok=True)# 数据加载和预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)# 由于显存限制,增大batch_size可能会导致显存不足,因此选择适中的batch_size
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)# 定义简化版的VSSM模型
class VSSM(nn.Module):def __init__(self, input_size=784, hidden_size=32, state_size=16, output_size=10):super(VSSM, self).__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.state_size = state_sizeself.output_size = output_size# 编码器网络 - 将输入映射到隐状态分布self.encoder = nn.Sequential(nn.Linear(input_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, hidden_size),nn.ReLU())# 变分推断网络 - 生成隐状态的均值和方差self.fc_mu = nn.Linear(hidden_size, state_size)self.fc_logvar = nn.Linear(hidden_size, state_size)# 状态转移网络 - 预测下一个隐状态self.transition = nn.Sequential(nn.Linear(state_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, state_size))# 解码器网络 - 从隐状态重构输入self.decoder = nn.Sequential(nn.Linear(state_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, input_size))# 分类器网络 - 从隐状态预测类别self.classifier = nn.Sequential(nn.Linear(state_size, hidden_size),nn.ReLU(),nn.Dropout(0.2),  # 添加Dropout减少过拟合nn.Linear(hidden_size, output_size))def encode(self, x):# x: [batch_size, input_size]h = self.encoder(x)mu = self.fc_mu(h)      # [batch_size, state_size]logvar = self.fc_logvar(h)  # [batch_size, state_size]return mu, logvardef reparameterize(self, mu, logvar):# 重参数化技巧,实现隐变量的随机采样std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * std  # [batch_size, state_size]def decode(self, z):# z: [batch_size, state_size]return self.decoder(z)  # [batch_size, input_size]def classify(self, z):# z: [batch_size, state_size]return self.classifier(z)  # [batch_size, output_size]def forward(self, x):# x: [batch_size, 1, 28, 28]batch_size = x.size(0)x_flat = x.view(batch_size, -1)  # [batch_size, 784]# 编码并采样隐状态mu, logvar = self.encode(x_flat)z = self.reparameterize(mu, logvar)# 状态转移z_next = self.transition(z)# 解码和分类recon_flat = self.decode(z_next)pred = self.classify(z)return recon_flat, pred, mu, logvar, z, x_flat# 定义VSSM损失函数
def vssm_loss(recon_x, x, pred, target, mu, logvar, lambda_kl=0.1, lambda_cls=1.0):# 重构损失 - 衡量重构图像与原始图像的差异recon_loss = F.mse_loss(recon_x, x.view(x.size(0), -1), reduction='sum')# KL散度 - 衡量隐变量分布与标准正态分布的差异kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())# 分类损失cls_loss = F.cross_entropy(pred, target, reduction='sum')# 计算总损失batch_size = x.size(0)total_loss = (recon_loss + lambda_kl * kl_loss + lambda_cls * cls_loss) / batch_sizereturn total_loss, recon_loss.item()/batch_size, kl_loss.item()/batch_size, cls_loss.item()/batch_size# 绘制损失曲线的函数
def pltLoss(train_losses, test_losses, epochs):plt.figure(figsize=(10, 5))plt.plot(range(1, epochs+1), train_losses, 'b-', label='Training Loss')plt.plot(range(1, epochs+1), test_losses, 'r-', label='Test Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.title('Training and Test Loss')plt.legend()plt.grid(True)plt.tight_layout()plt.savefig('loss_curve.png')plt.close()# 可视化测试样本及其预测结果的函数
def plotTest(model, test_loader, device, epoch):model.eval()best_sample = Nonebest_confidence = -1best_info = Nonewith torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)# 前向传播获取中间结果recon_flat, pred, mu, logvar, z, x_flat = model(data)# 计算预测置信度confidence = F.softmax(pred, dim=1).max(dim=1)[0]# 找到置信度最高的样本max_idx = confidence.argmin().item()if confidence[max_idx] > best_confidence:best_confidence = confidence[max_idx].item()best_sample = {'input': data[max_idx].cpu(),'recon': recon_flat[max_idx].cpu().view(1, 28, 28),'target': target[max_idx].cpu().item(),'pred': pred[max_idx].argmax().cpu().item(),'confidence': best_confidence,'mu': mu[max_idx].cpu().numpy(),'logvar': logvar[max_idx].cpu().numpy(),'z': z[max_idx].cpu().numpy(),'pred_dist': F.softmax(pred[max_idx], dim=0).cpu().numpy()}# 释放不再需要的张量以节省显存del data, target, recon_flat, pred, mu, logvar, z, x_flat, confidence, max_idxtorch.cuda.empty_cache()if best_sample is not None:# 创建可视化plt.figure(figsize=(12, 8))# 1. 原始输入图像plt.subplot(2, 3, 1)plt.title(f'Input Image (True: {best_sample["target"]})')plt.imshow(best_sample['input'].squeeze().numpy(), cmap='gray')plt.axis('off')# 2. 重构图像plt.subplot(2, 3, 2)plt.title(f'Reconstructed Image')plt.imshow(best_sample['recon'].squeeze().numpy(), cmap='gray')plt.axis('off')# 3. 隐变量均值plt.subplot(2, 3, 3)plt.title('Latent Mean (μ)')plt.bar(range(len(best_sample['mu'])), best_sample['mu'])plt.xlabel('Dimension')plt.ylabel('Value')# 4. 隐变量方差plt.subplot(2, 3, 4)plt.title('Latent Log Variance (log σ²)')plt.bar(range(len(best_sample['logvar'])), best_sample['logvar'])plt.xlabel('Dimension')plt.ylabel('Value')# 5. 采样的隐变量plt.subplot(2, 3, 5)plt.title('Sampled Latent Variable (z)')plt.bar(range(len(best_sample['z'])), best_sample['z'])plt.xlabel('Dimension')plt.ylabel('Value')# 6. 预测分布plt.subplot(2, 3, 6)plt.title(f'Prediction Distribution (Pred: {best_sample["pred"]}, Conf: {best_sample["confidence"]:.4f})')plt.bar(range(10), best_sample['pred_dist'])plt.xticks(range(10))plt.xlabel('Class')plt.ylabel('Probability')plt.tight_layout()plt.savefig(f'visualizations/epoch_{epoch}_best_sample.png')plt.close()# 初始化模型、优化器和学习率调度器
model = VSSM().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5, verbose=True)# 训练函数
def train(model, train_loader, optimizer, epoch, device):model.train()train_loss = 0train_recon_loss = 0train_kl_loss = 0train_cls_loss = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()# 前向传播 - 接收所有6个返回值recon, pred, mu, logvar, z, x_flat = model(data)# 计算损失loss, recon_loss, kl_loss, cls_loss = vssm_loss(recon, data, pred, target, mu, logvar)# 反向传播和优化loss.backward()optimizer.step()# 累加损失train_loss += loss.item()train_recon_loss += recon_losstrain_kl_loss += kl_losstrain_cls_loss += cls_loss# 释放不再需要的张量以节省显存# del data, target, recon, pred, mu, logvar, z, x_flat, loss, recon_loss, kl_loss, cls_loss# torch.cuda.empty_cache()# 打印训练进度if batch_idx % 100 == 0:print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} 'f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')# 计算平均损失avg_loss = train_loss / len(train_loader)avg_recon_loss = train_recon_loss / len(train_loader)avg_kl_loss = train_kl_loss / len(train_loader)avg_cls_loss = train_cls_loss / len(train_loader)print(f'Epoch: {epoch} Average training loss: {avg_loss:.4f} 'f'(Recon: {avg_recon_loss:.4f}, KL: {avg_kl_loss:.4f}, Cls: {avg_cls_loss:.4f})')return avg_loss# 测试函数
def test(model, test_loader, device):model.eval()test_loss = 0test_recon_loss = 0test_kl_loss = 0test_cls_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)# 前向传播 - 接收所有6个返回值recon, pred, mu, logvar, z, x_flat = model(data)# 计算损失loss, recon_loss, kl_loss, cls_loss = vssm_loss(recon, data, pred, target, mu, logvar)# 累加损失test_loss += loss.item()test_recon_loss += recon_losstest_kl_loss += kl_losstest_cls_loss += cls_loss# 计算分类准确率pred_class = pred.argmax(dim=1, keepdim=True)correct += pred_class.eq(target.view_as(pred_class)).sum().item()#   # 释放不再需要的张量以节省显存#   del data, target, recon, pred, mu, logvar, z, x_flat, loss, recon_loss, kl_loss, cls_loss, pred_class#   torch.cuda.empty_cache()# 计算平均损失和准确率avg_loss = test_loss / len(test_loader)avg_recon_loss = test_recon_loss / len(test_loader)avg_kl_loss = test_kl_loss / len(test_loader)avg_cls_loss = test_cls_loss / len(test_loader)accuracy = 100. * correct / len(test_loader.dataset)print(f'Average test loss: {avg_loss:.4f} 'f'(Recon: {avg_recon_loss:.4f}, KL: {avg_kl_loss:.4f}, Cls: {avg_cls_loss:.4f})')print(f'Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)')return avg_loss, accuracy# 主训练循环
epochs = 10
train_losses = []
test_losses = []
best_accuracy = 0.0for epoch in range(1, epochs + 1):print(f'\nEpoch {epoch}/{epochs}')# 训练一个epochtrain_loss = train(model, train_loader, optimizer, epoch, device)train_losses.append(train_loss)# 测试模型test_loss, accuracy = test(model, test_loader, device)test_losses.append(test_loss)# 可视化最佳样本plotTest(model, test_loader, device, epoch)# 学习率调整scheduler.step(test_loss)# 保存最佳模型if accuracy > best_accuracy:best_accuracy = accuracytorch.save(model.state_dict(), 'best_model.pth')print(f'Best model saved with accuracy: {accuracy:.2f}%')# 绘制损失曲线pltLoss(train_losses, test_losses, epoch)# 释放不再需要的张量以节省显存torch.cuda.empty_cache()print(f'\nTraining completed. Best accuracy: {best_accuracy:.2f}%')  
http://www.dtcms.com/a/292397.html

相关文章:

  • CSS中的transform
  • 算法笔记之堆排序
  • Oracle数据恢复—Oracle数据库所在分区被删除后报错的数据恢复案例
  • Oracle 12c 创建数据库初级教程
  • sqli-labs通关笔记-第14关 POST报错型注入(双引号闭合 手工注入+脚本注入两种方法)
  • mac实现sudo命切换node版本
  • 【C++进阶】揭秘list迭代器:从底层实现到极致优化
  • WIFI路由器长期不重启,手机连接时提示无IP分配
  • 【Linux系统】基础IO
  • Git使用git graph插件回滚版本
  • 【自定义一个简单的CNN模型】——深度学习.卷积神经网络
  • 大气能见度监测仪:洞察大气 “清晰度” 的科技之眼
  • 智慧教室:科技赋能,奏响个性化学习新乐章
  • MyBatis拦截器插件:实现敏感数据字段加解密
  • 中国科技信息杂志中国科技信息杂志社中国科技信息编辑部2025年第14期目录
  • 「芯生态」杰发科技AC7870携手IAR开发工具链,助推汽车电子全栈全域智能化落地
  • Vue中最简单的PDF引入方法及优缺点分析
  • docker build 和compose 学习笔记
  • CASB架构:了解正向代理、反向代理和API扫描
  • [转]Rust:过程宏
  • JMeter 实现 Protobuf 加密解密
  • AI 音频产品开发模板及流程(一)
  • 网络安全第三次作业搭建前端页面并解析
  • allegro 16.6配置CIS库报错 ORCIS-6129 ORCIS-6469
  • LeetCode 658.找到K个最接近的元素
  • .NET使用EPPlus导出EXCEL的接口中,文件流缺少文件名信息
  • Unity笔记——事件中心
  • 力扣-300.最长递增子序列
  • 以太坊网络发展分析:技术升级与市场动态的双重驱动
  • 快手开源 Kwaipilot-AutoThink 思考模型,有效解决过度思考问题