当前位置: 首页 > wzjs >正文

南昌设计政府网站的公司一键优化大师下载

南昌设计政府网站的公司,一键优化大师下载,想学电商运营在哪里学,步骤一WGAN-GP 原理及实现 一、WGAN-GP 原理1.1 WGAN-GP 核心原理1.2 WGAN-GP 实现步骤1.3 总结 二、WGAN-GP 实现2.1 导包2.2 数据加载和处理2.3 构建生成器2.4 构建判别器2.5 训练和保存模型2.6 图片转GIF 一、WGAN-GP 原理 Wasserstein GAN with Gradient Penalty (WGAN-GP) 是对…

WGAN-GP 原理及实现

  • 一、WGAN-GP 原理
    • 1.1 WGAN-GP 核心原理
    • 1.2 WGAN-GP 实现步骤
    • 1.3 总结
  • 二、WGAN-GP 实现
    • 2.1 导包
    • 2.2 数据加载和处理
    • 2.3 构建生成器
    • 2.4 构建判别器
    • 2.5 训练和保存模型
    • 2.6 图片转GIF

一、WGAN-GP 原理

Wasserstein GAN with Gradient Penalty (WGAN-GP) 是对原始 WGAN 的改进,通过梯度惩罚(Gradient Penalty)替代权重裁剪(Weight Clipping),解决了 WGAN 训练不稳定、权重裁剪导致梯度消失或爆炸的问题。


1.1 WGAN-GP 核心原理

(1) Wasserstein 距离(Earth-Mover 距离)

  • 原始 GAN 的 JS 散度在分布不重叠时梯度消失,而 WGAN 使用 Wasserstein 距离衡量生成分布 P g P_g Pg 和真实分布 P r P_r Pr 的距离:
    W ( P r , P g ) = inf ⁡ γ ∼ Π ( P r , P g ) E ( x , y ) ∼ γ [ ∥ x − y ∥ ] W(P_r, P_g) = \inf_{\gamma \sim \Pi(P_r, P_g)} \mathbb{E}_{(x,y)\sim \gamma} [\|x-y\|] W(Pr,Pg)=infγΠ(Pr,Pg)E(x,y)γ[xy]
  • 通过 Kantorovich-Rubinstein 对偶形式,转化为:
    W ( P r , P g ) = sup ⁡ ∥ D ∥ L ≤ 1 E x ∼ P r [ D ( x ) ] − E z ∼ P z [ D ( G ( z ) ) ] W(P_r, P_g) = \sup_{\|D\|_L \leq 1} \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{z \sim P_z}[D(G(z))] W(Pr,Pg)=supDL1ExPr[D(x)]EzPz[D(G(z))],其中 D D D 是 1-Lipschitz 函数(梯度范数不超过 1)

(2) 梯度惩罚(Gradient Penalty)

  • 原始 WGAN 的问题:通过权重裁剪强制判别器(Critic)满足 Lipschitz 约束,但会导致梯度不稳定或容量下降
  • WGAN-GP 的改进:直接对判别器的梯度施加惩罚项,强制其梯度范数接近 1: λ ⋅ E x ^ ∼ P x ^ \lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} λEx^Px^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] \left [(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2 \right] [(x^D(x^)21)2]
    • x ^ \hat{x} x^ 是真实数据和生成数据的随机插值点: x ^ = ϵ x + ( 1 − ϵ ) G ( z ) \hat{x} = \epsilon x + (1-\epsilon) G(z) x^=ϵx+(1ϵ)G(z) ϵ ∼ U [ 0 , 1 ] \epsilon \sim U[0,1] ϵU[0,1]
    • λ \lambda λ 是惩罚系数(通常设为 10)

1.2 WGAN-GP 实现步骤

(1) 判别器(Critic)的损失函数
判别器的目标是最大化 Wasserstein 距离,同时满足梯度约束:
L D = E x ∼ P r [ D ( x ) ] − E z ∼ P z [ D ( G ( z ) ) ] ⏟ Wasserstein 距离 + λ ⋅ E x ^ ∼ P x ^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] ⏟ 梯度惩罚 L_D = \underbrace{\mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{z \sim P_z}[D(G(z))]}_{\text{Wasserstein 距离}} + \underbrace{\lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} \left[ (\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2 \right]}_{\text{梯度惩罚}} LD=Wasserstein 距离 ExPr[D(x)]EzPz[D(G(z))]+梯度惩罚 λEx^Px^[(x^D(x^)21)2]

(2) 生成器(Generator)的损失函数
生成器的目标是最小化 Wasserstein 距离: L G = − E z ∼ P z [ D ( G ( z ) ) ] L_G = -\mathbb{E}_{z \sim P_z}[D(G(z))] LG=EzPz[D(G(z))]

