Google开源Tunix:JAX生态的LLM微调方案来了
JAX生态这两年在LLM训练这块追赶得挺快。PyTorch虽然还是主流但JAX在并行计算、TPU加速和API组合性上确实有些独特的优势。Google今天放出了Tunix这个库,专门做LLM的后训练——微调、强化学习、知识蒸馏这些都能搞。
Tunix是什么
这是个构建在JAX之上的后训练库,和Flax NNX集成得比较紧密。主要解决三类问题:
- 监督微调(Supervised Fine-Tuning)
- 强化学习(Reinforcement Learning)
- 知识蒸馏(Knowledge Distillation)
现在还在早期开发阶段,功能在持续迭代,支持的模型也在慢慢扩展。
核心功能
监督微调:既支持全参数微调,也支持LoRA和Q-LoRA这类参数高效的方法。内存和算力受限的时候,PEFT方案还是挺实用的。
强化学习:实现了几个主流算法:PPO(Proximal Policy Optimization)、GRPO(Group Relative Policy Optimization)、还有token级别的GSPO。另外还有DPO(Direct Preference Optimization)做偏好对齐,这个在RLHF场景用得比较多。
知识蒸馏:支持几种策略,包括基于logit的概率分布匹配、注意力机制的转移和投影、跨架构的特征池化与投影。这几种方法在不同场景下各有用处。
库的设计比较模块化,组件可以自由组合,想扩展自定义流程也不算麻烦。分布式训练支持数据并行(DP)、完全分片数据并行(FSDP)和张量并行(TP),对TPU做了专门优化。
安装
三种装法:
从PyPI装(推荐):
pip install "tunix[prod]"
或者直接从GitHub主分支:
pip install git+https://github.com/google/tunix
开发模式从源码装:
git clone https://github.com/google/tunix.git cd tunix pip install -e".[dev]"
TPU上用QLoRA微调Gemma
拿个英译法的任务来演示。用的是Google的Gemma 2B模型,跑在TPU v5e-8上。
环境准备
pip install -q kagglehub safetensors tensorflow tensorflow_datasets tensorboardX transformers grain datasets pip install -q git+https://github.com/google/tunix pip install -q git+https://github.com/google/qwix # Flax需要升级到最新版pip uninstall -q -y flax pip install -q git+https://github.com/google/flax.git
完整流程
第一步,从Kaggle拉预训练checkpoint:
import kagglehub model_path = "google/gemma/flax/2b" kaggle_ckpt_path = kagglehub.model_download(model_path)
初始化模型和tokenizer:
from flax import nnx
from tunix.models.gemma import model as gemma_lib, params as params_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib base_model = gemma_lib.Transformer.from_params( params_lib.load_and_format_params(kaggle_ckpt_path, "2b"), version="2b"
) tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=f"{kaggle_ckpt_path}/tokenizer.model")
挂上QLoRA adapter:
import qwix lora_provider = qwix.LoraProvider( module_path=".*(q_einsum|kv_einsum|proj)", rank=16, alpha=2.0, weight_qtype="nf4" # enable QLoRA quantization
) lora_model = qwix.apply_lora_to_model(base_model, lora_provider)
这里rank设成16,alpha是2.0,weight_qtype指定nf4量化格式。
加载训练数据:
from tunix.examples.data import translation_dataset train_ds, validation_ds = translation_dataset.create_datasets( dataset_name="mtnt/en-fr", global_batch_size=16, max_target_length=256, num_train_epochs=3, tokenizer=tokenizer, )
用的是mtnt的英法平行语料,batch size 16,目标序列最长256个token。
开始训练:
from tunix.sft import peft_trainer, utils
import optax trainer=peft_trainer.PeftTrainer( lora_model, optimizer=optax.adamw(1e-3), config=peft_trainer.TrainingConfig(max_steps=100)
) trainer.train(train_ds, validation_ds)
优化器用AdamW,学习率1e-3,跑100步看看效果。
推理测试:
训练完直接用adapter过的模型做生成。Tunix提供了Sampler工具:
from tunix.generate import sampler as sampler_lib # initialize sampler
sampler = sampler_lib.Sampler( transformer=lora_model, tokenizer=tokenizer, cache_config=sampler_lib.CacheConfig( cache_size=256, num_layers=base_model.num_layers, num_kv_heads=base_model.num_kv_heads, head_dim=base_model.head_dim, ),
) # test prompts
input_batch = [ "Translate this into French:\nHello, my name is Morgane.\n", "Translate this into French:\nThis dish is delicious!\n", "Translate this into French:\nI am a student.\n", "Translate this into French:\nHow's the weather today?\n",
] # generate predictions
out_data = sampler( input_strings=input_batch, max_generation_steps=20,
) # print results
for input_string, out_string in zip(input_batch, out_data.text): print(f"----------------------") print(f"Prompt:\n{input_string}") print(f"Output:\n{out_string}")
如果用的是QLoRA,把lora_model换成qlora_model就行。生产环境可以考虑把adapter合并回基模型,推理延迟能降下来。
总结
100步训练之后,模型已经能生成一些翻译结果了,虽然质量还不够好。多训练一段时间,准确率会明显提升,而且内存开销和训练速度都保持在不错的水平。
Tunix现在还比较新,但已经能看出一些潜力。TPU优先的设计、模块化的API、LoRA/QLoRA支持、完整的分布式训练策略,这些对做LLM适配研究的人来说都挺有用。
后续应该会继续扩展支持的模型类型和训练算法,值得关注。
地址:https://avoid.overfit.cn/post/c434311d8a894922b6c52ea179cf8d97
作者:Abish Pius