diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 3b793628aa6fdb08544ba90274736c9d29262a8b..b5b4668074cc79511f9b5e9bb462701360621ba1 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -271,7 +271,13 @@ class OpKernelRegistrar : public Registrar { #define REGISTER_OP(op_type, op_class, op_maker_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ - static ::paddle::framework::OpRegistrar \ + class _OpClass_##op_type##_ : public op_class { \ + public: \ + DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ + DEFINE_OP_CTOR(_OpClass_##op_type##_, op_class); \ + }; \ + static ::paddle::framework::OpRegistrar<_OpClass_##op_type##_, \ + op_maker_class> \ __op_registrar_##op_type##__(#op_type); \ int TouchOpRegistrar_##op_type() { \ __op_registrar_##op_type##__.Touch(); \ @@ -285,7 +291,12 @@ class OpKernelRegistrar : public Registrar { STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_gradient_op__##op_type##_##grad_op_type, \ "REGISTER_GRADIENT_OP must be called in global namespace"); \ - static ::paddle::framework::GradOpRegistrar \ + class _OpGradClass_##op_type##_ : public grad_op_class { \ + public: \ + DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \ + DEFINE_OP_CTOR(_OpGradClass_##op_type##_, grad_op_class); \ + }; \ + static ::paddle::framework::GradOpRegistrar<_OpGradClass_##op_type##_> \ __op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \ #grad_op_type); \ int TouchOpGradientRegistrar_##op_type() { \ diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 4a72ced6ced92054eb170cd3012cafb181744953..92032478661be65481bbd66693df95df83a678e7 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -69,10 +69,6 @@ class OperatorBase { OperatorBase(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, const AttributeMap& attrs); - OperatorBase(const OperatorBase& o) = delete; - OperatorBase& operator=(const OperatorBase& o) = delete; - OperatorBase(OperatorBase&& o) = delete; - virtual ~OperatorBase() {} template @@ -115,6 +111,8 @@ class OperatorBase { std::string Type() const { return type_; } const AttributeMap& Attrs() const { return attrs_; } + virtual OperatorBase* Clone() const = 0; + public: std::string type_; // NOTE: in case of OpGrad, inputs_ contains: @@ -129,6 +127,14 @@ class OperatorBase { AttributeMap attrs_; }; +#define DEFINE_OP_CLONE_METHOD(CLS) \ + OperatorBase* Clone() const final { return new CLS(*this); } + +#define DEFINE_OP_CTOR(CLS, PARENT_CLS) \ + CLS(const std::string& type, const VarNameMap& inputs, \ + const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \ + : PARENT_CLS(type, inputs, outputs, attrs) {} + class InferShapeContext { public: InferShapeContext(const OperatorBase& op, const Scope& scope) diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 6804841587730d51d9cfad30a9de81401d36695b..ceba7f5e6ea485f0e3642dcb4984822a783d6d4d 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -242,3 +242,22 @@ TEST(OpKernel, multi_inputs) { auto op = paddle::framework::OpRegistry::CreateOp(op_desc); op->Run(scope, cpu_device_context); } + +class OperatorClone : public paddle::framework::OperatorBase { + public: + DEFINE_OP_CLONE_METHOD(OperatorClone); + OperatorClone(const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, + const paddle::framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void InferShape(const paddle::framework::Scope& scope) const override {} + void Run(const paddle::framework::Scope& scope, + const paddle::platform::DeviceContext& dev_ctx) const override {} +}; + +TEST(Operator, Clone) { + OperatorClone a("ABC", {}, {}, {}); + auto* b = a.Clone(); + ASSERT_EQ(a.Type(), b->Type()); + delete b; +} \ No newline at end of file diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index 1d1b290440ec125bdb5b190745735dd077261731..896550f9d0c55665a9c9d6f133dc8f27bc5a07f1 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -87,5 +87,12 @@ NetOp::NetOp(const std::string& type, const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} +framework::OperatorBase* NetOp::Clone() const { + PADDLE_ENFORCE( + add_op_done_, + "Must clone a sealed NetOp, invoke Net::CompleteAddOp before clone"); + return new NetOp(*this); +} + } // namespace operators } // namespace paddle diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 4a3408c158a029a96740717280c1562671fa938f..deee543065ab716d9d802647745c3884c1e80380 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -41,6 +41,18 @@ class NetOp : public framework::OperatorBase { NetOp(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, const framework::AttributeMap& attrs); + NetOp(const NetOp& o) + : framework::OperatorBase( + static_cast(o)) { + this->ops_.reserve(o.ops_.size()); + std::transform(o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_), + [](const std::shared_ptr& op) + -> std::shared_ptr { + return std::shared_ptr(op->Clone()); + }); + this->CompleteAddOp(); + } + /** * Infer all the operators' input and output variables' shapes, will be called * before every mini-batch @@ -97,6 +109,7 @@ class NetOp : public framework::OperatorBase { bool IsNetOp() const override; std::vector OutputVars(bool has_intermediate) const override; + framework::OperatorBase* Clone() const override; std::vector> ops_; diff --git a/paddle/operators/net_op_test.cc b/paddle/operators/net_op_test.cc index f7aa56262ef71c24bf668950f6e9914e5f96ff70..40e43f46df79fa311f7113b347a2d7c179c30eeb 100644 --- a/paddle/operators/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -13,6 +13,7 @@ static int run_cnt = 0; class TestOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; + DEFINE_OP_CLONE_METHOD(TestOp); void InferShape(const Scope& scope) const override { ++infer_shape_cnt; } void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override { @@ -23,6 +24,7 @@ class TestOp : public framework::OperatorBase { class EmptyOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; + DEFINE_OP_CLONE_METHOD(EmptyOp); void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {} }; @@ -77,5 +79,20 @@ TEST(NetOp, insert_op) { ASSERT_EQ(3UL, net.ops_.size()); } +TEST(NetOp, Clone) { + NetOp net; + net.AddOp(std::shared_ptr(new EmptyOp{"empty", {}, {}, {}})); + net.AddOp(std::shared_ptr(new EmptyOp{"empty2", {}, {}, {}})); + net.CompleteAddOp(true); + auto* new_net_op = net.Clone(); + ASSERT_NE(new_net_op, nullptr); + ASSERT_TRUE(new_net_op->IsNetOp()); + auto* new_net = static_cast(new_net_op); + ASSERT_EQ(2, new_net->ops_.size()); + ASSERT_EQ(new_net->ops_[0]->Type(), "empty"); + ASSERT_EQ(new_net->ops_[1]->Type(), "empty2"); + delete new_net; +} + } // namespace operators } // namespace paddle diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index 8f4f2444d844b4ba5948f001a365a7ecaeecc106..cc40eff0cf94c75dd21c0e4e8c2b28ef0ed5c3e4 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -99,13 +99,20 @@ class RecurrentGradientAlgorithm { mutable size_t seq_len_; }; -class RecurrentOp final : public framework::OperatorBase { +class RecurrentOp : public framework::OperatorBase { public: RecurrentOp(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, const framework::AttributeMap& attrs); + + RecurrentOp(const RecurrentOp& o) + : framework::OperatorBase( + static_cast(o)) { + // TODO(yuyang18): Implement copy ctor well. + PADDLE_THROW("Not implemented"); + } /** - * InferShape must be called before Run. - */ + * InferShape must be called before Run. + */ void InferShape(const framework::Scope& scope) const override { alg_.InferShape(scope); } @@ -121,12 +128,19 @@ class RecurrentOp final : public framework::OperatorBase { RecurrentAlgorithm alg_; }; -class RecurrentGradientOp final : public framework::OperatorBase { +class RecurrentGradientOp : public framework::OperatorBase { public: RecurrentGradientOp(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, const framework::AttributeMap& attrs); + RecurrentGradientOp(const RecurrentGradientOp& o) + : framework::OperatorBase( + static_cast(o)) { + // TODO(yuyang18): Implement Copy ctor. + PADDLE_THROW("Not Implemented"); + } + /** * InferShape must be called before Run. */