diff --git a/paddle/operators/sequence_pool_op.cc b/paddle/operators/sequence_pool_op.cc index 34e1a12591515e2363651a2722f52963f7ae43b5..549d9620ef20e2be7d030f41f1e7567cb1922f3b 100644 --- a/paddle/operators/sequence_pool_op.cc +++ b/paddle/operators/sequence_pool_op.cc @@ -115,12 +115,32 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { } }; +class SequencePoolGradOpMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op_desc_ptr = new framework::OpDesc(); + op_desc_ptr->SetType("sequence_pool_grad"); + op_desc_ptr->SetInput("X", Input("X")); + if (boost::get(GetAttr("pooltype")) == "MAX") { + op_desc_ptr->SetInput("MaxIndex", Output("MaxIndex")); + } + op_desc_ptr->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op_desc_ptr->SetAttrMap(Attrs()); + return std::unique_ptr(op_desc_ptr); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker, - sequence_pool_grad, ops::SequencePoolGradOp); +REGISTER_OPERATOR(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker, + ops::SequencePoolGradOpMaker); +REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp); REGISTER_OP_CPU_KERNEL( sequence_pool, ops::SequencePoolKernel);