当前位置: 首页 > news >正文

DreamDiffusion代码学习及复现

论文解读在这里

File path | Description
```

/pretrains
┣ 📂 models
┃   ┗ 📜 config.yaml
┃   ┗ 📜 v1-5-pruned.ckpt

┣ 📂 generation  
┃   ┗ 📜 checkpoint_best.pth 

┣ 📂 eeg_pretain
┃   ┗ 📜 checkpoint.pth  (pre-trained EEG encoder)

/datasets
┣ 📂 imageNet_images (subset of Imagenet)

┗  📜 block_splits_by_image_all.pth
┗  📜 block_splits_by_image_single.pth 
┗  📜 eeg_5_95_std.pth  

/code
┣ 📂 sc_mbm
┃   ┗ 📜 mae_for_eeg.py
┃   ┗ 📜 trainer.py
┃   ┗ 📜 utils.py

┣ 📂 dc_ldm
┃   ┗ 📜 ldm_for_eeg.py
┃   ┗ 📜 utils.py
┃   ┣ 📂 models
┃   ┃   ┗ (adopted from LDM)
┃   ┣ 📂 modules
┃   ┃   ┗ (adopted from LDM)

┗  📜 stageA1_eeg_pretrain.py   (main script for EEG pre-training)
┗  📜 eeg_ldm.py    (main script for fine-tuning stable diffusion)
┗  📜 gen_eval_eeg.py               (main script for generating images)

┗  📜 dataset.py                (functions for loading datasets)
┗  📜 eval_metrics.py           (functions for evaluation metrics)
┗  📜 config.py                 (configurations for the main scripts)

```

目录

dataset.py

gen_eval_eeg.py

stageA1_eeg_pretrain.py

eeg_ldm.py

gen_eval_eeg.py


dataset.py

一、基础工具函数模块

"沿时间轴进行环形填充"是一种信号处理技术,当数据长度不足时,用数据的起始部分循环填充到末尾(类似"循环播放")

  • 对比其他填充方式

    • 零填充(Zero-pad):[1,2,3] -> [1,2,3,0,0]

    • 环形填充:[1,2,3] -> [1,2,3,1,2]

  • 参数解读

    • ((0,0), (0, pad_size)):表示只在第二个维度(时间轴)右侧填充

    • 'wrap':指定环形填充模式

  • 输入

    • x.shape = (128, 500)(128个EEG通道,500个时间点)

    • patch_size = 16(每个时间块包含16个时间点)

  • 计算需要填充的长度

    • 当前时间点:500

    • 需要达到 N × patch_size 的最小长度

    • ceil(500 / 16) = 32 块 → 32×16=512

    • 需填充:512 - 500 = 12 个时间点

  • 填充操作:从每个通道的起始位置取前12个时间点,拼接到末尾

为什么选择环形填充?

填充方式优点缺点适用场景
环形填充保持信号周期性
避免边界突变
可能引入周期性假象EEG/ECG等准周期信号
零填充实现简单引入高频噪声通用场景
镜像填充平滑边界计算复杂图像处理

对于EEG信号:

  • 具有准周期性(alpha/beta波等)

  • 避免零填充导致的频谱泄漏(spectral leakage)

  • 更适合后续的块处理(patch划分)

Z-score标准化(又称标准差标准化)是一种常见的数据标准化方法,其核心是通过线性变换将原始数据转换为均值为0、标准差为1的分布。

对于一组数据 x,其标准化值 z的计算公式为:z=(x−μ)/σ

  • μ:数据的均值(平均值)

  • σ:数据的标准差(反映数据离散程度)

二、时间序列处理模块

 时间窗口

  • 定义:将连续的EEG信号按固定时长分段处理

  • 目的

    • 降低计算复杂度

    • 捕捉局部时域特征

    • 匹配后续处理(如傅里叶变换、模型输入长度)

  • 8 / 0.75 ≈ 10.67,0.75秒/帧:该数据集的时间分辨率(每帧持续时间)

三、数据增强模块

四、核心数据集类

1. 预训练数据集

2. 完整EEG-Image数据集
class EEGDataset(Dataset):
    def __init__(self, eeg_signals_path):
        loaded = torch.load(eeg_signals_path)  # 加载预处理数据
        self.data = [
            {
                'eeg': tensor,       # EEG信号 [通道, 时间]
                'label': int,        # 类别标签 
                'image': 'n01440764' # ImageNet ID
            }, ...
        ]
    
    def __getitem__(self, i):
        # EEG处理
        eeg = data[i]['eeg'].t()     # 转置为[时间, 通道]
        eeg = eeg[20:460]            # 选择有效时间窗口
        eeg = interp1d(...)          # 插值到512点
        
        # 图像处理
        image_path = 'n01440764/n01440764_10026.JPEG'
        image = Image.open(path)
        image = processor(image)     # CLIP预处理

