提交 c2543f5b 编写于 作者: Y Yu Yang

Remove ScopePtr and OperatorPtr

* ScopePtr means pointer of scope, but it can be shared or uniqued.
Change it to std::shared_ptr<Scope> to make code better to read.
上级 1294b3c5
...@@ -39,7 +39,7 @@ namespace framework { ...@@ -39,7 +39,7 @@ namespace framework {
*/ */
class Net : public OperatorBase { class Net : public OperatorBase {
public: public:
virtual void AddOp(const OperatorPtr& op) = 0; virtual void AddOp(const std::shared_ptr<OperatorBase>& op) = 0;
virtual void CompleteAddOp(bool calc) = 0; virtual void CompleteAddOp(bool calc) = 0;
}; };
...@@ -57,7 +57,7 @@ class PlainNet : public Net { ...@@ -57,7 +57,7 @@ class PlainNet : public Net {
* Infer all the operators' input and output variables' shapes, will be called * Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch * before every mini-batch
*/ */
void InferShape(const ScopePtr& scope) const override { void InferShape(const std::shared_ptr<Scope>& scope) const override {
for (auto& op : ops_) { for (auto& op : ops_) {
op->InferShape(scope); op->InferShape(scope);
} }
...@@ -70,7 +70,7 @@ class PlainNet : public Net { ...@@ -70,7 +70,7 @@ class PlainNet : public Net {
* scope will be used instead. If no OpContext is provicded, default context * scope will be used instead. If no OpContext is provicded, default context
* will be used. * will be used.
*/ */
void Run(const ScopePtr& scope, void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
for (auto& op : ops_) { for (auto& op : ops_) {
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
...@@ -80,7 +80,7 @@ class PlainNet : public Net { ...@@ -80,7 +80,7 @@ class PlainNet : public Net {
/** /**
* @brief Add an operator by ptr * @brief Add an operator by ptr
*/ */
void AddOp(const OperatorPtr& op) override { void AddOp(const std::shared_ptr<OperatorBase>& op) override {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
ops_.push_back(op); ops_.push_back(op);
} }
...@@ -89,7 +89,7 @@ class PlainNet : public Net { ...@@ -89,7 +89,7 @@ class PlainNet : public Net {
std::string DebugString() const override; std::string DebugString() const override;
std::vector<OperatorPtr> ops_; std::vector<std::shared_ptr<OperatorBase>> ops_;
private: private:
bool add_op_done_{false}; bool add_op_done_{false};
......
...@@ -10,10 +10,10 @@ static int run_cnt = 0; ...@@ -10,10 +10,10 @@ static int run_cnt = 0;
class TestOp : public pd::OperatorBase { class TestOp : public pd::OperatorBase {
public: public:
void InferShape(const paddle::framework::ScopePtr& scope) const override { void InferShape(const std::shared_ptr<pd::Scope>& scope) const override {
++infer_shape_cnt; ++infer_shape_cnt;
} }
void Run(const paddle::framework::ScopePtr& scope, void Run(const std::shared_ptr<pd::Scope>& scope,
const paddle::platform::DeviceContext& dev_ctx) const override { const paddle::platform::DeviceContext& dev_ctx) const override {
++run_cnt; ++run_cnt;
} }
......
...@@ -227,10 +227,10 @@ class OpRegistry { ...@@ -227,10 +227,10 @@ class OpRegistry {
} }
} }
static OperatorPtr CreateOp(const std::string& type, static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameList& inputs, const VarNameList& inputs,
const VarNameList& outputs, const VarNameList& outputs,
const AttributeMap& attrs) { const AttributeMap& attrs) {
auto op_create_it = creators().find(type); auto op_create_it = creators().find(type);
PADDLE_ENFORCE(op_create_it != creators().end(), PADDLE_ENFORCE(op_create_it != creators().end(),
"Operator %s cannot be found", type); "Operator %s cannot be found", type);
...@@ -252,10 +252,10 @@ class OpRegistry { ...@@ -252,10 +252,10 @@ class OpRegistry {
} }
op->Init(); op->Init();
return OperatorPtr(op); return std::shared_ptr<OperatorBase>(op);
} }
static OperatorPtr CreateOp(const OpDesc& op_desc) { static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
std::vector<std::string> inputs; std::vector<std::string> inputs;
inputs.reserve((size_t)op_desc.inputs_size()); inputs.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
......
...@@ -7,9 +7,9 @@ namespace paddle { ...@@ -7,9 +7,9 @@ namespace paddle {
namespace framework { namespace framework {
class CosineOp : public OperatorBase { class CosineOp : public OperatorBase {
public: public:
void Run(const ScopePtr& scope, void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const ScopePtr& scope) const override {} void InferShape(const std::shared_ptr<Scope>& scope) const override {}
}; };
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...@@ -27,8 +27,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -27,8 +27,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase { class MyTestOp : public OperatorBase {
public: public:
void InferShape(const ScopePtr& scope) const override {} void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const ScopePtr& scope, void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
}; };
...@@ -67,7 +67,7 @@ TEST(OpRegistry, CreateOp) { ...@@ -67,7 +67,7 @@ TEST(OpRegistry, CreateOp) {
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(scale); attr->set_f(scale);
paddle::framework::OperatorPtr op = std::shared_ptr<paddle::framework::OperatorBase> op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<paddle::framework::Scope>(); auto scope = std::make_shared<paddle::framework::Scope>();
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
...@@ -89,8 +89,7 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -89,8 +89,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false; bool caught = false;
try { try {
paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) { } catch (std::runtime_error& err) {
caught = true; caught = true;
std::string msg = "larger_than check fail"; std::string msg = "larger_than check fail";
...@@ -110,7 +109,7 @@ TEST(OpRegistry, DefaultValue) { ...@@ -110,7 +109,7 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_TRUE(op_desc.IsInitialized()); ASSERT_TRUE(op_desc.IsInitialized());
paddle::framework::OperatorPtr op = std::shared_ptr<paddle::framework::OperatorBase> op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<paddle::framework::Scope>(); auto scope = std::make_shared<paddle::framework::Scope>();
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
...@@ -136,8 +135,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -136,8 +135,7 @@ TEST(OpRegistry, CustomChecker) {
// attr 'test_attr' is not set // attr 'test_attr' is not set
bool caught = false; bool caught = false;
try { try {
paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) { } catch (std::runtime_error& err) {
caught = true; caught = true;
std::string msg = "Attribute 'test_attr' is required!"; std::string msg = "Attribute 'test_attr' is required!";
...@@ -155,8 +153,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -155,8 +153,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(3); attr->set_i(3);
caught = false; caught = false;
try { try {
paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) { } catch (std::runtime_error& err) {
caught = true; caught = true;
std::string msg = "'test_attr' must be even!"; std::string msg = "'test_attr' must be even!";
...@@ -174,8 +171,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -174,8 +171,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_type(paddle::framework::AttrType::INT); attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4); attr->set_i(4);
SetInputFormat(&op_desc); SetInputFormat(&op_desc);
paddle::framework::OperatorPtr op = auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<paddle::framework::Scope>(); auto scope = std::make_shared<paddle::framework::Scope>();
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
......
...@@ -47,7 +47,6 @@ struct EigenDeviceConverter<platform::GPUPlace> { ...@@ -47,7 +47,6 @@ struct EigenDeviceConverter<platform::GPUPlace> {
#endif #endif
class OperatorBase; class OperatorBase;
using OperatorPtr = std::shared_ptr<OperatorBase>;
/** /**
* OperatorBase has the basic element that Net will call to do computation. * OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User * Only CreateOperator from OpRegistry will new Operator directly. User
...@@ -80,10 +79,10 @@ class OperatorBase { ...@@ -80,10 +79,10 @@ class OperatorBase {
/// InferShape infer the size of Variables used by this Operator with /// InferShape infer the size of Variables used by this Operator with
/// information inside scope /// information inside scope
virtual void InferShape(const ScopePtr& scope) const = 0; virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0;
/// Net will call this function to Run an op. /// Net will call this function to Run an op.
virtual void Run(const ScopePtr& scope, virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const = 0; const platform::DeviceContext& dev_ctx) const = 0;
// Get a input with argument's name described in `op_proto` // Get a input with argument's name described in `op_proto`
...@@ -208,7 +207,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -208,7 +207,7 @@ class OperatorWithKernel : public OperatorBase {
using OpKernelMap = using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>; std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void Run(const ScopePtr& scope, void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const final { const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(KernelContext(this, scope, dev_ctx)); opKernel->Compute(KernelContext(this, scope, dev_ctx));
......
...@@ -24,8 +24,8 @@ static int op_run_num = 0; ...@@ -24,8 +24,8 @@ static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase { class OpWithoutKernelTest : public OperatorBase {
public: public:
void Init() override { x = 1; } void Init() override { x = 1; }
void InferShape(const ScopePtr& scope) const override {} void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const ScopePtr& scope, void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
op_run_num++; op_run_num++;
ASSERT_EQ((int)inputs_.size(), 1); ASSERT_EQ((int)inputs_.size(), 1);
...@@ -70,8 +70,7 @@ TEST(OperatorBase, all) { ...@@ -70,8 +70,7 @@ TEST(OperatorBase, all) {
paddle::platform::CPUDeviceContext device_context; paddle::platform::CPUDeviceContext device_context;
auto scope = std::make_shared<paddle::framework::Scope>(); auto scope = std::make_shared<paddle::framework::Scope>();
paddle::framework::OperatorPtr op = auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::OpRegistry::CreateOp(op_desc);
scope->CreateVariable("OUT1"); scope->CreateVariable("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0); ASSERT_EQ(paddle::framework::op_run_num, 0);
op->Run(scope, device_context); op->Run(scope, device_context);
...@@ -189,8 +188,7 @@ TEST(OpKernel, all) { ...@@ -189,8 +188,7 @@ TEST(OpKernel, all) {
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<paddle::framework::Scope>(); auto scope = std::make_shared<paddle::framework::Scope>();
paddle::framework::OperatorPtr op = auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
...@@ -236,6 +234,6 @@ TEST(OpKernel, multi_inputs) { ...@@ -236,6 +234,6 @@ TEST(OpKernel, multi_inputs) {
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
OperatorPtr op(paddle::framework::OpRegistry::CreateOp(op_desc)); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);
} }
...@@ -24,7 +24,6 @@ namespace paddle { ...@@ -24,7 +24,6 @@ namespace paddle {
namespace framework { namespace framework {
class Scope; class Scope;
using ScopePtr = std::shared_ptr<Scope>;
/** /**
* @brief Scope that manage all variables. * @brief Scope that manage all variables.
...@@ -44,7 +43,7 @@ class Scope { ...@@ -44,7 +43,7 @@ class Scope {
/** /**
* @brief Initialize a Scope with parent. * @brief Initialize a Scope with parent.
*/ */
explicit Scope(const ScopePtr& parent) : parent_(parent) {} explicit Scope(const std::shared_ptr<Scope>& parent) : parent_(parent) {}
/** /**
* @brief Create Variable * @brief Create Variable
...@@ -91,7 +90,7 @@ class Scope { ...@@ -91,7 +90,7 @@ class Scope {
private: private:
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_; std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
ScopePtr parent_{nullptr}; std::shared_ptr<Scope> parent_{nullptr};
}; };
} // namespace framework } // namespace framework
......
...@@ -126,9 +126,10 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -126,9 +126,10 @@ All parameter, weight, gradient are variables in Paddle.
return new paddle::platform::CPUDeviceContext(); return new paddle::platform::CPUDeviceContext();
}); });
py::class_<pd::OperatorBase, pd::OperatorPtr> operator_base(m, "Operator"); py::class_<pd::OperatorBase, std::shared_ptr<pd::OperatorBase>> operator_base(
m, "Operator");
operator_base.def_static("create", [](py::bytes protobin) -> pd::OperatorPtr { operator_base.def_static("create", [](py::bytes protobin) {
pd::OpDesc desc; pd::OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc"); "Cannot parse user input to OpDesc");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册