embedding_bag如何在PyTorch中正常工作

在PyTorch中,

似乎是torch.nn.functional.embedding_bag的主要功能,它负责完成嵌入查找的实际工作。在PyTorch的文档中,已经提到embedding_bag在不实例化中间嵌入的情况下完成了它的工作。这到底是什么意思?这是否意味着例如当模式为“求和”时就地进行求和?还是仅仅意味着在调用embedding_bag时不会产生额外的张量,但是从系统的角度来看,所有中间行向量已经被提取到处理器中以用于计算最终的张量?

iCMS 回答:embedding_bag如何在PyTorch中正常工作

在最简单的情况下,torch.nn.functional.embedding_bag 在概念上是一个两步过程。第一步是创建嵌入,第二步是减少(sum/mean/max,根据“mode”参数)跨维度 0 的嵌入输出。因此,您可以通过调用 { {1}},后跟 torch.nn.functional.embedding。在以下示例中,torch.sum/mean/maxembedding_bag_res 相等。

embedding_mean_res

但是,概念上的两步流程并未反映其实际实施方式。由于 >>> weight = torch.randn(3,4) >>> weight tensor([[ 0.3987,1.6173,0.4912,1.5001],[ 0.2418,1.5810,-1.3191,0.0081],[ 0.0931,0.4102,0.3003,0.2288]]) >>> indices = torch.tensor([2,1]) >>> embedding_res = torch.nn.functional.embedding(indices,weight) >>> embedding_res tensor([[ 0.0931,0.2288],0.0081]]) >>> embedding_mean_res = embedding_res.mean(dim=0,keepdim=True) >>> embedding_mean_res tensor([[ 0.1674,0.9956,-0.5094,0.1185]]) >>> embedding_bag_res = torch.nn.functional.embedding_bag(indices,weight,torch.tensor([0]),mode='mean') >>> embedding_bag_res tensor([[ 0.1674,0.1185]]) 不需要返回中间结果,因此它实际上不会为嵌入生成 Tensor 对象。它只是直接计算减少量,根据 embedding_bag 参数中的索引从 weight 参数中提取适当的数据。避免创建嵌入张量可以提高性能。

所以你的问题的答案(如果我理解正确的话)

这只是意味着在调用 embedding_bag 时不会产生额外的 Tensor,但仍然从系统的角度来看,所有中间行向量都已经被提取到处理器中用于计算最终的 Tensor?

是的。

本文链接:https://www.f2er.com/2258310.html

大家都在问