From 23d1bead18ac984a7313200b3832e184377f492b Mon Sep 17 00:00:00 2001 From: liukai6 Date: Mon, 13 May 2019 18:25:42 +0800 Subject: [PATCH] move get_model_files from convert to format_model_config --- tools/converter.py | 158 ++++++++++++++++++++++----------------------- 1 file changed, 76 insertions(+), 82 deletions(-) diff --git a/tools/converter.py b/tools/converter.py index 74cb28b8..98fb6bc3 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -236,6 +236,77 @@ def sha256_checksum(fname): return hash_func.hexdigest() +def download_file(url, dst, num_retries=3): + from six.moves import urllib + + try: + urllib.request.urlretrieve(url, dst) + MaceLogger.info('\nDownloaded successfully.') + except (urllib.error.ContentTooShortError, urllib.error.HTTPError, + urllib.error.URLError) as e: + MaceLogger.warning('Download error:' + str(e)) + if num_retries > 0: + return download_file(url, dst, num_retries - 1) + else: + return False + return True + + +def get_model_files(model_config, model_output_dir): + if not os.path.exists(model_output_dir): + os.makedirs(model_output_dir) + model_file_path = model_config[YAMLKeyword.model_file_path] + model_sha256_checksum = model_config[YAMLKeyword.model_sha256_checksum] + weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "") + weight_sha256_checksum = model_config.get(YAMLKeyword.weight_sha256_checksum, "") # noqa + quantize_range_file_path = model_config.get(YAMLKeyword.quantize_range_file, "") # noqa + model_file = model_file_path + weight_file = weight_file_path + quantize_range_file = quantize_range_file_path + + if model_file_path.startswith("http://") or \ + model_file_path.startswith("https://"): + model_file = model_output_dir + "/" + md5sum(model_file_path) + ".pb" + if not os.path.exists(model_file) or \ + sha256_checksum(model_file) != model_sha256_checksum: + MaceLogger.info("Downloading model, please wait ...") + if not download_file(model_file_path, model_file): + MaceLogger.error(ModuleName.MODEL_CONVERTER, + "Model download failed.") + model_config[YAMLKeyword.model_file_path] = model_file + + if sha256_checksum(model_file) != model_sha256_checksum: + MaceLogger.error(ModuleName.MODEL_CONVERTER, + "model file sha256checksum not match") + + if weight_file_path.startswith("http://") or \ + weight_file_path.startswith("https://"): + weight_file = \ + model_output_dir + "/" + md5sum(weight_file_path) + ".caffemodel" + if not os.path.exists(weight_file) or \ + sha256_checksum(weight_file) != weight_sha256_checksum: + MaceLogger.info("Downloading model weight, please wait ...") + if not download_file(weight_file_path, weight_file): + MaceLogger.error(ModuleName.MODEL_CONVERTER, + "Model download failed.") + model_config[YAMLKeyword.weight_file_path] = weight_file + + if weight_file: + if sha256_checksum(weight_file) != weight_sha256_checksum: + MaceLogger.error(ModuleName.MODEL_CONVERTER, + "weight file sha256checksum not match") + + if quantize_range_file_path.startswith("http://") or \ + quantize_range_file_path.startswith("https://"): + quantize_range_file = \ + model_output_dir + "/" + md5sum(quantize_range_file_path) \ + + ".range" + if not download_file(quantize_range_file_path, quantize_range_file): + MaceLogger.error(ModuleName.MODEL_CONVERTER, + "Model range file download failed.") + model_config[YAMLKeyword.quantize_range_file] = quantize_range_file + + def format_model_config(flags): with open(flags.config) as f: configs = yaml.load(f) @@ -351,6 +422,8 @@ def format_model_config(flags): else: model_config[YAMLKeyword.weight_sha256_checksum] = "" + get_model_files(model_config, BUILD_DOWNLOADS_DIR) + runtime = model_config.get(YAMLKeyword.runtime, "") mace_check(runtime in RuntimeTypeStrs, ModuleName.YAML_CONFIG, @@ -580,9 +653,6 @@ def format_model_config(flags): + str(WinogradParameters) + ". 0 for disable winograd convolution") - weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "") - model_config[YAMLKeyword.weight_file_path] = weight_file_path - return configs @@ -623,73 +693,6 @@ def print_configuration(configs): MaceLogger.summary(StringFormatter.table(header, data, title)) -def download_file(url, dst, num_retries=3): - from six.moves import urllib - - try: - urllib.request.urlretrieve(url, dst) - MaceLogger.info('\nDownloaded successfully.') - except (urllib.error.ContentTooShortError, urllib.error.HTTPError, - urllib.error.URLError) as e: - MaceLogger.warning('Download error:' + str(e)) - if num_retries > 0: - return download_file(url, dst, num_retries - 1) - else: - return False - return True - - -def get_model_files(model_file_path, - model_sha256_checksum, - model_output_dir, - weight_file_path="", - weight_sha256_checksum="", - quantize_range_file_path=""): - model_file = model_file_path - weight_file = weight_file_path - quantize_range_file = quantize_range_file_path - - if model_file_path.startswith("http://") or \ - model_file_path.startswith("https://"): - model_file = model_output_dir + "/" + md5sum(model_file_path) + ".pb" - if not os.path.exists(model_file) or \ - sha256_checksum(model_file) != model_sha256_checksum: - MaceLogger.info("Downloading model, please wait ...") - if not download_file(model_file_path, model_file): - MaceLogger.error(ModuleName.MODEL_CONVERTER, - "Model download failed.") - - if sha256_checksum(model_file) != model_sha256_checksum: - MaceLogger.error(ModuleName.MODEL_CONVERTER, - "model file sha256checksum not match") - - if weight_file_path.startswith("http://") or \ - weight_file_path.startswith("https://"): - weight_file = \ - model_output_dir + "/" + md5sum(weight_file_path) + ".caffemodel" - if not os.path.exists(weight_file) or \ - sha256_checksum(weight_file) != weight_sha256_checksum: - MaceLogger.info("Downloading model weight, please wait ...") - if not download_file(weight_file_path, weight_file): - MaceLogger.error(ModuleName.MODEL_CONVERTER, - "Model download failed.") - - if weight_file: - if sha256_checksum(weight_file) != weight_sha256_checksum: - MaceLogger.error(ModuleName.MODEL_CONVERTER, - "weight file sha256checksum not match") - - if quantize_range_file_path.startswith("http://") or \ - quantize_range_file_path.startswith("https://"): - quantize_range_file = \ - model_output_dir + "/" + md5sum(quantize_range_file_path) \ - + ".range" - if not download_file(quantize_range_file_path, quantize_range_file): - MaceLogger.error(ModuleName.MODEL_CONVERTER, - "Model range file download failed.") - return model_file, weight_file, quantize_range_file - - def convert_model(configs, cl_mem_type): # Remove previous output dirs library_name = configs[YAMLKeyword.library_name] @@ -738,15 +741,6 @@ def convert_model(configs, cl_mem_type): else: model_config[YAMLKeyword.cl_mem_type] = "image" - model_file_path, weight_file_path, quantize_range_file_path = \ - get_model_files( - model_config[YAMLKeyword.model_file_path], - model_config[YAMLKeyword.model_sha256_checksum], - BUILD_DOWNLOADS_DIR, - model_config[YAMLKeyword.weight_file_path], - model_config[YAMLKeyword.weight_sha256_checksum], - model_config.get(YAMLKeyword.quantize_range_file, "")) - data_type = model_config[YAMLKeyword.data_type] # TODO(liuqi): support multiple subgraphs subgraphs = model_config[YAMLKeyword.subgraphs] @@ -755,8 +749,8 @@ def convert_model(configs, cl_mem_type): sh_commands.gen_model_code( model_codegen_dir, model_config[YAMLKeyword.platform], - model_file_path, - weight_file_path, + model_config[YAMLKeyword.model_file_path], + model_config[YAMLKeyword.weight_file_path], model_config[YAMLKeyword.model_sha256_checksum], model_config[YAMLKeyword.weight_sha256_checksum], ",".join(subgraphs[0][YAMLKeyword.input_tensors]), @@ -777,7 +771,7 @@ def convert_model(configs, cl_mem_type): model_config[YAMLKeyword.winograd], model_config[YAMLKeyword.quantize], model_config[YAMLKeyword.quantize_large_weights], - quantize_range_file_path, + model_config[YAMLKeyword.quantize_range_file], model_config[YAMLKeyword.change_concat_ranges], model_config[YAMLKeyword.obfuscate], configs[YAMLKeyword.model_graph_format], -- GitLab