churchyard / rpms / numpy

Forked from rpms/numpy 6 years ago
Clone
Blob Blame History Raw
From ad2a73c18dcff95d844c382c94ab7f73b5571cf3 Mon Sep 17 00:00:00 2001
From: Sebastian Berg <sebastian@sipsolutions.net>
Date: Tue, 4 May 2021 17:43:26 -0500
Subject: [PATCH] MAINT: Adjust NumPy float hashing to Python's slightly
 changed hash

This is necessary, since we use the Python double hash and the
semi-private function to calculate it in Python has a new signature
to return the identity-hash when the value is NaN.

closes gh-18833, gh-18907
---
 numpy/core/src/common/npy_pycompat.h        | 16 ++++++++++
 numpy/core/src/multiarray/scalartypes.c.src | 13 ++++----
 numpy/core/tests/test_scalarmath.py         | 34 +++++++++++++++++++++
 3 files changed, 57 insertions(+), 6 deletions(-)

diff --git a/numpy/core/src/common/npy_pycompat.h b/numpy/core/src/common/npy_pycompat.h
index aa0b5c1224d3..9e94a971090a 100644
--- a/numpy/core/src/common/npy_pycompat.h
+++ b/numpy/core/src/common/npy_pycompat.h
@@ -3,4 +3,20 @@
 
 #include "numpy/npy_3kcompat.h"
 
+
+/*
+ * In Python 3.10a7 (or b1), python started using the identity for the hash
+ * when a value is NaN.  See https://bugs.python.org/issue43475
+ */
+#if PY_VERSION_HEX > 0x030a00a6
+#define Npy_HashDouble _Py_HashDouble
+#else
+static NPY_INLINE Py_hash_t
+Npy_HashDouble(PyObject *NPY_UNUSED(identity), double val)
+{
+    return _Py_HashDouble(val);
+}
+#endif
+
+
 #endif /* _NPY_COMPAT_H_ */
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index a001500b0a97..9930f7791d6e 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -3172,7 +3172,7 @@ static npy_hash_t
 static npy_hash_t
 @lname@_arrtype_hash(PyObject *obj)
 {
-    return _Py_HashDouble((double) PyArrayScalar_VAL(obj, @name@));
+    return Npy_HashDouble(obj, (double)PyArrayScalar_VAL(obj, @name@));
 }
 
 /* borrowed from complex_hash */
@@ -3180,14 +3180,14 @@ static npy_hash_t
 c@lname@_arrtype_hash(PyObject *obj)
 {
     npy_hash_t hashreal, hashimag, combined;
-    hashreal = _Py_HashDouble((double)
-            PyArrayScalar_VAL(obj, C@name@).real);
+    hashreal = Npy_HashDouble(
+            obj, (double)PyArrayScalar_VAL(obj, C@name@).real);
 
     if (hashreal == -1) {
         return -1;
     }
-    hashimag = _Py_HashDouble((double)
-            PyArrayScalar_VAL(obj, C@name@).imag);
+    hashimag = Npy_HashDouble(
+            obj, (double)PyArrayScalar_VAL(obj, C@name@).imag);
     if (hashimag == -1) {
         return -1;
     }
@@ -3202,7 +3202,8 @@ c@lname@_arrtype_hash(PyObject *obj)
 static npy_hash_t
 half_arrtype_hash(PyObject *obj)
 {
-    return _Py_HashDouble(npy_half_to_double(PyArrayScalar_VAL(obj, Half)));
+    return Npy_HashDouble(
+            obj, npy_half_to_double(PyArrayScalar_VAL(obj, Half)));
 }
 
 static npy_hash_t
diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py
index d91b4a39146d..09a734284a76 100644
--- a/numpy/core/tests/test_scalarmath.py
+++ b/numpy/core/tests/test_scalarmath.py
@@ -707,3 +707,37 @@
                 shift_arr = np.array([shift]*32, dtype=dt)
                 res_arr = op(val_arr, shift_arr)
                 assert_equal(res_arr, res_scl)
+
+
+class TestHash:
+    @pytest.mark.parametrize("type_code", np.typecodes['AllInteger'])
+    def test_integer_hashes(self, type_code):
+        scalar = np.dtype(type_code).type
+        for i in range(128):
+            assert hash(i) == hash(scalar(i))
+
+    @pytest.mark.parametrize("type_code", np.typecodes['AllFloat'])
+    def test_float_and_complex_hashes(self, type_code):
+        scalar = np.dtype(type_code).type
+        for val in [np.pi, np.inf, 3, 6.]:
+            numpy_val = scalar(val)
+            # Cast back to Python, in case the NumPy scalar has less precision
+            if numpy_val.dtype.kind == 'c':
+                val = complex(numpy_val)
+            else:
+                val = float(numpy_val)
+            assert val == numpy_val
+            print(repr(numpy_val), repr(val))
+            assert hash(val) == hash(numpy_val)
+
+        if hash(float(np.nan)) != hash(float(np.nan)):
+            # If Python distinguises different NaNs we do so too (gh-18833)
+            assert hash(scalar(np.nan)) != hash(scalar(np.nan))
+
+    @pytest.mark.parametrize("type_code", np.typecodes['Complex'])
+    def test_complex_hashes(self, type_code):
+        # Test some complex valued hashes specifically:
+        scalar = np.dtype(type_code).type
+        for val in [np.pi+1j, np.inf-3j, 3j, 6.+1j]:
+            numpy_val = scalar(val)
+            assert hash(complex(numpy_val)) == hash(numpy_val)