未验证 提交 35a7ae21 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] polish prim arg None check (#52449)

* polish prim arg None check

* fix bug
上级 9197b7f7
...@@ -550,6 +550,17 @@ def _lower(block, reverse, blacklist): ...@@ -550,6 +550,17 @@ def _lower(block, reverse, blacklist):
block._sync_with_cpp() block._sync_with_cpp()
# In some case, inputs and outputs of composite op or its replaced composite rule might be None.
# It means such arg will be no longer required in processed program by composite mechanism.
# Therefore, such special ops should be recorded in advance and be released in args check.
ops_contain_none = (
"batch_norm",
"flatten_contiguous_range",
"squeeze2",
"unsqueeze2",
)
def _lower_composite( def _lower_composite(
block, block,
filter_: typing.Callable[[framework.Operator], bool] = lambda x: True, filter_: typing.Callable[[framework.Operator], bool] = lambda x: True,
...@@ -664,10 +675,17 @@ def _lower_composite( ...@@ -664,10 +675,17 @@ def _lower_composite(
f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, ' f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, '
f'but len(orig_outs) = {len(orig_outs)} and len(new_outs) = {len(new_outs)}' f'but len(orig_outs) = {len(orig_outs)} and len(new_outs) = {len(new_outs)}'
) )
for orig_out, new_out in zip( for orig_out, new_out in zip(
orig_outs, orig_outs,
new_outs, new_outs,
): ):
if (orig_out is None or new_out is None) and (
op_name not in ops_contain_none
):
raise ValueError(
f"op {op_name} should not contain any None value. original outs={orig_outs} and its composite rule outs={new_outs}"
)
if orig_out is None: if orig_out is None:
# to keep same as phi op definition, orig_out may receive None # to keep same as phi op definition, orig_out may receive None
continue continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册