未验证 提交 c0a29d2f 编写于 作者: W WangZhen 提交者: GitHub

[JitLayer]Fix jit.save error when save params combined (#44504)

* Fix jit.save error when save params combined

* Change dict_value to list
上级 e32e4a1d
...@@ -483,9 +483,9 @@ def _get_output_vars(outputs, output_spec, with_hook=False): ...@@ -483,9 +483,9 @@ def _get_output_vars(outputs, output_spec, with_hook=False):
if isinstance(var, Variable): if isinstance(var, Variable):
output_vars_dict[var.name] = var output_vars_dict[var.name] = var
if output_spec is None: if output_spec is None:
result_list = output_vars_dict.values() result_list = list(output_vars_dict.values())
elif output_spec is not None and len(output_spec) == len(output_vars_dict): elif output_spec is not None and len(output_spec) == len(output_vars_dict):
result_list = output_vars_dict.values() result_list = list(output_vars_dict.values())
for var in output_spec: for var in output_spec:
if var.name not in output_vars_dict: if var.name not in output_vars_dict:
warnings.warn(name_no_exists_error % var.name) warnings.warn(name_no_exists_error % var.name)
...@@ -868,7 +868,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -868,7 +868,7 @@ def save(layer, path, input_spec=None, **configs):
layer, layer,
] ]
all_vars = set() combine_vars = {}
property_vals = [] # (value, key) property_vals = [] # (value, key)
for attr_func in functions: for attr_func in functions:
if isinstance(layer, Layer): if isinstance(layer, Layer):
...@@ -1020,19 +1020,28 @@ def save(layer, path, input_spec=None, **configs): ...@@ -1020,19 +1020,28 @@ def save(layer, path, input_spec=None, **configs):
program_only=configs._program_only, program_only=configs._program_only,
clip_extra=configs.clip_extra) clip_extra=configs.clip_extra)
# collect all vars if combine_params:
for var in concrete_program.main_program.list_vars(): clone_main_program = concrete_program.main_program.clone()
all_vars.add(var) clone_main_program = clone_main_program._prune_with_input(
input_var_names, output_vars)
for block in clone_main_program.blocks:
combine_vars.update(block.vars)
# save shared params # save shared params
if combine_params: if combine_params:
# sort vars by name
combine_vars = sorted(combine_vars.items(), key=lambda item: item[0])
ordered_vars = []
for name, var in combine_vars:
ordered_vars.append(var)
params_filename = file_prefix + INFER_PARAMS_SUFFIX params_filename = file_prefix + INFER_PARAMS_SUFFIX
with scope_guard(scope): with scope_guard(scope):
paddle.static.save_vars(Executor(_current_expected_place()), paddle.static.save_vars(Executor(_current_expected_place()),
dirname=model_path, dirname=model_path,
vars=list( vars=list(
filter(paddle.fluid.io.is_persistable, filter(paddle.fluid.io.is_persistable,
all_vars)), ordered_vars)),
filename=params_filename) filename=params_filename)
# TODO: save property # TODO: save property
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册