From c2543f5b29df028e9eceec0273b882484998c03a Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 24 Jul 2017 15:20:29 +0800 Subject: [PATCH] Remove ScopePtr and OperatorPtr * ScopePtr means pointer of scope, but it can be shared or uniqued. Change it to std::shared_ptr to make code better to read. --- paddle/framework/net.h | 10 +++++----- paddle/framework/net_op_test.cc | 4 ++-- paddle/framework/op_registry.h | 12 ++++++------ paddle/framework/op_registry_test.cc | 24 ++++++++++-------------- paddle/framework/operator.h | 7 +++---- paddle/framework/operator_test.cc | 12 +++++------- paddle/framework/scope.h | 5 ++--- paddle/pybind/pybind.cc | 5 +++-- 8 files changed, 36 insertions(+), 43 deletions(-) diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 19c5fa223..b2c64a867 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -39,7 +39,7 @@ namespace framework { */ class Net : public OperatorBase { public: - virtual void AddOp(const OperatorPtr& op) = 0; + virtual void AddOp(const std::shared_ptr& op) = 0; virtual void CompleteAddOp(bool calc) = 0; }; @@ -57,7 +57,7 @@ class PlainNet : public Net { * Infer all the operators' input and output variables' shapes, will be called * before every mini-batch */ - void InferShape(const ScopePtr& scope) const override { + void InferShape(const std::shared_ptr& scope) const override { for (auto& op : ops_) { op->InferShape(scope); } @@ -70,7 +70,7 @@ class PlainNet : public Net { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - void Run(const ScopePtr& scope, + void Run(const std::shared_ptr& scope, const platform::DeviceContext& dev_ctx) const override { for (auto& op : ops_) { op->Run(scope, dev_ctx); @@ -80,7 +80,7 @@ class PlainNet : public Net { /** * @brief Add an operator by ptr */ - void AddOp(const OperatorPtr& op) override { + void AddOp(const std::shared_ptr& op) override { PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); ops_.push_back(op); } @@ -89,7 +89,7 @@ class PlainNet : public Net { std::string DebugString() const override; - std::vector ops_; + std::vector> ops_; private: bool add_op_done_{false}; diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc index e814a7e43..c179042c8 100644 --- a/paddle/framework/net_op_test.cc +++ b/paddle/framework/net_op_test.cc @@ -10,10 +10,10 @@ static int run_cnt = 0; class TestOp : public pd::OperatorBase { public: - void InferShape(const paddle::framework::ScopePtr& scope) const override { + void InferShape(const std::shared_ptr& scope) const override { ++infer_shape_cnt; } - void Run(const paddle::framework::ScopePtr& scope, + void Run(const std::shared_ptr& scope, const paddle::platform::DeviceContext& dev_ctx) const override { ++run_cnt; } diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index c41fe1072..165a68c1c 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -227,10 +227,10 @@ class OpRegistry { } } - static OperatorPtr CreateOp(const std::string& type, - const VarNameList& inputs, - const VarNameList& outputs, - const AttributeMap& attrs) { + static std::shared_ptr CreateOp(const std::string& type, + const VarNameList& inputs, + const VarNameList& outputs, + const AttributeMap& attrs) { auto op_create_it = creators().find(type); PADDLE_ENFORCE(op_create_it != creators().end(), "Operator %s cannot be found", type); @@ -252,10 +252,10 @@ class OpRegistry { } op->Init(); - return OperatorPtr(op); + return std::shared_ptr(op); } - static OperatorPtr CreateOp(const OpDesc& op_desc) { + static std::shared_ptr CreateOp(const OpDesc& op_desc) { std::vector inputs; inputs.reserve((size_t)op_desc.inputs_size()); std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 32a7e88a8..05095372d 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -7,9 +7,9 @@ namespace paddle { namespace framework { class CosineOp : public OperatorBase { public: - void Run(const ScopePtr& scope, + void Run(const std::shared_ptr& scope, const platform::DeviceContext& dev_ctx) const override {} - void InferShape(const ScopePtr& scope) const override {} + void InferShape(const std::shared_ptr& scope) const override {} }; class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { @@ -27,8 +27,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOp : public OperatorBase { public: - void InferShape(const ScopePtr& scope) const override {} - void Run(const ScopePtr& scope, + void InferShape(const std::shared_ptr& scope) const override {} + void Run(const std::shared_ptr& scope, const platform::DeviceContext& dev_ctx) const override {} }; @@ -67,7 +67,7 @@ TEST(OpRegistry, CreateOp) { attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_f(scale); - paddle::framework::OperatorPtr op = + std::shared_ptr op = paddle::framework::OpRegistry::CreateOp(op_desc); auto scope = std::make_shared(); paddle::platform::CPUDeviceContext dev_ctx; @@ -89,8 +89,7 @@ TEST(OpRegistry, IllegalAttr) { bool caught = false; try { - paddle::framework::OperatorPtr op __attribute__((unused)) = - paddle::framework::OpRegistry::CreateOp(op_desc); + paddle::framework::OpRegistry::CreateOp(op_desc); } catch (std::runtime_error& err) { caught = true; std::string msg = "larger_than check fail"; @@ -110,7 +109,7 @@ TEST(OpRegistry, DefaultValue) { ASSERT_TRUE(op_desc.IsInitialized()); - paddle::framework::OperatorPtr op = + std::shared_ptr op = paddle::framework::OpRegistry::CreateOp(op_desc); auto scope = std::make_shared(); paddle::platform::CPUDeviceContext dev_ctx; @@ -136,8 +135,7 @@ TEST(OpRegistry, CustomChecker) { // attr 'test_attr' is not set bool caught = false; try { - paddle::framework::OperatorPtr op __attribute__((unused)) = - paddle::framework::OpRegistry::CreateOp(op_desc); + paddle::framework::OpRegistry::CreateOp(op_desc); } catch (std::runtime_error& err) { caught = true; std::string msg = "Attribute 'test_attr' is required!"; @@ -155,8 +153,7 @@ TEST(OpRegistry, CustomChecker) { attr->set_i(3); caught = false; try { - paddle::framework::OperatorPtr op __attribute__((unused)) = - paddle::framework::OpRegistry::CreateOp(op_desc); + paddle::framework::OpRegistry::CreateOp(op_desc); } catch (std::runtime_error& err) { caught = true; std::string msg = "'test_attr' must be even!"; @@ -174,8 +171,7 @@ TEST(OpRegistry, CustomChecker) { attr->set_type(paddle::framework::AttrType::INT); attr->set_i(4); SetInputFormat(&op_desc); - paddle::framework::OperatorPtr op = - paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); paddle::platform::CPUDeviceContext dev_ctx; auto scope = std::make_shared(); op->Run(scope, dev_ctx); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 5f046d629..6b8dbb39a 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -47,7 +47,6 @@ struct EigenDeviceConverter { #endif class OperatorBase; -using OperatorPtr = std::shared_ptr; /** * OperatorBase has the basic element that Net will call to do computation. * Only CreateOperator from OpRegistry will new Operator directly. User @@ -80,10 +79,10 @@ class OperatorBase { /// InferShape infer the size of Variables used by this Operator with /// information inside scope - virtual void InferShape(const ScopePtr& scope) const = 0; + virtual void InferShape(const std::shared_ptr& scope) const = 0; /// Net will call this function to Run an op. - virtual void Run(const ScopePtr& scope, + virtual void Run(const std::shared_ptr& scope, const platform::DeviceContext& dev_ctx) const = 0; // Get a input with argument's name described in `op_proto` @@ -208,7 +207,7 @@ class OperatorWithKernel : public OperatorBase { using OpKernelMap = std::unordered_map, OpKernelHash>; - void Run(const ScopePtr& scope, + void Run(const std::shared_ptr& scope, const platform::DeviceContext& dev_ctx) const final { auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); opKernel->Compute(KernelContext(this, scope, dev_ctx)); diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 8e55d0111..3fae356c3 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -24,8 +24,8 @@ static int op_run_num = 0; class OpWithoutKernelTest : public OperatorBase { public: void Init() override { x = 1; } - void InferShape(const ScopePtr& scope) const override {} - void Run(const ScopePtr& scope, + void InferShape(const std::shared_ptr& scope) const override {} + void Run(const std::shared_ptr& scope, const platform::DeviceContext& dev_ctx) const override { op_run_num++; ASSERT_EQ((int)inputs_.size(), 1); @@ -70,8 +70,7 @@ TEST(OperatorBase, all) { paddle::platform::CPUDeviceContext device_context; auto scope = std::make_shared(); - paddle::framework::OperatorPtr op = - paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); scope->CreateVariable("OUT1"); ASSERT_EQ(paddle::framework::op_run_num, 0); op->Run(scope, device_context); @@ -189,8 +188,7 @@ TEST(OpKernel, all) { paddle::platform::CPUDeviceContext cpu_device_context; auto scope = std::make_shared(); - paddle::framework::OperatorPtr op = - paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); op->Run(scope, cpu_device_context); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); @@ -236,6 +234,6 @@ TEST(OpKernel, multi_inputs) { paddle::platform::CPUDeviceContext cpu_device_context; auto scope = std::make_shared(); - OperatorPtr op(paddle::framework::OpRegistry::CreateOp(op_desc)); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); op->Run(scope, cpu_device_context); } diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index ec62c9189..79c9ffd1a 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -24,7 +24,6 @@ namespace paddle { namespace framework { class Scope; -using ScopePtr = std::shared_ptr; /** * @brief Scope that manage all variables. @@ -44,7 +43,7 @@ class Scope { /** * @brief Initialize a Scope with parent. */ - explicit Scope(const ScopePtr& parent) : parent_(parent) {} + explicit Scope(const std::shared_ptr& parent) : parent_(parent) {} /** * @brief Create Variable @@ -91,7 +90,7 @@ class Scope { private: std::unordered_map> vars_; - ScopePtr parent_{nullptr}; + std::shared_ptr parent_{nullptr}; }; } // namespace framework diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 2c843839c..d48a948d2 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -126,9 +126,10 @@ All parameter, weight, gradient are variables in Paddle. return new paddle::platform::CPUDeviceContext(); }); - py::class_ operator_base(m, "Operator"); + py::class_> operator_base( + m, "Operator"); - operator_base.def_static("create", [](py::bytes protobin) -> pd::OperatorPtr { + operator_base.def_static("create", [](py::bytes protobin) { pd::OpDesc desc; PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), "Cannot parse user input to OpDesc"); -- GitLab