当前位置: 首页 > news >正文

微调训练时,ignore_empty_think是怎么保护模型的思考能力?

微调训练时,ignore_empty_think是怎么保护模型的思考能力?

flyfish

1. 配置文件定义(指定忽略的模式)

swift/plugin/loss_scale/config/ignore_empty_think.json 明确了需要忽略损失的令牌模式,将 </think>\n\n</think>\n\n(允许空白字符)的损失权重设为 0.0

{"<think>\\s*</think>\\s*": [0.0]
}

这里的正则表达式 </think>\\s*</think>\\s* 匹配 </think></think> 之间的空白字符(包括换行),即目标忽略的令牌序列。

2. 损失缩放类实现(关联配置与逻辑)

swift/plugin/loss_scale/loss_scale.py 中定义了 IgnoreEmptyThink 类,关联上述配置文件,并通过继承 REACTLossScale 实现损失权重计算:

class IgnoreEmptyThink(REACTLossScale):loss_scale_config = 'ignore_empty_think.json'

该类通过 loss_scale_config 指向上述配置文件,并复用 REACTLossScaleget_loss_scale 方法,最终调用 calculate_loss_scale 函数处理权重映射。

3. 命令行参数映射(关联参数与类)

同样在 swift/plugin/loss_scale/loss_scale.py 中,loss_scale_map 字典将命令行参数 ignore_empty_think 映射到 IgnoreEmptyThink 类,使得训练时指定 --loss_scale ignore_empty_think 能生效:

loss_scale_map = {...'ignore_empty_think': IgnoreEmptyThink,...
}

loss_scale.py

import os
from typing import List, Optional, Tupleimport jsonfrom swift.llm import Messages
from swift.llm.template.utils import ContextType
from .utils import calculate_loss_scaleclass LossScale:loss_scale_config = None  # pathdef _set_keep_loss_scale(self):self.keep_loss_scale = Falseif self.loss_scale_map is None:returnres = set()for v in self.loss_scale_map.values():res.update(v)if len(res - {0., 1.}) > 0:self.keep_loss_scale = Truedef __init__(self):if self.loss_scale_config is not None:path = os.path.dirname(os.path.abspath(__file__))config_path = os.path.join(path, 'config', self.loss_scale_config)with open(config_path, 'r', encoding='utf-8') as json_file:self.loss_scale_map = json.load(json_file)else:self.loss_scale_map = Noneself._set_keep_loss_scale()def get_loss_scale(self,context: str,context_type: ContextType,is_last_round: bool,*,query: Optional[str] = None) -> Tuple[List[str], List[float]]:"""Calculate loss scaleArgs:context: The input contextcontext_type: The type of this context, like response/suffix(eos token)/other(query/system, etc.)is_last_round: If this is the last round of messages.query: The query of this round.Returns:A tuple, list of context and list of loss_scales"""if context_type in {ContextType.RESPONSE, ContextType.SUFFIX}:loss_scale = 1.else:loss_scale = 0.return [context], [loss_scale]def __call__(self, context_list: List[str], context_types: List[ContextType], messages: Messages,**kwargs) -> Tuple[List[str], List[float]]:res_context_list = []res_loss_scale = []i = 0n_round = len(messages) // 2for context, context_type in zip(context_list, context_types):is_last_round = i + 1 == n_roundif context_type == ContextType.RESPONSE:query = messages[2 * i]['content']assert context == messages[2 * i + 1]['content']kwargs = {'query': query}i += 1if isinstance(context, dict) and 'loss_scale' in context:new_context = [[token] for token in context['token_ids']]loss_scale = context['loss_scale']else:if isinstance(context, dict) and 'token_ids' in context:context = context['token_ids']new_context, loss_scale = self.get_loss_scale(context, context_type, is_last_round, **kwargs)res_context_list += new_contextres_loss_scale += loss_scalereturn res_context_list, res_loss_scaleclass LastRoundLossScale(LossScale):def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool, **kwargs):if context_type == ContextType.RESPONSE:return [context], [float(is_last_round)]return super().get_loss_scale(context, context_type, is_last_round)class AgentFlanLossScale(LossScale):loss_scale_config = 'agentflan.json'def get_loss_scale(self,context: str,context_type: ContextType,is_last_round: bool,*,query: Optional[str] = None):if context_type == ContextType.RESPONSE and isinstance(context, str):return calculate_loss_scale(query, context, self.loss_scale_map['response'], self.loss_scale_map['query'])return super().get_loss_scale(context, context_type, is_last_round)class REACTLossScale(LossScale):loss_scale_config = 'react.json'def get_loss_scale(self,context: str,context_type: ContextType,is_last_round: bool,*,query: Optional[str] = None):if context_type == ContextType.RESPONSE and isinstance(context, str):return calculate_loss_scale(query, context, self.loss_scale_map)return super().get_loss_scale(context, context_type, is_last_round)class QwenLossScale(REACTLossScale):loss_scale_config = 'qwen.json'class HermesLossScale(REACTLossScale):loss_scale_config = 'hermes.json'class AlphaUmiLossScale(REACTLossScale):loss_scale_config = 'alpha_umi.json'class TrainAllLossScale(LossScale):def get_loss_scale(self, context: str, context_type: ContextType, *args, **kwargs):return [context], [1.]class IgnoreEmptyThink(REACTLossScale):loss_scale_config = 'ignore_empty_think.json'class LastRoundWithIgnoreEmptyThink(LossScale):loss_scale_config = 'ignore_empty_think.json'def get_loss_scale(self,context: str,context_type: ContextType,is_last_round: bool,*,query: Optional[str] = None):if context_type == ContextType.RESPONSE:if not is_last_round:return [context], [0.]elif isinstance(context, str):return calculate_loss_scale(query, context, self.loss_scale_map)return super().get_loss_scale(context, context_type, is_last_round)# Add your loss scale here, use --loss_scale xxx to train
loss_scale_map = {'last_round': LastRoundLossScale,'default': LossScale,'all': TrainAllLossScale,'ignore_empty_think': IgnoreEmptyThink,'last_round_with_ignore_empty_think': LastRoundWithIgnoreEmptyThink,# agent'react': REACTLossScale,'hermes': HermesLossScale,'qwen': QwenLossScale,'agentflan': AgentFlanLossScale,'alpha_umi': AlphaUmiLossScale,
}for k, v in loss_scale_map.items():v.name = k

