未验证 提交 5932fee6 编写于 作者: Z zhang wenhui 提交者: GitHub

enhance error message, test=develop (#30220)

上级 da16b33f
...@@ -30,8 +30,10 @@ class CVMOp : public framework::OperatorWithKernel { ...@@ -30,8 +30,10 @@ class CVMOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "CVM"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "CVM");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(
"Input(X)'s rank should be 2.")); x_dims.size(), 2UL,
platform::errors::InvalidArgument(
"Input(X)'s rank should be 2, but got %d", x_dims.size()));
if (ctx->Attrs().Get<bool>("use_cvm")) { if (ctx->Attrs().Get<bool>("use_cvm")) {
ctx->SetOutputDim("Y", {x_dims[0], x_dims[1]}); ctx->SetOutputDim("Y", {x_dims[0], x_dims[1]});
...@@ -68,26 +70,31 @@ class CVMGradientOp : public framework::OperatorWithKernel { ...@@ -68,26 +70,31 @@ class CVMGradientOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto cvm_dims = ctx->GetInputDim("CVM"); auto cvm_dims = ctx->GetInputDim("CVM");
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
PADDLE_ENFORCE_EQ(x_dims.size(), 2, platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(
"Input(X)'s rank should be 2.")); x_dims.size(), 2,
platform::errors::InvalidArgument(
"Expect Input(X)'s rank == 2, but got %d", x_dims.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dy_dims.size(), 2, dy_dims.size(), 2,
platform::errors::InvalidArgument("Input(Y@Grad)'s rank should be 2.")); platform::errors::InvalidArgument(
"Expect Input(X)'s rank == 2, but got %d", dy_dims.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
cvm_dims.size(), 2, cvm_dims.size(), 2,
platform::errors::InvalidArgument("Input(CVM)'s rank should be 2.")); platform::errors::InvalidArgument(
"Expect Input(X)'s rank == 2, but got %d", cvm_dims.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims[0], dy_dims[0], x_dims[0], dy_dims[0],
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The 1st dimension of Input(X) and Input(Y@Grad) should " "The 1st dimension of Input(X) and Input(Y@Grad) should "
"be equal.")); "be equal, X is %d, Y@Grad is %d",
x_dims[0], dy_dims[0]));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
cvm_dims[1], 2, cvm_dims[1], 2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"When Attr(soft_label) == false, the 2nd dimension of " "When Attr(soft_label) == false, the 2nd dimension of "
"Input(CVM) should be 2.")); "Input(CVM) should be 2, but got %d cvm_dims[1]"));
ctx->SetOutputDim(framework::GradVarName("X"), x_dims); ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("X", framework::GradVarName("X")); ctx->ShareLoD("X", framework::GradVarName("X"));
} }
......
...@@ -42,7 +42,9 @@ class FTRLOp : public framework::OperatorWithKernel { ...@@ -42,7 +42,9 @@ class FTRLOp : public framework::OperatorWithKernel {
auto param_dim = ctx->GetInputDim("Param"); auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"), PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Two input of FTRL Op's dimension must be same.")); "Two input of FTRL Op's dimension must be same, but "
"param_dim is %d, Grad is %d",
param_dim, ctx->GetInputDim("Grad")));
auto lr_dim = ctx->GetInputDim("LearningRate"); auto lr_dim = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(framework::product(lr_dim), 0, PADDLE_ENFORCE_NE(framework::product(lr_dim), 0,
...@@ -51,9 +53,10 @@ class FTRLOp : public framework::OperatorWithKernel { ...@@ -51,9 +53,10 @@ class FTRLOp : public framework::OperatorWithKernel {
"been initialized. You may need to confirm " "been initialized. You may need to confirm "
"if you put exe.run(startup_program) " "if you put exe.run(startup_program) "
"after optimizer.minimize function.")); "after optimizer.minimize function."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1,
framework::product(lr_dim), 1, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("Learning Rate should be a scalar.")); "Learning Rate should be a scalar, but got %d",
framework::product(lr_dim)));
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("SquaredAccumOut", param_dim); ctx->SetOutputDim("SquaredAccumOut", param_dim);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册