未验证 提交 669853f5 编写于 作者: W WeiXin 提交者: GitHub

Polish the error message of paddle.slice. (#35179)

* polish the error message of paddle.slice.

* polish code.
上级 32c1ec42
......@@ -113,7 +113,7 @@ class SliceOp : public framework::OperatorWithKernel {
}
ctx->SetOutputDim("Out", out_dims);
if (axes[0] != 0) {
if (axes.size() > 0 && axes[0] != 0) {
ctx->ShareLoD("Input", /*->*/ "Out");
}
}
......
......@@ -10975,6 +10975,22 @@ def slice(input, axes, starts, ends):
attrs = ()
starts_tensor = None
ends_tensor = None
if isinstance(axes, (list, tuple)):
if len(axes) == 0:
raise ValueError(
"Input axes should not be an empty list/tuple.")
for i in range(len(axes)):
if axes[i] < 0:
axes[i] = max(0, axes[i] + len(input.shape))
else:
axes[i] = min(len(input.shape) - 1, axes[i])
else:
raise ValueError(
"Input axes must be a python list or tuple, but reveived {}".
format(type(axes)))
infer_flags = list(1 for i in range(len(axes)))
if isinstance(starts, (list, tuple)):
......
......@@ -694,6 +694,45 @@ class TestInferShape(unittest.TestCase):
out0 = paddle.slice(x, axes=[1], starts=[0], ends=[3])
self.assertEqual(out0.shape, (3, 3, 5))
def test_axis_less_than_zero(self):
# Using paddle.disable_static will make other unittests fail.
with fluid.dygraph.guard():
x_arr = np.arange(0, 24, dtype=np.float32).reshape([2, 3, 4])
x = paddle.to_tensor(x_arr)
pp_slice = paddle.slice(x, [100, ], [0], [1])
np_slice = x_arr[:, :, 0:1]
self.assertTrue(np.array_equal(pp_slice, np_slice))
pp_slice = paddle.slice(x, [-100, ], [0], [1])
np_slice = x_arr[0:1]
self.assertTrue(np.array_equal(pp_slice, np_slice))
x_arr = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(x_arr, (0, 0, 0)))
starts = paddle.to_tensor(
np.reshape(
np.array(
[], dtype=np.int32), (0, )))
ends = paddle.to_tensor(
np.reshape(
np.array(
[], dtype=np.int32), (0, )))
with self.assertRaises(ValueError):
paddle.slice(x, [-1000000], starts, ends)
with self.assertRaises(ValueError):
paddle.slice(x, [1000000], starts, ends)
with self.assertRaises(ValueError):
paddle.slice(x, [], starts, ends)
with self.assertRaises(ValueError):
paddle.slice(x, 0, starts, ends)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册