diff --git a/paddle/fluid/operators/lstm_unit_op.cc b/paddle/fluid/operators/lstm_unit_op.cc index e6ffda201ba0fd717274e76e746b84aab297873d..e4b72cb50b707ca1bdc5fd4ab6b235d3f498e828 100644 --- a/paddle/fluid/operators/lstm_unit_op.cc +++ b/paddle/fluid/operators/lstm_unit_op.cc @@ -25,8 +25,8 @@ class LstmUnitOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lstm_unit"); OP_INOUT_CHECK(ctx->HasInput("C_prev"), "Input", "C_prev", "lstm_unit"); - OP_INOUT_CHECK(ctx->HasInput("C"), "Output", "C", "lstm_unit"); - OP_INOUT_CHECK(ctx->HasInput("H"), "Output", "H", "lstm_unit"); + OP_INOUT_CHECK(ctx->HasOutput("C"), "Output", "C", "lstm_unit"); + OP_INOUT_CHECK(ctx->HasOutput("H"), "Output", "H", "lstm_unit"); auto x_dims = ctx->GetInputDim("X"); auto c_prev_dims = ctx->GetInputDim("C_prev");