当前位置: 首页 > news >正文

7.4-Creating data loaders for an instruction dataset

Chapter 7-Fine-tuning to follow instructions

7.4-Creating data loaders for an instruction dataset

  • 我们只需将InstructionDataset对象和custom_collate_fn函数接入 PyTorch 数据加载器

  • 使用以下代码来初始化设备信息

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Note:
    # Uncommenting the following lines will allow the code to run on Apple Silicon chips, if applicable,
    # which is much faster than on an Apple CPU (as measured on an M3 MacBook Air).
    # However, the resulting loss values may be slightly different.#if torch.cuda.is_available():
    #    device = torch.device("cuda")
    #elif torch.backends.mps.is_available():
    #    device = torch.device("mps")
    #else:
    #    device = torch.device("cpu")print("Device:", device)"""输出"""
    Device: cuda
    

    custom_collate_fn函数中的device参数和allowed_max_length预先设定为变量device1024。这样在后续调用customized_collate_fn时,就不需要再手动传入这两个参数的值了。

    from functools import partialcustomized_collate_fn = partial(custom_collate_fn,device=device,allowed_max_length=1024
    )
    

    接下来,我们设置数据加载器,但是这次,我们将使用我们的自定义排序函数进行批处理过程。

    from torch.utils.data import DataLoadernum_workers = 0
    batch_size = 8torch.manual_seed(123)train_dataset = InstructionDataset(train_data, tokenizer)
    train_loader = DataLoader(train_dataset,batch_size=batch_size,collate_fn=customized_collate_fn,shuffle=True,drop_last=True,num_workers=num_workers
    )val_dataset = InstructionDataset(val_data, tokenizer)
    val_loader = DataLoader(val_dataset,batch_size=batch_size,collate_fn=customized_collate_fn,shuffle=False,drop_last=False,num_workers=num_workers
    )test_dataset = InstructionDataset(test_data, tokenizer)
    test_loader = DataLoader(test_dataset,batch_size=batch_size,collate_fn=customized_collate_fn,shuffle=False,drop_last=False,num_workers=num_workers
    )
    

    让我们看看input 和target批次的维度是什么样的

    print("Train loader:")
    for inputs, targets in train_loader:print(inputs.shape, targets.shape)"""输出"""
    Train loader:
    torch.Size([8, 61]) torch.Size([8, 61])
    torch.Size([8, 76]) torch.Size([8, 76])
    ......
    torch.Size([8, 69]) torch.Size([8, 69])
    

    根据上面的输出,我们可以看到,所有批次的批次大小为8,但长度不同,第一个[8,61]表示,batchsize为8,在当前批次中,每个训练示例中的token数量为61。让我们通过打印“input”批处理中第一个训练示例的内容来仔细检查输入是否包含与tokenID 50256对应的“<|endoftext|>”填充token

    print(inputs[0])"""输出"""
    tensor([21106,   318,   281, 12064,   326,  8477,   257,  4876,    13, 19430,257,  2882,   326, 20431, 32543,   262,  2581,    13,   198,   198,21017, 46486,    25,   198, 30003,  6525,   262,  6827,  1262,   257,985,   576,    13,   198,   198, 21017, 23412,    25,   198,   464,5156,   318,   845, 13779,    13,   198,   198, 21017, 18261,    25,198,   464,  5156,   318,   355, 13779,   355,   257,  4936,    13,50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],device='cuda:0')
    

    同样,我们仔细检查target是否包含-100占位符标记

    print(target[0])"""输出"""
    tensor([  318,   281, 12064,   326,  8477,   257,  4876,    13, 19430,   257,2882,   326, 20431, 32543,   262,  2581,    13,   198,   198, 21017,46486,    25,   198, 30003,  6525,   262,  6827,  1262,   257,   985,576,    13,   198,   198, 21017, 23412,    25,   198,   464,  5156,318,   845, 13779,    13,   198,   198, 21017, 18261,    25,   198,464,  5156,   318,   355, 13779,   355,   257,  4936,    13, 50256,-100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],device='cuda:0')
    

相关文章:

  • Nacos 2.4.3 登录配置
  • Day43
  • Day43 Python打卡训练营
  • Flickr30k Entities 短语定位评测沉浸式代码指南
  • 手机归属地查询接口如何用Java调用?
  • comfyui利用 SkyReels-V2直接生成长视频本地部署问题总结 2 :寻找丢失的model 和工作流中 get set 方法的应用
  • 新版智慧社区(小区)智能化弱电系统解决方案
  • 第18讲、Odoo接口开发详解:原理、类型与实践
  • 【CF】Day73——Codeforces Round 887 (Div. 2) B (思维 + 模拟)
  • 20250602在Ubuntu20.04.6下修改压缩包的日期和时间
  • 内网应用如何实现外网访问?无公网IP本地端口网址服务提供互联网连接
  • python打卡day43@浙大疏锦行
  • 软件开发项目管理工具选型及禅道开源版安装
  • 从0开始学vue:vue3和vue2的关系
  • 《信号与系统》--期末总结V1.0
  • 【算法训练营Day05】哈希表part1
  • 逐步检索增强推理的跨知识库路由学习
  • Ubuntu22.04 安装 CUDA12.8
  • 类和对象:实现日期类
  • MATLAB 安装与使用详细教程
  • 做外汇模拟的网站/培训心得简短
  • 做网站加入视频无法播放/汽车seo是什么意思
  • 阿里云做企业网站/职业培训机构有哪些
  • 自己建网站做那个模块好/海外aso优化
  • 深圳做网站费用/软文发稿平台
  • 旅游商务网站建设/厦门seo俱乐部