未验证 提交 2f900965 编写于 作者: C chenxiao120660 提交者: GitHub

fix zero bug of case18: paddle.logsumexp (#51034)

* fix bug of logsumexp

* fix bug for logsumexp

* fix bug for logsumexp
上级 83f61bd5
......@@ -71,6 +71,13 @@ void LogsumexpKernel(const Context& dev_ctx,
reduce_all = recompute_reduce_all(x, axis, reduce_all);
auto x_dim = x.dims();
for (int i = 0; i < x_dim.size(); i++) {
PADDLE_ENFORCE_LT(0,
x_dim[i],
errors::InvalidArgument(
"The dims of Input(X) should be greater than 0."));
}
if (reduce_all) {
// Flatten and reduce 1-D tensor
auto input = phi::EigenVector<T>::Flatten(x);
......
......@@ -238,5 +238,20 @@ class TestLogsumexpAPI(unittest.TestCase):
paddle.enable_static()
# Test logsumexp bug
class TestLogZeroError(unittest.TestCase):
def test_errors(self):
with paddle.fluid.dygraph.guard():
def test_0_size():
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(
np.reshape(array, [0, 0, 0]), dtype='float32'
)
paddle.logsumexp(x, axis=1)
self.assertRaises(ValueError, test_0_size)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册