当前位置: 首页 > news >正文

【LLIE技术专题】基于成对低光图像学习自适应先验方案代码讲解

本文是基于成对低光图像学习自适应先验方案的代码讲解,文章讲解可看链接PairLLE。

1、原文概要

本文PairLIE 是一种无监督低光图像增强方法,核心是利用成对低光图像(同场景、不同光照) 学习自适应先验,减少人工先验依赖,与一般方法的区别如下图所示:
在这里插入图片描述
常规的方法是(a),本文提出的方法是(b),可以看到本文的方法需要2张图来优化。

2、代码结构

代码整体结构如下
在这里插入图片描述

核心代码模块包含模型结构、数据加载、训练流程3部分。

3 、核心代码模块

1. 模型结构

模型包含用于完成降噪恒等映射的N_net,以及预测反射图和光照图的R_net和L_net,代码在net/net.py中。


class L_net(nn.Module):def __init__(self, num=64):super(L_net, self).__init__()self.L_net = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(3, num, 3, 1, 0),nn.ReLU(),               nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0),nn.ReLU(), nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0),nn.ReLU(),               nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0),nn.ReLU(),   nn.ReflectionPad2d(1),nn.Conv2d(num, 1, 3, 1, 0),)def forward(self, input):return torch.sigmoid(self.L_net(input))class R_net(nn.Module):def __init__(self, num=64):super(R_net, self).__init__()self.R_net = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(3, num, 3, 1, 0),nn.ReLU(), nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0),nn.ReLU(),               nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0),nn.ReLU(),               nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0),            nn.ReLU(),   nn.ReflectionPad2d(1),nn.Conv2d(num, 3, 3, 1, 0),)def forward(self, input):return torch.sigmoid(self.R_net(input))class N_net(nn.Module):def __init__(self, num=64):super(N_net, self).__init__()self.N_net = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(3, num, 3, 1, 0),nn.ReLU(), nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0),nn.ReLU(),               nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0),nn.ReLU(),               nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0),            nn.ReLU(),   nn.ReflectionPad2d(1),nn.Conv2d(num, 3, 3, 1, 0),)def forward(self, input):return torch.sigmoid(self.N_net(input))class net(nn.Module):def __init__(self):super(net, self).__init__()        self.L_net = L_net(num=64)self.R_net = R_net(num=64)self.N_net = N_net(num=64)        def forward(self, input):x = self.N_net(input)L = self.L_net(x)R = self.R_net(x)return L, R, x

模型结构比较简单是几个卷积+relu的组合,以上为训练模型结构。推理时,会对L图进行gamma增强后与反射图处理,结构如下所示:

L, R, X = model(input)    
I = torch.pow(L,0.2) * R  # default=0.2, LOL=0.14.

此与论文给出的流程图对应。

2. 数据加载

由于该篇论文选用的是多曝光的成对数据,因此它只需要加载某一个文件夹中的不同曝光数据即可完成训练,如dataset.py文件所示,:

class DatasetFromFolder(data.Dataset):def __init__(self, data_dir, transform=None):super(DatasetFromFolder, self).__init__()self.data_dir = data_dirself.transform = transformdef __getitem__(self, index):index = indexdata_filenames = [join(join(self.data_dir, str(index+1)), x) for x in listdir(join(self.data_dir, str(index+1))) if is_image_file(x)]num = len(data_filenames)index1 = random.randint(1,num)index2 = random.randint(1,num)while abs(index1 - index2) == 0:index2 = random.randint(1,num)im1 = load_img(data_filenames[index1-1])im2 = load_img(data_filenames[index2-1])_, file1 = os.path.split(data_filenames[index1-1])_, file2 = os.path.split(data_filenames[index2-1])seed = np.random.randint(123456789) # make a seed with numpy generator if self.transform:random.seed(seed) # apply this seed to img tranfsormstorch.manual_seed(seed) # needed for torchvision 0.7im1 = self.transform(im1)random.seed(seed)torch.manual_seed(seed)         im2 = self.transform(im2)        return im1, im2, file1, file2def __len__(self):return 324 # for custom datasets, please check the dataset size and modify this number

通过在提前准备好的文件夹中选出2个不同的文件,完成多曝光数据的准备。

3. 训练流程

位于main.py文件中,完成了R正则损失和Retinex损失的计算。

