From 648cb5083c5b4b6469c240f056834a139c4dbcac Mon Sep 17 00:00:00 2001 From: cyber-pioneer Date: Sun, 12 Feb 2023 03:46:52 +0000 Subject: [PATCH] polish log --- python/paddle/incubate/autograd/primx.py | 17 ++++++++++++++--- python/paddle/incubate/autograd/utils.py | 1 + 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index 6a5e4ae6fc..55ab90c27c 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -604,17 +604,28 @@ def _lower_composite(block, blacklist=[]): ops_to_remove.append(op_idx) if lookup_fn(op.type) is not None and op.type not in blacklist: change = True + op_name = op.type input_args = prepare_python_api_arguments(op) bind(input_args, to_bind, value_table) + orig_outs = expand_nested_list( + get_output_vars_from_comosite(op) + ) + new_outs = expand_nested_list( + as_tensors(lower_fn(op, *input_args)) + ) + assert len(orig_outs) == len(new_outs), ( + f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, ' + f'but len(orig_outs) = {len(orig_outs)} and len(new_outs) = {len(new_outs)}' + ) for orig_out, new_out in zip( - expand_nested_list(get_output_vars_from_comosite(op)), - expand_nested_list(as_tensors(lower_fn(op, *input_args))), + orig_outs, + new_outs, ): if new_out is not None: if orig_out.shape and new_out.shape: assert orig_out.shape == new_out.shape, ( - f'when replace origin op with composite rule, origin out shape should be equal to new out shape, ' + f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, ' f'but orig_out.shape={orig_out.shape} and new_out.shape={new_out.shape}' ) assert not (orig_out is None) ^ ( diff --git a/python/paddle/incubate/autograd/utils.py b/python/paddle/incubate/autograd/utils.py index c011c7495e..fe7fa229ff 100644 --- a/python/paddle/incubate/autograd/utils.py +++ b/python/paddle/incubate/autograd/utils.py @@ -239,6 +239,7 @@ def get_output_vars_from_comosite(op): for item in op_map[name]["outputs"].keys(): origin_output_name = op_map[name]["outputs"][item] if origin_output_name not in origin_output_names: + # in some cases, some output of origin op is optional, so op name may not be in origin_output_names continue origin_output_var = get_var_block( op.block, op.output(origin_output_name) -- GitLab