From 00548a1601998cea55e7b8096408dc2f5881ef90 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 18 Apr 2018 21:51:07 +0800 Subject: [PATCH] Remove intermediate output's gradient from inputs of grad_op. --- paddle/fluid/operators/gru_unit_op.cc | 38 ++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/gru_unit_op.cc b/paddle/fluid/operators/gru_unit_op.cc index 8f75a67bc78..f8d1d44b542 100644 --- a/paddle/fluid/operators/gru_unit_op.cc +++ b/paddle/fluid/operators/gru_unit_op.cc @@ -124,7 +124,7 @@ $$ which is same as one time step of GRU Operator. -@note To implement the complete GRU unit, fully-connected operator must be +@note To implement the complete GRU unit, fully-connected operator must be used before to feed xu, xr and xc as the Input of GRUUnit operator. )DOC"); @@ -194,13 +194,45 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { } }; +class GRUUnitGradOpMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("gru_unit_grad"); + + op->SetInput("Input", Input("Input")); + op->SetInput("HiddenPrev", Input("HiddenPrev")); + op->SetInput("Weight", Input("Weight")); + op->SetInput("Bias", Input("Bias")); + + op->SetInput("Hidden", Output("Hidden")); + op->SetInput("Gate", Output("Gate")); + op->SetInput("ResetHiddenPrev", Output("ResetHiddenPrev")); + op->SetInput(framework::GradVarName("Hidden"), OutputGrad("Hidden")); + + op->SetAttrMap(Attrs()); + + op->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); + op->SetOutput(framework::GradVarName("HiddenPrev"), + InputGrad("HiddenPrev")); + op->SetOutput(framework::GradVarName("Weight"), InputGrad("Weight")); + op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias")); + return std::unique_ptr(op); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; + REGISTER_OPERATOR(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker, - paddle::framework::DefaultGradOpDescMaker) -REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp) + ops::GRUUnitGradOpMaker); +REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp); + REGISTER_OP_CPU_KERNEL( gru_unit, ops::GRUUnitKernel, ops::GRUUnitKernel); -- GitLab