未验证 提交 395cb561 编写于 作者: Z zhupengyang 提交者: GitHub

refine logsumexp error message and docs (#27713)

上级 057e28bc
......@@ -32,7 +32,7 @@ class LogsumexpOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LE(x_rank, 4,
platform::errors::InvalidArgument(
"The input tensor X's dimensions of logsumexp "
"should be less equal than 4. But received X's "
"should be less or equal than 4. But received X's "
"dimensions = %d, X's shape = [%s].",
x_rank, x_dims));
auto axis = ctx->Attrs().Get<std::vector<int>>("axis");
......@@ -45,20 +45,18 @@ class LogsumexpOp : public framework::OperatorWithKernel {
axis.size()));
for (size_t i = 0; i < axis.size(); i++) {
PADDLE_ENFORCE_LT(
axis[i], x_rank,
platform::errors::InvalidArgument(
"axis[%d] should be in the "
"range [-dimension(X), dimension(X)] "
"where dimesion(X) is %d. But received axis[i] = %d.",
i, x_rank, axis[i]));
PADDLE_ENFORCE_GE(
axis[i], -x_rank,
platform::errors::InvalidArgument(
"axis[%d] should be in the "
"range [-dimension(X), dimension(X)] "
"where dimesion(X) is %d. But received axis[i] = %d.",
i, x_rank, axis[i]));
PADDLE_ENFORCE_LT(axis[i], x_rank,
platform::errors::InvalidArgument(
"axis[%d] should be in the "
"range [-D, D), where D is the dimensions of X and "
"D is %d. But received axis[%d] = %d.",
i, x_rank, i, axis[i]));
PADDLE_ENFORCE_GE(axis[i], -x_rank,
platform::errors::InvalidArgument(
"axis[%d] should be in the "
"range [-D, D), where D is the dimensions of X and "
"D is %d. But received axis[%d] = %d.",
i, x_rank, i, axis[i]));
if (axis[i] < 0) {
axis[i] += x_rank;
}
......
......@@ -999,7 +999,7 @@ def logsumexp(x, axis=None, keepdim=False, name=None):
This OP calculates the log of the sum of exponentials of ``x`` along ``axis`` .
.. math::
logsumexp(x) = \log\sum exp(x)
logsumexp(x) = \\log\\sum exp(x)
Args:
x (Tensor): The input Tensor with data type float32, float64.
......@@ -1030,8 +1030,6 @@ def logsumexp(x, axis=None, keepdim=False, name=None):
import paddle
paddle.disable_static()
x = paddle.to_tensor([[-1.5, 0., 2.], [3., 1.2, -2.4]])
out1 = paddle.logsumexp(x) # [3.4691226]
out2 = paddle.logsumexp(x, 1) # [2.15317821, 3.15684602]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册