我正在尝试训练 RNN,但在嵌入时遇到问题。 我收到以下错误消息:
TypeError: embedding(): argument 'indices' (position 2) must be Tensor,not list
forward 方法中的代码是这样开头的:
def forward(self,word_indices: [int]):
print("sentences")
print(len(word_indices))
print(word_indices)
word_ind_tensor = torch.tensor(word_indices,device="cpu")
print(word_ind_tensor)
print(word_ind_tensor.size())
embeds_word = self.embedding_word(word_indices)
所有这些的输出是:
sentences
29
[261,15,5149,44,287,688,1125,4147,9874,582,9875,3,2,6732,34,6733,9,485,7,6734,741,2179,1571,1]
tensor([ 261,1])
torch.Size([29])
Traceback (most recent call last):
File "/home/lukas/Documents/HU/Materialen/21SoSe-Studienprojekt/flair-Studienprojekt/TestModel.py",line 68,in <module>
embeddings_storage_mode = "CPU") #auf cuda ändern
File "/home/lukas/Documents/HU/Materialen/21SoSe-Studienprojekt/flair-Studienprojekt/flair/trainers/trainer.py",line 423,in train
loss = self.model.forward_loss(batch_step)
File "/home/lukas/Documents/HU/Materialen/21SoSe-Studienprojekt/flair-Studienprojekt/flair/models/sandbox/srl_tagger.py",line 122,in forward_loss
features = self.forward(word_indices = sent_word_ind,frame_indices = sent_frame_ind)
File "/home/lukas/Documents/HU/Materialen/21SoSe-Studienprojekt/flair-Studienprojekt/flair/models/sandbox/srl_tagger.py",line 147,in forward
embeds_word = self.embedding_word(word_indices)
File "/home/lukas/miniconda3/envs/studienprojekt/lib/python3.7/site-packages/torch/nn/modules/module.py",line 550,in __call__
result = self.forward(*input,**kwargs)
File "/home/lukas/miniconda3/envs/studienprojekt/lib/python3.7/site-packages/torch/nn/modules/sparse.py",line 114,in forward
self.norm_type,self.scale_grad_by_freq,self.sparse)
File "/home/lukas/miniconda3/envs/studienprojekt/lib/python3.7/site-packages/torch/nn/functional.py",line 1724,in embedding
return torch.embedding(weight,input,padding_idx,scale_grad_by_freq,sparse)
TypeError: embedding(): argument 'indices' (position 2) must be Tensor,not list
我最初通过以下方式初始化嵌入:
self.embedding_word = torch.nn.Embedding(self.word_dict_size,embedding_size)
word_dict_size 和 embedding_size 都是整数。 是我明显做错了什么还是更深层次的错误?