From c723f6f44b1cb156c972f1b7b342477a4943ce13 Mon Sep 17 00:00:00 2001 From: yejianwu Date: Mon, 23 Apr 2018 14:25:08 +0800 Subject: [PATCH] fix common node name conflit when validate multi tf model --- tools/sh_commands.py | 9 +++++---- tools/validate.py | 1 + 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tools/sh_commands.py b/tools/sh_commands.py index f54c0a0d..6aa77834 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -423,8 +423,8 @@ def gen_random_input(model_output_dir, input_file_name="model_input"): for input_name in input_nodes: formatted_name = formatted_file_name(input_name, input_file_name) - if os.path.exists(formatted_name): - sh.rm(formatted_name) + if os.path.exists("%s/%s" % (model_output_dir, formatted_name)): + sh.rm("%s/%s" % (model_output_dir, formatted_name)) input_nodes_str = ",".join(input_nodes) input_shapes_str = ":".join(input_shapes) generate_input_data("%s/%s" % (model_output_dir, input_file_name), @@ -606,8 +606,9 @@ def validate_model(target_soc, for output_name in output_nodes: formatted_name = formatted_file_name( output_name, output_file_name) - if os.path.exists(formatted_name): - sh.rm(formatted_name) + if os.path.exists("%s/%s" % (model_output_dir, + formatted_name)): + sh.rm("%s/%s" % (model_output_dir, formatted_name)) adb_pull("%s/%s" % (phone_data_dir, formatted_name), model_output_dir, serialno) validate(platform, model_file_path, "", diff --git a/tools/validate.py b/tools/validate.py index 2c76ecae..608cf1c5 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -75,6 +75,7 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file, print("Input graph file '" + model_file + "' does not exist!") sys.exit(-1) + tf.reset_default_graph() input_graph_def = tf.GraphDef() with open(model_file, "rb") as f: data = f.read() -- GitLab