大家好,又见面了,我是你们的朋友全栈君。
目录
- STN的作用 1.1 灵感来源 1.2 什么是STN?
- STN的基本架构
- Localisation net是如何实现参数的选取的? 3.1 实现平移 3.2 实现缩放 3.3 实现旋转 3.4 实现剪切 3.5 小结
- Grid generator实现像素点坐标的对应关系 4.1 为什么会有坐标的问题? 4.2 仿射变换关系
- Sampler实现坐标求解的可微性 5.1 小数坐标问题的提出 5.2 解决输出坐标为小数的问题 5.3 Sampler的数学原理
- Spatial Transformer Networks(STN)
- STN 实现代码
- reference
1.STN的作用
1.1 灵感来源
普通的CNN能够显示的学习平移不变性,以及隐式的学习旋转不变性,但attention model 告诉我们,与其让网络隐式的学习到某种能力,不如为网络设计一个显式的处理模块,专门处理以上的各种变换。因此,DeepMind就设计了Spatial Transformer Layer,简称STL来完成这样的功能。
1.2 什么是STN?
关于平移不变性 ,对于CNN来说,如果移动一张图片中的物体,那应该是不太一样的。假设物体在图像的左上角,我们做卷积,采样都不会改变特征的位置,糟糕的事情在我们把特征平滑后后接入了全连接层,而全连接层本身并不具备 平移不变性 的特征。但是 CNN 有一个采样层,假设某个物体移动了很小的范围,经过采样后,它的输出可能和没有移动的时候是一样的,这是 CNN 可以有小范围的平移不变性 的原因。
如图所示,如果是手写数字识别,图中只有一小块是数字,其他大部分地区都是黑色的,或者是小噪音。假如要识别,用Transformer Layer层来对图片数据进行旋转缩放,只取其中的一部分,放到之后然后经过CNN就能识别了。
我们发现,它其实也是一个layer,放在了CNN的前面,用来转换输入的图片数据,其实也可以转换feature map,因为feature map说白了就是浓缩的图片数据,所以Transformer layer也可以放到CNN里面。
2. STN的基本架构
如图是Spatial Transformer Networks的结构,主要的部分一共有三个,它们的功能和名称如下:
参数预测:Localisation net
参数预测:Localisation net
坐标映射:Grid generator
坐标映射:Grid generator
像素的采集:Sampler
像素的采集:Sampler为了让大家对这三个部分有一个先验知识,我先简单介绍一下。
如下图是完成的一个平移的功能,这其实就是Spatial Transformer Networks要做一个工作。
假设左边是 Layer l−1 Layer l-1的输出,也就是当前要做Transform的输入,最右边为Transform后的结果。这个过程是怎么得到的呢?
假设是一个全连接层,n,m代表输出的值在输出矩阵中的下标,输入的值通过权值w,做一个组合,完成这样的变换。
- 举个例子,假如要生成 al11 a_{11}^l,那就是将左边矩阵的九个输入元素,全部乘以一个权值,加权相加: al11=wl1111al−111 wl1112al−112 wl1113al−113 ⋯ wl1133al−133 a_{11}^l=w_{1111}^la_{11}^{l -1} w_{1112}^la_{12}^{l-1} w_{1113}^la_{13}^{l-1} cdots w_{1133}^la_{33}^{l-1}这仅仅是 al11 a_{11}^l的值,其他的结果也是这样算出来的,用公式表示称如下这样:
通过调整这些权值,达到缩放,平移的目的,其实这就是Transformer的思想。
在这个过程中,我们需要面对三个主要的问题:
- 这些参数应该怎么确定?
- 图片的像素点可以当成坐标,在平移过程中怎么实现原图片与平移后图片的坐标映射关系?
- 参数调整过程中,权值一定不可能都是整数,那输出的坐标有可能是小数,但实际坐标都是整数的,如果实现小数与整数之间的连接?
其实定义的三个部分,就是专门为了解决这几个问题的,接下来我们一个一个看一下怎么解决。
3.Localisation net是如何实现参数的选取的?
3.1 实现平移
如果是平移变换,比如从 al−111平移到al21 a_{11}^{l-1}平移到a_{21}^{l},得到 al21 a_{21}^l的表示为:
al21=wl2111al−111 wl2112al−112 wl2113al−113 ⋯ wl2133al−133
a_{21}^l=w_{2111}^la_{11}^{l -1} w_{2112}^la_{12}^{l-1} w_{2113}^la_{13}^{l-1} cdots w_{2133}^la_{33}^{l-1}我们可以令 wl2111=1 w_{2111}^l=1,其余均为0,不就得到了
al21=1∗al−111
a_{21}^{l}=1*a_{11}^{l-1}这就完成了平移了吗?其他的平移也可以用类似的方法来做到。
你可能会问了,那我该怎么得到这些权值呢?总不能人工去看吧! 当然不会,我们可以设置一个叫做NN这类的东西,把 Layer l−1 Layer l-1的输出放到NN里,然后生成一系列w。这样听起来好玄乎,但确实是可以这么做的。
3.2 实现缩放
其实缩放也不难,如图所示,如果要把图放大来看,在 x→(X2)→x′ x rightarrow (X2) rightarrow x', y→(X2)→y′ y rightarrow (X2) rightarrow y'将其同时乘以2,就达到了放大的效果了,用矩阵表示如下:
缩小也是同样的原理,如果把这张图放到坐标轴来看,就是如图所示,加上偏执值0.5表示向右,向上同时移动0.5的距离,这就完成了缩小。
3.3 实现旋转
既然前面的平移和缩放都是通过权值来改的,那旋转其实也是。但是旋转应该用什么样的权值呢? 仔细思考,不难发现,旋转是跟角度有关系的,那什么跟角度有关系呢? 正弦余弦嘛,为什么它们能做旋转呢? 一个圆圈的角度是360度,可以通过控制水平和竖直两个方向,就能控制了,如图所示。
由点A旋转 θ theta度角,到达点B.得到
x′=Rcosα
x'=Rcosalpha
y′=Rsinα
y'=Rsin alpha由A点得
x=Rcos(α θ)
x=R cos(alpha theta)
y=Rsin(α θ)
y=Rsin(alpha theta)展开,有:
x=Rcosα cosθ−Rsinα sinθ
x=Rcosalpha costheta - Rsinalpha sintheta
y=Rsinα cosθ Rcosα sinθ
y=Rsinalpha costheta Rcosalpha sintheta把未知数 α alpha替换掉
x=x′ cosθ−y′sinθ
x=x' costheta - y'sintheta
y=y′cosθ x′sinθ
y=y'costheta x'sintheta 我们可以简单的理解为 cosθ,sinθ costheta,sintheta就是控制这样的方向的,把它当成权值参数,写成矩阵形式,就完成了旋转操作。
注:如果想了解正余弦控制方向是怎么导出的,可以参考计算机图形学的相关书籍,一般都有介绍和数学公式的推导。
3.4 实现剪切
剪切变换相当于将图片沿x和y两个方向拉伸,且x方向拉伸长度与y有关,y方向拉伸长度与x有关,用矩阵形式表示前切变换如下:
3.5 小结
由此,我们发现所有的这些操作,只需要六个参数[2X3]控制就可以了,所以我们可以把feature map U作为输入,过连续若干层计算(如卷积、FC等),回归出参数θ,在我们的例子中就是一个[2,3]大小的6维仿射变换参数,用于下一步计算;
4.Grid generator实现像素点坐标的对应关系
4.1 为什么会有坐标的问题?
由上面的公式,可以发现,无论如何做旋转,缩放,平移,只用到六个参数就可以了,如图所示:
这6个参数,就足以完成我们需要的几个功能了。
而缩放的本质,其实就是在原样本上采样,拿到对应的像素点,通俗点说,就是输出的图片(i,j)的位置上,要对应输入图片的哪个位置?
如图所示旋转缩放操作,我们把像素点看成是坐标中的一个小方格,输入的图片 U∈RHxWxC U in R^{HxWxC}可以是一张图片,或者feature map,其中H表示高,W表示宽,C表示颜色通道。经过变换 Tθ(G) T_{theta}(G), θ theta是上一个部分(Localisation net)生成的参数,生成了图片 V∈RH′xW′xC V in R^{H'xW'xC},它的像素相当于被贴在了图片的固定位置上,用 G=Gi G={G_i}表示,像素点的位置可以表示为 Gi={ xti,yti} G_i={x_i^t,y_i^t}这就是我们在这一阶段要确定的坐标。
4.2 仿射变换关系
因此定义了如图的一个坐标矩阵变换关系:
(xti,yti) (x_i^t,y_i^t)是输出的目标图片的坐标, (xsi,ysi) (x_i^s,y_i^s)是原图片的坐标, Aθ A_{theta}表示仿射关系。
但仔细一点,这有一个非常重要的知识点,千万别混淆,我们的坐标映射关系是:
从目标图片→原图片
从目标图片 rightarrow 原图片也就是说,坐标的映射关系是从目标图片映射到输入图片上的,为什么这样呢?
作者在论文中写的比较模糊,比较满意的解释是坐标映射的作用,其实是让目标图片在原图片上采样,每次从原图片的不同坐标上采集像素到目标图片上,而且要把目标图片贴满,每次目标图片的坐标都要遍历一遍,是固定的,而采集的原图片的坐标是不固定的,因此用这样的映射。
举个自我感觉很贴切的小例子说一下吧。
- 如图所示,假设只有平移变换,这个过程就相当于一个拼图的过程,左图是一些像素点,右图是我们的目标,我们的目标是确定的,目标图的方框是确定的,图像也是确定的,这就是我们的目标,我们要从左边的小方块中拿一个小方块放在右边的空白方框上,因为一开始右边的方框是没有图的,只有坐标,为了确定拿过来的这个小方块应该放在哪里,我们需要遍历一遍右边这个方框的坐标,然后再决定应该放在哪个位置。所以每次从左边拿过来的方块是不固定的,而右边待填充的方框却是固定的,所以定义从 目标图片→原图片 目标图片 rightarrow 原图片的坐标映射关系更加合理,且方便。
5.Sampler实现坐标求解的可微性
5.1 小数坐标问题的提出
我们可以假设一下我们的权值矩阵的参数是如下这几个数,x,y分别是他们的下标,经过变换后,可以得到如下这样的对应。
前面举的例子中,权值都是整数,那得到的也必定是整数,如果不是整数呢? 如图所示:
假如权值是小数,拿得到的值也一定是小数,1.6,2.4,但是没有元素的下标索引是小数呀。那不然取最近吧,那就得到2,2了,也就是与 al22 a_{22}^l对应了。 那这样的方法能用梯度下降来解吗?
5.2 解决输出坐标为小数的问题
用上面的四舍五入显然是不能进行梯度下降来回传梯度的。 为什么呢? 梯度下降是一步一步调整的,而且调整的数值都比较小,哪怕权值参数有小范围的变化,虽然最后的输出也会有小范围的变化,比如一步迭代后,结果有:
1.6→1.64,2.4→2.38
1.6 rightarrow 1.64,2.4 rightarrow 2.38但是即使有这样的改变,结果依然是:
al−122→al22
a_{22}^{l-1} rightarrow a_{22}^l的对应关系没有一点变化,所以output依然没有变,我们没有办法微分了,也就是梯度依然为0呀,梯度为0就没有可学习的空间呀。所以我们需要做一个小小的调整。
仔细思考一下这个问题是什么造成的,我们发现其实在推导SVM的时候,我们也遇到过相同的问题,当时我们如果只是记录那些出界的点的个数,好像也是不能求梯度的,当时我们是用了hing loss,来计算一下出界点到边界的距离,来优化那个距离的,我们这里也类似,我们可以计算一下到输出[1.6,2.4]
附近的主要元素,如下所示,计算一下输出的结果与他们的下标的距离,可得:
然后做如下更改:
他们对应的权值都是与结果对应的距离相关的,如果目标图片发生了小范围的变化,这个式子也是可以捕捉到这样的变化的,这样就能用梯度下降法来优化了。
5.3 Sampler的数学原理
论文作者对我们前面的过程给出了非常严密的证明过程,以下是我对论文的转述。
每次变换,相当于从原图片 (xsi,ysi) (x_i^s,y_i^s)中,经过仿射变换,确定目标图片的像素点坐标 (xti,yti) (x_i^t,y_i^t)的过程,这个过程可以用公式表示为:
(注:把一张图片展开,相当于把矩阵变成坐标向量) kernel k表示一种线性插值方法,比如双线性插值,更详细的请参考:(线性插值,双线性插值Bilinear Interpolation算法), ϕx,ϕy phi_x,phi_y表示插值函数的参数; Ucnm U^c_{nm}表示位于颜色通道C中坐标为(n,m)的值。
如果使用双线性插值,可以有:
为了允许反向传播回传损失,我们可以求对该函数求偏导:
对于 ysi y_i^s的偏导也类似。
如果就能实现这一步的梯度计算,而对于 ∂xsi∂θ,∂ysi∂θ frac{partial x_i^s}{partial theta},frac{partial y_i^s}{partial theta}的求解也很简单,所以整个过程
Localisation net←Grid generator←Sampler
Localisation net leftarrow Grid generatorleftarrow Sampler的梯度回转就能走通了。
6.Spatial Transformer Networks(STN)
将这三个组块结合起来,就构成了完整STN网络结构了。
这个网络可以加入到CNN的任意位置,而且相应的计算量也很少。
将 spatial transformers 模块集成到 cnn 网络中,允许网络自动地学习如何进行 feature map 的转变,从而有助于降低网络训练中整体的代价。定位网络中输出的值,指明了如何对 每个训练数据进行转化。
7.STN 实现代码
相应的代码已经有人实现了,我就不做重复工作了。 请参考:Spatial Transformer Networks Torch code Theano code
8.reference
原论文 Spatial Transformer Spatial Transformer Networks 卷积神经网络结构变化——Spatial Transformer Networks 三十分钟理解:线性插值,双线性插值Bilinear Interpolation算法 Spatial Transformer Networks 笔记 李宏毅老师的视频讲解
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/180194.html原文链接:https://javaforall.cn