当前位置: 首页 > news >正文

论文阅读Diffusion Autoencoders: Toward a Meaningful and Decodable Representation

原文框架图:

官方代码: https://github.com/phizaz/diffae/blob/master/interpolate.ipynb

主要想记录一下模型的推理过程 :

%load_ext autoreload
%autoreload 2
from templates import *
device = 'cuda:1'
conf = ffhq256_autoenc()
# print(conf.name)
model = LitModel(conf)
state = torch.load(f'checkpoints/{conf.name}/last.ckpt', map_location='cpu')
model.load_state_dict(state['state_dict'], strict=False)
model.ema_model.eval()
model.ema_model.to(device);
Global seed set to 0
Model params: 160.69 M
data = ImageDataset('imgs_interpolate', image_size=conf.img_size, exts=['jpg', 'JPG', 'png'], do_augment=False)
batch = torch.stack([
    data[0]['img'],
    data[1]['img'],
])
import matplotlib.pyplot as plt
plt.imshow(batch[0].permute([1, 2, 0]) / 2 + 0.5)

cond = model.encode(batch.to(device))
xT = model.encode_stochastic(batch.to(device), cond, T=250)

import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ori = (batch + 1) / 2
ax[0].imshow(ori[0].permute(1, 2, 0).cpu())
ax[1].imshow(xT[0].permute(1, 2, 0).cpu())

 

 Interpolate

Semantic codes are interpolated using convex combination, while stochastic codes are interpolated using spherical linear interpolation.

import numpy as np
alpha = torch.tensor(np.linspace(0, 1, 10, dtype=np.float32)).to(cond.device)
intp = cond[0][None] * (1 - alpha[:, None]) + cond[1][None] * alpha[:, None]

def cos(a, b):
    a = a.view(-1)
    b = b.view(-1)
    a = F.normalize(a, dim=0)
    b = F.normalize(b, dim=0)
    return (a * b).sum()

theta = torch.arccos(cos(xT[0], xT[1]))
x_shape = xT[0].shape
intp_x = (torch.sin((1 - alpha[:, None]) * theta) * xT[0].flatten(0, 2)[None] + torch.sin(alpha[:, None] * theta) * xT[1].flatten(0, 2)[None]) / torch.sin(theta)
intp_x = intp_x.view(-1, *x_shape)

pred = model.render(intp_x, intp, T=20)



import matplotlib.pyplot as plt
# torch.manual_seed(1)
fig, ax = plt.subplots(1, 10, figsize=(5*10, 5))
for i in range(len(alpha)):
    ax[i].imshow(pred[i].permute(1, 2, 0).cpu())
# plt.savefig('imgs_manipulated/compare.png')

http://www.dtcms.com/a/111613.html

相关文章:

  • Dagster系列教程:快速掌握数据资产定义
  • 数据库系统概述 | 第二章课后习题答案
  • 计算机系统---CPU
  • 嵌入式系统应用-拓展-相关开发软件说明
  • 常见的微信个人号二次开发功能
  • Unity:平滑输入(Input.GetAxis)
  • 【Cursor】切换主题
  • JS API
  • 【软考中级软件设计师】数据表示:原码、反码、补码、移码、浮点数
  • sward V1.0.8版本发布,全面支持各种附件上传预览
  • 初识数据结构——算法效率的“两面性”:时间与空间复杂度全解析
  • yolov12检测 聚类轨迹运动速度
  • 与总社团联合会合作啦
  • Linux的: /proc/sys/net/ipv6/conf/ 笔记250404
  • 操作系统面经(一)
  • 2025年【陕西省安全员C证】报名考试及陕西省安全员C证找解析
  • Qt QTableView QAbstractTableModel实现复选框+代理实现单元格编辑
  • 进行性核上性麻痹:饮食调理为健康护航
  • SpringBoot项目报错: 缺少 Validation
  • 【NLP 55、投机采样加速推理】
  • 在线考试系统带万字文档java项目java课程设计java毕业设计springboot项目
  • 【matplotlib参数调整】
  • 2011-2019年各省地方财政国土资源气象等事务支出决策数数据
  • 如何理解缓存一致性?
  • Linux 安装 MySQL8数据库
  • LLM面试题六
  • Linux随机数
  • React: hook相当于函数吗?
  • 算法设计学习9
  • 【Groovy快速上手 ONLY ONE】Groovy与Java的核心差异