提交 c2543f5b 编写于 作者: Y Yu Yang

Remove ScopePtr and OperatorPtr

* ScopePtr means pointer of scope, but it can be shared or uniqued.
Change it to std::shared_ptr<Scope> to make code better to read.
上级 1294b3c5
......@@ -39,7 +39,7 @@ namespace framework {
*/
class Net : public OperatorBase {
public:
virtual void AddOp(const OperatorPtr& op) = 0;
virtual void AddOp(const std::shared_ptr<OperatorBase>& op) = 0;
virtual void CompleteAddOp(bool calc) = 0;
};
......@@ -57,7 +57,7 @@ class PlainNet : public Net {
* Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch
*/
void InferShape(const ScopePtr& scope) const override {
void InferShape(const std::shared_ptr<Scope>& scope) const override {
for (auto& op : ops_) {
op->InferShape(scope);
}
......@@ -70,7 +70,7 @@ class PlainNet : public Net {
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
void Run(const ScopePtr& scope,
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
for (auto& op : ops_) {
op->Run(scope, dev_ctx);
......@@ -80,7 +80,7 @@ class PlainNet : public Net {
/**
* @brief Add an operator by ptr
*/
void AddOp(const OperatorPtr& op) override {
void AddOp(const std::shared_ptr<OperatorBase>& op) override {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
ops_.push_back(op);
}
......@@ -89,7 +89,7 @@ class PlainNet : public Net {
std::string DebugString() const override;
std::vector<OperatorPtr> ops_;
std::vector<std::shared_ptr<OperatorBase>> ops_;
private:
bool add_op_done_{false};
......
......@@ -10,10 +10,10 @@ static int run_cnt = 0;
class TestOp : public pd::OperatorBase {
public:
void InferShape(const paddle::framework::ScopePtr& scope) const override {
void InferShape(const std::shared_ptr<pd::Scope>& scope) const override {
++infer_shape_cnt;
}
void Run(const paddle::framework::ScopePtr& scope,
void Run(const std::shared_ptr<pd::Scope>& scope,
const paddle::platform::DeviceContext& dev_ctx) const override {
++run_cnt;
}
......
......@@ -227,10 +227,10 @@ class OpRegistry {
}
}
static OperatorPtr CreateOp(const std::string& type,
const VarNameList& inputs,
const VarNameList& outputs,
const AttributeMap& attrs) {
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameList& inputs,
const VarNameList& outputs,
const AttributeMap& attrs) {
auto op_create_it = creators().find(type);
PADDLE_ENFORCE(op_create_it != creators().end(),
"Operator %s cannot be found", type);
......@@ -252,10 +252,10 @@ class OpRegistry {
}
op->Init();
return OperatorPtr(op);
return std::shared_ptr<OperatorBase>(op);
}
static OperatorPtr CreateOp(const OpDesc& op_desc) {
static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
std::vector<std::string> inputs;
inputs.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
......
......@@ -7,9 +7,9 @@ namespace paddle {
namespace framework {
class CosineOp : public OperatorBase {
public:
void Run(const ScopePtr& scope,
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const ScopePtr& scope) const override {}
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
};
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
......@@ -27,8 +27,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase {
public:
void InferShape(const ScopePtr& scope) const override {}
void Run(const ScopePtr& scope,
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}
};
......@@ -67,7 +67,7 @@ TEST(OpRegistry, CreateOp) {
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(scale);
paddle::framework::OperatorPtr op =
std::shared_ptr<paddle::framework::OperatorBase> op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<paddle::framework::Scope>();
paddle::platform::CPUDeviceContext dev_ctx;
......@@ -89,8 +89,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false;
try {
paddle::framework::OperatorPtr op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) {
caught = true;
std::string msg = "larger_than check fail";
......@@ -110,7 +109,7 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_TRUE(op_desc.IsInitialized());
paddle::framework::OperatorPtr op =
std::shared_ptr<paddle::framework::OperatorBase> op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<paddle::framework::Scope>();
paddle::platform::CPUDeviceContext dev_ctx;
......@@ -136,8 +135,7 @@ TEST(OpRegistry, CustomChecker) {
// attr 'test_attr' is not set
bool caught = false;
try {
paddle::framework::OperatorPtr op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) {
caught = true;
std::string msg = "Attribute 'test_attr' is required!";
......@@ -155,8 +153,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(3);
caught = false;
try {
paddle::framework::OperatorPtr op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) {
caught = true;
std::string msg = "'test_attr' must be even!";
......@@ -174,8 +171,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4);
SetInputFormat(&op_desc);
paddle::framework::OperatorPtr op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<paddle::framework::Scope>();
op->Run(scope, dev_ctx);
......
......@@ -47,7 +47,6 @@ struct EigenDeviceConverter<platform::GPUPlace> {
#endif
class OperatorBase;
using OperatorPtr = std::shared_ptr<OperatorBase>;
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
......@@ -80,10 +79,10 @@ class OperatorBase {
/// InferShape infer the size of Variables used by this Operator with
/// information inside scope
virtual void InferShape(const ScopePtr& scope) const = 0;
virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0;
/// Net will call this function to Run an op.
virtual void Run(const ScopePtr& scope,
virtual void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const = 0;
// Get a input with argument's name described in `op_proto`
......@@ -208,7 +207,7 @@ class OperatorWithKernel : public OperatorBase {
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void Run(const ScopePtr& scope,
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(KernelContext(this, scope, dev_ctx));
......
......@@ -24,8 +24,8 @@ static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase {
public:
void Init() override { x = 1; }
void InferShape(const ScopePtr& scope) const override {}
void Run(const ScopePtr& scope,
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
op_run_num++;
ASSERT_EQ((int)inputs_.size(), 1);
......@@ -70,8 +70,7 @@ TEST(OperatorBase, all) {
paddle::platform::CPUDeviceContext device_context;
auto scope = std::make_shared<paddle::framework::Scope>();
paddle::framework::OperatorPtr op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
scope->CreateVariable("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0);
op->Run(scope, device_context);
......@@ -189,8 +188,7 @@ TEST(OpKernel, all) {
paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<paddle::framework::Scope>();
paddle::framework::OperatorPtr op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
op->Run(scope, cpu_device_context);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
......@@ -236,6 +234,6 @@ TEST(OpKernel, multi_inputs) {
paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>();
OperatorPtr op(paddle::framework::OpRegistry::CreateOp(op_desc));
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context);
}
......@@ -24,7 +24,6 @@ namespace paddle {
namespace framework {
class Scope;
using ScopePtr = std::shared_ptr<Scope>;
/**
* @brief Scope that manage all variables.
......@@ -44,7 +43,7 @@ class Scope {
/**
* @brief Initialize a Scope with parent.
*/
explicit Scope(const ScopePtr& parent) : parent_(parent) {}
explicit Scope(const std::shared_ptr<Scope>& parent) : parent_(parent) {}
/**
* @brief Create Variable
......@@ -91,7 +90,7 @@ class Scope {
private:
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
ScopePtr parent_{nullptr};
std::shared_ptr<Scope> parent_{nullptr};
};
} // namespace framework
......
......@@ -126,9 +126,10 @@ All parameter, weight, gradient are variables in Paddle.
return new paddle::platform::CPUDeviceContext();
});
py::class_<pd::OperatorBase, pd::OperatorPtr> operator_base(m, "Operator");
py::class_<pd::OperatorBase, std::shared_ptr<pd::OperatorBase>> operator_base(
m, "Operator");
operator_base.def_static("create", [](py::bytes protobin) -> pd::OperatorPtr {
operator_base.def_static("create", [](py::bytes protobin) {
pd::OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册