4. 损失权重计算逻辑

swift/plugin/loss_scale/utils.py 中的 calculate_loss_scale 函数根据配置文件中的映射关系,对匹配到的令牌序列应用指定的损失权重(此处为 0.0),实现损失忽略:

from typing import Dict, List, Optional, Tuplefrom swift.llm.template import split_str_parts_bydef calculate_loss_scale(query: str,response: str,response_loss_scale_map: Dict[str, list],query_loss_scale_map: Optional[Dict[str, list]] = None) -> Tuple[List[str], List[float]]:"""Calculate the loss scale by splitting the agent response.This algorithm comes from paper: https://arxiv.org/pdf/2309.00986.pdfAgent response format:```textThought: you should always think about what to doAction: the action to take, should be one of the above tools[fire_recognition,fire_alert, call_police, call_fireman]Action Input: the input to the actionObservation: the result of the action... (this Thought/Action/Action Input/Observation can be repeated zero or more times)Thought: I now know the final answerFinal Answer: the final answer to the original input question```Returns:A tuple of agent response parts and their weights."""# query loss scale mapif query_loss_scale_map is not None:for key in query_loss_scale_map.keys():if key in query:if isinstance(query_loss_scale_map[key], (float, int)):query_loss_scale_map[key] = [query_loss_scale_map[key]]loss_scale_value = query_loss_scale_map[key][0]return [response], [float(loss_scale_value)]delimiters = [k for k, v in response_loss_scale_map.items() if len(v) == 2]if delimiters:agent_parts = split_str_parts_by(response, delimiters)else:regex_delimiters = [k for k, v in response_loss_scale_map.items() if len(v) == 1]agent_parts = split_str_parts_by(response, regex_delimiters, regex_mode=True)weights = []agent_content = []for c in agent_parts:if c['key'] in response_loss_scale_map:loss_scale = response_loss_scale_map[c['key']]assert len(loss_scale) in {1, 2}, f'loss_scale: {loss_scale}'if len(loss_scale) == 1:weights += loss_scaleagent_content.append(c['content'])else:weights += loss_scaleagent_content += [c['key'], c['content']]else:weights.append(1.)agent_content.append(c['content'])return agent_content, weights

5. 损失计算应用

swift/plugin/loss.py 中的 loss_scale_func 函数在计算损失时,会应用上述得到的损失权重,对权重为 0.0 的令牌忽略损失:

def loss_scale_func(outputs, labels, loss_scale=None, ...):loss, masks = ce_loss_func(outputs, labels)if loss_scale is not None:shift_scale = loss_scale[..., 1:].to(masks.device)shift_scale = shift_scale[masks]loss = (shift_scale * loss)  # 权重为 0 的部分损失被抵消...
# Use @register_loss_func to decorate your own loss, use --loss_type xxx to train
@register_loss_func(LossType.loss_scale)
def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:"""Loss funcArgs:outputs: The model outputslabels: The labelsloss_scale: The loss scalenum_items_in_batch: Number of tokens in the labels of gradient accumulation round that are not -100.Returns:"""loss, masks = ce_loss_func(outputs, labels)if loss_scale is not None:shift_scale = loss_scale[..., 1:].to(masks.device)shift_scale = shift_scale[masks]loss = (shift_scale * loss)if num_items_in_batch is None:loss = loss.mean()else:# compat transformers>=4.46loss = loss.sum() / num_items_in_batchreturn loss

loss_scale_func

loss_scale_func 函数是 ms-swift 中通过 @register_loss_func 注册的自定义损失函数,用于实现「带权重调节的交叉熵损失计算」,支持通过 loss_scale 参数对不同 token 的损失权重进行灵活调整。

1. 函数注册与作用

