当前位置: 首页 > news >正文

3 mnist gan小试牛刀

3 mnist gan小试牛刀

  • 背景
  • 原理
  • 训练方法
  • 模型定义
    • G模型定义
    • D模型定义
    • 一些讨论
  • 训练函数
  • 图像的预处理和训练过程
  • 结果输出

背景

之前已经做了手写数字的预测了,最近对生成方向比较感兴趣,就算我不是这个研究方向的人,但是我也听过gan和diffusion的大名,但在我的印象中gan是比较早提出来的,就从gan先开始玩。

原理

生成式对抗网络(Generative adversarial network, GAN) 的原理很简单,本质上就是训练两个网络一个判别器(Discriminator,D),一个生成器(Generator,G)

其中D的作用就是判断输入的图像是否是伪造的
G的作用就是根据随机噪声来生成图片,从而达到欺骗D的效果。

然后通过不断训练,使得G的生成效果越来越好,D的鉴别效果也越来越好

训练方法

既然D的核心目标,是判断输入图像的真伪,那么D的loss函数也很直观了

D的loss应该有两个部分来组成,一部分是真实图像所产生的loss,我们需要鼓励模型预测真是的图像为真,也就是说希望模型对于真实的图像输出的概率值越接近1越好

同时对于伪造的图像,肯定就是输出的概率值越接近0越好。

因此就可以用简单二元交叉熵损失函数来进行loss的计算了,这里给出简单的二元交叉熵损失函数的描述

其中,y 是真实标签,p 是模型预测的概率。
当真实标签 y 为1时,如果预测概率 p 接近1,则损失接近0;
反之,如果 p 接近0,则损失会变得非常大。
同理,当 y 为0时,如果 p 接近0,则损失接近0;
如果 p 接近1,则损失会变得非常大。

在计算过程中,我们将真实图像的y记为1,伪造图像的y记为0

我们记真实图像产生的loss是D_T_loss,伪造图像产生的误差是D_F_loss,那么D的loss就是

D_loss = D_T_loss + D_F_loss

而G的loss就更简单了,G的目标是欺骗D,是的D将本是伪造的图像预测成真实的图像,所以
我们只需要将伪造图像的y记为1,使得他往更加欺骗D的角度去训练,换句话说,就是我们希望伪造图像的预测值越来越接近1

模型定义

G模型定义