(3) 训练流程

  1. 输入:真实数据 x x x,噪声 z ∼ N ( 0 , 1 ) z \sim \mathcal{N}(0,1) zN(0,1)
  2. 生成数据 G ( z ) G(z) G(z)
  3. 插值采样 x ^ = ϵ x + ( 1 − ϵ ) G ( z ) \hat{x} = \epsilon x + (1-\epsilon) G(z) x^=ϵx+(1ϵ)G(z) ϵ ∼ U [ 0 , 1 ] \epsilon \sim U[0,1] ϵU[0,1]
  4. 计算梯度惩罚
    • 对插值样本 x ^ \hat{x} x^ 计算判别器输出 D ( x ^ ) D(\hat{x}) D(x^)
    • 求梯度 ∇ x ^ D ( x ^ ) \nabla_{\hat{x}} D(\hat{x}) x^D(x^) 并计算惩罚项
  5. 更新判别器:最小化 L D L_D LD
  6. 更新生成器:最小化 L G L_G LG(每 n critic n_{\text{critic}} ncritic 次判别器更新后更新 1 次生成器)

1.3 总结

WGAN-GP 通过梯度惩罚替代权重裁剪,显著提升了 WGAN 的训练稳定性,是生成对抗网络的重要改进之一。实际应用中需注意:

  • 判别器架构设计
  • 梯度惩罚的正确实现
  • 学习率和训练次数的调优

二、WGAN-GP 实现

2.1 导包

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as npimport os
import time
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  
from torchsummary import summary# 判断是否存在可用的GPU
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 指定存放日志路径
writer=SummaryWriter(log_dir="./runs/wgan_gp")os.makedirs("./img/wgan_gp_mnist", exist_ok=True) # 存放生成样本目录
os.makedirs("./model", exist_ok=True) # 模型存放目录

2.2 数据加载和处理

# 加载 MNIST 数据集
def load_data(batch_size=64,img_shape=(1,28,28)):transform = transforms.Compose([transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize(mean=[0.5], std=[0.5])  # 归一化到[-1,1]])# 下载训练集和测试集train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 创建 DataLoadertrain_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=2,shuffle=True)test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=2,shuffle=False)return train_loader, test_loader

2.3 构建生成器

class Generator(nn.Module):"""生成器"""def __init__(self, latent_dim=100,img_shape=(1,28,28)):super(Generator,self).__init__()# 网络块def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat))layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))return layersself.model = nn.Sequential(*block(latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh() # 输出归一化到[-1,1] )def forward(self,z): # 噪声z,2维[batch_size,latent_dim]gen_img=self.model(z) gen_img=gen_img.view(gen_img.shape[0],*img_shape)return gen_img # 4维[batch_size,1,H,W]

2.4 构建判别器

class Discriminator(nn.Module):"""判别器"""def __init__(self,img_shape=(1,28,28)):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(256, 1))def forward(self,img): # 输入图片,4维[batc_size,1,H,W]img=img.view(img.shape[0], -1) pred = self.model(img)return pred # 2维[batch_size,1] 

2.5 训练和保存模型

  • WGAN-GP 算法流程

  • 定义梯度惩罚函数

def compute_gradient_penalty(critic, real, fake, device):batch_size = real.shape[0]epsilon = torch.rand(batch_size, 1, 1, 1).to(device)  # 随机插值系数interpolates = (epsilon * real + (1 - epsilon) * fake).requires_grad_(True)critic_interpolates = critic(interpolates)# 计算梯度gradients = torch.autograd.grad(outputs=critic_interpolates,inputs=interpolates,grad_outputs=torch.ones_like(critic_interpolates),create_graph=True,retain_graph=True,)[0]gradients = gradients.view(gradients.shape[0], -1)gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gradient_penalty
  • 训练和保存
# 设置超参数
batch_size = 64
epochs = 200
lr= 0.0002
latent_dim=100 # 生成器输入噪声向量的长度(维数)
sample_interval=400 #每400次迭代保存生成样本# WGAN的特别设置
num_iter_critic = 5
lambda_gp = 10# 设置图片形状1*28*28
img_shape = (1,28,28)# 加载数据
train_loader,_= load_data(batch_size=batch_size,img_shape=img_shape)# 实例化生成器G、判别器D
G=Generator().to(device)
D=Discriminator().to(device)# 设置优化器
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))# 开始训练
batches_done=0
loader_len=len(train_loader) #训练集加载器的长度
for epoch in range(epochs):# 进入训练模式G.train()D.train()loop = tqdm(train_loader, desc=f"第{epoch+1}轮")for i, (real_imgs, _) in enumerate(loop):real_imgs=real_imgs.to(device)  # [B,C,H,W]# -----------------#  训练判别器# -----------------# 获取噪声样本[B,latent_dim)z=torch.normal(0,1,size=(real_imgs.shape[0],latent_dim),device=device)  #从正态分布中抽样# Step-1 计算判断器损失=判断真实图片损失+判断生成图片损失+惩罚项fake_imgs=G(z).detach()gradient_penalty=compute_gradient_penalty(D, real_imgs, fake_imgs, device)dis_loss=-torch.mean(D(real_imgs)) + torch.mean(D(fake_imgs))+lambda_gp*gradient_penalty# Step-2 更新判别器参数optimizer_D.zero_grad() # 梯度清零dis_loss.backward() #反向传播,计算梯度optimizer_D.step()  #更新判别器 # -----------------#  训练生成器# -----------------# 判别器每迭代 num_iter_critic 次,生成器迭代一次if i % num_iter_critic ==0 :gen_imgs=G(z).detach()# 更新生成器参数optimizer_G.zero_grad() #梯度清零gen_loss=-torch.mean(D(gen_imgs))gen_loss.backward() #反向传播,计算梯度optimizer_G.step()  #更新生成器  # 更新进度条loop.set_postfix(gen_loss=f"{gen_loss:.8f}",dis_loss=f"{dis_loss:.8f}")# 每 sample_interval 次迭代保存生成样本if batches_done % sample_interval == 0:save_image(gen_imgs.data[:25], f"./img/wgan_gp_mnist/{epoch}_{i}.png", nrow=5, normalize=True)batches_done += 1print('总共训练用时: %.2f min' % ((time.time() - start_time)/60))#仅保存模型的参数(权重和偏置),灵活性高,可以在不同的模型结构之间加载参数
torch.save(G.state_dict(), "./model/WGAN-GP_G.pth") 
torch.save(D.state_dict(), "./model/WGAN-GP_D.pth") 

