当前位置: 首页 > news >正文

在自己的数据上复现一下LlamaGen

git仓库:https://github.com/FoundationVision/LlamaGen

数据集准备

如果用ImageFolder读取,则最好和ImageNet一致。

data_path/
    class_1/
        image_001.jpg
        image_002.jpg
        ...
    class_2/
        image_003.jpg
        image_004.jpg
        ...
    ...
    class_n/
        image_005.jpg
        image_006.jpg
        ...

def build_imagenet(args, transform):
    return ImageFolder(args.data_path, transform=transform)

如果是train,val,test,最好整理成

data_path/
    train/
        class_1/
            image_001.jpg
            image_002.jpg
            ...
        class_2/
            image_003.jpg
            image_004.jpg
            ...
        ...
    val/
        class_1/
            image_005.jpg
            image_006.jpg
            ...
        class_2/
            image_007.jpg
            image_008.jpg
            ...
        ...
    test/
        class_1/
            image_009.jpg
            image_010.jpg
            ...
        class_2/
            image_011.jpg
            image_012.jpg
            ...
        ...

读取:

train_dataset = datasets.ImageFolder(root=args.data_path + '/train', transform=transform)

# 加载验证集
val_dataset = datasets.ImageFolder(root=args.data_path + '/val', transform=transform)

# 加载测试集
test_dataset = datasets.ImageFolder(root=args.data_path + '/test', transform=transform)

数据集预处理

NCCL_IB_DISABLE=1 NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=3 torchrun \
--nnodes=1 --nproc_per_node=1 --node_rank=0 \
--master_addr=localhost \
autoregressive/train/extract_codes_c2i.py \
--vq-ckpt ./pretrained_models/vq_ds16_c2i.pt \
--data-path 你的数据集 \
--code-path VQGAN处理的数据集放在哪 \
 --ten-crop \
 --crop-range 1.1 \
 --image-size 256

这里改成自己数据集的长度
在这里插入图片描述

ten-crop是作者定义的一种数据增强,每一个图片生成10个crop。最好修改一下这里的代码,训练的时候仅仅取一个。

在这里插入图片描述
注释掉这个self.flip

训练

NCCL_IB_DISABLE=1 NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=4,5 torchrun \
--nnodes=1 --nproc_per_node=2 --node_rank=0 \
--master_addr=localhost \
--master_port=8902 \
./autoregressive/train/train_c2i.py \
--cloud-save-path xxx \
--code-path 之前放VQGAN处理后数据集的地方 \
--image-size 256 \
--gpt-model GPT-B

生成

修改类别,权重

parser.add_argument("--num-classes", type=int, default=xxx)

label定义:
在这里插入图片描述
我的生成结果(数据集用了TinyImageNet的8个类)
300step
在这里插入图片描述

1500step
在这里插入图片描述

相关文章:

  • 开发HarmonyOS NEXT版五子棋游戏实战
  • 【Linux】vim 设置
  • 深入理解 Linux 中的 last 和 lastb 命令
  • OpenGL 04--GLSL、数据类型、Uniform、着色器类
  • Unity XR-XR Interaction Toolkit开发使用方法(十一)组件介绍(XR Interactable)
  • 在单位,领导不说,但自己得懂的7个道理
  • LSM-Tree (日志结构合并树)
  • Linux 运维工具-下载多个链接wget,aria2c
  • 06.【C++】模板初阶(template<typename T>,充分复用函数,函数模板和类模板的使用)
  • C#实现本地Deepseek模型及其他模型的对话
  • 在服务器Ubuntu22.04系统下,ComfyUI的部署
  • JavaScript系列(89)--前端模块化工程详解
  • centos和ubuntu安装mysql教程
  • 基于 Python 的网络监控系统开发全解
  • Android-创建mipmap-anydpi-v26的Logo
  • Activiti 5 + Spring Boot全流程开发指南
  • web安全——分析应用程序
  • java基本常识
  • 2025最新Flask学习笔记(对照Django做解析)
  • vue3-06vue2(Object.defineProperty)与vue3(基于ES6的Proxy)的响应式原理对比
  • 门户网站建设课程设计/关键词优化排名软件流量词
  • 网站建设 流程 域名申请/竞价网站推广
  • 门户网站建设理由/aso排名
  • 网站建设操作系统/推广赚钱软件排行
  • 做一个彩票网站需要怎么做/企业网站建设方案
  • 做视频网站推广/外贸平台哪个网站最好