超分辨率Loss设计记录(1)

在超分辨率(Super-Resolution)网络训练中使用均方误差(MSE)损失函数确实存在一个常见问题, 即可能导致生成的结果过于平滑和模糊 这是因为MSE损失函数鼓励模型生成像素值, 以使其与目标图像的像素值之间的平方差最小化这种最小化平方差的方法有时会导致图像细节的丢失, 使得生成的高分辨率图像看起来过于平滑.

相关材料可以查看SRGAN中Figure2以及1.1.3小节. 虽然直接优化MSE可以产生较高的PSNR/SSIM, 但是在zoom scale较大的情况下, MSE作为loss function引导的学习无法使得重建图像捕获细节信息, 从论文Figure2中可以看到, 左二图有较高的PSNR/SSIM, 但是从观感上判断, 左三图明显具有更多的细节.

为了解决这个问题, 通常在超分辨率网络中使用其他损失函数或技术, 以更好地保留细节和纹理 以下是一些替代方法:

  1. 感知损失(Perceptual Loss):使用感知损失, 通常是使用预训练的深度卷积神经网络(如VGG)来计算生成图像与目标图像之间的特征表示的差异 这种方法更强调图像的结构和纹理, 而不仅仅是像素值 这有助于生成更具细节和真实感的图像

  2. 对抗性损失(Adversarial Loss):引入生成对抗网络(GAN)的方法, 其中生成器网络和判别器网络相互竞争 生成器的目标是欺骗判别器, 而判别器的目标是区分生成图像和真实图像 这种对抗性训练有助于生成更逼真的图像

  3. 内容和风格损失(Content and Style Loss):结合感知损失和风格损失, 以确保生成的图像在内容和风格上都与目标图像相似

  4. 自适应权重调整:使用动态权重或损失加权方法, 以平衡不同类型的损失函数, 以便在训练过程中更好地控制平滑度和细节

这些方法可以帮助超分辨率网络生成更锐利和更具细节的高分辨率图像, 而不仅仅是通过最小化像素级差异来生成图像 选择哪种方法取决于具体的问题和数据集, 以及对生成图像的期望质量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn as nn
import torchvision.models as models

#SRGAN使用预训练好的VGG19,用生成器的结果以及原始图像通过VGG后分别得到的特征图计算MSE,具体解释推荐看SRGAN的相关资料
class VGG(nn.Module):
def __init__(self, device):
super(VGG, self).__init__()
vgg = models.vgg19(True)
for pa in vgg.parameters():
pa.requires_grad = False
self.vgg = vgg.features[:16]
self.vgg = self.vgg.to(device)

def forward(self, x):
out = self.vgg(x)
return out

#内容损失
class ContentLoss(nn.Module):
def __init__(self, device):
super().__init__()
self.mse = nn.MSELoss()
self.vgg19 = VGG(device)

def forward(self, fake, real):
feature_fake = self.vgg19(fake)
feature_real = self.vgg19(real)
loss = self.mse(feature_fake, feature_real)
return loss

#对抗损失
class AdversarialLoss(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
loss = torch.sum(-torch.log(x))
return loss

#感知损失
class PerceptualLoss(nn.Module):
def __init__(self, device):
super().__init__()
self.vgg_loss = ContentLoss(device)
self.adversarial = AdversarialLoss()

def forward(self, fake, real, x):
vgg_loss = self.vgg_loss(fake, real)
adversarial_loss = self.adversarial(x)
return vgg_loss + 1e-3*adversarial_loss

#正则项,需要说明的是,在SRGAN的后续版本的论文中,这个正则项被删除了
class RegularizationLoss(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
a = torch.square(
x[:, :, :x.shape[2]-1, :x.shape[3]-1] - x[:, :, 1:x.shape[2], :x.shape[3]-1]
)
b = torch.square(
x[:, :, :x.shape[2]-1, :x.shape[3]-1] - x[:, :, :x.shape[2]-1, 1:x.shape[3]]
)
loss = torch.sum(torch.pow(a+b, 1.25))
return loss