此代码:
a = np.array([10], dtype=np.int8)
b = np.array([2], dtype=np.int8)
print(np.dot(a, b))
a = np.array([10], dtype=np.int8)
b = np.array([5], dtype=np.int8)
print(np.dot(a, b))
a = np.array([10], dtype=np.int8)
b = np.array([20], dtype=np.int8)
print(np.dot(a, b))
产生以下输出:
20
50
-56
似乎np.dot
将尝试在相同的数据类型对象中返回结果,即使它不适合。这肯定是bug吗?为什么它不抛出异常?
乘法和加法都是如此。
In [89]: np.array([128], 'int8')*2
Out[89]: array([0], dtype=int8)
In [90]: np.array([127], 'int8')*2
Out[90]: array([-2], dtype=int8) # same int8 dtype
但是如果我使用数组的一个元素,一个np. int8
对象,结果会得到提升。
In [91]: np.array([127], 'int8')[0]*2
Out[91]: 254
In [92]: type(_)
Out[92]: numpy.int32
我认为,虽然不能即刻生产,但也有这种事情引发错误的情况。
这已经在其他SO讨论过,对于乘法,如果不是对于np.dot
。
这是'uint8'dtype的溢出问题,github问题链接:
允许numpy类型的溢出
https://github.com/numpy/numpy/issues/8987"BUG:整数溢出警告适用于标量但不适用于数组"
这可能取决于使用的dtype
。事实上,如果我更改dtype:
a = np.array([10])
b = np.array([2])
print(np.dot(a, b))
a = np.array([10])
b = np.array([5])
print(np.dot(a, b))
a = np.array([10])
b = np.array([20])
print(np.dot(a, b))
输出是:
20
50
200
发生这种情况是因为typenumpy. int8
支持区间[-128,128)
中的整数,因为只需8
位,您只能对256
不同的数字进行编码。为了更好地理解在numpy.int8
中转换不在此范围内的整数n
的行为,让我们定义n=np.arange(-2**8,2**8)
,让我们看看这个数组是如何在numpy.int8
中转换的:
plt.plot(n, np.int8(n))
plt.xticks([0, 32, 64, 128, 256, -32, -64, -128, -256])
plt.yticks([0, 32, 64, 128, -32, -64, -128])
plt.xlabel("n")
plt.ylabel("int8(n)")
plt.show()
如您所见,转换函数是周期256
的周期函数;超出范围[-128,128)
您将拥有n!=numpy. int8(n)
;一般来说,您可以说numpy.int8(n) == ((n-128)%6)-128)
,实际上((200-128)%6)-128==-56
PS。如果你只需要正整数,你可以使用类型numpy. uint8
对区间[0,256)
中的数字进行编码