当前位置: 首页 > news >正文

变分自编码器VAE的Pytorch实现

一、导入第三方库

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False   # 正常显示负号
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset,DataLoader

二、手写数字数据集准备

#手写数字数据集
class MINISTDataset(Dataset):def __init__(self,files,root_dir,transform=None):self.files=filesself.root_dir=root_dirself.transform=transformself.labels=[]for f in files:parts=f.split("_")p=parts[2].split(".")[0]self.labels.append(int(p))def __len__(self):return len(self.files)def __getitem__(self,idx):img_path=os.path.join(self.root_dir,self.files[idx])img=Image.open(img_path).convert("L")if self.transform:img=self.transform(img)label=self.labels[idx]return img,label

三、VAE模型的pytorch代码

#编码器
class Encoder(nn.Module):def __init__(self):super().__init__()self.conv1=nn.Sequential(nn.Conv2d(1,10,kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.conv2=nn.Sequential(nn.Conv2d(10,20,kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.fc1=nn.Linear(320,160)self.fc21=nn.Linear(160,80)  #均值self.fc22=nn.Linear(160,80)  #方差self.relu=nn.ReLU()def forward(self,x):batch_size=x.size(0)x=self.conv1(x)x=self.conv2(x)x=x.view(batch_size,-1)h=self.relu(self.fc1(x))mu=self.fc21(h)log_var=self.fc22(h)return mu,log_var#解码器
class Decoder(nn.Module):def __init__(self):super().__init__()self.main=nn.Sequential(nn.Linear(80,160),nn.ReLU(),nn.Linear(160,320),nn.ReLU(),nn.Linear(320,28*28),nn.Sigmoid())def forward(self,z):return self.main(z)#变分自编码器
class VAE(nn.Module):def __init__(self,encoder,decoder):super().__init__()self.encoder=encoderself.decoder=decoder#重参数化def reparameterize(self,mu,log_var):std=torch.exp(0.5*log_var)  #计算标准差eps=torch.randn_like(std)   #从标准正态分布中采样噪声z=mu+eps*std  #重参数化return zdef forward(self,x):mu,log_var=self.encoder(x)z=self.reparameterize(mu,log_var)return self.decoder(z),mu,log_var

四、主程序

if __name__=="__main__":#对数据做归一化处理transforms=transforms.Compose([transforms.Resize((28,28)),transforms.ToTensor()])#路径base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'train_dir=os.path.join(base_dir,"minist_train")#获取文件夹里图像的名称train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]#创建数据集和数据加载器train_dataset=MINISTDataset(train_files,train_dir,transform=transforms)train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)#参数num_epochs=50lr=0.001#模型初始化encoder=Encoder()decoder=Decoder()vae=VAE(encoder,decoder)criterion=nn.BCELoss()optimizer=optim.Adam(vae.parameters(),lr=lr,betas=(0.5,0.999))#记录损失函数值epoch_loss=[]for epoch in range(num_epochs):total_loss=0.0for data in train_loader:images,_=data#images=images.view(images.size(0),-1)optimizer.zero_grad()outputs,mu,logvar=vae(images)#计算重构损失和KL散度reconstruction_loss=criterion(outputs,images.view(images.size(0),-1))kl_divergence=-0.5*torch.mean(1+logvar-mu.pow(2)-logvar.exp())loss=reconstruction_loss+0.1*kl_divergenceloss.backward()optimizer.step()total_loss+=loss.item()avg_loss=total_loss/len(train_loader)epoch_loss.append(avg_loss)print("Epoch",epoch,"  Loss:",avg_loss)#生成新图像with torch.no_grad():if (epoch+1)%5==0:z=torch.randn(9,80)plt.figure(figsize=(9,9))for i in range(9):plt.subplot(3,3,i+1)plt.imshow(decoder(z[i]).view(28,28),cmap="gray")plt.axis("off")name=f"vae_gen_img_{epoch}.jpg"gen_name=os.path.join("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_img",name)plt.savefig(gen_name,dpi=300)plt.close()#绘制损失函数曲线图plt.figure(figsize=(12,6))plt.plot(epoch_loss,color="tomato")plt.xlabel("epoch")plt.ylabel("loss")plt.title("损失函数曲线图")plt.legend()plt.grid()plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_loss.jpg")plt.close()

五、运行结果

5.1 损失函数曲线图

5.2 生成的图像

这里只展示一部分

vae_gen_img_4.jpg

vae_gen_img_29.jpg

vae_gen_img_49.jpg

六、VAE的完整代码

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False   # 正常显示负号
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset,DataLoader#手写数字数据集
class MINISTDataset(Dataset):def __init__(self,files,root_dir,transform=None):self.files=filesself.root_dir=root_dirself.transform=transformself.labels=[]for f in files:parts=f.split("_")p=parts[2].split(".")[0]self.labels.append(int(p))def __len__(self):return len(self.files)def __getitem__(self,idx):img_path=os.path.join(self.root_dir,self.files[idx])img=Image.open(img_path).convert("L")if self.transform:img=self.transform(img)label=self.labels[idx]return img,label#编码器
class Encoder(nn.Module):def __init__(self):super().__init__()self.conv1=nn.Sequential(nn.Conv2d(1,10,kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.conv2=nn.Sequential(nn.Conv2d(10,20,kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.fc1=nn.Linear(320,160)self.fc21=nn.Linear(160,80)  #均值self.fc22=nn.Linear(160,80)  #方差self.relu=nn.ReLU()def forward(self,x):batch_size=x.size(0)x=self.conv1(x)x=self.conv2(x)x=x.view(batch_size,-1)h=self.relu(self.fc1(x))mu=self.fc21(h)log_var=self.fc22(h)return mu,log_var#解码器
class Decoder(nn.Module):def __init__(self):super().__init__()self.main=nn.Sequential(nn.Linear(80,160),nn.ReLU(),nn.Linear(160,320),nn.ReLU(),nn.Linear(320,28*28),nn.Sigmoid())def forward(self,z):return self.main(z)#变分自编码器
class VAE(nn.Module):def __init__(self,encoder,decoder):super().__init__()self.encoder=encoderself.decoder=decoder#重参数化def reparameterize(self,mu,log_var):std=torch.exp(0.5*log_var)  #计算标准差eps=torch.randn_like(std)   #从标准正态分布中采样噪声z=mu+eps*std  #重参数化return zdef forward(self,x):mu,log_var=self.encoder(x)z=self.reparameterize(mu,log_var)return self.decoder(z),mu,log_varif __name__=="__main__":#对数据做归一化处理transforms=transforms.Compose([transforms.Resize((28,28)),transforms.ToTensor()])#路径base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'train_dir=os.path.join(base_dir,"minist_train")#获取文件夹里图像的名称train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]#创建数据集和数据加载器train_dataset=MINISTDataset(train_files,train_dir,transform=transforms)train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)#参数num_epochs=50lr=0.001#模型初始化encoder=Encoder()decoder=Decoder()vae=VAE(encoder,decoder)criterion=nn.BCELoss()optimizer=optim.Adam(vae.parameters(),lr=lr,betas=(0.5,0.999))#记录损失函数值epoch_loss=[]for epoch in range(num_epochs):total_loss=0.0for data in train_loader:images,_=data#images=images.view(images.size(0),-1)optimizer.zero_grad()outputs,mu,logvar=vae(images)#计算重构损失和KL散度reconstruction_loss=criterion(outputs,images.view(images.size(0),-1))kl_divergence=-0.5*torch.mean(1+logvar-mu.pow(2)-logvar.exp())loss=reconstruction_loss+0.1*kl_divergenceloss.backward()optimizer.step()total_loss+=loss.item()avg_loss=total_loss/len(train_loader)epoch_loss.append(avg_loss)print("Epoch",epoch,"  Loss:",avg_loss)#生成新图像with torch.no_grad():if (epoch+1)%5==0:z=torch.randn(9,80)plt.figure(figsize=(9,9))for i in range(9):plt.subplot(3,3,i+1)plt.imshow(decoder(z[i]).view(28,28),cmap="gray")plt.axis("off")name=f"vae_gen_img_{epoch}.jpg"gen_name=os.path.join("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_img",name)plt.savefig(gen_name,dpi=300)plt.close()#绘制损失函数曲线图plt.figure(figsize=(12,6))plt.plot(epoch_loss,color="tomato")plt.xlabel("epoch")plt.ylabel("loss")plt.title("损失函数曲线图")plt.legend()plt.grid()plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_loss.jpg")plt.close()

http://www.dtcms.com/a/329277.html

相关文章:

  • 兰洋科技获第四届宁波市专利创新大赛殊荣,以液冷技术定义行业新标杆
  • 磁悬浮轴承转子动平衡:零接触旋转下的“隐形杀手”深度解析与精准猎杀指南
  • Java项目中地图功能如何创建
  • 使用 libpq 的 COPY 协议维护自定义 PG 到 PG 连接
  • 飞算JavaAI的中间件风暴:Redis + Kafka 全链路实战
  • WMware的安装以及Ubuntu22的安装
  • 自动驾驶中安全相关机器学习功能的可靠性定义方法
  • VirtualBox中的Ubuntu共享Windows的文件夹
  • 【Excel】被保护的文档如何显示隐藏的行或列
  • 厚铜PCB在百安级电流与高温环境中的关键作用
  • 普通电脑与云电脑的区别有哪些?全面科普
  • C++ 错误记录模块实现与解析
  • Redis:是什么、能做什么?
  • uniapp跨端性能优化方案
  • 各种排序算法(一)
  • Highcharts 图表示例|面积图与堆叠图(Area Stacked Chart)——让数据趋势更有层次感
  • SODA自然美颜相机(甜盐相机国际版) v9.3.0
  • LangChain是如何实现RAG多轮问答的
  • 【算法岗面试】手撕Self-Attention、Multi-head Attention
  • 比特币持有者结构性转变 XBIT分析BTC最新价格行情市场重构
  • 微店商品数据API接口的应用||电商API接口的应用
  • 数据结构与算法-选择题
  • 公司项目用户密码加密方案推荐(兼顾安全、可靠与通用性)
  • Chaos Vantage 2.8.1 发布:实时探索与材质工作流的全新突破
  • CacheBlend:结合缓存知识融合的快速RAG大语言模型推理服务
  • 大模型推理框架vLLM 中的Prompt缓存实现原理
  • 性能优化之通俗易懂学习requestAnimationFrame和使用场景举例
  • 来伊份×养馋记:社区零售4.0模式加速渗透上海市场
  • 四、深入剖析Java程序逻辑控制:从字节码到性能优化
  • MySQL事务原理分析以及隔离与锁