diff --git a/example.yaml b/example.yaml index c476e20924e203206ff85fbb977f67ab6658993f..05a69a172842ab10fd1e0a26c680b45f4925f27e 100644 --- a/example.yaml +++ b/example.yaml @@ -7,7 +7,7 @@ embed_model_data: 1 vlog_level: 0 models: preview_net: - model_file_path: path/to/model64.pb + model_file_path: path/to/model64.pb # also support http:// and https:// input_node: input_node output_node: output_node input_shape: 1,64,64,3 diff --git a/mace_tools.py b/mace_tools.py index 69f105d0f5240d3d573634193e3316c9b10c9103..22f7b757a36d6293ce371a8251b9bd2605b9454c 100644 --- a/mace_tools.py +++ b/mace_tools.py @@ -7,10 +7,12 @@ # --mode=all import argparse +import base64 import os import shutil import subprocess import sys +import urllib import yaml from ConfigParser import ConfigParser @@ -206,8 +208,8 @@ def main(unused_args): for key in model_config: os.environ[key.upper()] = str(model_config[key]) - model_output_dir = FLAGS.output_dir + "/" + target_abi + "/" + os.path.splitext( - model_config["model_file_path"])[0] + model_output_dir = FLAGS.output_dir + "/" + target_abi + "/" + model_name + "/" + base64.b16encode( + model_config["model_file_path"]) model_output_dirs.append(model_output_dir) if FLAGS.mode == "build" or FLAGS.mode == "all": @@ -216,6 +218,12 @@ def main(unused_args): os.makedirs(model_output_dir) clear_env() + # Support http:// and https:// + if model_config["model_file_path"].startswith( + "http://") or model_config["model_file_path"].startswith("https://"): + os.environ["MODEL_FILE_PATH"] = model_output_dir + "/model.pb" + urllib.urlretrieve(model_config["model_file_path"], os.environ["MODEL_FILE_PATH"]) + if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all": generate_random_input(model_output_dir)