当前位置: 首页 > news >正文

[nanoGPT] 文本生成 | 自回归采样 | `generate`方法

第六章:文本生成与采样

欢迎回来

在第五章:检查点与预训练模型加载中,我们学会了如何保存模型习得的知识,甚至加载强大的预训练GPT-2模型。现在,经过所有准备和训练,终于迎来激动人心的环节:让我们的模型创作文本

想象你培养了一位才华横溢的年轻作家。你已教授其语法、词汇甚至特定文风(如莎士比亚风格)。现在,你给出开头句子,要求其创造性续写。这正是文本生成与采样的核心。

本章重点在于使用训练好的GPT模型,基于给定提示生成新颖连贯的文本。

其流程包括:输入初始文本→预测最可能的下个标记→重复追加标记直至达到目标长度。

用户可通过temperature(控制随机性)和top_k(限制候选标记范围)等参数调节生成文本的创造性。

文本生成的核心任务

核心功能是:使用训练好的GPT模型,基于初始提示生成类人文本

例如输入:“在遥远的银河系…”
模型应据此续写合理的故事段落。

核心机制:自回归采样

GPT模型通过自回归采样逐标记生成文本,如同谨慎的作家:写一个词→思考下一个词→再写→再思考,循环往复

基本流程:

  1. 输入提示:提供初始标记序列(开头文本)
  2. 预测下个标记:模型根据当前序列预测下一个最可能标记
  3. 扩展序列:将预测标记追加到序列
  4. 循环执行:用新序列重复预测,直至达到目标长度

控制台:sample.py脚本

nanoGPT中负责文本生成的主脚本是sample.py。该脚本加载训练好的模型(或预训练模型)并生成文本,同时提供类似第三章:配置系统中的"调节旋钮"来控制文本风格与创造性。

典型调用方式:

python sample.py --start="你好,我叫" --num_samples=1 --max_new_tokens=50

该命令指示脚本以"你好,我叫"开头,生成1段50个标记的文本。

创造性控制参数

