提交 58f3de95 编写于 作者: Q Qiao Longfei 提交者: GitHub

Optimize ptr (#2851)

* use OperatorPtr = std::shared_ptr<OperatorBase>;
* use ScopePtr = std::share_ptr<Scope>;
上级 2462d0c5
......@@ -5,13 +5,13 @@ namespace framework {
PlainNet::PlainNet(const NetDesc& def) {}
void PlainNet::InferShape(Scope* scope) {
void PlainNet::InferShape(const ScopePtr& scope) const {
for (auto& op : ops_) {
op.InferShape();
}
}
void PlainNet::Run(std::shared_ptr<Scope> scope, DeviceContext* ctx) {
void PlainNet::Run(const ScopePtr& scope, const DeviceContext& ctx) const {
for (auto& op : ops_) {
op.Run(ctx);
}
......
......@@ -37,8 +37,8 @@ struct OpAttrs {};
class Operator {
public:
Operator(const OpDesc &def) {}
void InferShape() {}
void Run(DeviceContext *ctx) {}
void InferShape() const {}
void Run(const DeviceContext &ctx) const {}
};
/**
......@@ -60,7 +60,7 @@ class Net {
/**
* @brief Infer shapes of all inputs and outputs of operators.
*/
virtual void InferShape(Scope *scope) = 0;
virtual void InferShape(const ScopePtr &scope) const = 0;
/**
* @brief Run the network.
*
......@@ -69,7 +69,7 @@ class Net {
* environment for ops. `begin` and `end` specify the scope of `ops_` to run,
* If no positive indexes are provided, all operators in `ops_` will run.
*/
virtual void Run(std::shared_ptr<Scope> scope, DeviceContext *ctx) = 0;
virtual void Run(const ScopePtr &scope, const DeviceContext &ctx) const = 0;
/**
* @brief Add an Operator according to `def`.
......@@ -114,7 +114,7 @@ class PlainNet : public Net {
* Infer all the operators' input and output varialbes' shapes, will be called
* before every mini-batch
*/
virtual void InferShape(Scope *scope) override;
virtual void InferShape(const ScopePtr &scope) const override;
/**
* @brief Run the network.
......@@ -123,7 +123,8 @@ class PlainNet : public Net {
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
virtual void Run(std::shared_ptr<Scope> scope, DeviceContext *ctx) override;
virtual void Run(const ScopePtr &scope,
const DeviceContext &ctx) const override;
/**
* @brief Add an operator to this network.
......
......@@ -198,9 +198,9 @@ class OpRegistry {
op_type, op_proto.InitializationErrorString());
}
static OperatorBase* CreateOp(const OpDesc& op_desc) {
static OperatorPtr CreateOp(const OpDesc& op_desc) {
std::string op_type = op_desc.type();
OperatorBase* op = creators().at(op_type)();
OperatorPtr op(creators().at(op_type)());
op->desc_ = op_desc;
op->inputs_.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
......
......@@ -5,9 +5,9 @@ namespace paddle {
namespace framework {
class CosineOp : public OperatorBase {
public:
void Run(const std::shared_ptr<Scope>& scope,
void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void InferShape(const ScopePtr& scope) const override {}
};
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
......@@ -25,8 +25,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
void InferShape(const ScopePtr& scope) const override {}
void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const override {}
public:
......@@ -67,7 +67,7 @@ TEST(OpRegistry, CreateOp) {
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(scale);
paddle::framework::OperatorBase* op =
paddle::framework::OperatorPtr op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<paddle::framework::Scope>();
paddle::platform::CPUDeviceContext dev_ctx;
......@@ -89,7 +89,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false;
try {
paddle::framework::OperatorBase* op __attribute__((unused)) =
paddle::framework::OperatorPtr op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) {
caught = true;
......@@ -110,7 +110,7 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_TRUE(op_desc.IsInitialized());
paddle::framework::OperatorBase* op =
paddle::framework::OperatorPtr op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<paddle::framework::Scope>();
paddle::platform::CPUDeviceContext dev_ctx;
......@@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) {
// attr 'test_attr' is not set
bool caught = false;
try {
paddle::framework::OperatorBase* op __attribute__((unused)) =
paddle::framework::OperatorPtr op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) {
caught = true;
......@@ -155,7 +155,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(3);
caught = false;
try {
paddle::framework::OperatorBase* op __attribute__((unused)) =
paddle::framework::OperatorPtr op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) {
caught = true;
......@@ -174,7 +174,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4);
SetInputFormat(&op_desc);
paddle::framework::OperatorBase* op =
paddle::framework::OperatorPtr op =
paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<paddle::framework::Scope>();
......
......@@ -30,7 +30,7 @@ namespace paddle {
namespace framework {
class OperatorBase;
using OperatorPtr = std::shared_ptr<OperatorBase>;
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
......@@ -56,10 +56,10 @@ class OperatorBase {
/// InferShape infer the size of Variables used by this Operator with
/// information inside scope
virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0;
virtual void InferShape(const ScopePtr& scope) const = 0;
/// Net will call this function to Run an op.
virtual void Run(const std::shared_ptr<Scope>& scope,
virtual void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const = 0;
protected:
......@@ -82,7 +82,7 @@ class OpKernel {
*/
class KernelContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
KernelContext(const OperatorBase* op, const ScopePtr& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}
......@@ -95,7 +95,7 @@ class OpKernel {
}
const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const ScopePtr& scope_;
const platform::DeviceContext& device_context_;
};
......@@ -140,7 +140,7 @@ class OperatorWithKernel : public OperatorBase {
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void Run(const std::shared_ptr<Scope>& scope,
void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx));
opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx));
......
......@@ -22,8 +22,8 @@ namespace framework {
class OperatorTest : public OperatorBase {
public:
void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
void InferShape(const ScopePtr& scope) const override {}
void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const override {
float scale = GetAttr<float>("scale");
ASSERT_NEAR(scale, 3.14, 1e-5);
......@@ -36,6 +36,50 @@ class OperatorTest : public OperatorBase {
float x = 0;
};
class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OperatorTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
AddOutput("output", "output of test op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddComment("This is test op");
}
};
} // namespace framework
} // namespace paddle
REGISTER_OP(test_operator, paddle::framework::OperatorTest,
paddle::framework::OperatorTestProtoAndCheckerMaker);
TEST(OperatorBase, all) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("test_operator");
*op_desc.mutable_inputs()->Add() = "IN1";
*op_desc.mutable_outputs()->Add() = "OUT1";
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
float scale = 3.14;
attr->set_f(scale);
paddle::platform::CPUDeviceContext device_context;
auto scope = std::make_shared<paddle::framework::Scope>();
paddle::framework::OperatorPtr op =
paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(op->GetAttr<float>("scale"), scale);
scope->CreateVariable("OUT1");
op->Run(scope, device_context);
std::cout << op->DebugString() << std::endl;
}
namespace paddle {
namespace framework {
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
......@@ -73,9 +117,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest);
TEST(OpKernel, all) {
using namespace paddle::framework;
OpDesc op_desc;
paddle::framework::OpDesc op_desc;
op_desc.set_type("op_with_kernel");
*op_desc.mutable_inputs()->Add() = "IN1";
*op_desc.mutable_outputs()->Add() = "OUT1";
......@@ -85,10 +127,9 @@ TEST(OpKernel, all) {
attr->set_f(3.14);
paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>();
auto scope = std::make_shared<paddle::framework::Scope>();
OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::OperatorPtr op =
paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context);
delete op;
}
......@@ -23,6 +23,9 @@ limitations under the License. */
namespace paddle {
namespace framework {
class Scope;
using ScopePtr = std::shared_ptr<Scope>;
/**
* @brief Scope that manage all variables.
*
......@@ -41,7 +44,7 @@ class Scope {
/**
* @brief Initialize a Scope with parent.
*/
explicit Scope(const std::shared_ptr<Scope>& parent) : parent_(parent) {}
explicit Scope(const ScopePtr& parent) : parent_(parent) {}
/**
* @brief Create Variable
......@@ -88,7 +91,7 @@ class Scope {
private:
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
std::shared_ptr<Scope> parent_{nullptr};
ScopePtr parent_{nullptr};
};
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册