当前位置: 首页 > news >正文

RIFE.py代码学习 自学

#比较重要的是去看IFNet.py代码import torch
import torch.nn as nn
import numpy as np
from torch.optim import AdamW
import torch.optim as optim
import itertools
#导入了很多model里面的东西
from model.warplayer import warp
from torch.nn.parallel import DistributedDataParallel as DDP
from model.IFNet import *
from model.IFNet_m import *
import torch.nn.functional as F
#model的损失函数
from model.loss import *
#也是损失函数
from model.laplacian import *
#
from model.refine import *
#整个 RIFE 插帧系统的训练与推理控制器。
#它封装了模型结构、推理逻辑、训练流程、损失函数、优化器等核心组件
#目的是让你可以方便地训练和使用 RIFE 插帧模型。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")class Model:#转化为类的实例def __init__(self, local_rank=-1, arbitrary=False):#我看train中调用model是 model = Model(args.local_rank)#那就是IFNet.py代码中if arbitrary == True:self.flownet = IFNet_m()else:#flownet可以用IFNet类中的所有公开方法和属性self.flownet = IFNet()#设备self.device()#优化器self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-3) # use large weight decay may avoid NaN loss#损失函数-lossself.epe = EPE()#损失函数-laplacianself.lap = LapLoss()#损失函数-lossself.sobel = SOBEL()if local_rank != -1:self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)def train(self):#这个train跟train.py中定义的train方法不一样 这个是切换模型状态为训练模式#是因为flownet=IFNet类 而IFNet继承了class IFBlock(nn.Module) module#所以可以使用.train方法self.flownet.train()def eval(self):self.flownet.eval()def device(self):self.flownet.to(device)def load_model(self, path, rank=0):def convert(param):#如果是分布式训练,去掉 module. 前缀return {k.replace("module.", ""): vfor k, v in param.items()if "module." in k}if rank <= 0:#'{}/flownet.pkl'.format(path)的意思就是将path插入到flownet 然后torch.load#然后convert 看到上面的def convert了没 然后load_state_dict#load_state_dict() 是 PyTorch 模型对象的方法,用于将参数字典加载到模型中self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path))))# path=log_path log_path = 'train_log'def save_model(self, path, rank=0):#只主进程保存 state_dict 保存if rank == 0:torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))def inference(self, img0, img1, scale=1, scale_list=None, TTA=False, timestep=0.5):if scale_list is None:scale_list = [4, 2, 1]for i in range(3):scale_list[i] = scale_list[i] * 1.0 / scaleimgs = torch.cat((img0, img1), 1)#在IFNet类中 是return flow_list, mask_list[2], merged, flow_teacher, merged_teacher, loss_distill#看IFNet.py中的这个类中的forward中是如何运行的flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list, timestep=timestep)#根据TTA来决定是返回merged[2] 还是。。 可以去看看TTA是干嘛的?if TTA == False:return merged[2]else:flow2, mask2, merged2, flow_teacher2, merged_teacher2, loss_distill2 = self.flownet(imgs.flip(2).flip(3), scale_list, timestep=timestep)#imgs.flip(2).flip(3) 维度2 h 维度3 w 进行上下左右翻转return (merged[2] + merged2[2].flip(2).flip(3)) / 2#输入 imgs gtdef update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):for param_group in self.optimG.param_groups:param_group['lr'] = learning_rate#拿出来img0 = imgs[:, :3]img1 = imgs[:, 3:]#启动train还是evalif training:self.train()else:self.eval()#计算这些参数 到IFNet.py代码中去看#flow:学生模型估计的光流 mask:学生模型生成的融合掩码,merged 学生模型生成的插值帧#distill 教师和学生模型生成的插帧 然后计算二者差别flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(torch.cat((imgs, gt), 1), scale=[4, 2, 1])#merged[2]是生成的 gt是真实 计算损失 学生 教师loss_l1 = (self.lap(merged[2], gt)).mean()loss_tea = (self.lap(merged_teacher, gt)).mean()#如果是训练阶段,启动if training:#启动梯度更新 zero_grad()清空梯度self.optimG.zero_grad()#损失计算loss_G = loss_l1 + loss_tea + loss_distill * 0.01 # when training RIFEm, the weight of loss_distill should be 0.005 or 0.002#反向传播loss_G.backward()#optimG 更新参数self.optimG.step()else:flow_teacher = flow[2]#这些参数可以在IFNet中再去详细看(还有损失函数)return merged[2], {'merged_tea': merged_teacher,'mask': mask,'mask_tea': mask,'flow': flow[2][:, :2],'flow_tea': flow_teacher,'loss_l1': loss_l1,'loss_tea': loss_tea,'loss_distill': loss_distill,}