五、数据划分模块

class Splitter:
    def __init__(self, dataset, split_path):
        loaded = torch.load(split_path)
        self.split_idx = loaded['splits'][0]['train']  # 取第一个划分方案
        # 过滤条件:
        # 1. EEG长度在450-600之间
        # 2. 被试匹配(当subject!=0时)

六、图像处理模块

class random_crop:
    def __call__(self, img):
        if 概率p: 执行随机裁剪
        else: 返回原图

def normalize2(img):
    return img * 2.0 - 1.0  # 归一化到[-1,1]

七、重要技术细节

对齐流程:

sequenceDiagram
    participant EEG_Data
    participant ImageNet
    EEG_Data->>EEGDataset: 加载样本i
    EEGDataset->>EEG_Data: 读取self.data[i]["image"]字段
    EEGDataset->>ImageNet: 根据ID构造路径
    ImageNet-->>EEGDataset: 返回对应图像
    EEGDataset->>Model: 返回{'eeg':eeg, 'image':image}

gen_eval_eeg.py

基于MAE (Masked Autoencoder) 的EEG信号预训练框架,主要包含以下核心模块:

  1. 环境配置与工具函数

  2. 数据加载与预处理

  3. 模型定义与训练流程

  4. 可视化与日志记录

  5. 分布式训练支持

1. 核心模块解析

2. 关键实现细节

4. 可视化模块

代码流程图

graph TD
    A[初始化配置] --> B[加载数据集]
    B --> C[构建MAE模型]
    C --> D[初始化优化器]
    D --> E[训练循环]
    E --> F{达到保存点?}
    F -- 是 --> G[保存模型+可视化]
    F -- 否 --> E
    G --> H[完成训练]

stageA1_eeg_pretrain.py

Pre-training on EEG data

用于大量训练的数据集从MOABB上下载,还没学会,,,,

eeg_ldm.py

Finetune the Stable Diffusion with Pre-trained EEG Encoder

实现了一个基于Latent Diffusion Model (LDM) 的EEG信号到图像生成的完整流程:


一、代码整体架构

本代码是DreamDiffusion项目的第二阶段(Stage B),主要包含以下核心模块:

  1. 配置管理(Config_Generative_Model)

  2. 数据加载与预处理(create_EEG_dataset)

  3. 生成模型定义(eLDM)

  4. 训练流程控制(main函数)

  5. 图像生成与评估(generate_images)

  6. 实验日志记录(wandb集成)

二、核心组件详解

1. 配置管理
class Config_Generative_Model:
    def __init__(self):
        # 项目参数
        self.seed = 2022
        self.root_path = '.'
        self.eeg_signals_path = 'datasets/eeg_5_95_std.pth'
        
        # 模型参数
        self.pretrain_mbm_path = 'pretrains/generation/checkpoint.pth'
        self.pretrain_gm_path = 'pretrains/stable-diffusion-v1-5'
        
        # 训练参数
        self.batch_size = 25
        self.lr = 5.3e-5
        self.num_epoch = 500