2.6 图片转GIF

from PIL import Imagedef create_gif(img_dir="./img/wgan_gp_mnist", output_file="./img/wgan_gp_mnist/wgan_gp_figure.gif", duration=100):images = []img_paths = [f for f in os.listdir(img_dir) if f.endswith(".png")]# 自定义排序:按 "x_y.png" 的 x 和 y 排序img_paths_sorted = sorted(img_paths,key=lambda x: (int(x.split('_')[0]),  # 第一个数字(如 0_400.png 的 0)int(x.split('_')[1].split('.')[0])  # 第二个数字(如 0_400.png 的 400)))for img_file in img_paths_sorted:img = Image.open(os.path.join(img_dir, img_file))images.append(img)images[0].save(output_file, save_all=True, append_images=images[1:], duration=duration, loop=0)print(f"GIF已保存至 {output_file}")
create_gif()
http://www.dtcms.com/wzjs/41899.html

相关文章:

  • 找人做网站引擎搜索网站
  • 百度站长提交百度统计手机app
  • 百度搜自己的网站win7优化大师官方免费下载
  • 洛阳尚贤网络科技有限公司南昌seo数据监控
  • 开发网站实训的心得体会北京做网络优化的公司
  • 医疗产品网站建设河南网站公司
  • wordpress+做仿站自助建站工具
  • 个人怎样建立网站怎么让某个关键词排名上去
  • 西安网站建设公司十强湖人最新排名最新排名
  • 漂亮购物网站欣赏一个完整的产品运营方案
  • 贵阳专业性网站制作郑州网络推广专业公司
  • 东营网站建设国内销售平台有哪些
  • 做网站一排文字怎么水平对齐百度数据平台
  • 南昌网站建设服务器网站广告投放收费标准
  • 网站后台管理系统怎么进新闻媒体发稿平台
  • 建湖做网站多少钱系统优化工具
  • android官网潍坊seo教程
  • 新生活cms下载win7优化配置的方法
  • 刘家窑做网站的公司有哪些平台可以免费发广告
  • 石狮市住房和城乡建设局网站关键词分析工具网站
  • 网站建设销售怎么样2022重大时政热点事件简短
  • b站推广网站动漫深圳新闻最新事件
  • 宜昌十堰网站建设哪家好网站注册账号
  • 沈阳和平三好街做网站企业营销策划及推广
  • 网站开发行业前景营销策划案的模板
  • 网站推广连接怎么做的写软文推广
  • wordpress搬家出现404seo软件
  • 网站运维合同网络营销平台都有哪些
  • 兰州网站建设公司百度竞价点击工具
  • 网站app软件营销顾问公司