scipy curve_fit即使提供了很好的猜测也根本无法正确拟合?

我正在尝试对某些数据拟合经过指数修改的高斯函数。数据位于顶部。

我有以下代码:

import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
from scipy.special import erfc

bins = [-46.82455,-46.41738,-46.01021,-45.60304,-45.19587,-44.7887,-44.38153,-43.97436,-43.56719,-43.16002,-42.75285,-42.34568,-41.93851,-41.53134,-41.12417,-40.717,-40.30983,-39.90266,-39.49549,-39.08832,-38.68115,-38.27398,-37.86681,-37.45964,-37.05247,-36.6453,-36.23813,-35.83096,-35.42379,-35.01662,-34.60945,-34.20228,-33.79511,-33.38794,-32.98077,-32.5736,-32.16643,-31.75926,-31.35209,-30.94492,-30.53775,-30.13058,-29.72341,-29.31624,-28.90907,-28.5019,-28.09473,-27.68756,-27.28039,-26.87322,-26.46605,-26.05888,-25.65171,-25.24454,-24.83737,-24.4302,-24.02303,-23.61586,-23.20869,-22.80152,-22.39435,-21.98718,-21.58001,-21.17284,-20.76567,-20.3585,-19.95133,-19.54416,-19.13699,-18.72982,-18.32265,-17.91548,-17.50831,-17.10114,-16.69397,-16.2868,-15.87963,-15.47246,-15.06529,-14.65812,-14.25095,-13.84378,-13.43661,-13.02944,-12.62227,-12.2151,-11.80793,-11.40076,-10.99359,-10.58642,-10.17925,-9.77208,-9.36491,-8.95774,-8.55057,-8.1434,-7.73623,-7.32906,-6.92189,-6.51472,-6.10755,-5.70038,-5.29321,-4.88604,-4.47887,-4.0717,-3.66453,-3.25736,-2.85019,-2.44302,-2.03585,-1.62868,-1.22151,-0.81434,-0.40717,0.0,0.40717,0.81434,1.22151,1.62868,2.03585,2.44302,2.85019,3.25736,3.66453,4.0717,4.47887,4.88604,5.29321,5.70038,6.10755,6.51472,6.92189,7.32906,7.73623,8.1434,8.55057,8.95774,9.36491,9.77208,10.17925,10.58642,10.99359,11.40076,11.80793,12.2151,12.62227,13.02944,13.43661,13.84378,14.25095,14.65812,15.06529,15.47246,15.87963,16.2868,16.69397,17.10114,17.50831,17.91548,18.32265,18.72982,19.13699,19.54416,19.95133,20.3585,20.76567,21.17284,21.58001,21.98718,22.39435,22.80152,23.20869,23.61586,24.02303,24.4302,24.83737,25.24454,25.65171,26.05888,26.46605,26.87322,27.28039,27.68756,28.09473,28.5019,28.90907,29.31624,29.72341,30.13058,30.53775,30.94492,31.35209,31.75926,32.16643,32.5736,32.98077,33.38794,33.79511,34.20228,34.60945,35.01662,35.42379,35.83096,36.23813,36.6453,37.05247,37.45964,37.86681,38.27398,38.68115,39.08832,39.49549,39.90266,40.30983,40.717,41.12417,41.53134,41.93851,42.34568,42.75285,43.16002,43.56719,43.97436,44.38153,44.7887,45.19587,45.60304,46.01021,46.41738]

