[nanoGPT] 文本生成 | 自回归采样 | `generate`方法
第六章:文本生成与采样
欢迎回来
在第五章:检查点与预训练模型加载中,我们学会了如何保存模型习得的知识,甚至加载强大的预训练GPT-2模型。现在,经过所有准备和训练,终于迎来激动人心的环节:让我们的模型创作文本
想象你培养了一位才华横溢的年轻作家。你已教授其语法、词汇甚至特定文风(如莎士比亚风格)。现在,你给出开头句子,要求其创造性续写。这正是文本生成与采样的核心。
本章重点在于使用训练好的GPT模型,基于给定提示生成新颖连贯的文本。
其流程包括:输入初始文本→预测最可能的下个标记→重复追加标记直至达到目标长度。
用户可通过temperature(控制随机性)和top_k(限制候选标记范围)等参数调节生成文本的创造性。
文本生成的核心任务
核心功能是:使用训练好的GPT模型,基于初始提示生成类人文本。
例如输入:“在遥远的银河系…”
模型应据此续写合理的故事段落。
核心机制:自回归采样
GPT模型通过自回归采样逐标记生成文本,如同谨慎的作家:写一个词→思考下一个词→再写→再思考,循环往复。
基本流程:
- 输入提示:提供初始标记序列(开头文本)
- 预测下个标记:模型根据当前序列预测下一个最可能标记
- 扩展序列:将预测标记追加到序列
- 循环执行:用新序列重复预测,直至达到目标长度
控制台:sample.py脚本
nanoGPT中负责文本生成的主脚本是sample.py。该脚本加载训练好的模型(或预训练模型)并生成文本,同时提供类似第三章:配置系统中的"调节旋钮"来控制文本风格与创造性。
典型调用方式:
python sample.py --start="你好,我叫" --num_samples=1 --max_new_tokens=50
该命令指示脚本以"你好,我叫"开头,生成1段50个标记的文本。
创造性控制参数
sample.py提供多个关键参数调节生成过程:
--start:初始提示文本。可以是简单字符串或文件路径(如FILE:prompt.txt)--max_new_tokens:控制生成的新标记数量。若提示含5个标记且设为50,最终输出为55个标记--temperature:浮点数(如0.8/1.0/1.2)控制输出随机性=1.0:按模型原始概率采样<1.0(如0.5):输出更确定化,偏向高频词>1.0(如1.2):增加非常用词概率,输出更天马行空- 类比:视作文本"辣度",低值保守,高值奔放
--top_k:整数(如5/200)限制候选标记范围- 仅考虑词表中前
k个最可能标记 =1时总是选择最高概率标记,输出确定性高但易重复=200时在前200个候选中选择,平衡创造性与连贯性- 类比:如同为作家提供精选的200个候选词
- 仅考虑词表中前
生成实例
1. 从莎士比亚微调模型生成
首先按第四章训练字符级莎士比亚模型:
python data/shakespeare_char/prepare.py
python train.py config/train_shakespeare_char.py
训练完成后(或设置always_save_checkpoint=True实时保存),在out-shakespeare-char目录生成ckpt.pt文件。生成示例:
python sample.py --out_dir=out-shakespeare-char --start="ROMEO:" --num_samples=1 --max_new_tokens=100 --temperature=0.8 --top_k=50
示例输出:
ROMEO:
I am the death of all the land.
What, art thou come? and what will be the day?
I will not be gone.
这是小型字符级模型的输出,虽不完美但遵循了"ROMEO:"模式。注意通过out_dir指定模型路径。
2. 从预训练GPT-2 XL模型生成
使用大型预训练GPT-2模型(需先按第五章配置):
python sample.py \--init_from=gpt2-xl \--start="生命、宇宙及万物的终极答案是什么?" \--num_samples=1 \--max_new_tokens=100 \--temperature=0.7 \--top_k=20
示例输出:
生命、宇宙及万物的终极答案是什么?
根据《银河系漫游指南》的记载,答案是42。但问题在于,人们问错了问题。真正的问题应该是"哪个问题的答案是42?"。这个原问题比表面看起来复杂得多,已成为哲学界长期辩论的主题。
大型模型配合BPE分词(见第一章)能生成更高质量的文本,甚至引经据典
核心实现:generate方法
sample.py的核心是调用GPT类中的generate方法,该方法实现逐标记生成逻辑
生成流程
假设输入提示为"The cat sat":
- 初始化:将提示转为标记ID序列(如
[10,20,30]),获取max_new_tokens等参数 - 循环生成:开始最多50次的循环(对应
max_new_tokens) - 上下文截断:若当前序列超模型
block_size,仅保留最近部分 - 模型预测:将当前序列输入模型,获取所有可能下个标记的原始分数(logits)
- 温度调节:用
temperature调整分数分布,高值使分布更平缓 - Top-K过滤:若设置
top_k,仅保留前k个高分标记 - 概率计算:通过
softmax将分数转为概率 - 标记采样:依概率随机选取一个标记
- 序列扩展:将新标记追加到当前序列
- 循环继续:用扩展后的序列重复预测
- 返回结果:达到目标长度后返回完整标记序列
代码(model.py)
model.py中generate方法的核心逻辑(使用torch.no_grad()避免计算梯度):
# 摘自model.py(简化版)
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):for _ in range(max_new_tokens):# 1. 上下文截断idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]# 2. 获取预测分数logits, _ = self(idx_cond)logits = logits[:, -1, :] # 仅取最后位置的预测# 3. 温度调节logits = logits / temperature# 4. Top-K过滤if top_k is not None:v, _ = torch.topk(logits, min(top_k, logits.size(-1)))logits[logits < v[:, [-1]]] = -float('Inf')# 5. 概率计算与采样probs = F.softmax(logits, dim=-1)idx_next = torch.multinomial(probs, num_samples=1)# 6. 序列扩展idx = torch.cat((idx, idx_next), dim=1)return idx
self(idx_cond):调用模型前向传播(见第二章),获取每个位置的预测分数logits[:, -1, :]:仅保留序列末位的预测(即下一个标记)torch.topk:找出前k个高分标记torch.multinomial:依概率分布随机采样torch.cat:将新标记追加到序列
最终sample.py通过解码器(见第一章)将标记ID序列转为可读文本
小结
我们探索了文本生成与采样的精彩世界,学会了如何:
- 使用
sample.py脚本控制生成过程 - 通过
temperature和top_k调节文本创造性 - 从微调模型或预训练大模型生成文本
- 理解
generate方法的自回归采样机制
现在我们的模型已能创作文本,下一站是让训练与生成过程更加高效
下一章:性能与效率工具
