当前位置: 首页 > news >正文

【GAN网络入门系列】一,手写字MINST图片生成

在这里插入图片描述

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

博主简介:努力学习的22级本科生一枚 🌟​;探索AI算法,C++,go语言的世界;在迷茫中寻找光芒​🌸
博客主页:羊小猪~~-CSDN博客
内容简介GAN入门案例一,以生成minst字体为例。
GAN入门简介:GAN难度不小,本文打算更新三篇文章入门GAN,第一篇以知道什么是GAN(判别器、生成器),以手写生成字体为例;第二篇是GAN论文精度;第三篇是GAN人脸生成。

文章目录

  • 1、理论基础
    • 生成器
    • 判别器
    • 简单举例
  • 2、问题提出
  • 3、模型搭建(以mnist字为例)
    • 1、准备
      • 1、导入库与定义参数
      • 2、下载数据
      • 3、加载数据
    • 2、定义模型
      • 1、定义判别器
      • 2、定义生成器
    • 三、模型训练(概率分布)
      • 1、创建实例
      • 2、模型训练(核心)
      • 3、模型保存
  • 4、模型结果(生成图片)
  • 5、问题解决

1、理论基础

生成对抗网络(GAN)是一个很热门的方向—-AI生成图片,他不是一个具体的网络,而是指一种基于博弈思想设计的网络。

GAN网络有两部分组成,生成器和判别器,其中:

  • 生成器:从某种噪声分布(一般是正态分布)中随机采用作为输入,输出与训练集中真实样本非常相识的人工样本。
  • 判别器(分类器):输入真实样本或者人工样本数据,而判别器的作用就是将人工样本和真实样本分离出来。

👍 博弈

生成器和判别器交替运行,相互博弈,生成器的任务是如何让判别器判别不出来是人工样本还是真实样本,判别器的任务是尽最大努力将真实样本和人工样本分离出来;理想情况下,最后结果是判别器无法判别给定样本的真实性,即输出的样本中50%是真,50%是假,这个时候停止博弈,生成器就具有**“伪造样本”**的能力。

综上可以看出,GAN网络核心是生成器和判别器

生成器

📚 GANs中,生成器G选取随机噪声z作为输入,通过生成器不断拟合和优化,最总输出一个真实样本尺寸相同,分布相似的伪造样本G(z).

生成器本质

使用一个生成式方法模型,对数据的分布假设和参数进行学习,根据学习到的模型重新采样出新的样本。

📖 生成式方法

指那些能够学习数据分布,并基于该分布生产新数据的模型,如自回归模型(AR),基于前面的数据来生成下一个数据。

📘 数据的分布假设

指在进行数据分析的时候,对数据的分布形式做出的一种假设。比如说我们认为一组人的身高符合正态分布,这个时候就可以假设身高数据符合正态分布,就可以用正态分布的性质去求解问题。

😂 我查阅相关的资料发现,其实这个思想我们在高中、大学学的概率是一样的,我们在做概率题的时候,经常都会看到一个条件,“数据符合……分布”。因为对于相同的数据分布来说,他的统计值是很相似的,如:均值(衡量数据的平均水平)、方差(数据的分散程度)、峰值等,预测能力也是相同的。

📚分布参数

分布假设的参数,如假设数据符合状态分布,则标准差,方差就是分布参数

数学角度

首先对于给定的真实数据的显示变量(观察到的)或隐式变量(隐藏的,需要挖掘的变量)进行分布假设,然后将真实数据输入到模型中,对变量和参数进行训练,最后学习到一个假设的进似分布,可以用来生产新数据。

机器学习角度

模型不会直接生成这个假设,而是通过不断学习新数据,对模型进行修正,不断的进行优化,最后到达目标。

判别器

目的:判别输入数据的真伪。

判别器D对于输入样本x,输出一个[0, 1]之间的概率数值D(x)x是真实数据或者来着生成器的数据。

💁‍♂ 规定

D(x)越接近为1就代表样本为真值的可能性更大,反正伪造的样本可能性就越大。

简单举例

在这里插入图片描述

  • 首先第一代模型1G输入的是随机噪声z,然后根据生产模型会生成一张图片,判别器进行二分类操作,生成照片判别为0,生成照片判别为1;
  • 为欺瞒一代判别器,于是开始模型优化,就变成了二代,同样的对于判别器也会进行优化更新,重复第一步,知道判别器判别真假概率各位50%则停止。

2、问题提出

  • 如何进行分布假设?

  • 生成器、判别器如何进行模型参数优化更新?(核心)

  • 判别器如何判别真假?

3、模型搭建(以mnist字为例)

1、准备

