提问者:小点点

用numpy表示的装箱数据的索引


我想在每次超过阈值10000时绑定数据。

我尝试过,但没有运气:

# data which is an array of floats

diff = np.diff(np.cumsum(data)//10000, prepend=0)

indices = (np.argwhere(diff > 0)).flatten()

问题是所有的垃圾箱都装不了10000,而这正是我的目标。

input_data = [4000, 5000, 6000, 2000, 8000, 3000]
# (4000+5000+6000 >= 10000. Index 2)
# (2000+8000 >= 10000. Index 4)
Output: [2, 4]

我想知道是否有任何替代for循环的方法?


共1个答案

匿名用户

下面是如何使用np.searchsorted快速查找bin边界,通过循环可以相当有效地完成此操作:

import numpy as np

np.random.seed(0)
bin_size = 10_000
data = np.random.randint(100, size=20_000)

# Naive solution (incorrect, for comparison)
data_f = np.floor(np.cumsum(data) / bin_size).astype(int)
bin_starts = np.r_[0, np.where(np.diff(data_f) > 0)[0] + 1]
# Check bin sizes
bin_sums = np.add.reduceat(data, bin_starts)
# We go over the limit!
print(bin_sums.max())
# 10080

# Better solution with loop
data_c = np.cumsum(data)
ref_val = 0
bin_starts = [0]
while True:
    # Search next split point
    ref_idx = bin_starts[-1]
    # Binary search through remaining cumsum
    next_idx = np.searchsorted(data_c[ref_idx:], ref_val + bin_size, side='right')
    next_idx += ref_idx
    # If we finished the array stop
    if next_idx >= len(data_c):
        break
    # Add new bin boundary
    bin_starts.append(next_idx)
    ref_val = data_c[next_idx - 1]
# Convert bin boundaries to array
bin_starts = np.array(bin_starts)
# Check bin sizes
bin_sums = np.add.reduceat(data, bin_starts)
# Does not go over limit
print(bin_sums.max())
# 10000