Numpy ndarray 的高级索引存在 "bug" ?

时间:2022-06-28 01:49:10

Numpy ndarray 高级索引 "bug" ?

话说一天,搞事情,代码如下

import numpy as np

tmp = [1, 2, 3, 4] * 2
a, b = np.zeros((10, 10)), np.zeros((10, 10))
a[tmp[:-1], tmp[1:]] += 1
for i in range(len(tmp) - 1):
    b[tmp[i], tmp[i + 1]] += 1
print(a.sum() - b.sum())

心理预期a 与 b应该完全一样,但是实际结果却不一样!得出的和差是-3

为什么?

print(list(zip(tmp[:-1],tmp[1:])))
# [(1, 2), (2, 3), (3, 4), (4, 1), (1, 2), (2, 3), (3, 4)]

坐标集中存在重复的坐标。。。。改成如下,就没毛病了。。。直接使用ndarray 的高级索引进行操作,会自动对索引进行去重操作。。。

import numpy as np

tmp = [1, 2, 3, 4] * 2
a, b = np.zeros((10, 10)), np.zeros((10, 10))

np.add.at(a, [tmp[:-1], tmp[1:]], 1)
for i in range(len(tmp) - 1):
    b[tmp[i], tmp[i + 1]] += 1

print(a.sum() - b.sum())