2. 数据加载
  1. 加载EEG信号和对应的ImageNet图像路径

  2. 应用两种图像变换:

    • 训练集:随机裁剪+归一化(img_transform_train

    • 测试集:仅归一化(img_transform_test

  3. 返回包含EEG-图像对的数据集

3. 生成模型(eLDM)
  • 双条件机制:同时接受EEG特征和CLIP文本特征

  • 基于Latent Diffusion架构

  • 支持从检查点恢复训练

5. 图像生成与评估
def generate_images(generative_model, dataset, num_samples, ddim_steps):
    grid, samples = generative_model.generate(dataset, num_samples, ddim_steps)
    # 保存图像网格
    Image.fromarray(grid).save('samples.png')
    
    # 计算评估指标
    metrics = get_eval_metric(samples)
    return metrics

评估指标

  • 像素级:MSE, PCC, SSIM

  • 语义级:Top-1分类准确率

三、关键技术细节

1. 条件扩散模型
graph LR
    A[EEG信号] --> B[EEG编码器]
    C[CLIP文本编码] --> D[LDM UNet]
    B --> D
    D --> E[图像生成]
2. 双阶段训练策略
  1. 阶段A:预训练EEG编码器(MAE架构)

  2. 阶段B:微调扩散模型(本代码)

3. 图像变换流水线
img_transform_train = transforms.Compose([
    normalize,                     # 归一化到[-1,1]
    transforms.Resize(512),        # 调整大小
    random_crop(448, p=0.5),       # 随机裁剪(数据增强)
    transforms.Resize(512),        # 再次调整
    channel_last                   # 通道顺序转换
])

gen_eval_eeg.py

Generating Images with Trained Checkpoints

实现了EEG信号到图像生成的评估流程:

一、代码整体架构

这段代码是DreamDiffusion项目的评估部分,主要功能是加载预训练好的生成模型,对EEG信号进行图像生成并保存结果。核心模块包括:

  1. 配置加载:从检查点恢复实验配置

  2. 数据准备:加载EEG测试数据集

  3. 模型初始化:构建条件扩散模型(eLDM)

  4. 图像生成:使用训练好的模型生成图像

  5. 结果保存:存储生成的图像网格


二、核心组件详解

图像变换流程
img_transform_test = transforms.Compose([
    normalize,                  # 归一化到[-1,1]
    transforms.Resize((512,512)), # 调整尺寸
    channel_last                # 通道顺序转换 (C,H,W)->(H,W,C)
])
  • 数据规格

    • 输入EEG形状:(num_samples, 128通道, 512时间点)

    • 输出图像尺寸:512×512

3. 模型初始化
generative_model = eLDM(
    pretrain_mbm_metafile,   # EEG编码器配置
    num_voxels,              # 输入维度=EEG特征长度
    device=device,           # 计算设备
    pretrain_root=config.pretrain_gm_path,  # SD权重路径
    ddim_steps=config.ddim_steps  # 扩散步数(默认250)
)
generative_model.model.load_state_dict(sd['model_state_dict'])  # 加载训练权重

模型架构特点

  • 双条件机制:EEG特征 + CLIP文本特征

  • 基于Latent Diffusion架构

  • 使用DDIM采样方法

4. 图像生成
# 生成训练集样本(10个实例)
grid, _ = generative_model.generate(dataset_train, 
    num_samples=config.num_samples,
    ddim_steps=config.ddim_steps,
    HW=config.HW,  # 图像尺寸
    limit=10
)

# 生成测试集样本
grid, samples = generative_model.generate(dataset_test,
    num_samples=config.num_samples,
    ddim_steps=config.ddim_steps,
    state=sd['state']  # 随机状态恢复
)

生成参数

参数含义典型值
num_samples每样本生成数量5
ddim_steps扩散采样步数250
HW图像高宽[512,512]
limit最大生成样本数10

三、关键技术细节

1. 条件生成流程
sequenceDiagram
    participant EEG
    participant Model
    participant Image
    EEG->>Model: 输入EEG信号(128ch×512t)
    Model->>Model: 通过EEG编码器提取特征
    Model->>Model: 扩散模型条件生成
    Model->>Image: 输出512×512图像

这个生成代码很有问题啊,一直报错,类似这样,很多人都出现了,但目前无法解决,,,,

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.dtcms.com/a/110786.html

相关文章:

  • 【Linux】GCC编译选项-Wl 和 链接(ld)选项-rpath
  • 【自动化运维】Python 的安装和使用
  • ARM------硬件程序开发
  • 今日行情明日机会——20250403
  • 图解AUTOSAR_LINInterface
  • JavaEE-0403学习记录
  • 数据结构C语言练习(设计循环队列)
  • CSS:换行与不换行
  • openGL 学习,Hello Triangle!
  • 【机器学习】--多分类(单标签)
  • Spring Boot 整合mybatis
  • Vulnhub-PrinkysPalacev3
  • 火山 RTC 引擎 2 ----APPKEY
  • 研究下适合部署在jeston上的深度学习类单目标跟踪算法
  • 【数据结构】哈希
  • 算法每日一练 (25)
  • 【4】搭建k8s集群系列(二进制部署)之安装master节点组件(kube-apiserver)
  • 每日一题(小白)模拟娱乐篇11
  • Nginx接收https并内部转发成http
  • LORA+llama模型微调全流程
  • SQL DB 数据类型
  • UBUNTU编译dataline
  • 云渲染平台:创意产业的算力革命
  • Java面试34-Kafka的零拷贝原理
  • 讲一下resblock的跳跃连接,以及连接前后的shape保持(通过padding保持shape不变)
  • Maven+Spring实现后端开发
  • 【滑动窗口】3254. 长度为 K 的子数组的能量值 I
  • 【UE5 C++课程系列笔记】32——读Json文件并解析
  • 【GoLang】etcd初始化客户端时不会返回错误怎么办
  • Vue3命名规范指南