diff --git a/tools/sh_commands.py b/tools/sh_commands.py index f54c0a0d767e7cea8f29fad7c0e2356a39a03e19..6aa7783451d0c786217d367c304fdc1ec23b8018 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -423,8 +423,8 @@ def gen_random_input(model_output_dir, input_file_name="model_input"): for input_name in input_nodes: formatted_name = formatted_file_name(input_name, input_file_name) - if os.path.exists(formatted_name): - sh.rm(formatted_name) + if os.path.exists("%s/%s" % (model_output_dir, formatted_name)): + sh.rm("%s/%s" % (model_output_dir, formatted_name)) input_nodes_str = ",".join(input_nodes) input_shapes_str = ":".join(input_shapes) generate_input_data("%s/%s" % (model_output_dir, input_file_name), @@ -606,8 +606,9 @@ def validate_model(target_soc, for output_name in output_nodes: formatted_name = formatted_file_name( output_name, output_file_name) - if os.path.exists(formatted_name): - sh.rm(formatted_name) + if os.path.exists("%s/%s" % (model_output_dir, + formatted_name)): + sh.rm("%s/%s" % (model_output_dir, formatted_name)) adb_pull("%s/%s" % (phone_data_dir, formatted_name), model_output_dir, serialno) validate(platform, model_file_path, "", diff --git a/tools/validate.py b/tools/validate.py index 2c76ecae5626ba8b8de6037de9becb3b5809094f..608cf1c521182b3f4f9924390b3c8c03e09b5ba7 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -75,6 +75,7 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file, print("Input graph file '" + model_file + "' does not exist!") sys.exit(-1) + tf.reset_default_graph() input_graph_def = tf.GraphDef() with open(model_file, "rb") as f: data = f.read()