提交 92ada581 编写于 作者: 刘琦

Merge branch 'fix_multi_tf_model_validate' into 'master'

fix common node name conflit when validate multi tf model

See merge request !408
......@@ -246,7 +246,7 @@ def merge_libs_and_tuning_results(target_soc,
embed_model_data)
def download_model_files(model_file_path,
def get_model_files(model_file_path,
model_output_dir,
weight_file_path=""):
model_file = ""
......@@ -255,11 +255,15 @@ def download_model_files(model_file_path,
model_file_path.startswith("https://"):
model_file = model_output_dir + "/model.pb"
urllib.urlretrieve(model_file_path, model_file)
else:
model_file = model_file_path
if weight_file_path.startswith("http://") or \
weight_file_path.startswith("https://"):
weight_file = model_output_dir + "/model.caffemodel"
urllib.urlretrieve(weight_file_path, weight_file)
else:
weight_file = weight_file_path
return model_file, weight_file
......@@ -350,7 +354,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
sh_commands.clear_mace_run_data(
target_abi, target_soc, phone_data_dir)
model_file_path, weight_file_path = download_model_files(
model_file_path, weight_file_path = get_model_files(
model_config["model_file_path"],
model_output_dir,
model_config.get("weight_file_path", ""))
......
......@@ -423,8 +423,8 @@ def gen_random_input(model_output_dir,
input_file_name="model_input"):
for input_name in input_nodes:
formatted_name = formatted_file_name(input_name, input_file_name)
if os.path.exists(formatted_name):
sh.rm(formatted_name)
if os.path.exists("%s/%s" % (model_output_dir, formatted_name)):
sh.rm("%s/%s" % (model_output_dir, formatted_name))
input_nodes_str = ",".join(input_nodes)
input_shapes_str = ":".join(input_shapes)
generate_input_data("%s/%s" % (model_output_dir, input_file_name),
......@@ -606,8 +606,9 @@ def validate_model(target_soc,
for output_name in output_nodes:
formatted_name = formatted_file_name(
output_name, output_file_name)
if os.path.exists(formatted_name):
sh.rm(formatted_name)
if os.path.exists("%s/%s" % (model_output_dir,
formatted_name)):
sh.rm("%s/%s" % (model_output_dir, formatted_name))
adb_pull("%s/%s" % (phone_data_dir, formatted_name),
model_output_dir, serialno)
validate(platform, model_file_path, "",
......
......@@ -75,6 +75,7 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file,
print("Input graph file '" + model_file + "' does not exist!")
sys.exit(-1)
tf.reset_default_graph()
input_graph_def = tf.GraphDef()
with open(model_file, "rb") as f:
data = f.read()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册