From 6a577b1ebf19d6659a46d02e861fc21da8303054 Mon Sep 17 00:00:00 2001 From: yejianwu Date: Wed, 7 Feb 2018 17:20:01 +0800 Subject: [PATCH] add validate mode in tools/mace_tools.py --- mace_tools.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mace_tools.py b/mace_tools.py index 4909e5e2..e201e49c 100644 --- a/mace_tools.py +++ b/mace_tools.py @@ -235,7 +235,7 @@ def parse_args(): parser.add_argument( "--tuning", type="bool", default="true", help="Tune opencl params.") parser.add_argument( - "--mode", type=str, default="all", help="[build|run|merge|all].") + "--mode", type=str, default="all", help="[build|run|validate|merge|all].") return parser.parse_known_args() @@ -249,6 +249,9 @@ def main(unused_args): elif os.path.exists(os.path.join(FLAGS.output_dir, "libmace")): shutil.rmtree(os.path.join(FLAGS.output_dir, "libmace")) + if FLAGS.mode == "validate": + FLAGS.round = 1 + libmace_name = get_libs(configs) model_output_dirs = [] @@ -265,17 +268,17 @@ def main(unused_args): os.makedirs(model_output_dir) clear_env() - if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "all": + if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all": generate_random_input(model_output_dir) if FLAGS.mode == "build" or FLAGS.mode == "all": generate_model_code() build_mace_run_prod(model_output_dir, FLAGS.tuning, libmace_name) - if FLAGS.mode == "run" or FLAGS.mode == "all": + if FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all": run_model(model_output_dir, FLAGS.round) - if FLAGS.mode == "all": + if FLAGS.mode == "validate" or FLAGS.mode == "all": validate_model(model_output_dir) if FLAGS.mode == "build" or FLAGS.mode == "merge" or FLAGS.mode == "all": -- GitLab