多曝光融合网络Loss设计

对于包围式曝光融合网络, Loss设计如下:

\[ L(y, y^{*})=\left \| y - y^* \right \|_{1} + \left \| \nabla y - \nabla y^* \right \|_{1} + \gamma L_{gradsign}(y, y^*) + \eta L_{highsmooth}(y, y^*) \]

其中\(y\)为融合得到的图像, \(y^*\)表示融合图像的标签, \(\left \| y - y^* \right \|_{1}\)表示L1范数, 即向量元素绝对值之和, \(\nabla\)表示求图像的梯度, \(𝐿_{𝑔𝑟𝑎𝑑𝑠𝑖𝑔𝑛}\)损失函数的作用是为了学习得到的图像的边缘过渡和真实图像贴合得更好, 消除由于包围式曝光不连续导致的图像拼接人工痕迹. \(𝐿_{ℎ𝑖𝑔ℎ𝑠𝑚𝑜𝑜𝑡ℎ}\)损失函数表示加大高亮区域的权重, 使得高亮区域效果更好, 缓解图像中高亮区域占整张图像比重太小的问题, \(𝐻_{𝑡ℎ}\)是高亮像素值的阈值; \(\gamma\), \(\eta\)为常数, 表示损失函数权重, 可依效果需求要求进行调整.

\[ L_{gradsign}(y,y*)=ReLU(-\nabla y * \nabla y^{*}) \]

\[ L_{highsmooth}(y, y^*)=\left \| \nabla y - \nabla y^* \right \|_1 * I(y^*, H_{th}) \]

\[ ReLU(x)= \begin{cases}x, x \geqslant 0 \\ 0, else \end {cases} \]

\[ I(x,threshold)= \begin{cases}1, x \geqslant threshold \\ 0, else \end {cases} \]

另外还有两个Loss函数设计.

  1. LumaGradLoss函数的操作是把prediction和groudtruth分别做一个global average pooling操作downsample到8x8的图, 然后在计算L2 loss. 目的是保证prediction的整体亮度分布和groundtruth一致, 避免出现亮度反转的情况.

  2. BlackRegLoss函数的操作是根据图像亮度值小于某个阈值threshold得到mask, 然后惩罚这些区域的1/4, 1/16等欠曝图像的融合权重. 目的是避免暗光区域受到欠曝图片的影响, 避免出现融合后图像天空等非过曝区域亮度值被拉暗的情况.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def LumaGradLoss(self, prediction, target, input):
prediction = torch.nn.AdaptiveAvgPool2d((8,8))(prediction)
target = torch.nn.AdaptiveAvgPool2d((8,8))(input[:,0:4,:,:])

target = torch.clamp(target,0.0,1.0)
prediction = torch.clamp(prediction,0.0,1.0)

prediction_x_diff = prediction[:,:, :-1, :-1] - prediction[:, :, :-1, 1:]
prediction_y_diff = prediction[:,:, :-1, :-1] - prediction[:, :, 1:, :-1]
target_x_diff = target[:, :, :-1, :-1] - target[:,:, :-1, 1:]
target_y_diff = target[:, :, :-1, :-1] - target[:,:, 1:, :-1]

error = (target_x_diff-prediction_x_diff) ** 2.0 + (target_y_diff-prediction_y_diff) ** 2.0

loss = torch.mean(error)

return loss
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def BlackRegLoss(self, logit, target,input_refs,  weight):
'''
black region regularization
penalty on 1/4, 1/16 in order to persuade network selecting more on 1/1 images
'''

input_1 = input_refs[:, -12:-8, :, :]
raw_grey,_ = torch.max(input_1, dim=1, keepdim=True)

low_mask = torch.where(raw_grey <= 0.4, torch.ones_like(raw_grey).cuda(), torch.zeros_like(raw_grey).cuda())

p2d = [0,0,1,0]
tv_mask_h = nn.ReflectionPad2d(p2d)(target)
tv_mask_h = tv_mask_h[:,:,1:,:] - tv_mask_h[:,:,:-1,:]
p2d = [1,0,0,0]
tv_mask_w = nn.ReflectionPad2d(p2d)(target)
tv_mask_w = tv_mask_w[:,:,:,1:] - tv_mask_w[:,:,:,:-1]
tv_mask = torch.max(torch.abs(tv_mask_h), torch.abs(tv_mask_w))
tv_mask,_ = torch.max(tv_mask, dim = 1, keepdim = True)
tv_mask = 1.0 - tv_mask

low_mask = self.tf_high_light_dilate_blur(low_mask * tv_mask, 5)

if weight.shape[1] == 3 or weight.shape[1] == 4:
return torch.sum(weight[:,1:3,:,:] * low_mask) / (torch.sum(low_mask) * (2.0 )) # only loss on 1/4, 1/16 frames
elif weight.shape[1] == 12 or weight.shape[1] == 13:
return torch.sum(weight[:,4:4*3,:,:] * low_mask) / (torch.sum(low_mask) * (2.0 * 4)) # only loss on 1/4, 1/16 frames
else:
raise 'Unsupport weight type'