diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index c226e4e3d2a58d1a647e204c4cd26f4eb6bcd968..a1049f718dcd3ea42320b7d5d11152e57ace07b4 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -15,6 +15,8 @@ #include "paddle/framework/backward.h" #include +#include + #include "paddle/framework/op_registry.h" #include "paddle/operators/net_op.h" #include "paddle/operators/recurrent_op.h" @@ -43,11 +45,11 @@ static bool AllInSet( return all_in_set; } -static std::shared_ptr NOP() { - auto net_op = std::make_shared(); +static std::unique_ptr NOP() { + auto net_op = new operators::NetOp(); net_op->SetType("@NOP@"); net_op->CompleteAddOp(); - return net_op; + return std::unique_ptr(net_op); } // Get backward operator from a forward operator, a recursive implementation. @@ -62,11 +64,7 @@ static std::shared_ptr NOP() { // operator, in a complex situation, it maybe a NetOp. // // See Backward.h for details -static std::shared_ptr BackwardRecursive( - const OperatorBase& forwardOp, - std::unordered_set& no_grad_names, size_t& uniq_id); - -std::shared_ptr BackwardRecursive( +static std::unique_ptr BackwardRecursive( const OperatorBase& forwardOp, std::unordered_set& no_grad_names, size_t& uniq_id) { // If all input gradients of forwarding operator do not need to calculate, @@ -91,7 +89,7 @@ std::shared_ptr BackwardRecursive( } // Returned gradient network - auto net = std::make_shared(); + auto net = std::unique_ptr(); if (forwardOp.IsNetOp()) { // Because forwardOp is a net op, it can static_cast. @@ -105,14 +103,14 @@ std::shared_ptr BackwardRecursive( // reversely travel forwardNet and collect all duplicate outputs. for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); ++it, ++local_op_id) { - auto fwd = *it; + auto& fwd = *it; auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id); - net->AddOp(bwd); ForEachVarName(bwd->Outputs(), [&dup_output_ops, local_op_id](const std::string& out) { dup_output_ops[out].emplace_back(local_op_id); return false; }); + net->AddOp(std::move(bwd)); } // Get unique ID for this method. auto uid = uniq_id++; @@ -122,7 +120,7 @@ std::shared_ptr BackwardRecursive( // to handle this case. For each duplicate output, rename it to an alias // (original name with a offset), append an `add` op for its operator, // and finally sum all the alias variable to the final output variable y. - using Pos = std::pair>; + using Pos = std::pair>; std::list insert_position; for (auto& dup_output_op : dup_output_ops) { const std::string& name = dup_output_op.first; @@ -150,13 +148,13 @@ std::shared_ptr BackwardRecursive( [](const Pos& l, const Pos& r) { return l.first > r.first; }); for (auto& pos : insert_position) { - net->InsertOp(pos.first + 1, pos.second); + net->InsertOp(pos.first + 1, std::move(pos.second)); } } else { - std::shared_ptr grad_op = OpRegistry::CreateGradOp(forwardOp); + std::unique_ptr grad_op(OpRegistry::CreateGradOp(forwardOp)); - ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, - grad_op](const std::string& grad_input) { + ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op]( + const std::string& grad_input) { if (no_grad_names.count(grad_input)) { // +1 for \0 std::string prefix = grad_input.substr( @@ -190,20 +188,20 @@ std::shared_ptr BackwardRecursive( const auto& stepnet_op = *static_cast(&rnnop.stepnet()); // create stepnet's gradient op - auto grad_stepnet = BackwardRecursive(stepnet_op, no_grad_names, uniq_id); rnn_grad_op->set_stepnet( - std::static_pointer_cast(grad_stepnet)); + BackwardRecursive(stepnet_op, no_grad_names, uniq_id)); } if (net->ops_.empty()) { // Current no aux op is added to network return grad_op; } - net->AddOp(grad_op); + net->AddOp(std::move(grad_op)); } net->SetType("@GENERATED_BACKWARD@"); net->CompleteAddOp(); - return net; -} // namespace framework + return std::unique_ptr( + static_cast(net.release())); +} // See header for comments std::shared_ptr Backward( diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 4fa0a2750b239f7399086a4c5356995c4852eabc..f0cc0012e11def9bcdd98a6c72a29ec181a6580e 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -174,7 +174,7 @@ class OpRegistry { } } - static std::shared_ptr CreateOp(const std::string& type, + static std::unique_ptr CreateOp(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, AttributeMap attrs) { @@ -183,7 +183,7 @@ class OpRegistry { "Operator '%s' has not been registered.", type); it->second.checker_->Check(attrs); auto op = it->second.creator_(type, inputs, outputs, attrs); - return std::shared_ptr(op); + return std::unique_ptr(op); } static VarNameMap ConvertOpDescVarsToVarNameMap( @@ -199,7 +199,7 @@ class OpRegistry { return ret_val; } - static std::shared_ptr CreateOp(const OpDesc& op_desc) { + static std::unique_ptr CreateOp(const OpDesc& op_desc) { VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); AttributeMap attrs; @@ -210,11 +210,10 @@ class OpRegistry { return CreateOp(op_desc.type(), inputs, outputs, attrs); } - static std::shared_ptr CreateGradOp(const OperatorBase& op) { + static std::unique_ptr CreateGradOp(const OperatorBase& op) { PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops"); - std::shared_ptr grad_op(BuildGradOp(&op)); - return grad_op; + return std::unique_ptr(BuildGradOp(&op)); } static std::unordered_map& op_info_map() { diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 1a85d568350dc04ca1df28129de19cd45b5204b8..50c45919c53af22665feeeebe753da283ded2b0c 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -76,8 +76,7 @@ TEST(OpRegistry, CreateOp) { attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_f(scale); - std::shared_ptr op = - paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::Scope scope; paddle::platform::CPUDeviceContext dev_ctx; op->Run(scope, dev_ctx); @@ -118,8 +117,7 @@ TEST(OpRegistry, DefaultValue) { ASSERT_TRUE(op_desc.IsInitialized()); - std::shared_ptr op = - paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::Scope scope; paddle::platform::CPUDeviceContext dev_ctx; op->Run(scope, dev_ctx); diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index fe0c87bc570825014222807cb90a3bb341b44e8e..2fc1e214b28d8675327a52e379c6be41d15314cd 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -207,8 +207,7 @@ All parameter, weight, gradient are variables in Paddle. .def(py::init<>()) .def("__str__", string::to_string); - py::class_> operator_base( - m, "Operator"); + py::class_ operator_base(m, "Operator"); operator_base.def_static("create", [](py::bytes protobin) { OpDesc desc; @@ -228,25 +227,23 @@ All parameter, weight, gradient are variables in Paddle. ExposeOperator(operator_base); - py::class_> net(m, "Net"); + py::class_ net(m, "Net"); net.def_static("create", - []() -> std::shared_ptr { - auto retv = std::make_shared(); + []() -> operators::NetOp * { + auto *retv = new operators::NetOp; retv->SetType("plain_net"); return retv; }) - .def("add_op", &operators::NetOp::AddOp) + .def("add_op", [](operators::NetOp &self, + const OperatorBase &op) { self.AddOp(op); }) .def("add_op", - [](operators::NetOp &self, - const std::shared_ptr &net) -> void { - self.AddOp(std::static_pointer_cast(net)); + [](operators::NetOp &self, const operators::NetOp &net) -> void { + self.AddOp(net); }) .def("add_op", [](operators::NetOp &self, - const std::shared_ptr &rnn) -> void { - self.AddOp(std::static_pointer_cast(rnn)); - }) + const operators::RecurrentOp &rnn) -> void { self.AddOp(rnn); }) .def("complete_add_op", &operators::NetOp::CompleteAddOp) .def("complete_add_op", [](std::shared_ptr &self) { self->CompleteAddOp(); @@ -255,12 +252,11 @@ All parameter, weight, gradient are variables in Paddle. ExposeOperator(net); // recurrent_op - py::class_> - rnn(m, "RecurrentOp"); + py::class_ rnn(m, "RecurrentOp"); rnn.def_static( "create", - [](py::bytes protobin) -> std::shared_ptr { + [](py::bytes protobin) -> operators::RecurrentOp * { OpDesc desc; PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), "Cannot parse user input to OpDesc"); @@ -268,13 +264,12 @@ All parameter, weight, gradient are variables in Paddle. "User OpDesc is not initialized, reason %s", desc.InitializationErrorString()); auto rnn_op = OpRegistry::CreateOp(desc); - return std::dynamic_pointer_cast(rnn_op); + return static_cast(rnn_op.release()); }) - .def("set_stepnet", - [](operators::RecurrentOp &self, - const std::shared_ptr &net) -> void { - self.set_stepnet(net); - }); + .def("set_stepnet", [](operators::RecurrentOp &self, + const operators::NetOp &net) -> void { + self.set_stepnet(net.Clone()); + }); ExposeOperator(rnn); m.def("unique_integer", UniqueIntegerGenerator); diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 743f0e67dbeaab2de97a6cf635aad0ee90b2cef1..2ec65c63f3facaf8bdcc3668b95d73edc476c807 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -45,11 +45,11 @@ class NetOp : public framework::OperatorBase { : framework::OperatorBase( static_cast(o)) { this->ops_.reserve(o.ops_.size()); - std::transform(o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_), - [](const std::shared_ptr& op) - -> std::shared_ptr { - return std::shared_ptr(op->Clone()); - }); + std::transform( + o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_), + [](const std::unique_ptr& op) { + return std::unique_ptr(op->Clone()); + }); this->CompleteAddOp(); } @@ -86,21 +86,42 @@ class NetOp : public framework::OperatorBase { return true; } + void AddOp(const framework::OperatorBase& op) { AddOp(op.Clone()); } + /** * @brief Add an operator by ptr */ - void AddOp(const std::shared_ptr& op) { + void AddOp(framework::OperatorBase* op, bool own) { PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op"); - ops_.push_back(op); + if (!own) { + op = op->Clone().release(); + } + ops_.emplace_back(op); } - void InsertOp(size_t pos, const std::shared_ptr& op) { + void AddOp(std::unique_ptr&& op) { + AddOp(op.release(), true); + } + + void InsertOp(size_t pos, framework::OperatorBase* op, bool own) { PADDLE_ENFORCE(!add_op_done_, "Cannot InsertOp when this network is sealed"); PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op"); PADDLE_ENFORCE_LE(pos, ops_.size(), "Out of range"); - ops_.insert(ops_.begin() + pos, op); + if (!own) { + op = op->Clone().release(); + } + ops_.insert(ops_.begin() + pos, + std::unique_ptr(op)); + } + + void InsertOp(size_t pos, std::unique_ptr&& op) { + InsertOp(pos, op.release(), true); + } + + void InsertOp(size_t pos, const framework::OperatorBase& op) { + InsertOp(pos, op.Clone()); } void CompleteAddOp(bool calculate = true); @@ -112,7 +133,7 @@ class NetOp : public framework::OperatorBase { std::unique_ptr Clone() const override; - std::vector> ops_; + std::vector> ops_; private: bool add_op_done_{false}; diff --git a/paddle/operators/net_op_test.cc b/paddle/operators/net_op_test.cc index e28d4df6a570968205851c2e5b630a14c0492535..e9598610c0a74e08a613a397109ad65994821498 100644 --- a/paddle/operators/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -38,15 +38,12 @@ TEST(OpKernel, all) { auto net = std::make_shared(); ASSERT_NE(net, nullptr); - auto op1 = std::shared_ptr( + net->AddOp(std::unique_ptr( new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, - {{"Out", {"y"}}}, {})); - net->AddOp(op1); - - auto op2 = std::shared_ptr( + {{"Out", {"y"}}}, {}))); + net->AddOp(std::unique_ptr( new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}}, - {{"Out", {"z"}}}, {})); - net->AddOp(op2); + {{"Out", {"z"}}}, {}))); net->CompleteAddOp(); AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, @@ -61,21 +58,21 @@ TEST(OpKernel, all) { TEST(NetOp, insert_op) { NetOp net; - auto op1 = std::shared_ptr( + auto op1 = std::unique_ptr( new framework::NOP("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, {{"Out", {"y"}}}, {})); - net.AddOp(op1); - net.InsertOp(0, op1); + net.AddOp(*op1); + net.InsertOp(0, *op1); ASSERT_EQ(2UL, net.ops_.size()); - net.InsertOp(2, op1); + net.InsertOp(2, std::move(op1)); ASSERT_EQ(3UL, net.ops_.size()); } TEST(NetOp, Clone) { NetOp net; net.AddOp( - std::shared_ptr(new framework::NOP{"empty", {}, {}, {}})); - net.AddOp(std::shared_ptr( + std::unique_ptr(new framework::NOP{"empty", {}, {}, {}})); + net.AddOp(std::unique_ptr( new framework::NOP{"empty2", {}, {}, {}})); net.CompleteAddOp(true); auto new_net_op = net.Clone(); diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index 78ce0ba3c0fa4fe380e49a848c2434fe593cd00b..aae78a1cecb8b2faf2483c7bf20f7a18ceab58b6 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -42,7 +42,7 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const { rnn::LinkMemories(step_scopes, arg_->memories, i, -1, true /*infer_shape_mode*/); } - (*stepnet_)->InferShape(*step_scopes[i]); + stepnet_->InferShape(*step_scopes[i]); } rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); @@ -61,7 +61,7 @@ void RecurrentAlgorithm::Run(const Scope& scope, rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1, false /*infer_shape_mode*/); } - (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); + stepnet_->Run(*step_scopes[step_id], dev_ctx); } rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/); @@ -76,15 +76,15 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { // Now all variables in scope must be created outside of op. PADDLE_ENFORCE_NOT_NULL(stepnet_); - PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs"); - PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "net_op has no outputs"); + PADDLE_ENFORCE(!stepnet_->Outputs().empty(), "stepnet_ op has no outputs"); + PADDLE_ENFORCE(!stepnet_->Outputs().empty(), "net_op has no outputs"); if (seq_len_ > step_scopes->size()) { for (size_t i = step_scopes->size(); i < seq_len_; ++i) { auto& step_scope = scope.NewScope(); // create step net's temp inputs - for (auto& input : (*stepnet_)->Inputs()) { + for (auto& input : stepnet_->Inputs()) { // the weight are located in parent scope for (auto& var_name : input.second) { if (!step_scope.FindVar(var_name)) { @@ -93,7 +93,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { } } // create stepnet's outputs - for (const auto& output : (*stepnet_)->Outputs()) { + for (const auto& output : stepnet_->Outputs()) { for (auto& var_name : output.second) { step_scope.NewVar(var_name); } @@ -136,7 +136,7 @@ RecurrentOp::RecurrentOp(const std::string& type, const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) { rnn::InitArgument(kArgName, &arg_, *this); - alg_.Init(&arg_, &stepnet_); + alg_.Init(&arg_, stepnet_.get()); } class RecurrentAlgorithmProtoAndCheckerMaker @@ -178,7 +178,7 @@ void RecurrentGradientAlgorithm::Run( rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, false /*infer_shape_mode*/); } - (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); + stepnet_->Run(*step_scopes[step_id], dev_ctx); } LinkBootMemoryGradients(step_scopes[0], false); rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, @@ -215,7 +215,7 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, true /*infer_shape_mode*/); } - (*stepnet_)->InferShape(*step_scopes[step_id]); + stepnet_->InferShape(*step_scopes[step_id]); } rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); @@ -228,7 +228,7 @@ RecurrentGradientOp::RecurrentGradientOp( const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) { rnn::InitArgument(kArgName, &arg_, *this); - alg_.Init(&arg_, &stepnet_); + alg_.Init(&arg_, stepnet_.get()); } } // namespace operators diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index 1d8a6973955cf0b4ab372412fbb5428ff2622a0a..4d091aa21241e85c32bcc1497ac28244448140f1 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -34,7 +34,7 @@ class RecurrentAlgorithm { void Run(const framework::Scope& scope, const platform::DeviceContext& dev_ctx) const; - void Init(rnn::Argument* arg, std::shared_ptr* stepnet) { + void Init(rnn::Argument* arg, framework::OperatorBase* stepnet) { PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before."); arg_ = arg; stepnet_ = stepnet; @@ -63,7 +63,7 @@ class RecurrentAlgorithm { void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const; private: - std::shared_ptr* stepnet_; + framework::OperatorBase* stepnet_; rnn::Argument* arg_; mutable size_t seq_len_; }; @@ -80,7 +80,7 @@ class RecurrentGradientAlgorithm { * operator. */ public: - void Init(rnn::Argument* arg, std::shared_ptr* stepnet) { + void Init(rnn::Argument* arg, framework::OperatorBase* stepnet) { PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before."); arg_ = std::move(arg); stepnet_ = stepnet; @@ -107,7 +107,7 @@ class RecurrentGradientAlgorithm { private: rnn::Argument* arg_; mutable size_t seq_len_; - std::shared_ptr* stepnet_; + framework::OperatorBase* stepnet_; }; class RecurrentOp : public framework::OperatorBase { @@ -133,15 +133,17 @@ class RecurrentOp : public framework::OperatorBase { alg_.Run(scope, dev_ctx); } - void set_stepnet(std::shared_ptr net) { stepnet_ = net; } - const NetOp& stepnet() const { return *stepnet_; } + void set_stepnet(std::unique_ptr net) { + stepnet_ = std::move(net); + } + const OperatorBase& stepnet() const { return *stepnet_; } static const rnn::ArgumentName kArgName; private: RecurrentAlgorithm alg_; rnn::Argument arg_; - std::shared_ptr stepnet_; + std::unique_ptr stepnet_; }; class RecurrentGradientOp : public framework::OperatorBase { @@ -171,12 +173,14 @@ class RecurrentGradientOp : public framework::OperatorBase { static const rnn::ArgumentName kArgName; - void set_stepnet(const std::shared_ptr& net) { stepnet_ = net; } - const NetOp& stepnet() const { return *stepnet_; } + void set_stepnet(std::unique_ptr net) { + stepnet_ = std::move(net); + } + const OperatorBase& stepnet() const { return *stepnet_; } private: RecurrentGradientAlgorithm alg_; - std::shared_ptr stepnet_; + std::unique_ptr stepnet_; rnn::Argument arg_; };