未验证 提交 cb47a74c 编写于 作者: B Bai Yifan 提交者: GitHub

Fix fsp_op error message,test=develop (#24410)

* fix fsp_op error message,test=develop
上级 ee4795d1
...@@ -23,9 +23,9 @@ class FSPOp : public framework::OperatorWithKernel { ...@@ -23,9 +23,9 @@ class FSPOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fsp_op"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fsp");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "fsp_op"); OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "fsp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fsp_op"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fsp");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
...@@ -103,10 +103,11 @@ class FSPOpGrad : public framework::OperatorWithKernel { ...@@ -103,10 +103,11 @@ class FSPOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fsp_grad");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "fsp_grad");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Input(Out@GRAD) should not be null"); framework::GradVarName("Out"), "fsp_grad");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册