当前位置: 首页 > news >正文

极客时间:在 Google Colab 上尝试 Prefix Tuning

  每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领域的领跑者。点击订阅,与未来同行! 订阅:https://rengongzhineng.io/

Prefix Tuning 是当前最酷的参数高效微调(PEFT)方法之一,它可以在无需重新训练整个大模型的前提下对大语言模型(LLM)进行任务适配。为了理解它的工作原理,我们先了解下背景:传统微调需要更新模型的所有参数,成本高、计算密集。随后出现了 Prompting(提示学习),通过巧妙设计输入引导模型输出;Instruction Tuning(指令微调)进一步提升模型对任务指令的理解能力。再后来,LoRA(低秩适配)通过在网络中插入可训练的低秩矩阵实现任务适配,大大减少了可训练参数。

而 Prefix Tuning 则是另一种思路:它不会更改模型本体参数,也不插入额外矩阵,而是学习一小组“前缀向量”,将它们添加到每一层 Transformer 的输入中。这种方法轻巧快速,非常适合在 Google Colab 这样资源受限的环境中实践。

在这篇博客中,我们将一步步地在 Google Colab 上,使用 Hugging Face Transformers 和 peft 库完成 Prefix Tuning 的演示。


第一步:安装运行环境

!pip install transformers peft datasets accelerate bitsandbytes

使用的库包括:

  • transformers: 加载基础模型

  • peft: 实现 Prefix Tuning

  • datasets: 加载示例数据集

  • acceleratebitsandbytes: 优化训练性能


第二步:加载预训练模型和分词器

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, PrefixTuningConfig, TaskTypemodel_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

这里我们使用 GPT-2 作为演示模型,也可以替换为其他因果语言模型。


第三步:配置 Prefix Tuning

peft_config = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM,inference_mode=False,num_virtual_tokens=10
)model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

上述配置在每层 Transformer 中加入了 10 个可学习的虚拟前缀 token,我们将对它们进行微调。


第四步:加载并预处理 Yelp 数据集样本

from datasets import load_datasetdataset = load_dataset("yelp_review_full", cache_dir="/tmp/hf-datasets")
dataset = dataset.shuffle(seed=42).select(range(1000))def preprocess(example):tokens = tokenizer(example["text"], truncation=True, padding="max_length", max_length=128)return {"input_ids": tokens["input_ids"], "attention_mask": tokens["attention_mask"]}dataset = dataset.map(preprocess, batched=True)

第五步:使用 Prefix Tuning 训练模型

from transformers import TrainingArguments, Trainertraining_args = TrainingArguments(output_dir="./prefix_model",per_device_train_batch_size=4,num_train_epochs=1,logging_dir="./logs",logging_steps=10
)trainer = Trainer(model=model,args=training_args,train_dataset=dataset
)trainer.train()

第六步:保存并加载 Prefix Adapter

model.save_pretrained("prefix_yelp")

之后加载方法如下:

from peft import PeftModelbase_model = AutoModelForCausalLM.from_pretrained("gpt2")
prefix_model = PeftModel.from_pretrained(base_model, "prefix_yelp")

第七步:推理测试

训练完成后,我们可以使用调优后的模型进行生成测试。

input_text = "This restaurant was absolutely amazing!"
inputs = tokenizer(input_text, return_tensors="pt")output = prefix_model.generate(**inputs, max_new_tokens=50)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("\nGenerated Output:")
print(generated_text)

示例输出

这是在训练 3 个 epoch 并使用 20 个虚拟 token 后的输出示例:

This restaurant was absolutely amazing!, a the the the the the the the the the the the the the the the the the the the the the the the the the the the the the a., and the way. , and the was

虽然模型初步模仿了 Yelp 评论的风格,但输出仍重复性强、连贯性不足。为获得更好效果,可增加训练数据、延长训练周期,或使用更强的基础模型(如 gpt2-medium)。


完整代码

以下是经过改进的完整代码(包含更大前缀尺寸和更多训练轮次):

