当前位置: 首页 > wzjs >正文

永康建设投标网站宁德网站建设制作

永康建设投标网站,宁德网站建设制作,湛江市建设工程造价信息网,微信公众号里面免费做网站CWGAN-GP 一、CWGAN-GP 原理1.1 CWGAN-GP 的核心改进1.2 CWGAN-GP 的损失函数1.3 CWGAN-GP 的优势1.4 关键参数选择1.5 应用场景 二、CWGAN-GP 实现2.1 导包2.2 数据加载和处理2.3 构建生成器2.4 构建判别器2.5 训练和保存模型2.6 查看训练损失2.7 图片转GIF2.8 模型加载和推理…

CWGAN-GP

  • 一、CWGAN-GP 原理
    • 1.1 CWGAN-GP 的核心改进
    • 1.2 CWGAN-GP 的损失函数
    • 1.3 CWGAN-GP 的优势
    • 1.4 关键参数选择
    • 1.5 应用场景
  • 二、CWGAN-GP 实现
    • 2.1 导包
    • 2.2 数据加载和处理
    • 2.3 构建生成器
    • 2.4 构建判别器
    • 2.5 训练和保存模型
    • 2.6 查看训练损失
    • 2.7 图片转GIF
    • 2.8 模型加载和推理

一、CWGAN-GP 原理

CWGAN-GPConditional Wasserstein GAN with Gradient Penalty)是 WGAN-GP 的条件生成版本,结合了 条件生成对抗网络(CGAN)Wasserstein GAN 的梯度惩罚(GP),以提高生成样本的质量和训练稳定性。

1.1 CWGAN-GP 的核心改进

(1)条件生成(Conditional Generation)

  • CGAN 思想:生成器 G G G 和判别器 D D D 都接收额外的条件信息 y y y(如类别标签、文本描述等),使生成的数据 G ( z ∣ y ) G(z|y) G(zy) 与条件 y y y 相关。
  • 数学表达
    • 生成器: G ( z ∣ y ) → x f a k e G(z|y) \rightarrow x_{fake} G(zy)xfake
    • 判别器: D ( x ∣ y ) → 评分 D(x|y) \rightarrow \text{评分} D(xy)评分

(2)Wasserstein 距离(WGAN)

  • WGAN 用 Earth-Mover(EM)距离(即 Wasserstein-1 距离)替代原始 GAN 的 JS 散度,解决梯度消失/爆炸问题:
    W ( P r , P g ) = sup ⁡ ∥ D ∥ L ≤ 1 E x ∼ P r [ D ( x ) ] − E x ∼ P g [ D ( x ) ] W(P_r, P_g) = \underset{\|D\|_L \leq 1}{\sup} \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{x \sim P_g}[D(x)] W(Pr,Pg)=DL1supExPr[D(x)]ExPg[D(x)]
    • D D D 需要是 1-Lipschitz 函数(梯度范数 ≤ 1)

(3) 梯度惩罚(Gradient Penalty, GP)

  • WGAN-GP 通过 梯度惩罚 强制判别器满足 Lipschitz 约束,替代 WGAN 的权重裁剪(更稳定):
    L G P = λ ⋅ E x ^ ∼ P x ^ [ ( ∥ ∇ x ^ D ( x ^ ∣ y ) ∥ 2 − 1 ) 2 ] \mathcal{L}_{GP} = \lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} \left[ (\|\nabla_{\hat{x}} D(\hat{x}|y)\|_2 - 1)^2 \right] LGP=λEx^Px^[(x^D(x^y)21)2]
    • x ^ \hat{x} x^ 是真实数据和生成数据的随机插值点:
      x ^ = ϵ x r e a l + ( 1 − ϵ ) x f a k e , ϵ ∼ U [ 0 , 1 ] \hat{x} = \epsilon x_{real} + (1 - \epsilon) x_{fake}, \quad \epsilon \sim U[0,1] x^=ϵxreal+(1ϵ)xfake,ϵU[0,1]
    • λ \lambda λ 是惩罚系数(通常设为 10)

1.2 CWGAN-GP 的损失函数

