diff --git a/tools/mace_tools.py b/tools/mace_tools.py index 39b0bba674744d3148c432bc7a8bd382de3d2c49..86071dbb9639d891712b38493ea2425d309d60fd 100644 --- a/tools/mace_tools.py +++ b/tools/mace_tools.py @@ -246,20 +246,24 @@ def merge_libs_and_tuning_results(target_soc, embed_model_data) -def download_model_files(model_file_path, - model_output_dir, - weight_file_path=""): +def get_model_files(model_file_path, + model_output_dir, + weight_file_path=""): model_file = "" weight_file = "" if model_file_path.startswith("http://") or \ model_file_path.startswith("https://"): model_file = model_output_dir + "/model.pb" urllib.urlretrieve(model_file_path, model_file) + else: + model_file = model_file_path if weight_file_path.startswith("http://") or \ weight_file_path.startswith("https://"): weight_file = model_output_dir + "/model.caffemodel" urllib.urlretrieve(weight_file_path, weight_file) + else: + weight_file = weight_file_path return model_file, weight_file @@ -350,7 +354,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level, sh_commands.clear_mace_run_data( target_abi, target_soc, phone_data_dir) - model_file_path, weight_file_path = download_model_files( + model_file_path, weight_file_path = get_model_files( model_config["model_file_path"], model_output_dir, model_config.get("weight_file_path", ""))