似乎是torch.nn.functional.embedding_bag的主要功能,它负责完成嵌入查找的实际工作。在PyTorch的文档中,已经提到embedding_bag在不实例化中间嵌入的情况下完成了它的工作。这到底是什么意思?这是否意味着例如当模式为“求和”时就地进行求和?还是仅仅意味着在调用embedding_bag时不会产生额外的张量,但是从系统的角度来看,所有中间行向量已经被提取到处理器中以用于计算最终的张量?
iCMS 回答:embedding_bag如何在PyTorch中正常工作
在最简单的情况下,torch.nn.functional.embedding_bag
在概念上是一个两步过程。第一步是创建嵌入,第二步是减少(sum/mean/max,根据“mode”参数)跨维度 0 的嵌入输出。因此,您可以通过调用 { {1}},后跟 torch.nn.functional.embedding
。在以下示例中,torch.sum/mean/max
和 embedding_bag_res
相等。
embedding_mean_res
但是,概念上的两步流程并未反映其实际实施方式。由于 >>> weight = torch.randn(3,4)
>>> weight
tensor([[ 0.3987,1.6173,0.4912,1.5001],[ 0.2418,1.5810,-1.3191,0.0081],[ 0.0931,0.4102,0.3003,0.2288]])
>>> indices = torch.tensor([2,1])
>>> embedding_res = torch.nn.functional.embedding(indices,weight)
>>> embedding_res
tensor([[ 0.0931,0.2288],0.0081]])
>>> embedding_mean_res = embedding_res.mean(dim=0,keepdim=True)
>>> embedding_mean_res
tensor([[ 0.1674,0.9956,-0.5094,0.1185]])
>>> embedding_bag_res = torch.nn.functional.embedding_bag(indices,weight,torch.tensor([0]),mode='mean')
>>> embedding_bag_res
tensor([[ 0.1674,0.1185]])
不需要返回中间结果,因此它实际上不会为嵌入生成 Tensor 对象。它只是直接计算减少量,根据 embedding_bag
参数中的索引从 weight
参数中提取适当的数据。避免创建嵌入张量可以提高性能。
所以你的问题的答案(如果我理解正确的话)
这只是意味着在调用 embedding_bag 时不会产生额外的 Tensor,但仍然从系统的角度来看,所有中间行向量都已经被提取到处理器中用于计算最终的 Tensor?
是的。