sample.py提供多个关键参数调节生成过程:

  • --start:初始提示文本。可以是简单字符串或文件路径(如FILE:prompt.txt
  • --max_new_tokens:控制生成的新标记数量。若提示含5个标记且设为50,最终输出为55个标记
  • --temperature:浮点数(如0.8/1.0/1.2)控制输出随机性
    • =1.0:按模型原始概率采样
    • <1.0(如0.5):输出更确定化,偏向高频词
    • >1.0(如1.2):增加非常用词概率,输出更天马行空
    • 类比:视作文本"辣度",低值保守,高值奔放
  • --top_k:整数(如5/200)限制候选标记范围
    • 仅考虑词表中前k个最可能标记
    • =1时总是选择最高概率标记,输出确定性高但易重复
    • =200时在前200个候选中选择,平衡创造性与连贯性
    • 类比:如同为作家提供精选的200个候选词

生成实例

1. 从莎士比亚微调模型生成

首先按第四章训练字符级莎士比亚模型:

python data/shakespeare_char/prepare.py
python train.py config/train_shakespeare_char.py

训练完成后(或设置always_save_checkpoint=True实时保存),在out-shakespeare-char目录生成ckpt.pt文件。生成示例:

python sample.py --out_dir=out-shakespeare-char --start="ROMEO:" --num_samples=1 --max_new_tokens=100 --temperature=0.8 --top_k=50

示例输出

ROMEO:
I am the death of all the land.
What, art thou come? and what will be the day?
I will not be gone.

这是小型字符级模型的输出,虽不完美但遵循了"ROMEO:"模式。注意通过out_dir指定模型路径。

2. 从预训练GPT-2 XL模型生成

使用大型预训练GPT-2模型(需先按第五章配置):

python sample.py \--init_from=gpt2-xl \--start="生命、宇宙及万物的终极答案是什么?" \--num_samples=1 \--max_new_tokens=100 \--temperature=0.7 \--top_k=20

示例输出

生命、宇宙及万物的终极答案是什么?
根据《银河系漫游指南》的记载,答案是42。但问题在于,人们问错了问题。真正的问题应该是"哪个问题的答案是42?"。这个原问题比表面看起来复杂得多,已成为哲学界长期辩论的主题。

大型模型配合BPE分词(见第一章)能生成更高质量的文本,甚至引经据典

核心实现:generate方法

sample.py的核心是调用GPT类中的generate方法,该方法实现逐标记生成逻辑

生成流程

假设输入提示为"The cat sat":

  1. 初始化:将提示转为标记ID序列(如[10,20,30]),获取max_new_tokens等参数
  2. 循环生成:开始最多50次的循环(对应max_new_tokens
  3. 上下文截断:若当前序列超模型block_size仅保留最近部分
  4. 模型预测:将当前序列输入模型,获取所有可能下个标记的原始分数(logits)
  5. 温度调节:用temperature调整分数分布,高值使分布更平缓
  6. Top-K过滤:若设置top_k,仅保留前k个高分标记
  7. 概率计算:通过softmax分数转为概率
  8. 标记采样:依概率随机选取一个标记
  9. 序列扩展:将新标记追加到当前序列
  10. 循环继续用扩展后的序列重复预测
  11. 返回结果:达到目标长度后返回完整标记序列

代码(model.py)

model.pygenerate方法的核心逻辑(使用torch.no_grad()避免计算梯度):

# 摘自model.py(简化版)
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):for _ in range(max_new_tokens):# 1. 上下文截断idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]# 2. 获取预测分数logits, _ = self(idx_cond)logits = logits[:, -1, :] # 仅取最后位置的预测# 3. 温度调节logits = logits / temperature# 4. Top-K过滤if top_k is not None:v, _ = torch.topk(logits, min(top_k, logits.size(-1)))logits[logits < v[:, [-1]]] = -float('Inf')# 5. 概率计算与采样probs = F.softmax(logits, dim=-1)idx_next = torch.multinomial(probs, num_samples=1)# 6. 序列扩展idx = torch.cat((idx, idx_next), dim=1)return idx
  • self(idx_cond):调用模型前向传播(见第二章),获取每个位置的预测分数
  • logits[:, -1, :]:仅保留序列末位的预测(即下一个标记)
  • torch.topk:找出前k个高分标记
  • torch.multinomial:依概率分布随机采样
  • torch.cat将新标记追加到序列

最终sample.py通过解码器(见第一章)将标记ID序列转为可读文本

小结

我们探索了文本生成与采样的精彩世界,学会了如何:

  • 使用sample.py脚本控制生成过程
  • 通过temperaturetop_k调节文本创造性
  • 从微调模型或预训练大模型生成文本
  • 理解generate方法的自回归采样机制

现在我们的模型已能创作文本,下一站是让训练与生成过程更加高效

下一章:性能与效率工具

http://www.dtcms.com/a/536945.html

相关文章:

  • 【Linux专栏】shell脚本变量的取值|转换
  • [Dify 实战] 插件调试技巧(进阶篇):本地测试与部署全流程问题排查指南(Dify本地部署环境下)
  • 一、初识 LangChain:架构、应用与开发环境部署
  • 中山公司网站建设阿里云域名交易平台
  • 做flash音乐网站的开题报告做网站建设的合同范本
  • Trait与泛型高级用法
  • 解锁效率:一份关于大语言模型量化的综合技术指南
  • 网站后天添加文章不显示上海搜索优化推广哪家强
  • 前端基础之《React(3)—webpack简介-集成JSX语法支持》
  • 虚拟机之间配置免密登录(Centos)
  • 嵌入式测试的工作内容非常具体和专业化,它横跨了软件和硬件两个领域。
  • 保定网站建设团队网站备案密码 多少位
  • ZW3D二次开发_整图缩放
  • 滁州网站建设费用开发公司网签价格
  • 福州建网站公司最好的营销型网站建设公司
  • 新手入门:一篇看懂计算机基础核心知识
  • 每日算法刷题Day80:10.27:leetcode 回溯11道题,用时3h
  • 建设一个网站的规划广州seo公司如何
  • [强化学习] 第1篇:奖励信号是智能的灵魂
  • 从“看得见“到“看得懂“:监控安全管理的智能进化
  • YOLOv5 核心模块解析与可视化
  • 昆山外贸型网站制作建站科技公司
  • 快速建站框架网站如何做360优化
  • 网站公司做网站网络推广公司介绍
  • 百度网站验证方法室内设计效果图多少钱一张
  • 网站服务器查找wordpress cms主题制作
  • 《Chart.js 柱形图:全面解析与实战指南》
  • 物联网设备运维中的上下文感知自动化响应与策略动态调整
  • JAVA面试汇总(五)数据库(二)
  • 程序员的自我修养(三)