当前位置: 首页 > news >正文

7.4-Creating data loaders for an instruction dataset

Chapter 7-Fine-tuning to follow instructions

7.4-Creating data loaders for an instruction dataset

  • 我们只需将InstructionDataset对象和custom_collate_fn函数接入 PyTorch 数据加载器

  • 使用以下代码来初始化设备信息

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Note:
    # Uncommenting the following lines will allow the code to run on Apple Silicon chips, if applicable,
    # which is much faster than on an Apple CPU (as measured on an M3 MacBook Air).
    # However, the resulting loss values may be slightly different.#if torch.cuda.is_available():
    #    device = torch.device("cuda")
    #elif torch.backends.mps.is_available():
    #    device = torch.device("mps")
    #else:
    #    device = torch.device("cpu")print("Device:", device)"""输出"""
    Device: cuda
    

    custom_collate_fn函数中的device参数和allowed_max_length预先设定为变量device1024。这样在后续调用customized_collate_fn时,就不需要再手动传入这两个参数的值了。

    from functools import partialcustomized_collate_fn = partial(custom_collate_fn,device=device,allowed_max_length=1024
    )
    

    接下来,我们设置数据加载器,但是这次,我们将使用我们的自定义排序函数进行批处理过程。

    from torch.utils.data import DataLoadernum_workers = 0
    batch_size = 8torch.manual_seed(123)train_dataset = InstructionDataset(train_data, tokenizer)
    train_loader = DataLoader(train_dataset,batch_size=batch_size,collate_fn=customized_collate_fn,shuffle=True,drop_last=True,num_workers=num_workers
    )val_dataset = InstructionDataset(val_data, tokenizer)
    val_loader = DataLoader(val_dataset,batch_size=batch_size,collate_fn=customized_collate_fn,shuffle=False,drop_last=False,num_workers=num_workers
    )test_dataset = InstructionDataset(test_data, tokenizer)
    test_loader = DataLoader(test_dataset,batch_size=batch_size,collate_fn=customized_collate_fn,shuffle=False,drop_last=False,num_workers=num_workers
    )
    

    让我们看看input 和target批次的维度是什么样的

    print("Train loader:")
    for inputs, targets in train_loader:print(inputs.shape, targets.shape)"""输出"""
    Train loader:
    torch.Size([8, 61]) torch.Size([8, 61])
    torch.Size([8, 76]) torch.Size([8, 76])
    ......
    torch.Size([8, 69]) torch.Size([8, 69])
    

    根据上面的输出,我们可以看到,所有批次的批次大小为8,但长度不同,第一个[8,61]表示,batchsize为8,在当前批次中,每个训练示例中的token数量为61。让我们通过打印“input”批处理中第一个训练示例的内容来仔细检查输入是否包含与tokenID 50256对应的“<|endoftext|>”填充token

    print(inputs[0])"""输出"""
    tensor([21106,   318,   281, 12064,   326,  8477,   257,  4876,    13, 19430,257,  2882,   326, 20431, 32543,   262,  2581,    13,   198,   198,21017, 46486,    25,   198, 30003,  6525,   262,  6827,  1262,   257,985,   576,    13,   198,   198, 21017, 23412,    25,   198,   464,5156,   318,   845, 13779,    13,   198,   198, 21017, 18261,    25,198,   464,  5156,   318,   355, 13779,   355,   257,  4936,    13,50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],device='cuda:0')
    

    同样,我们仔细检查target是否包含-100占位符标记

    print(target[0])"""输出"""
    tensor([  318,   281, 12064,   326,  8477,   257,  4876,    13, 19430,   257,2882,   326, 20431, 32543,   262,  2581,    13,   198,   198, 21017,46486,    25,   198, 30003,  6525,   262,  6827,  1262,   257,   985,576,    13,   198,   198, 21017, 23412,    25,   198,   464,  5156,318,   845, 13779,    13,   198,   198, 21017, 18261,    25,   198,464,  5156,   318,   355, 13779,   355,   257,  4936,    13, 50256,-100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],device='cuda:0')
    


文章转载自:

http://KzKm11uz.nwbnt.cn
http://4fjL1Lee.nwbnt.cn
http://7cDxFMS7.nwbnt.cn
http://GxZbFKI8.nwbnt.cn
http://dl5NWyTr.nwbnt.cn
http://5JFGAKwx.nwbnt.cn
http://yhSsma4R.nwbnt.cn
http://fjWReyRS.nwbnt.cn
http://E9T0abne.nwbnt.cn
http://dz4xoPzA.nwbnt.cn
http://XDY2rvME.nwbnt.cn
http://Mlyg8hgz.nwbnt.cn
http://3A424Dkr.nwbnt.cn
http://mWR5xXn0.nwbnt.cn
http://4ZwLEemJ.nwbnt.cn
http://fSxK2Bzf.nwbnt.cn
http://TE2xWmNI.nwbnt.cn
http://7s328V32.nwbnt.cn
http://lCyl1yo2.nwbnt.cn
http://4hXmSm37.nwbnt.cn
http://TFEbri6I.nwbnt.cn
http://O6Fm4Mh8.nwbnt.cn
http://q7bBFjJW.nwbnt.cn
http://Nw11IWiG.nwbnt.cn
http://65m7qvVo.nwbnt.cn
http://xZnzqMtT.nwbnt.cn
http://sCU37FUc.nwbnt.cn
http://6uEqbqzu.nwbnt.cn
http://dg7l5Plr.nwbnt.cn
http://2KDgOiDB.nwbnt.cn
http://www.dtcms.com/a/227546.html

相关文章:

  • Nacos 2.4.3 登录配置
  • Day43
  • Day43 Python打卡训练营
  • Flickr30k Entities 短语定位评测沉浸式代码指南
  • 手机归属地查询接口如何用Java调用?
  • comfyui利用 SkyReels-V2直接生成长视频本地部署问题总结 2 :寻找丢失的model 和工作流中 get set 方法的应用
  • 新版智慧社区(小区)智能化弱电系统解决方案
  • 第18讲、Odoo接口开发详解:原理、类型与实践
  • 【CF】Day73——Codeforces Round 887 (Div. 2) B (思维 + 模拟)
  • 20250602在Ubuntu20.04.6下修改压缩包的日期和时间
  • 内网应用如何实现外网访问?无公网IP本地端口网址服务提供互联网连接
  • python打卡day43@浙大疏锦行
  • 软件开发项目管理工具选型及禅道开源版安装
  • 从0开始学vue:vue3和vue2的关系
  • 《信号与系统》--期末总结V1.0
  • 【算法训练营Day05】哈希表part1
  • 逐步检索增强推理的跨知识库路由学习
  • Ubuntu22.04 安装 CUDA12.8
  • 类和对象:实现日期类
  • MATLAB 安装与使用详细教程
  • gcc符号表生成机制
  • 【位运算】只出现⼀次的数字 II(medium)
  • 【latex】易遗忘的表达
  • esp32 platformio lvgl_gif的使用和踩坑情况
  • Qt OpenGL 3D 编程入门
  • 2 Studying《Effective STL》
  • 使用ArcPy批量处理矢量数据
  • inux系统基本操作命令(系统信息查看)
  • MyBatis04:SpringBoot整合MyBatis——多表关联|延迟加载|MyBatisX插件|SQL注解
  • Linux 基础指令入门指南:解锁命令行的实用密码