diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index cc2234d50e8ce5a594e137b223e8308b8b9e2645..a76a4d60b4107dfa3b2a3cea09a443d7b136bcdc 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -144,8 +144,18 @@ class OpKernelRegistrar : public Registrar { grad_op_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ - static ::paddle::framework::OpRegistrar \ + class _OpClass_##op_type##_ : public op_class { \ + public: \ + DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ + DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \ + }; \ + class _OpGradClass_##op_type##_ : public grad_op_class { \ + public: \ + DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \ + DEFINE_OP_CONSTRUCTOR(_OpGradClass_##op_type##_, grad_op_class); \ + }; \ + static ::paddle::framework::OpRegistrar< \ + _OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \ __op_registrar_##op_type##__(#op_type, #grad_op_type); \ int TouchOpRegistrar_##op_type() { \ __op_registrar_##op_type##__.Touch(); \ @@ -176,7 +186,8 @@ class OpKernelRegistrar : public Registrar { REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) /** - * Macro to mark what Operator and Kernel we will use and tell the compiler to + * Macro to mark what Operator and Kernel + * we will use and tell the compiler to * link them into target. */ #define USE_OP_ITSELF(op_type) \ @@ -196,7 +207,8 @@ class OpKernelRegistrar : public Registrar { __attribute__((unused)) = \ TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() -// TODO(fengjiayi): The following macros seems ugly, do we have better method? +// TODO(fengjiayi): The following macros +// seems ugly, do we have better method? #ifdef PADDLE_ONLY_CPU #define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 2c8620a7ce007ede4e2bef089e2fc8902bf0c9f4..848baeeeb6493f61c41193a5cc0fc69e93934bfb 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -67,10 +67,6 @@ class OperatorBase { OperatorBase(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, const AttributeMap& attrs); - OperatorBase(const OperatorBase& o) = delete; - OperatorBase& operator=(const OperatorBase& o) = delete; - OperatorBase(OperatorBase&& o) = delete; - virtual ~OperatorBase() {} template @@ -116,10 +112,14 @@ class OperatorBase { void SetType(const std::string& type) { type_ = type; } const AttributeMap& Attrs() const { return attrs_; } + // Return a new operator instance, which is as same as this. + // Use unique_ptr to prevent caller forget to delete this pointer. + virtual std::unique_ptr Clone() const = 0; + protected: std::string type_; // NOTE: in case of OpGrad, inputs_ contains: - // I (Inputs) + // I (Inputs)opear // O (Outputs) // OG (Output Gradients) VarNameMap inputs_; @@ -130,12 +130,32 @@ class OperatorBase { AttributeMap attrs_; }; +// Macro for define a clone method. +// If you are writing an kernel operator, `Clone` will be defined when you +// register it. i.e. `Clone` method is not needed to define by yourself. +#define DEFINE_OP_CLONE_METHOD(CLS) \ + std::unique_ptr Clone() const final { \ + return std::unique_ptr(new CLS(*this)); \ + } + +// Macro for define a default constructor for Operator. +// You can also use +// using PARENT_CLASS::PARENT_CLASS; +// to use parent's constructor. +#define DEFINE_OP_CONSTRUCTOR(CLS, PARENT_CLS) \ + CLS(const std::string& type, const VarNameMap& inputs, \ + const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \ + : PARENT_CLS(type, inputs, outputs, attrs) {} + class NOP : public OperatorBase { public: using OperatorBase::OperatorBase; void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override {} + std::unique_ptr Clone() const override { + return std::unique_ptr(new NOP(*this)); + } }; // this class not only make proto but also init attribute checkers. diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 0441cec9f6d10246fba38b02b4de3cbe2ee4766b..2425b87779f6af01b0e8a91b5f574a28385f0efd 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -245,3 +245,21 @@ TEST(OpKernel, multi_inputs) { auto op = paddle::framework::OpRegistry::CreateOp(op_desc); op->Run(scope, cpu_device_context); } + +class OperatorClone : public paddle::framework::OperatorBase { + public: + DEFINE_OP_CLONE_METHOD(OperatorClone); + OperatorClone(const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, + const paddle::framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void InferShape(const paddle::framework::Scope& scope) const override {} + void Run(const paddle::framework::Scope& scope, + const paddle::platform::DeviceContext& dev_ctx) const override {} +}; + +TEST(Operator, Clone) { + OperatorClone a("ABC", {}, {}, {}); + auto b = a.Clone(); + ASSERT_EQ(a.Type(), b->Type()); +} \ No newline at end of file diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index c36fe8d6b58a0afa568e31e43567baa5f261c7d0..a7d710511093dfbe13a13b1222b0230bba0398bd 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -85,7 +85,14 @@ NetOp::NetOp(const std::string& type, const framework::OperatorBase::VarNameMap& inputs, const framework::OperatorBase::VarNameMap& outputs, const framework::AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs) {} + : framework::OperatorBase(type, inputs, outputs, attrs) {} + +std::unique_ptr NetOp::Clone() const { + PADDLE_ENFORCE( + add_op_done_, + "Must clone a sealed NetOp, invoke Net::CompleteAddOp before clone"); + return std::unique_ptr(new NetOp(*this)); +} } // namespace operators } // namespace paddle diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 4a3408c158a029a96740717280c1562671fa938f..743f0e67dbeaab2de97a6cf635aad0ee90b2cef1 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -41,6 +41,18 @@ class NetOp : public framework::OperatorBase { NetOp(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, const framework::AttributeMap& attrs); + NetOp(const NetOp& o) + : framework::OperatorBase( + static_cast(o)) { + this->ops_.reserve(o.ops_.size()); + std::transform(o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_), + [](const std::shared_ptr& op) + -> std::shared_ptr { + return std::shared_ptr(op->Clone()); + }); + this->CompleteAddOp(); + } + /** * Infer all the operators' input and output variables' shapes, will be called * before every mini-batch @@ -98,6 +110,8 @@ class NetOp : public framework::OperatorBase { bool IsNetOp() const override; std::vector OutputVars(bool has_intermediate) const override; + std::unique_ptr Clone() const override; + std::vector> ops_; private: diff --git a/paddle/operators/net_op_test.cc b/paddle/operators/net_op_test.cc index 0cef71de6a032674a54387986f65f17ca99b400e..e28d4df6a570968205851c2e5b630a14c0492535 100644 --- a/paddle/operators/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -13,6 +13,7 @@ static int run_cnt = 0; class TestOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; + DEFINE_OP_CLONE_METHOD(TestOp); void InferShape(const Scope& scope) const override { ++infer_shape_cnt; } void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override { @@ -70,5 +71,21 @@ TEST(NetOp, insert_op) { ASSERT_EQ(3UL, net.ops_.size()); } +TEST(NetOp, Clone) { + NetOp net; + net.AddOp( + std::shared_ptr(new framework::NOP{"empty", {}, {}, {}})); + net.AddOp(std::shared_ptr( + new framework::NOP{"empty2", {}, {}, {}})); + net.CompleteAddOp(true); + auto new_net_op = net.Clone(); + ASSERT_NE(new_net_op, nullptr); + ASSERT_TRUE(new_net_op->IsNetOp()); + auto* new_net = static_cast(new_net_op.get()); + ASSERT_EQ(2, new_net->ops_.size()); + ASSERT_EQ(new_net->ops_[0]->Type(), "empty"); + ASSERT_EQ(new_net->ops_[1]->Type(), "empty2"); +} + } // namespace operators } // namespace paddle diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index 171a0bd2ae80245939a9237f51d195e8bc9178fc..1d8a6973955cf0b4ab372412fbb5428ff2622a0a 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -110,13 +110,20 @@ class RecurrentGradientAlgorithm { std::shared_ptr* stepnet_; }; -class RecurrentOp final : public framework::OperatorBase { +class RecurrentOp : public framework::OperatorBase { public: RecurrentOp(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, const framework::AttributeMap& attrs); + + RecurrentOp(const RecurrentOp& o) + : framework::OperatorBase( + static_cast(o)) { + // TODO(yuyang18): Implement copy ctor well. + PADDLE_THROW("Not implemented"); + } /** - * InferShape must be called before Run. - */ + * InferShape must be called before Run. + */ void InferShape(const framework::Scope& scope) const override { alg_.InferShape(scope); } @@ -137,12 +144,19 @@ class RecurrentOp final : public framework::OperatorBase { std::shared_ptr stepnet_; }; -class RecurrentGradientOp final : public framework::OperatorBase { +class RecurrentGradientOp : public framework::OperatorBase { public: RecurrentGradientOp(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, const framework::AttributeMap& attrs); + RecurrentGradientOp(const RecurrentGradientOp& o) + : framework::OperatorBase( + static_cast(o)) { + // TODO(yuyang18): Implement Copy ctor. + PADDLE_THROW("Not Implemented"); + } + /** * InferShape must be called before Run. */