当前位置: 首页 > wzjs >正文

网站建设需要找工信部吗seo点击

网站建设需要找工信部吗,seo点击,网站闪图怎么做的,个人求职网站设计3 mnist gan小试牛刀 背景原理训练方法模型定义G模型定义D模型定义一些讨论 训练函数图像的预处理和训练过程结果输出 背景 之前已经做了手写数字的预测了,最近对生成方向比较感兴趣,就算我不是这个研究方向的人,但是我也听过gan和diffusion…

3 mnist gan小试牛刀

  • 背景
  • 原理
  • 训练方法
  • 模型定义
    • G模型定义
    • D模型定义
    • 一些讨论
  • 训练函数
  • 图像的预处理和训练过程
  • 结果输出

背景

之前已经做了手写数字的预测了,最近对生成方向比较感兴趣,就算我不是这个研究方向的人,但是我也听过gan和diffusion的大名,但在我的印象中gan是比较早提出来的,就从gan先开始玩。

原理

生成式对抗网络(Generative adversarial network, GAN) 的原理很简单,本质上就是训练两个网络一个判别器(Discriminator,D),一个生成器(Generator,G)

其中D的作用就是判断输入的图像是否是伪造的
G的作用就是根据随机噪声来生成图片,从而达到欺骗D的效果。

然后通过不断训练,使得G的生成效果越来越好,D的鉴别效果也越来越好

训练方法

既然D的核心目标,是判断输入图像的真伪,那么D的loss函数也很直观了

D的loss应该有两个部分来组成,一部分是真实图像所产生的loss,我们需要鼓励模型预测真是的图像为真,也就是说希望模型对于真实的图像输出的概率值越接近1越好

同时对于伪造的图像,肯定就是输出的概率值越接近0越好。

因此就可以用简单二元交叉熵损失函数来进行loss的计算了,这里给出简单的二元交叉熵损失函数的描述

其中,y 是真实标签,p 是模型预测的概率。
当真实标签 y 为1时,如果预测概率 p 接近1,则损失接近0;
反之,如果 p 接近0,则损失会变得非常大。
同理,当 y 为0时,如果 p 接近0,则损失接近0;
如果 p 接近1,则损失会变得非常大。

在计算过程中,我们将真实图像的y记为1,伪造图像的y记为0

我们记真实图像产生的loss是D_T_loss,伪造图像产生的误差是D_F_loss,那么D的loss就是

D_loss = D_T_loss + D_F_loss

而G的loss就更简单了,G的目标是欺骗D,是的D将本是伪造的图像预测成真实的图像,所以
我们只需要将伪造图像的y记为1,使得他往更加欺骗D的角度去训练,换句话说,就是我们希望伪造图像的预测值越来越接近1

模型定义

G模型定义