1、导入库与定义参数

在这里插入图片描述

在这里插入图片描述

import argparse  # 命令行,参数解析
import os 
import numpy as np 
import torchvision.transforms as transforms
from torchvision.utils import save_image # 可保存为图片
from torch.utils.data import DataLoader
from torchvision import datasets 
from torch.autograd import Variable
import torch.nn as nn 
import torch# 创建文件夹
os.makedirs("./images/", exist_ok=True)  # 记录训练过程图片效果
os.makedirs("./save/", exist_ok=True)  # 训练完模型保存位置
os.makedirs("./datasets/mnist", exist_ok=True)  # 下载数据集保存位置## 设置超参数
n_epochs = 50  # 训练50轮
batch_size = 64  # 批次大小
lr = 0.0002  # 学习率
'''  
# b1、b2 通过矩估计,根据之前的梯度样本来调整现在的梯度,使梯度变化更加稳定
# 结合公式发现b1、b2大,这对于梯度大的跟新就小一点,梯度小的更新就大一点
'''
b1 = 0.5  # adam 一阶矩估计,决定了历史梯度对当前时刻影响程度
b2 = 0.999  # adam 二阶矩估计,影响了梯度平方的历史历史保留比例
n_cpu = 2  # 数据加载CPU数量
latent_dim = 100  # 随机向量维度,影响生成器
img_size = 28  # 图像大小
channels = 1 # 图片通道数
sample_interval = 500  # 保存图像间隔,这个参数决定了训练过程中多久保存一次生成图像# 图像尺寸
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)  # 每个维度相乘cuda = True if torch.cuda.is_available() else False 
cuda
True

2、下载数据

mnist = datasets.MNIST(root="./datasets", train=True,download=True,transform=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])
)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not FoundDownloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./datasets\MNIST\raw\train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [07:47<00:00, 21193.22it/s]
Extracting ./datasets\MNIST\raw\train-images-idx3-ubyte.gz to ./datasets\MNIST\rawDownloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not FoundDownloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./datasets\MNIST\raw\train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<00:00, 29647.75it/s]
Extracting ./datasets\MNIST\raw\train-labels-idx1-ubyte.gz to ./datasets\MNIST\rawDownloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not FoundDownloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./datasets\MNIST\raw\t10k-images-idx3-ubyte.gz
100%|██████████| 1648877/1648877 [00:37<00:00, 43870.66it/s]
Extracting ./datasets\MNIST\raw\t10k-images-idx3-ubyte.gz to ./datasets\MNIST\rawDownloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not FoundDownloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./datasets\MNIST\raw\t10k-labels-idx1-ubyte.gz
100%|██████████| 4542/4542 [00:00<00:00, 1283728.35it/s]Extracting ./datasets\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./datasets\MNIST\raw

3、加载数据

dataloader = DataLoader(mnist,batch_size=batch_size,shuffle=True   # 随机打乱
)

2、定义模型

1、定义判别器

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_area, 512),   # 输入特征784, 输出为512nn.LeakyReLU(0.2, inplace=True),  # 激活函数nn.Linear(512, 256),   # 512 -> 256nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),   # 二分类激活函数nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0), -1)v = self.model(img_flat)return v

2、定义生成器

# 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维,
# 然后通过LeakyReLU激活函数,接着进行一个线性变换,再经过一个LeakyReLU激活函数,
# 然后经过线性变换将其变成784维,最后经过Tanh激活函数是希望生成的假的图片数据分布, 能够在-1~1之间。
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()# 模型中间块def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))  # 正则化layers.append(nn.LeakyReLU(0.2, inplace=True))return layers # 定义分类模型self.model = nn.Sequential(*block(latent_dim, 128, normalize=False), # 不进行标准化*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, img_area),# 映射[-1, 1]nn.Tanh())def forward(self, z):   # 输入(64, 100)噪声imgs = self.model(z)  # 噪声通过生成器imgs = imgs.view(imgs.size(0), *img_shape)  # reshape(64, 1, 28, 28)return imgs 

三、模型训练(概率分布)

1、创建实例

generator = Generator()
discriminator = Discriminator()# 损失函数
loss_fn = torch.nn.BCELoss()# 定义优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))# 转移到显卡上
if torch.cuda.is_available():generator = generator.cuda()discriminator = discriminator.cuda()loss_fn = loss_fn.cuda()

2、模型训练(核心)

