提交 23d1bead 编写于 作者: L liukai6

move get_model_files from convert to format_model_config

上级 a157a65f
......@@ -236,6 +236,77 @@ def sha256_checksum(fname):
return hash_func.hexdigest()
def download_file(url, dst, num_retries=3):
from six.moves import urllib
try:
urllib.request.urlretrieve(url, dst)
MaceLogger.info('\nDownloaded successfully.')
except (urllib.error.ContentTooShortError, urllib.error.HTTPError,
urllib.error.URLError) as e:
MaceLogger.warning('Download error:' + str(e))
if num_retries > 0:
return download_file(url, dst, num_retries - 1)
else:
return False
return True
def get_model_files(model_config, model_output_dir):
if not os.path.exists(model_output_dir):
os.makedirs(model_output_dir)
model_file_path = model_config[YAMLKeyword.model_file_path]
model_sha256_checksum = model_config[YAMLKeyword.model_sha256_checksum]
weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
weight_sha256_checksum = model_config.get(YAMLKeyword.weight_sha256_checksum, "") # noqa
quantize_range_file_path = model_config.get(YAMLKeyword.quantize_range_file, "") # noqa
model_file = model_file_path
weight_file = weight_file_path
quantize_range_file = quantize_range_file_path
if model_file_path.startswith("http://") or \
model_file_path.startswith("https://"):
model_file = model_output_dir + "/" + md5sum(model_file_path) + ".pb"
if not os.path.exists(model_file) or \
sha256_checksum(model_file) != model_sha256_checksum:
MaceLogger.info("Downloading model, please wait ...")
if not download_file(model_file_path, model_file):
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"Model download failed.")
model_config[YAMLKeyword.model_file_path] = model_file
if sha256_checksum(model_file) != model_sha256_checksum:
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"model file sha256checksum not match")
if weight_file_path.startswith("http://") or \
weight_file_path.startswith("https://"):
weight_file = \
model_output_dir + "/" + md5sum(weight_file_path) + ".caffemodel"
if not os.path.exists(weight_file) or \
sha256_checksum(weight_file) != weight_sha256_checksum:
MaceLogger.info("Downloading model weight, please wait ...")
if not download_file(weight_file_path, weight_file):
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"Model download failed.")
model_config[YAMLKeyword.weight_file_path] = weight_file
if weight_file:
if sha256_checksum(weight_file) != weight_sha256_checksum:
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"weight file sha256checksum not match")
if quantize_range_file_path.startswith("http://") or \
quantize_range_file_path.startswith("https://"):
quantize_range_file = \
model_output_dir + "/" + md5sum(quantize_range_file_path) \
+ ".range"
if not download_file(quantize_range_file_path, quantize_range_file):
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"Model range file download failed.")
model_config[YAMLKeyword.quantize_range_file] = quantize_range_file
def format_model_config(flags):
with open(flags.config) as f:
configs = yaml.load(f)
......@@ -351,6 +422,8 @@ def format_model_config(flags):
else:
model_config[YAMLKeyword.weight_sha256_checksum] = ""
get_model_files(model_config, BUILD_DOWNLOADS_DIR)
runtime = model_config.get(YAMLKeyword.runtime, "")
mace_check(runtime in RuntimeTypeStrs,
ModuleName.YAML_CONFIG,
......@@ -580,9 +653,6 @@ def format_model_config(flags):
+ str(WinogradParameters) +
". 0 for disable winograd convolution")
weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
model_config[YAMLKeyword.weight_file_path] = weight_file_path
return configs
......@@ -623,73 +693,6 @@ def print_configuration(configs):
MaceLogger.summary(StringFormatter.table(header, data, title))
def download_file(url, dst, num_retries=3):
from six.moves import urllib
try:
urllib.request.urlretrieve(url, dst)
MaceLogger.info('\nDownloaded successfully.')
except (urllib.error.ContentTooShortError, urllib.error.HTTPError,
urllib.error.URLError) as e:
MaceLogger.warning('Download error:' + str(e))
if num_retries > 0:
return download_file(url, dst, num_retries - 1)
else:
return False
return True
def get_model_files(model_file_path,
model_sha256_checksum,
model_output_dir,
weight_file_path="",
weight_sha256_checksum="",
quantize_range_file_path=""):
model_file = model_file_path
weight_file = weight_file_path
quantize_range_file = quantize_range_file_path
if model_file_path.startswith("http://") or \
model_file_path.startswith("https://"):
model_file = model_output_dir + "/" + md5sum(model_file_path) + ".pb"
if not os.path.exists(model_file) or \
sha256_checksum(model_file) != model_sha256_checksum:
MaceLogger.info("Downloading model, please wait ...")
if not download_file(model_file_path, model_file):
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"Model download failed.")
if sha256_checksum(model_file) != model_sha256_checksum:
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"model file sha256checksum not match")
if weight_file_path.startswith("http://") or \
weight_file_path.startswith("https://"):
weight_file = \
model_output_dir + "/" + md5sum(weight_file_path) + ".caffemodel"
if not os.path.exists(weight_file) or \
sha256_checksum(weight_file) != weight_sha256_checksum:
MaceLogger.info("Downloading model weight, please wait ...")
if not download_file(weight_file_path, weight_file):
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"Model download failed.")
if weight_file:
if sha256_checksum(weight_file) != weight_sha256_checksum:
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"weight file sha256checksum not match")
if quantize_range_file_path.startswith("http://") or \
quantize_range_file_path.startswith("https://"):
quantize_range_file = \
model_output_dir + "/" + md5sum(quantize_range_file_path) \
+ ".range"
if not download_file(quantize_range_file_path, quantize_range_file):
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"Model range file download failed.")
return model_file, weight_file, quantize_range_file
def convert_model(configs, cl_mem_type):
# Remove previous output dirs
library_name = configs[YAMLKeyword.library_name]
......@@ -738,15 +741,6 @@ def convert_model(configs, cl_mem_type):
else:
model_config[YAMLKeyword.cl_mem_type] = "image"
model_file_path, weight_file_path, quantize_range_file_path = \
get_model_files(
model_config[YAMLKeyword.model_file_path],
model_config[YAMLKeyword.model_sha256_checksum],
BUILD_DOWNLOADS_DIR,
model_config[YAMLKeyword.weight_file_path],
model_config[YAMLKeyword.weight_sha256_checksum],
model_config.get(YAMLKeyword.quantize_range_file, ""))
data_type = model_config[YAMLKeyword.data_type]
# TODO(liuqi): support multiple subgraphs
subgraphs = model_config[YAMLKeyword.subgraphs]
......@@ -755,8 +749,8 @@ def convert_model(configs, cl_mem_type):
sh_commands.gen_model_code(
model_codegen_dir,
model_config[YAMLKeyword.platform],
model_file_path,
weight_file_path,
model_config[YAMLKeyword.model_file_path],
model_config[YAMLKeyword.weight_file_path],
model_config[YAMLKeyword.model_sha256_checksum],
model_config[YAMLKeyword.weight_sha256_checksum],
",".join(subgraphs[0][YAMLKeyword.input_tensors]),
......@@ -777,7 +771,7 @@ def convert_model(configs, cl_mem_type):
model_config[YAMLKeyword.winograd],
model_config[YAMLKeyword.quantize],
model_config[YAMLKeyword.quantize_large_weights],
quantize_range_file_path,
model_config[YAMLKeyword.quantize_range_file],
model_config[YAMLKeyword.change_concat_ranges],
model_config[YAMLKeyword.obfuscate],
configs[YAMLKeyword.model_graph_format],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册