counts = [0.00000000e+00,0.00000000e+00,9.82318271e-04,1.96463654e-03,7.85854617e-03,9.82318271e-03,1.27701375e-02,1.47347741e-02,1.76817289e-02,2.75049116e-02,3.14341847e-02,4.32220039e-02,5.79567780e-02,6.77799607e-02,9.43025540e-02,1.29666012e-01,1.48330059e-01,1.87622790e-01,2.07269155e-01,2.54420432e-01,3.00589391e-01,3.33005894e-01,4.03732809e-01,4.72495088e-01,5.22593320e-01,5.99214145e-01,6.34577603e-01,7.04322200e-01,8.18271120e-01,8.58546169e-01,9.26326130e-01,9.65618861e-01,9.35166994e-01,9.76424361e-01,9.39096267e-01,1.00000000e+00,9.67583497e-01,9.36149312e-01,9.13555992e-01,9.38113949e-01,8.35952849e-01,8.31041257e-01,8.33988212e-01,7.54420432e-01,7.17092338e-01,6.12966601e-01,6.22789784e-01,5.37328094e-01,4.76424361e-01,4.35166994e-01,3.89980354e-01,3.53634578e-01,3.47740668e-01,3.51669941e-01,2.87819253e-01,2.67190570e-01,3.04518664e-01,2.60314342e-01,2.70137525e-01,2.65225933e-01,3.06483301e-01,2.72102161e-01,2.61296660e-01,2.57367387e-01,2.45579568e-01,2.25933202e-01,2.28880157e-01,2.21021611e-01,2.23968566e-01,1.95481336e-01,1.80746562e-01,1.56188605e-01,1.53241650e-01,1.23772102e-01,1.47347741e-01,1.26719057e-01,8.93909627e-02,7.17092338e-02,8.84086444e-02,6.28683694e-02,6.97445972e-02,6.58153242e-02,4.61689587e-02,4.51866405e-02,4.22396857e-02,3.92927308e-02,3.43811395e-02,2.45579568e-02,3.53634578e-02,2.94695481e-02,2.16110020e-02,3.63457760e-02,1.96463654e-02,2.35756385e-02,2.84872299e-02,2.55402750e-02,2.06286837e-02,1.86640472e-02,3.33988212e-02,2.25933202e-02,2.65225933e-02,6.87622790e-03,1.66994106e-02,1.17878193e-02,1.08055010e-02,1.37524558e-02,5.89390963e-03,9.82318271e-03]

def exp_mod_gauss(x,m,s,l):
    y = 0.5*l*np.exp(0.5*l*(2*m+l*s*s-2*x))*erfc((m+l*s*s-x)/(np.sqrt(2)*s))
    return y
    #l=Lambda,s=Sigma,m=Mu

bins=np.asarray(bins,dtype='float')
counts=np.asarray(counts,dtype='float')

popt,pcov = curve_fit(exp_mod_gauss,bins,counts,p0=[-3.5,2.8736,0.1548])
fitted_func = exp_mod_gauss(bins,popt[0],popt[1],popt[2])
#fitted_func = exp_mod_gauss(bins,-3.5,0.1548) #used for manual example
plt.plot(bins,'o',markersize=1) #plot actual counts
plt.plot(bins,fitted_func/max(fitted_func),'-') #plot fitted func/scaled
plt.show()

如果按照编写的方式使用scipy拟合运行代码,则会得到以下结果:

scipy curve_fit即使提供了很好的猜测也根本无法正确拟合?

显然不是很好。

但是,如果我注释掉使用curve_fit参数的fit_func行,并使用我在初始猜测中提供的参数(-3.5、2.876、0.1548),则会得到以下结果:

scipy curve_fit即使提供了很好的猜测也根本无法正确拟合?

因此,即使我为curve_fit提供最初的猜测,这基本上就是我要寻找的答案,但是它失败了。通过在Matlab中执行完全相同的过程,我得到了很好的拟合参数,但是我不想使用Matlab。我想使用Python。

有人知道这里发生了什么吗?

非常感谢。

zcg345 回答:scipy curve_fit即使提供了很好的猜测也根本无法正确拟合?

因此,事实证明,该配件需要EMG功能中的附加自由度才能起作用。以便可以扩展到数据。如果我将EMG函数修改为:

def exp_mod_gauss(x,b,m,s,l):
    y = b*(0.5*l*np.exp(0.5*l*(2*m+l*s*s-2*x))*erfc((m+l*s*s-x)/(np.sqrt(2)*s)))
    return y
    #l=Lambda,s=Sigma,m=Mu,#b=scaling

因此,将b项相加来缩放峰进行了排序。我提供了猜测[1,-1,1,0]现在可以满足我期望的各种数据。

,

根据散点图,数据中似乎有两个单独的重叠峰。这是一个图形化的Python拟合器,使用您的数据并将其拟合到两个高斯峰的总和,并由scipy的Differential Evolution遗传算法模块提供了curve_fit()的初始参数估计。该模块使用拉丁文Hypercube算法来确保对参数空间进行彻底搜索,从而需要在搜索范围内进行搜索。在此示例中,这些边界取自最大值和最小值数据值,其中0.0用作我怀疑应该为正的参数的下限。

