From 4b8b4c71fc7535e597ca2932dfb01f1219053ec6 Mon Sep 17 00:00:00 2001 From: cyber-pioneer Date: Sat, 11 Feb 2023 13:38:23 +0000 Subject: [PATCH] map output from composite rule to origin op --- .../incubate/autograd/composite_rules.py | 4 ++-- python/paddle/incubate/autograd/primx.py | 12 ++++++---- python/paddle/incubate/autograd/utils.py | 24 +++++++++++++++++++ 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index d7eb54ba57..1d7f61e4ab 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -98,6 +98,6 @@ def composite_batchnorm( run_mean_ = assign(run_mean) run_var_ = assign(run_var) if trainable_statistics or not is_test: - return run_mean_, None, batch_mean_, batch_var_, run_var_, y + return y, run_mean_, run_var_, batch_mean_, batch_var_, None else: - return run_mean_, batch_mean_, batch_var_, run_var_, y + return y, run_mean_, run_var_, batch_mean_, batch_var_ diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index 9314019dd5..b2a6cc8062 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -36,6 +36,7 @@ from .utils import ( flatten_and_remove_none, get_input_var_list, get_output_var_list, + get_output_vars_from_comosite, prepare_python_api_arguments, ) @@ -605,14 +606,15 @@ def _lower_composite(block, blacklist=[]): bind(input_args, to_bind, value_table) for orig_out, new_out in zip( - expand_nested_list(get_output_var_list(op)), + expand_nested_list(get_output_vars_from_comosite(op)), expand_nested_list(as_tensors(lower_fn(op, *input_args))), ): if new_out is not None: - assert orig_out.shape == new_out.shape, ( - f'when replace origin op with composite rule, origin out shape should be equal to new out shape, ' - f'but orig_out.shape={orig_out.shape} and new_out.shape={new_out.shape}' - ) + if orig_out.shape and new_out.shape: + assert orig_out.shape == new_out.shape, ( + f'when replace origin op with composite rule, origin out shape should be equal to new out shape, ' + f'but orig_out.shape={orig_out.shape} and new_out.shape={new_out.shape}' + ) assert not (orig_out is None) ^ ( new_out is None ), "orig_out and new_out should match." diff --git a/python/paddle/incubate/autograd/utils.py b/python/paddle/incubate/autograd/utils.py index 70537a3c8b..2f651f0e16 100644 --- a/python/paddle/incubate/autograd/utils.py +++ b/python/paddle/incubate/autograd/utils.py @@ -219,6 +219,30 @@ def get_output_var_list(op): ] +def get_output_vars_from_comosite(op): + """origin op outputs must be mapped into outputs of composite rule.""" + origin_output_names = op.output_names + if origin_output_names is None: + return [] + else: + name = op.type + res = [] + if op_map[name].get("outputs"): + for item in op_map[name]["outputs"].keys(): + origin_output_name = op_map[name]["outputs"][item] + if origin_output_name not in origin_output_names: + continue + origin_output_var = get_var_block(op.block, op.output(origin_output_name)) + res.append(origin_output_var) + elif len(origin_output_names) == 1: + # When origin output num is 1, map info is not needed. + origin_output_var = get_var_block(op.block, op.output(origin_output_names[0])) + res.append(origin_output_var) + else: + raise ValueError("When replace op with composite rule, there must exist output map info from origin op to composite rule.") + return res + + def flatten(inp): if inp is None or isinstance(inp, paddle.fluid.framework.Variable): return [inp] -- GitLab