Biratral Grid与HDRNET论文笔记

这篇博客介绍 Google 近年来在Photography方面一篇具有代表性的论文,Deep Bilateral Learning for Real-Time Image Enhancement

HDRNET的主要特征数数据分成两路,低分辨率偏向high level语义,这一路学习到的是一系列参数(分mesh的3×4 color matrix),学习参数需要的计算资源相对于学习输出图需要的资源更少,移动端部署更加方便。高分辨率一路更偏向low-level引导,学些引导图和参数apply。

首先看下网络结构,总体分成两路,上面这一路是低分辨率图像的处理,通过一张降采样的图,学习到一些low level 的feature,之后又会分成两路,一路学习local feature,另一路学习global feature,这两路特征在生成参数之前会合并起来。高分辨率这一路首先学习引导图,结合引导图将特征参数上采样,然后再逐像素的apply特征变换。

一、Bilatral Grid

网络中学习的一个重要参数是bilatetral grid 参数,我们首先看下什么是bilatral grid。参考连接:https://zhuanlan.zhihu.com/p/365874538

bilatral grid 要追根溯源到 bilatral filter,bilatetral filter是说,一个滤波位置的权重和像素位置、像素值(滤波点像素值与中心像素值的差值)都有关系。

w(p,q)=g_{\sigma s} (||p-q||)*g_{\sigma r}(||I(p)-I(q)||)

bilatral filter的计算复杂度很高,Durand发现可以通过如下方法来简化bilatral filter的计算。https://dspace.mit.edu/bitstream/handle/1721.1/34876/MIT-CSAIL-TR-2006-073.pdf?sequence=1

  • 将像素值作为一个额外维度,使得weight计算变成了第三个维度的卷积,提高data locality利用率。
  • 离散化,利用额外channel记录滤波器系数累加和
  • 使用slicing操作重新得到原数据空间的值

我们首先将2D图像,离散化到一个3D的grid中,这个3D grid是一个低分辨率的网络,我们的滤波可以转化到这个3D网络中来进行。

具体滤波过程我们拿论文的几张一维表示图来说明,1D的bilatral filter,经过grid操作后数据变成2D,可以直观画出来,方便我们理解

下图一维情况中,横轴是图像空间,纵轴是升维后的值域空间。即先对空间做下采样,然后过一遍高斯核得到滤波后的分子和分母(滤波中心点是当前x所在的grid),然后把分子除以分母,最后再上采样映射回去,得到去噪后的结果。

bilatral grid是在这篇论文中提出。https://groups.csail.mit.edu/graphics/bilagrid/bilagrid_web.pdf

另外,bilatral grid还有很多妙用,可以调制边缘blur或者sharpen,可以控制上采样等,可以查阅文件检索。

对于理解论文来说,我们需要知道bilatral grid将2D图像离散化到3D网格中进行卷积滤波,得到输出图像。

二、Network Pytorch

我们对着代码继续理解下网络结构和参数规模:

因为原始论文比较早,google当时用tensorflow比较多,官方link的连接是tensorflow版本 https://github.com/mgharbi/hdrnet, 后来也有很多pytorch版本,我们用这个版本的pytorch源码来理解过程https://github.com/creotiv/hdrnet-pytorch/tree/master

我们先关注模型参数的初始化

class HDRPointwiseNN(nn.Module):

    def __init__(self, params):
        super(HDRPointwiseNN, self).__init__()
        self.coeffs = Coeffs(params=params)
        self.guide = GuideNN(params=params)
        self.slice = Slice()
        self.apply_coeffs = ApplyCoeffs()
        # self.bsa = bsa.BilateralSliceApply()

    def forward(self, lowres, fullres):
        coeffs = self.coeffs(lowres)
        guide = self.guide(fullres)
        slice_coeffs = self.slice(coeffs, guide)
        out = self.apply_coeffs(slice_coeffs, fullres)
        # out = bsa.bsa(coeffs,guide,fullres)
        return out

我们再下面,特别留意输出参数的规模(网格和通道)

from turtle import forward
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict
from slice import bilateral_slice


class L2LOSS(nn.Module):

    def forward(self, x,y):
        return torch.mean((x-y)**2)

class ConvBlock(nn.Module):
    def __init__(self, inc , outc, kernel_size=3, padding=1, stride=1, use_bias=True, activation=nn.ReLU, batch_norm=False):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(int(inc), int(outc), kernel_size, padding=padding, stride=stride, bias=use_bias)
        self.activation = activation() if activation else None
        self.bn = nn.BatchNorm2d(outc) if batch_norm else None

        if use_bias and not batch_norm:
            self.conv.bias.data.fill_(0.00)
        # aka TF variance_scaling_initializer
        torch.nn.init.kaiming_uniform_(self.conv.weight)#, mode='fan_out',nonlinearity='relu')
        
    def forward(self, x):
        x = self.conv(x)
        if self.bn:
            x = self.bn(x)
        if self.activation:
            x = self.activation(x)
        return x

