未验证 提交 24a2bedb 编写于 作者: A Aurelius84 提交者: GitHub

[API] Fix slice bool infershape BUG (#45737)

上级 b83d27ac
...@@ -315,8 +315,8 @@ def get_value_for_bool_tensor(var, item): ...@@ -315,8 +315,8 @@ def get_value_for_bool_tensor(var, item):
return paddle.empty(var_shape, dtype=var.dtype) return paddle.empty(var_shape, dtype=var.dtype)
from .layers.control_flow import cond from .layers.control_flow import cond
return cond(paddle.logical_not(item.any()), lambda: idx_empty(var), return cond(item.any(), lambda: idx_not_empty(var, item),
lambda: idx_not_empty(var, item)) lambda: idx_empty(var))
def _getitem_impl_(var, item): def _getitem_impl_(var, item):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册