装饰器 @register_loss_func(LossType.loss_scale) 表示将该函数注册为名为 loss_scale 的损失函数,训练时可通过命令行参数 --loss_type loss_scale 启用。
功能:在常规交叉熵损失(CE Loss)的基础上,支持通过 loss_scale 参数对不同 token 的损失进行「加权缩放」,实现对特定 token 的训练关注度调整(例如忽略某些 token 的损失、或放大某些 token 的损失权重)。

2. 参数说明

outputs:模型的输出结果(通常包含 logits 等预测信息)。
labels:标签数据(shape 与模型输入序列长度一致,其中 -100 通常表示该位置的 token 不参与损失计算,即掩码)。
loss_scale:损失缩放因子(shape 与标签序列长度一致,每个值代表对应 token 的损失权重,例如 0 表示忽略该 token 的损失,2.0 表示放大该 token 的损失)。
num_items_in_batch:梯度累积轮次中非 -100 的有效标签数量。

3. 逻辑

步骤 1:计算基础交叉熵损失与掩码
loss, masks = ce_loss_func(outputs, labels)

ce_loss_func 是辅助函数,用于计算原始交叉熵损失:
对模型输出 logits 和标签 labels 进行「移位处理」(语言模型中,通常用前 n-1 个 token 预测第 n 个 token,因此 logits[:, :-1, :],标签取 [:, 1:])。
生成 masks 掩码:标记标签中「非 -100 的位置」(即需要参与损失计算的 token),形状与移位后的标签一致。
返回「未加权的交叉熵损失」(loss,仅包含有效 token 的损失)和「掩码」(masks)。

步骤 2:应用损失缩放因子
if loss_scale is not None:shift_scale = loss_scale[..., 1:].to(masks.device)shift_scale = shift_scale[masks]loss = (shift_scale * loss)

loss_scale 不为空(即指定了自定义权重),则对损失进行加权:
shift_scale = loss_scale[..., 1:]:对 loss_scale 进行移位(与步骤 1 中标签的移位对应,确保权重与 token 位置对齐)。
shift_scale = shift_scale[masks]:用 masks 筛选出有效 token 对应的权重(忽略 -100 位置的权重)。
loss = (shift_scale * loss):将原始损失与对应权重相乘,实现对不同 token 的损失缩放(例如权重为 0 时,该 token 的损失被忽略;权重为 2.0 时,损失被放大一倍)。

步骤 3:损失归一化
if num_items_in_batch is None:loss = loss.mean()
else:loss = loss.sum() / num_items_in_batch

对加权后的损失进行归一化,避免批量大小对损失值的影响:
num_items_in_batchNone:直接取所有有效 token 损失的平均值。
num_items_in_batch 存在(通常用于梯度累积场景):用损失总和除以 梯度累积轮次中的有效 token 总数

http://www.dtcms.com/a/318320.html

相关文章:

  • 自然语言处理的相关概念与问题
  • Redis面试精讲 Day 12:Redis Sentinel哨兵机制详解
  • 非机动车识别mAP↑28%!陌讯多模态融合算法在智慧交通的实战解析
  • PyTorch生成式人工智能——Hugging Face环境配置与应用详解
  • leetcode 3479. 水果成篮 III 中等
  • 74.5%登顶SWE-bench:Claude Opus 4.1如何重塑AI编程格局
  • AdGuard 安卓修改版:全方位广告拦截与隐私保护专家
  • 将英文PDF文件完整地翻译成中文的4类方式
  • 【机器学习篇】02day.python机器学习篇Scikit-learn基础操作
  • Kafka ISR机制和Raft区别:副本数优化的秘密
  • 浅谈对linux进程池的理解
  • 解决远程连接云服务器mysql编号1130问题
  • Vue Router 路由的创建和基本使用(超详细)
  • 《算法导论》第 7 章 - 快速排序
  • 服务器工作职责及核心组件详解
  • P1629 邮递员送信
  • 【RabbitMQ】高级特性—发送方确认详解
  • 【科研绘图系列】R语言绘制瀑布图
  • 院校机试刷题第二十一天|回顾代码随想录第十六天、
  • google官方性能文档:Android 动态性能框架优化散热和 CPU 性能-Thermal API部分
  • 短剧小程序系统开发:技术驱动下的内容创新之路
  • 2025年08月 GitHub 热门项目推荐
  • 1深度学习Pytorch-pytorch、tensor的创建、属性、设备和类型转换、数据转换、常见操作(获取元素、元素运算、形状改变、相乘、广播)
  • 【31】C++实战篇——C++ 从数组里找出相邻两个波谷之间的主波峰的y值和其对应下标i,考虑到波形的上升和下降情况
  • 【AI总结】python连接MySQL(5)- 高级数据库配置与连接器设计
  • go语言变量2
  • 开疆智能ModbusTCP转Profinet网关连接安川YRC1000机器人配置案例
  • 嵌入式处理器指令系统:精简指令集RISC与复杂指令集CISC的简介,及区别
  • Cervantes:面向渗透测试人员和红队的开源协作平台
  • 勇芳字体查看器 v1.0 免费版