class FC(nn.Module):
    def __init__(self, inc , outc, activation=nn.ReLU, batch_norm=False):
        super(FC, self).__init__()
        self.fc = nn.Linear(int(inc), int(outc), bias=(not batch_norm))
        self.activation = activation() if activation else None
        self.bn = nn.BatchNorm1d(outc) if batch_norm else None  
        
        if not batch_norm:
            self.fc.bias.data.fill_(0.00)
        # aka TF variance_scaling_initializer
        torch.nn.init.kaiming_uniform_(self.fc.weight)#, mode='fan_out',nonlinearity='relu')

        
    def forward(self, x):
        x = self.fc(x)
        if self.bn:
            x = self.bn(x)
        if self.activation:
            x = self.activation(x)
        return x

class Slice(nn.Module):
    def __init__(self):
        super(Slice, self).__init__()

    def forward(self, bilateral_grid, guidemap): 
        bilateral_grid = bilateral_grid.permute(0,3,4,2,1)
        guidemap = guidemap.squeeze(1)
        # grid: The bilateral grid with shape (gh, gw, gd, gc).
        # guide: A guide image with shape (h, w). Values must be in the range [0, 1].
        coeefs = bilateral_slice(bilateral_grid, guidemap).permute(0,3,1,2)
        return coeefs
        # Nx12x8x16x16
        # print(guidemap.shape)
        # print(bilateral_grid.shape)
        # device = bilateral_grid.get_device()
        # N, _, H, W = guidemap.shape
        # hg, wg = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)]) # [0,511] HxW
        # if device >= 0:
        #     hg = hg.to(device)
        #     wg = wg.to(device)
        # hg = hg.float().repeat(N, 1, 1).unsqueeze(3) / (H-1)# * 2 - 1 # norm to [-1,1] NxHxWx1
        # wg = wg.float().repeat(N, 1, 1).unsqueeze(3) / (W-1)# * 2 - 1 # norm to [-1,1] NxHxWx1
        # guidemap = guidemap.permute(0,2,3,1).contiguous()
        # guidemap_guide = torch.cat([hg, wg, guidemap], dim=3).unsqueeze(1) # Nx1xHxWx3
        # # When mode='bilinear' and the input is 5-D, the interpolation mode used internally will actually be trilinear. 
        # coeff = F.grid_sample(bilateral_grid, guidemap_guide, 'bilinear')#, align_corners=True)
        # print(coeff.shape)
        # return coeff.squeeze(2)

class ApplyCoeffs(nn.Module):
    def __init__(self):
        super(ApplyCoeffs, self).__init__()

    def forward(self, coeff, full_res_input):

        '''
            Affine:
            r = a11*r + a12*g + a13*b + a14
            g = a21*r + a22*g + a23*b + a24
            ...
        '''

        # out_channels = []
        # for chan in range(n_out):
        #     ret = scale[:, :, :, chan, 0]*input_image[:, :, :, 0]
        #     for chan_i in range(1, n_in):
        #         ret += scale[:, :, :, chan, chan_i]*input_image[:, :, :, chan_i]
        #     if has_affine_term:
        #         ret += offset[:, :, :, chan]
        #     ret = tf.expand_dims(ret, 3)
        #     out_channels.append(ret)

        # ret = tf.concat(out_channels, 3)
        """
            R = r1[0]*r2 + r1[1]*g2 + r1[2]*b3 +r1[3]
        """

        # print(coeff.shape)
        # R = torch.sum(full_res_input * coeff[:, 0:3, :, :], dim=1, keepdim=True) + coeff[:, 3:4, :, :]
        # G = torch.sum(full_res_input * coeff[:, 4:7, :, :], dim=1, keepdim=True) + coeff[:, 7:8, :, :]
        # B = torch.sum(full_res_input * coeff[:, 8:11, :, :], dim=1, keepdim=True) + coeff[:, 11:12, :, :]
        R = torch.sum(full_res_input * coeff[:, 0:3, :, :], dim=1, keepdim=True) + coeff[:, 9:10, :, :]
        G = torch.sum(full_res_input * coeff[:, 3:6, :, :], dim=1, keepdim=True) + coeff[:, 10:11, :, :]
        B = torch.sum(full_res_input * coeff[:, 6:9, :, :], dim=1, keepdim=True) + coeff[:, 11:12, :, :]

        return torch.cat([R, G, B], dim=1)


