diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index 8844c5a89732992314948df446ec13d95930f39b..54818470b277443e411ea5f7d9c7561eddc7046a 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -51,6 +51,20 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker { } }; +template +class ReduceSumDoubleOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetInput("X", this->OutputGrad(framework::GradVarName("X"))); + op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out"))); + op->SetAttrMap(this->Attrs()); + op->SetType("reduce_sum"); + } +}; + DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReduceSumGradNoNeedBufferVarInferer, "X"); class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference { public: @@ -63,50 +77,6 @@ class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference { } }; -class ReduceSumDoubleGradDescMaker : public framework::GradOpDescMakerBase { - public: - using framework::GradOpDescMakerBase::GradOpDescMakerBase; - - std::vector> operator()() const override { - std::vector> ops; - auto x_gg = OutputGrad(framework::GradVarName("X")); // input ddx - auto out_grads = InputGrad(framework::GradVarName("Out")); - if (!out_grads.empty()) { - auto* out_grad_op = new framework::OpDesc(); - out_grad_op->SetType("reduce_sum"); - out_grad_op->SetInput("X", x_gg); - out_grad_op->SetAttrMap(Attrs()); - out_grad_op->SetOutput("Out", out_grads); - ops.emplace_back(out_grad_op); - } - - return ops; - } -}; - -class ReduceSumDoubleGradOpBaseMaker : public imperative::GradOpBaseMakerBase { - public: - using imperative::GradOpBaseMakerBase::GradOpBaseMakerBase; - - std::shared_ptr operator()() const override { - auto out_grads = InputGrad(framework::GradVarName("Out")); - if (!out_grads.empty()) { - auto x_gg = OutputGrad(framework::GradVarName("X")); // input ddx - auto node = this->NewGradNode(); - { - imperative::TracedGradOp op(node); - op.SetType("reduce_sum"); - op.SetInput("X", x_gg); - op.SetAttrMap(Attrs()); - op.SetOutput("Out", out_grads); - } - return node; - } else { - return nullptr; - } - } -}; - } // namespace operators } // namespace paddle @@ -121,8 +91,8 @@ REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker, ops::ReduceSumOpGradMaker, ops::ReduceSumOpGradMaker); REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp, - ops::ReduceSumDoubleGradDescMaker, - ops::ReduceSumDoubleGradOpBaseMaker, + ops::ReduceSumDoubleOpGradMaker, + ops::ReduceSumDoubleOpGradMaker, ops::ReduceSumGradNoNeedBufferVarInferer); REGISTER_OP_CPU_KERNEL(