提交 c723f6f4 编写于 作者: Y yejianwu

fix common node name conflit when validate multi tf model

上级 f2fc2197
...@@ -423,8 +423,8 @@ def gen_random_input(model_output_dir, ...@@ -423,8 +423,8 @@ def gen_random_input(model_output_dir,
input_file_name="model_input"): input_file_name="model_input"):
for input_name in input_nodes: for input_name in input_nodes:
formatted_name = formatted_file_name(input_name, input_file_name) formatted_name = formatted_file_name(input_name, input_file_name)
if os.path.exists(formatted_name): if os.path.exists("%s/%s" % (model_output_dir, formatted_name)):
sh.rm(formatted_name) sh.rm("%s/%s" % (model_output_dir, formatted_name))
input_nodes_str = ",".join(input_nodes) input_nodes_str = ",".join(input_nodes)
input_shapes_str = ":".join(input_shapes) input_shapes_str = ":".join(input_shapes)
generate_input_data("%s/%s" % (model_output_dir, input_file_name), generate_input_data("%s/%s" % (model_output_dir, input_file_name),
...@@ -606,8 +606,9 @@ def validate_model(target_soc, ...@@ -606,8 +606,9 @@ def validate_model(target_soc,
for output_name in output_nodes: for output_name in output_nodes:
formatted_name = formatted_file_name( formatted_name = formatted_file_name(
output_name, output_file_name) output_name, output_file_name)
if os.path.exists(formatted_name): if os.path.exists("%s/%s" % (model_output_dir,
sh.rm(formatted_name) formatted_name)):
sh.rm("%s/%s" % (model_output_dir, formatted_name))
adb_pull("%s/%s" % (phone_data_dir, formatted_name), adb_pull("%s/%s" % (phone_data_dir, formatted_name),
model_output_dir, serialno) model_output_dir, serialno)
validate(platform, model_file_path, "", validate(platform, model_file_path, "",
......
...@@ -75,6 +75,7 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file, ...@@ -75,6 +75,7 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file,
print("Input graph file '" + model_file + "' does not exist!") print("Input graph file '" + model_file + "' does not exist!")
sys.exit(-1) sys.exit(-1)
tf.reset_default_graph()
input_graph_def = tf.GraphDef() input_graph_def = tf.GraphDef()
with open(model_file, "rb") as f: with open(model_file, "rb") as f:
data = f.read() data = f.read()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册