微调训练时,ignore_empty_think是怎么保护模型的思考能力?
微调训练时,ignore_empty_think是怎么保护模型的思考能力?
flyfish
1. 配置文件定义(指定忽略的模式)
swift/plugin/loss_scale/config/ignore_empty_think.json
明确了需要忽略损失的令牌模式,将 </think>\n\n</think>\n\n
(允许空白字符)的损失权重设为 0.0
:
{"<think>\\s*</think>\\s*": [0.0]
}
这里的正则表达式 </think>\\s*</think>\\s*
匹配 </think>
与 </think>
之间的空白字符(包括换行),即目标忽略的令牌序列。
2. 损失缩放类实现(关联配置与逻辑)
swift/plugin/loss_scale/loss_scale.py
中定义了 IgnoreEmptyThink
类,关联上述配置文件,并通过继承 REACTLossScale
实现损失权重计算:
class IgnoreEmptyThink(REACTLossScale):loss_scale_config = 'ignore_empty_think.json'
该类通过 loss_scale_config
指向上述配置文件,并复用 REACTLossScale
的 get_loss_scale
方法,最终调用 calculate_loss_scale
函数处理权重映射。
3. 命令行参数映射(关联参数与类)
同样在 swift/plugin/loss_scale/loss_scale.py
中,loss_scale_map
字典将命令行参数 ignore_empty_think
映射到 IgnoreEmptyThink
类,使得训练时指定 --loss_scale ignore_empty_think
能生效:
loss_scale_map = {...'ignore_empty_think': IgnoreEmptyThink,...
}
loss_scale.py
import os
from typing import List, Optional, Tupleimport jsonfrom swift.llm import Messages
from swift.llm.template.utils import ContextType
from .utils import calculate_loss_scaleclass LossScale:loss_scale_config = None # pathdef _set_keep_loss_scale(self):self.keep_loss_scale = Falseif self.loss_scale_map is None:returnres = set()for v in self.loss_scale_map.values():res.update(v)if len(res - {0., 1.}) > 0:self.keep_loss_scale = Truedef __init__(self):if self.loss_scale_config is not None:path = os.path.dirname(os.path.abspath(__file__))config_path = os.path.join(path, 'config', self.loss_scale_config)with open(config_path, 'r', encoding='utf-8') as json_file:self.loss_scale_map = json.load(json_file)else:self.loss_scale_map = Noneself._set_keep_loss_scale()def get_loss_scale(self,context: str,context_type: ContextType,is_last_round: bool,*,query: Optional[str] = None) -> Tuple[List[str], List[float]]:"""Calculate loss scaleArgs:context: The input contextcontext_type: The type of this context, like response/suffix(eos token)/other(query/system, etc.)is_last_round: If this is the last round of messages.query: The query of this round.Returns:A tuple, list of context and list of loss_scales"""if context_type in {ContextType.RESPONSE, ContextType.SUFFIX}:loss_scale = 1.else:loss_scale = 0.return [context], [loss_scale]def __call__(self, context_list: List[str], context_types: List[ContextType], messages: Messages,**kwargs) -> Tuple[List[str], List[float]]:res_context_list = []res_loss_scale = []i = 0n_round = len(messages) // 2for context, context_type in zip(context_list, context_types):is_last_round = i + 1 == n_roundif context_type == ContextType.RESPONSE:query = messages[2 * i]['content']assert context == messages[2 * i + 1]['content']kwargs = {'query': query}i += 1if isinstance(context, dict) and 'loss_scale' in context:new_context = [[token] for token in context['token_ids']]loss_scale = context['loss_scale']else:if isinstance(context, dict) and 'token_ids' in context:context = context['token_ids']new_context, loss_scale = self.get_loss_scale(context, context_type, is_last_round, **kwargs)res_context_list += new_contextres_loss_scale += loss_scalereturn res_context_list, res_loss_scaleclass LastRoundLossScale(LossScale):def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool, **kwargs):if context_type == ContextType.RESPONSE:return [context], [float(is_last_round)]return super().get_loss_scale(context, context_type, is_last_round)class AgentFlanLossScale(LossScale):loss_scale_config = 'agentflan.json'def get_loss_scale(self,context: str,context_type: ContextType,is_last_round: bool,*,query: Optional[str] = None):if context_type == ContextType.RESPONSE and isinstance(context, str):return calculate_loss_scale(query, context, self.loss_scale_map['response'], self.loss_scale_map['query'])return super().get_loss_scale(context, context_type, is_last_round)class REACTLossScale(LossScale):loss_scale_config = 'react.json'def get_loss_scale(self,context: str,context_type: ContextType,is_last_round: bool,*,query: Optional[str] = None):if context_type == ContextType.RESPONSE and isinstance(context, str):return calculate_loss_scale(query, context, self.loss_scale_map)return super().get_loss_scale(context, context_type, is_last_round)class QwenLossScale(REACTLossScale):loss_scale_config = 'qwen.json'class HermesLossScale(REACTLossScale):loss_scale_config = 'hermes.json'class AlphaUmiLossScale(REACTLossScale):loss_scale_config = 'alpha_umi.json'class TrainAllLossScale(LossScale):def get_loss_scale(self, context: str, context_type: ContextType, *args, **kwargs):return [context], [1.]class IgnoreEmptyThink(REACTLossScale):loss_scale_config = 'ignore_empty_think.json'class LastRoundWithIgnoreEmptyThink(LossScale):loss_scale_config = 'ignore_empty_think.json'def get_loss_scale(self,context: str,context_type: ContextType,is_last_round: bool,*,query: Optional[str] = None):if context_type == ContextType.RESPONSE:if not is_last_round:return [context], [0.]elif isinstance(context, str):return calculate_loss_scale(query, context, self.loss_scale_map)return super().get_loss_scale(context, context_type, is_last_round)# Add your loss scale here, use --loss_scale xxx to train
loss_scale_map = {'last_round': LastRoundLossScale,'default': LossScale,'all': TrainAllLossScale,'ignore_empty_think': IgnoreEmptyThink,'last_round_with_ignore_empty_think': LastRoundWithIgnoreEmptyThink,# agent'react': REACTLossScale,'hermes': HermesLossScale,'qwen': QwenLossScale,'agentflan': AgentFlanLossScale,'alpha_umi': AlphaUmiLossScale,
}for k, v in loss_scale_map.items():v.name = k
4. 损失权重计算逻辑
swift/plugin/loss_scale/utils.py
中的 calculate_loss_scale
函数根据配置文件中的映射关系,对匹配到的令牌序列应用指定的损失权重(此处为 0.0
),实现损失忽略:
from typing import Dict, List, Optional, Tuplefrom swift.llm.template import split_str_parts_bydef calculate_loss_scale(query: str,response: str,response_loss_scale_map: Dict[str, list],query_loss_scale_map: Optional[Dict[str, list]] = None) -> Tuple[List[str], List[float]]:"""Calculate the loss scale by splitting the agent response.This algorithm comes from paper: https://arxiv.org/pdf/2309.00986.pdfAgent response format:```textThought: you should always think about what to doAction: the action to take, should be one of the above tools[fire_recognition,fire_alert, call_police, call_fireman]Action Input: the input to the actionObservation: the result of the action... (this Thought/Action/Action Input/Observation can be repeated zero or more times)Thought: I now know the final answerFinal Answer: the final answer to the original input question```Returns:A tuple of agent response parts and their weights."""# query loss scale mapif query_loss_scale_map is not None:for key in query_loss_scale_map.keys():if key in query:if isinstance(query_loss_scale_map[key], (float, int)):query_loss_scale_map[key] = [query_loss_scale_map[key]]loss_scale_value = query_loss_scale_map[key][0]return [response], [float(loss_scale_value)]delimiters = [k for k, v in response_loss_scale_map.items() if len(v) == 2]if delimiters:agent_parts = split_str_parts_by(response, delimiters)else:regex_delimiters = [k for k, v in response_loss_scale_map.items() if len(v) == 1]agent_parts = split_str_parts_by(response, regex_delimiters, regex_mode=True)weights = []agent_content = []for c in agent_parts:if c['key'] in response_loss_scale_map:loss_scale = response_loss_scale_map[c['key']]assert len(loss_scale) in {1, 2}, f'loss_scale: {loss_scale}'if len(loss_scale) == 1:weights += loss_scaleagent_content.append(c['content'])else:weights += loss_scaleagent_content += [c['key'], c['content']]else:weights.append(1.)agent_content.append(c['content'])return agent_content, weights
5. 损失计算应用
swift/plugin/loss.py
中的 loss_scale_func
函数在计算损失时,会应用上述得到的损失权重,对权重为 0.0
的令牌忽略损失:
def loss_scale_func(outputs, labels, loss_scale=None, ...):loss, masks = ce_loss_func(outputs, labels)if loss_scale is not None:shift_scale = loss_scale[..., 1:].to(masks.device)shift_scale = shift_scale[masks]loss = (shift_scale * loss) # 权重为 0 的部分损失被抵消...
# Use @register_loss_func to decorate your own loss, use --loss_type xxx to train
@register_loss_func(LossType.loss_scale)
def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:"""Loss funcArgs:outputs: The model outputslabels: The labelsloss_scale: The loss scalenum_items_in_batch: Number of tokens in the labels of gradient accumulation round that are not -100.Returns:"""loss, masks = ce_loss_func(outputs, labels)if loss_scale is not None:shift_scale = loss_scale[..., 1:].to(masks.device)shift_scale = shift_scale[masks]loss = (shift_scale * loss)if num_items_in_batch is None:loss = loss.mean()else:# compat transformers>=4.46loss = loss.sum() / num_items_in_batchreturn loss
loss_scale_func
loss_scale_func
函数是 ms-swift 中通过 @register_loss_func
注册的自定义损失函数,用于实现「带权重调节的交叉熵损失计算」,支持通过 loss_scale
参数对不同 token 的损失权重进行灵活调整。
1. 函数注册与作用
装饰器 @register_loss_func(LossType.loss_scale)
表示将该函数注册为名为 loss_scale
的损失函数,训练时可通过命令行参数 --loss_type loss_scale
启用。
功能:在常规交叉熵损失(CE Loss)的基础上,支持通过 loss_scale
参数对不同 token 的损失进行「加权缩放」,实现对特定 token 的训练关注度调整(例如忽略某些 token 的损失、或放大某些 token 的损失权重)。
2. 参数说明
outputs
:模型的输出结果(通常包含 logits
等预测信息)。
labels
:标签数据(shape 与模型输入序列长度一致,其中 -100
通常表示该位置的 token 不参与损失计算,即掩码)。
loss_scale
:损失缩放因子(shape 与标签序列长度一致,每个值代表对应 token 的损失权重,例如 0
表示忽略该 token 的损失,2.0
表示放大该 token 的损失)。
num_items_in_batch
:梯度累积轮次中非 -100
的有效标签数量。
3. 逻辑
步骤 1:计算基础交叉熵损失与掩码
loss, masks = ce_loss_func(outputs, labels)
ce_loss_func
是辅助函数,用于计算原始交叉熵损失:
对模型输出 logits
和标签 labels
进行「移位处理」(语言模型中,通常用前 n-1
个 token 预测第 n
个 token,因此 logits
取 [:, :-1, :]
,标签取 [:, 1:]
)。
生成 masks
掩码:标记标签中「非 -100
的位置」(即需要参与损失计算的 token),形状与移位后的标签一致。
返回「未加权的交叉熵损失」(loss
,仅包含有效 token 的损失)和「掩码」(masks
)。
步骤 2:应用损失缩放因子
if loss_scale is not None:shift_scale = loss_scale[..., 1:].to(masks.device)shift_scale = shift_scale[masks]loss = (shift_scale * loss)
若 loss_scale
不为空(即指定了自定义权重),则对损失进行加权:
shift_scale = loss_scale[..., 1:]
:对 loss_scale
进行移位(与步骤 1 中标签的移位对应,确保权重与 token 位置对齐)。
shift_scale = shift_scale[masks]
:用 masks
筛选出有效 token 对应的权重(忽略 -100
位置的权重)。
loss = (shift_scale * loss)
:将原始损失与对应权重相乘,实现对不同 token 的损失缩放(例如权重为 0
时,该 token 的损失被忽略;权重为 2.0
时,损失被放大一倍)。
步骤 3:损失归一化
if num_items_in_batch is None:loss = loss.mean()
else:loss = loss.sum() / num_items_in_batch
对加权后的损失进行归一化,避免批量大小对损失值的影响:
若 num_items_in_batch
为 None
:直接取所有有效 token 损失的平均值。
若 num_items_in_batch
存在(通常用于梯度累积场景):用损失总和除以 梯度累积轮次中的有效 token 总数