【LUT技术专题】SVDLUT代码讲解
本文是对SVDLUT技术的代码讲解,原文解读请看SVDLUT文章讲解。
1、原文概要
SVDLUT最大的创新点在于其将3DLUT的查找过程使用SVD转换为了2DLUT的多次查找求和,减小了参数量和计算量。流程如图所示:

类比SABLUT方法上是相似的,都存在双边grid的插值和颜色的LUT插值操作,不过因为查找维度变小了,因此作者额外引入了GridweightsGrid \ weightsGrid weights以及LUTweightsLUT \ weightsLUT weights用于对多个2DLUT的插值结果进行加权求和,从而完成与3DLUT相似的功能。
2、代码结构
代码整体结构如下:

主要关注模型部分的实现即可。
3 、核心代码模块
models.py 文件
整体如下所示:
class SVDLUT(nn.Module):def __init__(self, backbone_type='cnn', backbone_coef=8,lut_n_vertices=17, lut_n_ranks=24, grid_n_vertices=17, grid_n_ranks=24, ch_per_grid=2,lut_weight_ranks=8, grid_weight_ranks=8,lut_n_singular=8, grid_n_singular=8):super(SVDLUT, self).__init__()self.backbone_type = backbone_type.lower()if backbone_type.lower() == 'resnet':self.backbone = resnet18_224()print('Resnet backbone apply')n_feats = 512else:self.backbone = Backbone(backbone_coef=backbone_coef)print('CNN backbone apply')n_feats = 32*backbone_coefself.gen_2d_lut = Gen_2D_SVD_LUT(n_vertices=lut_n_vertices, n_feats=n_feats, n_ranks=lut_n_ranks, n_singlar=lut_n_singular) self.gen_2d_lut_weight_bias = Gen_2D_LUT_weight_bias(n_vertices=lut_n_vertices, n_feats=n_feats, n_ranks=lut_weight_ranks)self.gen_2d_bilateral = Gen_2D_bilateral_grids(n_vertices=grid_n_vertices, n_feats=n_feats, n_ranks=grid_n_ranks, ch_per_grid=ch_per_grid)self.gen_2d_grid_weight_bias =Gen_2D_bilateral_grids_weight_bias(n_vertices=grid_n_vertices, n_feats=n_feats, n_ranks=grid_weight_ranks, ch_per_grid=ch_per_grid)self.slicing_transform = bilinear_2Dslicing_lut_transformself.relu = nn.ReLU()def init_weights(self):def special_initilization(m):classname = m.__class__.__name__if 'Conv' in classname:nn.init.xavier_normal_(m.weight.data)elif 'InstanceNorm' in classname:nn.init.normal_(m.weight.data, 1.0, 0.02)nn.init.constant_(m.bias.data, 0.0)if self.backbone_type != 'resnet':self.backbone.apply(special_initilization)self.gen_2d_lut.init_weights()self.gen_2d_lut_weight_bias.init_weights()self.gen_2d_bilateral.init_weights()self.gen_2d_grid_weight_bias.init_weights()def forward(self, img):img_feature = self.backbone(img)g3d_lut, lut_weights = self.gen_2d_lut(img_feature)lut_param_weights, lut_param_bias = self.gen_2d_lut_weight_bias(img_feature)gbilateral, grid_weights = self.gen_2d_bilateral(img_feature)grid_param_weights, grid_param_bias = self.gen_2d_grid_weight_bias(img_feature)output = self.slicing_transform(gbilateral, img, grid_param_weights, grid_param_bias, g3d_lut, lut_param_weights, lut_param_bias)output = self.relu(output)return output, lut_weights, grid_weights, g3d_lut, gbilateral
整体跟我们在文章讲解中一样,先提取特征,然后生成2DLUT和grid以及它们对应的weights和bias用于加权,经过一个统一的slice插值模块得到最终结果,整个过程包含以下几类。
1. Backbone类
这里用于提取图像特征,没有什么特别的。
class Backbone(nn.Module):def __init__(self, backbone_coef=8):super(Backbone, self).__init__()self.backbone_coef = backbone_coefself.model = nn.Sequential(nn.Upsample(size=(256,256),mode='bilinear'),nn.Conv2d(3, backbone_coef, 3, stride=2, padding=1), #8 x 128 x 128nn.LeakyReLU(0.2),nn.InstanceNorm2d(backbone_coef, affine=True),*discriminator_block(backbone_coef, 2*backbone_coef, normalization=True), #16 x 64 x 64*discriminator_block(2*backbone_coef, 4*backbone_coef, normalization=True), #32 x 32 x 32*discriminator_block(4*backbone_coef, 8*backbone_coef, normalization=True), #64 x 16 x 16*discriminator_block(8*backbone_coef, 8*backbone_coef), #64 x 8 x 8#*discriminator_block(128, 128, normalization=True),nn.Dropout(p=0.5),nn.AvgPool2d(5, stride=2) #64 x 2 x 2)def forward(self, img_input):return self.model(img_input).view([-1,self.backbone_coef*32])
2. Gen_2D_SVD_LUT和Gen_2D_LUT_weight_bias类
用于生成3DLUT。
class Gen_2D_SVD_LUT(nn.Module):def __init__(self, n_colors=3, ch_per_lut = 3, n_lut_dim=2, n_vertices=17, n_feats=256, n_ranks=24, n_singlar=8):super(Gen_2D_SVD_LUT, self).__init__()# h0self.weights_generator = nn.Linear(n_feats, n_ranks)self.n_svd = n_vertices * n_singlar + n_singlar + n_singlar * n_vertices# h1self.basis_luts_bank = nn.Linear(n_ranks, n_colors * ch_per_lut * self.n_svd)self.n_colors = n_colorsself.n_vertices = n_verticesself.n_feats = n_featsself.n_ranks = n_ranksself.ch_per_lut = ch_per_lutself.n_singlar = n_singlardef init_weights(self):r"""Init weights for models.For the mapping f (`backbone`) and h (`lut_generator`), we follow the initialization in[3D-LUT](https://github.com/HuiZeng/Image-Adaptive-3DLUT)."""nn.init.ones_(self.weights_generator.bias)nn.init.zeros_(self.basis_luts_bank.bias)cols, rows = torch.stack(torch.meshgrid(*[torch.arange(self.n_vertices) for _ in range(2)]),dim=0).div(self.n_vertices - 1).flip(0)zero2d = torch.zeros(self.n_vertices, self.n_vertices)d = torch.stack([cols,cols,zero2d, rows,zero2d,cols,zero2d,rows,rows], dim=0)u,s,v =torch.svd(d)u = u[:,:,:self.n_singlar].contiguous().view([3*self.ch_per_lut,-1])s = s[:,:self.n_singlar]v = v[:,:,:self.n_singlar].mT.contiguous().view([3*self.ch_per_lut,-1])d= torch.cat([u,s,v], dim=1)identity_lut = torch.stack([d,*[torch.zeros(3 * self.ch_per_lut, self.n_svd) for _ in range(self.n_ranks - 1)]], dim=0).view(self.n_ranks, -1)self.basis_luts_bank.weight.data.copy_(identity_lut.t())def forward(self, img_feature):weights = self.weights_generator(img_feature)lut_svd = self.basis_luts_bank(weights)lut_svd = lut_svd.view([-1, self.n_svd])lut_u = lut_svd[:,:self.n_vertices * self.n_singlar]lut_s = lut_svd[:,self.n_vertices * self.n_singlar:self.n_vertices * self.n_singlar + self.n_singlar]lut_v = lut_svd[:,self.n_vertices * self.n_singlar + self.n_singlar:]lut_u = lut_u.view([-1, self.n_vertices, self.n_singlar])lut_s = torch.diag_embed(lut_s)lut_v = lut_v.view([-1, self.n_singlar, self.n_vertices])luts = torch.bmm(torch.bmm(lut_u,lut_s), lut_v)luts = luts.view([-1,self.n_colors, self.ch_per_lut, self.n_vertices,self.n_vertices])return luts, weights
利用svd还原多个2DLUT,每个2DLUT包含了n_vertices个顶点,缺少的一个维度转换为ch_per_lut,否则无法得到多个2DLUT的组合。
class Gen_2D_LUT_weight_bias(nn.Module):def __init__(self, n_colors=3, ch_per_lut = 3, n_vertices=17, n_feats=256, n_ranks=24):super(Gen_2D_LUT_weight_bias, self).__init__()# h0self.weights_generator = nn.Linear(n_feats, n_ranks)# h1self.basis_luts_bank = nn.Linear(n_ranks, n_colors * (ch_per_lut + 1))self.n_colors = n_colorsself.n_vertices = n_verticesself.n_feats = n_featsself.n_ranks = n_ranksself.ch_per_lut = ch_per_lutdef init_weights(self):nn.init.ones_(self.weights_generator.bias)nn.init.zeros_(self.basis_luts_bank.bias)d = torch.tensor([[0.5,0.5,0,0],[0.5,0,0.5,0],[0,0.5,0.5,0]])identity_lut = torch.stack([d,*[torch.zeros(self.n_colors, self.ch_per_lut + 1) for _ in range(self.n_ranks - 1)]], dim=0).view(self.n_ranks, -1)self.basis_luts_bank.weight.data.copy_(identity_lut.t())def forward(self, img_feature):weights = self.weights_generator(img_feature)weights_bias = self.basis_luts_bank(weights)weights_bias = weights_bias.view([-1,self.n_colors, self.ch_per_lut + 1])lut_param_weights = weights_bias[:, :, :self.n_colors]lut_param_bias = weights_bias[:, :, self.n_colors:]return lut_param_weights, lut_param_bias
生成加权系数,跟LUT数目有关,参数量非常小。
3. Gen_2D_bilateral_grids和Gen_2D_bilateral_grids_weight_bias类
用于生成双边网格。
class Gen_2D_bilateral_grids(nn.Module):def __init__(self, n_grid_dim=2, n_vertices=17, n_feats=256, n_ranks=24, ch_per_grid=2):super(Gen_2D_bilateral_grids, self).__init__()# h0self.weights_generator = nn.Linear(n_feats, n_ranks)# h1self.basis_grids_bank = nn.Linear(n_ranks, ch_per_grid * 3 * 3 * (n_vertices ** n_grid_dim))self.n_grid_dim = n_grid_dimself.n_vertices = n_verticesself.n_feats = n_featsself.n_ranks = n_ranksself.ch_per_grid = ch_per_gridself.n_grids = ch_per_grid * 3def init_weights(self):r"""Init weights for models.For the mapping f (`backbone`) and h (`lut_generator`), we follow the initialization in[3D-LUT](https://github.com/HuiZeng/Image-Adaptive-3DLUT)."""nn.init.ones_(self.weights_generator.bias)nn.init.zeros_(self.basis_grids_bank.bias)cols, rows = torch.stack(torch.meshgrid(*[torch.arange(self.n_vertices) for _ in range(2)]),dim=0).div(self.n_vertices - 1).flip(0)zero2d = torch.zeros(self.n_vertices, self.n_vertices)d = torch.stack([*[zero2d,rows,rows, zero2d,rows,rows,zero2d,rows,rows] * self.ch_per_grid], dim=0)identity_grid = torch.stack([d,*[torch.zeros(self.n_grids * 3,self.n_vertices, self.n_vertices) for _ in range(self.n_ranks - 1)]], dim=0).view(self.n_ranks, -1)self.basis_grids_bank.weight.data.copy_(identity_grid.t())def forward(self, img_feature):weights = self.weights_generator(img_feature)grids = self.basis_grids_bank(weights)grids = grids.view([-1,self.n_grids,3,self.n_vertices,self.n_vertices])return grids, weightsclass Gen_2D_bilateral_grids_weight_bias(nn.Module):def __init__(self, n_colors=3, ch_per_grid=2, n_vertices=17, n_feats=256, n_ranks=24):super(Gen_2D_bilateral_grids_weight_bias, self).__init__()# h0self.weights_generator = nn.Linear(n_feats, n_ranks)# h1self.basis_luts_bank = nn.Linear(n_ranks, ch_per_grid * (3 * n_colors + n_colors))self.n_colors = n_colorsself.n_vertices = n_verticesself.n_feats = n_featsself.n_ranks = n_ranksself.ch_per_grid = ch_per_griddef init_weights(self):r"""Init weights for models.For the mapping f (`backbone`) and h (`lut_generator`), we follow the initialization in[3D-LUT](https://github.com/HuiZeng/Image-Adaptive-3DLUT)."""nn.init.ones_(self.weights_generator.bias)nn.init.zeros_(self.basis_luts_bank.bias)d = torch.tensor([*[[0,1,1,0],[0,1,1,0],[0,1,1,0]] * self.ch_per_grid]).div(self.ch_per_grid * 2)identity_lut = torch.stack([d,*[torch.zeros(3*self.ch_per_grid, self.n_colors + 1) for _ in range(self.n_ranks - 1)]], dim=0).view(self.n_ranks, -1)self.basis_luts_bank.weight.data.copy_(identity_lut.t())def forward(self, img_feature):weights = self.weights_generator(img_feature)weights_bias = self.basis_luts_bank(weights)weights_bias = weights_bias.view([-1,self.ch_per_grid, 3 *self.n_colors + self.n_colors])grid_param_weights = weights_bias[:, :, : 3 * self.n_colors]grid_param_bias = weights_bias[:, :, 3 * self.n_colors:]return grid_param_weights, grid_param_bias
grid的生成不适用于svd,因为性能会下降,因此作者这里是直接用linear给出的。
4. bilinear_2Dslicing_lut_transform函数
在kernel_code/bilateral_slicing/src/trilinear2D_slice_LUTTransform_cpu.cpp中可以看到源码。
void TriLinearCPU2DSliceAndLUTTransformForward(const int nthreads,const scalar_t *grid,const scalar_t *image,const scalar_t *grid_weights,const scalar_t *grid_bias,const scalar_t *lut,const scalar_t *lut_weights,const scalar_t *lut_bias,scalar_t *output,const int grid_dim,const int grid_shift,const scalar_t grid_binsize,const int lut_dim,const int lut_shift,const scalar_t lut_binsize,const int width,const int height,const int num_channels,const int grid_per_ch)
{for (int index = 0; index < nthreads; index++){const int x_ = index % width;const int y_ = index / width;const scalar_t x = x_ / (width - 1);const scalar_t y = y_ / (height - 1);const scalar_t r = image[index];const scalar_t g = image[index + width * height];const scalar_t b = image[index + width * height * 2];const int32_t x_id = clamp((int32_t)floor(x * (grid_dim - 1)), 0, grid_dim - 2);const int32_t y_id = clamp((int32_t)floor(y * (grid_dim - 1)), 0, grid_dim - 2);int32_t r_id = clamp((int32_t)floor(r * (grid_dim - 1)), 0, grid_dim - 2);int32_t g_id = clamp((int32_t)floor(g * (grid_dim - 1)), 0, grid_dim - 2);int32_t b_id = clamp((int32_t)floor(b * (grid_dim - 1)), 0, grid_dim - 2);const scalar_t x_d = (x - grid_binsize * x_id) / grid_binsize;const scalar_t y_d = (y - grid_binsize * y_id) / grid_binsize;scalar_t r_d = (r - grid_binsize * r_id) / grid_binsize;scalar_t g_d = (g - grid_binsize * g_id) / grid_binsize;scalar_t b_d = (b - grid_binsize * b_id) / grid_binsize;const int id00_xy = (x_id) + (y_id)*grid_dim;const int id10_xy = (x_id + 1) + (y_id)*grid_dim;const int id01_xy = (x_id) + (y_id + 1) * grid_dim;const int id11_xy = (x_id + 1) + (y_id + 1) * grid_dim;const int id00_xr = (x_id) + (r_id)*grid_dim;const int id10_xr = (x_id + 1) + (r_id)*grid_dim;const int id01_xr = (x_id) + (r_id + 1) * grid_dim;const int id11_xr = (x_id + 1) + (r_id + 1) * grid_dim;const int id00_yr = (y_id) + (r_id)*grid_dim;const int id10_yr = (y_id + 1) + (r_id)*grid_dim;const int id01_yr = (y_id) + (r_id + 1) * grid_dim;const int id11_yr = (y_id + 1) + (r_id + 1) * grid_dim;const int id00_xg = (x_id) + (g_id)*grid_dim;const int id10_xg = (x_id + 1) + (g_id)*grid_dim;const int id01_xg = (x_id) + (g_id + 1) * grid_dim;const int id11_xg = (x_id + 1) + (g_id + 1) * grid_dim;const int id00_yg = (y_id) + (g_id)*grid_dim;const int id10_yg = (y_id + 1) + (g_id)*grid_dim;const int id01_yg = (y_id) + (g_id + 1) * grid_dim;const int id11_yg = (y_id + 1) + (g_id + 1) * grid_dim;const int id00_xb = (x_id) + (b_id)*grid_dim;const int id10_xb = (x_id + 1) + (b_id)*grid_dim;const int id01_xb = (x_id) + (b_id + 1) * grid_dim;const int id11_xb = (x_id + 1) + (b_id + 1) * grid_dim;const int id00_yb = (y_id) + (b_id)*grid_dim;const int id10_yb = (y_id + 1) + (b_id)*grid_dim;const int id01_yb = (y_id) + (b_id + 1) * grid_dim;const int id11_yb = (y_id + 1) + (b_id + 1) * grid_dim;const scalar_t w00_xy = (1 - x_d) * (1 - y_d);const scalar_t w10_xy = (x_d) * (1 - y_d);const scalar_t w01_xy = (1 - x_d) * (y_d);const scalar_t w11_xy = (x_d) * (y_d);const scalar_t w00_xr = (1 - x_d) * (1 - r_d);const scalar_t w10_xr = (x_d) * (1 - r_d);const scalar_t w01_xr = (1 - x_d) * (r_d);const scalar_t w11_xr = (x_d) * (r_d);const scalar_t w00_yr = (1 - y_d) * (1 - r_d);const scalar_t w10_yr = (y_d) * (1 - r_d);const scalar_t w01_yr = (1 - y_d) * (r_d);const scalar_t w11_yr = (y_d) * (r_d);const scalar_t w00_xg = (1 - x_d) * (1 - g_d);const scalar_t w10_xg = (x_d) * (1 - g_d);const scalar_t w01_xg = (1 - x_d) * (g_d);const scalar_t w11_xg = (x_d) * (g_d);const scalar_t w00_yg = (1 - y_d) * (1 - g_d);const scalar_t w10_yg = (y_d) * (1 - g_d);const scalar_t w01_yg = (1 - y_d) * (g_d);const scalar_t w11_yg = (y_d) * (g_d);const scalar_t w00_xb = (1 - x_d) * (1 - b_d);const scalar_t w10_xb = (x_d) * (1 - b_d);const scalar_t w01_xb = (1 - x_d) * (b_d);const scalar_t w11_xb = (x_d) * (b_d);const scalar_t w00_yb = (1 - y_d) * (1 - b_d);const scalar_t w10_yb = (y_d) * (1 - b_d);const scalar_t w01_yb = (1 - y_d) * (b_d);const scalar_t w11_yb = (y_d) * (b_d);scalar_t int_img[3] = {0,};for (int i = 0; i < grid_per_ch; ++i){int_img[0] = int_img[0] + grid_weights[3 * (i + grid_per_ch * 0)] * (w00_xy * grid[id00_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w10_xy * grid[id10_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w01_xy * grid[id01_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w11_xy * grid[id11_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)]) +grid_weights[3 * (i + grid_per_ch * 0) + 1] * (w00_xr * grid[id00_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +w10_xr * grid[id10_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +w01_xr * grid[id01_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +w11_xr * grid[id11_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)]) +grid_weights[3 * (i + grid_per_ch * 0) + 2] * (w00_yr * grid[id00_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +w10_yr * grid[id10_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +w01_yr * grid[id01_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +w11_yr * grid[id11_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)]) +grid_bias[(i + grid_per_ch * 0)];int_img[1] = int_img[1] + grid_weights[3 * (i + grid_per_ch * 1)] * (w00_xy * grid[id00_xy + grid_shift * (3 * (i + grid_per_ch * 1) + 0)] + w10_xy * grid[id10_xy + grid_shift * (3 * (i + grid_per_ch * 1) + 0)] + w01_xy * grid[id01_xy + grid_shift * (3 * (i + grid_per_ch * 1) + 0)] + w11_xy * grid[id11_xy + grid_shift * (3 * (i + grid_per_ch * 1) + 0)]) +grid_weights[3 * (i + grid_per_ch * 1) + 1] * (w00_xg * grid[id00_xg + grid_shift * (3 * (i + grid_per_ch * 1) + 1)] +w10_xg * grid[id10_xg + grid_shift * (3 * (i + grid_per_ch * 1) + 1)] +w01_xg * grid[id01_xg + grid_shift * (3 * (i + grid_per_ch * 1) + 1)] +w11_xg * grid[id11_xg + grid_shift * (3 * (i + grid_per_ch * 1) + 1)]) +grid_weights[3 * (i + grid_per_ch * 1) + 2] * (w00_yg * grid[id00_yg + grid_shift * (3 * (i + grid_per_ch * 1) + 2)] +w10_yg * grid[id10_yg + grid_shift * (3 * (i + grid_per_ch * 1) + 2)] +w01_yg * grid[id01_yg + grid_shift * (3 * (i + grid_per_ch * 1) + 2)] +w11_yg * grid[id11_yg + grid_shift * (3 * (i + grid_per_ch * 1) + 2)]) +grid_bias[(i + grid_per_ch * 1)];int_img[2] = int_img[2] + grid_weights[3 * (i + grid_per_ch * 2)] * (w00_xy * grid[id00_xy + grid_shift * (3 * (i + grid_per_ch * 2) + 0)] + w10_xy * grid[id10_xy + grid_shift * (3 * (i + grid_per_ch * 2) + 0)] + w01_xy * grid[id01_xy + grid_shift * (3 * (i + grid_per_ch * 2) + 0)] + w11_xy * grid[id11_xy + grid_shift * (3 * (i + grid_per_ch * 2) + 0)]) +grid_weights[3 * (i + grid_per_ch * 2) + 1] * (w00_xb * grid[id00_xb + grid_shift * (3 * (i + grid_per_ch * 2) + 1)] +w10_xb * grid[id10_xb + grid_shift * (3 * (i + grid_per_ch * 2) + 1)] +w01_xb * grid[id01_xb + grid_shift * (3 * (i + grid_per_ch * 2) + 1)] +w11_xb * grid[id11_xb + grid_shift * (3 * (i + grid_per_ch * 2) + 1)]) +grid_weights[3 * (i + grid_per_ch * 2) + 2] * (w00_yb * grid[id00_yb + grid_shift * (3 * (i + grid_per_ch * 2) + 2)] +w10_yb * grid[id10_yb + grid_shift * (3 * (i + grid_per_ch * 2) + 2)] +w01_yb * grid[id01_yb + grid_shift * (3 * (i + grid_per_ch * 2) + 2)] +w11_yb * grid[id11_yb + grid_shift * (3 * (i + grid_per_ch * 2) + 2)]) +grid_bias[(i + grid_per_ch * 2)];}r_id = clamp((int32_t)floor(r * (lut_dim - 1)), 0, lut_dim - 2);g_id = clamp((int32_t)floor(g * (lut_dim - 1)), 0, lut_dim - 2);b_id = clamp((int32_t)floor(b * (lut_dim - 1)), 0, lut_dim - 2);r_d = (r - lut_binsize * r_id) / lut_binsize;g_d = (g - lut_binsize * g_id) / lut_binsize;b_d = (b - lut_binsize * b_id) / lut_binsize;const int id00_rg = r_id + g_id * lut_dim;const int id10_rg = r_id + 1 + g_id * lut_dim;const int id01_rg = r_id + (g_id + 1) * lut_dim;const int id11_rg = r_id + 1 + (g_id + 1) * lut_dim;const int id00_rb = r_id + b_id * lut_dim;const int id10_rb = r_id + 1 + b_id * lut_dim;const int id01_rb = r_id + (b_id + 1) * lut_dim;const int id11_rb = r_id + 1 + (b_id + 1) * lut_dim;const int id00_gb = g_id + b_id * lut_dim;const int id10_gb = g_id + 1 + b_id * lut_dim;const int id01_gb = g_id + (b_id + 1) * lut_dim;const int id11_gb = g_id + 1 + (b_id + 1) * lut_dim;const scalar_t w00_rg = (1 - r_d) * (1 - g_d);const scalar_t w10_rg = (r_d) * (1 - g_d);const scalar_t w01_rg = (1 - r_d) * (g_d);const scalar_t w11_rg = (r_d) * (g_d);const scalar_t w00_rb = (1 - r_d) * (1 - b_d);const scalar_t w10_rb = (r_d) * (1 - b_d);const scalar_t w01_rb = (1 - r_d) * (b_d);const scalar_t w11_rb = (r_d) * (b_d);const scalar_t w00_gb = (1 - g_d) * (1 - b_d);const scalar_t w10_gb = (g_d) * (1 - b_d);const scalar_t w01_gb = (1 - g_d) * (b_d);const scalar_t w11_gb = (g_d) * (b_d);for (int i = 0; i < num_channels; ++i){scalar_t output_rg = w00_rg * lut[id00_rg + lut_shift * 3 * i] + w10_rg * lut[id10_rg + lut_shift * 3 * i] +w01_rg * lut[id01_rg + lut_shift * 3 * i] + w11_rg * lut[id11_rg + lut_shift * 3 * i];scalar_t output_rb = w00_rb * lut[id00_rb + lut_shift * (3 * i + 1)] + w10_rb * lut[id10_rb + lut_shift * (3 * i + 1)] +w01_rb * lut[id01_rb + lut_shift * (3 * i + 1)] + w11_rb * lut[id11_rb + lut_shift * (3 * i + 1)];scalar_t output_gb = w00_gb * lut[id00_gb + lut_shift * (3 * i + 2)] + w10_gb * lut[id10_gb + lut_shift * (3 * i + 2)] +w01_gb * lut[id01_gb + lut_shift * (3 * i + 2)] + w11_gb * lut[id11_gb + lut_shift * (3 * i + 2)];output[index + width * height * i] = int_img[i] + lut_weights[3 * i] * output_rg + lut_weights[3 * i + 1] * output_rb + lut_weights[3 * i + 2] * output_gb + lut_bias[i];}}
}
由于它合并了slice和最后lut转换的部分,所以代码比较长,我们先看第一段关于空间信息的融合部分的输出,每一个grid通道的和(以r为例)都需要原图、xy、xr、yr、bias组合而成,这跟讲解中公式对应,对应代码为:
int_img[0] = int_img[0] + grid_weights[3 * (i + grid_per_ch * 0)] * (w00_xy * grid[id00_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w10_xy * grid[id10_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w01_xy * grid[id01_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)] + w11_xy * grid[id11_xy + grid_shift * (3 * (i + grid_per_ch * 0) + 0)]) +grid_weights[3 * (i + grid_per_ch * 0) + 1] * (w00_xr * grid[id00_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +w10_xr * grid[id10_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +w01_xr * grid[id01_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)] +w11_xr * grid[id11_xr + grid_shift * (3 * (i + grid_per_ch * 0) + 1)]) +grid_weights[3 * (i + grid_per_ch * 0) + 2] * (w00_yr * grid[id00_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +w10_yr * grid[id10_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +w01_yr * grid[id01_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)] +w11_yr * grid[id11_yr + grid_shift * (3 * (i + grid_per_ch * 0) + 2)]) +
grid_bias[(i + grid_per_ch * 0)];
后续完成2DLUT的颜色插值时,我们看输出,同理(以r为例),都需要原图(这里是上面的grid slice输出结果)、rg、rb、gb、bias这几个组合而成,同样的与公式对应,对应代码为:
scalar_t output_rg = w00_rg * lut[id00_rg + lut_shift * 3 * i] + w10_rg * lut[id10_rg + lut_shift * 3 * i] +w01_rg * lut[id01_rg + lut_shift * 3 * i] + w11_rg * lut[id11_rg + lut_shift * 3 * i];scalar_t output_rb = w00_rb * lut[id00_rb + lut_shift * (3 * i + 1)] + w10_rb * lut[id10_rb + lut_shift * (3 * i + 1)] +w01_rb * lut[id01_rb + lut_shift * (3 * i + 1)] + w11_rb * lut[id11_rb + lut_shift * (3 * i + 1)];scalar_t output_gb = w00_gb * lut[id00_gb + lut_shift * (3 * i + 2)] + w10_gb * lut[id10_gb + lut_shift * (3 * i + 2)] +w01_gb * lut[id01_gb + lut_shift * (3 * i + 2)] + w11_gb * lut[id11_gb + lut_shift * (3 * i + 2)];output[index + width * height * i] = int_img[i] + lut_weights[3 * i] * output_rg + lut_weights[3 * i + 1] * output_rb + lut_weights[3 * i + 2] * output_gb + lut_bias[i];
3、总结
代码实现核心的部分讲解完毕,SVDLUT利用3D转2D和SVD分解的思路进一步优化了SABLUT,比较好的兼顾性能、存储与推理效率,特别适合资源受限的边缘设备部署。
感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。
