提交 e119177a 编写于 作者: Y Yu Yang

Use unique_ptr

上级 14a59d2e
...@@ -36,9 +36,11 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp( ...@@ -36,9 +36,11 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp(
auto grad_descs = info.grad_op_maker_(op_desc); auto grad_descs = info.grad_op_maker_(op_desc);
std::vector<std::unique_ptr<OperatorBase>> grad_ops; std::vector<std::unique_ptr<OperatorBase>> grad_ops;
grad_ops.reserve(grad_descs.size()); grad_ops.reserve(grad_descs.size());
std::transform( std::transform(grad_descs.begin(), grad_descs.end(),
grad_descs.begin(), grad_descs.end(), std::back_inserter(grad_ops), std::back_inserter(grad_ops),
[](OpDescBind& grad_desc) { return OpRegistry::CreateOp(&grad_desc); }); [](const std::unique_ptr<OpDescBind>& grad_desc) {
return OpRegistry::CreateOp(grad_desc.get());
});
PADDLE_ENFORCE_GT(grad_ops.size(), 0); PADDLE_ENFORCE_GT(grad_ops.size(), 0);
if (grad_ops.size() == 1) { if (grad_ops.size() == 1) {
return std::move(grad_ops[0]); return std::move(grad_ops[0]);
......
...@@ -39,13 +39,13 @@ class RowWiseAddGradMaker : public SingleGradOpDescMaker { ...@@ -39,13 +39,13 @@ class RowWiseAddGradMaker : public SingleGradOpDescMaker {
using SingleGradOpDescMaker::SingleGradOpDescMaker; using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
OpDescBind Apply() const override { std::unique_ptr<OpDescBind> Apply() const override {
OpDescBind grad_op; auto grad_op = new OpDescBind();
grad_op.SetInput(GradVarName("Out"), OutputGrad("Out")); grad_op->SetInput(GradVarName("Out"), OutputGrad("Out"));
grad_op.SetOutput(GradVarName("X"), InputGrad("X")); grad_op->SetOutput(GradVarName("X"), InputGrad("X"));
grad_op.SetOutput(GradVarName("b"), InputGrad("b")); grad_op->SetOutput(GradVarName("b"), InputGrad("b"));
grad_op.SetType("rowwise_add_grad"); grad_op->SetType("rowwise_add_grad");
return grad_op; return std::unique_ptr<OpDescBind>(grad_op);
} }
}; };
...@@ -147,10 +147,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -147,10 +147,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input tensors of sum operator.") AddInput("X", "the input tensors of sum operator.").AsDuplicable();
.AsDuplicable() AddOutput("Out", "the output tensor of sum operator.");
.NotInGradient();
AddOutput("Out", "the output tensor of sum operator.").NotInGradient();
AddComment(""); AddComment("");
} }
}; };
......
...@@ -57,13 +57,13 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker { ...@@ -57,13 +57,13 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
framework::OpDescBind Apply() const override { std::unique_ptr<framework::OpDescBind> Apply() const override {
framework::OpDescBind grad_op; auto* grad_op = new framework::OpDescBind();
grad_op.SetType("mean_grad"); grad_op->SetType("mean_grad");
grad_op.SetInput("X", Input("X")); grad_op->SetInput("X", Input("X"));
grad_op.SetInput(framework::GradVarName("Out"), OutputGrad("Out")); grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op.SetOutput(framework::GradVarName("X"), InputGrad("X")); grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return grad_op; return std::unique_ptr<framework::OpDescBind>(grad_op);
} }
}; };
......
...@@ -69,19 +69,22 @@ class MinusGradMaker : public framework::GradOpDescMakerBase { ...@@ -69,19 +69,22 @@ class MinusGradMaker : public framework::GradOpDescMakerBase {
public: public:
using framework::GradOpDescMakerBase::GradOpDescMakerBase; using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<framework::OpDescBind> operator()() const override { std::vector<std::unique_ptr<framework::OpDescBind>> operator()()
std::vector<framework::OpDescBind> ops; const override {
std::vector<std::unique_ptr<framework::OpDescBind>> ops;
ops.resize(2); ops.resize(2);
ops[0].SetType("scale"); ops[0].reset(new framework::OpDescBind());
ops[0].SetInput("X", OutputGrad("Out")); ops[0]->SetType("scale");
ops[0].SetOutput("Out", InputGrad("X")); ops[0]->SetInput("X", OutputGrad("Out"));
ops[0].SetAttr("scale", 1.0f); ops[0]->SetOutput("Out", InputGrad("X"));
ops[0]->SetAttr("scale", 1.0f);
ops[1].SetType("scale");
ops[1].SetInput("X", OutputGrad("Out")); ops[1].reset(new framework::OpDescBind());
ops[1].SetOutput("Out", InputGrad("Y")); ops[1]->SetType("scale");
ops[1].SetAttr("scale", -1.0f); ops[1]->SetInput("X", OutputGrad("Out"));
ops[1]->SetOutput("Out", InputGrad("Y"));
ops[1]->SetAttr("scale", -1.0f);
return ops; return ops;
} }
}; };
......
...@@ -111,18 +111,18 @@ class PadOpGrad : public framework::OperatorWithKernel { ...@@ -111,18 +111,18 @@ class PadOpGrad : public framework::OperatorWithKernel {
}; };
class PadOpGradMaker : public framework::SingleGradOpDescMaker { class PadOpGradMaker : public framework::SingleGradOpDescMaker {
protected:
framework::OpDescBind Apply() const override {
framework::OpDescBind bind;
bind.SetInput("X", Input("X"));
bind.SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind.SetOutput(framework::GradVarName("X"), InputGrad("X"));
bind.SetAttrMap(Attrs());
return bind;
}
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto* bind = new framework::OpDescBind();
bind->SetInput("X", Input("X"));
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("X"), InputGrad("X"));
bind->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(bind);
}
}; };
} // namespace operators } // namespace operators
......
...@@ -57,13 +57,13 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker { ...@@ -57,13 +57,13 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
framework::OpDescBind Apply() const override { std::unique_ptr<framework::OpDescBind> Apply() const override {
framework::OpDescBind grad_op; auto *grad_op = new framework::OpDescBind();
grad_op.SetType("scale"); grad_op->SetType("scale");
grad_op.SetInput("X", OutputGrad("Out")); grad_op->SetInput("X", OutputGrad("Out"));
grad_op.SetOutput("Out", InputGrad("X")); grad_op->SetOutput("Out", InputGrad("X"));
grad_op.SetAttr("scale", GetAttr("scale")); grad_op->SetAttr("scale", GetAttr("scale"));
return grad_op; return std::unique_ptr<framework::OpDescBind>(grad_op);
} }
}; };
......
...@@ -167,17 +167,17 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker { ...@@ -167,17 +167,17 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
framework::OpDescBind Apply() const override { std::unique_ptr<framework::OpDescBind> Apply() const override {
framework::OpDescBind grad_op; auto* grad_op = new framework::OpDescBind();
grad_op.SetType("softmax_with_cross_entropy_grad"); grad_op->SetType("softmax_with_cross_entropy_grad");
grad_op.SetInput("Label", Input("Label")); grad_op->SetInput("Label", Input("Label"));
grad_op.SetInput("Softmax", Output("Softmax")); grad_op->SetInput("Softmax", Output("Softmax"));
grad_op.SetInput("Loss", Output("Loss")); grad_op->SetInput("Loss", Output("Loss"));
grad_op.SetInput(framework::GradVarName("Softmax"), OutputGrad("Softmax")); grad_op->SetInput(framework::GradVarName("Softmax"), OutputGrad("Softmax"));
grad_op.SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); grad_op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
grad_op.SetOutput(framework::GradVarName("Logits"), InputGrad("Logits")); grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
grad_op.SetAttrMap(Attrs()); grad_op->SetAttrMap(Attrs());
return grad_op; return std::unique_ptr<framework::OpDescBind>(grad_op);
} }
}; };
......
...@@ -60,19 +60,20 @@ class SumGradMaker : public framework::GradOpDescMakerBase { ...@@ -60,19 +60,20 @@ class SumGradMaker : public framework::GradOpDescMakerBase {
public: public:
using framework::GradOpDescMakerBase::GradOpDescMakerBase; using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<framework::OpDescBind> operator()() const override { std::vector<std::unique_ptr<framework::OpDescBind>> operator()()
const override {
auto x_grads = InputGrad("X"); auto x_grads = InputGrad("X");
std::vector<framework::OpDescBind> grad_ops; std::vector<std::unique_ptr<framework::OpDescBind>> grad_ops;
grad_ops.reserve(x_grads.size()); grad_ops.reserve(x_grads.size());
auto og = OutputGrad("Out"); auto og = OutputGrad("Out");
std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops), std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
[&og](const std::string& x_grad) { [&og](const std::string& x_grad) {
framework::OpDescBind grad_op; auto* grad_op = new framework::OpDescBind();
grad_op.SetType("scale"); grad_op->SetType("scale");
grad_op.SetInput("X", og); grad_op->SetInput("X", og);
grad_op.SetOutput("Out", {x_grad}); grad_op->SetOutput("Out", {x_grad});
grad_op.SetAttr("scale", 1.0f); grad_op->SetAttr("scale", 1.0f);
return grad_op; return std::unique_ptr<framework::OpDescBind>(grad_op);
}); });
return grad_ops; return grad_ops;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册