强化学习之GRPO
强化学习之GRPO
- 一、前置知识
- 1.1、CoT方法与扩展策略
- 1.2、训练数据与目标
- 1.3、GRPO优化算法详解
- 1.4、关键步骤与逻辑
- 1.5、算法流程总结
- 1.6、GPRO_loss
- 1.7、⼿撕GRPO KL
- 二、GRPO训练流程
- 2.1、模型加载
- 2.2、加载数据集
- 2.2.1、看一下数据的格式
- 2.3、定义输出模版
- 2.4、数据预处理
- 2.5、定义基于规则的奖励函数
- 2.6、GRPO训练
- 2.7、配置wandb
一、前置知识
1.1、CoT方法与扩展策略
针对数学问题,我们常采⽤以下五种 Prompt Engineering⼿段来促使模型给出答案:
⚡1、直接回答
● 直接要求模型输出最终答案(如数值或选项)。● 特点:实现简单,但缺乏推理过程,准确率较低。
⚡2、上下⽂学习(In-context Learning)
● 在输⼊中添加"问题-答案"⽰例,引导模型模仿输出模式。● ⽰例:在选择题场景中,通过"问题:... 答案:B"的模板引导选项输出。● 特点:依赖⽰例设计质量,适⽤于简单问题快速适配。
⚡3、思维链提⽰(Chain-of-Thought, CoT)
● 在Prompt中提供结构化推理⽰例,要求模型⽣成"问题分析→分步推导→最终答案"的完整过程。● 示例:○ 问题:x + 2 = 5,求x的值。 ○ 解答:将等式两边减2得x = 5 - 2,计算结果为x = 3,因此答案是\boxed{3}。● 优势:显式训练逻辑推导能⼒,显著提升复杂问题准确率。
⚡4、⾃动思维链(Auto-CoT)
● 使⽤通⽤Prompt(如"Let’s think step - by - step")触发模型⾃主⽣成推理链,⽆需⼈⼯设计详细⽰例。 ● 特点:降低⼈⼯⼲预成本,但依赖模型⾃⾝的推理能⼒。
⚡5、⻓思维链训练
● 核⼼:通过合成更复杂、多⻆度的推理链⽰例,并微调模型参数,使其内化结构化推理能⼒。 ● 实现⽅式:○ 提供包含多步骤验证、逻辑分⽀的⻓推理⽰例(如数学证明中的分情况讨论)。 ○ 使⽤单样本学习(One-shot)⽣成扩展推理链,强化模型对⻓逻辑路径的处理能⼒。 ● 差异点: ○ 基础⽅法(2-4)依赖输⼊Prompt设计,⽽LongCoT通过参数微调直接改变模型⾏为,实现更稳定的推理能⼒。
1.2、训练数据与目标
● 输⼊:数学问题 q (如“求解⽅程 x + 2 = 5 ”)
● 输出:模型需⽣成包含以下两部分的回答: ○ 推理过程(CoT):分步逻辑推导(如“将等式两边减2得 x = 3 ”) ○ 最终答案:以指定格式呈现(如 \boxed{3} )
● 正确性判定:仅通过最终答案与标签的数学等价性判断正确性(例如 0.5 与 1/2 视为等价)。
1.3、GRPO优化算法详解
GRPO(Group Relative Policy Optimization)是⼀种针对⽣成任务优化的强化学习算法,其核⼼⽬标
是在提升模型答案准确率的同时,约束策略更新幅度以保持稳定性。
1.4、关键步骤与逻辑
1、多组回答⽣成与采样策略● ⽣成过程:对每个问题q,模型⽣成G组不同回答 {o1 , o2 , ..., oG},每组回答⻓度可能不同(即∣oi∣可变)。● 逐token策略:在解码的第t步(⽣成第t个token时),模型基于当前策略πθ采样⽣成token oi,t,其概率分布为πθ (oi,t ∣ q, oi,<t) 。
2、KL散度约束的优化特性 ● 分布对⻬:通过 KL 散度DKL[πθ /πref] 约束当前策略πθ与参考策略πref(如预训练模型)的差异,避免策略突变。 ● ⽅差控制:KL 项的优化设计偏向降低策略更新的⽅差,确保训练稳定性。
3、组相对优势的计算逻辑 奖励标准化:对每组回答的奖励ri进⾏标准化(减去均值、除以标准差),得到相对优势值A^ i,t。⼀致性分配:同⼀组回答oi 中,所有token oi,t共享相同的优势值 A^ i,t,即整条推理链的奖励⼀致。
1.5、算法流程总结
1、⽣成阶段:对每个问题 q 采样 G 组回答,每组回答按token逐步⽣成 t = 1, 2, ..., ∣oi∣ 。
2、策略更新:● 通过概率⽐πθ /πold衡量新旧策略差异,并⽤ clip 函数限制更新幅度(超参数ϵ控制裁剪范围)。● 结合组相对优势 A^i,t 计算梯度,优先提升⾼奖励组的⽣成概率。
3、约束控制:KL 散度项确保策略更新不偏离参考分布,同时优化设计降低⽅差。
1.6、GPRO_loss
def grpo_loss(pi_logprob, pi_old_logprob, pi_ref_logprob, advantage, input_len, len_oi):# 策略梯度计算ratio = torch.exp(pi_logprob - pi_old_logprob)ratio_clip = torch.clamp(ratio, 1-0.2, 1+0.2)policy_grad = torch.min(ratio*advantage, ratio_clip*advantage)# KL约束计算kl = grpo_kl(pi_logprob, pi_ref_logprob) # 使⽤改进的KL公式# 加权平均与归⼀化loss = (-1/group_num) * (1/len_oi) * (policy_grad - 0.01*kl).sum()return loss
1.7、⼿撕GRPO KL
def grpo_kl(pi_logprob, pi_ref_logprob)# 利⽤对数概率计算,避免数值溢出ratio = (pi_ref_logprob - pi_logprob).exp()return ratio - (pi_ref_logprob - pi_logprob) - 1pi = torch.randn(3, 5) # batch,sequence 当前策略对数概率
pi_ref = torch.randn(3, 5) # batch,sequence 参考策略对数概率
pi_logprob = torch.log_softmax(pi, dim=1)
pi_ref_logprob = torch.log_softmax(pi_ref, dim=1)
print(grpo_kl(pi_logprob, pi_ref_logprob))
二、GRPO训练流程
2.1、模型加载
from modelscope import AutoModelForCausalLM, AutoTokenizer
model_name = "./Qwen2.5-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype="auto",device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
2.2、加载数据集
# 加载数据集
from datasets import load_dataset
data = load_dataset('gsm8k')
2.2.1、看一下数据的格式
data['train'][0]
抽出一条数据,进行tokenizer编码
prompt = data['train'][0]['question']# 按照 Qwen要求的格式构造数据
messages = [{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True
)
# 编码
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
text
# '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nNatalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?<|im_end|>\n<|im_start|>assistant\n'
model_inputs
# {'input_ids': tensor([[151644, 8948, 198, 2610, 525, 1207, 16948, 11, 3465,
# 553, 54364, 14817, 13, 1446, 525, 264, 10950, 17847,
# 13, 151645, 198, 151644, 872, 198, 45, 4212, 685,
# 6088, 26111, 311, 220, 19, 23, 315, 1059, 4780,
# 304, 5813, 11, 323, 1221, 1340, 6088, 4279, 438,
# 1657, 26111, 304, 3217, 13, 2585, 1657, 26111, 1521,
# 41601, 685, 4559, 30055, 304, 5813, 323, 3217, 30,
# 151645, 198, 151644, 77091, 198]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
# 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
# 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])})
# inputs_ids 对应embedding,token_type_ids表示属于第几个句子,attention_mask表示embedding中哪部分是真实有效的。
利用初始模型,看看输出
generated_ids = model.generate(**model_inputs,max_new_tokens=512
)
# 只提取答案部分
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
# 解码得到文本
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response
# 输出比较混乱,说明模型还不具备比较强的推理能力,急需强化学习拯救世界
2.3、定义输出模版
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer
# 定义系统模板
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
# 最终的输出应该是这个格式的
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
2.4、数据预处理
定义输出和答案的提取函数
# 按照输出模板,提取出模型的输出
def extract_xml_answer(text: str) -> str:answer = text.split("<answer>")[-1]answer = answer.split("</answer>")[0]return answer.strip()
# 按照数据集的格式,提取出答案
def extract_hash_answer(text: str) -> str | None:if "####" not in text:return Nonereturn text.split("####")[1].strip()
按照qwen的格式要求进行数据预处理
def get_gsm8k_questions(split = "train") -> Dataset:data = load_dataset('gsm8k')[split] # type: ignoredata = data.map(lambda x: { # type: ignore'prompt': [{'role':'system', 'content': SYSTEM_PROMPT},{'role': 'user', 'content': x['question']}],'answer': extract_hash_answer(x['answer'])}) # type: ignorereturn data # type: ignore
dataset=get_gsm8k_questions()
dataset['answer'][0]
# '72'
dataset['prompt'][0]
# [{'content': '\nRespond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n',
# 'role': 'system'},
# {'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
# 'role': 'user'}]
2.5、定义基于规则的奖励函数
判断答案是否正确的奖励函数
# 答案完全正确得2分(是按照要求的xml格式,且是整数,且答案正确),否则0分
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:responses = [completion[0]['content'] for completion in completions]q = prompts[0][-1]['content']extracted_responses = [extract_xml_answer(r) for r in responses]print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
判断答案是整数的奖励函数
# 答案是整数(是<answer></answer>得xml格式,且是整数)得0.5分,否则0分
def int_reward_func(completions, **kwargs) -> list[float]:responses = [completion[0]['content'] for completion in completions]extracted_responses = [extract_xml_answer(r) for r in responses]return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
```判断答案是否严格符合输出模板的奖励函数```
# 答案严格符合<reasoning>{reasoning}</reasoning><answer>{answer}</answer>的格式(换行也要正确)得0.5分,否则0分
def strict_format_reward_func(completions, **kwargs) -> list[float]:"""Reward function that checks if the completion has a specific format."""pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"responses = [completion[0]["content"] for completion in completions]matches = [re.match(pattern, r) for r in responses]return [0.5 if match else 0.0 for match in matches]
判断答案是否基本符合输出模板的奖励函数
# 答案没有强制要求换行符,只要标签之间有任何空白字符(包括空格或换行符)即可,符合则得0.5分,否则得0分
def soft_format_reward_func(completions, **kwargs) -> list[float]:"""Reward function that checks if the completion has a specific format."""pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"responses = [completion[0]["content"] for completion in completions]matches = [re.match(pattern, r) for r in responses]return [0.5 if match else 0.0 for match in matches]
判断答案中是否存在标签,标签位置是否正确的奖励函数
# 根据<reasoning><answer>标签是否出现,位置是否正确打分,0~0.5分
def count_xml(text) -> float:count = 0.0if text.count("<reasoning>\n") == 1:count += 0.125if text.count("\n</reasoning>\n") == 1:count += 0.125if text.count("\n<answer>\n") == 1:count += 0.125count -= len(text.split("\n</answer>\n")[-1])*0.001if text.count("\n</answer>") == 1:count += 0.125count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001return count
计算GRPO一组输出的奖励
#计算一个批次的xml得分
def xmlcount_reward_func(completions, **kwargs) -> list[float]:contents = [completion[0]["content"] for completion in completions]return [count_xml(c) for c in contents]
2.6、GRPO训练
训练参数设置
model_name = "Qwen2.5-0.5B-Instruct"
output_dir="outputs/Qwen2.5-0.5B-reasoning-GRPO"
run_name="Qwen2.5-0.5B-GRPO-gsm8k"
training_args = GRPOConfig(output_dir=output_dir, # 输出目录run_name=run_name, # wandb 中的项目名称learning_rate=5e-6, # 强化学习学习率设置的比较小adam_beta1 = 0.9, # adam优化器adam_beta2 = 0.99,weight_decay = 0.1, # 正则warmup_ratio = 0.1, # 学习率预热比例lr_scheduler_type='cosine', # 学习率衰减策略logging_steps=1,bf16=True, # 混合精度训练per_device_train_batch_size=8, # 总的batch = per_device_train_batch_size * 显卡数gradient_accumulation_steps=4, # 累计gradient_accumulation_steps个batch更新一次模型num_generations=8, # GRPO中每个q输出num_generations个omax_prompt_length=256, # 限制prompt长度max_completion_length=200, # 限制模型输出上限 num_train_epochs=1,save_steps=100, # 每save_steps步保存一次模型max_grad_norm=0.1, # 梯度裁剪log_on_each_node=False,use_vllm=False,vllm_gpu_memory_utilization=.3, # vllm 加速vllm_device="cuda:0",report_to="wandb"
)
Trainer设置,开始训练
model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.bfloat16,device_map=None
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
trainer = GRPOTrainer(model=model,processing_class=tokenizer,reward_funcs=[xmlcount_reward_func,soft_format_reward_func,strict_format_reward_func,int_reward_func,correctness_reward_func],args=training_args,train_dataset=dataset,
)
trainer.train()trainer.save_model(output_dir)
2.7、配置wandb
wandb.login(key="your key")
wandb.init(project="project name")