diff --git a/paddle/fluid/operators/gru_unit_op.cc b/paddle/fluid/operators/gru_unit_op.cc index 8f75a67bc78a5829a5ef5fbe5ed2887368b55e57..f8d1d44b5423dd09fe5aad11434911af6f14fe77 100644 --- a/paddle/fluid/operators/gru_unit_op.cc +++ b/paddle/fluid/operators/gru_unit_op.cc @@ -124,7 +124,7 @@ $$ which is same as one time step of GRU Operator. -@note To implement the complete GRU unit, fully-connected operator must be +@note To implement the complete GRU unit, fully-connected operator must be used before to feed xu, xr and xc as the Input of GRUUnit operator. )DOC"); @@ -194,13 +194,45 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { } }; +class GRUUnitGradOpMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("gru_unit_grad"); + + op->SetInput("Input", Input("Input")); + op->SetInput("HiddenPrev", Input("HiddenPrev")); + op->SetInput("Weight", Input("Weight")); + op->SetInput("Bias", Input("Bias")); + + op->SetInput("Hidden", Output("Hidden")); + op->SetInput("Gate", Output("Gate")); + op->SetInput("ResetHiddenPrev", Output("ResetHiddenPrev")); + op->SetInput(framework::GradVarName("Hidden"), OutputGrad("Hidden")); + + op->SetAttrMap(Attrs()); + + op->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); + op->SetOutput(framework::GradVarName("HiddenPrev"), + InputGrad("HiddenPrev")); + op->SetOutput(framework::GradVarName("Weight"), InputGrad("Weight")); + op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias")); + return std::unique_ptr(op); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; + REGISTER_OPERATOR(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker, - paddle::framework::DefaultGradOpDescMaker) -REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp) + ops::GRUUnitGradOpMaker); +REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp); + REGISTER_OP_CPU_KERNEL( gru_unit, ops::GRUUnitKernel, ops::GRUUnitKernel);