未验证 提交 de456e74 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #11913 from reyoung/feature/remove_clone_method

Remove Op::Clone method
...@@ -182,21 +182,15 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I, ...@@ -182,21 +182,15 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
VarTypeInference VarTypeInference
InferShapeBase InferShapeBase
*/ */
#define REGISTER_OPERATOR(op_type, op_class, ...) \ #define REGISTER_OPERATOR(op_type, op_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, \ __reg_op__##op_type, \
"REGISTER_OPERATOR must be called in global namespace"); \ "REGISTER_OPERATOR must be called in global namespace"); \
class _OpClass_##op_type##_ : public op_class { \ static ::paddle::framework::OperatorRegistrar<op_class, ##__VA_ARGS__> \
public: \ __op_registrar_##op_type##__(#op_type); \
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ int TouchOpRegistrar_##op_type() { \
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \ __op_registrar_##op_type##__.Touch(); \
}; \ return 0; \
static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_, \
##__VA_ARGS__> \
__op_registrar_##op_type##__(#op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
return 0; \
} }
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
......
...@@ -193,15 +193,10 @@ TEST(OpRegistry, CustomChecker) { ...@@ -193,15 +193,10 @@ TEST(OpRegistry, CustomChecker) {
ASSERT_EQ(test_attr, 4); ASSERT_EQ(test_attr, 4);
} }
class CosineOpComplete : public paddle::framework::CosineOp {
public:
DEFINE_OP_CONSTRUCTOR(CosineOpComplete, paddle::framework::CosineOp);
DEFINE_OP_CLONE_METHOD(CosineOpComplete);
};
TEST(OperatorRegistrar, Test) { TEST(OperatorRegistrar, Test) {
paddle::framework::OperatorRegistrar< paddle::framework::OperatorRegistrar<
CosineOpComplete, paddle::framework::CosineOpProtoAndCheckerMaker> paddle::framework::CosineOp,
paddle::framework::CosineOpProtoAndCheckerMaker>
reg("cos"); reg("cos");
} }
......
...@@ -121,10 +121,6 @@ class OperatorBase { ...@@ -121,10 +121,6 @@ class OperatorBase {
//! Get all outputs variable names //! Get all outputs variable names
virtual std::vector<std::string> OutputVars(bool has_intermediate) const; virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
// Return a new operator instance, which is as same as this.
// Use unique_ptr to prevent caller forget to delete this pointer.
virtual std::unique_ptr<OperatorBase> Clone() const = 0;
protected: protected:
std::string type_; std::string type_;
// NOTE: in case of OpGrad, inputs_ contains: // NOTE: in case of OpGrad, inputs_ contains:
...@@ -145,37 +141,6 @@ class OperatorBase { ...@@ -145,37 +141,6 @@ class OperatorBase {
const platform::Place& place) const = 0; const platform::Place& place) const = 0;
}; };
// Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you
// register it. i.e. `Clone` method is not needed to define by yourself.
#define DEFINE_OP_CLONE_METHOD(cls) \
std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final { \
return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \
}
// Macro for define a default constructor for Operator.
// You can also use
// using PARENT_CLASS::PARENT_CLASS;
// to use parent's constructor.
#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls) \
cls(const std::string& type, \
const ::paddle::framework::VariableNameMap& inputs, \
const ::paddle::framework::VariableNameMap& outputs, \
const paddle::framework::AttributeMap& attrs) \
: parent_cls(type, inputs, outputs, attrs) {}
class NOP : public OperatorBase {
public:
using OperatorBase::OperatorBase;
std::unique_ptr<OperatorBase> Clone() const override {
return std::unique_ptr<OperatorBase>(new NOP(*this));
}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
};
class ExecutionContext { class ExecutionContext {
public: public:
ExecutionContext(const OperatorBase& op, const Scope& scope, ExecutionContext(const OperatorBase& op, const Scope& scope,
......
...@@ -247,26 +247,3 @@ TEST(OpKernel, multi_inputs) { ...@@ -247,26 +247,3 @@ TEST(OpKernel, multi_inputs) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_place); op->Run(scope, cpu_place);
} }
class OperatorClone : public paddle::framework::OperatorBase {
public:
DEFINE_OP_CLONE_METHOD(OperatorClone);
OperatorClone(const std::string& type,
const paddle::framework::VariableNameMap& inputs,
const paddle::framework::VariableNameMap& outputs,
const paddle::framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const paddle::framework::Scope& scope,
const paddle::platform::Place& place) const override {}
};
TEST(Operator, Clone) {
paddle::framework::InitDevices(true);
OperatorClone a("ABC", paddle::framework::VariableNameMap{},
paddle::framework::VariableNameMap{},
paddle::framework::AttributeMap{});
auto b = a.Clone();
ASSERT_EQ(a.Type(), b->Type());
}
...@@ -22,6 +22,17 @@ limitations under the License. */ ...@@ -22,6 +22,17 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class NOP : public OperatorBase {
public:
NOP(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope &scope,
const platform::Place &place) const override {}
};
class SumOpMaker : public OpProtoAndCheckerMaker { class SumOpMaker : public OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册