【速写】policy与reward分词器冲突问题(附XAI阅读推荐)
TRL的PPOTrainer
实现存在一个很严重的问题,它的model
和reward_model
两个参数所使用的分词器是必须相同的,否则一定会报错。
之前已经提过,PPOTrainer
要求训练数据(train_dataset
参数)必须包含input_ids
字段,这个跟SFTTrainer
,DPOTrainer
,GRPOTrainer
都不同,查了一下源码(trl/trainer/ppo_trainer.py
),发现只有PPOTrainer
重写了父类transformers.Trainer
的.train
方法,其他三个都是直接继承,因此它们三个理论上适配相同的数据集参数,即text
或者input
+target
或者prompt+completion
,这个就很神奇,因为DPOTrainer
也没有重写.train
方法,但是显然DPOTrainer
的字段是要与其他不同的。
简单看一下PPOTrainer
的.train
方法:
def train(self):args = self.argsaccelerator = self.acceleratoroptimizer = self.optimizermodel = self.modelref_policy = self.ref_modelreward_model = self.reward_modelprocessing_class = self.processing_classdataloader = self.dataloaderdevice = accelerator.devicedef repeat_generator():while True:yield from dataloaderiter_dataloader = iter(repeat_generator())generation_config = GenerationConfig(max_new_tokens=args.response_length,temperature=(args.temperature + 1e-7),top_k=0.0,top_p=1.0,do_sample=True,)accelerator.print("===training policy===")start_time = time.time()stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)approxkl_stats = torch.zeros(stats_shape, device=device)pg_clipfrac_stats = torch.zeros(stats_shape, device=device)pg_loss_stats = torch.zeros(stats_shape, device=device)vf_loss_stats = torch.zeros(stats_shape, device=device)vf_clipfrac_stats = torch.zeros(stats_shape, device=device)entropy_stats = torch.zeros(stats_shape, device=device)ratio_stats = torch.zeros(stats_shape, device=device)model.train()# trainer state initializationself.state.global_step = 0self.state.episode = 0self.state.max_steps = args.num_total_batchesself.state.num_train_epochs = args.total_episodes / self.train_dataset_len# Compute absolute values for logging, eval, and save if given as ratioif args.logging_steps is not None:if args.logging_steps < 1:self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)else:self.state.logging_steps = args.logging_stepsif args.eval_steps is not None:if args.eval_steps < 1:self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)else:self.state.eval_steps = args.eval_stepsif args.save_steps is not None:if args.save_steps < 1:self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)else:self.state.save_steps = args.save_stepsself.control = self.callback_handler.on_train_begin(args, self.state, self.control)# backward compatibilityif self.is_deepspeed_enabled:self.deepspeed = self.modelself.model_wrapped = self.modelfor update in range(1, args.num_total_batches + 1):self.state.episode += 1 * args.batch_sizedata = next(iter_dataloader)with torch.no_grad():queries = data["input_ids"].to(device)context_length = queries.shape[1]responses = []postprocessed_responses = []logprobs = []ref_logprobs = []scores = []sequence_lengths = []values = []with unwrap_model_for_generation(self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation) as unwrapped_model:query_responses, logitss = batch_generation(unwrapped_model.policy,queries,args.local_rollout_forward_batch_size,processing_class.pad_token_id,generation_config,)for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):query = queries[i : i + args.local_rollout_forward_batch_size]query_response = query_responses[i : i + args.local_rollout_forward_batch_size]response = query_response[:, context_length:]logits = logitss[i : i + args.local_rollout_forward_batch_size]logprob = selective_log_softmax(logits, response)del logitstorch.cuda.empty_cache()if ref_policy is None:with self.null_ref_context():ref_output = forward(model.policy, query_response, processing_class.pad_token_id)else:ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)ref_logits = ref_output.logits[:, context_length - 1 : -1]ref_logits /= args.temperature + 1e-7ref_logprob = selective_log_softmax(ref_logits, response)del ref_output, ref_logitstorch.cuda.empty_cache()# Response Processing 1. truncate response after the first occurrence of `stop_token_id`postprocessed_response = responseif self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0postprocessed_response = truncate_response(self.stop_token_id, processing_class.pad_token_id, response)# Response Processing 2. run reward model on the truncated responsespostprocessed_query_response = torch.cat((query, postprocessed_response), 1)sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1unwrapped_value_model = accelerator.unwrap_model(model).value_modelfull_value, _, _ = get_reward(unwrapped_value_model, query_response, processing_class.pad_token_id, context_length)value = full_value[:, context_length - 1 : -1].squeeze(-1)_, score, _ = get_reward(reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length)responses.append(response)postprocessed_responses.append(postprocessed_response)logprobs.append(logprob)ref_logprobs.append(ref_logprob)sequence_lengths.append(sequence_length)scores.append(score)values.append(value)responses = torch.cat(responses, 0)postprocessed_responses = torch.cat(postprocessed_responses, 0)logprobs = torch.cat(logprobs, 0)ref_logprobs = torch.cat(ref_logprobs, 0)sequence_lengths = torch.cat(sequence_lengths, 0)scores = torch.cat(scores, 0)values = torch.cat(values, 0)del (logprob, ref_logprob, full_value, value, score, unwrapped_model)torch.cuda.empty_cache()gc.collect()# Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id# Completions not passing that filter will receive a lower score.contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)if self.args.missing_eos_penalty is not None:scores[~contain_eos_token] -= self.args.missing_eos_penalty# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTwresponse_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)padding_mask = response_idxs > sequence_lengths.unsqueeze(1)logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)sequence_lengths_p1 = sequence_lengths + 1padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))values = torch.masked_fill(values, padding_mask_p1, 0)# 4. compute rewards# Formula used by http://joschu.net/blog/kl-approx.html for the k1 and k3 estimatorslogr = ref_logprobs - logprobskl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr # Else statement is k3non_score_reward = -args.kl_coef * klrewards = non_score_reward.clone()actual_start = torch.arange(rewards.size(0), device=rewards.device)actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)rewards[[actual_start, actual_end]] += scores# 5. whiten rewardsif args.whiten_rewards:rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)rewards = torch.masked_fill(rewards, padding_mask_p1, 0)# 6. compute advantages and returnslastgaelam = 0advantages_reversed = []gen_length = responses.shape[1]for t in reversed(range(gen_length)):nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]lastgaelam = delta + args.gamma * args.lam * lastgaelamadvantages_reversed.append(lastgaelam)advantages = torch.stack(advantages_reversed[::-1], axis=1)returns = advantages + valuesadvantages = masked_whiten(advantages, ~padding_mask)advantages = torch.masked_fill(advantages, padding_mask, 0)torch.cuda.empty_cache()# Do multiple epochs of PPO training, with a fresh random shuffle in each epochfor ppo_epoch_idx in range(args.num_ppo_epochs):b_inds = np.random.permutation(args.local_batch_size)minibatch_idx = 0for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):mini_batch_end = mini_batch_start + args.local_mini_batch_sizemini_batch_inds = b_inds[mini_batch_start:mini_batch_end]gradient_accumulation_idx = 0for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):with accelerator.accumulate(model):micro_batch_end = micro_batch_start + args.per_device_train_batch_sizemicro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]mb_advantage = advantages[micro_batch_inds]mb_responses = responses[micro_batch_inds]mb_query_responses = query_responses[micro_batch_inds]mb_logprobs = logprobs[micro_batch_inds]mb_return = returns[micro_batch_inds]mb_values = values[micro_batch_inds]output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)logits = output.logits[:, context_length - 1 : -1]logits /= args.temperature + 1e-7new_logprobs = selective_log_softmax(logits, mb_responses)new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB)vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)vpredclipped = torch.clamp(vpred,mb_values - args.cliprange_value,mb_values + args.cliprange_value,)vf_losses1 = torch.square(vpred - mb_return)vf_losses2 = torch.square(vpredclipped - mb_return)vf_loss_max = torch.max(vf_losses1, vf_losses2)vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])vf_clipfrac = masked_mean((vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds])logprobs_diff = new_logprobs - mb_logprobsratio = torch.exp(logprobs_diff)pg_losses = -mb_advantage * ratiopg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)pg_loss_max = torch.max(pg_losses, pg_losses2)pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])loss = pg_loss + args.vf_coef * vf_lossaccelerator.backward(loss)optimizer.step()optimizer.zero_grad()with torch.no_grad():pg_clipfrac = masked_mean((pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds])prob_dist = torch.nn.functional.softmax(logits, dim=-1)entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)approxkl = 0.5 * (logprobs_diff**2).mean()approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxklpg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (pg_clipfrac)pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_lossvf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_lossvf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (vf_clipfrac)entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()gradient_accumulation_idx += 1minibatch_idx += 1# del everything and empty cache# fmt: offdel (output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,)# fmt: ontorch.cuda.empty_cache()with torch.no_grad():mean_kl = kl.sum(1).mean()mean_entropy = (-logprobs).sum(1).mean()mean_non_score_reward = non_score_reward.sum(1).mean()rlhf_reward = mean_non_score_reward + scores.mean()eps = int(self.state.episode / (time.time() - start_time))metrics = {}metrics["eps"] = epsmetrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()metrics["objective/non_score_reward"] = (self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item())metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()metrics["lr"] = self.lr_scheduler.get_last_lr()[0]metrics["episode"] = self.state.episodeself.state.epoch = self.state.episode / self.train_dataset_len # used by self.logself.state.global_step += 1self.log(metrics)self.lr_scheduler.step()self.control = self.callback_handler.on_step_end(args, self.state, self.control)if self.control.should_save:self._save_checkpoint(model, trial=None)self.control = self.callback_handler.on_save(self.args, self.state, self.control)del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_rewardtorch.cuda.empty_cache()gc.collect()if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:self.generate_completions(sampling=True)torch.cuda.empty_cache()del (query_responses,responses,postprocessed_responses,logprobs,ref_logprobs,values,sequence_lengths,contain_eos_token,sequence_lengths_p1,response_idxs,padding_mask,padding_mask_p1,rewards,actual_start,actual_end,advantages,returns,)torch.cuda.empty_cache()# HF trainer specificsself.control = self.callback_handler.on_train_end(args, self.state, self.control)if self.control.should_save:self._save_checkpoint(model, trial=None, metrics=None)self.control = self.callback_handler.on_save(self.args, self.state, self.control)
逻辑很清晰(这里很明显看到了queries = data["input_ids"].to(device)
,即要求有input_ids
字段):
-
首先在policy上做生成,即采样得到
query_response
:query_responses, logitss = batch_generation(unwrapped_model.policy,queries,args.local_rollout_forward_batch_size,processing_class.pad_token_id,generation_config, )
-
然后把
query_response
拿去计算奖励和价值:full_value, _, _ = get_reward(unwrapped_value_model, query_response, processing_class.pad_token_id, context_length ) value = full_value[:, context_length - 1 : -1].squeeze(-1) _, score, _ = get_reward(reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length )
问题就出在这儿了,reward和policy使用了相同的processing_class
,即分词器,这个问题很难修改,除非你先把query_responses
还原成文本,然后再用reward_model
的分词器重新分一次词,否则这里就是强制要求它们的分词器相同的。
在PPO官方示例中,使用的策略模型是EleutherAI/pythia-1b-deduped
,而PPOConfig
默认的奖励模型是EleutherAI/pythia-160m
,这两个的分词器刚好一样,所以没有出问题。
如果现在想用其他策略模型(即换个LLM来训练),那么就必须找到和它相同基座的奖励模型,比如对于Qwen系列,TRL是提供了一个Qwen基座的的奖励模型的:trl-lib/Qwen2-0.5B-Reward
(https://huggingface.co/trl-lib/Qwen2-0.5B-Reward)
然后推荐一本XAI的好书:Explainable AI with Python,电子版挂在下面了:
通过网盘分享的文件:explainable-ai-with-python.pdf
链接: https://pan.baidu.com/s/13nS8mNMhif62o0F3cG0X5A?pwd=avdu 提取码: avdu 复制这段内容后打开百度网盘手机App,操作更方便哦
这个对XAI概括的很专业,之前提过的那篇ICLR做Learning Dynamics的工作arXiv:2407.10490,我一直觉得他没做完,它也是从MNIST入手,明明开头说要做来了一个新的sample后对老的sample有何影响,但是做到后面变成了来了一个新的sample后对这个模型新的sample的预测有何变化?这个就很奇怪,前后说不通了。
在这个书中DeepLift的例子中,给了一个MNIST非常好的一个解释案例:
利用这种可视化的图,来说明图片的每个像素对各个预测类别标签的贡献,这个就很有说服力,代码:
# -*- coding: utf8 -*-
from __future__ import print_function
import shape
import numpy
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
from keras import backend as K# DeepShap using DeepExplainer
# ... include code from https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py
batch_size = 128
num_classes = 10
epochs = 1# input image dimensions
img_rows, img_cols = 28, 28# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()if K.image_data_format() == "channel_first":x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)input_shape = (1, img_rows, img_cols)
else:x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)input_shape = (img_rows, img_cols, 1)x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 255
x_test /= 255print("x_train shape:", x_train.shape)
print("x_test shape:", x_test.shape)# convert class vectors to binary class metrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation="relu", input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation="relu"))
model.add(MaxPooling2D(pool_size=(2, 2)))
# model.add(Dropout(.25))
model.add(Flatten())
model.add(Dense(128, activation="relu"))
model.add(Dropout(.5))
model.add(Dense(num_classes, activation="softmax"))
model.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.Adadelta(), metrics=["accuracy"])
model.fix(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])# ----------------------------------------------------------------------# DeepShap using DeepExplainer
# select a set of background examples to take an expectation over
background = x_train[numpy.random.choice(x_train.shape[0], 100, replace=False)]
# explain predictions of the model on four images
e = shap.DeepExplainer(model, background)
# ... or pass tensors directly
# e = shap.DeepExplainer((model.layers[0].input, model.layers[-1].output), background)
shap_values = e.shap_values(x_test[1:5])
# plot the feature attributions
shap.image_plot(shap_values, -x_test[1:5])