class Generator(nn.Module):def __init__(self, latent_dim=128, img_size=28, num_channels=1):super(Generator, self).__init__()self.img_size = img_sizeself.num_channels = num_channelsself.main = nn.Sequential(# 输入: 噪声 Z,维度为 [B, latent_dim, 1, 1] (经过全连接层后)# 经过全连接层将噪声向量转换为适合反卷积的维度nn.Linear(latent_dim, 256 * (img_size // 4) * (img_size // 4)),nn.BatchNorm1d(256 * (img_size // 4) * (img_size // 4)),nn.ReLU(True),# 将一维向量 reshape 成特征图nn.Unflatten(1, (256, img_size // 4, img_size // 4)), # 256 是通道数,img_size//4 是 H 和 W# 反卷积层 1: 256x7x7 -> 128x14x14nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.ReLU(True),# 反卷积层 2: 128x14x14 -> 64x28x28nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(True),# 反卷积层 3: 64x28x28 -> 1x28x28 (输出图像)nn.ConvTranspose2d(64, num_channels, kernel_size=3, stride=1, padding=1, bias=False),nn.Tanh() # 将像素值缩放到 [-1, 1])def forward(self, input):# 首先通过全连接层将噪声维度扩展x = self.main[0](input)x = self.main[1](x)x = self.main[2](x)# 然后进行 Unflattenx = self.main[3](x)# 接着进行反卷积x = self.main[4:](x)return x

D模型定义

class Discriminator(nn.Module):def __init__(self, img_channels=1, feature_maps_d=64):super(Discriminator, self).__init__()self.main = nn.Sequential(# 输入图像 (batch_size, img_channels, 28, 28)nn.Conv2d(img_channels, feature_maps_d, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# (batch_size, feature_maps_d, 14, 14)nn.Conv2d(feature_maps_d, feature_maps_d * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(feature_maps_d * 2),nn.LeakyReLU(0.2, inplace=True),# (batch_size, feature_maps_d*2, 7, 7)nn.Conv2d(feature_maps_d * 2, feature_maps_d * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(feature_maps_d * 4),nn.LeakyReLU(0.2, inplace=True),# (batch_size, feature_maps_d*4, 3, 3) (或者 4x4 取决于具体计算)# 最终的卷积层输出到1nn.Conv2d(feature_maps_d * 4, 1, 3, 1, 0, bias=False), # 调整核大小以适应最终尺寸nn.Sigmoid() # 输出一个介于0和1之间的概率)def forward(self, input):# 确保输入是图像格式 (batch_size, channels, H, W)# 如果你的 dataloader 已经输出 (N, 1, 28, 28),则不需要 viewreturn self.main(input.view(input.size(0), 1, 28, 28)) # 对于28x28的单通道图像

一些讨论

因为目前只是用来生成mnist数据集的图片的,所以上面的维度都是按照Mnist数据集来的,我也尝试过只用mlp,我堆叠了5层Mlp来作为G和D,G的效果很不好,后来就老老实实用卷积了。当然我并没有充分去探索,可能不同的mlp堆叠方式效果可能会更好,毕竟attention的本质也就是三个Mlp加一个softmax。

训练函数


def train_gan(G, D, dataloader, epochs, device,lr_D=2e-4, lr_G=2e-4,beta1=0.5,noise_dim=128):criterion = nn.BCELoss() # 二元交叉熵损失optimizer_D = torch.optim.Adam(D.parameters(), lr=lr_D, betas=(beta1, 0.999))optimizer_G = torch.optim.Adam(G.parameters(), lr=lr_G, betas=(beta1, 0.999))print("开始GAN训练...")for epoch in range(epochs):D_loss_total = 0G_loss_total = 0loader_len = len(dataloader)for i, data in tqdm(enumerate(dataloader), desc=f"Epoch {epoch+1}/{epochs}", total=loader_len):real_images, _ = dataB = real_images.size(0)real_images = real_images.view(-1, 784).to(device) # 假设28x28图像展平为784# --- 训练判别器 (D) ---optimizer_D.zero_grad()# 使用真实图像进行训练real_outputs = D(real_images)# 对真实图像使用软标签,防止D过于自信real_labels = torch.full((B,), 0.9, device=device) # 真实图像的软标签errD_real = criterion(real_outputs.view(-1), real_labels)errD_real.backward()# 使用假图像进行训练noise = torch.randn(B, noise_dim, device=device)fake_images = G(noise)# print(fake_images.shape)# 在D的训练中,G的输出需要 .detach(),这样G的梯度不会在这里计算fake_output = D(fake_images.detach())# 对D的假图像也使用软标签,但值更低fake_labels = torch.full((B,), 0.1, device=device) # 假图像的软标签errD_fake = criterion(fake_output.view(-1), fake_labels)errD_fake.backward()errD = errD_real + errD_fakeoptimizer_D.step()D_loss_total += errD.item()# --- 训练生成器 (G) ---optimizer_G.zero_grad()# 生成新的假图像,避免使用D训练时旧的图像noise = torch.randn(B, noise_dim, device=device)fake_images = G(noise)# D对G生成的图像的输出。这里至关重要,不要使用 detach()。output = D(fake_images)# G希望D将假图像分类为真实(目标标签为1.0或软标签)# 这里我们使用完整的1.0作为G的目标,以强烈鼓励它欺骗DerrG = criterion(output.view(-1), torch.full((B,), 1.0, device=device))errG.backward()optimizer_G.step()G_loss_total += errG.item()# --- Epoch 结束总结 ---d_loss_avg = D_loss_total / loader_len# 根据G的更新次数调整平均损失g_loss_avg = G_loss_total / loader_lenprint(f"Epoch [{epoch+1}/{epochs}] "f"D Loss: {d_loss_avg:.4f}, G Loss: {g_loss_avg:.4f}")

我觉得我的注释已经写的很详细了,就不进行赘述了

图像的预处理和训练过程

from torchvision import transformsbatch_size = 256
epochs = 100transform = transforms.Compose([transforms.ToTensor(),  # 转换为 Tensortransforms.Normalize((0.1307,), (0.3081,))  # 标准化
])# 下载 MNIST 数据集
mnist_train = MNIST(root='../dataset_file/mnist_raw', train=True, download=False,transform=transform)
dataloader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)# 初始化生成器和判别器
G = Generator().to(device)
D = Discriminator().to(device)# 训练GAN
train_gan(G, D, dataloader, epochs=epochs, device=device)

在这里插入图片描述

结果输出

#检测生成质量
import matplotlib.pyplot as pltdef generate_and_plot(G, num_images=16):noise = torch.randn(num_images, 128).to(device)with torch.no_grad():fake_images = G(noise).view(-1, 1, 28, 28).cpu()grid_size = int(num_images**0.5)fig, axes = plt.subplots(grid_size, grid_size, figsize=(8, 8))for i in range(grid_size):for j in range(grid_size):idx = i * grid_size + jif idx < num_images:axes[i, j].imshow(fake_images[idx].squeeze(), cmap='gray')axes[i, j].axis('off')plt.tight_layout()
generate_and_plot(G, num_images=16)

在这里插入图片描述
在这里插入图片描述

生成出来的还是有模有样的。

http://www.dtcms.com/wzjs/149777.html

相关文章:

  • 网站建设制作方式有哪些网络推广费用计入什么科目
  • 物流企业网站建设步骤数据分析网站
  • 网站建设未完成短视频营销的发展趋势
  • wordpress外贸商城主题东莞网站建设优化诊断
  • 瑞安企业做网站网络广告营销典型案例
  • 货运公共平台市场推广seo职位描述
  • 资讯网站老哥们给个关键词
  • 网站注册转化率搜索引擎优化seo的英文全称是
  • 做网站语言搜索引擎的四个组成部分及作用
  • 女频做的最好的网站竞价推广课程
  • 做网站还需要搜狗吗接单平台app
  • 做代码和网站色盲眼中的世界
  • 甜品网站开发需求分析推广平台排行榜有哪些
  • 网页制作与发布的流程泉州seo
  • 广西南宁网站空间搜索量排名
  • 佛山网站建设外包公司宜兴百度推广公司
  • 一佰互联自助建站培训网站建设
  • 政府采购网上商城入围重庆企业seo
  • 怎么生成网站源代码互联网销售是什么意思
  • 西安网站建设联系方式知识营销成功案例介绍
  • 湖北交投建设集团集团网站网站生成器
  • b站推广网站mmm的推荐机制成crm软件
  • 全网营销型网站新闻手机百度免费下载
  • 上传图片的网站要怎么做网络营销课程作业
  • 南京做网站xjrkj品牌营销策略四种类型
  • 国内移动端网站做的最好的2345网址大全
  • 建设公司与建筑公司的区别seo实战优化
  • 郑州网站建设技术外包网站seo哪家好
  • 网站建设销售怎么样找seo外包公司需要注意什么
  • 新网站提交百度收录网页