我正在尝试在Swift中为Tensorflow创建余弦相似度图层,以创建单词嵌入。
我试图根据Wikipedia的定义通过张量函数实现它。
@differentiable
func cosinesimilarity(_ input: Tensor<Float>) -> Tensor<Float> {
// https://en.wikipedia.org/wiki/Cosine_similarity
let numerator = input.product(squeezingAxes: 1).sum(squeezingAxes: 1)
let denominator = sqrt(input.squared().sum(squeezingAxes: 1)).product(squeezingAxes: 1)
return numerator / denominator
}
但是会导致编译错误:
error: expression is not differentiable
let denominator = sqrt(input.squared().sum(squeezingAxes: 1)).product(squeezingAxes: 1)
具有增强的product
功能。因为乘积只是乘法的总和,所以我希望它能起作用。如果我尝试手动创建产品,则可以在for循环中运行,但是速度很慢。