提交 803f2621 编写于 作者: 李寅

Merge branch 'keep_downloaded' into 'master'

keep downloaded models

See merge request !633
...@@ -39,6 +39,7 @@ from common import StringFormatter ...@@ -39,6 +39,7 @@ from common import StringFormatter
# common definitions # common definitions
################################ ################################
BUILD_OUTPUT_DIR = 'build' BUILD_OUTPUT_DIR = 'build'
BUILD_DOWNLOADS_DIR = BUILD_OUTPUT_DIR + '/downloads'
PHONE_DATA_DIR = "/data/local/tmp/mace_run" PHONE_DATA_DIR = "/data/local/tmp/mace_run"
MODEL_OUTPUT_DIR_NAME = 'model' MODEL_OUTPUT_DIR_NAME = 'model'
MODEL_HEADER_DIR_PATH = 'include/mace/public' MODEL_HEADER_DIR_PATH = 'include/mace/public'
...@@ -536,36 +537,41 @@ def print_configuration(flags, configs): ...@@ -536,36 +537,41 @@ def print_configuration(flags, configs):
MaceLogger.summary(StringFormatter.table(header, data, title)) MaceLogger.summary(StringFormatter.table(header, data, title))
def download_model_files(model_file_path, def get_model_files(model_file_path,
model_output_dir, model_sha256_checksum,
weight_file_path=""): model_output_dir,
MaceLogger.info("Downloading model, please wait ...") weight_file_path="",
if model_file_path.startswith("http://") or \ weight_sha256_checksum=""):
model_file_path.startswith("https://"): model_file = model_file_path
model_file = model_output_dir + "/model.pb" weight_file = weight_file_path
urllib.urlretrieve(model_file_path, model_file)
if weight_file_path.startswith("http://") or \
weight_file_path.startswith("https://"):
weight_file = model_output_dir + "/model.caffemodel"
urllib.urlretrieve(weight_file_path, weight_file)
MaceLogger.info("Model downloaded successfully.")
def get_model_files_path(model_file_path,
model_output_dir,
weight_file_path=""):
if model_file_path.startswith("http://") or \ if model_file_path.startswith("http://") or \
model_file_path.startswith("https://"): model_file_path.startswith("https://"):
model_file = model_output_dir + "/model.pb" model_file = model_output_dir + "/" + md5sum(model_file_path) + ".pb"
else: if not os.path.exists(model_file) or \
model_file = model_file_path sha256_checksum(model_file) != model_sha256_checksum:
MaceLogger.info("Downloading model, please wait ...")
urllib.urlretrieve(model_file_path, model_file)
MaceLogger.info("Model downloaded successfully.")
if sha256_checksum(model_file) != model_sha256_checksum:
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"model file sha256checksum not match")
if weight_file_path.startswith("http://") or \ if weight_file_path.startswith("http://") or \
weight_file_path.startswith("https://"): weight_file_path.startswith("https://"):
weight_file = model_output_dir + "/model.caffemodel" weight_file = \
else: model_output_dir + "/" + md5sum(weight_file_path) + ".caffemodel"
weight_file = weight_file_path if not os.path.exists(weight_file) or \
sha256_checksum(weight_file) != weight_sha256_checksum:
MaceLogger.info("Downloading model weight, please wait ...")
urllib.urlretrieve(weight_file_path, weight_file)
MaceLogger.info("Model weight downloaded successfully.")
if weight_file:
if sha256_checksum(weight_file) != weight_sha256_checksum:
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"weight file sha256checksum not match")
return model_file, weight_file return model_file, weight_file
...@@ -578,6 +584,8 @@ def convert_model(configs): ...@@ -578,6 +584,8 @@ def convert_model(configs):
elif os.path.exists(os.path.join(BUILD_OUTPUT_DIR, library_name)): elif os.path.exists(os.path.join(BUILD_OUTPUT_DIR, library_name)):
sh.rm("-rf", os.path.join(BUILD_OUTPUT_DIR, library_name)) sh.rm("-rf", os.path.join(BUILD_OUTPUT_DIR, library_name))
os.makedirs(os.path.join(BUILD_OUTPUT_DIR, library_name)) os.makedirs(os.path.join(BUILD_OUTPUT_DIR, library_name))
if not os.path.exists(BUILD_DOWNLOADS_DIR):
os.makedirs(BUILD_DOWNLOADS_DIR)
model_output_dir = \ model_output_dir = \
'%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME) '%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME)
...@@ -609,38 +617,12 @@ def convert_model(configs): ...@@ -609,38 +617,12 @@ def convert_model(configs):
model_config = configs[YAMLKeyword.models][model_name] model_config = configs[YAMLKeyword.models][model_name]
runtime = model_config[YAMLKeyword.runtime] runtime = model_config[YAMLKeyword.runtime]
# Create model build directory model_file_path, weight_file_path = get_model_files(
model_path_digest = md5sum(
model_config[YAMLKeyword.model_file_path])
model_output_base_dir = "%s/%s/%s/%s/%s" % (
BUILD_OUTPUT_DIR, library_name, BUILD_TMP_DIR_NAME,
model_name, model_path_digest)
if os.path.exists(model_output_base_dir):
sh.rm("-rf", model_output_base_dir)
os.makedirs(model_output_base_dir)
download_model_files(
model_config[YAMLKeyword.model_file_path],
model_output_base_dir,
model_config[YAMLKeyword.weight_file_path])
model_file_path, weight_file_path = get_model_files_path(
model_config[YAMLKeyword.model_file_path], model_config[YAMLKeyword.model_file_path],
model_output_base_dir, model_config[YAMLKeyword.model_sha256_checksum],
model_config[YAMLKeyword.weight_file_path]) BUILD_DOWNLOADS_DIR,
model_config[YAMLKeyword.weight_file_path],
if sha256_checksum(model_file_path) != \ model_config[YAMLKeyword.weight_sha256_checksum])
model_config[YAMLKeyword.model_sha256_checksum]:
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"model file sha256checksum not match")
if weight_file_path:
if sha256_checksum(weight_file_path) != \
model_config[YAMLKeyword.weight_sha256_checksum]:
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"weight file sha256checksum not match")
data_type = model_config[YAMLKeyword.data_type] data_type = model_config[YAMLKeyword.data_type]
# TODO(liuqi): support multiple subgraphs # TODO(liuqi): support multiple subgraphs
...@@ -1068,10 +1050,12 @@ def run_specific_target(flags, configs, target_abi, ...@@ -1068,10 +1050,12 @@ def run_specific_target(flags, configs, target_abi,
linkshared=linkshared, linkshared=linkshared,
) )
if flags.validate: if flags.validate:
model_file_path, weight_file_path = get_model_files_path( model_file_path, weight_file_path = get_model_files(
model_config[YAMLKeyword.model_file_path], model_config[YAMLKeyword.model_file_path],
model_output_base_dir, model_config[YAMLKeyword.model_sha256_checksum],
model_config[YAMLKeyword.weight_file_path]) BUILD_DOWNLOADS_DIR,
model_config[YAMLKeyword.weight_file_path],
model_config[YAMLKeyword.weight_sha256_checksum])
sh_commands.validate_model( sh_commands.validate_model(
abi=target_abi, abi=target_abi,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册