(1)判别器损失
L D = E x ∼ P r [ D ( x ∣ y ) ] − E z ∼ p ( z ) [ D ( G ( z ∣ y ) ∣ y ) ] ⏟ Wasserstein 距离 + λ ⋅ E x ^ ∼ P x ^ [ ( ∥ ∇ x ^ D ( x ^ ∣ y ) ∥ 2 − 1 ) 2 ] ⏟ 梯度惩罚 \mathcal{L}_D = \underbrace{\mathbb{E}_{x \sim P_r}[D(x|y)] - \mathbb{E}_{z \sim p(z)}[D(G(z|y)|y)]}_{\text{Wasserstein 距离}} + \underbrace{\lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}[(\|\nabla_{\hat{x}} D(\hat{x}|y)\|_2 - 1)^2]}_{\text{梯度惩罚}} LD=Wasserstein 距离 ExPr[D(xy)]Ezp(z)[D(G(zy)y)]+梯度惩罚 λEx^Px^[(x^D(x^y)21)2]

(2) 生成器损失
L G = − E z ∼ p ( z ) [ D ( G ( z ∣ y ) ∣ y ) ] \mathcal{L}_G = -\mathbb{E}_{z \sim p(z)}[D(G(z|y)|y)] LG=Ezp(z)[D(G(zy)y)]

  • 生成器的目标是让判别器对生成样本的评分最大化(即最小化 − L G -\mathcal{L}_G LG

1.3 CWGAN-GP 的优势

改进点WGANWGAN-GPCWGAN-GP
距离度量WassersteinWasserstein + GPWasserstein + GP
条件生成
训练稳定性权重裁剪(不稳定)梯度惩罚(稳定)梯度惩罚 + 条件控制
模式崩溃较少更少最少(因条件约束)

1.4 关键参数选择

  • 梯度惩罚系数 λ \lambda λ:通常设为 10
  • 判别器训练次数 n c r i t i c n_{critic} ncritic:通常 3~5 次/生成器 1 次
  • 学习率:建议 1 0 − 4 10^{-4} 104(Adam 优化器, β 1 = 0.5 , β 2 = 0.9 \beta_1=0.5, \beta_2=0.9 β1=0.5,β2=0.9

1.5 应用场景

  • 图像生成(如条件生成 MNIST 数字)
  • 文本到图像合成(如 StackGAN)
  • 数据增强(生成特定类别的数据)

CWGAN-GP 通过 条件输入 + 梯度惩罚显著提升了生成质量和训练稳定性,是许多现代 GAN 变种(如 ProGAN、StyleGAN)的基础架构


二、CWGAN-GP 实现

2.1 导包

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as npimport os
import time
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  
from torchsummary import summary# 判断是否存在可用的GPU
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 设置日志
time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) 
log_dir = os.path.join("./logs/cwgan-gp", time_str) 
os.makedirs(log_dir, exist_ok=True) 
writer = SummaryWriter(log_dir=log_dir) os.makedirs("./img/cwgan-gp_mnist", exist_ok=True) # 存放生成样本目录
os.makedirs("./model", exist_ok=True) # 模型存放目录# 超参数配置
config = {"num_epochs": 100,"batch_size": 64,"lr": 1e-4,"img_channels": 1,"img_size": 32,"features_g": 64,"features_d": 64,"z_dim": 100, "num_classes": 10, # 分类数"label_embed_dim": 10,  # 嵌入维度"lambda_gp": 10,  # 梯度惩罚系数"n_critic": 5,    # 判别器更新次数/生成器更新1次     
}

2.2 数据加载和处理

# 加载 MNIST 数据集
def load_data(config):transform = transforms.Compose([transforms.Resize(config["img_size"]),transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize(mean=[0.5], std=[0.5])  # 归一化到[-1,1]])# 下载训练集和测试集train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 创建 DataLoadertrain_loader = DataLoader(dataset=train_dataset, batch_size=config["batch_size"],shuffle=True)test_loader = DataLoader(dataset=test_dataset, batch_size=config["batch_size"], shuffle=False)return train_loader, test_loader 

2.3 构建生成器

class Generator(nn.Module):"""生成器"""def __init__(self, config):super(Generator,self).__init__()# 定义嵌入层 [B]-> [B,label_embed_dim]=[64,10]self.label_embed = nn.Embedding(config["num_classes"], config["label_embed_dim"]) # num_classes 个类别, label_embed_dim 维嵌入# 定义模型结构self.model = nn.Sequential(# 转置卷积-1: [B,z_dim + label_embed_dim,1,1]-> [B,features_g*8,4,4]nn.ConvTranspose2d(config["z_dim"] + config["label_embed_dim"],config["features_g"]*8, 4, 1, 0),nn.BatchNorm2d(config["features_g"]*8),nn.LeakyReLU(negative_slope=0.0001, inplace=True),# 转置卷积-2: [B,features_g*8,4,4]-> [B,features_g*4,8,8]nn.ConvTranspose2d(config["features_g"]*8, config["features_g"]*4, 4, 2, 1),nn.BatchNorm2d(config["features_g"]*4),nn.LeakyReLU(negative_slope=0.0001, inplace=True),# 转置卷积-3: [B,features_g*4,8,8]-> [B,features_g*2,16,16]nn.ConvTranspose2d(config["features_g"]*4, config["features_g"]*2, 4, 2, 1),nn.BatchNorm2d(config["features_g"]*2),nn.LeakyReLU(negative_slope=0.0001, inplace=True),# 转置卷积-4: [B,features_g*2,16,16]-> [B,img_channels,32,32]nn.ConvTranspose2d(config["features_g"]*2, config["img_channels"], 4, 2, 1),nn.Tanh() # 输出归一化到[-1,1]  )def forward(self, z, labels): # z[B,z_dim,1,1]# 嵌入标签 [B]-> [B,label_embed_dim]=[B,10]label_embed = self.label_embed(labels)# 标签升维 [B,label_embed_dim]-> [B,label_embed_dim,1,1]=[B,10,1,1]label_embed = label_embed.unsqueeze(2).unsqueeze(3)# 拼接噪声和嵌入标签 ->[B,z_dim + label_embed_dim ,1,1]=[B,100+10,1,1]gen_input= torch.cat([z,label_embed], dim=1)# 生成图片 [B,label_embed_dim + z_dim,1,1]-> [B,img_channels,32,32]img=self.model(gen_input)return img # [B,1,32,32]

2.4 构建判别器

class Discriminator(nn.Module):def __init__(self, img_shape=(1,32,32),num_classes=10,label_embed_dim=10,features_d=64):"判别器"super(Discriminator, self).__init__()# 定义嵌入层 [B]-> [B,label_embed_dim]=[B,10]self.label_embed = nn.Embedding(config["num_classes"], config["label_embed_dim"]) # num_classes 个类别, label_embed_dim 维嵌入# 定义模型结构self.model = nn.Sequential(# 卷积-1:[B,img_channels + label_embed_dim,32,32]-> [B,features_d,16,16]nn.Conv2d(config["img_channels"] + config["label_embed_dim"], config["features_d"], 4,2,1),nn.LeakyReLU(negative_slope=0.2, inplace=True),# 卷积-2:[B,features_d,16,16]-> [B,features_d*2,8,8]nn.Conv2d( config["features_d"],  config["features_d"]* 2, 4, 2, 1),nn.InstanceNorm2d( config["features_d"]* 2),nn.LeakyReLU(negative_slope=0.2, inplace=True),# 卷积-3:[B,features_d*2,8,8]-> [B,features_d*4,4,4]nn.Conv2d( config["features_d"]*2,  config["features_d"]*4, 4, 2, 1),nn.InstanceNorm2d( config["features_d"]*4),nn.LeakyReLU(negative_slope=0.2, inplace=True),# 卷积-4:[B,features_d*4,4,4]-> [B,1,1,1]nn.Conv2d( config["features_d"]*4, 1, 4, 1, 0),nn.Flatten(), # 展平,4维[B,1,1,1]-> 2维[B,1,1,1])def forward(self, img, labels):# 嵌入标签 [B]-> [B,label_embed_dim]=[B,10]label_embed = self.label_embed(labels)# 标签升维 [B,label_embed_dim]-> [B,label_embed_dim,1,1]=[B,10,1,1]label_embed = label_embed.unsqueeze(2).unsqueeze(3)# 沿着空间维度(高度和宽度)进行复制扩展,使其与目标图像的空间尺寸匹配label_embed = label_embed.repeat(1, 1, img.shape[2], img.shape[3])# 拼接图片和嵌入标签 ->[B,img.shape[0]+label_embed_dim ,img.shape[2],img.shape[3]]=[B,1+10,32,32]dis_input= torch.cat([img,label_embed], dim=1)# 进行判定validity = self.model(dis_input)return validity # [B,1]

2.5 训练和保存模型

  1. 定义梯度惩罚函数
# 梯度惩罚计算函数
def compute_gradient_penalty(disc, real_samples, fake_samples, labels,device=device):"""计算梯度惩罚项"""# 随机插值系数alpha = torch.rand(real_samples.shape[0], 1, 1, 1).to(device)# 生成插值样本interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)# 计算判别器对插值样本的输出d_interpolates = disc(interpolates, labels)# 计算梯度gradients = torch.autograd.grad(outputs=d_interpolates,inputs=interpolates,grad_outputs=torch.ones_like(d_interpolates),create_graph=True,retain_graph=True,only_inputs=True,)[0]# 计算梯度惩罚项 gradients = gradients.view(gradients.shape[0], -1) # 多维-> 2维 [B,*]gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gradient_penalty
  1. 定义保存生成样本
def sample_image(G,n_row, batches_done,latent_dim=100,device=device):"""Saves a grid of generated digits ranging from 0 to n_classes"""# 随机噪声-> [n_row ** 2,latent_dim]=[100,100]z=torch.normal(0,1,size=(n_row ** 2,latent_dim,1,1),device=device)  #从正态分布中抽样# 条件标签->[100]labels = torch.arange(n_row, dtype=torch.long, device=device).repeat_interleave(n_row)gen_imgs = G(z, labels)save_image(gen_imgs.data, "./img/cwgan-gp_mnist/%d.png" % batches_done, nrow=n_row, normalize=True)
  1. 训练和保存
# 加载数据
train_loader,_= load_data(config)# 实例化生成器G、判别器D
G=Generator(config).to(device)
D=Discriminator(config).to(device)# 设置优化器
optimizer_G = torch.optim.Adam(G.parameters(), lr=config["lr"],betas=(0.5, 0.9))
optimizer_D = torch.optim.Adam(D.parameters(), lr=config["lr"],betas=(0.5, 0.9))# 开始训练
start_time = time.time()  # 计时器
loader_len=len(train_loader) #训练集加载器的长度
for epoch in range(config["num_epochs"]):# 进入训练模式G.train()D.train()loop = tqdm(train_loader, desc=f"第{epoch+1}轮")for i, (real_imgs, labels) in enumerate(loop):real_imgs=real_imgs.to(device)  # [B,C,H,W]labels=labels.to(device) # [B]# -----------------#  训练判别器# 1. 必须 gen_imgs.detach(),防止生成器的梯度干扰判别器更新# 2. 损失函数为 D(real_img) - D(gen_img.detach()) + 惩罚项# -----------------for _ in range(config["n_critic"]):# 获取噪声样本[B,latent_dim,1,1]及对应的条件标签 [B]z=torch.normal(0,1,size=(real_imgs.shape[0],config["z_dim"],1,1),device=device)  #从正态分布中抽样# --- 计算判别器损失 ---# Step-1:对真实图片损失valid_loss=torch.mean(D(real_imgs,labels))# Step-2:对生成图片损失gen_imgs=G(z,labels) # 生成图片fake_loss=torch.mean(D(gen_imgs.detach(),labels))# Step-3:计算梯度惩罚gradient_penalty = compute_gradient_penalty(D, real_imgs.data, gen_imgs.data, labels)# Step-4:整体损失dis_loss= -(valid_loss - fake_loss) + config["lambda_gp"]* gradient_penalty# 更新判别器参数optimizer_D.zero_grad() #梯度清零dis_loss.backward() #反向传播,计算梯度optimizer_D.step()  #更新判别器   # -----------------#  训练生成器# 1.禁止gen_imgs.detach(),需保持完整的计算图# 2.损失函数应为 -D(gen_img)(WGAN 的目标是最大化判别器对生成样本的评分)# -----------------# 计算生成器损失       gen_loss = -torch.mean(D(gen_imgs, labels)) # 更新生成器参数optimizer_G.zero_grad() #梯度清零gen_loss.backward() #反向传播,计算梯度optimizer_G.step()  #更新生成器  # --- 记录损失到TensorBoard ---batches_done = epoch * loader_len + iif batches_done % 100 == 0:writer.add_scalars("CWGAN-GP",{"Generator": gen_loss.item(),"Discriminator": dis_loss.item(),},global_step=epoch * loader_len + i,  # 全局步数)# 更新进度条loop.set_postfix(gen_loss=f"{gen_loss:.8f}",dis_loss=f"{dis_loss:.8f}")if batches_done % 400 == 0:sample_image(G,10, batches_done)writer.close()
print('总共训练用时: %.2f min' % ((time.time() - start_time)/60))#仅保存模型的参数(权重和偏置),灵活性高,可以在不同的模型结构之间加载参数
torch.save(G.state_dict(), "./model/CWGAN-GP_G.pth") 
torch.save(D.state_dict(), "./model/CWGAN-GP_D.pth") 

2.6 查看训练损失

# 加载魔术命令
%reload_ext tensorboard# 启动TensorBoard(
%tensorboard --logdir ./logs/cwgan-gp --port=6007 

在这里插入图片描述

2.7 图片转GIF

from PIL import Imagedef create_gif(img_dir="./img/cwgan-gp_mnist", output_file="./img/cwgan-gp_mnist/cwgan-gp_figure.gif", duration=100):images = []img_paths = [f for f in os.listdir(img_dir) if f.endswith(".png")]# 自定义排序:按 "x.png" 的排序img_paths_sorted = sorted(img_paths,key=lambda x: (int(x.split('.')[0]),  ))for img_file in img_paths_sorted:img = Image.open(os.path.join(img_dir, img_file))images.append(img)images[0].save(output_file, save_all=True, append_images=images[1:], duration=duration, loop=0)print(f"GIF已保存至 {output_file}")
create_gif()

2.8 模型加载和推理

#载入训练好的模型
G = Generator(config) # 定义模型结构
G.load_state_dict(torch.load("./model/CWGAN-GP_G.pth",weights_only=True,map_location=device)) # 加载保存的参数
G.to(device) 
G.eval() # 获取噪声样本[10,100]及对应的条件标签 [10]
z=torch.normal(0,1,size=(10,100,1,1),device=device)  #从正态分布中抽样
gen_labels = torch.arange(10, dtype=torch.long, device=device) #0~9整数#生成假样本
gen_imgs=G(z,gen_labels).view(-1,32,32) # 4维->3维
gen_imgs=gen_imgs.detach().cpu().numpy()#绘制
plt.figure(figsize=(8, 5)) 
plt.rcParams.update({'font.family': 'serif',  
})
for i in range(10):plt.subplot(2, 5, i + 1)  plt.xticks([], []) plt.yticks([], [])  plt.imshow(gen_imgs[i], cmap='gray')  plt.title(f"Figure {i}", fontsize=12)  
plt.tight_layout()  
plt.show()
http://www.dtcms.com/wzjs/300988.html

相关文章:

  • zbolg搭建的网站软件外包公司排行榜
  • dz网站开发苏州网站建设优化
  • 网站建设后还有什么费用如何做企业产品推广
  • 建个网站找个人网页模板
  • js网站记住密码怎么做种子搜索引擎在线
  • 贵阳网站建设怎么样seo+网站排名
  • 高级网站开发技术太原网站优化公司
  • 哈尔滨建设网站公司吗上海外包seo
  • 广东新冠疫情最新情况百度seo排名优化公司
  • 深圳做网站排名哪家好优化王
  • 衡水网站公司朝阳seo
  • 兰山网站建设天津seo排名费用
  • 怎样在外管局网站做延期付款长春百度快速优化
  • 湖南省人民政府办公厅seo长尾关键词优化
  • 水平滚动网站快速排名方案
  • 自己做网站怎么能被访问水平优化
  • 农村电子商务网站建设方案seo是什么软件
  • 网站怎么做扫码微信支付接口想要推广网页
  • 网站开发报价合同2000元代理微信朋友圈广告
  • 企业网站空间域名在线生成个人网站免费
  • 如何做响应式网站设计推广运营
  • 网店代运营什么意思扬州百度seo
  • 亚马逊网站建设进度计划表seo搜索引擎优化教程
  • 源码论坛下载山西免费网站关键词优化排名
  • 教育培训类网站建设与维护百度100%秒收录
  • it运维管理个人网站seo入门
  • 怎么做有趣视频网站搜索最多的关键词的排名
  • 网建公司浅谈网站建设的目的和意义seo优化文章网站
  • 制作一个专门浏览图片的网站杭州网站建设网页制作
  • 绛帐做网站关键词代发排名推广