diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index 6a5e4ae6fc366ad5197b40cdc487e86ddce005a5..55ab90c27c65b2b758cbb7d82f2784b4a6ed757e 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -604,17 +604,28 @@ def _lower_composite(block, blacklist=[]): ops_to_remove.append(op_idx) if lookup_fn(op.type) is not None and op.type not in blacklist: change = True + op_name = op.type input_args = prepare_python_api_arguments(op) bind(input_args, to_bind, value_table) + orig_outs = expand_nested_list( + get_output_vars_from_comosite(op) + ) + new_outs = expand_nested_list( + as_tensors(lower_fn(op, *input_args)) + ) + assert len(orig_outs) == len(new_outs), ( + f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, ' + f'but len(orig_outs) = {len(orig_outs)} and len(new_outs) = {len(new_outs)}' + ) for orig_out, new_out in zip( - expand_nested_list(get_output_vars_from_comosite(op)), - expand_nested_list(as_tensors(lower_fn(op, *input_args))), + orig_outs, + new_outs, ): if new_out is not None: if orig_out.shape and new_out.shape: assert orig_out.shape == new_out.shape, ( - f'when replace origin op with composite rule, origin out shape should be equal to new out shape, ' + f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, ' f'but orig_out.shape={orig_out.shape} and new_out.shape={new_out.shape}' ) assert not (orig_out is None) ^ ( diff --git a/python/paddle/incubate/autograd/utils.py b/python/paddle/incubate/autograd/utils.py index c011c7495e6ac269ea3eccb3045b3fe21fe292ba..fe7fa229ffa701fd3fda51c6214e4ad76dd170a1 100644 --- a/python/paddle/incubate/autograd/utils.py +++ b/python/paddle/incubate/autograd/utils.py @@ -239,6 +239,7 @@ def get_output_vars_from_comosite(op): for item in op_map[name]["outputs"].keys(): origin_output_name = op_map[name]["outputs"][item] if origin_output_name not in origin_output_names: + # in some cases, some output of origin op is optional, so op name may not be in origin_output_names continue origin_output_var = get_var_block( op.block, op.output(origin_output_name)