未验证 提交 2bcd3935 编写于 作者: A ahahahahahaha 提交者: GitHub

fix divide zero bug for paddle.all (#51088)

上级 77c9c90a
......@@ -25,6 +25,13 @@ void AllKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
DenseTensor* out) {
auto x_dim = x.dims();
for (int i = 0; i < x_dim.size(); i++) {
PADDLE_ENFORCE_LT(0,
x_dim[i],
errors::InvalidArgument(
"The dims of Input(X) should be greater than 0."));
}
bool reduce_all = recompute_reduce_all(x, dims);
AllRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -1250,6 +1250,18 @@ class TestAnyAPI(unittest.TestCase):
paddle.enable_static()
class TestAllZeroError(unittest.TestCase):
def test_errors(self):
with paddle.fluid.dygraph.guard():
def test_0_size():
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0, 0, 0]), dtype='bool')
paddle.all(x, axis=1)
self.assertRaises(ValueError, test_0_size)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册