diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index b9b7c90b17a9e1e805e9e672001569dd87b76b33..50943b07458ce8232e5f6782ed8182e0d735b230 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -550,6 +550,17 @@ def _lower(block, reverse, blacklist): block._sync_with_cpp() +# In some case, inputs and outputs of composite op or its replaced composite rule might be None. +# It means such arg will be no longer required in processed program by composite mechanism. +# Therefore, such special ops should be recorded in advance and be released in args check. +ops_contain_none = ( + "batch_norm", + "flatten_contiguous_range", + "squeeze2", + "unsqueeze2", +) + + def _lower_composite( block, filter_: typing.Callable[[framework.Operator], bool] = lambda x: True, @@ -664,10 +675,17 @@ def _lower_composite( f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, ' f'but len(orig_outs) = {len(orig_outs)} and len(new_outs) = {len(new_outs)}' ) + for orig_out, new_out in zip( orig_outs, new_outs, ): + if (orig_out is None or new_out is None) and ( + op_name not in ops_contain_none + ): + raise ValueError( + f"op {op_name} should not contain any None value. original outs={orig_outs} and its composite rule outs={new_outs}" + ) if orig_out is None: # to keep same as phi op definition, orig_out may receive None continue