#比较重要的是去看IFNet.py代码import torch
import torch.nn as nn
import numpy as np
from torch.optim import AdamW
import torch.optim as optim
import itertools
#导入了很多model里面的东西
from model.warplayer import warp
from torch.nn.parallel import DistributedDataParallel as DDP
from model.IFNet import *
from model.IFNet_m import *
import torch.nn.functional as F
#model的损失函数
from model.loss import *
#也是损失函数
from model.laplacian import *
#
from model.refine import *
#整个 RIFE 插帧系统的训练与推理控制器。
#它封装了模型结构、推理逻辑、训练流程、损失函数、优化器等核心组件
#目的是让你可以方便地训练和使用 RIFE 插帧模型。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")class Model:#转化为类的实例def __init__(self, local_rank=-1, arbitrary=False):#我看train中调用model是 model = Model(args.local_rank)#那就是IFNet.py代码中if arbitrary == True:self.flownet = IFNet_m()else:#flownet可以用IFNet类中的所有公开方法和属性self.flownet = IFNet()#设备self.device()#优化器self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-3) # use large weight decay may avoid NaN loss#损失函数-lossself.epe = EPE()#损失函数-laplacianself.lap = LapLoss()#损失函数-lossself.sobel = SOBEL()if local_rank != -1:self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)def train(self):#这个train跟train.py中定义的train方法不一样 这个是切换模型状态为训练模式#是因为flownet=IFNet类 而IFNet继承了class IFBlock(nn.Module) module#所以可以使用.train方法self.flownet.train()def eval(self):self.flownet.eval()def device(self):self.flownet.to(device)def load_model(self, path, rank=0):def convert(param):#如果是分布式训练,去掉 module. 前缀return {k.replace("module.", ""): vfor k, v in param.items()if "module." in k}if rank <= 0:#'{}/flownet.pkl'.format(path)的意思就是将path插入到flownet 然后torch.load#然后convert 看到上面的def convert了没 然后load_state_dict#load_state_dict() 是 PyTorch 模型对象的方法,用于将参数字典加载到模型中self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path))))# path=log_path log_path = 'train_log'def save_model(self, path, rank=0):#只主进程保存 state_dict 保存if rank == 0:torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))def inference(self, img0, img1, scale=1, scale_list=None, TTA=False, timestep=0.5):if scale_list is None:scale_list = [4, 2, 1]for i in range(3):scale_list[i] = scale_list[i] * 1.0 / scaleimgs = torch.cat((img0, img1), 1)#在IFNet类中 是return flow_list, mask_list[2], merged, flow_teacher, merged_teacher, loss_distill#看IFNet.py中的这个类中的forward中是如何运行的flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list, timestep=timestep)#根据TTA来决定是返回merged[2] 还是。。 可以去看看TTA是干嘛的?if TTA == False:return merged[2]else:flow2, mask2, merged2, flow_teacher2, merged_teacher2, loss_distill2 = self.flownet(imgs.flip(2).flip(3), scale_list, timestep=timestep)#imgs.flip(2).flip(3) 维度2 h 维度3 w 进行上下左右翻转return (merged[2] + merged2[2].flip(2).flip(3)) / 2#输入 imgs gtdef update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):for param_group in self.optimG.param_groups:param_group['lr'] = learning_rate#拿出来img0 = imgs[:, :3]img1 = imgs[:, 3:]#启动train还是evalif training:self.train()else:self.eval()#计算这些参数 到IFNet.py代码中去看#flow:学生模型估计的光流 mask:学生模型生成的融合掩码,merged 学生模型生成的插值帧#distill 教师和学生模型生成的插帧 然后计算二者差别flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(torch.cat((imgs, gt), 1), scale=[4, 2, 1])#merged[2]是生成的 gt是真实 计算损失 学生 教师loss_l1 = (self.lap(merged[2], gt)).mean()loss_tea = (self.lap(merged_teacher, gt)).mean()#如果是训练阶段,启动if training:#启动梯度更新 zero_grad()清空梯度self.optimG.zero_grad()#损失计算loss_G = loss_l1 + loss_tea + loss_distill * 0.01 # when training RIFEm, the weight of loss_distill should be 0.005 or 0.002#反向传播loss_G.backward()#optimG 更新参数self.optimG.step()else:flow_teacher = flow[2]#这些参数可以在IFNet中再去详细看(还有损失函数)return merged[2], {'merged_tea': merged_teacher,'mask': mask,'mask_tea': mask,'flow': flow[2][:, :2],'flow_tea': flow_teacher,'loss_l1': loss_l1,'loss_tea': loss_tea,'loss_distill': loss_distill,}