BPCLSTM

2022-09-02 22:18:41 浏览数 (1)

传统LSTM

Classic LSTM在RNN 卷积神经网络上的改进,网络结构如图所示:

Classic Lstm 的定义如下:

i_{t}, f_{t}, o_{t}分别表示输入门,遗忘门、输出门,sigma() 是sigmoid 函数,b_{i}, b_{f}, b_{c}, b_{o} 表示bias ,x_{t}, c_{t}, h_{t} 是在t时刻的输入、激活状态、隐藏状态。W表示权重矩阵。比如W_{h f} 控制遗忘门怎么从隐藏状态获取值。

C-LSTM (convlstm)

CLSTM 【1】是最近生成的LSTM,主要用于输入是图像,用卷积操作替换向量乘法。LSTM处理序列问题,1维数据,CLSTM在对于图像序列的处理中,有优势,比如图像序列的预测以及编码解码框架。

i_{t}, f_{t}, o_{t}mathcal{X}_{1}, ldots, mathcal{X}_{t}mathcal{C}_{1}, ldots, mathcal{C}_{t}mathcal{H}_{1}, ldots, mathcal{H}_{t}都是3D张量。定义如下:将传统的LSTM 中的向量乘积用卷积操作替代。

left{begin{array}{l} i_{z}=sigmaleft(x_{z} * W_{x i} h_{z-1} * W_{h i} b_{i}right) \ f_{z}=sigmaleft(x_{z} * W_{x f} h_{z-1} * W_{h f} b_{f}right) \ c_{z}=c_{z-1} odot f_{z} i_{z} odot tanh left(x_{z} * W_{x c} h_{z-1} * W_{h c} b_{c}right) \ o_{z}=sigmaleft(x_{z} * W_{x o} h_{z-1} * W_{h o} b_{o}right) \ h_{z}=o_{z} odot tanh left(c_{z}right) end{array}right.

CLSTMCell代表CLSTM序列中的一个节点,只产生一个时刻的输出。 hidden_channels 为隐藏层数目,相当于CNN中卷积层的卷积核数目。

代码语言:javascript复制
import torch
import torch.nn as nn
from torch.autograd import Variable
     

# Batch x NumChannels x Height x Width
# UNET --> BatchSize x 1 (3?) x 240 x 240
# BDCLSTM --> BatchSize x 64 x 240 x240
     
''' Class CLSTMCell.
    This represents a single node in a CLSTM series.
    It produces just one time (spatial) step output.
'''
     
     
class CLSTMCell(nn.Module):
     # Constructor
     def __init__(self, input_channels, hidden_channels,
                     kernel_size, bias=True):
         super(CLSTMCell, self).__init__()
     
         assert hidden_channels % 2 == 0
     
         self.input_channels = input_channels
         self.hidden_channels = hidden_channels
         self.bias = bias
         self.kernel_size = kernel_size
         self.num_features = 4
     
         self.padding = (kernel_size - 1) // 2
         self.conv = nn.Conv2d(self.input_channels   self.hidden_channels,
                               self.num_features * self.hidden_channels,
                               self.kernel_size,
                               1,
                               self.padding)
     
     # Forward propogation formulation
     def forward(self, x, h, c):
         # print('x: ', x.type)
         # print('h: ', h.type)
         combined = torch.cat((x, h), dim=1)
         A = self.conv(combined)
     
         # NOTE: A? = xz * Wx?   hz-1 * Wh?   b? where * is convolution
         (Ai, Af, Ao, Ag) = torch.split(A,
                                        A.size()[1] // self.num_features,
                                        dim=1)
     
          i = torch.sigmoid(Ai)     # input gate
          f = torch.sigmoid(Af)     # forget gate
          o = torch.sigmoid(Ao)     # output gate
          g = torch.tanh(Ag)
     
          c = c * f   i * g         # cell activation state
          h = o * torch.tanh(c)     # cell hidden state
     
          return h, c
     
     @staticmethod
     def init_hidden(batch_size, hidden_c, shape):
         try:
            return(Variable(torch.zeros(batch_size,
                            hidden_c,
                            shape[0],
                            shape[1])).cuda(),
            Variable(torch.zeros(batch_size,
                                 hidden_c,
                                 shape[0],
                                 shape[1])).cuda())
         except:
            return(Variable(torch.zeros(batch_size,
                            hidden_c,
                            shape[0],
                            shape[1])),
                   Variable(torch.zeros(batch_size,
                            hidden_c,
                            shape[0],
                            shape[1])))

CLSTM 为多个CLSTMCell的组合。

代码语言:javascript复制
''' Class CLSTM.
    This represents a series of CLSTM nodes (one direction)
'''
     
     
class CLSTM(nn.Module):
    # Constructor
    def __init__(self, input_channels=64, hidden_channels=[64],
                     kernel_size=5, bias=True):
        super(CLSTM, self).__init__()
     
        # store stuff
        self.input_channels = [input_channels]   hidden_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.num_layers = len(hidden_channels)
     
        self.bias = bias
        self.all_layers = []
     
        # create a node for each layer in the CLSTM
        for layer in range(self.num_layers):
            name = 'cell{}'.format(layer)
            cell = CLSTMCell(self.input_channels[layer],
                             self.hidden_channels[layer],
                             self.kernel_size,
                             self.bias)
            setattr(self, name, cell)
            self.all_layers.append(cell)
     
    # Forward propogation
    # x --> BatchSize x NumSteps x NumChannels x Height x Width
    #       BatchSize x 2 x 64 x 240 x 240
    def forward(self, x):
        bsize, steps, _, height, width = x.size()
        internal_state = []
        outputs = []
        for step in range(steps):
            input = torch.squeeze(x[:, step, :, :, :], dim=1)
            for layer in range(self.num_layers):
                # populate hidden states for all layers
                if step == 0:
                     (h, c) = CLSTMCell.init_hidden(bsize,
                                                    self.hidden_channels[layer],
                                                    (height, width))
                     internal_state.append((h, c))
                
                # do forward
                name = 'cell{}'.format(layer)
                (h, c) = internal_state[layer]
     
                input, c = getattr(self, name)(
                    input, h, c)  # forward propogation call
                internal_state[layer] = (input, c)
     
           outputs.append(input)
     
            #for i in range(len(outputs)):
            #    print(outputs[i].size())
        return outputs

BPC-LSTM(Bi_Directional Conv-LSTM)

BPC-LSTM【2】是在C-LSTM的扩展,C-LSTM作用在两个相反的方向,一个在Z^- 方向,一个在Z^ 方向,结合两层的上下文信息作为输出。

A作为一个BPC-LSTM单元,在深层结构中,可以将其看作和CNN层相似的结构,组合成深层次结构。

代码语言:javascript复制
class BDCLSTM(nn.Module):
     # Constructor
     def __init__(self, input_channels=64, hidden_channels=[64],
                     kernel_size=5, bias=True, num_classes=2):
     
         super(BDCLSTM, self).__init__()
         self.forward_net = CLSTM(
             input_channels, hidden_channels, kernel_size, bias)
         self.reverse_net = CLSTM(
             input_channels, hidden_channels, kernel_size, bias)
         self.conv = nn.Conv2d(
             2 * hidden_channels[-1], num_classes, kernel_size=1)
         self.soft = nn.Softmax2d()
     
     # Forward propogation
     # x --> BatchSize x NumChannels x Height x Width
     #       BatchSize x 64 x 240 x 240
     def forward(self, x1, x2, x3):
         x1 = torch.unsqueeze(x1, dim=1)
         x2 = torch.unsqueeze(x2, dim=1)
         x3 = torch.unsqueeze(x3, dim=1)
     
         xforward = torch.cat((x1, x2), dim=1)
         xreverse = torch.cat((x3, x2), dim=1)
     
         yforward = self.forward_net(xforward)
         yreverse = self.reverse_net(xreverse)
     
         # assumes y is BatchSize x NumClasses x 240 x 240
         # print(yforward[-1].type)
         ycat = torch.cat((yforward[-1], yreverse[-1]), dim=1)
         # print(ycat.size())
         y = self.conv(ycat)
         # print(y.type)
         y = self.soft(y)
         # print(y.type)
         return y

参考文献:

【1】X. Shi, Z. Chen, H. Wang, D.-Y. Yeung, W.-K. Wong, and W. chun Woo. Convolutional lstm network: A

machine learning approach for precipitation nowcasting. arXiv preprint arXiv:1506.04214, 2015.

【2】Jianxu Chen,Lin Yang. Combining Fully Convolutional and RecurrentNeural Networks for 3D Biomedical Image Segmentationr . aXiv:1609.01006v2 [cs.CV] 6 Sep 2016.

0 人点赞