【LLIE技术专题】基于成对低光图像学习自适应先验方案代码讲解
本文是基于成对低光图像学习自适应先验方案的代码讲解,文章讲解可看链接PairLLE。
1、原文概要
本文PairLIE 是一种无监督低光图像增强方法,核心是利用成对低光图像(同场景、不同光照) 学习自适应先验,减少人工先验依赖,与一般方法的区别如下图所示:

常规的方法是(a),本文提出的方法是(b),可以看到本文的方法需要2张图来优化。
2、代码结构
代码整体结构如下:

核心代码模块包含模型结构、数据加载、训练流程3部分。
3 、核心代码模块
1. 模型结构
模型包含用于完成降噪恒等映射的N_net,以及预测反射图和光照图的R_net和L_net,代码在net/net.py中。
class L_net(nn.Module):def __init__(self, num=64):super(L_net, self).__init__()self.L_net = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(3, num, 3, 1, 0),nn.ReLU(), nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0),nn.ReLU(), nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0),nn.ReLU(), nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0),nn.ReLU(), nn.ReflectionPad2d(1),nn.Conv2d(num, 1, 3, 1, 0),)def forward(self, input):return torch.sigmoid(self.L_net(input))class R_net(nn.Module):def __init__(self, num=64):super(R_net, self).__init__()self.R_net = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(3, num, 3, 1, 0),nn.ReLU(), nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0),nn.ReLU(), nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0),nn.ReLU(), nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0), nn.ReLU(), nn.ReflectionPad2d(1),nn.Conv2d(num, 3, 3, 1, 0),)def forward(self, input):return torch.sigmoid(self.R_net(input))class N_net(nn.Module):def __init__(self, num=64):super(N_net, self).__init__()self.N_net = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(3, num, 3, 1, 0),nn.ReLU(), nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0),nn.ReLU(), nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0),nn.ReLU(), nn.ReflectionPad2d(1),nn.Conv2d(num, num, 3, 1, 0), nn.ReLU(), nn.ReflectionPad2d(1),nn.Conv2d(num, 3, 3, 1, 0),)def forward(self, input):return torch.sigmoid(self.N_net(input))class net(nn.Module):def __init__(self):super(net, self).__init__() self.L_net = L_net(num=64)self.R_net = R_net(num=64)self.N_net = N_net(num=64) def forward(self, input):x = self.N_net(input)L = self.L_net(x)R = self.R_net(x)return L, R, x
模型结构比较简单是几个卷积+relu的组合,以上为训练模型结构。推理时,会对L图进行gamma增强后与反射图处理,结构如下所示:
L, R, X = model(input)
I = torch.pow(L,0.2) * R # default=0.2, LOL=0.14.
此与论文给出的流程图对应。
2. 数据加载
由于该篇论文选用的是多曝光的成对数据,因此它只需要加载某一个文件夹中的不同曝光数据即可完成训练,如dataset.py文件所示,:
class DatasetFromFolder(data.Dataset):def __init__(self, data_dir, transform=None):super(DatasetFromFolder, self).__init__()self.data_dir = data_dirself.transform = transformdef __getitem__(self, index):index = indexdata_filenames = [join(join(self.data_dir, str(index+1)), x) for x in listdir(join(self.data_dir, str(index+1))) if is_image_file(x)]num = len(data_filenames)index1 = random.randint(1,num)index2 = random.randint(1,num)while abs(index1 - index2) == 0:index2 = random.randint(1,num)im1 = load_img(data_filenames[index1-1])im2 = load_img(data_filenames[index2-1])_, file1 = os.path.split(data_filenames[index1-1])_, file2 = os.path.split(data_filenames[index2-1])seed = np.random.randint(123456789) # make a seed with numpy generator if self.transform:random.seed(seed) # apply this seed to img tranfsormstorch.manual_seed(seed) # needed for torchvision 0.7im1 = self.transform(im1)random.seed(seed)torch.manual_seed(seed) im2 = self.transform(im2) return im1, im2, file1, file2def __len__(self):return 324 # for custom datasets, please check the dataset size and modify this number
通过在提前准备好的文件夹中选出2个不同的文件,完成多曝光数据的准备。
3. 训练流程
位于main.py文件中,完成了R正则损失和Retinex损失的计算。
def train():model.train()loss_print = 0for iteration, batch in enumerate(training_data_loader, 1):im1, im2, file1, file2 = batch[0], batch[1], batch[2], batch[3]im1 = im1.cuda()im2 = im2.cuda()L1, R1, X1 = model(im1)L2, R2, X2 = model(im2) loss1 = C_loss(R1, R2)loss2 = R_loss(L1, R1, im1, X1)loss3 = P_loss(im1, X1)loss = loss1 * 1 + loss2 * 1 + loss3 * 500optimizer.zero_grad()loss.backward()optimizer.step()loss_print = loss_print + loss.item()if iteration % 10 == 0:print("===> Epoch[{}]({}/{}): Loss: {:.4f} || Learning rate: lr={}.".format(epoch,iteration, len(training_data_loader), loss_print, optimizer.param_groups[0]['lr']))
可以看到损失分为3个部分:
- loss1代表的是R正则,两个不同曝光的图像它们的R图一样。
- loss2代表的是Retinex假设损失,分解后的结果需要满足假设。
- loss3是一个降噪后的保真度损失。
其中所有损失的具体计算在util.py中。
def gradient(img):height = img.size(2)width = img.size(3)gradient_h = (img[:,:,2:,:]-img[:,:,:height-2,:]).abs()gradient_w = (img[:, :, :, 2:] - img[:, :, :, :width-2]).abs()return gradient_h, gradient_wdef tv_loss(illumination):gradient_illu_h, gradient_illu_w = gradient(illumination)loss_h = gradient_illu_hloss_w = gradient_illu_wloss = loss_h.mean() + loss_w.mean()return lossdef C_loss(R1, R2):loss = torch.nn.MSELoss()(R1, R2) return lossdef R_loss(L1, R1, im1, X1):max_rgb1, _ = torch.max(im1, 1)max_rgb1 = max_rgb1.unsqueeze(1) loss1 = torch.nn.MSELoss()(L1*R1, X1) + torch.nn.MSELoss()(R1, X1/L1.detach())loss2 = torch.nn.MSELoss()(L1, max_rgb1) + tv_loss(L1)return loss1 + loss2def P_loss(im1, X1):loss = torch.nn.MSELoss()(im1, X1)return loss
与讲解中公式对应。
3、总结
代码实现核心的部分讲解完毕,本文实现了一种无监督低光图像增强方法,核心是利用成对低光图像(同场景、不同光照) 学习自适应先验,减少人工先验依赖,在简单的网络结构下实际了降噪增强的效果。
感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。