def train():model.train()loss_print = 0for iteration, batch in enumerate(training_data_loader, 1):im1, im2, file1, file2 = batch[0], batch[1], batch[2], batch[3]im1 = im1.cuda()im2 = im2.cuda()L1, R1, X1 = model(im1)L2, R2, X2 = model(im2)   loss1 = C_loss(R1, R2)loss2 = R_loss(L1, R1, im1, X1)loss3 = P_loss(im1, X1)loss =  loss1 * 1 + loss2 * 1 + loss3 * 500optimizer.zero_grad()loss.backward()optimizer.step()loss_print = loss_print + loss.item()if iteration % 10 == 0:print("===> Epoch[{}]({}/{}): Loss: {:.4f} || Learning rate: lr={}.".format(epoch,iteration, len(training_data_loader), loss_print, optimizer.param_groups[0]['lr']))

可以看到损失分为3个部分:

  1. loss1代表的是R正则,两个不同曝光的图像它们的R图一样。
  2. loss2代表的是Retinex假设损失,分解后的结果需要满足假设。
  3. loss3是一个降噪后的保真度损失。

其中所有损失的具体计算在util.py中。

def gradient(img):height = img.size(2)width = img.size(3)gradient_h = (img[:,:,2:,:]-img[:,:,:height-2,:]).abs()gradient_w = (img[:, :, :, 2:] - img[:, :, :, :width-2]).abs()return gradient_h, gradient_wdef tv_loss(illumination):gradient_illu_h, gradient_illu_w = gradient(illumination)loss_h = gradient_illu_hloss_w = gradient_illu_wloss = loss_h.mean() + loss_w.mean()return lossdef C_loss(R1, R2):loss = torch.nn.MSELoss()(R1, R2) return lossdef R_loss(L1, R1, im1, X1):max_rgb1, _ = torch.max(im1, 1)max_rgb1 = max_rgb1.unsqueeze(1) loss1 = torch.nn.MSELoss()(L1*R1, X1) + torch.nn.MSELoss()(R1, X1/L1.detach())loss2 = torch.nn.MSELoss()(L1, max_rgb1) + tv_loss(L1)return loss1 + loss2def P_loss(im1, X1):loss = torch.nn.MSELoss()(im1, X1)return loss

与讲解中公式对应。

3、总结

代码实现核心的部分讲解完毕,本文实现了一种无监督低光图像增强方法,核心是利用成对低光图像(同场景、不同光照) 学习自适应先验,减少人工先验依赖,在简单的网络结构下实际了降噪增强的效果。


感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。

http://www.dtcms.com/a/594683.html

相关文章:

  • 瑞金网站建设推广自助建站吧
  • 深圳大型网站开发seo与网站建设
  • 行业网站 cms最好的在线影视免费
  • Day1算法训练(数字统计,两个数组的交集,点击消除)
  • 双并网点 + 104 协议传输!Acrel1000 打造厂区储能综合自动化标杆方案
  • 网站备案要收费吗网络营销方案论文
  • 哪个网站可以帮助做数学题计算机应用网站建设与维护是做什么
  • Vue3:详解toRefs
  • 性价比高的建筑设备监控管理系统企业
  • 网站建设好吗网站建设与制作实训报告
  • 如何做好网站推广营销WordPress最强大的主题
  • 免费室内设计素材网站网站建站基本要素
  • P4198 楼房重建 题解
  • asp网站例子一套公司vi设计多少钱一
  • YOLOv5(一):目录结构 学习顺序
  • 密云建站推广电子商务网站建设考试
  • Python | 常用的控制流语句及工作原理
  • 网站建设公司有哪些方面郑州妇科
  • seo综合查询网站源码微网站建设招聘
  • Linux 重定向与Cookie
  • 24G毫米波雷达实现风扇跟随人转动,精准智能,节能省事
  • 杭州网站建设洛洛科技权威的唐山网站建设
  • 广东省省考备考(第一百四十六天11.10)——资料分析、数量关系(强化训练)
  • 自己做的网站无法访问网站页面设计招聘
  • CommonJS 与 ES Module 完全入门指南:从基础概念到项目实战
  • dedecms 调用 另一个网站凡科网站怎样做
  • 建立自己的平台网站吗如何做网站小编
  • 一个数据库两个网站wordpress登陆重庆网站托管
  • 烟台专业做网站公司哪家好百度云虚拟主机搭建wordpress
  • 网站设计的公司皆选奇点网络招商网站建设