From a35aa8cf522b2509024265e338252503bed86f68 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Sun, 25 Jun 2023 15:02:07 +0800 Subject: [PATCH] fix paddle.static.load_inference_model api (#54793) --- python/paddle/static/io.py | 66 ++++++++++++++++++++++++++------------ 1 file changed, 45 insertions(+), 21 deletions(-) diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index eef8cacc0e6..5a58fe5440b 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -528,11 +528,18 @@ def save_inference_model( legacy_format=legacy_format, ) save_to_file(model_path, program_bytes) - # serialize and save params - params_bytes = _serialize_persistables(program, executor) - # program may not contain any parameter and just compute operation - if params_bytes is not None: - save_to_file(params_path, params_bytes) + + vars = list(filter(is_persistable, program.list_vars())) + if len(vars) > 0: + save_dirname = os.path.dirname(params_path) + params_filename = os.path.basename(params_path) + save_vars( + executor, + dirname=save_dirname, + main_program=program, + predicate=is_persistable, + filename=params_filename, + ) @static_only @@ -581,6 +588,7 @@ def deserialize_program(data): return program +# NOTE(liuyuanle): Due to load from memory, deserialize_persistables does not support loading weights with file sizes exceeding 2GB. @static_only def deserialize_persistables(program, data, executor): """ @@ -797,16 +805,29 @@ def load_inference_model(path_prefix, executor, **kwargs): # load from memory if path_prefix is None: - _logger.warning("Load inference model from memory is deprecated.") + _logger.warning( + "Load inference model from memory is deprecated. Please specify path_prefix." + ) model_filename = kwargs.get('model_filename', None) params_filename = kwargs.get('params_filename', None) if params_filename is None: raise ValueError( "params_filename cannot be None when path_prefix is None." ) - load_dirname = '' program_bytes = model_filename - params_bytes = params_filename + # deserialize bytes to program + program = deserialize_program(program_bytes) + + vars = list(filter(is_persistable, program.list_vars())) + if len(vars) > 0: + load_vars( + executor, + # load from memory, dirname is None + dirname=None, + main_program=program, + predicate=is_persistable, + filename=params_filename, + ) # load from file else: # check and norm path_prefix @@ -841,24 +862,27 @@ def load_inference_model(path_prefix, executor, **kwargs): if not os.path.exists(params_path): params_path = os.path.join(path_prefix, params_filename) _logger.warning( - "The old way to load inference model is deprecated." + "The old way to load inference model is deprecated. Please specify path_prefix." " model path: {}, params path: {}".format( model_path, params_path ) ) program_bytes = load_from_file(model_path) - load_dirname = os.path.dirname(params_path) - params_filename = os.path.basename(params_path) - # load params data - params_path = os.path.join(load_dirname, params_filename) - params_bytes = None - if os.path.exists(params_path): - params_bytes = load_from_file(params_path) - # deserialize bytes to program - program = deserialize_program(program_bytes) - # deserialize bytes to params - deserialize_persistables(program, params_bytes, executor) + # deserialize bytes to program + program = deserialize_program(program_bytes) + + vars = list(filter(is_persistable, program.list_vars())) + if len(vars) > 0: + load_dirname = os.path.dirname(params_path) + params_filename = os.path.basename(params_path) + load_vars( + executor, + dirname=load_dirname, + main_program=program, + predicate=is_persistable, + filename=params_filename, + ) feed_target_names = program.desc.get_feed_target_names() fetch_target_names = program.desc.get_fetch_target_names() @@ -952,7 +976,7 @@ def save_vars( if dirname is None and filename is None: save_to_memory = True - main_program = paddle.static.io._get_valid_program(main_program) + main_program = _get_valid_program(main_program) if vars is None: return save_vars( -- GitLab