RealSR性能大幅提升,旷视科技+快手科技+电子科大联合提出“先发散再收敛”的D2CSR

2021-04-29 14:19:49 浏览数 (1)

点击下方卡片,关注“AIWalker”公众号

重磅干货,第一时间送达

标题&作者团队

本文是旷视科技&快手科技&电子科大联合提出的一种新颖的图像超分框架。本文从图像超分“病态”特性出发,提出一种两阶段的超分框架。在divergence阶段,本文构建了一种新颖的树状深度网络用于输出具有发散性质的预测结果,为达到该效果,引入triplet损失;在convergence阶段,我们采用数据自适应加权方式对divergence分支的结果进行融合得到最终的输出。此外,本文还提出了一个新的用于x8超分任务的Real-world数据集D2CRealSR。所提方法在RealSR、DRealSR以及D2cRealSR等数据集上取得了大幅超越其他方案的效果,相比次优模型CDC,所提方法取得了0.2~0.5dB不等的性能提升。

Abstract

本文提出一种新颖的框架D2C-SR用于图像超分。图像超分作为一种病态问题,其关键挑战在于:给定低分辨率输入存在多个合理预测。大部分经典方法与早期深度学习方法忽略了该基本事实,将图像超分建模为确定性处理,这就导致不理想结果。

受近期工作(如SRFLow)启发,我们采用半概率方式解决该问题,提出一种两阶段方案:divergence阶段采用离散形势学习潜在高分辨率输出分布;convergence阶段则用于将所学习预测融合为最终的输出。更具体来首,我们提出一种树状深度网络,每个分支用于学习一种可能的高分辨率预测。训练过程中,每个分支分别训练以拟合Ground-truth,我们采用triple损失迫使不同分支的输出具有发散性。然后,我们添加一个融合模块合并多个输出作为最终的输出,融合模块可以通过端到端方式训练。

我们在多个基准数据集上进行了评估,并提出了一个8x超分数据集。实验结果表明:所提D2C-SR取得了SOTA性能(PSNR与SSIM),且具有更少的恶计算量。

本文主要贡献包含以下几点:

  • 提出一种新颖的版概率框架D2C-SR用于图像超分,它包含divergence与convergence两个阶段;
  • 提出采用半监督方式训练divergence阶段模型,它采用triplet损失迫使预测的发散性;
  • 在多个主流基准数据集上,D2C-SR取得了SOTA性能;
  • 构建了一个8x超分数据集D2CRealSR;
  • 我们对D2C-SR框架中的不用设计选择进行了深度分析。

Method

framework

上图给出了本文所提D2C-SR框架示意图,它包含两个阶段:divergence与convergence。接下来,我们将针对这两个阶段进行相似介绍。

Divergence Network

在所提Divergence网络中,我们通过显示地设计一个具有发散多输出的网络解决图像超分这个病态问题。具体来说,我们设计了一个树状结构网络得到期望预测。

该树状深度网络包含三个主要模块:

  • 浅层特征提取模块:它由一个卷积构成;
  • 基础分支模块:它由多个残差组构成,每个残差组包含B个残差通道注意力模块(类似RCAN);
  • 上采样模块:它由卷积 pixelshuffle构成。

该网络有L个分支,每个分支由基础分支模块构成并包含C个子分支。以前面图示为例,divergence网络从浅层特征提取模块开始,然后按照树状结构逐层通过网络得到输出。需要注意的是:每个分支的权值不进行共享。divergence网络生成了P个预测结果,它们具有不同的高频成分。这些预测可以表示为:

I_D = F(I_{LR};Theta_D)

Deep Residual Structure 我们在divergence网络中构建了相对深的残差,这种深度残差结构使得每个分支可以学习深度残差特征,它的子分支可以学习更深的残差结果。每个分支聚焦于学习比父分支更进一步的残差,进而促进高频学习。

Divergence loss divergence网络中的发散损失由

L_2

损失与triplet损失构成。每个预测结果

I_D^i

与HR图像计算

L_2

损失并相加构成最终的

L_2

损失,定义如下:

L_2^D = sum_{i=1}^P | I_D^i - I_{HR} |_2

为使得divergence网络生成更发散结果,我们采用了triplet损失。我们目标在于使得

