From 767ee6e3cd5f29bd88fe98f84e1f630e2b4f67af Mon Sep 17 00:00:00 2001 From: liyin Date: Fri, 2 Aug 2019 13:24:28 +0800 Subject: [PATCH] Sort range log topologically --- mace/python/tools/quantization/quantize_stat.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/mace/python/tools/quantization/quantize_stat.py b/mace/python/tools/quantization/quantize_stat.py index 31fd110f..a3e911bb 100644 --- a/mace/python/tools/quantization/quantize_stat.py +++ b/mace/python/tools/quantization/quantize_stat.py @@ -9,6 +9,8 @@ class QuantizeStat(object): @staticmethod def run(log_file, percentile, enhance, enhance_ratio): res = {} + tensor_id = {} + idx = 0 tensor_ranges = {} with open(log_file) as log: for line in log: @@ -20,6 +22,9 @@ class QuantizeStat(object): tensor_ranges[tensor_name] = ([], []) tensor_ranges[tensor_name][0].append(min_val) tensor_ranges[tensor_name][1].append(max_val) + if tensor_name not in tensor_id: + tensor_id[tensor_name] = idx + idx = idx + 1 for tensor_name in tensor_ranges: samples = len(tensor_ranges[tensor_name][0]) @@ -65,7 +70,10 @@ class QuantizeStat(object): res[tensor_name] = (cur_min, cur_max) - return res + res_list = [(name, rng) for (name, rng) in res.items()] + res_list.sort(key=lambda x: tensor_id[x[0]]) + + return res_list if __name__ == '__main__': @@ -93,5 +101,5 @@ if __name__ == '__main__': res = QuantizeStat.run(FLAGS.log_file, FLAGS.percentile, FLAGS.enhance, FLAGS.enhance_ratio) - for tensor in res: - print("%s@@%f,%f" % (tensor, res[tensor][0], res[tensor][1])) + for r in res: + print("%s@@%f,%f" % (r[0], r[1][0], r[1][1])) -- GitLab