当前位置: 首页 > news >正文

PyTorch 实现 Conditional DCGAN(条件深度卷积生成对抗网络)进行图像到图像转换的示例代码

以下是一个使用 PyTorch 实现 Conditional DCGAN(条件深度卷积生成对抗网络)进行图像到图像转换的示例代码。该代码包含训练和可视化部分,假设输入为图片和 4 个工艺参数,根据这些输入生成相应的图片。

1. 导入必要的库

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt

# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

2. 定义数据集类

class ImagePairDataset(Dataset):
    def __init__(self, image_pairs, params):
        self.image_pairs = image_pairs
        self.params = params

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        input_image, target_image = self.image_pairs[idx]
        param = self.params[idx]
        return input_image, target_image, param

3. 定义生成器和判别器

# 生成器
class Generator(nn.Module):
    def __init__(self, z_dim=4, img_channels=3):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # 输入: [batch_size, z_dim + 4, 1, 1]
            self._block(z_dim + 4, 1024, 4, 1, 0),  # [batch_size, 1024, 4, 4]
            self._block(1024, 512, 4, 2, 1),  # [batch_size, 512, 8, 8]
            self._block(512, 256, 4, 2, 1),  # [batch_size, 256, 16, 16]
            self._block(256, 128, 4, 2, 1),  # [batch_size, 128, 32, 32]
            nn.ConvTranspose2d(128, img_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )

    def forward(self, z, params):
        params = params.view(params.size(0), 4, 1, 1)
        x = torch.cat([z, params], dim=1)
        return self.gen(x)

# 判别器
class Discriminator(nn.Module):
    def __init__(self, img_channels=3):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # 输入: [batch_size, img_channels + 4, 64, 64]
            nn.Conv2d(img_channels + 4, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            self._block(64, 128, 4, 2, 1),  # [batch_size, 128, 16, 16]
            self._block(128, 256, 4, 2, 1),  # [batch_size, 256, 8, 8]
            self._block(256, 512, 4, 2, 1),  # [batch_size, 512, 4, 4]
            nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, img, params):
        params = params.view(params.size(0), 4, 1, 1).repeat(1, 1, img.size(2), img.size(3))
        x = torch.cat([img, params], dim=1)
        return self.disc(x)

4. 训练代码

def train_conditional_dcgan(image_pairs, params, batch_size=32, epochs=10, lr=0.0002, z_dim=4):
    dataset = ImagePairDataset(image_pairs, params)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    gen = Generator(z_dim).to(device)
    disc = Discriminator().to(device)

    criterion = nn.BCELoss()
    opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))

    for epoch in range(epochs):
        for i, (input_images, target_images, param) in enumerate(dataloader):
            input_images = input_images.to(device)
            target_images = target_images.to(device)
            param = param.to(device)

            # 训练判别器
            opt_disc.zero_grad()
            real_labels = torch.ones((target_images.size(0), 1, 1, 1)).to(device)
            fake_labels = torch.zeros((target_images.size(0), 1, 1, 1)).to(device)

            # 计算判别器对真实图像的损失
            real_output = disc(target_images, param)
            d_real_loss = criterion(real_output, real_labels)

            # 生成假图像
            z = torch.randn(target_images.size(0), z_dim, 1, 1).to(device)
            fake_images = gen(z, param)

            # 计算判别器对假图像的损失
            fake_output = disc(fake_images.detach(), param)
            d_fake_loss = criterion(fake_output, fake_labels)

            # 总判别器损失
            d_loss = d_real_loss + d_fake_loss
            d_loss.backward()
            opt_disc.step()

            # 训练生成器
            opt_gen.zero_grad()
            output = disc(fake_images, param)
            g_loss = criterion(output, real_labels)
            g_loss.backward()
            opt_gen.step()

        print(f'Epoch [{epoch+1}/{epochs}] D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}')

    return gen

5. 可视化代码

