当前位置: 首页 > news >正文

Google开源Tunix:JAX生态的LLM微调方案来了

JAX生态这两年在LLM训练这块追赶得挺快。PyTorch虽然还是主流但JAX在并行计算、TPU加速和API组合性上确实有些独特的优势。Google今天放出了Tunix这个库,专门做LLM的后训练——微调、强化学习、知识蒸馏这些都能搞。

Tunix是什么

这是个构建在JAX之上的后训练库,和Flax NNX集成得比较紧密。主要解决三类问题:

  • 监督微调(Supervised Fine-Tuning)
  • 强化学习(Reinforcement Learning)
  • 知识蒸馏(Knowledge Distillation)

现在还在早期开发阶段,功能在持续迭代,支持的模型也在慢慢扩展。

核心功能

监督微调:既支持全参数微调,也支持LoRA和Q-LoRA这类参数高效的方法。内存和算力受限的时候,PEFT方案还是挺实用的。

强化学习:实现了几个主流算法:PPO(Proximal Policy Optimization)、GRPO(Group Relative Policy Optimization)、还有token级别的GSPO。另外还有DPO(Direct Preference Optimization)做偏好对齐,这个在RLHF场景用得比较多。

知识蒸馏:支持几种策略,包括基于logit的概率分布匹配、注意力机制的转移和投影、跨架构的特征池化与投影。这几种方法在不同场景下各有用处。

库的设计比较模块化,组件可以自由组合,想扩展自定义流程也不算麻烦。分布式训练支持数据并行(DP)、完全分片数据并行(FSDP)和张量并行(TP),对TPU做了专门优化。

安装

三种装法:

从PyPI装(推荐):

 pip install "tunix[prod]"

或者直接从GitHub主分支:

 pip install git+https://github.com/google/tunix

开发模式从源码装:

 git clone https://github.com/google/tunix.git  cd tunix  pip install -e".[dev]"

TPU上用QLoRA微调Gemma

拿个英译法的任务来演示。用的是Google的Gemma 2B模型,跑在TPU v5e-8上。

环境准备

 pip install -q kagglehub safetensors tensorflow tensorflow_datasets tensorboardX transformers grain datasets  pip install -q git+https://github.com/google/tunix  pip install -q git+https://github.com/google/qwix  # Flax需要升级到最新版pip uninstall -q -y flax  pip install -q git+https://github.com/google/flax.git

完整流程

第一步,从Kaggle拉预训练checkpoint:

 import kagglehub  model_path = "google/gemma/flax/2b"  kaggle_ckpt_path = kagglehub.model_download(model_path)

初始化模型和tokenizer:

 from flax import nnx  
from tunix.models.gemma import model as gemma_lib, params as params_lib  
from tunix.generate import tokenizer_adapter as tokenizer_lib  base_model = gemma_lib.Transformer.from_params(  params_lib.load_and_format_params(kaggle_ckpt_path, "2b"),  version="2b"  
)  tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=f"{kaggle_ckpt_path}/tokenizer.model")

挂上QLoRA adapter:

 import qwix  lora_provider = qwix.LoraProvider(  module_path=".*(q_einsum|kv_einsum|proj)",  rank=16,  alpha=2.0,  weight_qtype="nf4"  # enable QLoRA quantization
)  lora_model = qwix.apply_lora_to_model(base_model, lora_provider)

这里rank设成16,alpha是2.0,weight_qtype指定nf4量化格式。

加载训练数据:

 from tunix.examples.data import translation_dataset  train_ds, validation_ds = translation_dataset.create_datasets(  dataset_name="mtnt/en-fr",  global_batch_size=16,  max_target_length=256,  num_train_epochs=3,  tokenizer=tokenizer,  )

用的是mtnt的英法平行语料,batch size 16,目标序列最长256个token。

开始训练:

 from tunix.sft import peft_trainer, utils  