I_D^i

与HR尽可能相近,且两两之间距离变远。然而直接在RGB空间使用triplet损失会导致网络聚焦于学习其他不同的方向(比如亮度)而非纹理。因此,我们提出了对

I_D^i

进行如下处理:

G(I_D^i) = frac{Y_D^i - mu_{Y_D^i}}{sigma_{Y_D^i}}

注:Y表示YCbCr空间中的Y通道。上述操作使得网络聚焦于学习纹理特征差异。由于超分病态问题主要源于高频区域,因此我们在残差域计算triplet损失,残差定义如下:

res_{I_D^i} = |G(I_D^i) - G(I_{HR}) |

triplet损失定义如下:

trip(a,p,n) = Max[d(a,p) - d(a,n) margin, 0]

因此,最终的损失定义如下:

T_D = frac{sum_{i=1}^P sum_{i=1,jne i}^P beta_{ij} * trip(res_{I_D^i}, zero, res_{I_D^j})}{P(P-1)}

注:

beta_{ij} = theta^{l-1},lin[1,L]

表示注意力系数,它用于控制不同分支的相似性。最终的总体损失则定义如下:

L_D = L_2^D alpha * T_D

Convergence Network

组合divergence网络的多个输出可以生成更精确的结果。我们认为不同分支的预测对于最终结果具有不同的贡献,因此我哦们构建了convergence网络采用加权方式组合divergence网络的多个输出。

convergence网络采用divergence网络的M个输出作为输入,输出每个预测的权值,定义如下:

W = F(Concat(I_D); Theta_C)

然后采用所得权值与divergence网路输出加权得到最终的结果:

I_{SR} = sum_{i=1}^P(I_D^i cdot W_i)

从上图可以看到:在合并过程中,不同分支在不同区域具有不同的加权权值。

Convergence loss convergence网络的目标是合并divergence网络的输出,因此该网络的损失称之为convergence损失,它仅仅包含

L_2

损失。定义如下:

L_2^C = | I_{SR} - I_{HR} |_2

Training Strategy

该框架的两个网络分别训练,我们首先训练divergence网络到稳定状态;然后固定divergence网络参数训练convergence网络。

Experiments

D2CRealSR 现有RealSR数据仅仅包含x2、x3与x4倍率数据,缺乏更大倍率数据。我们构建了一个x8倍率数据D2CRealSR,它包含115图像对,其中15个用于测试,其他用作训练。

Existing Dataset 现有Real-world超分数据有RealSR与DRealSR两个。由于DRealSR部分数据存在不对齐问题,因此我们仅仅在DRealSR的测试集上进行验证性能。

Implementation Detail 实验过程中,

L=2,C=2,G2,B=4

。优化器为Adam,初始学习率0.0001,每个2000epoch折半,LR的图像块尺寸为

96times 96

sota

上表给出了不同方法在不同数据集上的性能对比,从中可以看到:在不同倍率下,所提方法均大幅优于其他方案,超出次优模型0.2~0.5dB不等。

x4

上图给出了DRealSR数据集上X4超分的视觉效果对比,下图给出了D2CRealSR数据集上X8超分的视觉效果对比。可以看到:相比其他方案,所提方法可以复原更多高频细节

x8

model-size

上图给出了RealSR数据集上不同模型大小性能的对比,可以看到:

  • 所提0.23M模型可以取得更好的性能;而基线5.88M参数模型可以取得更高的指标,具有更好的模型大小与性能均衡。
  • 在同等PSNR水平下,CDC需要39.92M参数量,RCAN需要15M参数量。

width-depth

上表给出了不同深度、不同宽度模型的性能对比,可以看到:深度与宽度的提升均可以带来一致性的性能提升。

loss

上表给出了不同损失下的模型性能对比,可以看到:

  • 不带convergence损失时,模型性能均出现了显著下降,x2超分指标下降0.17dB
  • 单一分支的性能明显低于convergence网络的性能
  • 树结构有助于网路聚焦高频信息学习,移除树状结构后,模型的性能出现了显著的下降,比如x2任务下降了0.12dB。

branches

上图给出了不同分支的视觉效果对比,由于convergence损失的恶存在,不同分支的预测具有不同的高频预测。

0 人点赞