理解Pytorch中LSTM的输入输出参数含义

2020-06-12 10:01:32 浏览数 (1)

  • 这个输出tensor包含了LSTM模型最后一层每个time step的输出特征,比如说LSTM有两层,那么最后输出的是
[h^1_0,h^1_1,...,h^1_l]

,表示第二层LSTM每个time step对应的输出。

代码语言:txt复制
- 另外如果前面你对输入数据使用了`torch.nn.utils.rnn.PackedSequence`,那么输出也会做同样的操作编程packed sequence。
- 对于unpacked情况,我们可以对输出做如下处理来对方向作分离`output.view(seq_len, batch, num_directions, hidden_size)`, 其中前向和后向分别用0和1表示Similarly, the directions can be separated in the packed case.

3、 代码示例

代码语言:javascript复制
rnn = nn.LSTM(10, 20, 2) # 一个单词向量长度为10,隐藏层节点数为20,LSTM有2层
input = torch.randn(5, 3, 10) # 输入数据由3个句子组成,每个句子由5个单词组成,单词向量长度为10
h0 = torch.randn(2, 3, 20) # 2:LSTM层数*方向 3:batch 20: 隐藏层节点数
c0 = torch.randn(2, 3, 20) # 同上
output, (hn, cn) = rnn(input, (h0, c0))

print(output.shape, hn.shape, cn.shape)

>>> torch.Size([5, 3, 20]) torch.Size([2, 3, 20]) torch.Size([2, 3, 20])

参考:

  • LSTM神经网络输入输出究竟是怎样的?Scofield的回答
  • Pytorch-LSTM

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~

邮箱:marsggbo@foxmail.com

2019-12-31 10:41:09

0 人点赞