当前位置: 首页 > news >正文

【DLI】Generative AI with Diffusion Models通关秘籍

Generative AI with Diffusion Models,加载时间在20分钟左右,耐心等待。
在这里插入图片描述
在这里插入图片描述

6.2TODO

在这里插入图片描述

这里是在设置扩散模型的参数,代码里的FIXME部分需要根据上下文进行替换。以下是各个FIXME的替换说明:
1.a_bar 是 a 的累积乘积,在 PyTorch 里可以用 torch.cumprod 实现。
2.sqrt_a_bar、sqrt_one_minus_a_bar 和 sqrt_a_inv 都是对输入张量求平方根,可使用 torch.sqrt 实现。
3.pred_noise_coeff 中的 FIXME(1 - a_bar) 同样是求平方根,用 torch.sqrt 即可。
以下是替换后的代码:

nrows = 10
ncols = 15

T = nrows * ncols
B_start = 0.0001
B_end = 0.02
B = torch.linspace(B_start, B_end, T).to(device)

a = 1.0 - B
a_bar = torch.cumprod(a, dim=0)
sqrt_a_bar = torch.sqrt(a_bar)  # Mean Coefficient
sqrt_one_minus_a_bar = torch.sqrt(1 - a_bar)  # St. Dev. Coefficient

# Reverse diffusion variables
sqrt_a_inv = torch.sqrt(1 / a)
pred_noise_coeff = (1 - a) / torch.sqrt(1 - a_bar)  # Predicted Noise Coefficient

在扩散模型里,正向扩散过程 q 函数是按照如下公式把原始图像 x_0 逐步添加噪声变成 x_t 的
在这里插入图片描述
FIXME 部分应该分别用 sqrt_a_bar_t 和 sqrt_one_minus_a_bar_t 来替换。
在这个 q 函数中,按照扩散模型的正向过程公式,把原始图像 x_0 和随机噪声 noise 按一定比例组合,从而得到加噪后的图像 x_t。

def q(x_0, t):
    t = t.int()
    noise = torch.randn_like(x_0)
    sqrt_a_bar_t = sqrt_a_bar[t, None, None, None]
    sqrt_one_minus_a_bar_t = sqrt_one_minus_a_bar[t, None, None, None]

    x_t = sqrt_a_bar_t * x_0 + sqrt_one_minus_a_bar_t * noise
    return x_t, noise

在反向扩散过程中,我们要根据当前的潜在图像,当前时间步 , 以及预测的噪声 来恢复上一个时间步的图像。在这里插入图片描述
在这个 reverse_q 函数中,我们根据反向扩散过程的公式,从当前的潜在图像和预测的噪声中恢复上一个时间步的图像。如果当前时间步为 0,则表示反向扩散过程完成。否则,我们会添加一些噪声以模拟扩散过程。下面是对代码中 FIXME 部分的分析与替换:

@torch.no_grad()
def reverse_q(x_t, t, e_t):
    t = t.int()
    pred_noise_coeff_t = pred_noise_coeff[t]
    sqrt_a_inv_t = sqrt_a_inv[t]
    u_t = sqrt_a_inv_t * (x_t - pred_noise_coeff_t * e_t)
    if t[0] == 0:  # All t values should be the same
        return u_t  # Reverse diffusion complete!
    else:
        B_t = B[t - 1]  # Apply noise from the previous timestep
        new_noise = torch.randn_like(x_t)
        return u_t + torch.sqrt(B_t) * new_noise

在这里插入图片描述

6.3TODO

在这里插入图片描述

每个类的功能来添加正确模块名 依次改写FIXME 即可:

DownBlock进行下采样操作,包含卷积和池化相关的块
EmbedBlock将输入进行线性变换和激活
GELUConvBlock使用了卷积、组归一化和 GELU 激活函数,通常是一个卷积块
RearrangePoolBlock使用了 Rearrange 进行张量重排和卷积操作
ResidualConvBlock使用了两个卷积块并进行了残差连接
SinusoidalPositionEmbedBlock实现了正弦位置嵌入的功能
UpBlock上采样操作,包含转置卷积和卷积块

6.4TODO

在这个 get_context_mask 函数里,其目的是随机丢弃上下文信息。要实现随机丢弃,通常会使用 torch.bernoulli 函数。torch.bernoulli 函数会依据给定的概率来生成一个二进制掩码张量,其中每个元素为 1 的概率就是传入的概率值。
在这个函数中,我们希望以 drop_prob 的概率丢弃上下文,所以每个元素保留的概率是 1 - drop_prob。因此,FIXME 处应该填入 bernoulli。

def get_context_mask(c, drop_prob):
    c_hot = F.one_hot(c.to(torch.int64), num_classes=N_CLASSES).to(device)
    c_mask = torch.bernoulli(torch.ones_like(c_hot).float() * (1 - drop_prob)).to(device)
    return c_hot, c_mask