for epoch in range(n_epochs):for i, (imgs, _) in enumerate(dataloader):#-----------  判别器imgs = imgs.view(imgs.size(0), -1) # 展开real_img = Variable(imgs).cuda()   # Tensor->Variable放到计算图中,可以自动求导real_label = Variable(torch.ones(imgs.size(0), 1)).cuda()  # 真实图片1fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()  # 真实图片1# ----------- 训练判别器# 1、计算真图片损失real_out = discriminator(real_img)  loss_real_D = loss_fn(real_out, real_label)  # 得到真实图片的lossreal_scores = real_out   # 得到真实图片的判别值tanh, 越接近1越好# 2、计算假图片损失 ***************z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()  # 随机生成噪声(128, 100)fake_img = generator(z).detach()fake_out = discriminator(fake_img)  # 生成数据判别loss_fake_D = loss_fn(fake_out, fake_label)  # 假的损失函数fake_scores = fake_out # 损失函数优化loss_D = loss_real_D + loss_fake_D # 真 + 假, 考虑到了真+ 假# 反向传播, 对判别器进行更新optimizer_D.zero_grad()loss_D.backward()optimizer_D.step()# ----------------  生成器训练# 原理:目的是希望生成的假的图片被判别器判断为真的图片,## 在此过程中,将判别器固定,将假的图片传入判别器的结果与真实的label对应,## 反向传播更新的参数是生成网络里面的参数,## 这样可以通过更新生成网络里面的参数,来训练网络,使得生成的图片让判别器以为是真的, 这样就达到了对抗的目的z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()fake_img = generator(z)  # 生成器output = discriminator(fake_img)# 损失和优化loss_G = loss_fn(output, real_label)# 梯度更新optimizer_G.zero_grad()loss_G.backward()optimizer_G.step()# 打印结果if (i + 1) % 300 == 0:print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"% (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean()))## 保存训练过程中的图像batches_done = epoch * len(dataloader) + iif batches_done % sample_interval == 0:save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True) # 生成数据保存
[Epoch 0/50] [Batch 299/938] [D loss: 1.026318] [G loss: 1.046470] [D real: 0.654074] [D fake: 0.441841]
[Epoch 0/50] [Batch 599/938] [D loss: 1.191877] [G loss: 0.912844] [D real: 0.458853] [D fake: 0.223538]
[Epoch 0/50] [Batch 899/938] [D loss: 0.938042] [G loss: 1.040793] [D real: 0.607376] [D fake: 0.302409]
[Epoch 1/50] [Batch 299/938] [D loss: 1.004832] [G loss: 2.220758] [D real: 0.825477] [D fake: 0.541220]
[Epoch 1/50] [Batch 599/938] [D loss: 1.010692] [G loss: 1.499485] [D real: 0.555188] [D fake: 0.226147]
[Epoch 1/50] [Batch 899/938] [D loss: 0.880904] [G loss: 1.338730] [D real: 0.626870] [D fake: 0.233363]
[Epoch 2/50] [Batch 299/938] [D loss: 0.829813] [G loss: 2.303998] [D real: 0.747940] [D fake: 0.371720]
[Epoch 2/50] [Batch 599/938] [D loss: 1.053018] [G loss: 2.094998] [D real: 0.830040] [D fake: 0.523504]
[Epoch 2/50] [Batch 899/938] [D loss: 0.878583] [G loss: 1.198429] [D real: 0.552428] [D fake: 0.116365]
[Epoch 3/50] [Batch 299/938] [D loss: 0.623464] [G loss: 2.115592] [D real: 0.798577] [D fake: 0.298132]
[Epoch 3/50] [Batch 599/938] [D loss: 0.849235] [G loss: 1.517444] [D real: 0.664473] [D fake: 0.281270]
[Epoch 3/50] [Batch 899/938] [D loss: 1.018560] [G loss: 2.523264] [D real: 0.762676] [D fake: 0.475363]
[Epoch 4/50] [Batch 299/938] [D loss: 0.675941] [G loss: 1.221385] [D real: 0.687409] [D fake: 0.177262]
[Epoch 4/50] [Batch 599/938] [D loss: 0.518201] [G loss: 2.030452] [D real: 0.820068] [D fake: 0.218934]
[Epoch 4/50] [Batch 899/938] [D loss: 0.596310] [G loss: 2.074541] [D real: 0.740501] [D fake: 0.182397]
[Epoch 5/50] [Batch 299/938] [D loss: 0.875196] [G loss: 3.143662] [D real: 0.874345] [D fake: 0.492405]
[Epoch 5/50] [Batch 599/938] [D loss: 0.724136] [G loss: 2.575390] [D real: 0.869490] [D fake: 0.398790]
[Epoch 5/50] [Batch 899/938] [D loss: 0.899812] [G loss: 1.146360] [D real: 0.563581] [D fake: 0.101017]
[Epoch 6/50] [Batch 299/938] [D loss: 0.702829] [G loss: 2.000169] [D real: 0.649199] [D fake: 0.068437]
[Epoch 6/50] [Batch 599/938] [D loss: 0.590536] [G loss: 2.533962] [D real: 0.798274] [D fake: 0.255741]
[Epoch 6/50] [Batch 899/938] [D loss: 0.620965] [G loss: 1.547438] [D real: 0.674020] [D fake: 0.088620]
[Epoch 7/50] [Batch 299/938] [D loss: 0.763007] [G loss: 2.397530] [D real: 0.820272] [D fake: 0.369547]
[Epoch 7/50] [Batch 599/938] [D loss: 0.642566] [G loss: 2.361846] [D real: 0.813566] [D fake: 0.277883]
[Epoch 7/50] [Batch 899/938] [D loss: 0.766600] [G loss: 1.826628] [D real: 0.725339] [D fake: 0.288866]
[Epoch 8/50] [Batch 299/938] [D loss: 0.780408] [G loss: 1.391555] [D real: 0.712020] [D fake: 0.251366]
[Epoch 8/50] [Batch 599/938] [D loss: 1.040399] [G loss: 1.866020] [D real: 0.762199] [D fake: 0.451866]
[Epoch 8/50] [Batch 899/938] [D loss: 1.095012] [G loss: 0.927824] [D real: 0.497888] [D fake: 0.120893]
[Epoch 9/50] [Batch 299/938] [D loss: 0.784348] [G loss: 1.687682] [D real: 0.692549] [D fake: 0.249467]
[Epoch 9/50] [Batch 599/938] [D loss: 0.990058] [G loss: 1.017130] [D real: 0.545202] [D fake: 0.119006]
[Epoch 9/50] [Batch 899/938] [D loss: 0.801848] [G loss: 1.602949] [D real: 0.687938] [D fake: 0.251169]
[Epoch 10/50] [Batch 299/938] [D loss: 0.791089] [G loss: 2.393065] [D real: 0.800895] [D fake: 0.392401]
[Epoch 10/50] [Batch 599/938] [D loss: 1.030113] [G loss: 1.646965] [D real: 0.738721] [D fake: 0.436578]
[Epoch 10/50] [Batch 899/938] [D loss: 0.981594] [G loss: 2.254743] [D real: 0.805582] [D fake: 0.486691]
[Epoch 11/50] [Batch 299/938] [D loss: 0.991982] [G loss: 1.162609] [D real: 0.594142] [D fake: 0.249730]
[Epoch 11/50] [Batch 599/938] [D loss: 0.906605] [G loss: 1.491534] [D real: 0.714087] [D fake: 0.357005]
[Epoch 11/50] [Batch 899/938] [D loss: 0.844362] [G loss: 1.179038] [D real: 0.718654] [D fake: 0.316384]
[Epoch 12/50] [Batch 299/938] [D loss: 0.863015] [G loss: 1.507195] [D real: 0.697577] [D fake: 0.311967]
[Epoch 12/50] [Batch 599/938] [D loss: 1.340321] [G loss: 0.789641] [D real: 0.402968] [D fake: 0.147501]
[Epoch 12/50] [Batch 899/938] [D loss: 0.778064] [G loss: 1.558621] [D real: 0.719452] [D fake: 0.301073]
[Epoch 13/50] [Batch 299/938] [D loss: 0.910379] [G loss: 1.211193] [D real: 0.591672] [D fake: 0.199390]
[Epoch 13/50] [Batch 599/938] [D loss: 1.063265] [G loss: 1.422406] [D real: 0.632027] [D fake: 0.371244]
[Epoch 13/50] [Batch 899/938] [D loss: 1.001968] [G loss: 1.666258] [D real: 0.689657] [D fake: 0.395144]
[Epoch 14/50] [Batch 299/938] [D loss: 0.989379] [G loss: 1.208383] [D real: 0.667937] [D fake: 0.383515]
[Epoch 14/50] [Batch 599/938] [D loss: 0.906237] [G loss: 1.960370] [D real: 0.761342] [D fake: 0.398946]
[Epoch 14/50] [Batch 899/938] [D loss: 0.825228] [G loss: 1.379719] [D real: 0.678601] [D fake: 0.290060]
[Epoch 15/50] [Batch 299/938] [D loss: 1.141459] [G loss: 2.423731] [D real: 0.853908] [D fake: 0.573443]
[Epoch 15/50] [Batch 599/938] [D loss: 1.029997] [G loss: 1.796759] [D real: 0.758488] [D fake: 0.466061]
[Epoch 15/50] [Batch 899/938] [D loss: 1.061821] [G loss: 0.838967] [D real: 0.528519] [D fake: 0.219860]
[Epoch 16/50] [Batch 299/938] [D loss: 0.962612] [G loss: 1.375448] [D real: 0.631942] [D fake: 0.286362]
[Epoch 16/50] [Batch 599/938] [D loss: 0.889674] [G loss: 1.145487] [D real: 0.661181] [D fake: 0.276140]
[Epoch 16/50] [Batch 899/938] [D loss: 1.162354] [G loss: 2.665335] [D real: 0.858547] [D fake: 0.568792]
[Epoch 17/50] [Batch 299/938] [D loss: 1.302850] [G loss: 0.698707] [D real: 0.428801] [D fake: 0.105473]
[Epoch 17/50] [Batch 599/938] [D loss: 0.842663] [G loss: 1.710078] [D real: 0.701462] [D fake: 0.302471]
[Epoch 17/50] [Batch 899/938] [D loss: 0.832509] [G loss: 1.283650] [D real: 0.698389] [D fake: 0.313110]
[Epoch 18/50] [Batch 299/938] [D loss: 1.368770] [G loss: 2.371649] [D real: 0.878667] [D fake: 0.666697]
[Epoch 18/50] [Batch 599/938] [D loss: 0.949422] [G loss: 1.067938] [D real: 0.567139] [D fake: 0.180208]
[Epoch 18/50] [Batch 899/938] [D loss: 0.816572] [G loss: 1.360575] [D real: 0.643414] [D fake: 0.225392]
[Epoch 19/50] [Batch 299/938] [D loss: 0.937279] [G loss: 1.885128] [D real: 0.763646] [D fake: 0.430748]
[Epoch 19/50] [Batch 599/938] [D loss: 0.974825] [G loss: 2.017726] [D real: 0.778439] [D fake: 0.470931]
[Epoch 19/50] [Batch 899/938] [D loss: 0.915921] [G loss: 1.154979] [D real: 0.595053] [D fake: 0.221935]
[Epoch 20/50] [Batch 299/938] [D loss: 0.881256] [G loss: 1.747743] [D real: 0.694174] [D fake: 0.341680]
[Epoch 20/50] [Batch 599/938] [D loss: 0.918638] [G loss: 1.552267] [D real: 0.702267] [D fake: 0.352388]
[Epoch 20/50] [Batch 899/938] [D loss: 0.951683] [G loss: 1.106557] [D real: 0.579497] [D fake: 0.215146]
[Epoch 21/50] [Batch 299/938] [D loss: 0.974893] [G loss: 1.930849] [D real: 0.803696] [D fake: 0.480613]
[Epoch 21/50] [Batch 599/938] [D loss: 0.930471] [G loss: 1.194343] [D real: 0.625553] [D fake: 0.278800]
[Epoch 21/50] [Batch 899/938] [D loss: 0.859748] [G loss: 1.238153] [D real: 0.707986] [D fake: 0.333503]
[Epoch 22/50] [Batch 299/938] [D loss: 0.983662] [G loss: 2.023248] [D real: 0.786604] [D fake: 0.449383]
[Epoch 22/50] [Batch 599/938] [D loss: 0.844784] [G loss: 1.927803] [D real: 0.792305] [D fake: 0.412590]
[Epoch 22/50] [Batch 899/938] [D loss: 0.939253] [G loss: 1.503554] [D real: 0.728900] [D fake: 0.408329]
[Epoch 23/50] [Batch 299/938] [D loss: 0.879789] [G loss: 0.887674] [D real: 0.580086] [D fake: 0.184930]
[Epoch 23/50] [Batch 599/938] [D loss: 0.764812] [G loss: 1.432183] [D real: 0.672108] [D fake: 0.214926]
[Epoch 23/50] [Batch 899/938] [D loss: 0.883344] [G loss: 1.886303] [D real: 0.846723] [D fake: 0.431878]
[Epoch 24/50] [Batch 299/938] [D loss: 1.251742] [G loss: 2.720743] [D real: 0.870346] [D fake: 0.613369]
[Epoch 24/50] [Batch 599/938] [D loss: 0.859271] [G loss: 1.429896] [D real: 0.660809] [D fake: 0.257913]
[Epoch 24/50] [Batch 899/938] [D loss: 1.149945] [G loss: 0.889781] [D real: 0.543300] [D fake: 0.254658]
[Epoch 25/50] [Batch 299/938] [D loss: 1.094652] [G loss: 1.214759] [D real: 0.605972] [D fake: 0.337315]
[Epoch 25/50] [Batch 599/938] [D loss: 0.797354] [G loss: 1.191305] [D real: 0.660583] [D fake: 0.224572]
[Epoch 25/50] [Batch 899/938] [D loss: 0.845965] [G loss: 1.268798] [D real: 0.626741] [D fake: 0.203728]
[Epoch 26/50] [Batch 299/938] [D loss: 1.115303] [G loss: 0.840342] [D real: 0.535978] [D fake: 0.193441]
[Epoch 26/50] [Batch 599/938] [D loss: 0.774557] [G loss: 1.810735] [D real: 0.781696] [D fake: 0.343132]
[Epoch 26/50] [Batch 899/938] [D loss: 0.895646] [G loss: 1.480375] [D real: 0.676262] [D fake: 0.313794]
[Epoch 27/50] [Batch 299/938] [D loss: 0.948866] [G loss: 1.774191] [D real: 0.781517] [D fake: 0.434177]
[Epoch 27/50] [Batch 599/938] [D loss: 0.857572] [G loss: 1.523750] [D real: 0.691116] [D fake: 0.303740]
[Epoch 27/50] [Batch 899/938] [D loss: 0.871821] [G loss: 1.407018] [D real: 0.676466] [D fake: 0.285737]
[Epoch 28/50] [Batch 299/938] [D loss: 0.994471] [G loss: 1.704868] [D real: 0.755193] [D fake: 0.421763]
[Epoch 28/50] [Batch 599/938] [D loss: 0.927082] [G loss: 1.568059] [D real: 0.628921] [D fake: 0.225705]
[Epoch 28/50] [Batch 899/938] [D loss: 0.859597] [G loss: 1.028539] [D real: 0.630817] [D fake: 0.248941]
[Epoch 29/50] [Batch 299/938] [D loss: 0.953265] [G loss: 1.643800] [D real: 0.725071] [D fake: 0.378147]
[Epoch 29/50] [Batch 599/938] [D loss: 0.881621] [G loss: 1.632039] [D real: 0.667327] [D fake: 0.277784]
[Epoch 29/50] [Batch 899/938] [D loss: 0.812341] [G loss: 1.214891] [D real: 0.727618] [D fake: 0.307845]
[Epoch 30/50] [Batch 299/938] [D loss: 0.928579] [G loss: 1.605430] [D real: 0.803836] [D fake: 0.434104]
[Epoch 30/50] [Batch 599/938] [D loss: 0.968729] [G loss: 0.863796] [D real: 0.586883] [D fake: 0.168837]
[Epoch 30/50] [Batch 899/938] [D loss: 1.047158] [G loss: 0.931751] [D real: 0.572705] [D fake: 0.217274]
[Epoch 31/50] [Batch 299/938] [D loss: 0.978278] [G loss: 1.163587] [D real: 0.594499] [D fake: 0.236582]
[Epoch 31/50] [Batch 599/938] [D loss: 0.874393] [G loss: 1.391558] [D real: 0.664588] [D fake: 0.276939]
[Epoch 31/50] [Batch 899/938] [D loss: 0.714764] [G loss: 1.503244] [D real: 0.735266] [D fake: 0.280823]
[Epoch 32/50] [Batch 299/938] [D loss: 0.926041] [G loss: 1.102411] [D real: 0.622352] [D fake: 0.214901]
[Epoch 32/50] [Batch 599/938] [D loss: 0.878116] [G loss: 1.503615] [D real: 0.684305] [D fake: 0.280049]
[Epoch 32/50] [Batch 899/938] [D loss: 0.933366] [G loss: 1.587198] [D real: 0.736988] [D fake: 0.381085]
[Epoch 33/50] [Batch 299/938] [D loss: 1.014969] [G loss: 1.378572] [D real: 0.576361] [D fake: 0.217423]
[Epoch 33/50] [Batch 599/938] [D loss: 0.948930] [G loss: 1.505417] [D real: 0.734145] [D fake: 0.396490]
[Epoch 33/50] [Batch 899/938] [D loss: 0.794515] [G loss: 2.165621] [D real: 0.780711] [D fake: 0.358258]
[Epoch 34/50] [Batch 299/938] [D loss: 0.923137] [G loss: 2.110215] [D real: 0.805725] [D fake: 0.443990]
[Epoch 34/50] [Batch 599/938] [D loss: 0.900147] [G loss: 1.770728] [D real: 0.708048] [D fake: 0.351760]
[Epoch 34/50] [Batch 899/938] [D loss: 0.918207] [G loss: 1.825978] [D real: 0.751272] [D fake: 0.379385]
[Epoch 35/50] [Batch 299/938] [D loss: 0.786494] [G loss: 1.572840] [D real: 0.753308] [D fake: 0.315357]
[Epoch 35/50] [Batch 599/938] [D loss: 0.926298] [G loss: 1.078452] [D real: 0.614426] [D fake: 0.223858]
[Epoch 35/50] [Batch 899/938] [D loss: 0.844453] [G loss: 1.427545] [D real: 0.661638] [D fake: 0.250853]
[Epoch 36/50] [Batch 299/938] [D loss: 0.841130] [G loss: 1.609914] [D real: 0.671988] [D fake: 0.251540]
[Epoch 36/50] [Batch 599/938] [D loss: 0.835743] [G loss: 1.123168] [D real: 0.690642] [D fake: 0.272213]
[Epoch 36/50] [Batch 899/938] [D loss: 0.878262] [G loss: 1.813486] [D real: 0.723274] [D fake: 0.328700]
[Epoch 37/50] [Batch 299/938] [D loss: 0.914168] [G loss: 1.344702] [D real: 0.642786] [D fake: 0.208098]
[Epoch 37/50] [Batch 599/938] [D loss: 0.851241] [G loss: 1.402598] [D real: 0.700463] [D fake: 0.274131]
[Epoch 37/50] [Batch 899/938] [D loss: 0.815197] [G loss: 1.269329] [D real: 0.632013] [D fake: 0.169207]
[Epoch 38/50] [Batch 299/938] [D loss: 0.608096] [G loss: 1.815545] [D real: 0.815722] [D fake: 0.265863]
[Epoch 38/50] [Batch 599/938] [D loss: 0.913200] [G loss: 1.506232] [D real: 0.628835] [D fake: 0.234755]
[Epoch 38/50] [Batch 899/938] [D loss: 1.190256] [G loss: 2.323765] [D real: 0.837308] [D fake: 0.554138]
[Epoch 39/50] [Batch 299/938] [D loss: 0.836248] [G loss: 1.411437] [D real: 0.643575] [D fake: 0.186915]
[Epoch 39/50] [Batch 599/938] [D loss: 0.682158] [G loss: 2.188884] [D real: 0.782492] [D fake: 0.269880]
[Epoch 39/50] [Batch 899/938] [D loss: 0.835626] [G loss: 1.458166] [D real: 0.682812] [D fake: 0.267175]
[Epoch 40/50] [Batch 299/938] [D loss: 0.815727] [G loss: 1.891578] [D real: 0.768537] [D fake: 0.332005]
[Epoch 40/50] [Batch 599/938] [D loss: 0.954156] [G loss: 1.477366] [D real: 0.674973] [D fake: 0.341626]
[Epoch 40/50] [Batch 899/938] [D loss: 1.165488] [G loss: 2.452061] [D real: 0.873572] [D fake: 0.580245]
[Epoch 41/50] [Batch 299/938] [D loss: 0.807535] [G loss: 1.703256] [D real: 0.688650] [D fake: 0.219471]
[Epoch 41/50] [Batch 599/938] [D loss: 0.918228] [G loss: 1.949707] [D real: 0.808916] [D fake: 0.444576]
[Epoch 41/50] [Batch 899/938] [D loss: 0.851099] [G loss: 1.622494] [D real: 0.720547] [D fake: 0.301786]
[Epoch 42/50] [Batch 299/938] [D loss: 0.990698] [G loss: 1.669534] [D real: 0.718162] [D fake: 0.371262]
[Epoch 42/50] [Batch 599/938] [D loss: 0.852659] [G loss: 1.764711] [D real: 0.772116] [D fake: 0.331731]
[Epoch 42/50] [Batch 899/938] [D loss: 0.815819] [G loss: 2.146849] [D real: 0.802263] [D fake: 0.374821]
[Epoch 43/50] [Batch 299/938] [D loss: 0.664832] [G loss: 1.936043] [D real: 0.815041] [D fake: 0.295924]
[Epoch 43/50] [Batch 599/938] [D loss: 0.642531] [G loss: 2.022935] [D real: 0.769789] [D fake: 0.232571]
[Epoch 43/50] [Batch 899/938] [D loss: 0.708858] [G loss: 2.054930] [D real: 0.883253] [D fake: 0.399093]
[Epoch 44/50] [Batch 299/938] [D loss: 0.984197] [G loss: 0.908327] [D real: 0.597017] [D fake: 0.170716]
[Epoch 44/50] [Batch 599/938] [D loss: 1.000634] [G loss: 1.766603] [D real: 0.867114] [D fake: 0.511092]
[Epoch 44/50] [Batch 899/938] [D loss: 1.116100] [G loss: 1.961709] [D real: 0.766508] [D fake: 0.442761]
[Epoch 45/50] [Batch 299/938] [D loss: 1.047572] [G loss: 2.178110] [D real: 0.819176] [D fake: 0.495287]
[Epoch 45/50] [Batch 599/938] [D loss: 0.948527] [G loss: 1.080285] [D real: 0.616301] [D fake: 0.185327]
[Epoch 45/50] [Batch 899/938] [D loss: 0.822393] [G loss: 1.724983] [D real: 0.737316] [D fake: 0.276954]
[Epoch 46/50] [Batch 299/938] [D loss: 0.771683] [G loss: 2.185272] [D real: 0.818269] [D fake: 0.359337]
[Epoch 46/50] [Batch 599/938] [D loss: 0.899866] [G loss: 1.661140] [D real: 0.749309] [D fake: 0.355079]
[Epoch 46/50] [Batch 899/938] [D loss: 0.839567] [G loss: 1.016037] [D real: 0.619461] [D fake: 0.178458]
[Epoch 47/50] [Batch 299/938] [D loss: 0.850783] [G loss: 1.755419] [D real: 0.733569] [D fake: 0.320683]
[Epoch 47/50] [Batch 599/938] [D loss: 0.842861] [G loss: 2.125519] [D real: 0.853034] [D fake: 0.412070]
[Epoch 47/50] [Batch 899/938] [D loss: 0.903069] [G loss: 0.903264] [D real: 0.599899] [D fake: 0.158276]
[Epoch 48/50] [Batch 299/938] [D loss: 0.778783] [G loss: 1.301862] [D real: 0.758963] [D fake: 0.328192]
[Epoch 48/50] [Batch 599/938] [D loss: 0.836319] [G loss: 1.851378] [D real: 0.798966] [D fake: 0.379225]
[Epoch 48/50] [Batch 899/938] [D loss: 0.987796] [G loss: 1.919121] [D real: 0.718209] [D fake: 0.348957]
[Epoch 49/50] [Batch 299/938] [D loss: 0.983706] [G loss: 1.484739] [D real: 0.635931] [D fake: 0.246885]
[Epoch 49/50] [Batch 599/938] [D loss: 0.831517] [G loss: 1.506502] [D real: 0.700573] [D fake: 0.243325]
[Epoch 49/50] [Batch 899/938] [D loss: 0.777098] [G loss: 1.839880] [D real: 0.760655] [D fake: 0.325862]

