提交 58f3de95 编写于 作者: Q Qiao Longfei 提交者: GitHub

Optimize ptr (#2851)

* use OperatorPtr = std::shared_ptr<OperatorBase>;
* use ScopePtr = std::share_ptr<Scope>;
上级 2462d0c5
...@@ -5,13 +5,13 @@ namespace framework { ...@@ -5,13 +5,13 @@ namespace framework {
PlainNet::PlainNet(const NetDesc& def) {} PlainNet::PlainNet(const NetDesc& def) {}
void PlainNet::InferShape(Scope* scope) { void PlainNet::InferShape(const ScopePtr& scope) const {
for (auto& op : ops_) { for (auto& op : ops_) {
op.InferShape(); op.InferShape();
} }
} }
void PlainNet::Run(std::shared_ptr<Scope> scope, DeviceContext* ctx) { void PlainNet::Run(const ScopePtr& scope, const DeviceContext& ctx) const {
for (auto& op : ops_) { for (auto& op : ops_) {
op.Run(ctx); op.Run(ctx);
} }
......
...@@ -37,8 +37,8 @@ struct OpAttrs {}; ...@@ -37,8 +37,8 @@ struct OpAttrs {};
class Operator { class Operator {
public: public:
Operator(const OpDesc &def) {} Operator(const OpDesc &def) {}
void InferShape() {} void InferShape() const {}
void Run(DeviceContext *ctx) {} void Run(const DeviceContext &ctx) const {}
}; };
/** /**
...@@ -60,7 +60,7 @@ class Net { ...@@ -60,7 +60,7 @@ class Net {
/** /**
* @brief Infer shapes of all inputs and outputs of operators. * @brief Infer shapes of all inputs and outputs of operators.
*/ */
virtual void InferShape(Scope *scope) = 0; virtual void InferShape(const ScopePtr &scope) const = 0;
/** /**
* @brief Run the network. * @brief Run the network.
* *
...@@ -69,7 +69,7 @@ class Net { ...@@ -69,7 +69,7 @@ class Net {
* environment for ops. `begin` and `end` specify the scope of `ops_` to run, * environment for ops. `begin` and `end` specify the scope of `ops_` to run,
* If no positive indexes are provided, all operators in `ops_` will run. * If no positive indexes are provided, all operators in `ops_` will run.
*/ */
virtual void Run(std::shared_ptr<Scope> scope, DeviceContext *ctx) = 0; virtual void Run(const ScopePtr &scope, const DeviceContext &ctx) const = 0;
/** /**
* @brief Add an Operator according to `def`. * @brief Add an Operator according to `def`.
...@@ -114,7 +114,7 @@ class PlainNet : public Net { ...@@ -114,7 +114,7 @@ class PlainNet : public Net {
* Infer all the operators' input and output varialbes' shapes, will be called * Infer all the operators' input and output varialbes' shapes, will be called
* before every mini-batch * before every mini-batch
*/ */
virtual void InferShape(Scope *scope) override; virtual void InferShape(const ScopePtr &scope) const override;
/** /**
* @brief Run the network. * @brief Run the network.
...@@ -123,7 +123,8 @@ class PlainNet : public Net { ...@@ -123,7 +123,8 @@ class PlainNet : public Net {
* scope will be used instead. If no OpContext is provicded, default context * scope will be used instead. If no OpContext is provicded, default context
* will be used. * will be used.
*/ */
virtual void Run(std::shared_ptr<Scope> scope, DeviceContext *ctx) override; virtual void Run(const ScopePtr &scope,
const DeviceContext &ctx) const override;
/** /**
* @brief Add an operator to this network. * @brief Add an operator to this network.
......
...@@ -198,9 +198,9 @@ class OpRegistry { ...@@ -198,9 +198,9 @@ class OpRegistry {
op_type, op_proto.InitializationErrorString()); op_type, op_proto.InitializationErrorString());
} }
static OperatorBase* CreateOp(const OpDesc& op_desc) { static OperatorPtr CreateOp(const OpDesc& op_desc) {
std::string op_type = op_desc.type(); std::string op_type = op_desc.type();
OperatorBase* op = creators().at(op_type)(); OperatorPtr op(creators().at(op_type)());
op->desc_ = op_desc; op->desc_ = op_desc;
op->inputs_.reserve((size_t)op_desc.inputs_size()); op->inputs_.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
......
...@@ -5,9 +5,9 @@ namespace paddle { ...@@ -5,9 +5,9 @@ namespace paddle {
namespace framework { namespace framework {
class CosineOp : public OperatorBase { class CosineOp : public OperatorBase {
public: public:
void Run(const std::shared_ptr<Scope>& scope, void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const std::shared_ptr<Scope>& scope) const override {} void InferShape(const ScopePtr& scope) const override {}
}; };
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...@@ -25,8 +25,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -25,8 +25,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase { class MyTestOp : public OperatorBase {
public: public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {} void InferShape(const ScopePtr& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope, void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
public: public:
...@@ -67,7 +67,7 @@ TEST(OpRegistry, CreateOp) { ...@@ -67,7 +67,7 @@ TEST(OpRegistry, CreateOp) {
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(scale); attr->set_f(scale);
paddle::framework::OperatorBase* op = paddle::framework::OperatorPtr op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<paddle::framework::Scope>(); auto scope = std::make_shared<paddle::framework::Scope>();
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
...@@ -89,7 +89,7 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -89,7 +89,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false; bool caught = false;
try { try {
paddle::framework::OperatorBase* op __attribute__((unused)) = paddle::framework::OperatorPtr op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) { } catch (paddle::framework::EnforceNotMet err) {
caught = true; caught = true;
...@@ -110,7 +110,7 @@ TEST(OpRegistry, DefaultValue) { ...@@ -110,7 +110,7 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_TRUE(op_desc.IsInitialized()); ASSERT_TRUE(op_desc.IsInitialized());
paddle::framework::OperatorBase* op = paddle::framework::OperatorPtr op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<paddle::framework::Scope>(); auto scope = std::make_shared<paddle::framework::Scope>();
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
...@@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) {
// attr 'test_attr' is not set // attr 'test_attr' is not set
bool caught = false; bool caught = false;
try { try {
paddle::framework::OperatorBase* op __attribute__((unused)) = paddle::framework::OperatorPtr op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) { } catch (paddle::framework::EnforceNotMet err) {
caught = true; caught = true;
...@@ -155,7 +155,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -155,7 +155,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(3); attr->set_i(3);
caught = false; caught = false;
try { try {
paddle::framework::OperatorBase* op __attribute__((unused)) = paddle::framework::OperatorPtr op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) { } catch (paddle::framework::EnforceNotMet err) {
caught = true; caught = true;
...@@ -174,7 +174,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -174,7 +174,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_type(paddle::framework::AttrType::INT); attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4); attr->set_i(4);
SetInputFormat(&op_desc); SetInputFormat(&op_desc);
paddle::framework::OperatorBase* op = paddle::framework::OperatorPtr op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<paddle::framework::Scope>(); auto scope = std::make_shared<paddle::framework::Scope>();
......
...@@ -30,7 +30,7 @@ namespace paddle { ...@@ -30,7 +30,7 @@ namespace paddle {
namespace framework { namespace framework {
class OperatorBase; class OperatorBase;
using OperatorPtr = std::shared_ptr<OperatorBase>;
/** /**
* OperatorBase has the basic element that Net will call to do computation. * OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User * Only CreateOperator from OpRegistry will new Operator directly. User
...@@ -56,10 +56,10 @@ class OperatorBase { ...@@ -56,10 +56,10 @@ class OperatorBase {
/// InferShape infer the size of Variables used by this Operator with /// InferShape infer the size of Variables used by this Operator with
/// information inside scope /// information inside scope
virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0; virtual void InferShape(const ScopePtr& scope) const = 0;
/// Net will call this function to Run an op. /// Net will call this function to Run an op.
virtual void Run(const std::shared_ptr<Scope>& scope, virtual void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const = 0; const platform::DeviceContext& dev_ctx) const = 0;
protected: protected:
...@@ -82,7 +82,7 @@ class OpKernel { ...@@ -82,7 +82,7 @@ class OpKernel {
*/ */
class KernelContext { class KernelContext {
public: public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope, KernelContext(const OperatorBase* op, const ScopePtr& scope,
const platform::DeviceContext& device_context) const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {} : op_(*op), scope_(scope), device_context_(device_context) {}
...@@ -95,7 +95,7 @@ class OpKernel { ...@@ -95,7 +95,7 @@ class OpKernel {
} }
const OperatorBase& op_; const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_; const ScopePtr& scope_;
const platform::DeviceContext& device_context_; const platform::DeviceContext& device_context_;
}; };
...@@ -140,7 +140,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -140,7 +140,7 @@ class OperatorWithKernel : public OperatorBase {
using OpKernelMap = using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>; std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void Run(const std::shared_ptr<Scope>& scope, void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const final { const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx)); auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx));
opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx)); opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx));
......
...@@ -22,8 +22,8 @@ namespace framework { ...@@ -22,8 +22,8 @@ namespace framework {
class OperatorTest : public OperatorBase { class OperatorTest : public OperatorBase {
public: public:
void Init() override { x = 1; } void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {} void InferShape(const ScopePtr& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope, void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
float scale = GetAttr<float>("scale"); float scale = GetAttr<float>("scale");
ASSERT_NEAR(scale, 3.14, 1e-5); ASSERT_NEAR(scale, 3.14, 1e-5);
...@@ -36,6 +36,50 @@ class OperatorTest : public OperatorBase { ...@@ -36,6 +36,50 @@ class OperatorTest : public OperatorBase {
float x = 0; float x = 0;
}; };
class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OperatorTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
AddOutput("output", "output of test op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddComment("This is test op");
}
};
} // namespace framework
} // namespace paddle
REGISTER_OP(test_operator, paddle::framework::OperatorTest,
paddle::framework::OperatorTestProtoAndCheckerMaker);
TEST(OperatorBase, all) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("test_operator");
*op_desc.mutable_inputs()->Add() = "IN1";
*op_desc.mutable_outputs()->Add() = "OUT1";
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
float scale = 3.14;
attr->set_f(scale);
paddle::platform::CPUDeviceContext device_context;
auto scope = std::make_shared<paddle::framework::Scope>();
paddle::framework::OperatorPtr op =
paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(op->GetAttr<float>("scale"), scale);
scope->CreateVariable("OUT1");
op->Run(scope, device_context);
std::cout << op->DebugString() << std::endl;
}
namespace paddle {
namespace framework {
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
...@@ -73,9 +117,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest, ...@@ -73,9 +117,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest); REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest);
TEST(OpKernel, all) { TEST(OpKernel, all) {
using namespace paddle::framework; paddle::framework::OpDesc op_desc;
OpDesc op_desc;
op_desc.set_type("op_with_kernel"); op_desc.set_type("op_with_kernel");
*op_desc.mutable_inputs()->Add() = "IN1"; *op_desc.mutable_inputs()->Add() = "IN1";
*op_desc.mutable_outputs()->Add() = "OUT1"; *op_desc.mutable_outputs()->Add() = "OUT1";
...@@ -85,10 +127,9 @@ TEST(OpKernel, all) { ...@@ -85,10 +127,9 @@ TEST(OpKernel, all) {
attr->set_f(3.14); attr->set_f(3.14);
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<paddle::framework::Scope>();
OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OperatorPtr op =
paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);
delete op;
} }
...@@ -23,6 +23,9 @@ limitations under the License. */ ...@@ -23,6 +23,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Scope;
using ScopePtr = std::shared_ptr<Scope>;
/** /**
* @brief Scope that manage all variables. * @brief Scope that manage all variables.
* *
...@@ -41,7 +44,7 @@ class Scope { ...@@ -41,7 +44,7 @@ class Scope {
/** /**
* @brief Initialize a Scope with parent. * @brief Initialize a Scope with parent.
*/ */
explicit Scope(const std::shared_ptr<Scope>& parent) : parent_(parent) {} explicit Scope(const ScopePtr& parent) : parent_(parent) {}
/** /**
* @brief Create Variable * @brief Create Variable
...@@ -88,7 +91,7 @@ class Scope { ...@@ -88,7 +91,7 @@ class Scope {
private: private:
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_; std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
std::shared_ptr<Scope> parent_{nullptr}; ScopePtr parent_{nullptr};
}; };
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册