提交 b2806135 编写于 作者: Y Yu Yang

Change Interface to unique_ptr

上级 495a80a7
......@@ -42,7 +42,7 @@ The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_`
```cpp
struct OpInfo {
std::function<std::vector<OpDescBind>(const OpDescBind&)> grad_op_maker_;
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)> grad_op_maker_;
...
};
```
......@@ -55,11 +55,11 @@ We propose a base class called `GradOpDescMakerBase` to let operator developers
class GradOpDescMakerBase {
public:
GradOpDescMakerBase(const OpDescBind& );
virtual std::vector<OpDescBind> operator()()const = 0;
virtual std::vector<std::unique_ptr<OpDescBind>> operator()()const = 0;
};
```
We can convert `GradOpDescMakerBase` to `std::function<std::vector<OpDescBind>(const OpDescBind&)>` by
We can convert `GradOpDescMakerBase` to `std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>` by
```cpp
using GradOpMaker = ...;
......
......@@ -24,7 +24,7 @@ class GradOpDescMakerBase {
explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {}
virtual ~GradOpDescMakerBase() = default;
virtual std::vector<OpDescBind> operator()() const = 0;
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
protected:
static std::vector<std::string> ToGradNames(
......@@ -81,10 +81,14 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase {
public:
using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<OpDescBind> operator()() const { return {this->Apply()}; }
std::vector<std::unique_ptr<OpDescBind>> operator()() const {
std::vector<std::unique_ptr<OpDescBind>> retv;
retv.emplace_back(this->Apply());
return retv;
}
protected:
virtual OpDescBind Apply() const = 0;
virtual std::unique_ptr<OpDescBind> Apply() const = 0;
};
class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
......@@ -92,23 +96,23 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
virtual OpDescBind Apply() const {
OpDescBind grad;
grad.SetType(this->GradOpType());
virtual std::unique_ptr<OpDescBind> Apply() const {
auto* grad = new OpDescBind();
grad->SetType(this->GradOpType());
for (auto& input_param : this->InputNames()) {
grad.SetInput(input_param, this->Input(input_param));
grad.SetOutput(GradVarName(input_param), this->InputGrad(input_param));
grad->SetInput(input_param, this->Input(input_param));
grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param));
}
for (auto& output_param : this->OutputNames()) {
grad.SetInput(output_param, this->Output(output_param));
grad.SetInput(GradVarName(output_param), this->OutputGrad(output_param));
grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param));
}
grad.SetAttrMap(this->Attrs());
grad->SetAttrMap(this->Attrs());
return grad;
return std::unique_ptr<OpDescBind>(grad);
}
virtual std::string GradOpType() const {
......
......@@ -28,7 +28,7 @@ namespace framework {
struct OpInfo {
OpCreator creator_;
std::string grad_op_type_;
std::function<std::vector<OpDescBind>(const OpDescBind&)> grad_op_maker_;
GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr};
......
......@@ -20,6 +20,7 @@
namespace paddle {
namespace framework {
class OperatorBase;
class OpDescBind;
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// The order should be as same as framework.proto
......@@ -34,5 +35,8 @@ using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VariableNameMap& /*inputs*/,
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
using GradOpMakerFN =
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>;
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册