再谈注意力机制 | 运用强化学习实现目标特征提取

2022-09-23 11:09:55 浏览数 (1)

  • 论文题目:Recurrent Models of Visual Attention
  • 论文链接:http://www.oalib.com/paper/4082117

作者及单位

研究目标

研究如何减少图像相关任务的计算量, 提出通过使用attention based RNN 模型建立序列模型(recurrent attention model, RAM), 每次基于上下文和任务来适应性的选择输入的的image patch, 而不是整张图片, 从而使得计算量独立于图片大小, 从而缓解CNN模型中计算量与输入图片的像素数成正比的缺点. 该文通过强化学习的方式来学习任务明确的策略, 从而解决模型是不可微的问题.

RAM 模型在几个图像分类任务上,在处理杂乱图像(cluttered images)时, 它明显优于基于CNN的模型,并且在动态视觉控制问题上,无需明确的训练信号, 它就能学习跟踪一个简单的对象。

Introduction

该文将注意力问题视为与视觉环境交互时以目标为导向的序列决策过程。

人类感知的一个重要特性是人们不会倾向于一次完整地处理整个场景。相反,人们将注意力有选择地集中在视觉空间的某些部分,以便在需要的时间和地点获取信息,并随着时间的推移组合来自不同固定位置(fixation)的信息,以建立场景的内部表示,指导下一步眼睛看下哪里以及决策。将计算资源聚焦在场景的各部分上节省了“带宽”,因为需要处理的“像素”更少。但它也大大降低了任务复杂性,因为感兴趣的对象可以置于固定位置(fixation)的中心,并且固定区域外的视觉环境(“混乱”)的不相关特征自然被忽略。

模型架构

attention注意力机制网络架构如下图所示:

该模型架构主要由Glimpse NetworkLocation NetworkCoreNetwork三部分网络组成。其中Glimpse Network主要由由Glimpse Sensor组成。

如上图所示,agent围绕一个递归神经网络构建。在每个时间步骤中,它处理传感器数据,随着时间的推移集成信息,并在下一次时间步骤中选择如何操作和如何部署传感器。过程主要是通过强化学习实现的,下面介绍主要部件:

  • Sensor:在每个步骤t中,agent接受到一个输入图像Xt的环境,agent没有完全访问这个图像,而是通过信息带宽有限的传感器ρ提取信息。如通过传感器在某些地区或感兴趣的频段。
  • Internal state:agent保持一种内部状态,该状态汇总从过去的观察历史中提取的信息,它对代理的环境进行编码,并有助于决定如何操作和在何处部署传感器。该内部状态由递归神经网络的隐藏单元ht组成,通过下面要介绍的它对代理的环境知识进行编码,并有助于决定如何操作和在何处部署传感器CoreNetwork进行更新。网络的外部输入是Glimpse sensor输出向量gt
  • Action:每一步骤中,agent主要执行两个动作:他决定如何通过传感控制器
l_{t}

部署传感器以及一个可能影响环境状态的动作

a_{t}

。动作

a_{t}

由一个分布得出,该分布以输出

a_{t}~p(.|f_{a}(h_{t};theta_{a}))

为条件。

  • Reward:跟强化学习里面reward的设置是一模一样的,因此此处不再赘述,具体可以参照强化学习reward的设置。即
sum_{t=1}^{T} r

训练

agent的参数由Glimpse NetworkCoreNetworkaction Network这些网络的参数

theta={theta_{g},theta_{h};theta_{a}}

组成。agent通过与环境互动来使总的Reward最大化来更新参数。

Reward的最大化形式如下公式所示:

要使Reward的方差最小,可以把上述公式转化为:

loss采用混合监督损失。

Experiments

我们评估了我们的方法在几个图像分类任务以及一个简单的游戏。我们首先描述了我们所有实验中常见的设计选择:

  • Retina and location encodings: retina encoding ρ(x, l)从中间位置L提取k^2的像素。第一个patch的大小为gw×gw像素,每个后续patch的宽度都是之前的两倍。最后将k个patch串联起来.
  • Glimpse network:
f_{g}(x,l)

有两个全连接层。g是该网络的输出。

g = Rect(Linear(hg) Linear(hl))
hg= Rect(Linear(ρ(x, l)))
hl= Rect(Linear(l))
  • Location network:位置l的policy由具有固定方差的双分量高斯分布定义。位置网络输出时刻位置策略的均值,定义为
fl(h) = Linear(h)

,其中h为核心网络/RNN的状态。

  • Core network:在分类任务中,fh定义为
ht= fh(ht−1) = Rect(Linear(ht−1) Linear(gt))

在动态环境中,fh是由LSTM单元组成。

mnist手写字母图像识别结果

来自MNIST测试集的输入图像,其中Glimpse路径以绿色(正确分类)或红色(错误分类)覆盖。

第2-7栏:网络选择的6个亮点。每个图像的中心显示全分辨率的一瞥,外部低分辨率区域是通过将低分辨率的一瞥放大到全图像的尺寸来获得的。瞥见路径清楚地表明,学习策略避免了在输入空间的空或噪声部分进行计算,并直接探索了感兴趣对象周围的区域。

结论

介绍了一种新颖的视觉注意力模型。制定作为一个以一睹窗口为递归神经网络输入和使用网络的内部状态来选择下一个位置关注以及生成控制信号在动态环境中。虽然模型是不可微的,但是所提出的统一架构是使用策略梯度方法从像素输入到操作端到端进行训练的。这个模型有几个吸引人的特性。首先,参数的数量和RAM执行的计算量都可以独立于输入图像的大小进行控制。其次,该模型能够忽略图像中存在的杂波,将视网膜集中在相关区域。

我们的实验表明,在一个混乱的对象分类任务中,RAM的性能显著优于具有相同数量参数的卷积架构。此外,我们的方法的灵活性允许许多有趣的扩展。例如,可以使用另一个操作来扩展网络,该操作允许网络在任何时间点终止并做出最终的分类决策。我们的初步实验表明,一旦有了足够的信息来进行可靠的分类,网络就可以学会停止Glimpse。该网络还可以控制视网膜采样图像的尺度,使其能够在固定大小的视网膜中适应不同大小的对象。在这两种情况下,可以使用前面描述的策略梯度过程将额外的操作简单地添加到操作网络fa中并对其进行训练。鉴于RAM取得的令人鼓舞的结果,将该模型应用于大规模对象识别和视频分类是未来工作的一个自然方向。

开源代码

https://github.com/kevinzakka/recurrent-visual-attention

0 人点赞