未验证 提交 de0cb386 编写于 作者: Z zyfncg 提交者: GitHub

fix bug of indexing tensor with None (#37400)

上级 31344ab7
...@@ -562,7 +562,7 @@ static void ParseIndexingSlice( ...@@ -562,7 +562,7 @@ static void ParseIndexingSlice(
PADDLE_ENFORCE_LE(ell_count, 1, PADDLE_ENFORCE_LE(ell_count, 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"An index can only have a single ellipsis ('...')")); "An index can only have a single ellipsis ('...')"));
int none_count = 0;
for (int i = 0, dim = 0; i < size; ++i) { for (int i = 0, dim = 0; i < size; ++i) {
PyObject *slice_item = PyTuple_GetItem(index, i); PyObject *slice_item = PyTuple_GetItem(index, i);
...@@ -608,7 +608,8 @@ static void ParseIndexingSlice( ...@@ -608,7 +608,8 @@ static void ParseIndexingSlice(
} else if (slice_item == Py_Ellipsis) { } else if (slice_item == Py_Ellipsis) {
dim += rank - specified_dims; dim += rank - specified_dims;
} else if (slice_item == Py_None) { } else if (slice_item == Py_None) {
none_axes->push_back(dim); none_axes->push_back(dim + none_count);
none_count++;
} else if (PyList_Check(slice_item)) { } else if (PyList_Check(slice_item)) {
*list_select_flag = true; *list_select_flag = true;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -1214,29 +1215,6 @@ void BindImperative(py::module *m_ptr) { ...@@ -1214,29 +1215,6 @@ void BindImperative(py::module *m_ptr) {
axis -= len; axis -= len;
} }
// Deal with cases that there are more than one
// prefix none index, For example:
// [None, None, :, :, None]
// the none_axes int the return of ParseIndexingSlice is:
// [0, 0, 2 ]
// according to the interface of "unsqueeze2",
// we should convert it to:
// [0, 0, 4 ]
int prefix_zero_cnt = 0;
for (const auto &axis : none_axes) {
if (axis == 0) {
prefix_zero_cnt++;
} else {
break;
}
}
if (prefix_zero_cnt > 0) {
int none_axes_num = static_cast<int>(none_axes.size());
for (int i = prefix_zero_cnt; i < none_axes_num; ++i) {
none_axes[i] += prefix_zero_cnt;
}
}
imperative::NameVarBaseMap ins = {{"X", {out}}}; imperative::NameVarBaseMap ins = {{"X", {out}}};
framework::AttributeMap attrs = {{"axes", none_axes}}; framework::AttributeMap attrs = {{"axes", none_axes}};
auto new_out = std::shared_ptr<imperative::VarBase>( auto new_out = std::shared_ptr<imperative::VarBase>(
......
...@@ -408,6 +408,14 @@ class TestSetValueItemNone9(TestSetValueApi): ...@@ -408,6 +408,14 @@ class TestSetValueItemNone9(TestSetValueApi):
self.data[None, :, 1, ..., None] = np.zeros(self.shape)[0, 0, :, None] self.data[None, :, 1, ..., None] = np.zeros(self.shape)[0, 0, :, None]
class TestSetValueItemNone10(TestSetValueApi):
def _call_setitem(self, x):
x[..., None, :, None] = np.zeros(self.shape)[..., None, :, None]
def _get_answer(self):
self.data[..., None, :, None] = np.zeros(self.shape)[..., None, :, None]
# 1.5 item is list or Tensor of bol # 1.5 item is list or Tensor of bol
class TestSetValueItemBool1(TestSetValueApi): class TestSetValueItemBool1(TestSetValueApi):
def _call_setitem(self, x): def _call_setitem(self, x):
......
...@@ -723,6 +723,7 @@ class TestVarBase(unittest.TestCase): ...@@ -723,6 +723,7 @@ class TestVarBase(unittest.TestCase):
var_tensor[None].numpy(), var_tensor[None].numpy(),
var_tensor[0, 0, None, 0, 0, None].numpy(), var_tensor[0, 0, None, 0, 0, None].numpy(),
var_tensor[None, None, 0, ..., None].numpy(), var_tensor[None, None, 0, ..., None].numpy(),
var_tensor[..., None, :, None].numpy(),
var_tensor[0, 1:10:2, None, None, ...].numpy(), var_tensor[0, 1:10:2, None, None, ...].numpy(),
] ]
...@@ -738,11 +739,12 @@ class TestVarBase(unittest.TestCase): ...@@ -738,11 +739,12 @@ class TestVarBase(unittest.TestCase):
np.array_equal(var[8], np_value[0, 0, None, 0, 0, None])) np.array_equal(var[8], np_value[0, 0, None, 0, 0, None]))
self.assertTrue( self.assertTrue(
np.array_equal(var[9], np_value[None, None, 0, ..., None])) np.array_equal(var[9], np_value[None, None, 0, ..., None]))
self.assertTrue(np.array_equal(var[10], np_value[..., None, :, None]))
# TODO(zyfncg) there is a bug of dimensions when slice step > 1 and # TODO(zyfncg) there is a bug of dimensions when slice step > 1 and
# indexs has int type # indexs has int type
# self.assertTrue( # self.assertTrue(
# np.array_equal(var[10], np_value[0, 1:10:2, None, None, ...])) # np.array_equal(var[11], np_value[0, 1:10:2, None, None, ...]))
def _test_bool_index(self): def _test_bool_index(self):
shape = (4, 2, 5, 64) shape = (4, 2, 5, 64)
......
...@@ -436,13 +436,15 @@ class TestVariableSlice(unittest.TestCase): ...@@ -436,13 +436,15 @@ class TestVariableSlice(unittest.TestCase):
out1 = x[0:, None] out1 = x[0:, None]
out2 = x[None, 1:] out2 = x[None, 1:]
out3 = x[None] out3 = x[None]
out4 = x[..., None, :, None]
outs = [out0, out1, out2, out3] outs = [out0, out1, out2, out3, out4]
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=outs) result = exe.run(prog, fetch_list=outs)
expected = [ expected = [
data[0:, None, 1:], data[0:, None], data[None, 1:], data[None] data[0:, None, 1:], data[0:, None], data[None, 1:], data[None],
data[..., None, :, None]
] ]
for i in range(len(outs)): for i in range(len(outs)):
self.assertEqual(outs[i].shape, expected[i].shape) self.assertEqual(outs[i].shape, expected[i].shape)
......
...@@ -204,7 +204,8 @@ def replace_ellipsis(var, item): ...@@ -204,7 +204,8 @@ def replace_ellipsis(var, item):
# Remove Variable to skip bug when counting Ellipsis # Remove Variable to skip bug when counting Ellipsis
item_remove_var = [ item_remove_var = [
ele for ele in item if not isinstance(ele, (Variable, np.ndarray)) ele for ele in item
if not isinstance(ele, (Variable, np.ndarray)) and ele is not None
] ]
ell_count = item_remove_var.count(Ellipsis) ell_count = item_remove_var.count(Ellipsis)
if ell_count == 0: if ell_count == 0:
...@@ -218,7 +219,7 @@ def replace_ellipsis(var, item): ...@@ -218,7 +219,7 @@ def replace_ellipsis(var, item):
return item[:-1] return item[:-1]
else: else:
item[ell_idx:ell_idx + 1] = [slice(None)] * ( item[ell_idx:ell_idx + 1] = [slice(None)] * (
len(var.shape) - len(item) + 1) len(var.shape) - len(item) + item.count(None) + 1)
return item return item
...@@ -298,8 +299,8 @@ def _getitem_impl_(var, item): ...@@ -298,8 +299,8 @@ def _getitem_impl_(var, item):
use_strided_slice = False use_strided_slice = False
item = replace_ndarray(item) item = replace_ndarray(item)
item, none_axes = replace_none(item)
item = replace_ellipsis(var, item) item = replace_ellipsis(var, item)
item, none_axes = replace_none(item)
slice_info = SliceInfo() slice_info = SliceInfo()
for dim, slice_item in enumerate(item): for dim, slice_item in enumerate(item):
...@@ -517,8 +518,8 @@ def _setitem_impl_(var, item, value): ...@@ -517,8 +518,8 @@ def _setitem_impl_(var, item, value):
steps = [] steps = []
item = replace_ndarray(item) item = replace_ndarray(item)
item, none_axes = replace_none(item)
item = replace_ellipsis(var, item) item = replace_ellipsis(var, item)
item, none_axes = replace_none(item)
slice_info = SliceInfo() slice_info = SliceInfo()
dim = 0 dim = 0
for _, slice_item in enumerate(item): for _, slice_item in enumerate(item):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册