diff --git a/tools/converter.py b/tools/converter.py index 1f0f568cb963231d9f6435f2d7833cda773567d9..9e7973a94ba328b63d5191ee4f739bf12dc7a76f 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -39,6 +39,7 @@ from common import StringFormatter # common definitions ################################ BUILD_OUTPUT_DIR = 'build' +BUILD_DOWNLOADS_DIR = BUILD_OUTPUT_DIR + '/downloads' PHONE_DATA_DIR = "/data/local/tmp/mace_run" MODEL_OUTPUT_DIR_NAME = 'model' MODEL_HEADER_DIR_PATH = 'include/mace/public' @@ -536,36 +537,41 @@ def print_configuration(flags, configs): MaceLogger.summary(StringFormatter.table(header, data, title)) -def download_model_files(model_file_path, - model_output_dir, - weight_file_path=""): - MaceLogger.info("Downloading model, please wait ...") - if model_file_path.startswith("http://") or \ - model_file_path.startswith("https://"): - model_file = model_output_dir + "/model.pb" - urllib.urlretrieve(model_file_path, model_file) - - if weight_file_path.startswith("http://") or \ - weight_file_path.startswith("https://"): - weight_file = model_output_dir + "/model.caffemodel" - urllib.urlretrieve(weight_file_path, weight_file) - MaceLogger.info("Model downloaded successfully.") - +def get_model_files(model_file_path, + model_sha256_checksum, + model_output_dir, + weight_file_path="", + weight_sha256_checksum=""): + model_file = model_file_path + weight_file = weight_file_path -def get_model_files_path(model_file_path, - model_output_dir, - weight_file_path=""): if model_file_path.startswith("http://") or \ model_file_path.startswith("https://"): - model_file = model_output_dir + "/model.pb" - else: - model_file = model_file_path + model_file = model_output_dir + "/" + md5sum(model_file_path) + ".pb" + if not os.path.exists(model_file) or \ + sha256_checksum(model_file) != model_sha256_checksum: + MaceLogger.info("Downloading model, please wait ...") + urllib.urlretrieve(model_file_path, model_file) + MaceLogger.info("Model downloaded successfully.") + + if sha256_checksum(model_file) != model_sha256_checksum: + MaceLogger.error(ModuleName.MODEL_CONVERTER, + "model file sha256checksum not match") if weight_file_path.startswith("http://") or \ weight_file_path.startswith("https://"): - weight_file = model_output_dir + "/model.caffemodel" - else: - weight_file = weight_file_path + weight_file = \ + model_output_dir + "/" + md5sum(weight_file_path) + ".caffemodel" + if not os.path.exists(weight_file) or \ + sha256_checksum(weight_file) != weight_sha256_checksum: + MaceLogger.info("Downloading model weight, please wait ...") + urllib.urlretrieve(weight_file_path, weight_file) + MaceLogger.info("Model weight downloaded successfully.") + + if weight_file: + if sha256_checksum(weight_file) != weight_sha256_checksum: + MaceLogger.error(ModuleName.MODEL_CONVERTER, + "weight file sha256checksum not match") return model_file, weight_file @@ -578,6 +584,8 @@ def convert_model(configs): elif os.path.exists(os.path.join(BUILD_OUTPUT_DIR, library_name)): sh.rm("-rf", os.path.join(BUILD_OUTPUT_DIR, library_name)) os.makedirs(os.path.join(BUILD_OUTPUT_DIR, library_name)) + if not os.path.exists(BUILD_DOWNLOADS_DIR): + os.makedirs(BUILD_DOWNLOADS_DIR) model_output_dir = \ '%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME) @@ -609,38 +617,12 @@ def convert_model(configs): model_config = configs[YAMLKeyword.models][model_name] runtime = model_config[YAMLKeyword.runtime] - # Create model build directory - model_path_digest = md5sum( - model_config[YAMLKeyword.model_file_path]) - - model_output_base_dir = "%s/%s/%s/%s/%s" % ( - BUILD_OUTPUT_DIR, library_name, BUILD_TMP_DIR_NAME, - model_name, model_path_digest) - - if os.path.exists(model_output_base_dir): - sh.rm("-rf", model_output_base_dir) - os.makedirs(model_output_base_dir) - - download_model_files( - model_config[YAMLKeyword.model_file_path], - model_output_base_dir, - model_config[YAMLKeyword.weight_file_path]) - - model_file_path, weight_file_path = get_model_files_path( + model_file_path, weight_file_path = get_model_files( model_config[YAMLKeyword.model_file_path], - model_output_base_dir, - model_config[YAMLKeyword.weight_file_path]) - - if sha256_checksum(model_file_path) != \ - model_config[YAMLKeyword.model_sha256_checksum]: - MaceLogger.error(ModuleName.MODEL_CONVERTER, - "model file sha256checksum not match") - - if weight_file_path: - if sha256_checksum(weight_file_path) != \ - model_config[YAMLKeyword.weight_sha256_checksum]: - MaceLogger.error(ModuleName.MODEL_CONVERTER, - "weight file sha256checksum not match") + model_config[YAMLKeyword.model_sha256_checksum], + BUILD_DOWNLOADS_DIR, + model_config[YAMLKeyword.weight_file_path], + model_config[YAMLKeyword.weight_sha256_checksum]) data_type = model_config[YAMLKeyword.data_type] # TODO(liuqi): support multiple subgraphs @@ -1068,10 +1050,12 @@ def run_specific_target(flags, configs, target_abi, linkshared=linkshared, ) if flags.validate: - model_file_path, weight_file_path = get_model_files_path( + model_file_path, weight_file_path = get_model_files( model_config[YAMLKeyword.model_file_path], - model_output_base_dir, - model_config[YAMLKeyword.weight_file_path]) + model_config[YAMLKeyword.model_sha256_checksum], + BUILD_DOWNLOADS_DIR, + model_config[YAMLKeyword.weight_file_path], + model_config[YAMLKeyword.weight_sha256_checksum]) sh_commands.validate_model( abi=target_abi,