代码解释:
c_hot = F.one_hot(c.to(torch.int64), num_classes=N_CLASSES).to(device):将输入的 c 转换为独热编码向量,并且移动到指定的设备(如 GPU)上。
c_mask = torch.bernoulli(torch.ones_like(c_hot).float() * (1 - drop_prob)).to(device):生成一个与 c_hot 形状相同的二进制掩码张量,每个元素以 1 - drop_prob 的概率为 1,以 drop_prob 的概率为 0。
return c_hot, c_mask:返回独热编码向量和二进制掩码张量。
这样,你就可以使用这个函数来随机丢弃上下文信息了。

在这里插入图片描述

在扩散模型里,通常采用均方误差损失(Mean Squared Error Loss,MSE)来衡量预测噪声 noise_pred 和实际添加的噪声 noise 之间的差异。因为均方误差能够很好地衡量两个向量之间的平均平方误差,这对于扩散模型中预测噪声的准确性评估是很合适的。
在 PyTorch 中,nn.functional.mse_loss 函数可用于计算均方误差损失。所以 FIXME 处应填入 mse_loss。

def get_loss(model, x_0, t, *model_args):
    x_noisy, noise = q(x_0, t)
    noise_pred = model(x_noisy, t/T, *model_args)
    return F.mse_loss(noise, noise_pred)

代码解释
x_noisy, noise = q(x_0, t):调用 q 函数给原始图像 x_0 添加噪声,得到加噪后的图像 x_noisy 以及实际添加的噪声 noise。
noise_pred = model(x_noisy, t/T, *model_args):把加噪后的图像 x_noisy 和归一化后的时间步 t/T 输入到模型 model 中,得到模型预测的噪声 noise_pred。
return F.mse_loss(noise, noise_pred):使用 F.mse_loss 函数计算实际噪声 noise 和预测噪声 noise_pred 之间的均方误差损失并返回。
通过使用均方误差损失,模型能够学习到如何更准确地预测添加到图像中的噪声,从而在反向扩散过程中更好地恢复原始图像。

下一个 TODO

  1. c_drop_prob 的设置
    c_drop_prob 是上下文丢弃概率,一般在训练过程中会采用线性衰减策略,也就是在训练初期以较高概率丢弃上下文,随着训练的推进逐渐降低丢弃概率。在代码中,我们可以简单地将其设置为一个随着训练轮数逐渐降低的值。
  2. get_context_mask 函数的输入
    get_context_mask 函数需要一个上下文标签作为输入,在代码里这个标签应该从 batch 中获取。通常假设 batch 的第二个元素为上下文标签。

optimizer = Adam(model.parameters(), lr=0.001)
epochs = 5
preview_c = 0

model.train()
for epoch in range(epochs):
    # 线性衰减上下文丢弃概率
    c_drop_prob = max(0.1, 1 - epoch / epochs)  #这里我调整了顺序
    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()

        t = torch.randint(0, T, (BATCH_SIZE,), device=device).float()
        x = batch[0].to(device)
        # 假设 batch 的第二个元素是上下文标签
        c = batch[1].to(device)
        c_hot, c_mask = get_context_mask(c, c_drop_prob)
        loss = get_loss(model, x, t, c_hot, c_mask)
        loss.backward()
        optimizer.step()

        if epoch % 1 == 0 and step % 100 == 0:
            print(f"Epoch {epoch} | Step {step:03d} | Loss: {loss.item()} | C: {preview_c}")
            c_drop_prob = 0  # Do not drop context for preview
            c_hot, c_mask = get_context_mask(torch.Tensor([preview_c]).to(device), c_drop_prob)
            sample_images(model, IMG_CH, IMG_SIZE, ncols, c_hot, c_mask)
            preview_c = (preview_c + 1) % N_CLASSES

代码解释
c_drop_prob 的设置:运用线性衰减策略,在训练初期 c_drop_prob 为 0.9,随着训练的推进逐渐降低到 0.1。
get_context_mask 函数的输入:假设 batch 的第二个元素是上下文标签,将其传入 get_context_mask 函数。
训练过程:在每个训练步骤中,先将梯度清零,接着计算损失,再进行反向传播和参数更新。每训练 100 个步骤,就打印一次损失信息并进行一次样本生成。
通过这些修改,代码就能正常运行,从而开始训练模型。
在这里插入图片描述

6.5TODO

在扩散模型的采样过程中,为了给扩散过程添加权重,一般会根据给定的权重 w 对保留上下文的预测噪声 e_t_keep_c 和丢弃上下文的预测噪声 e_t_drop_c 进行加权组合。在这里插入图片描述
在代码中,FIXME 处应该根据上述公式进行计算,将 e_t_keep_c 和 e_t_drop_c 按照权重 w 进行组合。具体的代码如下:

def sample_w(model, c, w):
    input_size = (IMG_CH, IMG_SIZE, IMG_SIZE)
    n_samples = len(c)
    w = torch.tensor([w]).float()
    w = w[:, None, None, None].to(device)  # Make w broadcastable
    x_t = torch.randn(n_samples, *input_size).to(device)

    # One c for each w
    c = c.repeat(len(w), 1)

    # Double the batch
    c = c.repeat(2, 1)

    # Don't drop context at test time
    c_mask = torch.ones_like(c).to(device)
    c_mask[n_samples:] = 0.0

    x_t_store = []
    for i in range(0, T)[::-1]:
        # Duplicate t for each sample
        t = torch.tensor([i]).to(device)
        t = t.repeat(n_samples, 1, 1, 1)

        # Double the batch
        x_t = x_t.repeat(2, 1, 1, 1)
        t = t.repeat(2, 1, 1, 1)

        # Find weighted noise
        e_t = model(x_t, t, c, c_mask)
        e_t_keep_c = e_t[:n_samples]
        e_t_drop_c = e_t[n_samples:]
        e_t = w * e_t_keep_c + (1 - w) * e_t_drop_c

        # Deduplicate batch for reverse diffusion
        x_t = x_t[:n_samples]
        t = t[:n_samples]
        x_t = reverse_q(x_t, t, e_t)

    return x_t

## TODO

在扩散模型里,权重 w 可用于控制上下文信息在生成过程中的影响程度。w 值越接近 1,生成结果就越依赖上下文信息;w 值越接近 0,生成结果受上下文信息的影响就越小。若要让生成的数字能够被持续识别,你可以试着增大 w 的值,以此增强上下文信息对生成过程的影响。
下面是修改后的代码,你可以调整 w 的值来观察生成结果:

model.eval()
w = 5.0  # 可以尝试不同的值,通常大于 1 能增强上下文的影响
c = torch.arange(N_CLASSES).to(device)
c_drop_prob = 0 
c_hot, c_mask = get_context_mask(c, c_drop_prob)

x_0 = sample_w(model, c_hot, w)
other_utils.to_image(make_grid(x_0.cpu(), nrow=N_CLASSES))

代码解释
w = 5.0:把 w 的值设为 5.0,你可以根据实际情况调整这个值。通常,当 w 大于 1 时,上下文信息的影响会得到增强,这样生成的数字可能会更易于识别。
x_0 = sample_w(model, c_hot, w):调用 sample_w 函数生成图像,将 w 作为参数传入。
other_utils.to_image(make_grid(x_0.cpu(), nrow=N_CLASSES)):把生成的图像转换为可视化的形式。
你可以多次运行这段代码,并且调整 w 的值,直到生成的数字能够被稳定识别。

至此结束。
在这里插入图片描述

完整代码都在图片里

http://www.dtcms.com/a/108793.html

相关文章:

  • Redis基础知识-2
  • 从零构建大语言模型全栈开发指南:第五部分:行业应用与前沿探索-5.1.1百度ERNIE、阿里通义千问的技术对比
  • 程序化广告行业(56/89):S2S对接与第三方广告监测全解析
  • 《第三次世界大战》第七章:破碎的未来
  • 《实战AI智能体》MCP对Agent有哪些好处
  • [CISSP] [7] PKI和密码应用
  • 应用安全系列之四十五:日志伪造(Log_Forging)之二
  • 基于BusyBox构建ISO镜像
  • 多模态模型:专栏概要与内容目录
  • 网络爬虫的基础知识
  • 《inZOI(云族裔)》50+MOD整合包
  • 【目标检测】【深度学习】【Pytorch版本】YOLOV2模型算法详解
  • 【现代深度学习技术】现代卷积神经网络07:稠密连接网络(DenseNet)
  • CFResNet鸟类识别:原网络基础上改进算法
  • Springboot logback日志实例
  • RK3568下的QT工程配置
  • Joomla教程—Joomla 模块管理与Joomla 模块类型介绍
  • AI SEO内容优化指南:如何打造AI平台青睐的高质量内容
  • 在 Elasticsearch 中使用 Amazon Nova 模型
  • Vue父组件调用子组件设置table表格合并
  • chromium魔改——修改 navigator.webdriver 检测
  • 【大模型系列篇】大模型基建工程:基于 FastAPI 自动构建 SSE MCP 服务器 —— 进阶篇
  • Leetcode hot 100(day 4)
  • 03.01、三合一
  • 使用Amazon Bedrock Guardrails保护你的DeepSeek模型部署
  • 一问讲透redis持久化机制-rdb aof
  • 深度优化:解决SpringBoot应用启动速度慢的8个关键策略
  • 部署大模型实战:如何巧妙权衡效果、成本与延迟?
  • 智慧园区大屏如何实现全局监测:监测意义、内容、方式
  • .NET WebApi的详细发布流程——及其部署到Linux与Windows平台