import optax  trainer=peft_trainer.PeftTrainer(  lora_model,  optimizer=optax.adamw(1e-3),  config=peft_trainer.TrainingConfig(max_steps=100)  
)  trainer.train(train_ds, validation_ds)

优化器用AdamW,学习率1e-3,跑100步看看效果。

推理测试:

训练完直接用adapter过的模型做生成。Tunix提供了Sampler工具:

 from tunix.generate import sampler as sampler_lib  # initialize sampler
sampler = sampler_lib.Sampler(  transformer=lora_model,  tokenizer=tokenizer,  cache_config=sampler_lib.CacheConfig(  cache_size=256,  num_layers=base_model.num_layers,  num_kv_heads=base_model.num_kv_heads,  head_dim=base_model.head_dim,  ),  
)  # test prompts
input_batch = [  "Translate this into French:\nHello, my name is Morgane.\n",  "Translate this into French:\nThis dish is delicious!\n",  "Translate this into French:\nI am a student.\n",  "Translate this into French:\nHow's the weather today?\n",  
]  # generate predictions
out_data = sampler(  input_strings=input_batch,  max_generation_steps=20,  
)  # print results
for input_string, out_string in zip(input_batch, out_data.text):  print(f"----------------------")  print(f"Prompt:\n{input_string}")  print(f"Output:\n{out_string}")

如果用的是QLoRA,把lora_model换成qlora_model就行。生产环境可以考虑把adapter合并回基模型,推理延迟能降下来。

总结

100步训练之后,模型已经能生成一些翻译结果了,虽然质量还不够好。多训练一段时间,准确率会明显提升,而且内存开销和训练速度都保持在不错的水平。

Tunix现在还比较新,但已经能看出一些潜力。TPU优先的设计、模块化的API、LoRA/QLoRA支持、完整的分布式训练策略,这些对做LLM适配研究的人来说都挺有用。

后续应该会继续扩展支持的模型类型和训练算法,值得关注。

地址:https://avoid.overfit.cn/post/c434311d8a894922b6c52ea179cf8d97

作者:Abish Pius

http://www.dtcms.com/a/442447.html

相关文章:

  • YOLO入门教程(番外):机器视觉实践—Kaggle实战:深度学习实现狗的品种识别
  • Redis和MySQL的数据同步
  • 织梦网站转移服务器厦门网站建设网络推广
  • 嵌入式系统应用-触摸屏输入 LVGL 9.3版本
  • GPT-5最新特性和优点
  • 如何做幸运28网站代理做网站怎么销售
  • 洛谷P5365解题报告
  • C语言入门:数组的常见操作算法
  • 洛谷 P1054 [NOIP 2005 提高组] 等价表达式
  • 【左程云算法020】递归和master公式
  • php 怎么做 网站 图片福州外语外贸学院
  • 网站点击率东莞网站建设的公司
  • 【Linux】线程的互斥
  • 第三十九天:斐波那契数列
  • JAVA中用到的线程调度算法是什么?
  • 网站开发是无形资产如何在家里做网站
  • PySide6 打印或显示系统支持字体(QFontDataBase)
  • 网站开发框架怎么写wordpress前端会员中心开发教程
  • redis-zset数据类型的常见指令(sorted set)
  • 触摸未来2025.10.04:当神经网络拥有了内在记忆……
  • 生成对抗网络(GANs)深度解析:从原理、变体到前沿应用
  • 项目1:get_rga_thread线程和low_camera_venc_thread线程获取低分辨率VENC码流数据
  • 哪个网站做简历好musik wordpress视频
  • 【Linux】Linux管道与进程池深度解析:从原理到实战
  • Kotlin 协程之 Flow 操作符大全
  • python高级01——linux基础命令
  • 发帖那个网站好 做装修的怎么优化关键词排名优化
  • 分类信息网站建设价格西安公司注册网站
  • 数据要素X_第三批“数据要素×”典型案例——科技创新领域【附全文阅读】
  • 安装nginx时,yum 不从stable源安装