未验证 提交 a35aa8cf 编写于 作者: Y Yuanle Liu 提交者: GitHub

fix paddle.static.load_inference_model api (#54793)

上级 18329d67
......@@ -528,11 +528,18 @@ def save_inference_model(
legacy_format=legacy_format,
)
save_to_file(model_path, program_bytes)
# serialize and save params
params_bytes = _serialize_persistables(program, executor)
# program may not contain any parameter and just compute operation
if params_bytes is not None:
save_to_file(params_path, params_bytes)
vars = list(filter(is_persistable, program.list_vars()))
if len(vars) > 0:
save_dirname = os.path.dirname(params_path)
params_filename = os.path.basename(params_path)
save_vars(
executor,
dirname=save_dirname,
main_program=program,
predicate=is_persistable,
filename=params_filename,
)
@static_only
......@@ -581,6 +588,7 @@ def deserialize_program(data):
return program
# NOTE(liuyuanle): Due to load from memory, deserialize_persistables does not support loading weights with file sizes exceeding 2GB.
@static_only
def deserialize_persistables(program, data, executor):
"""
......@@ -797,16 +805,29 @@ def load_inference_model(path_prefix, executor, **kwargs):
# load from memory
if path_prefix is None:
_logger.warning("Load inference model from memory is deprecated.")
_logger.warning(
"Load inference model from memory is deprecated. Please specify path_prefix."
)
model_filename = kwargs.get('model_filename', None)
params_filename = kwargs.get('params_filename', None)
if params_filename is None:
raise ValueError(
"params_filename cannot be None when path_prefix is None."
)
load_dirname = ''
program_bytes = model_filename
params_bytes = params_filename
# deserialize bytes to program
program = deserialize_program(program_bytes)
vars = list(filter(is_persistable, program.list_vars()))
if len(vars) > 0:
load_vars(
executor,
# load from memory, dirname is None
dirname=None,
main_program=program,
predicate=is_persistable,
filename=params_filename,
)
# load from file
else:
# check and norm path_prefix
......@@ -841,24 +862,27 @@ def load_inference_model(path_prefix, executor, **kwargs):
if not os.path.exists(params_path):
params_path = os.path.join(path_prefix, params_filename)
_logger.warning(
"The old way to load inference model is deprecated."
"The old way to load inference model is deprecated. Please specify path_prefix."
" model path: {}, params path: {}".format(
model_path, params_path
)
)
program_bytes = load_from_file(model_path)
load_dirname = os.path.dirname(params_path)
params_filename = os.path.basename(params_path)
# load params data
params_path = os.path.join(load_dirname, params_filename)
params_bytes = None
if os.path.exists(params_path):
params_bytes = load_from_file(params_path)
# deserialize bytes to program
program = deserialize_program(program_bytes)
# deserialize bytes to params
deserialize_persistables(program, params_bytes, executor)
vars = list(filter(is_persistable, program.list_vars()))
if len(vars) > 0:
load_dirname = os.path.dirname(params_path)
params_filename = os.path.basename(params_path)
load_vars(
executor,
dirname=load_dirname,
main_program=program,
predicate=is_persistable,
filename=params_filename,
)
feed_target_names = program.desc.get_feed_target_names()
fetch_target_names = program.desc.get_fetch_target_names()
......@@ -952,7 +976,7 @@ def save_vars(
if dirname is None and filename is None:
save_to_memory = True
main_program = paddle.static.io._get_valid_program(main_program)
main_program = _get_valid_program(main_program)
if vars is None:
return save_vars(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册