From 5d2b21136b3725b3e71864b530ea2d973ac2294a Mon Sep 17 00:00:00 2001 From: liutuo Date: Mon, 5 Nov 2018 16:20:45 +0800 Subject: [PATCH] fix calculate cosine similarity accurate --- tools/validate.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tools/validate.py b/tools/validate.py index be499c1a..3c85fb55 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -18,8 +18,6 @@ import os import os.path import numpy as np import re -from scipy import spatial -from scipy import stats import common @@ -60,14 +58,22 @@ def calculate_sqnr(expected, actual): return signal_power_sum / (noise_power_sum + 1e-15) +def calculate_similarity(u, v, data_type=np.float64): + if u.dtype is not data_type: + u = u.astype(data_type) + if v.dtype is not data_type: + v = v.astype(data_type) + return np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v)) + + def compare_output(platform, device_type, output_name, mace_out_value, out_value, validation_threshold): if mace_out_value.size != 0: out_value = out_value.reshape(-1) mace_out_value = mace_out_value.reshape(-1) assert len(out_value) == len(mace_out_value) - similarity = (1 - spatial.distance.cosine(out_value, mace_out_value)) sqnr = calculate_sqnr(out_value, mace_out_value) + similarity = calculate_similarity(out_value, mace_out_value) common.MaceLogger.summary( output_name + ' MACE VS ' + platform.upper() + ' similarity: ' + str(similarity) + ' , sqnr: ' + str(sqnr)) -- GitLab