未验证 提交 4a675b7a 编写于 作者: W WangZhen 提交者: GitHub

Enhance slice to support 0 shape Tensor (#45861)

* Enhance slice to support 0 dims Tensor

* Add UT
上级 f06ab336
...@@ -759,8 +759,10 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self, ...@@ -759,8 +759,10 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
decrease_axis, none_axes, infer_flags, list_select_idxs; decrease_axis, none_axes, infer_flags, list_select_idxs;
// if index is a list, list_select_flag will be true // if index is a list, list_select_flag will be true
bool list_select_flag = false; bool list_select_flag = false;
// Note(0x45f): Using defined() instead of initialized()
// to support slice tensor which shape like [0, 0, 0].
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
self->tensor.initialized(), self->tensor.defined(),
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"tensor %s has not been initialized, we can only slice initialized " "tensor %s has not been initialized, we can only slice initialized "
......
...@@ -284,5 +284,23 @@ class TestPaddleStridedSlice(unittest.TestCase): ...@@ -284,5 +284,23 @@ class TestPaddleStridedSlice(unittest.TestCase):
np.testing.assert_array_equal(sl.numpy(), array_slice) np.testing.assert_array_equal(sl.numpy(), array_slice)
def slice_zero_shape_tensor(x):
y = x[1:2]
return y
class TestSliceZeroShapeTensor(unittest.TestCase):
def test_slice(self):
paddle.disable_static()
x = paddle.ones([0, 0, 0, 0])
y = slice_zero_shape_tensor(x)
np.testing.assert_equal(y.shape, [0, 0, 0, 0])
static_func = paddle.jit.to_static(slice_zero_shape_tensor)
y = static_func(x)
np.testing.assert_equal(y.shape, [0, 0, 0, 0])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册