时间梯度匹配损失 TGMLoss
时间梯度匹配损失(Temporal Gradient Matching Loss, TGM Loss)的完整示例,该损失函数常用于视频预测、运动平滑等任务,通过约束预测序列的时间梯度与真实序列一致来提升时序连续性
训练测试demo代码:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
# 定义时间梯度匹配损失
class TGMLoss(nn.Module):
def __init__(self, mode='l1', reduction='mean'):
super().__init__()
self.mode = mode
self.reduction = reduction
def forward(self, preds, targets):
"""
输入形状: (B, T, ...)
B: 批量大小, T