我了解您的困惑。看一下下面的示例和注释:
# [Batch size,Sequence length,Embedding size]
inputs = torch.rand(128,5,300)
gru = nn.GRU(input_size=300,hidden_size=400,num_layers=2,batch_first=True)
with torch.no_grad():
# output is all hidden states,for each element in the batch of the last layer in the RNN
# a is the last hidden state of the first layer
# b is the last hidden state of the second (last) layer
output,(a,b) = gru(inputs)
如果我们打印出形状,它们将证实我们的理解:
print(output.shape) # torch.Size([128,400])
print(a.shape) # torch.Size([128,400])
print(b.shape) # torch.Size([128,400])
此外,我们可以测试从output
获得的最后一层的批次中每个元素的最后隐藏状态是否等于b
:
np.testing.assert_almost_equal(b.numpy(),output[:,:-1,:].numpy())
最后,我们可以创建一个3层的RNN,并运行相同的测试:
gru = nn.GRU(input_size=300,num_layers=3,batch_first=True)
with torch.no_grad():
output,b,c) = gru(inputs)
np.testing.assert_almost_equal(c.numpy(),-1,:].numpy())
同样,断言通过,但仅当我们对c
执行断言时,断言现在是RNN的最后一层。否则:
np.testing.assert_almost_equal(b.numpy(),:].numpy())
引发错误:
AssertionError:数组几乎不等于7个小数位
我希望这对您来说很清楚。
本文链接:https://www.f2er.com/3128013.html