未验证 提交 200d57c7 编写于 作者: L liym27 提交者: GitHub

[getitem] Support index is None for getitem in static mode (#33001)

上级 23b9ed34
......@@ -295,5 +295,61 @@ class TestVariable(unittest.TestCase):
self.assertRaises(Exception, _test)
class TestVariableSlice(unittest.TestCase):
def _test_item_none(self, place):
data = np.random.rand(2, 3, 4).astype("float32")
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
out0 = x[0:, None, 1:]
out1 = x[0:, None]
out2 = x[None, 1:]
out3 = x[None]
outs = [out0, out1, out2, out3]
exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=outs)
expected = [
data[0:, None, 1:], data[0:, None], data[None, 1:], data[None]
]
for i in range(len(outs)):
self.assertEqual(outs[i].shape, expected[i].shape)
self.assertTrue((result[i] == expected[i]).all())
def _test_item_none_and_decrease(self, place):
data = np.random.rand(2, 3, 4).astype("float32")
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
out0 = x[0, 1:, None]
out1 = x[0, None]
out2 = x[None, 1]
out3 = x[None]
out4 = x[0, 0, 0, None]
out5 = x[None, 0, 0, 0, None]
outs = [out0, out1, out2, out3, out4, out5]
exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=outs)
expected = [
data[0, 1:, None], data[0, None], data[None, 1], data[None],
data[0, 0, 0, None], data[None, 0, 0, 0, None]
]
for i in range(len(outs)):
self.assertEqual(outs[i].shape, expected[i].shape)
self.assertTrue((result[i] == expected[i]).all())
def test_slice(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self._test_item_none(place)
self._test_item_none_and_decrease(place)
if __name__ == '__main__':
unittest.main()
......@@ -50,6 +50,17 @@ def replace_ellipsis(var, item):
return item
def replace_none(item):
new_item = []
none_axes = []
for i, slice_item in enumerate(item):
if slice_item is None:
none_axes.append(i)
else:
new_item.append(slice_item)
return new_item, none_axes
def is_integer_or_scalar_tensor(ele):
from .framework import Variable
if isinstance(ele, int):
......@@ -97,9 +108,10 @@ def _getitem_impl_(var, item):
starts = []
ends = []
steps = []
reverse_axis = []
reverse_axes = []
use_strided_slice = False
item, none_axes = replace_none(item)
for dim, slice_item in enumerate(item):
if is_integer_or_scalar_tensor(slice_item):
......@@ -120,7 +132,7 @@ def _getitem_impl_(var, item):
if start is None and end is None:
assert (step == -1)
reverse_axis.append(dim)
reverse_axes.append(dim)
continue
start = 0 if start is None else start
......@@ -195,9 +207,38 @@ def _getitem_impl_(var, item):
attrs=attrs)
out = slice_out_var
if len(reverse_axis) > 0:
if len(reverse_axes) > 0:
from .layers.tensor import reverse
out = reverse(out, axis=reverse_axis)
out = reverse(out, axis=reverse_axes)
# Deal with cases when all axes are decreased.
# After slice, the shape of out is [1], which should have been [], but Paddle doesn't support scalar.
# In order to ensure the correctness of the final shape of out, one dimension of out needs to be decreased.
# For example:
# # x.shape: (2,3,4)
# out = x[0, 1, 1, None] # out.shape : (1)
if len(decrease_axes) == len(var.shape):
none_axes = none_axes[1:]
if len(none_axes) > 0:
# Deal with cases that decrease_axes is not empty
# For example:
# # x.shape: (2,3,4)
# out = x[0, 0:2, None] # out.shape : (2, 1, 4)
for idx, axis in enumerate(none_axes):
l = len([i for i in decrease_axes if i < axis])
new_axis = axis - l
none_axes[idx] = new_axis
# Deal with cases when all axes are decreased.
# After slice, the shape of out is [1], which should have been [], but Paddle doesn't support scalar.
# In order to ensure the correctness of the final shape of out, one dimension of out needs to be decreased.
# For example:
# # x.shape: (2,3,4)
# out = x[0, 1, 1, None] # out.shape : (1)
from ..tensor import unsqueeze
out = unsqueeze(out, axis=none_axes)
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册