当前位置: 首页 > wzjs >正文

免费的自助设计网站网站推广 排名

免费的自助设计网站,网站推广 排名,免费建网站系统,医生问诊在线咨询免费文章目录 1. description2. code 1. description 后续整理 GAN是生成对抗网络,主要由G生成器,D判别器组成,具体形式如下 D 判别器: G生成器: 2. code 部分源码,暂定,后续修改 import nump…

文章目录

  • 1. description
  • 2. code

1. description

后续整理
GAN是生成对抗网络,主要由G生成器,D判别器组成,具体形式如下

  • D 判别器:
    在这里插入图片描述
  • G生成器:
    在这里插入图片描述

2. code

部分源码,暂定,后续修改

import numpy as np
import os
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, Datasetimport torch.cudaimage_size = [1, 28, 28]
latent_dim = 96
label_emb_dim = 32
batch_size = 64
use_gpu = torch.cuda.is_available()
save_dir = "cgan_images"
os.makedirs(save_dir, exist_ok=True)class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.embedding = nn.Embedding(10, label_emb_dim)self.model = nn.Sequential(nn.Linear(label_emb_dim + label_emb_dim, 128),nn.BatchNorm1d(128),nn.GELU(),nn.Linear(128, 256),nn.BatchNorm1d(256),nn.GELU(),nn.Linear(256, 512),nn.BatchNorm1d(512),nn.GELU(),nn.Linear(512, 1024),nn.BatchNorm1d(1024),nn.GELU(),nn.Linear(1024, np.prod(image_size, dtype=np.int32)),nn.Sigmoid(),)def forward(self, z, labels):# shape of z:[batch_size,latent_dim]label_embedding = self.embedding(labels)z = torch.cat([z, label_embedding], axis=-1)output = self.model(z)image = output.reshape(z.shape[0], *image_size)return imageclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.embedding = nn.Embedding(10, label_emb_dim)self.model = nn.Sequential(nn.Linear(np.prod(image_size, dtype=np.int32) + label_emb_dim, 512),torch.nn.GELU(),# nn.Linear(512,256)nn.utils.spectral_norm(nn.Linear(512, 256)),nn.GELU(),# nn.Linear(256,128)nn.utils.spectral_norm(nn.Linear(256, 128)),nn.GELU(),# nn.Linear(128,64)nn.utils.spectral_norm(nn.Linear(128, 64)),nn.GELU(),# nn.Linear(64,32)nn.utils.spectral_norm(nn.Linear(64, 32)),nn.GELU(),# nn.Linear(32,1)nn.utils.spectral_norm(nn.Linear(32, 1)),nn.Sigmoid(),)def forward(self, image, labels):# shape of image:[batch_size,1,28,28]label_embedding = self.embedding(labels)prob = self.model(torch.cat([image.reshape(image.shape[0], -1), label_embedding], axis=-1))return probif __name__ == "__main__":run_code = 0v_transform = torchvision.transforms.Compose([torchvision.transforms.Resize(28),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.5], [0.5])])dataset = torchvision.datasets.MNIST("mnist_data", train=True, download=True, transform=v_transform)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)generator = Generator()discriminator = Discriminator()g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)loss_fn = nn.BCELoss()labels_one = torch.ones(batch_size, 1)labels_zero = torch.zeros(batch_size, 1)if use_gpu:print("use gpu for trainning")generator = generator.cuda()discriminator = discriminator.cuda()loss_fn = loss_fn.cuda()labels_one = labels_one.to("cuda")labels_zero = labels_zero.to("cuda")num_epoch = 200for epoch in range(num_epoch):for i, mini_batch in enumerate(dataloader):gt_images, labels = mini_batchz = torch.randn(batch_size, latent_dim)if use_gpu:gt_images = gt_images.to("cuda")z = z.to("cuda")pred_images = generator(z, labels)g_optimizer.zero_grad()recons_loss = torch.abs(pred_images - gt_images).mean()g_loss = 0.05 * recons_loss + loss_fn(discriminator(pred_images, labels), labels_one)g_loss.backward()g_optimizer.step()d_optimizer.zero_grad()real_loss = loss_fn(discriminator(gt_images, labels), labels_one)fake_loss = loss_fn(discriminator(pred_images, labels), labels_zero)d_loss = real_loss + fake_loss# 观察 real_loss 与 fake_loss 同时下降同时达到最小值,并且差不多大,说明D已经稳定了d_loss.backward()d_optimizer.step()if i % 50 == 0:print(f"step:{len(dataloader) * epoch + i},recons_loss:{recons_loss.item()},g_loss:{g_loss.item()},"f"d_loss:{d_loss.item()},real_loss:{real_loss.item()},fake_loss:{fake_loss.item()},d_loss:{d_loss.item()}")if i % 800 == 0:image = pred_images[:16].datatorchvision.utils.save_image(image, f"{save_dir}/image_{len(dataloader) * epoch + i}.png", nrow=4)
http://www.dtcms.com/wzjs/800063.html

相关文章:

  • 备案怎么关闭网站网站制作的原因
  • 西宁的网站建设wordpress固定连接怎么设置最好
  • 企业网站分为哪四类怎么做自己的网站免费
  • 营销网站建设推广网站建设云技术公司推荐
  • 中国旅游网站建设网页微信登录不了提示为了安全考虑
  • 公司互联网站全面改版wordpress 反斜杠
  • 网站建设工作室创业计划书新郑郑州网站建设
  • 网站建设公司在哪里anker 网站谁做的
  • 国内最大ae模板下载网站网站做发
  • 赣州网站优化推广网站快速被收录
  • 简单的网站类型有哪些内容小程序雀神麻将开挂视频
  • 云南seo网站关键词优化软件展馆展示设计公司哪家好一点
  • 奢侈品网站策划方案建设一个电商网站的流程
  • 秦皇岛做网站外包一流本科专业建设网站
  • 深圳做手机网站建设济南装修网
  • dw 个人网站怎么做个人建设网站程序
  • 坪山企业网站建设黑马程序员吧
  • 网站收录入口申请网站被降权重新做网站
  • 网站建设的基本流程有哪些常用的编辑html的软件
  • 大连网站建设与维护题库 天堂资源最新版中文资源
  • 简述php网站开发流程图成都 网站
  • 做电商运营要什么条件wordpress 界面优化
  • 搭建网站宣传西安网站搭建费用
  • 建协网官方网站科技管理信息网站的建设方案
  • 营销型网站方案广州建网站模板
  • 南京和筑建设有限公司网站关闭 百度云加速 后网站打不开了
  • 在家做兼职的网站服务器如何发布网站
  • 温州网站建设wmwl视频直播app下载
  • 深圳网站制作作优秀的app界面设计案例
  • 博山网站建设yx718做漆包线的招聘网站