文章转载自:

http://pM2I5lOL.ykrkb.cn
http://hW01sWx6.ykrkb.cn
http://cKraxuJ3.ykrkb.cn
http://0kDMasbS.ykrkb.cn
http://1gEfLSkx.ykrkb.cn
http://XOonQMNu.ykrkb.cn
http://efTEXbUi.ykrkb.cn
http://jmOnEUoo.ykrkb.cn
http://5ivZTRJ5.ykrkb.cn
http://AOS7UYsK.ykrkb.cn
http://DVhyZ8cS.ykrkb.cn
http://7PwDaPTg.ykrkb.cn
http://0V71NiHC.ykrkb.cn
http://MOFZuN1x.ykrkb.cn
http://DWHVpkUZ.ykrkb.cn
http://HUcPB0mn.ykrkb.cn
http://LI8TnIim.ykrkb.cn
http://J3qFUhUM.ykrkb.cn
http://HSOmUhc9.ykrkb.cn
http://GGCqmgon.ykrkb.cn
http://0b2nMaYm.ykrkb.cn
http://wyYoFIMl.ykrkb.cn
http://8heIuE0q.ykrkb.cn
http://IdWcoMrs.ykrkb.cn
http://9kfoeeWs.ykrkb.cn
http://JAae7PFw.ykrkb.cn
http://vI6lpwUn.ykrkb.cn
http://2gclZKKt.ykrkb.cn
http://mF3tUxyX.ykrkb.cn
http://65kiwecl.ykrkb.cn
http://www.dtcms.com/a/385546.html

相关文章:

  • Gateway-路由-规则配置
  • 低端影视官网入口 - 免费看影视资源网站|网页版|电脑版地址
  • 【Python3教程】Python3高级篇之日期与时间
  • 计算机网络——传输层(25王道最新版)
  • 5-14 forEach-数组简易循环(实例:数组的汇总)
  • 【智能体】rStar2-Agent
  • ego(5)---Astar绕障
  • UE5C++编译遇到MSB3073
  • 记一次JS逆向学习
  • 【PyTorch】单目标检测
  • RabbitMQ—基础篇
  • 介绍一下 Test-Time Training 技术
  • 【LangChain指南】Document loaders
  • 日语学习-日语知识点小记-进阶-JLPT-N1阶段蓝宝书,共120语法(10):91-100语法+考え方13
  • 2021/07 JLPT听力原文 问题四
  • MySQL 视图的更新与删除:从操作规范到风险防控
  • 【SQLMap】获取 Shell
  • Java之异常处理
  • C# 通过 TCP/IP 控制 Keysight 34465A 万用表(保姆级教程)
  • TVS二极管详解:原理、选型与应用实战
  • C++实现文件中单词统计等
  • 数据库(四)MySQL读写分离原理和实现
  • 关于数据库的导入和导出
  • 【氮化镓】GaN中受主的氢相关钝化余激活
  • AI 进课堂 - 语文教学流程重塑
  • 最近一些机器github解析到本地回环地址127.0.0.1
  • P6352 [COCI 2007/2008 #3] CETIRI
  • 【LeetCode 每日一题】37. 解数独
  • 多项式回归:线性回归的扩展
  • AI生成到无缝PBR材质:Firefly+第三方AI+Substance工作流