plot

import numpy,scipy,matplotlib
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from scipy.optimize import differential_evolution
import warnings

bins = [-46.82455,-46.41738,-46.01021,-45.60304,-45.19587,-44.7887,-44.38153,-43.97436,-43.56719,-43.16002,-42.75285,-42.34568,-41.93851,-41.53134,-41.12417,-40.717,-40.30983,-39.90266,-39.49549,-39.08832,-38.68115,-38.27398,-37.86681,-37.45964,-37.05247,-36.6453,-36.23813,-35.83096,-35.42379,-35.01662,-34.60945,-34.20228,-33.79511,-33.38794,-32.98077,-32.5736,-32.16643,-31.75926,-31.35209,-30.94492,-30.53775,-30.13058,-29.72341,-29.31624,-28.90907,-28.5019,-28.09473,-27.68756,-27.28039,-26.87322,-26.46605,-26.05888,-25.65171,-25.24454,-24.83737,-24.4302,-24.02303,-23.61586,-23.20869,-22.80152,-22.39435,-21.98718,-21.58001,-21.17284,-20.76567,-20.3585,-19.95133,-19.54416,-19.13699,-18.72982,-18.32265,-17.91548,-17.50831,-17.10114,-16.69397,-16.2868,-15.87963,-15.47246,-15.06529,-14.65812,-14.25095,-13.84378,-13.43661,-13.02944,-12.62227,-12.2151,-11.80793,-11.40076,-10.99359,-10.58642,-10.17925,-9.77208,-9.36491,-8.95774,-8.55057,-8.1434,-7.73623,-7.32906,-6.92189,-6.51472,-6.10755,-5.70038,-5.29321,-4.88604,-4.47887,-4.0717,-3.66453,-3.25736,-2.85019,-2.44302,-2.03585,-1.62868,-1.22151,-0.81434,-0.40717,0.0,0.40717,0.81434,1.22151,1.62868,2.03585,2.44302,2.85019,3.25736,3.66453,4.0717,4.47887,4.88604,5.29321,5.70038,6.10755,6.51472,6.92189,7.32906,7.73623,8.1434,8.55057,8.95774,9.36491,9.77208,10.17925,10.58642,10.99359,11.40076,11.80793,12.2151,12.62227,13.02944,13.43661,13.84378,14.25095,14.65812,15.06529,15.47246,15.87963,16.2868,16.69397,17.10114,17.50831,17.91548,18.32265,18.72982,19.13699,19.54416,19.95133,20.3585,20.76567,21.17284,21.58001,21.98718,22.39435,22.80152,23.20869,23.61586,24.02303,24.4302,24.83737,25.24454,25.65171,26.05888,26.46605,26.87322,27.28039,27.68756,28.09473,28.5019,28.90907,29.31624,29.72341,30.13058,30.53775,30.94492,31.35209,31.75926,32.16643,32.5736,32.98077,33.38794,33.79511,34.20228,34.60945,35.01662,35.42379,35.83096,36.23813,36.6453,37.05247,37.45964,37.86681,38.27398,38.68115,39.08832,39.49549,39.90266,40.30983,40.717,41.12417,41.53134,41.93851,42.34568,42.75285,43.16002,43.56719,43.97436,44.38153,44.7887,45.19587,45.60304,46.01021,46.41738]

