From 35a7ae21823f7703bae1bd88b1898aeeaac79479 Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Mon, 3 Apr 2023 23:26:51 +0800 Subject: [PATCH] [Prim] polish prim arg None check (#52449) * polish prim arg None check * fix bug --- python/paddle/incubate/autograd/primx.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index b9b7c90b17a..50943b07458 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -550,6 +550,17 @@ def _lower(block, reverse, blacklist): block._sync_with_cpp() +# In some case, inputs and outputs of composite op or its replaced composite rule might be None. +# It means such arg will be no longer required in processed program by composite mechanism. +# Therefore, such special ops should be recorded in advance and be released in args check. +ops_contain_none = ( + "batch_norm", + "flatten_contiguous_range", + "squeeze2", + "unsqueeze2", +) + + def _lower_composite( block, filter_: typing.Callable[[framework.Operator], bool] = lambda x: True, @@ -664,10 +675,17 @@ def _lower_composite( f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, ' f'but len(orig_outs) = {len(orig_outs)} and len(new_outs) = {len(new_outs)}' ) + for orig_out, new_out in zip( orig_outs, new_outs, ): + if (orig_out is None or new_out is None) and ( + op_name not in ops_contain_none + ): + raise ValueError( + f"op {op_name} should not contain any None value. original outs={orig_outs} and its composite rule outs={new_outs}" + ) if orig_out is None: # to keep same as phi op definition, orig_out may receive None continue -- GitLab