记录一个大模型逐层微调计算损失输出少了一个维度的小bug
1.假如针对的对象是linear
def _compute_mse_on_batch(layer: nn.Module, batch_iter: Iterator[Tuple[torch.Tensor, torch.Tensor]], **kwargs
) -> torch.Tensor:inps_batch, outs_batch = next(batch_iter)print("Initial inps_batch:", inps_batch.shape)print("Initial outs_batch:", outs_batch.shape)# print("Any NaNs in inps_batch:", torch.isnan(inps_batch).any())# print("Any NaNs in outs_batch:", torch.isnan(outs_batch).any())# if inps_batch.shape[0] != 1:# for name, value in list(kwargs.items()):# if isinstance(value, torch.Tensor) and value.shape[0] == 1:# if name not in ("attention_mask", "position_ids"):# warnings.warn(f"Tiling an unexpected kwarg {name} over batch size; make sure this is valid.")# repeats = [len(inps_batch)] + [1 for _ in range(value.ndim - 1)]# kwargs[name] = value.tile(*repeats)outs_prediction= layer(inps_batch, **kwargs)assert outs_prediction.shape == outs_batch.shapeloss = F.mse_loss(outs_prediction, outs_batch)# print("Computed loss:", loss.item())return loss
2.假如针对的对象是transformer
def _compute_mse_on_batch(layer: nn.Module, batch_iter: Iterator[Tuple[torch.Tensor, torch.Tensor]], **kwargs
) -> torch.Tensor:inps_batch, outs_batch = next(batch_iter)print("Initial inps_batch:", inps_batch.shape)print("Initial outs_batch:", outs_batch.shape)# print("Any NaNs in inps_batch:", torch.isnan(inps_batch).any())# print("Any NaNs in outs_batch:", torch.isnan(outs_batch).any())# if inps_batch.shape[0] != 1:# for name, value in list(kwargs.items()):# if isinstance(value, torch.Tensor) and value.shape[0] == 1:# if name not in ("attention_mask", "position_ids"):# warnings.warn(f"Tiling an unexpected kwarg {name} over batch size; make sure this is valid.")# repeats = [len(inps_batch)] + [1 for _ in range(value.ndim - 1)]# kwargs[name] = value.tile(*repeats)outs_prediction, *_unused = layer(inps_batch, **kwargs)# print("outs_prediction device in loss:", outs_prediction.device)# print(" outs_batch device in loss:", outs_batch.device)assert outs_prediction.shape == outs_batch.shapeloss = F.mse_loss(outs_prediction, outs_batch)# print("Computed loss:", loss.item())return loss
值得注意的是,假如我们在线性层里面写的是: outs_prediction, *_unused = layer(inps_batch, **kwargs),由于线性层返回没有 *_unused ,会导致我们输入[batchsize,input]时候得到的输出不是我们期望的[batchsize,output],而只会有[output]