提交 a5ad26c8 编写于 作者: B Bin Li

keep downloaded models

上级 82ff5515
......@@ -39,6 +39,7 @@ from common import StringFormatter
# common definitions
################################
BUILD_OUTPUT_DIR = 'build'
BUILD_DOWNLOADS_DIR = BUILD_OUTPUT_DIR + '/downloads'
PHONE_DATA_DIR = "/data/local/tmp/mace_run"
MODEL_OUTPUT_DIR_NAME = 'model'
MODEL_HEADER_DIR_PATH = 'include/mace/public'
......@@ -536,36 +537,41 @@ def print_configuration(flags, configs):
MaceLogger.summary(StringFormatter.table(header, data, title))
def download_model_files(model_file_path,
model_output_dir,
weight_file_path=""):
MaceLogger.info("Downloading model, please wait ...")
if model_file_path.startswith("http://") or \
model_file_path.startswith("https://"):
model_file = model_output_dir + "/model.pb"
urllib.urlretrieve(model_file_path, model_file)
if weight_file_path.startswith("http://") or \
weight_file_path.startswith("https://"):
weight_file = model_output_dir + "/model.caffemodel"
urllib.urlretrieve(weight_file_path, weight_file)
MaceLogger.info("Model downloaded successfully.")
def get_model_files(model_file_path,
model_sha256_checksum,
model_output_dir,
weight_file_path="",
weight_sha256_checksum=""):
model_file = model_file_path
weight_file = weight_file_path
def get_model_files_path(model_file_path,
model_output_dir,
weight_file_path=""):
if model_file_path.startswith("http://") or \
model_file_path.startswith("https://"):
model_file = model_output_dir + "/model.pb"
else:
model_file = model_file_path
model_file = model_output_dir + "/" + md5sum(model_file_path) + ".pb"
if not os.path.exists(model_file) or \
sha256_checksum(model_file) != model_sha256_checksum:
MaceLogger.info("Downloading model, please wait ...")
urllib.urlretrieve(model_file_path, model_file)
MaceLogger.info("Model downloaded successfully.")
if sha256_checksum(model_file) != model_sha256_checksum:
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"model file sha256checksum not match")
if weight_file_path.startswith("http://") or \
weight_file_path.startswith("https://"):
weight_file = model_output_dir + "/model.caffemodel"
else:
weight_file = weight_file_path
weight_file = \
model_output_dir + "/" + md5sum(weight_file_path) + ".caffemodel"
if not os.path.exists(weight_file) or \
sha256_checksum(weight_file) != weight_sha256_checksum:
MaceLogger.info("Downloading model weight, please wait ...")
urllib.urlretrieve(weight_file_path, weight_file)
MaceLogger.info("Model weight downloaded successfully.")
if weight_file:
if sha256_checksum(weight_file) != weight_sha256_checksum:
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"weight file sha256checksum not match")
return model_file, weight_file
......@@ -578,6 +584,8 @@ def convert_model(configs):
elif os.path.exists(os.path.join(BUILD_OUTPUT_DIR, library_name)):
sh.rm("-rf", os.path.join(BUILD_OUTPUT_DIR, library_name))
os.makedirs(os.path.join(BUILD_OUTPUT_DIR, library_name))
if not os.path.exists(BUILD_DOWNLOADS_DIR):
os.makedirs(BUILD_DOWNLOADS_DIR)
model_output_dir = \
'%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME)
......@@ -609,38 +617,12 @@ def convert_model(configs):
model_config = configs[YAMLKeyword.models][model_name]
runtime = model_config[YAMLKeyword.runtime]
# Create model build directory
model_path_digest = md5sum(
model_config[YAMLKeyword.model_file_path])
model_output_base_dir = "%s/%s/%s/%s/%s" % (
BUILD_OUTPUT_DIR, library_name, BUILD_TMP_DIR_NAME,
model_name, model_path_digest)
if os.path.exists(model_output_base_dir):
sh.rm("-rf", model_output_base_dir)
os.makedirs(model_output_base_dir)
download_model_files(
model_config[YAMLKeyword.model_file_path],
model_output_base_dir,
model_config[YAMLKeyword.weight_file_path])
model_file_path, weight_file_path = get_model_files_path(
model_file_path, weight_file_path = get_model_files(
model_config[YAMLKeyword.model_file_path],
model_output_base_dir,
model_config[YAMLKeyword.weight_file_path])
if sha256_checksum(model_file_path) != \
model_config[YAMLKeyword.model_sha256_checksum]:
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"model file sha256checksum not match")
if weight_file_path:
if sha256_checksum(weight_file_path) != \
model_config[YAMLKeyword.weight_sha256_checksum]:
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"weight file sha256checksum not match")
model_config[YAMLKeyword.model_sha256_checksum],
BUILD_DOWNLOADS_DIR,
model_config[YAMLKeyword.weight_file_path],
model_config[YAMLKeyword.weight_sha256_checksum])
data_type = model_config[YAMLKeyword.data_type]
# TODO(liuqi): support multiple subgraphs
......@@ -1068,10 +1050,12 @@ def run_specific_target(flags, configs, target_abi,
linkshared=linkshared,
)
if flags.validate:
model_file_path, weight_file_path = get_model_files_path(
model_file_path, weight_file_path = get_model_files(
model_config[YAMLKeyword.model_file_path],
model_output_base_dir,
model_config[YAMLKeyword.weight_file_path])
model_config[YAMLKeyword.model_sha256_checksum],
BUILD_DOWNLOADS_DIR,
model_config[YAMLKeyword.weight_file_path],
model_config[YAMLKeyword.weight_sha256_checksum])
sh_commands.validate_model(
abi=target_abi,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册