From 89b49ec0784a6a1fb99e28600f9d541f7067fed4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Wed, 10 Oct 2018 11:04:32 +0800 Subject: [PATCH] Add sqnr metric --- tools/validate.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tools/validate.py b/tools/validate.py index 516cf512..bc3a9c5d 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -49,6 +49,17 @@ def load_data(file, data_type='float32'): return np.empty([0]) +def calculate_sqnr(expected, actual): + noise = expected - actual + + def power_sum(xs): + return sum([x * x for x in xs]) + + signal_power_sum = power_sum(expected) + noise_power_sum = power_sum(noise) + return signal_power_sum / (noise_power_sum + 1e-15) + + def compare_output(platform, device_type, output_name, mace_out_value, out_value, validation_threshold): if mace_out_value.size != 0: @@ -56,9 +67,10 @@ def compare_output(platform, device_type, output_name, mace_out_value, mace_out_value = mace_out_value.reshape(-1) assert len(out_value) == len(mace_out_value) similarity = (1 - spatial.distance.cosine(out_value, mace_out_value)) + sqnr = calculate_sqnr(out_value, mace_out_value) common.MaceLogger.summary( output_name + ' MACE VS ' + platform.upper() - + ' similarity: ' + str(similarity)) + + ' similarity: ' + str(similarity) + ' , sqnr: ' + str(sqnr)) if similarity > validation_threshold: common.MaceLogger.summary( common.StringFormatter.block("Similarity Test Passed")) -- GitLab