3、模型保存

torch.save(generator.state_dict(), "./save/generator.pth")
torch.save(discriminator.state_dict(), "./save/discriminator.pth")

4、模型结果(生成图片)

在这里插入图片描述

可以看出,后面效果好了不少。

5、问题解决

  • 如何进行分布假设?—>rand(正态分布)

  • 生成器、判别器如何进行模型参数优化更新?(核心)

    在这里插入图片描述

    在这里插入图片描述

  • 判别器如何判别真假?

    • 看判别器模型代码—->tanh

相关文章:

  • 前端取经路——量子UI:响应式交互新范式
  • 知识蒸馏实战:用PyTorch和预训练模型提升小模型性能
  • 【笔记】导出Conda环境依赖以复现项目虚拟环境
  • LLaMA-Factory:准备模型和数据集
  • 详解具身智能开源数据集:ARIO(All Robots In One)
  • Java 多线程基础:Thread 类核心用法详解
  • RabbitMQ 消息模式实战:从简单队列到复杂路由(三)
  • 第一次做逆向
  • LLaMA-Factory:环境准备
  • 【全解析】EN18031标准下的SSM安全存储机制
  • 小刚说C语言刷题—1700请输出所有的2位数中,含有数字2的整数
  • ubuntu22.04卸载vscode
  • C#实现访问远程硬盘(附源码)
  • 技术剖析|线性代数之特征值分解,支撑AI算法的数学原理
  • 在MYSQL中导入cookbook.sql文件
  • Chrome代理IP配置教程常见方式附问题解答
  • Android 中使用通知(Kotlin 版)
  • 若依框架Consul微服务版本
  • AI 笔记 -基于retinaface的FPN上采样替换为CARAFE
  • Vue2项目中使用videojs播放mp4视频
  • 习近平在第三十五个全国助残日到来之际作出重要指示
  • 中国军网:带你揭开3所新调整组建军队院校的神秘面纱
  • 自强!助残!全国200个集体和260名个人受到表彰
  • 商务部:今年前3月自贸试验区进出口总额达2万亿元
  • 魔都眼|锦江乐园摩天轮“换代”开拆,新摩天轮暂定118米
  • 新任国防部新闻发言人蒋斌正式亮相