未验证 提交 059699a2 编写于 作者: J JYChen 提交者: GitHub

just a patch for bool tensor indexing with shape -1 (#51046)

上级 74442f5e
......@@ -320,14 +320,19 @@ def get_value_for_bool_tensor(var, item):
"the dims of bool index except to be equal or less "
"than {}, but received {}.".format(len(var.shape), len(item.shape))
)
for i, dim_len in enumerate(item.shape):
if dim_len != var.shape[i]:
i = 0
item_shape = item.shape
while i < len(item.shape):
dim_len = item_shape[i]
if dim_len != -1 and var.shape[i] != -1 and dim_len != var.shape[i]:
raise IndexError(
"The dimension of bool index doesn't match indexed array along "
"dimension {}, the target dimension is {}, but received {}.".format(
i, var.shape[i], dim_len
)
)
i += 1
empty_shape = [0] + list(var.shape[i:])
def idx_not_empty(var, item):
from ..tensor import gather_nd
......@@ -335,15 +340,12 @@ def get_value_for_bool_tensor(var, item):
bool_2_idx = paddle.nonzero(item == True)
return gather_nd(var, bool_2_idx)
def idx_empty(var):
var_shape = list(var.shape)
var_shape[0] = 0
return paddle.empty(var_shape, dtype=var.dtype)
from paddle.static.nn import cond
return cond(
item.any(), lambda: idx_not_empty(var, item), lambda: idx_empty(var)
item.any(),
lambda: idx_not_empty(var, item),
lambda: paddle.empty(empty_shape, var.dtype),
)
......@@ -848,7 +850,7 @@ def set_value_for_bool_tensor(var, item, value):
"than {}, but received {}.".format(len(var.shape), len(item.shape))
)
for i, dim_len in enumerate(item.shape):
if dim_len != var.shape[i]:
if dim_len != -1 and var.shape[i] != -1 and dim_len != var.shape[i]:
raise IndexError(
"The dimension of bool index doesn't match indexed array along "
"dimension {}, the target dimension is {}, but received {}.".format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册