class GuideNN(nn.Module):
    def __init__(self, params=None):
        super(GuideNN, self).__init__()
        self.params = params
        self.conv1 = ConvBlock(3, params['guide_complexity'], kernel_size=1, padding=0, batch_norm=True)
        self.conv2 = ConvBlock(params['guide_complexity'], 1, kernel_size=1, padding=0, activation= nn.Sigmoid) #nn.Tanh nn.Sigmoid

    def forward(self, x):
        return self.conv2(self.conv1(x))#.squeeze(1)

class Coeffs(nn.Module):

    def __init__(self, nin=4, nout=3, params=None):
        super(Coeffs, self).__init__()
        self.params = params
        self.nin = nin 
        self.nout = nout
        
        lb = params['luma_bins']
        cm = params['channel_multiplier']
        sb = params['spatial_bin']
        bn = params['batch_norm']
        nsize = params['net_input_size']

        self.relu = nn.ReLU()

        # splat features
        n_layers_splat = int(np.log2(nsize/sb))
        self.splat_features = nn.ModuleList()
        prev_ch = 3
        for i in range(n_layers_splat):
            use_bn = bn if i > 0 else False
            self.splat_features.append(ConvBlock(prev_ch, cm*(2**i)*lb, 3, stride=2, batch_norm=use_bn))
            prev_ch = splat_ch = cm*(2**i)*lb

        # global features
        n_layers_global = int(np.log2(sb/4))
        self.global_features_conv = nn.ModuleList()
        self.global_features_fc = nn.ModuleList()
        for i in range(n_layers_global):
            self.global_features_conv.append(ConvBlock(prev_ch, cm*8*lb, 3, stride=2, batch_norm=bn))
            prev_ch = cm*8*lb

        n_total = n_layers_splat + n_layers_global
        prev_ch = prev_ch * (nsize/2**n_total)**2
        self.global_features_fc.append(FC(prev_ch, 32*cm*lb, batch_norm=bn))
        self.global_features_fc.append(FC(32*cm*lb, 16*cm*lb, batch_norm=bn))
        self.global_features_fc.append(FC(16*cm*lb, 8*cm*lb, activation=None, batch_norm=bn))

        # local features
        self.local_features = nn.ModuleList()
        self.local_features.append(ConvBlock(splat_ch, 8*cm*lb, 3, batch_norm=bn))
        self.local_features.append(ConvBlock(8*cm*lb, 8*cm*lb, 3, activation=None, use_bias=False))
        
        # predicton
        self.conv_out = ConvBlock(8*cm*lb, lb*nout*nin, 1, padding=0, activation=None)#,batch_norm=True)

   
    def forward(self, lowres_input):
        params = self.params
        bs = lowres_input.shape[0]
        lb = params['luma_bins']
        cm = params['channel_multiplier']
        sb = params['spatial_bin']

        x = lowres_input
        for layer in self.splat_features:
            x = layer(x)
        splat_features = x
        
        for layer in self.global_features_conv:
            x = layer(x)
        x = x.view(bs, -1)
        for layer in self.global_features_fc:
            x = layer(x)
        global_features = x

        x = splat_features
        for layer in self.local_features:
            x = layer(x)        
        local_features = x

        fusion_grid = local_features
        fusion_global = global_features.view(bs,8*cm*lb,1,1)
        fusion = self.relu( fusion_grid + fusion_global )

        x = self.conv_out(fusion)
        s = x.shape
        y = torch.stack(torch.split(x, self.nin*self.nout, 1),2)
        # y = torch.stack(torch.split(y, self.nin, 1),3)
        # print(y.shape)
        # x = x.view(bs,self.nin*self.nout,lb,sb,sb) # B x Coefs x Luma x Spatial x Spatial
        # print(x.shape)
        return y


class HDRPointwiseNN(nn.Module):

    def __init__(self, params):
        super(HDRPointwiseNN, self).__init__()
        self.coeffs = Coeffs(params=params)
        self.guide = GuideNN(params=params)
        self.slice = Slice()
        self.apply_coeffs = ApplyCoeffs()
        # self.bsa = bsa.BilateralSliceApply()

    def forward(self, lowres, fullres):
        coeffs = self.coeffs(lowres)
        guide = self.guide(fullres)
        slice_coeffs = self.slice(coeffs, guide)
        out = self.apply_coeffs(slice_coeffs, fullres)
        # out = bsa.bsa(coeffs,guide,fullres)
        return out

三、Feature Visual

发表评论