From b2806135a53cbe85fbc764375d9cecc2596ab4be Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 3 Oct 2017 14:41:21 -0700 Subject: [PATCH] Change Interface to unique_ptr --- doc/design/register_grad_op.md | 6 +++--- paddle/framework/grad_op_desc_maker.h | 28 +++++++++++++++------------ paddle/framework/op_info.h | 2 +- paddle/framework/type_defs.h | 4 ++++ 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/doc/design/register_grad_op.md b/doc/design/register_grad_op.md index cdb7a8435b..3cf8a59446 100644 --- a/doc/design/register_grad_op.md +++ b/doc/design/register_grad_op.md @@ -42,7 +42,7 @@ The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_` ```cpp struct OpInfo { - std::function(const OpDescBind&)> grad_op_maker_; + std::function>(const OpDescBind&)> grad_op_maker_; ... }; ``` @@ -55,11 +55,11 @@ We propose a base class called `GradOpDescMakerBase` to let operator developers class GradOpDescMakerBase { public: GradOpDescMakerBase(const OpDescBind& ); - virtual std::vector operator()()const = 0; + virtual std::vector> operator()()const = 0; }; ``` -We can convert `GradOpDescMakerBase` to `std::function(const OpDescBind&)>` by +We can convert `GradOpDescMakerBase` to `std::function>(const OpDescBind&)>` by ```cpp using GradOpMaker = ...; diff --git a/paddle/framework/grad_op_desc_maker.h b/paddle/framework/grad_op_desc_maker.h index e6d63e4b8a..e9ae6e2206 100644 --- a/paddle/framework/grad_op_desc_maker.h +++ b/paddle/framework/grad_op_desc_maker.h @@ -24,7 +24,7 @@ class GradOpDescMakerBase { explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {} virtual ~GradOpDescMakerBase() = default; - virtual std::vector operator()() const = 0; + virtual std::vector> operator()() const = 0; protected: static std::vector ToGradNames( @@ -81,10 +81,14 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase { public: using GradOpDescMakerBase::GradOpDescMakerBase; - std::vector operator()() const { return {this->Apply()}; } + std::vector> operator()() const { + std::vector> retv; + retv.emplace_back(this->Apply()); + return retv; + } protected: - virtual OpDescBind Apply() const = 0; + virtual std::unique_ptr Apply() const = 0; }; class DefaultGradOpDescMaker : public SingleGradOpDescMaker { @@ -92,23 +96,23 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker { using SingleGradOpDescMaker::SingleGradOpDescMaker; protected: - virtual OpDescBind Apply() const { - OpDescBind grad; - grad.SetType(this->GradOpType()); + virtual std::unique_ptr Apply() const { + auto* grad = new OpDescBind(); + grad->SetType(this->GradOpType()); for (auto& input_param : this->InputNames()) { - grad.SetInput(input_param, this->Input(input_param)); - grad.SetOutput(GradVarName(input_param), this->InputGrad(input_param)); + grad->SetInput(input_param, this->Input(input_param)); + grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param)); } for (auto& output_param : this->OutputNames()) { - grad.SetInput(output_param, this->Output(output_param)); - grad.SetInput(GradVarName(output_param), this->OutputGrad(output_param)); + grad->SetInput(output_param, this->Output(output_param)); + grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param)); } - grad.SetAttrMap(this->Attrs()); + grad->SetAttrMap(this->Attrs()); - return grad; + return std::unique_ptr(grad); } virtual std::string GradOpType() const { diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index 806a960018..8b7882485f 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -28,7 +28,7 @@ namespace framework { struct OpInfo { OpCreator creator_; std::string grad_op_type_; - std::function(const OpDescBind&)> grad_op_maker_; + GradOpMakerFN grad_op_maker_; OpProto* proto_{nullptr}; OpAttrChecker* checker_{nullptr}; diff --git a/paddle/framework/type_defs.h b/paddle/framework/type_defs.h index dec5066f1e..a5b9472213 100644 --- a/paddle/framework/type_defs.h +++ b/paddle/framework/type_defs.h @@ -20,6 +20,7 @@ namespace paddle { namespace framework { class OperatorBase; +class OpDescBind; using VariableNameMap = std::map>; // The order should be as same as framework.proto @@ -34,5 +35,8 @@ using OpCreator = std::function; +using GradOpMakerFN = + std::function>(const OpDescBind&)>; + } // namespace framework } // namespace paddle -- GitLab