当前位置: 首页 > news >正文

nano-GPT:最小可复现的GPT实操

引言

昨天看到新闻,不少人在介绍nano-GPT[1]这个项目,去仓库地址看了眼,是个三年前开源的项目,不知道为什么现在突然又受到关注。

这个项目的作者是 OpenAI 的研究员 Andrej Karpathy,该项目旨在用最少量/简单的代码来对GPT-2(124M)代码进行微调,就像下图所展现的那样,可用的GPT是巡洋舰,nanoGPT是小帆船,代码虽少,五脏俱全也能跑。

image.png

小帆船意味着只用笔记本级别的显卡就能复现出GPT-2的训练过程,从实用角度来说,放到今日已然过时。

但从学习的角度来说,是一份优质的学习资料。之前已经分析过GPT-1[2]和GPT-2[3]的基本原理,这个代码可以作为一份高质量的实操材料。

本文就从实践的角度,运行并解析一下这个项目。

环境安装

首先需要安装python相关依赖。

如果是 Linux,可以直接用 uv 安装环境:

uv add torch numpy transformers datasets tiktoken wandb tqdm

在 Windows 上,直接安装 torch 会报错,因此需要修改一下pyproject.toml,改成以下内容,主要添加uv.index

[project]
name = "nanogpt"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"dependencies = ["torch>=2.6.0","torchvision>=0.21.0","datasets>=4.2.0","numpy>=2.3.3","tiktoken>=0.12.0","tqdm>=4.67.1","transformers>=4.57.1","wandb>=0.22.2",
][tool.uv.sources]
torch = [{ index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
torchvision = [{ index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
][[tool.uv.index]]
name = "pytorch-cu126"
url = "https://download.pytorch.org/whl/cu126"
explicit = true

注:这个项目实际用不到 torchvision,但是为了方便其它项目可复制此配置,将torchvision的信息进行保留。

然后运行

uv sync

安装好之后,可以用下面的脚本,测试 torch 是否能正常工作,以及是否能准确识别出显卡信息:

import torchif torch.cuda.is_available():print("CUDA is available!")print(f"Number of GPUs: {torch.cuda.device_count()}")print(f"GPU Name: {torch.cuda.get_device_name(0)}")
else:print("CUDA is not available on this system.")

数据获取

要训练模型,首先需要准备数据,初次运行可以使用 Tiny Shakespeare 这个数据集,运行数据下载脚本:

python data/shakespeare_char/prepare.py

Tiny Shakespeare 是一个包含了莎士比亚所有作品(包括戏剧和诗歌)的纯文本文件,文本内容部分如下图所示,格式类似于剧本。

image.png

其中,训练集train.bin 包含 1,003,854 tokens;验证集val.bin 包含 111,540 tokens。

训练模型

有了数据之后,就可以开始训练模型,训练脚本train.py支持三种模式:

  • scratch:从零开始进行训练
  • resume:从某个checkpoint恢复训练
  • gpt2:从gpt2模型的权重进行初始化

运行下面的命令,即可采用scratch模式进行训练:

python train.py config/train_shakespeare_char.py --compile=False

如果不加compile=False,在windows系统上无法运行。

compile 是 pytorch2.0 引入的特性,compile 能够在动态计算图的基础上,通过算子融合、常量折叠等一系列优化操作,针对目标硬件(如 NVIDIA GPU、AMD GPU、CPU),生成高度优化的底层机器代码,从而加快模型运行效率。

然而,该特性目前尚未支持windows,因此需要将其关闭,linux系统则无需添加此命令。

根据config/train_shakespeare_char.py这个参数配置,该脚本训练的是一个“缩水版”的GPT-2,层数(n_layer)是6,多头头数(n_head)是6,嵌入层维度(n_embd)是384,这三个Transformer block的关键值是真正的 GPT-2 的一半。

总共的迭代次数(max_iters)被设置为5000,运行约占用4GB显存,差不多2小时就能训完,训练完的模型参数默认保存在out-shakespeare-char\ckpt.pt

训练中,每隔 250 轮会自动进行一次验证,在控制台会输出相关的损失:

step 5000: train loss 0.6137, val loss 1.7057
iter 5000: loss 0.8154, time 78204.53ms, mfu 0.61%

这里有个值mfu还是比较陌生的,它代表的是模型浮点运算利用率,可以理解为GPU的利用率。由于我没有开compile,数据io也没有通过缓存等方式进行加速,因此GPU大部分时间都在等待,mfu值并不高。

推理

运行下面的命令,可以从加载训练好的模型,并进行推理:

python sample.py --init_from='resume' --out_dir='out-shakespeare-char' --start="What is the answer to life, the universe, and everything?"  --num_samples=1

由于gpt是个预测模型,即根据输入的内容不断预测下一个词,因此start就是输入的内容,num_samples是指进行一次采样,即进行一次推理,由于输出内容会受到temperaturetop_k等参数的影响,因此每次都会不太一样。

运行结果如下:

What is the answer to life, the universe, and everything?CORIOLANUS:
O, affection the
baits aboard with the people?MENENIUS:
Your uncle, my lord: you are all the confiscation
In peace, and full of the state,
Which you ever have late made over-peared me now
With joys, like dispatched her departure,
That seem to make a sentence of golden power,
And that you have done evil to your dishonour,
And now you mock my business noble lords
Are great up to your execution.EDWARD:
How art thou! still is a royal point?YORK:
How now! what? Irely most my cousi
---------------

默认的max_new_tokens数值是500,意味着输出500个tokens就会被截断。

模型结构

模型结构在model.py这个文件,整体基本就是Transformer的Decoder部分,根据代码画了一个结构图,如下图所示:

结构图

总结

这个项目是个挺不错的学习资料,除了本文所提及的部分之外,该项目还支持了一些多卡DP,特定优化编译等实用技巧,对理解GPT原理会很有帮助。

参考

[1] nanoGPT仓库:https://github.com/karpathy/nanoGPT
[2] 【不背八股】18.GPT1:GPT系列的初代目:https://zstar.blog.csdn.net/article/details/152122075
[3] 【不背八股】19.GPT-2:不再微调,聚焦零样本:https://zstar.blog.csdn.net/article/details/152233948

http://www.dtcms.com/a/490658.html

相关文章:

  • 网站建设公众号wordpress中文模板下载地址
  • 菜单及库(Num28)
  • super()核心作用是调用父类的属性/方法
  • 【Win32 多线程程序设计基础第三章笔记】
  • CentOS 7 FTP安装与配置详细介绍
  • 网页设计跟网站建设的区别淘宝店铺运营推广
  • 机器学习使用GPU
  • 做网站分为哪些功能的网站找工作网
  • 湖南粒界教育科技有限公司:专注影视技能培养,AI辅助教学提升学员就业竞争力
  • 【系统分析师】写作框架:静态测试方法及其应用
  • React useEffect组件渲染执行操作 组件生命周期 监视器 副作用
  • 在哪些场景下适合使用 v-model 机制?
  • 长沙申请域名网站备案查域名服务商
  • 游标卡尺 东莞网站建设大连建设工程信息网去哪里找
  • 华为USG防火墙之开局上网配置
  • 【第五章:计算机视觉-计算机视觉在医学领域中应用】1.生物细胞检测实战-(3)基于YOLO的细胞检测实战:数据读取、模型搭建、训练与测试
  • 【MFC实用技巧】对话框“边框”属性四大选项:None、Thin、Resizing、对话框外框,到底怎么选?
  • 网站备案 备注关联性天津网站建设内容
  • 所有网站收录入口济南市住监局官网
  • frida android quickstart
  • 作为测试工程师,我们该如何应用 AI?
  • 【Flutter】Flutter项目整体架构
  • 电子电气架构 --- 未来汽车软件架构
  • 怎么优化网站关键词辽宁省住房建设厅网站科技中心
  • 电力自动化新突破:Modbus如何变身Profinet?智能仪表连接的终极解决方案
  • cGVHD患者的血常规指标 生化指标 动态监测
  • 重庆网站建设师网站顶部布局
  • 【算法与数据结构】二叉树后序遍历非递归算法:保姆级教程(附具体实例+可运行代码)
  • AI-调查研究-105-具身智能 机器人学习数据采集:从示范视频到状态-动作对的流程解析
  • 基于 PyQt5 的多算法视频关键帧提取工具