diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index a1049f718dcd3ea42320b7d5d11152e57ace07b4..9d30887224fe0020ff5665f362e7403bf5c724ee 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -89,7 +89,7 @@ static std::unique_ptr BackwardRecursive( } // Returned gradient network - auto net = std::unique_ptr(); + auto net = std::unique_ptr(new operators::NetOp()); if (forwardOp.IsNetOp()) { // Because forwardOp is a net op, it can static_cast. @@ -204,7 +204,7 @@ static std::unique_ptr BackwardRecursive( } // See header for comments -std::shared_ptr Backward( +std::unique_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars) { std::unordered_set no_grad_names; diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h index c181919dc165cf0b49362f85e22ceb4131bbd387..1ecf69881b3126c2904920b9f4b77bfcccc9cf86 100644 --- a/paddle/framework/backward.h +++ b/paddle/framework/backward.h @@ -20,7 +20,7 @@ namespace framework { // Create the backward operator from a forward operator. // TODO(yuyang18): Add more API reference comment. -extern std::shared_ptr Backward( +extern std::unique_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars); } // namespace framework diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index d942604bf05998ab9e1ee147b28586e7e4e9777d..1003b1ccd83c85c54f07c8d2b84f993203408941 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -180,8 +180,7 @@ TEST(Backward, simple_op_not_need_grad) { auto no_input_gop = f::Backward(*fwd, {"x", "b"}); ASSERT_NE(no_input_gop, nullptr); ASSERT_TRUE(no_input_gop->IsNetOp()); - ASSERT_EQ(0UL, - std::static_pointer_cast(no_input_gop)->ops_.size()); + ASSERT_EQ(0UL, static_cast(no_input_gop.get())->ops_.size()); } TEST(Backward, net_fc_backward_normal) { diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index 2fc1e214b28d8675327a52e379c6be41d15314cd..f0114b9e4908d65b3fddb493230777f9e500b4e1 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -48,29 +48,6 @@ namespace framework { using Tensor = framework::Tensor; -template -void ExposeOperator(ClassType &m) { - m.def("infer_shape", &ClassType::type::InferShape) - .def("run", &ClassType::type::Run) - .def("type", - [](const typename ClassType::type &op) -> std::string { - return op.Type(); - }) - .def("outputs", - [](const typename ClassType::type &op) - -> std::map> { - return op.Outputs(); - }) - .def("inputs", - [](const typename ClassType::type &op) { return op.Inputs(); }) - .def("__str__", &ClassType::type::DebugString) - .def("no_intermediate_outputs", - [](const typename ClassType::type &op) { - return op.OutputVars(false); - }) - .def("support_gpu", &ClassType::type::SupportGPU); -} - static size_t UniqueIntegerGenerator() { static std::atomic generator; return generator.fetch_add(1); @@ -207,70 +184,69 @@ All parameter, weight, gradient are variables in Paddle. .def(py::init<>()) .def("__str__", string::to_string); - py::class_ operator_base(m, "Operator"); - - operator_base.def_static("create", [](py::bytes protobin) { - OpDesc desc; - PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), - "Cannot parse user input to OpDesc"); - PADDLE_ENFORCE(desc.IsInitialized(), - "User OpDesc is not initialized, reason %s", - desc.InitializationErrorString()); - return OpRegistry::CreateOp(desc); - }); - - operator_base.def("backward", - [](const OperatorBase &forwardOp, - const std::unordered_set &no_grad_vars) { - return Backward(forwardOp, no_grad_vars); - }); - - ExposeOperator(operator_base); - - py::class_ net(m, "Net"); + py::class_(m, "Operator") + .def_static("create", + [](py::bytes protobin) { + OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + return OpRegistry::CreateOp(desc); + }) + .def("backward", + [](const OperatorBase &forwardOp, + const std::unordered_set &no_grad_vars) { + return Backward(forwardOp, no_grad_vars).release(); + }) + .def("infer_shape", &OperatorBase::InferShape) + .def("run", &OperatorBase::Run) + .def("type", + [](const OperatorBase &op) -> std::string { return op.Type(); }) + .def("outputs", + [](const OperatorBase &op) + -> std::map> { + return op.Outputs(); + }) + .def("inputs", [](const OperatorBase &op) { return op.Inputs(); }) + .def("__str__", &OperatorBase::DebugString) + .def("no_intermediate_outputs", + [](const OperatorBase &op) { return op.OutputVars(false); }) + .def("support_gpu", &OperatorBase::SupportGPU); - net.def_static("create", - []() -> operators::NetOp * { - auto *retv = new operators::NetOp; - retv->SetType("plain_net"); - return retv; - }) + py::class_(m, "Net") + .def_static("create", + []() -> operators::NetOp * { + auto *retv = new operators::NetOp; + retv->SetType("plain_net"); + return retv; + }) .def("add_op", [](operators::NetOp &self, const OperatorBase &op) { self.AddOp(op); }) - .def("add_op", - [](operators::NetOp &self, const operators::NetOp &net) -> void { - self.AddOp(net); - }) - .def("add_op", - [](operators::NetOp &self, - const operators::RecurrentOp &rnn) -> void { self.AddOp(rnn); }) .def("complete_add_op", &operators::NetOp::CompleteAddOp) .def("complete_add_op", [](std::shared_ptr &self) { self->CompleteAddOp(); }); - ExposeOperator(net); - // recurrent_op - py::class_ rnn(m, "RecurrentOp"); - - rnn.def_static( - "create", - [](py::bytes protobin) -> operators::RecurrentOp * { - OpDesc desc; - PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), - "Cannot parse user input to OpDesc"); - PADDLE_ENFORCE(desc.IsInitialized(), - "User OpDesc is not initialized, reason %s", - desc.InitializationErrorString()); - auto rnn_op = OpRegistry::CreateOp(desc); - return static_cast(rnn_op.release()); - }) + py::class_(m, "RecurrentOp") + .def_static( + "create", + [](py::bytes protobin) -> operators::RecurrentOp * { + OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + auto rnn_op = OpRegistry::CreateOp(desc); + return static_cast(rnn_op.release()); + }) .def("set_stepnet", [](operators::RecurrentOp &self, const operators::NetOp &net) -> void { self.set_stepnet(net.Clone()); }); - ExposeOperator(rnn); m.def("unique_integer", UniqueIntegerGenerator); diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 2ec65c63f3facaf8bdcc3668b95d73edc476c807..ce7da1f383672bb2b2dd78881a6e63ef8dba620f 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -41,9 +41,7 @@ class NetOp : public framework::OperatorBase { NetOp(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, const framework::AttributeMap& attrs); - NetOp(const NetOp& o) - : framework::OperatorBase( - static_cast(o)) { + NetOp(const NetOp& o) : framework::OperatorBase(o.type_, {}, {}, o.attrs_) { this->ops_.reserve(o.ops_.size()); std::transform( o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_), diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index aae78a1cecb8b2faf2483c7bf20f7a18ceab58b6..78ce0ba3c0fa4fe380e49a848c2434fe593cd00b 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -42,7 +42,7 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const { rnn::LinkMemories(step_scopes, arg_->memories, i, -1, true /*infer_shape_mode*/); } - stepnet_->InferShape(*step_scopes[i]); + (*stepnet_)->InferShape(*step_scopes[i]); } rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); @@ -61,7 +61,7 @@ void RecurrentAlgorithm::Run(const Scope& scope, rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1, false /*infer_shape_mode*/); } - stepnet_->Run(*step_scopes[step_id], dev_ctx); + (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); } rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/); @@ -76,15 +76,15 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { // Now all variables in scope must be created outside of op. PADDLE_ENFORCE_NOT_NULL(stepnet_); - PADDLE_ENFORCE(!stepnet_->Outputs().empty(), "stepnet_ op has no outputs"); - PADDLE_ENFORCE(!stepnet_->Outputs().empty(), "net_op has no outputs"); + PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs"); + PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "net_op has no outputs"); if (seq_len_ > step_scopes->size()) { for (size_t i = step_scopes->size(); i < seq_len_; ++i) { auto& step_scope = scope.NewScope(); // create step net's temp inputs - for (auto& input : stepnet_->Inputs()) { + for (auto& input : (*stepnet_)->Inputs()) { // the weight are located in parent scope for (auto& var_name : input.second) { if (!step_scope.FindVar(var_name)) { @@ -93,7 +93,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { } } // create stepnet's outputs - for (const auto& output : stepnet_->Outputs()) { + for (const auto& output : (*stepnet_)->Outputs()) { for (auto& var_name : output.second) { step_scope.NewVar(var_name); } @@ -136,7 +136,7 @@ RecurrentOp::RecurrentOp(const std::string& type, const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) { rnn::InitArgument(kArgName, &arg_, *this); - alg_.Init(&arg_, stepnet_.get()); + alg_.Init(&arg_, &stepnet_); } class RecurrentAlgorithmProtoAndCheckerMaker @@ -178,7 +178,7 @@ void RecurrentGradientAlgorithm::Run( rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, false /*infer_shape_mode*/); } - stepnet_->Run(*step_scopes[step_id], dev_ctx); + (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); } LinkBootMemoryGradients(step_scopes[0], false); rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, @@ -215,7 +215,7 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, true /*infer_shape_mode*/); } - stepnet_->InferShape(*step_scopes[step_id]); + (*stepnet_)->InferShape(*step_scopes[step_id]); } rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); @@ -228,7 +228,7 @@ RecurrentGradientOp::RecurrentGradientOp( const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) { rnn::InitArgument(kArgName, &arg_, *this); - alg_.Init(&arg_, stepnet_.get()); + alg_.Init(&arg_, &stepnet_); } } // namespace operators diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index 4d091aa21241e85c32bcc1497ac28244448140f1..bcfa817de8242153b164fa091309f19a6ad8a246 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -34,7 +34,8 @@ class RecurrentAlgorithm { void Run(const framework::Scope& scope, const platform::DeviceContext& dev_ctx) const; - void Init(rnn::Argument* arg, framework::OperatorBase* stepnet) { + void Init(rnn::Argument* arg, + std::unique_ptr* stepnet) { PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before."); arg_ = arg; stepnet_ = stepnet; @@ -63,7 +64,7 @@ class RecurrentAlgorithm { void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const; private: - framework::OperatorBase* stepnet_; + std::unique_ptr* stepnet_; rnn::Argument* arg_; mutable size_t seq_len_; }; @@ -80,7 +81,8 @@ class RecurrentGradientAlgorithm { * operator. */ public: - void Init(rnn::Argument* arg, framework::OperatorBase* stepnet) { + void Init(rnn::Argument* arg, + std::unique_ptr* stepnet) { PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before."); arg_ = std::move(arg); stepnet_ = stepnet; @@ -107,7 +109,7 @@ class RecurrentGradientAlgorithm { private: rnn::Argument* arg_; mutable size_t seq_len_; - framework::OperatorBase* stepnet_; + std::unique_ptr* stepnet_; }; class RecurrentOp : public framework::OperatorBase { diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py index 501cf6110ff745b8a6022b463bc9cc3a70145c60..831c0f0f2a9096cd78ae0a0ffd9b515110c6fa06 100644 --- a/python/paddle/v2/framework/tests/gradient_checker.py +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -165,7 +165,6 @@ class GradientChecker(unittest.TestCase): for no_grad in no_grad_set: if no_grad not in in_names: raise ValueError("no_grad should be in in_names") - backward_op = core.Operator.backward(forward_op, no_grad_set) bwd_outputs = backward_op.outputs()