scipy.stats.multivariate_normal.pdf与使用numpy编写的同一函数有何不同?

我需要在脚本中使用多元正态分布。我已经注意到,我的版本与scipy的方法给出了不同的答案。我真的不知道为什么...

这是我的功能:

def gauss(x,mu,sigma):
    assert np.linalg.det(sigma)!=0,"determinant of sigma is 0"
    y = np.exp((-1/2)*(x-mu).T.dot(np.linalg.inv(sigma)).dot(x-mu))/np.sqrt(
      np.power(2*np.pi,len(x))*np.linalg.det(sigma)
    )
    return y

以下是结果的比较:

from scipy.stats import multivariate_normal
import numpy as np

x = np.array([-0.54849176,6.39530657])
mu = np.array([15,20])
sigma = np.array([
  [2,3],[4,10]
])

print(gauss(x,sigma))
# output is 1.8781656851138248e-37

print(multivariate_normal.pdf(x,sigma))
# output is 2.698549423643947e-61

有人注意到吗?我的功能错了吗?任何帮助将不胜感激!

afang2468 回答:scipy.stats.multivariate_normal.pdf与使用numpy编写的同一函数有何不同?

您用作示例的特定输入可能会引起误解,因为这些值太低,以至于数字问题很容易引起您所看到的差异。但是,即使使用密度较大的示例,您仍然会遇到问题:

In [95]: x = np.array([15.00054849176,20.0009530657]) 
    ...: mu = np.array([15,20]) 
    ...: sigma = np.array([ 
    ...:   [2,3],...:   [4,10] 
    ...: ]) 
    ...:                                                                                        

In [96]: print(gauss(x,mu,sigma)) 
    ...: print(multivariate_normal.pdf(x,sigma)) 
    ...:                                                                                        
0.05626976565965294
0.07957746514880353

也许有趣的是,在数值问题上,差异是np.sqrt(2)的一个因素,但这有点像个红色鲱鱼:事实证明,差异仅是由您的协方差矩阵而不是协方差引起的矩阵:虽然它是正半定数,但它是不对称。使用有效的输入,这两种方法确实可以达成共识(在数字问题上):

In [99]: x = np.array([15.00054849176,...:   [3,10] 
    ...: ]) 
    ...:                                                                                        

In [100]: print(gauss(x,sigma)) 
     ...: print(multivariate_normal.pdf(x,sigma)) 
     ...:                                                                                       
0.047987017204594515
0.04798701720459451

或者,使用您的原始输入:

In [111]: x = np.array([-0.54849176,6.39530657]) 
     ...: mu = np.array([15,20]) 
     ...: sigma = np.array([ 
     ...:   [2,10] 
     ...: ]) 
     ...:                                                                                       

In [112]: print(gauss(x,sigma)) 
     ...:                                                                                       
5.060725651214228e-32
5.060725651214157e-32
本文链接:https://www.f2er.com/3061310.html

大家都在问