From 427e47459b7cc8b51fd4b4b5c05c99941d364aa0 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Tue, 9 Jan 2018 20:48:54 +0800 Subject: [PATCH] Add grad_op_maker for sequence_pool. --- paddle/operators/sequence_pool_op.cc | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/paddle/operators/sequence_pool_op.cc b/paddle/operators/sequence_pool_op.cc index 34e1a125915..549d9620ef2 100644 --- a/paddle/operators/sequence_pool_op.cc +++ b/paddle/operators/sequence_pool_op.cc @@ -115,12 +115,32 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { } }; +class SequencePoolGradOpMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op_desc_ptr = new framework::OpDesc(); + op_desc_ptr->SetType("sequence_pool_grad"); + op_desc_ptr->SetInput("X", Input("X")); + if (boost::get(GetAttr("pooltype")) == "MAX") { + op_desc_ptr->SetInput("MaxIndex", Output("MaxIndex")); + } + op_desc_ptr->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op_desc_ptr->SetAttrMap(Attrs()); + return std::unique_ptr(op_desc_ptr); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker, - sequence_pool_grad, ops::SequencePoolGradOp); +REGISTER_OPERATOR(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker, + ops::SequencePoolGradOpMaker); +REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp); REGISTER_OP_CPU_KERNEL( sequence_pool, ops::SequencePoolKernel); -- GitLab