From 58f3de95cf34d8c826221781e8a8dbea954e7069 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 14 Jul 2017 14:56:49 +0800 Subject: [PATCH] Optimize ptr (#2851) * use OperatorPtr = std::shared_ptr; * use ScopePtr = std::share_ptr; --- paddle/framework/net.cc | 4 +- paddle/framework/net.h | 13 +++--- paddle/framework/op_registry.h | 4 +- paddle/framework/op_registry_test.cc | 20 +++++----- paddle/framework/operator.h | 12 +++--- paddle/framework/operator_test.cc | 59 +++++++++++++++++++++++----- paddle/framework/scope.h | 7 +++- 7 files changed, 82 insertions(+), 37 deletions(-) diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index 73b3051235e..854ad8e33e9 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -5,13 +5,13 @@ namespace framework { PlainNet::PlainNet(const NetDesc& def) {} -void PlainNet::InferShape(Scope* scope) { +void PlainNet::InferShape(const ScopePtr& scope) const { for (auto& op : ops_) { op.InferShape(); } } -void PlainNet::Run(std::shared_ptr scope, DeviceContext* ctx) { +void PlainNet::Run(const ScopePtr& scope, const DeviceContext& ctx) const { for (auto& op : ops_) { op.Run(ctx); } diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 76992e07282..0481d8f47cc 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -37,8 +37,8 @@ struct OpAttrs {}; class Operator { public: Operator(const OpDesc &def) {} - void InferShape() {} - void Run(DeviceContext *ctx) {} + void InferShape() const {} + void Run(const DeviceContext &ctx) const {} }; /** @@ -60,7 +60,7 @@ class Net { /** * @brief Infer shapes of all inputs and outputs of operators. */ - virtual void InferShape(Scope *scope) = 0; + virtual void InferShape(const ScopePtr &scope) const = 0; /** * @brief Run the network. * @@ -69,7 +69,7 @@ class Net { * environment for ops. `begin` and `end` specify the scope of `ops_` to run, * If no positive indexes are provided, all operators in `ops_` will run. */ - virtual void Run(std::shared_ptr scope, DeviceContext *ctx) = 0; + virtual void Run(const ScopePtr &scope, const DeviceContext &ctx) const = 0; /** * @brief Add an Operator according to `def`. @@ -114,7 +114,7 @@ class PlainNet : public Net { * Infer all the operators' input and output varialbes' shapes, will be called * before every mini-batch */ - virtual void InferShape(Scope *scope) override; + virtual void InferShape(const ScopePtr &scope) const override; /** * @brief Run the network. @@ -123,7 +123,8 @@ class PlainNet : public Net { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - virtual void Run(std::shared_ptr scope, DeviceContext *ctx) override; + virtual void Run(const ScopePtr &scope, + const DeviceContext &ctx) const override; /** * @brief Add an operator to this network. diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index d049599a2fd..6be6ae15c24 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -198,9 +198,9 @@ class OpRegistry { op_type, op_proto.InitializationErrorString()); } - static OperatorBase* CreateOp(const OpDesc& op_desc) { + static OperatorPtr CreateOp(const OpDesc& op_desc) { std::string op_type = op_desc.type(); - OperatorBase* op = creators().at(op_type)(); + OperatorPtr op(creators().at(op_type)()); op->desc_ = op_desc; op->inputs_.reserve((size_t)op_desc.inputs_size()); std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 1adafa3714e..4791d4aaab4 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -5,9 +5,9 @@ namespace paddle { namespace framework { class CosineOp : public OperatorBase { public: - void Run(const std::shared_ptr& scope, + void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const override {} - void InferShape(const std::shared_ptr& scope) const override {} + void InferShape(const ScopePtr& scope) const override {} }; class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { @@ -25,8 +25,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOp : public OperatorBase { public: - void InferShape(const std::shared_ptr& scope) const override {} - void Run(const std::shared_ptr& scope, + void InferShape(const ScopePtr& scope) const override {} + void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const override {} public: @@ -67,7 +67,7 @@ TEST(OpRegistry, CreateOp) { attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_f(scale); - paddle::framework::OperatorBase* op = + paddle::framework::OperatorPtr op = paddle::framework::OpRegistry::CreateOp(op_desc); auto scope = std::make_shared(); paddle::platform::CPUDeviceContext dev_ctx; @@ -89,7 +89,7 @@ TEST(OpRegistry, IllegalAttr) { bool caught = false; try { - paddle::framework::OperatorBase* op __attribute__((unused)) = + paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); } catch (paddle::framework::EnforceNotMet err) { caught = true; @@ -110,7 +110,7 @@ TEST(OpRegistry, DefaultValue) { ASSERT_TRUE(op_desc.IsInitialized()); - paddle::framework::OperatorBase* op = + paddle::framework::OperatorPtr op = paddle::framework::OpRegistry::CreateOp(op_desc); auto scope = std::make_shared(); paddle::platform::CPUDeviceContext dev_ctx; @@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) { // attr 'test_attr' is not set bool caught = false; try { - paddle::framework::OperatorBase* op __attribute__((unused)) = + paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); } catch (paddle::framework::EnforceNotMet err) { caught = true; @@ -155,7 +155,7 @@ TEST(OpRegistry, CustomChecker) { attr->set_i(3); caught = false; try { - paddle::framework::OperatorBase* op __attribute__((unused)) = + paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); } catch (paddle::framework::EnforceNotMet err) { caught = true; @@ -174,7 +174,7 @@ TEST(OpRegistry, CustomChecker) { attr->set_type(paddle::framework::AttrType::INT); attr->set_i(4); SetInputFormat(&op_desc); - paddle::framework::OperatorBase* op = + paddle::framework::OperatorPtr op = paddle::framework::OpRegistry::CreateOp(op_desc); paddle::platform::CPUDeviceContext dev_ctx; auto scope = std::make_shared(); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index d3c55e0ceb6..cf79f379fae 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -30,7 +30,7 @@ namespace paddle { namespace framework { class OperatorBase; - +using OperatorPtr = std::shared_ptr; /** * OperatorBase has the basic element that Net will call to do computation. * Only CreateOperator from OpRegistry will new Operator directly. User @@ -56,10 +56,10 @@ class OperatorBase { /// InferShape infer the size of Variables used by this Operator with /// information inside scope - virtual void InferShape(const std::shared_ptr& scope) const = 0; + virtual void InferShape(const ScopePtr& scope) const = 0; /// Net will call this function to Run an op. - virtual void Run(const std::shared_ptr& scope, + virtual void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const = 0; protected: @@ -82,7 +82,7 @@ class OpKernel { */ class KernelContext { public: - KernelContext(const OperatorBase* op, const std::shared_ptr& scope, + KernelContext(const OperatorBase* op, const ScopePtr& scope, const platform::DeviceContext& device_context) : op_(*op), scope_(scope), device_context_(device_context) {} @@ -95,7 +95,7 @@ class OpKernel { } const OperatorBase& op_; - const std::shared_ptr& scope_; + const ScopePtr& scope_; const platform::DeviceContext& device_context_; }; @@ -140,7 +140,7 @@ class OperatorWithKernel : public OperatorBase { using OpKernelMap = std::unordered_map, OpKernelHash>; - void Run(const std::shared_ptr& scope, + void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const final { auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx)); opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx)); diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 204b601a4aa..d0c3153faef 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -22,8 +22,8 @@ namespace framework { class OperatorTest : public OperatorBase { public: void Init() override { x = 1; } - void InferShape(const std::shared_ptr& scope) const override {} - void Run(const std::shared_ptr& scope, + void InferShape(const ScopePtr& scope) const override {} + void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const override { float scale = GetAttr("scale"); ASSERT_NEAR(scale, 3.14, 1e-5); @@ -36,6 +36,50 @@ class OperatorTest : public OperatorBase { float x = 0; }; +class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { + public: + OperatorTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", "input of test op"); + AddOutput("output", "output of test op"); + AddAttr("scale", "scale of cosine op") + .SetDefault(1.0) + .LargerThan(0.0); + AddComment("This is test op"); + } +}; + +} // namespace framework +} // namespace paddle + +REGISTER_OP(test_operator, paddle::framework::OperatorTest, + paddle::framework::OperatorTestProtoAndCheckerMaker); + +TEST(OperatorBase, all) { + paddle::framework::OpDesc op_desc; + op_desc.set_type("test_operator"); + *op_desc.mutable_inputs()->Add() = "IN1"; + *op_desc.mutable_outputs()->Add() = "OUT1"; + auto attr = op_desc.mutable_attrs()->Add(); + attr->set_name("scale"); + attr->set_type(paddle::framework::AttrType::FLOAT); + float scale = 3.14; + attr->set_f(scale); + + paddle::platform::CPUDeviceContext device_context; + auto scope = std::make_shared(); + + paddle::framework::OperatorPtr op = + paddle::framework::OpRegistry::CreateOp(op_desc); + ASSERT_EQ(op->GetAttr("scale"), scale); + scope->CreateVariable("OUT1"); + op->Run(scope, device_context); + std::cout << op->DebugString() << std::endl; +} + +namespace paddle { +namespace framework { + class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) @@ -73,9 +117,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest, REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest); TEST(OpKernel, all) { - using namespace paddle::framework; - - OpDesc op_desc; + paddle::framework::OpDesc op_desc; op_desc.set_type("op_with_kernel"); *op_desc.mutable_inputs()->Add() = "IN1"; *op_desc.mutable_outputs()->Add() = "OUT1"; @@ -85,10 +127,9 @@ TEST(OpKernel, all) { attr->set_f(3.14); paddle::platform::CPUDeviceContext cpu_device_context; - auto scope = std::make_shared(); + auto scope = std::make_shared(); - OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); + paddle::framework::OperatorPtr op = + paddle::framework::OpRegistry::CreateOp(op_desc); op->Run(scope, cpu_device_context); - - delete op; } diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index a4470f726fb..ec62c9189fd 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -23,6 +23,9 @@ limitations under the License. */ namespace paddle { namespace framework { +class Scope; +using ScopePtr = std::shared_ptr; + /** * @brief Scope that manage all variables. * @@ -41,7 +44,7 @@ class Scope { /** * @brief Initialize a Scope with parent. */ - explicit Scope(const std::shared_ptr& parent) : parent_(parent) {} + explicit Scope(const ScopePtr& parent) : parent_(parent) {} /** * @brief Create Variable @@ -88,7 +91,7 @@ class Scope { private: std::unordered_map> vars_; - std::shared_ptr parent_{nullptr}; + ScopePtr parent_{nullptr}; }; } // namespace framework -- GitLab