diff --git a/tools/validate.py b/tools/validate.py index 8b8a9e82c8eccd10243ba1232984c080c873900d..4f7800426834d52006e8785a5acf21daabd889e5 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -69,9 +69,9 @@ def calculate_similarity(u, v, data_type=np.float64): norm = u_norm * v_norm if norm == 0: if u_norm == 0 and v_norm == 0: - return 1 + return data_type(1) else: - return 0 + return data_type(0) else: return np.dot(u, v) / norm