diff --git a/tools/validate.py b/tools/validate.py index be499c1a3bed51ea2e0631d71dd3d3630ad97bff..3c85fb553f4329931a8178adafeb94d49ebbb4ac 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -18,8 +18,6 @@ import os import os.path import numpy as np import re -from scipy import spatial -from scipy import stats import common @@ -60,14 +58,22 @@ def calculate_sqnr(expected, actual): return signal_power_sum / (noise_power_sum + 1e-15) +def calculate_similarity(u, v, data_type=np.float64): + if u.dtype is not data_type: + u = u.astype(data_type) + if v.dtype is not data_type: + v = v.astype(data_type) + return np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v)) + + def compare_output(platform, device_type, output_name, mace_out_value, out_value, validation_threshold): if mace_out_value.size != 0: out_value = out_value.reshape(-1) mace_out_value = mace_out_value.reshape(-1) assert len(out_value) == len(mace_out_value) - similarity = (1 - spatial.distance.cosine(out_value, mace_out_value)) sqnr = calculate_sqnr(out_value, mace_out_value) + similarity = calculate_similarity(out_value, mace_out_value) common.MaceLogger.summary( output_name + ' MACE VS ' + platform.upper() + ' similarity: ' + str(similarity) + ' , sqnr: ' + str(sqnr))