提交 29d892c1 编写于 作者: Y Yu Yang

Add Clone Method For OperatorBase

* Clone method will create a new object instance, which is as same as
  itself.
* This is the first step to remove shared_ptr for OperatorBase
上级 ffbb0be2
...@@ -271,7 +271,13 @@ class OpKernelRegistrar : public Registrar { ...@@ -271,7 +271,13 @@ class OpKernelRegistrar : public Registrar {
#define REGISTER_OP(op_type, op_class, op_maker_class) \ #define REGISTER_OP(op_type, op_class, op_maker_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \ class _OpClass_##op_type##_ : public op_class { \
public: \
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \
DEFINE_OP_CTOR(_OpClass_##op_type##_, op_class); \
}; \
static ::paddle::framework::OpRegistrar<_OpClass_##op_type##_, \
op_maker_class> \
__op_registrar_##op_type##__(#op_type); \ __op_registrar_##op_type##__(#op_type); \
int TouchOpRegistrar_##op_type() { \ int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \ __op_registrar_##op_type##__.Touch(); \
...@@ -285,7 +291,12 @@ class OpKernelRegistrar : public Registrar { ...@@ -285,7 +291,12 @@ class OpKernelRegistrar : public Registrar {
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##op_type##_##grad_op_type, \ __reg_gradient_op__##op_type##_##grad_op_type, \
"REGISTER_GRADIENT_OP must be called in global namespace"); \ "REGISTER_GRADIENT_OP must be called in global namespace"); \
static ::paddle::framework::GradOpRegistrar<grad_op_class> \ class _OpGradClass_##op_type##_ : public grad_op_class { \
public: \
DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \
DEFINE_OP_CTOR(_OpGradClass_##op_type##_, grad_op_class); \
}; \
static ::paddle::framework::GradOpRegistrar<_OpGradClass_##op_type##_> \
__op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \ __op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \
#grad_op_type); \ #grad_op_type); \
int TouchOpGradientRegistrar_##op_type() { \ int TouchOpGradientRegistrar_##op_type() { \
......
...@@ -69,10 +69,6 @@ class OperatorBase { ...@@ -69,10 +69,6 @@ class OperatorBase {
OperatorBase(const std::string& type, const VarNameMap& inputs, OperatorBase(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs); const VarNameMap& outputs, const AttributeMap& attrs);
OperatorBase(const OperatorBase& o) = delete;
OperatorBase& operator=(const OperatorBase& o) = delete;
OperatorBase(OperatorBase&& o) = delete;
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
template <typename T> template <typename T>
...@@ -115,6 +111,8 @@ class OperatorBase { ...@@ -115,6 +111,8 @@ class OperatorBase {
std::string Type() const { return type_; } std::string Type() const { return type_; }
const AttributeMap& Attrs() const { return attrs_; } const AttributeMap& Attrs() const { return attrs_; }
virtual OperatorBase* Clone() const = 0;
public: public:
std::string type_; std::string type_;
// NOTE: in case of OpGrad, inputs_ contains: // NOTE: in case of OpGrad, inputs_ contains:
...@@ -129,6 +127,14 @@ class OperatorBase { ...@@ -129,6 +127,14 @@ class OperatorBase {
AttributeMap attrs_; AttributeMap attrs_;
}; };
#define DEFINE_OP_CLONE_METHOD(CLS) \
OperatorBase* Clone() const final { return new CLS(*this); }
#define DEFINE_OP_CTOR(CLS, PARENT_CLS) \
CLS(const std::string& type, const VarNameMap& inputs, \
const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \
: PARENT_CLS(type, inputs, outputs, attrs) {}
class InferShapeContext { class InferShapeContext {
public: public:
InferShapeContext(const OperatorBase& op, const Scope& scope) InferShapeContext(const OperatorBase& op, const Scope& scope)
......
...@@ -242,3 +242,22 @@ TEST(OpKernel, multi_inputs) { ...@@ -242,3 +242,22 @@ TEST(OpKernel, multi_inputs) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);
} }
class OperatorClone : public paddle::framework::OperatorBase {
public:
DEFINE_OP_CLONE_METHOD(OperatorClone);
OperatorClone(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs,
const paddle::framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void InferShape(const paddle::framework::Scope& scope) const override {}
void Run(const paddle::framework::Scope& scope,
const paddle::platform::DeviceContext& dev_ctx) const override {}
};
TEST(Operator, Clone) {
OperatorClone a("ABC", {}, {}, {});
auto* b = a.Clone();
ASSERT_EQ(a.Type(), b->Type());
delete b;
}
\ No newline at end of file
...@@ -87,5 +87,12 @@ NetOp::NetOp(const std::string& type, ...@@ -87,5 +87,12 @@ NetOp::NetOp(const std::string& type,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
framework::OperatorBase* NetOp::Clone() const {
PADDLE_ENFORCE(
add_op_done_,
"Must clone a sealed NetOp, invoke Net::CompleteAddOp before clone");
return new NetOp(*this);
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -41,6 +41,18 @@ class NetOp : public framework::OperatorBase { ...@@ -41,6 +41,18 @@ class NetOp : public framework::OperatorBase {
NetOp(const std::string& type, const VarNameMap& inputs, NetOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const framework::AttributeMap& attrs); const VarNameMap& outputs, const framework::AttributeMap& attrs);
NetOp(const NetOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
this->ops_.reserve(o.ops_.size());
std::transform(o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_),
[](const std::shared_ptr<OperatorBase>& op)
-> std::shared_ptr<OperatorBase> {
return std::shared_ptr<OperatorBase>(op->Clone());
});
this->CompleteAddOp();
}
/** /**
* Infer all the operators' input and output variables' shapes, will be called * Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch * before every mini-batch
...@@ -97,6 +109,7 @@ class NetOp : public framework::OperatorBase { ...@@ -97,6 +109,7 @@ class NetOp : public framework::OperatorBase {
bool IsNetOp() const override; bool IsNetOp() const override;
std::vector<std::string> OutputVars(bool has_intermediate) const override; std::vector<std::string> OutputVars(bool has_intermediate) const override;
framework::OperatorBase* Clone() const override;
std::vector<std::shared_ptr<OperatorBase>> ops_; std::vector<std::shared_ptr<OperatorBase>> ops_;
......
...@@ -13,6 +13,7 @@ static int run_cnt = 0; ...@@ -13,6 +13,7 @@ static int run_cnt = 0;
class TestOp : public framework::OperatorBase { class TestOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
DEFINE_OP_CLONE_METHOD(TestOp);
void InferShape(const Scope& scope) const override { ++infer_shape_cnt; } void InferShape(const Scope& scope) const override { ++infer_shape_cnt; }
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
...@@ -23,6 +24,7 @@ class TestOp : public framework::OperatorBase { ...@@ -23,6 +24,7 @@ class TestOp : public framework::OperatorBase {
class EmptyOp : public framework::OperatorBase { class EmptyOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
DEFINE_OP_CLONE_METHOD(EmptyOp);
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {} void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {}
}; };
...@@ -77,5 +79,20 @@ TEST(NetOp, insert_op) { ...@@ -77,5 +79,20 @@ TEST(NetOp, insert_op) {
ASSERT_EQ(3UL, net.ops_.size()); ASSERT_EQ(3UL, net.ops_.size());
} }
TEST(NetOp, Clone) {
NetOp net;
net.AddOp(std::shared_ptr<EmptyOp>(new EmptyOp{"empty", {}, {}, {}}));
net.AddOp(std::shared_ptr<EmptyOp>(new EmptyOp{"empty2", {}, {}, {}}));
net.CompleteAddOp(true);
auto* new_net_op = net.Clone();
ASSERT_NE(new_net_op, nullptr);
ASSERT_TRUE(new_net_op->IsNetOp());
auto* new_net = static_cast<NetOp*>(new_net_op);
ASSERT_EQ(2, new_net->ops_.size());
ASSERT_EQ(new_net->ops_[0]->Type(), "empty");
ASSERT_EQ(new_net->ops_[1]->Type(), "empty2");
delete new_net;
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -99,10 +99,17 @@ class RecurrentGradientAlgorithm { ...@@ -99,10 +99,17 @@ class RecurrentGradientAlgorithm {
mutable size_t seq_len_; mutable size_t seq_len_;
}; };
class RecurrentOp final : public framework::OperatorBase { class RecurrentOp : public framework::OperatorBase {
public: public:
RecurrentOp(const std::string& type, const VarNameMap& inputs, RecurrentOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const framework::AttributeMap& attrs); const VarNameMap& outputs, const framework::AttributeMap& attrs);
RecurrentOp(const RecurrentOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
// TODO(yuyang18): Implement copy ctor well.
PADDLE_THROW("Not implemented");
}
/** /**
* InferShape must be called before Run. * InferShape must be called before Run.
*/ */
...@@ -121,12 +128,19 @@ class RecurrentOp final : public framework::OperatorBase { ...@@ -121,12 +128,19 @@ class RecurrentOp final : public framework::OperatorBase {
RecurrentAlgorithm alg_; RecurrentAlgorithm alg_;
}; };
class RecurrentGradientOp final : public framework::OperatorBase { class RecurrentGradientOp : public framework::OperatorBase {
public: public:
RecurrentGradientOp(const std::string& type, const VarNameMap& inputs, RecurrentGradientOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const VarNameMap& outputs,
const framework::AttributeMap& attrs); const framework::AttributeMap& attrs);
RecurrentGradientOp(const RecurrentGradientOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
// TODO(yuyang18): Implement Copy ctor.
PADDLE_THROW("Not Implemented");
}
/** /**
* InferShape must be called before Run. * InferShape must be called before Run.
*/ */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册