未验证 提交 3f5d0083 编写于 作者: R RichardWooSJTU 提交者: GitHub

[PaddleInference] Fix llm inference dy2static error (#56688)

* fix llm inference dy2static error

* use kwargs instead of default argument
上级 323566d5
...@@ -418,6 +418,9 @@ class _SaveLoadConfig: ...@@ -418,6 +418,9 @@ class _SaveLoadConfig:
# when need to save a prune model, use input_names_after_prune to specify the inputs left after pruning # when need to save a prune model, use input_names_after_prune to specify the inputs left after pruning
self.input_names_after_prune = None self.input_names_after_prune = None
# in the scene of llm-inference, prunning program can cause unexpectable result, an option to skip prune is necessary
self.skip_prune_program = False
@property @property
def output_spec(self): def output_spec(self):
return self._output_spec return self._output_spec
...@@ -497,6 +500,7 @@ def _parse_save_configs(configs): ...@@ -497,6 +500,7 @@ def _parse_save_configs(configs):
"clip_extra", "clip_extra",
"skip_forward", "skip_forward",
"input_names_after_prune", "input_names_after_prune",
"skip_prune_program",
] ]
# input check # input check
...@@ -517,6 +521,7 @@ def _parse_save_configs(configs): ...@@ -517,6 +521,7 @@ def _parse_save_configs(configs):
inner_config.input_names_after_prune = configs.get( inner_config.input_names_after_prune = configs.get(
"input_names_after_prune", None "input_names_after_prune", None
) )
inner_config.skip_prune_program = configs.get("skip_prune_program", False)
return inner_config return inner_config
...@@ -1259,6 +1264,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -1259,6 +1264,7 @@ def save(layer, path, input_spec=None, **configs):
executor=Executor(_current_expected_place()), executor=Executor(_current_expected_place()),
program=concrete_program.main_program.clone(), program=concrete_program.main_program.clone(),
clip_extra=configs.clip_extra, clip_extra=configs.clip_extra,
skip_prune_program=configs.skip_prune_program,
) )
if combine_params: if combine_params:
......
...@@ -187,7 +187,7 @@ def append_fetch_ops( ...@@ -187,7 +187,7 @@ def append_fetch_ops(
) )
def normalize_program(program, feed_vars, fetch_vars): def normalize_program(program, feed_vars, fetch_vars, **kwargs):
""" """
Normalize/Optimize a program according to feed_vars and fetch_vars. Normalize/Optimize a program according to feed_vars and fetch_vars.
...@@ -196,6 +196,8 @@ def normalize_program(program, feed_vars, fetch_vars): ...@@ -196,6 +196,8 @@ def normalize_program(program, feed_vars, fetch_vars):
program(Program): Specify a program you want to optimize. program(Program): Specify a program you want to optimize.
feed_vars(Tensor | list[Tensor]): Variables needed by inference. feed_vars(Tensor | list[Tensor]): Variables needed by inference.
fetch_vars(Tensor | list[Tensor]): Variables returned by inference. fetch_vars(Tensor | list[Tensor]): Variables returned by inference.
kwargs: Supported keys including ``skip_prune_program``.
- skip_prune_program(bool): whether to skip prunning program. Defaults to False.
Returns: Returns:
Program: Normalized/Optimized program. Program: Normalized/Optimized program.
...@@ -277,6 +279,9 @@ def normalize_program(program, feed_vars, fetch_vars): ...@@ -277,6 +279,9 @@ def normalize_program(program, feed_vars, fetch_vars):
copy_program.desc.flush() copy_program.desc.flush()
feed_var_names = [var.name for var in feed_vars] feed_var_names = [var.name for var in feed_vars]
skip_prune_program = kwargs.get('skip_prune_program', False)
if not skip_prune_program:
copy_program = copy_program._prune_with_input( copy_program = copy_program._prune_with_input(
feeded_var_names=feed_var_names, targets=fetch_vars feeded_var_names=feed_var_names, targets=fetch_vars
) )
...@@ -569,7 +574,12 @@ def save_inference_model( ...@@ -569,7 +574,12 @@ def save_inference_model(
program = _get_valid_program(kwargs.get('program', None)) program = _get_valid_program(kwargs.get('program', None))
clip_extra = kwargs.get('clip_extra', True) clip_extra = kwargs.get('clip_extra', True)
program = normalize_program(program, feed_vars, fetch_vars) program = normalize_program(
program,
feed_vars,
fetch_vars,
skip_prune_program=kwargs.get('skip_prune_program', False),
)
# serialize and save program # serialize and save program
legacy_format = kwargs.get('legacy_format', False) legacy_format = kwargs.get('legacy_format', False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册