# 安装依赖
!pip install -U fsspec==2023.9.2
!pip install transformers peft datasets accelerate bitsandbytes# 加载模型与分词器
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, PrefixTuningConfig, TaskTypemodel_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)
model.config.pad_token_id = tokenizer.pad_token_id# 配置 Prefix Tuning
peft_config = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM,inference_mode=False,num_virtual_tokens=20  # 使用更多虚拟 token
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()# 加载和预处理数据集
from datasets import load_dataset
try:dataset = load_dataset("yelp_review_full", split="train[:1000]")
except:dataset = load_dataset("yelp_review_full")dataset = dataset["train"].select(range(1000))def preprocess(examples):tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)tokenized["labels"] = [[-100 if mask == 0 else token for token, mask in zip(input_ids, attention_mask)]for input_ids, attention_mask in zip(tokenized["input_ids"], tokenized["attention_mask"])]return tokenizeddataset = dataset.map(preprocess, batched=True, remove_columns=["text", "label"])# 配置训练参数
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(output_dir="./prefix_model",per_device_train_batch_size=4,num_train_epochs=3,  # 增加轮次logging_dir="./logs",logging_steps=10,report_to="none"
)trainer = Trainer(model=model,args=training_args,train_dataset=dataset
)trainer.train()# 保存和加载前缀
model.save_pretrained("prefix_yelp")
from peft import PeftModel
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
prefix_model = PeftModel.from_pretrained(base_model, "prefix_yelp")# 推理
input_text = "This restaurant was absolutely amazing!"
inputs = tokenizer(input_text, return_tensors="pt")
output = prefix_model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(output[0], skip_special_tokens=True))

不同微调技术如何选择?

方法特点适合场景
Prompting零样本/少样本,无需训练快速实验、通用模型调用
Instruction Tuning统一风格指导多个任务多任务模型,提示兼容性强
Full Fine-Tuning全模型更新,效果最好但成本高数据量大、计算资源充足场景
LoRA插入低秩矩阵,性能和效率平衡中等规模适配任务、部署灵活
Prefix Tuning训练前缀向量,模块化且轻量多任务共享底模、小规模快速适配

真实应用案例

  • 客服机器人:为不同产品线训练不同前缀,提高回答准确性

  • 法律/医学摘要:为专业领域调优风格和术语的理解

  • 多语种翻译:为不同语言对训练前缀,重用同一个基础模型

  • 角色对话代理:通过前缀改变语气(如正式、幽默、亲切)

  • SaaS 多租户服务:不同客户使用不同前缀,但共用主模型架构


总结

Prefix Tuning 是一种灵活且资源友好的方法,适合:

  • 有多个任务/用户但希望复用基础大模型的情况

  • 算力有限,但希望实现快速个性化的场景

  • 构建模块化、可热切换行为的 LLM 服务

建议从小任务入手测试,尝试不同 prefix 长度与训练轮次,并结合任务类型进行微调策略选择。

如果你想将此教程发布到 Colab、Hugging Face 或本地部署,欢迎继续交流!

相关文章:

  • 01.SQL语言概述
  • 算法-构造题
  • CSS悬停闪现与a标签嵌套的问题
  • vue3:十六、个人中心-修改密码
  • 《前端面试题:JavaScript 作用域深度解析》
  • leetcode Top100 189.轮转数组
  • Python Cookbook-7.13 生成一个字典将字段名映射为列号
  • 【学习笔记】TLS
  • 【threejs】每天一个小案例讲解:题外话篇
  • JDK 17 新特性
  • Java常见异常处理指南:IndexOutOfBoundsException与ClassCastException深度解析
  • Linux系统防火墙之iptables
  • LeetCode --- 452周赛
  • 基于FPGA的超声波显示水位距离,通过蓝牙传输水位数据到手机,同时支持RAM存储水位数据,读取数据。
  • Java八股文——并发编程「场景篇」
  • 基于n8n指定网页自动抓取解析入库工作流实战
  • Python学习(7) ----- Python起源
  • 【DAY43】复习日
  • JESD204B IP核接口实例,ADI的ADRV9009板卡,ZYNQ7045驱动实现2发2收。
  • Halo站点全站定时备份并通过邮箱存储备份
  • 东营市做网站/互联网营销策划案
  • 网站建设属于什么职能/搜狗推广助手
  • 做公司网站的费用计入什么科目/世界比分榜
  • 怎么做网站百度贴吧/深圳做网站的公司
  • 尤溪网站开发/关键词排名软件官网
  • 咖啡网站源码/中国十大营销策划机构