Baukit库使用教程--监督和修改LLM中间层输出
Baukit库使用教程–监督和修改LLM中间层输出
- 原始项目地址:https://github.com/davidbau/baukit
1. TraceDict概述
TraceDict(model, layers=target_layers, retain_input=True, retain_output=True, edit_output=None, stop=False)
layers
:想要监控的层- 一个list,里面元素为
model.layer.{layer_num}.{specific_module_name}
- 一个list,里面元素为
retain_input/output
:是否需要保存目标层的原始输入和输出edit_output
:若需要编辑模型某层输出,则这个参数为函数名- 这个函数需要可以接受
(output , layer)
两个参数的传递
- 这个函数需要可以接受
stop
:运行完监督的层后,模型停止运行
2. TraceDict使用方法
with TraceDict(model, layers=target_layers, retain_input=True, retain_output=True, edit_output=None, stop=False) as td:
2.1 监控输入输出
def monitor_layers(self, input_text: str, hook_layers: List[str], target_layer_idx: int = 5):input_tokens = self.tokenizer(input_text,return_tensors='pt').to(self.device)results = {}with TraceDict(self.model, layers=hook_layers, retain_input=True, retain_output=True) as td:model_output = self.model(**input_tokens)target_layer = hook_layers[target_layer_idx]before_layer = td[target_layer].input # 目标层之前的输入after_layer = td[target_layer].output # 目标层之后的输出# after_attn=td['model.layers.5.self_attn.q_proj'].output # 会报错,因为hook_layers里不包含self_attn.q_projresults = {'before_layer_shape': before_layer.shape,'after_layer_shape': after_layer.shape,}return results
2.2 修改模型输入输出
-
需要定义一个修改输出的函数。这个函数重点是里边的函数
-
按照baukit的规定,里面的函数必须接受output和layer_name两个参数:
- output参数(模型中间层的输出)
- layer_name(表示当前前传到模型的哪一模块了)
-
外面封装的函数参数可以随便定义,只要最终返回值是里面的函数即可
def wrap_func(edit_layer, device, idx=-1):def add_func(output, layer_name):current_layer = int(layer_name.split(".")[2])if current_layer == edit_layer: # 遍历到edit_layerprint("output_sum",output.sum())# 创建与output相同形状的扰动perturbation = torch.randn_like(output) * 0.1output += perturbation.to(device)print("output_sum",output.sum())return outputreturn add_func
2.3 完整监控和修改文件
import torch
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
from baukit import TraceDict
from typing import List, Dict, Anyclass ModelHandler:def __init__(self, model_path: str, device: str = "cuda"):self.model_path = model_pathself.device = deviceself.model = Noneself.tokenizer = Nonedef load_model(self):print(f"Loading model: {self.model_path}")self.model = AutoModelForCausalLM.from_pretrained(self.model_path, torch_dtype=torch.float16).to(self.device)self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)print("Model loaded!")return self.model, self.tokenizerdef monitor_layers(self, input_text: str, hook_layers: List[str], target_layer_idx: int = 5):input_tokens = self.tokenizer(input_text, return_tensors='pt').to(self.device)results = {}with TraceDict(self.model, layers=hook_layers, retain_input=True, retain_output=True) as rep:model_output = self.model(**input_tokens)target_layer = hook_layers[target_layer_idx]before_layer = rep[target_layer].input # 目标层之前的输入after_layer = rep[target_layer].output # 目标层之后的输出results = {'before_layer_shape': before_layer.shape,'after_layer_shape': after_layer.shape,}return resultsdef generate_hook_layers(layer_type: str, num_layers: int = 16):layer_indices = list(range(num_layers))if layer_type == "attn_o":return [f'model.layers.{l}.self_attn.o_proj' for l in layer_indices]elif layer_type == "attn_q":return [f'model.layers.{l}.self_attn.q_proj' for l in layer_indices]elif layer_type == "attn_k":return [f'model.layers.{l}.self_attn.k_proj' for l in layer_indices]elif layer_type == "attn_v":return [f'model.layers.{l}.self_attn.v_proj' for l in layer_indices]elif layer_type == "mlp_gate":return [f'model.layers.{l}.mlp.gate_proj' for l in layer_indices]elif layer_type == "mlp_up":return [f'model.layers.{l}.mlp.up_proj' for l in layer_indices]elif layer_type == "mlp_down":return [f'model.layers.{l}.mlp.down_proj' for l in layer_indices]else:raise ValueError(f"Unsupported layer type: {layer_type}")def wrap_func(edit_layer, device, idx=-1):def add_func(output, layer_name):current_layer = int(layer_name.split(".")[2])if current_layer == edit_layer: # 遍历到edit_layerprint("output_sum",output.sum())# 创建与output相同形状的扰动perturbation = torch.randn_like(output) * 0.1output += perturbation.to(device)print("output_sum",output.sum())return outputreturn add_funcdef main():parser = argparse.ArgumentParser(description='Model layer monitoring and editing tool')parser.add_argument('--model_path', type=str, default='YOUR_MODEL_PATH', help='Model path')parser.add_argument('--input_text', type=str, default='Hello, how are you?', help='Input text')parser.add_argument('--layer_type', type=str, default='attn_o', choices=['attn_o', 'attn_q', 'attn_k', 'attn_v', 'mlp_gate', 'mlp_up', 'mlp_down'],help='Layer type to monitor')parser.add_argument('--num_layers', type=int, default=16, help='Total number of layers')parser.add_argument('--target_layer_idx', type=int, default=5, help='Layer index to analyze')parser.add_argument('--mode', type=str, default='edit', choices=['monitor', 'edit'],help='Running mode: monitor or edit')parser.add_argument('--device', type=str, default='cuda', help='device')args = parser.parse_args()# Create model handlerhandler = ModelHandler(args.model_path, args.device)handler.load_model()# Generate hook layershook_layers = generate_hook_layers(args.layer_type, args.num_layers)print(f"Input text: {args.input_text}") # Input textprint(f"Monitor layer type: {args.layer_type}")print(f"Monitor layer number: {len(hook_layers)}")print(f"Target layer index: {args.target_layer_idx}")if args.mode == 'monitor':# Monitor moderesults = handler.monitor_layers(args.input_text, hook_layers, args.target_layer_idx)print(f"\n=== Monitor results ===")print(f"Target layer: {hook_layers[args.target_layer_idx]}")print(f"Layer input shape: {results['before_layer_shape']}") # Dimension is [bsz, num_tokens, dim_model]print(f"Layer output shape: {results['after_layer_shape']}") # Dimension is [bsz, num_tokens, dim_model]elif args.mode == 'edit':# Edit mode example: add noise to the target layerintervention_fn = wrap_func(args.target_layer_idx, handler.model.device)hook_layers = [f'model.layers.{l}.self_attn.o_proj' for l in range(args.num_layers)]with TraceDict(handler.model, layers=hook_layers, edit_output=intervention_fn):input_tokens = handler.tokenizer(args.input_text, return_tensors='pt').to(handler.device)handler.model(**input_tokens)if __name__ == "__main__":main()