未验证 提交 a35aa8cf 编写于 作者: Y Yuanle Liu 提交者: GitHub

fix paddle.static.load_inference_model api (#54793)

上级 18329d67
...@@ -528,11 +528,18 @@ def save_inference_model( ...@@ -528,11 +528,18 @@ def save_inference_model(
legacy_format=legacy_format, legacy_format=legacy_format,
) )
save_to_file(model_path, program_bytes) save_to_file(model_path, program_bytes)
# serialize and save params
params_bytes = _serialize_persistables(program, executor) vars = list(filter(is_persistable, program.list_vars()))
# program may not contain any parameter and just compute operation if len(vars) > 0:
if params_bytes is not None: save_dirname = os.path.dirname(params_path)
save_to_file(params_path, params_bytes) params_filename = os.path.basename(params_path)
save_vars(
executor,
dirname=save_dirname,
main_program=program,
predicate=is_persistable,
filename=params_filename,
)
@static_only @static_only
...@@ -581,6 +588,7 @@ def deserialize_program(data): ...@@ -581,6 +588,7 @@ def deserialize_program(data):
return program return program
# NOTE(liuyuanle): Due to load from memory, deserialize_persistables does not support loading weights with file sizes exceeding 2GB.
@static_only @static_only
def deserialize_persistables(program, data, executor): def deserialize_persistables(program, data, executor):
""" """
...@@ -797,16 +805,29 @@ def load_inference_model(path_prefix, executor, **kwargs): ...@@ -797,16 +805,29 @@ def load_inference_model(path_prefix, executor, **kwargs):
# load from memory # load from memory
if path_prefix is None: if path_prefix is None:
_logger.warning("Load inference model from memory is deprecated.") _logger.warning(
"Load inference model from memory is deprecated. Please specify path_prefix."
)
model_filename = kwargs.get('model_filename', None) model_filename = kwargs.get('model_filename', None)
params_filename = kwargs.get('params_filename', None) params_filename = kwargs.get('params_filename', None)
if params_filename is None: if params_filename is None:
raise ValueError( raise ValueError(
"params_filename cannot be None when path_prefix is None." "params_filename cannot be None when path_prefix is None."
) )
load_dirname = ''
program_bytes = model_filename program_bytes = model_filename
params_bytes = params_filename # deserialize bytes to program
program = deserialize_program(program_bytes)
vars = list(filter(is_persistable, program.list_vars()))
if len(vars) > 0:
load_vars(
executor,
# load from memory, dirname is None
dirname=None,
main_program=program,
predicate=is_persistable,
filename=params_filename,
)
# load from file # load from file
else: else:
# check and norm path_prefix # check and norm path_prefix
...@@ -841,24 +862,27 @@ def load_inference_model(path_prefix, executor, **kwargs): ...@@ -841,24 +862,27 @@ def load_inference_model(path_prefix, executor, **kwargs):
if not os.path.exists(params_path): if not os.path.exists(params_path):
params_path = os.path.join(path_prefix, params_filename) params_path = os.path.join(path_prefix, params_filename)
_logger.warning( _logger.warning(
"The old way to load inference model is deprecated." "The old way to load inference model is deprecated. Please specify path_prefix."
" model path: {}, params path: {}".format( " model path: {}, params path: {}".format(
model_path, params_path model_path, params_path
) )
) )
program_bytes = load_from_file(model_path) program_bytes = load_from_file(model_path)
load_dirname = os.path.dirname(params_path)
params_filename = os.path.basename(params_path)
# load params data
params_path = os.path.join(load_dirname, params_filename)
params_bytes = None
if os.path.exists(params_path):
params_bytes = load_from_file(params_path)
# deserialize bytes to program # deserialize bytes to program
program = deserialize_program(program_bytes) program = deserialize_program(program_bytes)
# deserialize bytes to params
deserialize_persistables(program, params_bytes, executor) vars = list(filter(is_persistable, program.list_vars()))
if len(vars) > 0:
load_dirname = os.path.dirname(params_path)
params_filename = os.path.basename(params_path)
load_vars(
executor,
dirname=load_dirname,
main_program=program,
predicate=is_persistable,
filename=params_filename,
)
feed_target_names = program.desc.get_feed_target_names() feed_target_names = program.desc.get_feed_target_names()
fetch_target_names = program.desc.get_fetch_target_names() fetch_target_names = program.desc.get_fetch_target_names()
...@@ -952,7 +976,7 @@ def save_vars( ...@@ -952,7 +976,7 @@ def save_vars(
if dirname is None and filename is None: if dirname is None and filename is None:
save_to_memory = True save_to_memory = True
main_program = paddle.static.io._get_valid_program(main_program) main_program = _get_valid_program(main_program)
if vars is None: if vars is None:
return save_vars( return save_vars(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册