当前位置: 首页 > news >正文

【transformers.Trainer填坑】在自定义compute_metrics时logits和labels数据维度不一致问题

问题描述

我在使用 transformers.Trainer 训练我的模型时,我自定义了 compute_loss 函数和compute_metrics函数,我的模型是一个简单的二分类模型。

在自定义 compute_loss 时这样写的:

def compute_loss(self, model, inputs, return_outputs=False):
        """
        重写 Trainer.compute_loss:
        1) 提取字典中的 images, bboxes, locs, labels 等
        2) 用 vision_encoder 先处理图像,得到特征
        3) 用下游 model 做预测
        4) 计算并返回 loss
        """
        # 前向传播
        outputs, labels = model(**inputs)  # (bz, num_classes), or (bz*num_frames, num_classes)

        batch_size = inputs['labels'].shape[0]

        outputs = outputs.squeeze()  # (bz*num_frames)

        if batch_size == 1:
            outputs = outputs.unsqueeze(0)

        # 计算 loss
        loss = self.loss_func(outputs, labels.float())

        if self.state.global_step % 10 == 0 and self.state.global_step > 0:
            # 以50个step为间隔打印
            pred_probs = torch.sigmoid(outputs)
            preds = (pred_probs > 0.5).int()
            logger.info(f"[global_step={self.state.global_step}] preds={preds.tolist()} / labels={labels.tolist()} / loss={loss.item():.4f}")
            # compute metric
            accuracy = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())
            precision = precision_score(labels.cpu().numpy(), preds.cpu().numpy())
            recall = recall_score(labels.cpu().numpy(), preds.cpu().numpy())
            logger.info(f"[global_step={self.state.global_step}] accuracy={accuracy:.4f} / precision={precision:.4f} / recall={recall:.4f}")

        # 返回 (loss, outputs) 或者只返回 loss
        return (loss, outputs) if return_outputs else loss

于是就出现了报错,像这样的:

File "/opt/conda/lib/python3.9/site-packages/transformers/trainer.py", line 3754, in predict
    output = eval_loop(
  File "/opt/conda/lib/python3.9/site-packages/transformers/trainer.py", line 3966, in evaluation_loop
    metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
  File "/workspace/train/object_query/train.py", line 281, in compute_metrics
    correct_num = preds == labels
ValueError: operands could not be broadcast together with shapes (11720,) (12104,)
    output = eval_loop(
  File "/opt/conda/lib/python3.9/site-packages/transformers/trainer.py", line 3966, in evaluation_loop
    metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
  File "/workspace/train/object_query/train.py", line 281, in compute_metrics
    correct_num = preds == labels
ValueError: operands could not be broadcast together with shapes (11720,) (12104,)

原因

该问题是 transformers.Trainer 内部有一段对outputs的操作造成的:

if isinstance(outputs, dict):
    logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:
    logits = outputs[1:]

这里当 outputs 不是字典时,会把第一个位置的元素offset掉。

解决

Refer to here
所以,我们应该在返回那里这样写:

return (loss, {"label": outputs}) if return_outputs else loss

相关文章:

  • 通过沙箱技术测试识别潜在的威胁
  • 第一章:认识Tailwind CSS - 第三节 - Tailwind CSS 开发环境搭建和工具链配置
  • redis的哨兵模式和集群模式
  • 1.3 AI大模型应用浪潮解析:高校、硅谷与地缘政治的三角博弈
  • vscode调试和环境路径配置
  • 【微软- Entra ID】Microsoft Entra ID
  • 强化学习《初学者》--基础概念贝尔曼公式
  • 【Java】一文了解spring的三级缓存
  • 如何使用智能化RFID管控系统,对涉密物品进行安全有效的管理?
  • 在香橙派5 NPU上使用Yolov5
  • Ollama+Deepseek+chatbox快速部署属于自己的大模型
  • SSM课设-学生选课系统
  • 格式工厂 FormatFactory v5.18.便携版 ——多功能媒体文件转换工具
  • 玄机——第一章 应急响应-Linux入侵排查
  • 在 Go 中实现事件溯源:构建高效且可扩展的系统
  • Jupyter lab 无法导出格式 Save and Export Notebook As无法展开
  • CSS实现单行、多行文本溢出显示省略号(…)
  • JVM 类加载机制
  • QT无弹窗运行和只允许运行一个exe
  • 问卷数据分析|SPSS实操之独立样本T检验
  • 杭州网站开发工程师/杭州做seo的公司
  • 衢州php网站建设/网络营销推广活动有哪些
  • 怎么登录手机wordpress/南昌seo优化
  • 成都手机网站制作/优化大师官方下载
  • 温州建设公司网站/国外免费网站域名服务器查询
  • 天元建设集团有限公司邮编/seo排名优化价格