未验证 提交 3f5d0083 编写于 作者: R RichardWooSJTU 提交者: GitHub

[PaddleInference] Fix llm inference dy2static error (#56688)

* fix llm inference dy2static error

* use kwargs instead of default argument
上级 323566d5
......@@ -418,6 +418,9 @@ class _SaveLoadConfig:
# when need to save a prune model, use input_names_after_prune to specify the inputs left after pruning
self.input_names_after_prune = None
# in the scene of llm-inference, prunning program can cause unexpectable result, an option to skip prune is necessary
self.skip_prune_program = False
@property
def output_spec(self):
return self._output_spec
......@@ -497,6 +500,7 @@ def _parse_save_configs(configs):
"clip_extra",
"skip_forward",
"input_names_after_prune",
"skip_prune_program",
]
# input check
......@@ -517,6 +521,7 @@ def _parse_save_configs(configs):
inner_config.input_names_after_prune = configs.get(
"input_names_after_prune", None
)
inner_config.skip_prune_program = configs.get("skip_prune_program", False)
return inner_config
......@@ -1259,6 +1264,7 @@ def save(layer, path, input_spec=None, **configs):
executor=Executor(_current_expected_place()),
program=concrete_program.main_program.clone(),
clip_extra=configs.clip_extra,
skip_prune_program=configs.skip_prune_program,
)
if combine_params:
......
......@@ -187,7 +187,7 @@ def append_fetch_ops(
)
def normalize_program(program, feed_vars, fetch_vars):
def normalize_program(program, feed_vars, fetch_vars, **kwargs):
"""
Normalize/Optimize a program according to feed_vars and fetch_vars.
......@@ -196,6 +196,8 @@ def normalize_program(program, feed_vars, fetch_vars):
program(Program): Specify a program you want to optimize.
feed_vars(Tensor | list[Tensor]): Variables needed by inference.
fetch_vars(Tensor | list[Tensor]): Variables returned by inference.
kwargs: Supported keys including ``skip_prune_program``.
- skip_prune_program(bool): whether to skip prunning program. Defaults to False.
Returns:
Program: Normalized/Optimized program.
......@@ -277,6 +279,9 @@ def normalize_program(program, feed_vars, fetch_vars):
copy_program.desc.flush()
feed_var_names = [var.name for var in feed_vars]
skip_prune_program = kwargs.get('skip_prune_program', False)
if not skip_prune_program:
copy_program = copy_program._prune_with_input(
feeded_var_names=feed_var_names, targets=fetch_vars
)
......@@ -569,7 +574,12 @@ def save_inference_model(
program = _get_valid_program(kwargs.get('program', None))
clip_extra = kwargs.get('clip_extra', True)
program = normalize_program(program, feed_vars, fetch_vars)
program = normalize_program(
program,
feed_vars,
fetch_vars,
skip_prune_program=kwargs.get('skip_prune_program', False),
)
# serialize and save program
legacy_format = kwargs.get('legacy_format', False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册