From 682ffb5888b89e40243f2d72ea38124bd1c21b0c Mon Sep 17 00:00:00 2001 From: liuqi Date: Thu, 29 Mar 2018 16:44:30 +0800 Subject: [PATCH] Remove skip_validation flag and add validation_inputs_data flag --- tools/example.yaml | 2 ++ tools/mace_tools.py | 8 ++------ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tools/example.yaml b/tools/example.yaml index 882e8514..6fe860d0 100644 --- a/tools/example.yaml +++ b/tools/example.yaml @@ -20,6 +20,8 @@ models: dsp_mode: 0 obfuscate: 1 fast_conv: 0 + validation_inputs_data: + - path/to/input_files capture_net: platform: caffe model_file_path: path/to/model.prototxt diff --git a/tools/mace_tools.py b/tools/mace_tools.py index 2e0ea3fa..c952bafa 100644 --- a/tools/mace_tools.py +++ b/tools/mace_tools.py @@ -96,7 +96,6 @@ def generate_random_input(target_soc, model_output_dir, input_file_list[i].startswith("https://"): urllib.urlretrieve(input_file_list[i], dst_input_file) else: - print 'Copy input data:', dst_input_file shutil.copy(input_file_list[i], dst_input_file) def generate_model_code(): @@ -294,10 +293,8 @@ def main(unused_args): # Transfer params by environment os.environ["MODEL_TAG"] = model_name print '=======================', model_name, '=======================' - skip_validation = configs["models"][model_name].get( - "skip_validation", 0) model_config = configs["models"][model_name] - input_file_list = model_config.get("input_files", []) + input_file_list = model_config.get("validation_inputs_data", []) for key in model_config: if key in ['input_nodes', 'output_nodes'] and isinstance( model_config[key], list): @@ -357,8 +354,7 @@ def main(unused_args): if FLAGS.mode == "benchmark": benchmark_model(target_soc, model_output_dir, option_args) - if FLAGS.mode == "validate" or (FLAGS.mode == "all" and - skip_validation == 0): + if FLAGS.mode == "validate" or FLAGS.mode == "all": validate_model(target_soc, model_output_dir) if FLAGS.mode == "build" or FLAGS.mode == "merge" or FLAGS.mode == "all": -- GitLab