未验证 提交 25d3ed65 编写于 作者: C chenxiao120660 提交者: GitHub

fix zero bug of case21: paddle.mode (#51091)

上级 2bcd3935
......@@ -29,6 +29,12 @@ void ModeKernel(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* indices) {
const auto& in_dims = x.dims();
for (int i = 0; i < in_dims.size(); i++) {
PADDLE_ENFORCE_LT(0,
in_dims[i],
errors::InvalidArgument(
"The dims of Input(X) should be greater than 0."));
}
auto out_dims = out->dims();
// axis < 0, cacluate the real axis
if (axis < 0) axis += in_dims.size();
......
......@@ -30,6 +30,12 @@ void ModeKernel(const Context& dev_ctx,
DenseTensor* indices) {
// get the input dims
const auto& in_dims = x.dims();
for (int i = 0; i < in_dims.size(); i++) {
PADDLE_ENFORCE_LT(0,
in_dims[i],
errors::InvalidArgument(
"The dims of Input(X) should be greater than 0."));
}
// calcluate the real axis
if (axis < 0) axis += in_dims.size();
......
......@@ -182,5 +182,17 @@ class TestModeOpInStatic(unittest.TestCase):
np.testing.assert_allclose(paddle_result, expect_value, rtol=1e-05)
class TestModeZeroError(unittest.TestCase):
def test_errors(self):
with paddle.fluid.dygraph.guard():
def test_0_size():
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0, 0]), dtype='float32')
paddle.mode(x, axis=0)
self.assertRaises(ValueError, test_0_size)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册