def visualize_generated_images(gen, input_images, params, z_dim=4):
    input_images = input_images.to(device)
    params = params.to(device)
    z = torch.randn(input_images.size(0), z_dim, 1, 1).to(device)
    fake_images = gen(z, params).cpu().detach()

    fig, axes = plt.subplots(1, input_images.size(0), figsize=(15, 3))
    for i in range(input_images.size(0)):
        img = fake_images[i].permute(1, 2, 0).numpy()
        img = (img + 1) / 2  # 从 [-1, 1] 转换到 [0, 1]
        axes[i].imshow(img)
        axes[i].axis('off')
    plt.show()

6. 示例使用

# 假设 image_pairs 是一个包含图像对的列表,params 是一个包含 4 个工艺参数的列表
image_pairs = []  # 这里需要替换为实际的图像对数据
params = []  # 这里需要替换为实际的工艺参数数据

# 训练模型
gen = train_conditional_dcgan(image_pairs, params)

# 可视化生成的图像
test_input_images, test_target_images, test_params = image_pairs[:5], image_pairs[:5], params[:5]
test_input_images = torch.stack([torch.tensor(img) for img in test_input_images]).float()
test_params = torch.tensor(test_params).float()
visualize_generated_images(gen, test_input_images, test_params)

代码说明

  1. 数据集类ImagePairDataset 用于加载图像对和工艺参数。
  2. 生成器和判别器GeneratorDiscriminator 分别定义了生成器和判别器的网络结构。
  3. 训练代码train_conditional_dcgan 函数用于训练 Conditional DCGAN 模型。
  4. 可视化代码visualize_generated_images 函数用于可视化生成的图像。
  5. 示例使用:最后部分展示了如何使用上述函数进行训练和可视化。

请注意,你需要将 image_pairsparams 替换为实际的数据集。此外,代码中的超参数(如 batch_sizeepochslr 等)可以根据实际情况进行调整。

相关文章:

  • RabbitMQ可靠性进制
  • C语言每日一练——day_9
  • 【AHE数据集】 NCAR Anthropogenic Heat Flux (AHF) 数据集
  • Flask应用调试模式下外网访问的技巧
  • Day5 结构体、文字显示与GDT/IDT初始化
  • MySQL查询语句之like
  • Flask从入门到精通--初始Flask
  • 黑马node.js教程(nodejs教程)——AJAX-Day01-04.案例_地区查询——查询某个省某个城市所有地区(代码示例)
  • 五种最新优化算法(ALA、AE、DOA、GOA、OX)求解多个无人机协同路径规划(可以自定义无人机数量及起始点),MATLAB代码
  • dubbo nacos配置详解
  • 【electron】vue项目中使用electron打包报错的解决办法
  • 用pyqt做个日期输入控件,实现公农历转换及干支纪时功能
  • python微分方程求解,分别用显式欧拉方法、梯形法、改进欧拉方法、二阶龙格库塔方法、四阶龙格库塔方法求解微分方程
  • [oeasy]python074_ai辅助编程_水果程序_fruits_apple_banana_加法_python之禅
  • 解决WIN10使用苹果鼠标滚轮不能使用的问题
  • ArcGis使用-对轨迹起点终点的网格化编号
  • git使用。创建仓库,拉取分支,新建分支开发
  • DeepSeek在学术写作文献综述中两个核心提示词
  • 从中序与后序遍历序列构造二叉树 最大二叉树 合并二叉树 二叉搜索树中的搜索
  • 【USTC 计算机网络】第一章:计算机网络概述 - Internet 结构与 ISP、分组延时与丢失、协议层次与服务模型
  • 五一小长假,带着小狗去上海音乐厅
  • 力箭二号火箭成功进行满载起竖试验,计划今年首飞发射轻舟飞船
  • “乐购浦东”消费券明起发放,多个商家同期推出折扣促销活动
  • 赛力斯拟赴港上市:去年扭亏为盈净利59亿元,三年内实现百万销量目标
  • 报告显示2024年全球军费开支增幅达冷战后最大
  • 幸福航空取消“五一”前航班,财务人员透露“没钱飞了”