未验证 提交 2dc7ee27 编写于 作者: L lijianshe02 提交者: GitHub

enhance error message of nll_loss op test=develop (#30125)

* enhance error message of nll_loss op test=develop
上级 54bf3f5a
...@@ -53,10 +53,14 @@ class NLLLossOp : public framework::OperatorWithKernel { ...@@ -53,10 +53,14 @@ class NLLLossOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(w_dims.size(), 1, PADDLE_ENFORCE_EQ(w_dims.size(), 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input(Weight) should be a 1D tensor.")); "Input(Weight) should be a 1D tensor."));
PADDLE_ENFORCE_EQ(x_dims[1], w_dims[0], PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( x_dims[1], w_dims[0],
"Input(Weight) Tensor's size should match " platform::errors::InvalidArgument(
"to the the total number of classes.")); "Expected input tensor Weight's size should equal "
"to the first dimension of the input tensor X. But received "
"Weight's "
"size is %d, the first dimension of input X is %d",
w_dims[0], x_dims[1]));
} }
} }
if (x_dims.size() == 2) { if (x_dims.size() == 2) {
...@@ -68,7 +72,8 @@ class NLLLossOp : public framework::OperatorWithKernel { ...@@ -68,7 +72,8 @@ class NLLLossOp : public framework::OperatorWithKernel {
} else if (x_dims.size() == 4) { } else if (x_dims.size() == 4) {
PADDLE_ENFORCE_EQ(label_dims.size(), 3, PADDLE_ENFORCE_EQ(label_dims.size(), 3,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The tensor rank of Input(Label) must be 3.")); "Expected Input(Lable) dimensions=3, received %d.",
label_dims.size()));
auto input0 = x_dims[0]; auto input0 = x_dims[0];
auto input2 = x_dims[2]; auto input2 = x_dims[2];
auto input3 = x_dims[3]; auto input3 = x_dims[3];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册