From c0a29d2f7a65538f22aee5e24826b634ab78413a Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Mon, 25 Jul 2022 10:31:25 +0800 Subject: [PATCH] [JitLayer]Fix jit.save error when save params combined (#44504) * Fix jit.save error when save params combined * Change dict_value to list --- python/paddle/fluid/dygraph/jit.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index a55bcb9aaab..c3c3838f4be 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -483,9 +483,9 @@ def _get_output_vars(outputs, output_spec, with_hook=False): if isinstance(var, Variable): output_vars_dict[var.name] = var if output_spec is None: - result_list = output_vars_dict.values() + result_list = list(output_vars_dict.values()) elif output_spec is not None and len(output_spec) == len(output_vars_dict): - result_list = output_vars_dict.values() + result_list = list(output_vars_dict.values()) for var in output_spec: if var.name not in output_vars_dict: warnings.warn(name_no_exists_error % var.name) @@ -868,7 +868,7 @@ def save(layer, path, input_spec=None, **configs): layer, ] - all_vars = set() + combine_vars = {} property_vals = [] # (value, key) for attr_func in functions: if isinstance(layer, Layer): @@ -1020,19 +1020,28 @@ def save(layer, path, input_spec=None, **configs): program_only=configs._program_only, clip_extra=configs.clip_extra) - # collect all vars - for var in concrete_program.main_program.list_vars(): - all_vars.add(var) + if combine_params: + clone_main_program = concrete_program.main_program.clone() + clone_main_program = clone_main_program._prune_with_input( + input_var_names, output_vars) + for block in clone_main_program.blocks: + combine_vars.update(block.vars) # save shared params if combine_params: + # sort vars by name + combine_vars = sorted(combine_vars.items(), key=lambda item: item[0]) + ordered_vars = [] + for name, var in combine_vars: + ordered_vars.append(var) + params_filename = file_prefix + INFER_PARAMS_SUFFIX with scope_guard(scope): paddle.static.save_vars(Executor(_current_expected_place()), dirname=model_path, vars=list( filter(paddle.fluid.io.is_persistable, - all_vars)), + ordered_vars)), filename=params_filename) # TODO: save property -- GitLab