提交 29d892c1 编写于 作者: Y Yu Yang

Add Clone Method For OperatorBase

* Clone method will create a new object instance, which is as same as
  itself.
* This is the first step to remove shared_ptr for OperatorBase
上级 ffbb0be2
......@@ -271,7 +271,13 @@ class OpKernelRegistrar : public Registrar {
#define REGISTER_OP(op_type, op_class, op_maker_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \
class _OpClass_##op_type##_ : public op_class { \
public: \
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \
DEFINE_OP_CTOR(_OpClass_##op_type##_, op_class); \
}; \
static ::paddle::framework::OpRegistrar<_OpClass_##op_type##_, \
op_maker_class> \
__op_registrar_##op_type##__(#op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
......@@ -285,7 +291,12 @@ class OpKernelRegistrar : public Registrar {
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##op_type##_##grad_op_type, \
"REGISTER_GRADIENT_OP must be called in global namespace"); \
static ::paddle::framework::GradOpRegistrar<grad_op_class> \
class _OpGradClass_##op_type##_ : public grad_op_class { \
public: \
DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \
DEFINE_OP_CTOR(_OpGradClass_##op_type##_, grad_op_class); \
}; \
static ::paddle::framework::GradOpRegistrar<_OpGradClass_##op_type##_> \
__op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \
#grad_op_type); \
int TouchOpGradientRegistrar_##op_type() { \
......
......@@ -69,10 +69,6 @@ class OperatorBase {
OperatorBase(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs);
OperatorBase(const OperatorBase& o) = delete;
OperatorBase& operator=(const OperatorBase& o) = delete;
OperatorBase(OperatorBase&& o) = delete;
virtual ~OperatorBase() {}
template <typename T>
......@@ -115,6 +111,8 @@ class OperatorBase {
std::string Type() const { return type_; }
const AttributeMap& Attrs() const { return attrs_; }
virtual OperatorBase* Clone() const = 0;
public:
std::string type_;
// NOTE: in case of OpGrad, inputs_ contains:
......@@ -129,6 +127,14 @@ class OperatorBase {
AttributeMap attrs_;
};
#define DEFINE_OP_CLONE_METHOD(CLS) \
OperatorBase* Clone() const final { return new CLS(*this); }
#define DEFINE_OP_CTOR(CLS, PARENT_CLS) \
CLS(const std::string& type, const VarNameMap& inputs, \
const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \
: PARENT_CLS(type, inputs, outputs, attrs) {}
class InferShapeContext {
public:
InferShapeContext(const OperatorBase& op, const Scope& scope)
......
......@@ -242,3 +242,22 @@ TEST(OpKernel, multi_inputs) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context);
}
class OperatorClone : public paddle::framework::OperatorBase {
public:
DEFINE_OP_CLONE_METHOD(OperatorClone);
OperatorClone(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs,
const paddle::framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void InferShape(const paddle::framework::Scope& scope) const override {}
void Run(const paddle::framework::Scope& scope,
const paddle::platform::DeviceContext& dev_ctx) const override {}
};
TEST(Operator, Clone) {
OperatorClone a("ABC", {}, {}, {});
auto* b = a.Clone();
ASSERT_EQ(a.Type(), b->Type());
delete b;
}
\ No newline at end of file
......@@ -87,5 +87,12 @@ NetOp::NetOp(const std::string& type,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
framework::OperatorBase* NetOp::Clone() const {
PADDLE_ENFORCE(
add_op_done_,
"Must clone a sealed NetOp, invoke Net::CompleteAddOp before clone");
return new NetOp(*this);
}
} // namespace operators
} // namespace paddle
......@@ -41,6 +41,18 @@ class NetOp : public framework::OperatorBase {
NetOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const framework::AttributeMap& attrs);
NetOp(const NetOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
this->ops_.reserve(o.ops_.size());
std::transform(o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_),
[](const std::shared_ptr<OperatorBase>& op)
-> std::shared_ptr<OperatorBase> {
return std::shared_ptr<OperatorBase>(op->Clone());
});
this->CompleteAddOp();
}
/**
* Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch
......@@ -97,6 +109,7 @@ class NetOp : public framework::OperatorBase {
bool IsNetOp() const override;
std::vector<std::string> OutputVars(bool has_intermediate) const override;
framework::OperatorBase* Clone() const override;
std::vector<std::shared_ptr<OperatorBase>> ops_;
......
......@@ -13,6 +13,7 @@ static int run_cnt = 0;
class TestOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
DEFINE_OP_CLONE_METHOD(TestOp);
void InferShape(const Scope& scope) const override { ++infer_shape_cnt; }
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
......@@ -23,6 +24,7 @@ class TestOp : public framework::OperatorBase {
class EmptyOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
DEFINE_OP_CLONE_METHOD(EmptyOp);
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {}
};
......@@ -77,5 +79,20 @@ TEST(NetOp, insert_op) {
ASSERT_EQ(3UL, net.ops_.size());
}
TEST(NetOp, Clone) {
NetOp net;
net.AddOp(std::shared_ptr<EmptyOp>(new EmptyOp{"empty", {}, {}, {}}));
net.AddOp(std::shared_ptr<EmptyOp>(new EmptyOp{"empty2", {}, {}, {}}));
net.CompleteAddOp(true);
auto* new_net_op = net.Clone();
ASSERT_NE(new_net_op, nullptr);
ASSERT_TRUE(new_net_op->IsNetOp());
auto* new_net = static_cast<NetOp*>(new_net_op);
ASSERT_EQ(2, new_net->ops_.size());
ASSERT_EQ(new_net->ops_[0]->Type(), "empty");
ASSERT_EQ(new_net->ops_[1]->Type(), "empty2");
delete new_net;
}
} // namespace operators
} // namespace paddle
......@@ -99,13 +99,20 @@ class RecurrentGradientAlgorithm {
mutable size_t seq_len_;
};
class RecurrentOp final : public framework::OperatorBase {
class RecurrentOp : public framework::OperatorBase {
public:
RecurrentOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const framework::AttributeMap& attrs);
RecurrentOp(const RecurrentOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
// TODO(yuyang18): Implement copy ctor well.
PADDLE_THROW("Not implemented");
}
/**
* InferShape must be called before Run.
*/
* InferShape must be called before Run.
*/
void InferShape(const framework::Scope& scope) const override {
alg_.InferShape(scope);
}
......@@ -121,12 +128,19 @@ class RecurrentOp final : public framework::OperatorBase {
RecurrentAlgorithm alg_;
};
class RecurrentGradientOp final : public framework::OperatorBase {
class RecurrentGradientOp : public framework::OperatorBase {
public:
RecurrentGradientOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs,
const framework::AttributeMap& attrs);
RecurrentGradientOp(const RecurrentGradientOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
// TODO(yuyang18): Implement Copy ctor.
PADDLE_THROW("Not Implemented");
}
/**
* InferShape must be called before Run.
*/
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册