当前位置: 首页 > news >正文

LLMs-from-scratch(dataloader)

代码链接:https://github.com/rasbt/LLMs-from-scratch/blob/main/ch02/01_main-chapter-code/dataloader.ipynb

《从零开始构建大型语言模型》一书的补充代码,作者:Sebastian Raschka

代码仓库:https://github.com/rasbt/LLMs-from-scratch

主要数据加载管道总结

完整的章节代码位于 ch02.ipynb。

这个笔记本包含了主要要点,即不包含中间步骤的数据加载管道。

本笔记本中使用的包:

# NBVAL_SKIP
from importlib.metadata import versionprint("torch version:", version("torch"))
print("tiktoken version:", version("tiktoken"))
torch version: 2.5.1+cu124
tiktoken version: 0.12.0
import tiktoken
import torch
from torch.utils.data import Dataset, DataLoaderclass GPTDatasetV1(Dataset):def __init__(self, txt, tokenizer, max_length, stride):self.input_ids = []self.target_ids = []# 对整个文本进行标记化token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})# 使用滑动窗口将书籍分割成重叠的 max_length 长度序列for i in range(0, len(token_ids) - max_length, stride):input_chunk = token_ids[i:i + max_length]target_chunk = token_ids[i + 1: i + max_length + 1]self.input_ids.append(torch.tensor(input_chunk))self.target_ids.append(torch.tensor(target_chunk))def __len__(self):return len(self.input_ids)def __getitem__(self, idx):return self.input_ids[idx], self.target_ids[idx]def create_dataloader_v1(txt, batch_size, max_length, stride,shuffle=True, drop_last=True, num_workers=0):# 初始化标记器tokenizer = tiktoken.get_encoding("gpt2")# 创建数据集dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)# 创建数据加载器dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)return dataloader# 读取文本文件
with open("the-verdict.txt", "r", encoding="utf-8") as f:raw_text = f.read()# 定义模型参数
vocab_size = 50257      # 词汇表大小
output_dim = 256        # 输出维度
context_length = 1024   # 上下文长度# 创建嵌入层
token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)  # 标记嵌入层
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)  # 位置嵌入层# 设置数据加载器参数
batch_size = 8
max_length = 4
dataloader = create_dataloader_v1(raw_text,batch_size=batch_size,max_length=max_length,stride=max_length
)
# 处理一个批次的数据
for batch in dataloader:x, y = batch  # x: 输入序列, y: 目标序列# 获取标记嵌入token_embeddings = token_embedding_layer(x)# 获取位置嵌入pos_embeddings = pos_embedding_layer(torch.arange(max_length))# 将标记嵌入和位置嵌入相加得到输入嵌入input_embeddings = token_embeddings + pos_embeddingsbreak  # 只处理第一个批次作为示例
print(input_embeddings.shape)  # 输出嵌入张量的形状
torch.Size([8, 4, 256])
http://www.dtcms.com/a/490083.html

相关文章:

  • 兴义哪有做网站婚纱影楼网站源码
  • C++_394_tableWidget控件,两种模式,1、行显示模式 2、网格显示模式
  • MyBatis拦截器实现saas租户同库同表数据隔离
  • 求n以内最大的k个素数以及它们的和
  • 手机 网站建设在线自动取名网站怎么做
  • PHP电动汽车租赁管理系统-计算机毕业设计源码35824
  • 零基础新手小白快速了解掌握服务集群与自动化运维(十二)Python3编程之python基础
  • 大型网站怎样做优化PHP营销推广的主要方法
  • 【泛3C篇】AI深度学习在手机前/后摄像头外观缺陷检测应用方案
  • 建设网站需要申请网站建设与管理专业好找工作吗
  • 绿色在线网站模板下载工具别人做的网站不能用怎么办
  • Initiater for mac 小巧的菜单栏OCR工具
  • ntfs可以用在mac上吗?3 种实用方案,解决Mac与NTFS硬盘兼容问题
  • 数据结构——二十、树与森林的遍历
  • 洛杉矶服务器常见问题汇总与解决方案大全
  • Linux云计算基础篇(27)-NFS网络文件系统
  • Mac安装使用Gradle
  • 夜莺监控设计思考(二)边缘机房架构思考
  • AI+大数据时代:时序数据库的架构革新与生态重构
  • 【记录】MAC本地微调大模型(MLX + Qwen2.5)并利用Ollama接入项目实战
  • wordpress 导购站模板接私活app有哪些平台
  • 有哪些网站可以做推广十大奢侈品牌logo图片
  • 服务注册 / 服务发现 - Eureka
  • 2025机器人自动化打磨抛光设备及汽车零件打磨新技术10月应用解析
  • bk7258 libzip崩溃之解决
  • 【Android】【底层机制】组件生命周期以及背后的状态管理
  • CPM:CMake 包管理详细介绍
  • D3.js + SVG:数据可视化领域的黄金搭档,绘制动态交互图表。
  • 【个人成长笔记】在 QT 中 SkipEmptyParts 编译错误信息及其解决方案
  • 设计模式篇之 备忘录模式 Memento