diff --git a/tools/validate.py b/tools/validate.py index 516cf5128dddf1f1557bfdb8dd5db5a0364660bc..bc3a9c5db46e2851d6309ad2b0c181b6a9acd26d 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -49,6 +49,17 @@ def load_data(file, data_type='float32'): return np.empty([0]) +def calculate_sqnr(expected, actual): + noise = expected - actual + + def power_sum(xs): + return sum([x * x for x in xs]) + + signal_power_sum = power_sum(expected) + noise_power_sum = power_sum(noise) + return signal_power_sum / (noise_power_sum + 1e-15) + + def compare_output(platform, device_type, output_name, mace_out_value, out_value, validation_threshold): if mace_out_value.size != 0: @@ -56,9 +67,10 @@ def compare_output(platform, device_type, output_name, mace_out_value, mace_out_value = mace_out_value.reshape(-1) assert len(out_value) == len(mace_out_value) similarity = (1 - spatial.distance.cosine(out_value, mace_out_value)) + sqnr = calculate_sqnr(out_value, mace_out_value) common.MaceLogger.summary( output_name + ' MACE VS ' + platform.upper() - + ' similarity: ' + str(similarity)) + + ' similarity: ' + str(similarity) + ' , sqnr: ' + str(sqnr)) if similarity > validation_threshold: common.MaceLogger.summary( common.StringFormatter.block("Similarity Test Passed"))