使用NumPy高效返回具有小数部分的插入点索引

我希望有效地计算应该在数组中插入元素以保持顺序的索引,但包括一个小数部分,表示数组中两个最近点之间的“距离”。

>

应该可以使用索引和小数取回原始值。在实践中,以及为什么性能如此重要的原因,我将需要对大量数据点进行此操作。

为了演示我的意思,我通过np.searchsorted和一些if语句提出了一些可行的逻辑,但是还不能使用NumPy对逻辑进行矢量化处理。我也很高兴看到一个有效的解决方案,它利用numba并具有与NumPy相当或更好的性能。甚至我不知道的NumPy,Scipy等现成的解决方案。

我还在下面包括了一些基准测试代码。

import numpy as np

np.random.seed(0)

datapoint = np.random.random() * np.random.choice([1,-1]) * 500  # -274.4067
line = np.linspace(-500,500,101)  # [-500,-490,...,490,500] - an ordered array,may not be linspace

def get_position(line,point):
    position = np.searchsorted(line,point,side='right')
    size = line.shape[0]
    if position == 0:
        main = 0
        fraction = 0
    elif position == size:
        main = size-1
        fraction = 0
    else:
        main = position - 1
        fraction = (point - line[position-1]) / (line[position] - line[position-1])
    return main,fraction

idx,frac = get_position(line,datapoint)              # (22,0.55932480363376269)
print(line[idx] + frac * (line[idx + 1] - line[idx]))  # -274.4067; test to see if you get back original value

def run_multiple(line,data):
    out = np.empty((data.shape[0],3))
    for i in range(data.shape[0]):
        idx,data[i])
        out[i,0] = data[i]
        out[i,1] = idx
        out[i,2] = frac
    return out

基准化

# Python 3.6.0,NumPy 1.11.3,Numba 0.30.1
# Note: Numba 0.30.1 does not support "side" argument of np.searchsorted; not able to upgrade

n = 10**5  # actual n will be larger
res = run_multiple(line,np.random.random(n) * np.random.choice([1,-1],n) * 500)  # 901 ms per loop

# array([[ -4.22132874e+02,7.00000000e+00,7.86712571e-01],#        [ -4.28972809e+02,1.02719119e-01],#        [  4.23625869e+02,9.20000000e+01,3.62586939e-01],#        ...,#        [ -1.88627877e+02,3.10000000e+01,1.37212282e-01],#        [  4.98162640e+01,5.40000000e+01,9.81626397e-01],#        [  1.35777097e+02,6.30000000e+01,5.77709684e-01]])
xhq3645254 回答:使用NumPy高效返回具有小数部分的插入点索引

为进行矢量化处理,我将掩盖边缘情况,并在最后担心它们。无论如何,您只需要考虑position == size条件,因为在各个列中的低条件仅为零,out数组已经满足了这一条件。

def frac(line,points):
    pos = np.searchsorted(line,points,side='right')
    low = pos == 0
    high = pos == line.shape[0]
    m = ~(low | high)
    ii = points[m]
    jj = pos[m]
    frac = (ii - line[jj-1]) / (line[jj] - line[jj-1])
    out = np.zeros((points.shape[0],3))
    out[:,0] = points
    out[m,1] = jj - 1
    out[m,2] = frac
    out[high,1] = line.shape[0] - 1
    return out

基准

n = 10**5
line = np.linspace(-500,500,101)
points = np.random.random(n) * np.random.choice([1,-1],n) * 500

In [5]: %timeit run_multiple(line,points)
1.23 s ± 53.1 ms per loop (mean ± std. dev. of 7 runs,1 loop each)

In [7]: %timeit frac(line,points)
13.4 ms ± 290 µs per loop (mean ± std. dev. of 7 runs,100 loops each)

In [8]: np.allclose(frac(line,points),run_multiple(line,points))
Out[8]: True
,

如果Numba(或您使用的版本)不支持某些功能,最好查看Numba source code并查看已经存在的功能。 通常至少已经解决了一部分问题。

代码

import numpy as np
import numba as nb

#almost copied from Numba source
#https://github.com/numba/numba/blob/master/numba/targets/arraymath.py
"""Copyright (c) 2012,Anaconda,Inc.
All rights reserved.

Redistribution and use in source and binary forms,with or without
modification,are permitted provided that the following conditions are
met:

Redistributions of source code must retain the above copyright notice,this list of conditions and the following disclaimer.

Redistributions in binary form must reproduce the above copyright
notice,this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,INCLUDING,BUT NOT
LIMITED TO,THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,INDIRECT,INCIDENTAL,SPECIAL,EXEMPLARY,OR CONSEQUENTIAL DAMAGES (INCLUDING,PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,DATA,OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY,WHETHER IN CONTRACT,STRICT LIABILITY,OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE,EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
@nb.njit()
def searchsorted_right(a,v):
    n = len(a)
    if np.isnan(v):
        # Find the first nan (i.e. the last from the end of a,# since there shouldn't be many of them in practice)
        for i in range(n,-1):
            if not np.isnan(a[i - 1]):
                return i
        return 0
    lo = 0
    hi = n
    while hi > lo:
        mid = (lo + hi) >> 1
        if a[mid]<= v:
            # mid is too low => go up
            lo = mid + 1
        else:
            # mid is too high,or is a NaN => go down
            hi = mid
    return lo

@nb.njit()
def get_position(line,point):
    position = searchsorted_right(line,point)
    size = line.shape[0]
    if position == 0:
        main = 0
        fraction = 0
    elif position == size:
        main = size-1
        fraction = 0
    else:
        main = position - 1
        fraction = (point - line[position-1]) / (line[position] - line[position-1])
    return main,fraction

@nb.njit(parallel=True)
def run_multiple(line,data):
    out = np.empty((data.shape[0],3))
    for i in nb.prange(data.shape[0]):
        idx,frac = get_position(line,data[i])
        out[i,0] = data[i]
        out[i,1] = idx
        out[i,2] = frac
    return out

时间

n = 10**5
line = np.linspace(-500,n) * 500

%timeit run_multiple(line,points)
#1.08 ms ± 14 µs per loop (mean ± std. dev. of 7 runs,1000 loops each)

#@user3483203
%timeit frac(line,points)
#8.65 ms ± 266 µs per loop (mean ± std. dev. of 7 runs,100 loops each)
本文链接:https://www.f2er.com/3116398.html

大家都在问