提交 e14a4541 编写于 作者: F fengjiayi

Refactor registry macro

上级 78c3e1de
...@@ -307,7 +307,10 @@ class OpRegistry { ...@@ -307,7 +307,10 @@ class OpRegistry {
} }
}; };
class Registrar {}; class Registrar {
public:
void Touch() {}
};
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
class OpRegistrar : public Registrar { class OpRegistrar : public Registrar {
...@@ -336,8 +339,6 @@ class OpKernelRegistrar : public Registrar { ...@@ -336,8 +339,6 @@ class OpKernelRegistrar : public Registrar {
} }
}; };
int TouchRegistrar(const Registrar& registrar) { return 0; }
/** /**
* check if MACRO is used in GLOBAL NAMESPACE. * check if MACRO is used in GLOBAL NAMESPACE.
*/ */
...@@ -354,7 +355,11 @@ int TouchRegistrar(const Registrar& registrar) { return 0; } ...@@ -354,7 +355,11 @@ int TouchRegistrar(const Registrar& registrar) { return 0; }
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \ static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \
__op_registrar_##op_type##__(#op_type); __op_registrar_##op_type##__(#op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
return 0; \
}
/** /**
* Macro to Register Gradient Operator. * Macro to Register Gradient Operator.
...@@ -364,18 +369,26 @@ int TouchRegistrar(const Registrar& registrar) { return 0; } ...@@ -364,18 +369,26 @@ int TouchRegistrar(const Registrar& registrar) { return 0; }
__reg_gradient_op__##op_type##_##grad_op_type, \ __reg_gradient_op__##op_type##_##grad_op_type, \
"REGISTER_GRADIENT_OP must be called in global namespace"); \ "REGISTER_GRADIENT_OP must be called in global namespace"); \
static ::paddle::framework::GradOpRegistrar<grad_op_class> \ static ::paddle::framework::GradOpRegistrar<grad_op_class> \
__op_gradient_register_##op_type##_##grad_op_type##__(#op_type, \ __op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \
#grad_op_type); #grad_op_type); \
int TouchOpGradientRegister_##op_type() { \
__op_gradient_registrar_##op_type##_##grad_op_type##__.Touch(); \
return 0; \
}
/** /**
* Macro to Register OperatorKernel. * Macro to Register OperatorKernel.
*/ */
#define REGISTER_OP_KERNEL(op_type, DEVICE_TYPE, place_class, kernel_class) \ #define REGISTER_OP_KERNEL(op_type, DEVICE_TYPE, place_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \ __reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \
"REGISTER_OP_KERNEL must be called in global namespace"); \ "REGISTER_OP_KERNEL must be called in global namespace"); \
static ::paddle::framework::OpKernelRegistrar<place_class, kernel_class> \ static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); \
int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() { \
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__.Touch(); \
return 0; \
}
/** /**
* Macro to Forbid user register Gradient Operator. * Macro to Forbid user register Gradient Operator.
...@@ -385,11 +398,11 @@ int TouchRegistrar(const Registrar& registrar) { return 0; } ...@@ -385,11 +398,11 @@ int TouchRegistrar(const Registrar& registrar) { return 0; }
__reg_gradient_op__##op_type##_##op_type##_grad, \ __reg_gradient_op__##op_type##_##op_type##_grad, \
"NO_GRADIENT must be called in global namespace") "NO_GRADIENT must be called in global namespace")
#define REGISTER_OP_GPU_KERNEL(op_type, kernel_class) \ #define REGISTER_OP_GPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, kernel_class) REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL(op_type, kernel_class) \ #define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, kernel_class) REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
/** /**
* Macro to mark what Operator and Kernel we will use and tell the compiler to * Macro to mark what Operator and Kernel we will use and tell the compiler to
...@@ -399,30 +412,27 @@ int TouchRegistrar(const Registrar& registrar) { return 0; } ...@@ -399,30 +412,27 @@ int TouchRegistrar(const Registrar& registrar) { return 0; }
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_itself_##op_type, \ __use_op_itself_##op_type, \
"USE_OP_ITSELF must be called in global namespace"); \ "USE_OP_ITSELF must be called in global namespace"); \
extern ::paddle::framework::OpRegistrar<op_class, op_maker_class> \ extern int TouchOpRegistrar_##op_type(); \
__op_registrar_##op_type##__; \ static int use_op_itself_##op_type##_ __attribute__((unused)) = \
static int __use_op_ptr_##op_type##_without_kernel__ \ TouchOpRegistrar_##op_type##()
__attribute__((unused)) = __op_register_##op_type##_handle__()
#define USE_OP_KERNEL(op_type, DEVICE_TYPE) \ #define USE_OP_KERNEL(op_type, DEVICE_TYPE) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \ __use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
"USE_OP_KERNEL must be in global namespace"); \ "USE_OP_KERNEL must be in global namespace"); \
extern int __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__(); \ extern int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE(); \
static int __use_op_ptr_##op_type##_##DEVICE_TYPE##_kernel__ \ static int use_op_kernel_##op_type##_##DEVICE_TYPE##_ \
__attribute__((unused)) = \ __attribute__((unused)) = \
__op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__() TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE##()
// use Operator with only cpu kernel. #ifdef PADDLE_ONLY_CPU
#define USE_OP_CPU(op_type) \ #define USE_OP(op_type) \
USE_OP_ITSELF(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type, CPU) USE_OP_KERNEL(op_type, CPU)
#ifdef PADDLE_ONLY_CPU
#define USE_OP(op_type) USE_OP_CPU(op_type)
#else #else
#define USE_OP(op_type) \ #define USE_OP(op_type) \
USE_OP_CPU(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type, CPU); \
USE_OP_KERNEL(op_type, GPU) USE_OP_KERNEL(op_type, GPU)
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册