当前位置: 首页 > news >正文

记录一个大模型逐层微调计算损失输出少了一个维度的小bug

1.假如针对的对象是linear

def _compute_mse_on_batch(layer: nn.Module, batch_iter: Iterator[Tuple[torch.Tensor, torch.Tensor]], **kwargs
) -> torch.Tensor:inps_batch, outs_batch = next(batch_iter)print("Initial inps_batch:", inps_batch.shape)print("Initial outs_batch:", outs_batch.shape)# print("Any NaNs in inps_batch:", torch.isnan(inps_batch).any())# print("Any NaNs in outs_batch:", torch.isnan(outs_batch).any())# if inps_batch.shape[0] != 1:#     for name, value in list(kwargs.items()):#         if isinstance(value, torch.Tensor) and value.shape[0] == 1:#             if name not in ("attention_mask", "position_ids"):#                 warnings.warn(f"Tiling an unexpected kwarg {name} over batch size; make sure this is valid.")#             repeats = [len(inps_batch)] + [1 for _ in range(value.ndim - 1)]#             kwargs[name] = value.tile(*repeats)outs_prediction= layer(inps_batch, **kwargs)assert outs_prediction.shape == outs_batch.shapeloss = F.mse_loss(outs_prediction, outs_batch)# print("Computed loss:", loss.item())return loss

2.假如针对的对象是transformer

def _compute_mse_on_batch(layer: nn.Module, batch_iter: Iterator[Tuple[torch.Tensor, torch.Tensor]], **kwargs
) -> torch.Tensor:inps_batch, outs_batch = next(batch_iter)print("Initial inps_batch:", inps_batch.shape)print("Initial outs_batch:", outs_batch.shape)# print("Any NaNs in inps_batch:", torch.isnan(inps_batch).any())# print("Any NaNs in outs_batch:", torch.isnan(outs_batch).any())# if inps_batch.shape[0] != 1:#     for name, value in list(kwargs.items()):#         if isinstance(value, torch.Tensor) and value.shape[0] == 1:#             if name not in ("attention_mask", "position_ids"):#                 warnings.warn(f"Tiling an unexpected kwarg {name} over batch size; make sure this is valid.")#             repeats = [len(inps_batch)] + [1 for _ in range(value.ndim - 1)]#             kwargs[name] = value.tile(*repeats)outs_prediction, *_unused = layer(inps_batch, **kwargs)# print("outs_prediction device in loss:", outs_prediction.device)# print(" outs_batch device in loss:",  outs_batch.device)assert outs_prediction.shape == outs_batch.shapeloss = F.mse_loss(outs_prediction, outs_batch)# print("Computed loss:", loss.item())return loss

值得注意的是,假如我们在线性层里面写的是: outs_prediction, *_unused = layer(inps_batch, **kwargs),由于线性层返回没有 *_unused ,会导致我们输入[batchsize,input]时候得到的输出不是我们期望的[batchsize,output],而只会有[output]

相关文章:

  • Go语言高并发爬虫程序源码
  • 软件测试BUG
  • 在Ubuntu中使用Apache2部署项目
  • Vivado libtinfo.so.5
  • 前缀和题目:子数组异或查询
  • react实现axios 的简单封装
  • 解决新版RN 热更新报错:recreateReactContextInBackground
  • 基于sample_aiisp例子,创建3路编码流,记录
  • 【微软RDP协议】微软RDP协议技术架构特点与跨地域应用实践
  • 【 java 虚拟机知识 第二篇 】
  • android 之 CALL
  • 使用adb 抓取perfetto-trace的注意事项
  • 基于 Redis 的幂等性设计:SpringBoot @Async 在高并发 MySQL 日志存储中的应用
  • Mac 系统 Node.js 安装与版本管理指南
  • RAG检索前处理
  • GO后端开发内存管理及参考答案
  • adb 查看android 设备的硬盘及存储空间
  • 录制mp4 rospy
  • 2025年中国人工智能发展研究报告:技术突破、行业变革与全球竞争新格局
  • Spring 路由匹配机制详解:时间复杂度从 O(n) 降至 O(log n)
  • 深圳响应式网站建设公司/最新军事新闻今日最新消息
  • wordpress在手机版/windows优化大师好用吗
  • 网站友情链接怎么设置/百度网络营销推广
  • wordpress产品页面如何编辑/企业网站seo
  • 网文订阅做多的网站/加强网络暴力治理
  • 寿县住房与城乡建设局网站/seoul national university