diff --git a/example.yaml b/example.yaml index 1d1523b09b449f867cc6dc285c84b9a07a9a2874..c476e20924e203206ff85fbb977f67ab6658993f 100644 --- a/example.yaml +++ b/example.yaml @@ -1,7 +1,8 @@ # example.yaml # Each yaml file describes a exported library (could be named [target_abi]/libmace-${filename}.a), # which can contains more than one models -target_abi: armeabi-v7a # arm64-v8a +target_abis: [armeabi-v7a, arm64-v8a] +target_socs: [MSM8953] # target_socs not enabled yet embed_model_data: 1 vlog_level: 0 models: diff --git a/mace_tools.py b/mace_tools.py index 9533df877d95c319e1ddd7c6be023cdad4a80d79..69f105d0f5240d3d573634193e3316c9b10c9103 100644 --- a/mace_tools.py +++ b/mace_tools.py @@ -186,51 +186,55 @@ def main(unused_args): if FLAGS.mode == "validate": FLAGS.round = 1 - target_abi = configs["target_abi"] - libmace_name = get_libs(target_abi, configs) + # target_abi = configs["target_abi"] + # libmace_name = get_libs(target_abi, configs) # Transfer params by environment - os.environ["TARGET_ABI"] = target_abi + # os.environ["TARGET_ABI"] = target_abi os.environ["EMBED_MODEL_DATA"] = str(configs["embed_model_data"]) os.environ["VLOG_LEVEL"] = str(configs["vlog_level"]) os.environ["PROJECT_NAME"] = os.path.splitext(FLAGS.config)[0] - model_output_dirs = [] - for model_name in configs["models"]: + for target_abi in configs["target_abis"]: + libmace_name = get_libs(target_abi, configs) # Transfer params by environment - os.environ["MODEL_TAG"] = model_name - model_config = configs["models"][model_name] - for key in model_config: - os.environ[key.upper()] = str(model_config[key]) - - model_output_dir = FLAGS.output_dir + "/" + target_abi + "/" + os.path.splitext( - model_config["model_file_path"])[0] - model_output_dirs.append(model_output_dir) - - if FLAGS.mode == "build" or FLAGS.mode == "all": - if os.path.exists(model_output_dir): - shutil.rmtree(model_output_dir) - os.makedirs(model_output_dir) - clear_env() - - if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all": - generate_random_input(model_output_dir) - - if FLAGS.mode == "build" or FLAGS.mode == "all": - generate_model_code() - build_mace_run_prod(model_output_dir, FLAGS.tuning, libmace_name) - - if FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all": - run_model(model_output_dir, FLAGS.round) - - if FLAGS.mode == "benchmark": - benchmark_model(model_output_dir) - - if FLAGS.mode == "validate" or FLAGS.mode == "all": - validate_model(model_output_dir) - - if FLAGS.mode == "build" or FLAGS.mode == "merge" or FLAGS.mode == "all": - merge_libs_and_tuning_results(FLAGS.output_dir + "/" + target_abi, - model_output_dirs) + os.environ["TARGET_ABI"] = target_abi + model_output_dirs = [] + for model_name in configs["models"]: + # Transfer params by environment + os.environ["MODEL_TAG"] = model_name + model_config = configs["models"][model_name] + for key in model_config: + os.environ[key.upper()] = str(model_config[key]) + + model_output_dir = FLAGS.output_dir + "/" + target_abi + "/" + os.path.splitext( + model_config["model_file_path"])[0] + model_output_dirs.append(model_output_dir) + + if FLAGS.mode == "build" or FLAGS.mode == "all": + if os.path.exists(model_output_dir): + shutil.rmtree(model_output_dir) + os.makedirs(model_output_dir) + clear_env() + + if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all": + generate_random_input(model_output_dir) + + if FLAGS.mode == "build" or FLAGS.mode == "all": + generate_model_code() + build_mace_run_prod(model_output_dir, FLAGS.tuning, libmace_name) + + if FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all": + run_model(model_output_dir, FLAGS.round) + + if FLAGS.mode == "benchmark": + benchmark_model(model_output_dir) + + if FLAGS.mode == "validate" or FLAGS.mode == "all": + validate_model(model_output_dir) + + if FLAGS.mode == "build" or FLAGS.mode == "merge" or FLAGS.mode == "all": + merge_libs_and_tuning_results(FLAGS.output_dir + "/" + target_abi, + model_output_dirs) if __name__ == "__main__":