提交 b2806135 编写于 作者: Y Yu Yang

Change Interface to unique_ptr

上级 495a80a7
...@@ -42,7 +42,7 @@ The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_` ...@@ -42,7 +42,7 @@ The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_`
```cpp ```cpp
struct OpInfo { struct OpInfo {
std::function<std::vector<OpDescBind>(const OpDescBind&)> grad_op_maker_; std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)> grad_op_maker_;
... ...
}; };
``` ```
...@@ -55,11 +55,11 @@ We propose a base class called `GradOpDescMakerBase` to let operator developers ...@@ -55,11 +55,11 @@ We propose a base class called `GradOpDescMakerBase` to let operator developers
class GradOpDescMakerBase { class GradOpDescMakerBase {
public: public:
GradOpDescMakerBase(const OpDescBind& ); GradOpDescMakerBase(const OpDescBind& );
virtual std::vector<OpDescBind> operator()()const = 0; virtual std::vector<std::unique_ptr<OpDescBind>> operator()()const = 0;
}; };
``` ```
We can convert `GradOpDescMakerBase` to `std::function<std::vector<OpDescBind>(const OpDescBind&)>` by We can convert `GradOpDescMakerBase` to `std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>` by
```cpp ```cpp
using GradOpMaker = ...; using GradOpMaker = ...;
......
...@@ -24,7 +24,7 @@ class GradOpDescMakerBase { ...@@ -24,7 +24,7 @@ class GradOpDescMakerBase {
explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {} explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {}
virtual ~GradOpDescMakerBase() = default; virtual ~GradOpDescMakerBase() = default;
virtual std::vector<OpDescBind> operator()() const = 0; virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
protected: protected:
static std::vector<std::string> ToGradNames( static std::vector<std::string> ToGradNames(
...@@ -81,10 +81,14 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase { ...@@ -81,10 +81,14 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase {
public: public:
using GradOpDescMakerBase::GradOpDescMakerBase; using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<OpDescBind> operator()() const { return {this->Apply()}; } std::vector<std::unique_ptr<OpDescBind>> operator()() const {
std::vector<std::unique_ptr<OpDescBind>> retv;
retv.emplace_back(this->Apply());
return retv;
}
protected: protected:
virtual OpDescBind Apply() const = 0; virtual std::unique_ptr<OpDescBind> Apply() const = 0;
}; };
class DefaultGradOpDescMaker : public SingleGradOpDescMaker { class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
...@@ -92,23 +96,23 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker { ...@@ -92,23 +96,23 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
using SingleGradOpDescMaker::SingleGradOpDescMaker; using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
virtual OpDescBind Apply() const { virtual std::unique_ptr<OpDescBind> Apply() const {
OpDescBind grad; auto* grad = new OpDescBind();
grad.SetType(this->GradOpType()); grad->SetType(this->GradOpType());
for (auto& input_param : this->InputNames()) { for (auto& input_param : this->InputNames()) {
grad.SetInput(input_param, this->Input(input_param)); grad->SetInput(input_param, this->Input(input_param));
grad.SetOutput(GradVarName(input_param), this->InputGrad(input_param)); grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param));
} }
for (auto& output_param : this->OutputNames()) { for (auto& output_param : this->OutputNames()) {
grad.SetInput(output_param, this->Output(output_param)); grad->SetInput(output_param, this->Output(output_param));
grad.SetInput(GradVarName(output_param), this->OutputGrad(output_param)); grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param));
} }
grad.SetAttrMap(this->Attrs()); grad->SetAttrMap(this->Attrs());
return grad; return std::unique_ptr<OpDescBind>(grad);
} }
virtual std::string GradOpType() const { virtual std::string GradOpType() const {
......
...@@ -28,7 +28,7 @@ namespace framework { ...@@ -28,7 +28,7 @@ namespace framework {
struct OpInfo { struct OpInfo {
OpCreator creator_; OpCreator creator_;
std::string grad_op_type_; std::string grad_op_type_;
std::function<std::vector<OpDescBind>(const OpDescBind&)> grad_op_maker_; GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr}; OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr}; OpAttrChecker* checker_{nullptr};
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OperatorBase; class OperatorBase;
class OpDescBind;
using VariableNameMap = std::map<std::string, std::vector<std::string>>; using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// The order should be as same as framework.proto // The order should be as same as framework.proto
...@@ -34,5 +35,8 @@ using OpCreator = std::function<OperatorBase*( ...@@ -34,5 +35,8 @@ using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VariableNameMap& /*inputs*/, const std::string& /*type*/, const VariableNameMap& /*inputs*/,
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>; const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
using GradOpMakerFN =
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册