当前位置: 首页 > news >正文

记录一个大模型逐层微调计算损失输出少了一个维度的小bug

1.假如针对的对象是linear

def _compute_mse_on_batch(layer: nn.Module, batch_iter: Iterator[Tuple[torch.Tensor, torch.Tensor]], **kwargs
) -> torch.Tensor:inps_batch, outs_batch = next(batch_iter)print("Initial inps_batch:", inps_batch.shape)print("Initial outs_batch:", outs_batch.shape)# print("Any NaNs in inps_batch:", torch.isnan(inps_batch).any())# print("Any NaNs in outs_batch:", torch.isnan(outs_batch).any())# if inps_batch.shape[0] != 1:#     for name, value in list(kwargs.items()):#         if isinstance(value, torch.Tensor) and value.shape[0] == 1:#             if name not in ("attention_mask", "position_ids"):#                 warnings.warn(f"Tiling an unexpected kwarg {name} over batch size; make sure this is valid.")#             repeats = [len(inps_batch)] + [1 for _ in range(value.ndim - 1)]#             kwargs[name] = value.tile(*repeats)outs_prediction= layer(inps_batch, **kwargs)assert outs_prediction.shape == outs_batch.shapeloss = F.mse_loss(outs_prediction, outs_batch)# print("Computed loss:", loss.item())return loss

2.假如针对的对象是transformer

def _compute_mse_on_batch(layer: nn.Module, batch_iter: Iterator[Tuple[torch.Tensor, torch.Tensor]], **kwargs
) -> torch.Tensor:inps_batch, outs_batch = next(batch_iter)print("Initial inps_batch:", inps_batch.shape)print("Initial outs_batch:", outs_batch.shape)# print("Any NaNs in inps_batch:", torch.isnan(inps_batch).any())# print("Any NaNs in outs_batch:", torch.isnan(outs_batch).any())# if inps_batch.shape[0] != 1:#     for name, value in list(kwargs.items()):#         if isinstance(value, torch.Tensor) and value.shape[0] == 1:#             if name not in ("attention_mask", "position_ids"):#                 warnings.warn(f"Tiling an unexpected kwarg {name} over batch size; make sure this is valid.")#             repeats = [len(inps_batch)] + [1 for _ in range(value.ndim - 1)]#             kwargs[name] = value.tile(*repeats)outs_prediction, *_unused = layer(inps_batch, **kwargs)# print("outs_prediction device in loss:", outs_prediction.device)# print(" outs_batch device in loss:",  outs_batch.device)assert outs_prediction.shape == outs_batch.shapeloss = F.mse_loss(outs_prediction, outs_batch)# print("Computed loss:", loss.item())return loss

值得注意的是,假如我们在线性层里面写的是: outs_prediction, *_unused = layer(inps_batch, **kwargs),由于线性层返回没有 *_unused ,会导致我们输入[batchsize,input]时候得到的输出不是我们期望的[batchsize,output],而只会有[output]


文章转载自:

http://ED5xxgFc.yrjym.cn
http://Gbb2dYCJ.yrjym.cn
http://w1CkEAOy.yrjym.cn
http://nKV6pNdG.yrjym.cn
http://e62idkNm.yrjym.cn
http://mAygzXfC.yrjym.cn
http://Rx8ikMew.yrjym.cn
http://Z253rCDl.yrjym.cn
http://vtm6YVJb.yrjym.cn
http://lvH8aHNu.yrjym.cn
http://6AwVTDip.yrjym.cn
http://KaKEQAth.yrjym.cn
http://zt4i2Qmr.yrjym.cn
http://9UXWprRj.yrjym.cn
http://9ee0HUOS.yrjym.cn
http://LKZmVyzS.yrjym.cn
http://L9vv41d0.yrjym.cn
http://ImpmLaMU.yrjym.cn
http://VfhZDGxx.yrjym.cn
http://yHnlQOBD.yrjym.cn
http://PUBdwsRH.yrjym.cn
http://qYJmVQfi.yrjym.cn
http://zTDtP5rm.yrjym.cn
http://PtvfXC4u.yrjym.cn
http://KXZcuNdk.yrjym.cn
http://WHIyBaLC.yrjym.cn
http://ruuLDvLI.yrjym.cn
http://x1m6UIIi.yrjym.cn
http://ZXXMIPyL.yrjym.cn
http://KuwDytO2.yrjym.cn
http://www.dtcms.com/a/245353.html

相关文章:

  • Go语言高并发爬虫程序源码
  • 软件测试BUG
  • 在Ubuntu中使用Apache2部署项目
  • Vivado libtinfo.so.5
  • 前缀和题目:子数组异或查询
  • react实现axios 的简单封装
  • 解决新版RN 热更新报错:recreateReactContextInBackground
  • 基于sample_aiisp例子,创建3路编码流,记录
  • 【微软RDP协议】微软RDP协议技术架构特点与跨地域应用实践
  • 【 java 虚拟机知识 第二篇 】
  • android 之 CALL
  • 使用adb 抓取perfetto-trace的注意事项
  • 基于 Redis 的幂等性设计:SpringBoot @Async 在高并发 MySQL 日志存储中的应用
  • Mac 系统 Node.js 安装与版本管理指南
  • RAG检索前处理
  • GO后端开发内存管理及参考答案
  • adb 查看android 设备的硬盘及存储空间
  • 录制mp4 rospy
  • 2025年中国人工智能发展研究报告:技术突破、行业变革与全球竞争新格局
  • Spring 路由匹配机制详解:时间复杂度从 O(n) 降至 O(log n)
  • 学习STC51单片机36(芯片为STC89C52RCRC)智能小车3(PWM差速小车)
  • Redis 安装实践:基于鲲鹏 ARM 架构 Ubuntu 环境
  • 随记:sw2urdf插件导出urdf模型在ROS2-rviz2显示
  • 电流传感器在工业自动化中的应用
  • Tess4J:基于 Java 的 OCR 解决方案
  • 【doris】doris集成ranger控制权限,ranger配置无法存储doris密码password信息
  • 代码随想录训练营第三十天 | 452. 用最少数量的箭引爆气球 435. 无重叠区间 763.划分字母区间
  • 【Net】OPC UA(OPC Unified Architecture)协议
  • 图片压缩工具类
  • 深入剖析 C++ 默认函数:拷贝构造与赋值运算符重载