class Generator(nn.Module):def __init__(self, latent_dim=128, img_size=28, num_channels=1):super(Generator, self).__init__()self.img_size = img_sizeself.num_channels = num_channelsself.main = nn.Sequential(# 输入: 噪声 Z,维度为 [B, latent_dim, 1, 1] (经过全连接层后)# 经过全连接层将噪声向量转换为适合反卷积的维度nn.Linear(latent_dim, 256 * (img_size // 4) * (img_size // 4)),nn.BatchNorm1d(256 * (img_size // 4) * (img_size // 4)),nn.ReLU(True),# 将一维向量 reshape 成特征图nn.Unflatten(1, (256, img_size // 4, img_size // 4)), # 256 是通道数,img_size//4 是 H 和 W# 反卷积层 1: 256x7x7 -> 128x14x14nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.ReLU(True),# 反卷积层 2: 128x14x14 -> 64x28x28nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(True),# 反卷积层 3: 64x28x28 -> 1x28x28 (输出图像)nn.ConvTranspose2d(64, num_channels, kernel_size=3, stride=1, padding=1, bias=False),nn.Tanh() # 将像素值缩放到 [-1, 1])def forward(self, input):# 首先通过全连接层将噪声维度扩展x = self.main[0](input)x = self.main[1](x)x = self.main[2](x)# 然后进行 Unflattenx = self.main[3](x)# 接着进行反卷积x = self.main[4:](x)return x

D模型定义

class Discriminator(nn.Module):def __init__(self, img_channels=1, feature_maps_d=64):super(Discriminator, self).__init__()self.main = nn.Sequential(# 输入图像 (batch_size, img_channels, 28, 28)nn.Conv2d(img_channels, feature_maps_d, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# (batch_size, feature_maps_d, 14, 14)nn.Conv2d(feature_maps_d, feature_maps_d * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(feature_maps_d * 2),nn.LeakyReLU(0.2, inplace=True),# (batch_size, feature_maps_d*2, 7, 7)nn.Conv2d(feature_maps_d * 2, feature_maps_d * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(feature_maps_d * 4),nn.LeakyReLU(0.2, inplace=True),# (batch_size, feature_maps_d*4, 3, 3) (或者 4x4 取决于具体计算)# 最终的卷积层输出到1nn.Conv2d(feature_maps_d * 4, 1, 3, 1, 0, bias=False), # 调整核大小以适应最终尺寸nn.Sigmoid() # 输出一个介于0和1之间的概率)def forward(self, input):# 确保输入是图像格式 (batch_size, channels, H, W)# 如果你的 dataloader 已经输出 (N, 1, 28, 28),则不需要 viewreturn self.main(input.view(input.size(0), 1, 28, 28)) # 对于28x28的单通道图像

一些讨论

因为目前只是用来生成mnist数据集的图片的,所以上面的维度都是按照Mnist数据集来的,我也尝试过只用mlp,我堆叠了5层Mlp来作为G和D,G的效果很不好,后来就老老实实用卷积了。当然我并没有充分去探索,可能不同的mlp堆叠方式效果可能会更好,毕竟attention的本质也就是三个Mlp加一个softmax。

训练函数


def train_gan(G, D, dataloader, epochs, device,lr_D=2e-4, lr_G=2e-4,beta1=0.5,noise_dim=128):criterion = nn.BCELoss() # 二元交叉熵损失optimizer_D = torch.optim.Adam(D.parameters(), lr=lr_D, betas=(beta1, 0.999))optimizer_G = torch.optim.Adam(G.parameters(), lr=lr_G, betas=(beta1, 0.999))print("开始GAN训练...")for epoch in range(epochs):D_loss_total = 0G_loss_total = 0loader_len = len(dataloader)for i, data in tqdm(enumerate(dataloader), desc=f"Epoch {epoch+1}/{epochs}", total=loader_len):real_images, _ = dataB = real_images.size(0)real_images = real_images.view(-1, 784).to(device) # 假设28x28图像展平为784# --- 训练判别器 (D) ---optimizer_D.zero_grad()# 使用真实图像进行训练real_outputs = D(real_images)# 对真实图像使用软标签,防止D过于自信real_labels = torch.full((B,), 0.9, device=device) # 真实图像的软标签errD_real = criterion(real_outputs.view(-1), real_labels)errD_real.backward()# 使用假图像进行训练noise = torch.randn(B, noise_dim, device=device)fake_images = G(noise)# print(fake_images.shape)# 在D的训练中,G的输出需要 .detach(),这样G的梯度不会在这里计算fake_output = D(fake_images.detach())# 对D的假图像也使用软标签,但值更低fake_labels = torch.full((B,), 0.1, device=device) # 假图像的软标签errD_fake = criterion(fake_output.view(-1), fake_labels)errD_fake.backward()errD = errD_real + errD_fakeoptimizer_D.step()D_loss_total += errD.item()# --- 训练生成器 (G) ---optimizer_G.zero_grad()# 生成新的假图像,避免使用D训练时旧的图像noise = torch.randn(B, noise_dim, device=device)fake_images = G(noise)# D对G生成的图像的输出。这里至关重要,不要使用 detach()。output = D(fake_images)# G希望D将假图像分类为真实(目标标签为1.0或软标签)# 这里我们使用完整的1.0作为G的目标,以强烈鼓励它欺骗DerrG = criterion(output.view(-1), torch.full((B,), 1.0, device=device))errG.backward()optimizer_G.step()G_loss_total += errG.item()# --- Epoch 结束总结 ---d_loss_avg = D_loss_total / loader_len# 根据G的更新次数调整平均损失g_loss_avg = G_loss_total / loader_lenprint(f"Epoch [{epoch+1}/{epochs}] "f"D Loss: {d_loss_avg:.4f}, G Loss: {g_loss_avg:.4f}")

我觉得我的注释已经写的很详细了,就不进行赘述了

图像的预处理和训练过程

from torchvision import transformsbatch_size = 256
epochs = 100transform = transforms.Compose([transforms.ToTensor(),  # 转换为 Tensortransforms.Normalize((0.1307,), (0.3081,))  # 标准化
])# 下载 MNIST 数据集
mnist_train = MNIST(root='../dataset_file/mnist_raw', train=True, download=False,transform=transform)
dataloader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)# 初始化生成器和判别器
G = Generator().to(device)
D = Discriminator().to(device)# 训练GAN
train_gan(G, D, dataloader, epochs=epochs, device=device)

在这里插入图片描述

结果输出

#检测生成质量
import matplotlib.pyplot as pltdef generate_and_plot(G, num_images=16):noise = torch.randn(num_images, 128).to(device)with torch.no_grad():fake_images = G(noise).view(-1, 1, 28, 28).cpu()grid_size = int(num_images**0.5)fig, axes = plt.subplots(grid_size, grid_size, figsize=(8, 8))for i in range(grid_size):for j in range(grid_size):idx = i * grid_size + jif idx < num_images:axes[i, j].imshow(fake_images[idx].squeeze(), cmap='gray')axes[i, j].axis('off')plt.tight_layout()
generate_and_plot(G, num_images=16)

在这里插入图片描述
在这里插入图片描述

生成出来的还是有模有样的。

相关文章:

  • 6.11 打卡
  • 亚马逊商品数据实时获取方案:API 接口开发与安全接入实践
  • Jenkins + Docker + Kubernetes(JKD)在 DevOps CI/CD 中的核心价值与实践要点
  • 鹰盾Win播放器作为专业的视频安全解决方案,除了硬件翻录外还有什么呢?
  • 网络安全中对抗性漂移的多智能体强化学习
  • R语言缓释制剂QBD解决方案之二
  • 微信小程序分享带参数地址
  • 网传西门子12亿美元收购云原生工业软件,云化PLM系统转机在协同
  • UniApp APP打包方法(Android/iOS双平台)
  • iOS 26 beta1 重新禁止 JIT 执行,Flutter 下的 iOS 真机 hot load 暂时无法使用
  • React Native 跨平台开发:iOS 与安卓原生模块高效交互
  • 腾讯开源 ovCompose 跨平台框架:实现一次跨三端(Android/iOS/鸿蒙)
  • 前端实现ios26最新液态玻璃效果!
  • 【云原生】阿里云SLS日志自定义字段标签实现日志告警
  • MatAnyone本地部署,视频分割处理,绿幕抠像(WIN/MAC)
  • 数据可视化新姿势:Altair的声明式魔法
  • PyTorch:让深度学习飞入寻常百姓家(从零开始玩转张量与神经网络!)
  • MFE微前端基础版:Angular + Module Federation + webpack + 路由(Route way)完整示例
  • Mac 上使用 mysql -u root -p 命令,出现“zsh: command not found: mysql“?如何解决
  • 11.TCP三次握手
  • 中山企业网站建设公司/今天最新疫情情况
  • 反恶意镜像网站/网络营销是什么
  • 网站建设总流程/客户引流的最快方法是什么
  • wordpress备份网站/软文推广发布平台
  • 网站模板套用/谁有恶意点击软件
  • 1668阿里巴巴官网/外贸网站优化推广