未验证 提交 35a7ae21 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] polish prim arg None check (#52449)

* polish prim arg None check

* fix bug
上级 9197b7f7
......@@ -550,6 +550,17 @@ def _lower(block, reverse, blacklist):
block._sync_with_cpp()
# In some case, inputs and outputs of composite op or its replaced composite rule might be None.
# It means such arg will be no longer required in processed program by composite mechanism.
# Therefore, such special ops should be recorded in advance and be released in args check.
ops_contain_none = (
"batch_norm",
"flatten_contiguous_range",
"squeeze2",
"unsqueeze2",
)
def _lower_composite(
block,
filter_: typing.Callable[[framework.Operator], bool] = lambda x: True,
......@@ -664,10 +675,17 @@ def _lower_composite(
f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, '
f'but len(orig_outs) = {len(orig_outs)} and len(new_outs) = {len(new_outs)}'
)
for orig_out, new_out in zip(
orig_outs,
new_outs,
):
if (orig_out is None or new_out is None) and (
op_name not in ops_contain_none
):
raise ValueError(
f"op {op_name} should not contain any None value. original outs={orig_outs} and its composite rule outs={new_outs}"
)
if orig_out is None:
# to keep same as phi op definition, orig_out may receive None
continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册