counts = [0.00000000e+00,0.00000000e+00,9.82318271e-04,1.96463654e-03,7.85854617e-03,9.82318271e-03,1.27701375e-02,1.47347741e-02,1.76817289e-02,2.75049116e-02,3.14341847e-02,4.32220039e-02,5.79567780e-02,6.77799607e-02,9.43025540e-02,1.29666012e-01,1.48330059e-01,1.87622790e-01,2.07269155e-01,2.54420432e-01,3.00589391e-01,3.33005894e-01,4.03732809e-01,4.72495088e-01,5.22593320e-01,5.99214145e-01,6.34577603e-01,7.04322200e-01,8.18271120e-01,8.58546169e-01,9.26326130e-01,9.65618861e-01,9.35166994e-01,9.76424361e-01,9.39096267e-01,1.00000000e+00,9.67583497e-01,9.36149312e-01,9.13555992e-01,9.38113949e-01,8.35952849e-01,8.31041257e-01,8.33988212e-01,7.54420432e-01,7.17092338e-01,6.12966601e-01,6.22789784e-01,5.37328094e-01,4.76424361e-01,4.35166994e-01,3.89980354e-01,3.53634578e-01,3.47740668e-01,3.51669941e-01,2.87819253e-01,2.67190570e-01,3.04518664e-01,2.60314342e-01,2.70137525e-01,2.65225933e-01,3.06483301e-01,2.72102161e-01,2.61296660e-01,2.57367387e-01,2.45579568e-01,2.25933202e-01,2.28880157e-01,2.21021611e-01,2.23968566e-01,1.95481336e-01,1.80746562e-01,1.56188605e-01,1.53241650e-01,1.23772102e-01,1.47347741e-01,1.26719057e-01,8.93909627e-02,7.17092338e-02,8.84086444e-02,6.28683694e-02,6.97445972e-02,6.58153242e-02,4.61689587e-02,4.51866405e-02,4.22396857e-02,3.92927308e-02,3.43811395e-02,2.45579568e-02,3.53634578e-02,2.94695481e-02,2.16110020e-02,3.63457760e-02,1.96463654e-02,2.35756385e-02,2.84872299e-02,2.55402750e-02,2.06286837e-02,1.86640472e-02,3.33988212e-02,2.25933202e-02,2.65225933e-02,6.87622790e-03,1.66994106e-02,1.17878193e-02,1.08055010e-02,1.37524558e-02,5.89390963e-03,9.82318271e-03]

xData = numpy.array(bins)
yData = numpy.array(counts)


def func(X,a,c,f,g,h): # sum of two gaussian peaks
    # a,c and f,h are the fitted parameters for the two peaks
    return a * numpy.exp(-0.5 * ((X-b)/c)**2)  +  f * numpy.exp(-0.5 * ((X-g)/h)**2)


# function for genetic algorithm to minimize (sum of squared error)
def sumOfSquaredError(parameterTuple):
    warnings.filterwarnings("ignore") # do not print warnings by genetic algorithm
    val = func(xData,*parameterTuple)
    return numpy.sum((yData - val) ** 2.0)


def generate_Initial_Parameters():
    # min and max used for bounds
    maxX = max(xData)
    minX = min(xData)
    #maxY = max(yData)
    #minY = min(yData)

    parameterBounds = []

    parameterBounds.append([0.0,maxX]) # search bounds for a,positive
    parameterBounds.append([minX,maxX]) # search bounds for b
    parameterBounds.append([0.0,maxX]) # search bounds for c,positive

    parameterBounds.append([0.0,maxX]) # search bounds for f,maxX]) # search bounds for g
    parameterBounds.append([0.0,maxX]) # search bounds for h,positive

    # "seed" the numpy random number generator for repeatable results
    result = differential_evolution(sumOfSquaredError,parameterBounds,seed=3)
    return result.x

# by default,differential_evolution completes by calling curve_fit() using parameter bounds
fittedParameters = generate_Initial_Parameters()
print('Fitted parameters:',fittedParameters)
print()

modelPredictions = func(xData,*fittedParameters) 

absError = modelPredictions - yData

SE = numpy.square(absError) # squared errors
MSE = numpy.mean(SE) # mean squared errors
RMSE = numpy.sqrt(MSE) # Root Mean Squared Error,RMSE
Rsquared = 1.0 - (numpy.var(absError) / numpy.var(yData))

print()
print('RMSE:',RMSE)
print('R-squared:',Rsquared)

print()


##########################################################
# graphics output section
def ModelAndScatterPlot(graphWidth,graphHeight):
    f = plt.figure(figsize=(graphWidth/100.0,graphHeight/100.0),dpi=100)
    axes = f.add_subplot(111)

    # first the raw data as a scatter plot
    axes.plot(xData,yData,'D')

    # create data for the fitted equation plot
    xModel = numpy.linspace(min(xData),max(xData))
    yModel = func(xModel,*fittedParameters)

    # now the model as a line plot
    axes.plot(xModel,yModel)

    axes.set_xlabel('X Data') # X axis data label
    axes.set_ylabel('Y Data') # Y axis data label

    plt.show()
    plt.close('all') # clean up after using pyplot

graphWidth = 800
graphHeight = 600
ModelAndScatterPlot(graphWidth,graphHeight)
本文链接:https://www.f2er.com/3123876.html

大家都在问