From e14a4541dd8f85a49ee3c42429f0f663864f1e0a Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 9 Aug 2017 13:16:08 -0700 Subject: [PATCH] Refactor registry macro --- paddle/framework/op_registry.h | 102 ++++++++++++++++++--------------- 1 file changed, 56 insertions(+), 46 deletions(-) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index b3663f8bf..0ac3ffda2 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -307,7 +307,10 @@ class OpRegistry { } }; -class Registrar {}; +class Registrar { + public: + void Touch() {} +}; template class OpRegistrar : public Registrar { @@ -336,8 +339,6 @@ class OpKernelRegistrar : public Registrar { } }; -int TouchRegistrar(const Registrar& registrar) { return 0; } - /** * check if MACRO is used in GLOBAL NAMESPACE. */ @@ -354,28 +355,40 @@ int TouchRegistrar(const Registrar& registrar) { return 0; } STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ static ::paddle::framework::OpRegistrar \ - __op_registrar_##op_type##__(#op_type); + __op_registrar_##op_type##__(#op_type); \ + int TouchOpRegistrar_##op_type() { \ + __op_registrar_##op_type##__.Touch(); \ + return 0; \ + } /** * Macro to Register Gradient Operator. */ -#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_gradient_op__##op_type##_##grad_op_type, \ - "REGISTER_GRADIENT_OP must be called in global namespace"); \ - static ::paddle::framework::GradOpRegistrar \ - __op_gradient_register_##op_type##_##grad_op_type##__(#op_type, \ - #grad_op_type); +#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_gradient_op__##op_type##_##grad_op_type, \ + "REGISTER_GRADIENT_OP must be called in global namespace"); \ + static ::paddle::framework::GradOpRegistrar \ + __op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \ + #grad_op_type); \ + int TouchOpGradientRegister_##op_type() { \ + __op_gradient_registrar_##op_type##_##grad_op_type##__.Touch(); \ + return 0; \ + } /** * Macro to Register OperatorKernel. */ -#define REGISTER_OP_KERNEL(op_type, DEVICE_TYPE, place_class, kernel_class) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \ - "REGISTER_OP_KERNEL must be called in global namespace"); \ - static ::paddle::framework::OpKernelRegistrar \ - __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); +#define REGISTER_OP_KERNEL(op_type, DEVICE_TYPE, place_class, ...) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \ + "REGISTER_OP_KERNEL must be called in global namespace"); \ + static ::paddle::framework::OpKernelRegistrar \ + __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); \ + int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() { \ + __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__.Touch(); \ + return 0; \ + } /** * Macro to Forbid user register Gradient Operator. @@ -385,44 +398,41 @@ int TouchRegistrar(const Registrar& registrar) { return 0; } __reg_gradient_op__##op_type##_##op_type##_grad, \ "NO_GRADIENT must be called in global namespace") -#define REGISTER_OP_GPU_KERNEL(op_type, kernel_class) \ - REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, kernel_class) +#define REGISTER_OP_GPU_KERNEL(op_type, ...) \ + REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__) -#define REGISTER_OP_CPU_KERNEL(op_type, kernel_class) \ - REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, kernel_class) +#define REGISTER_OP_CPU_KERNEL(op_type, ...) \ + REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) /** * Macro to mark what Operator and Kernel we will use and tell the compiler to * link them into target. */ -#define USE_OP_ITSELF(op_type) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __use_op_itself_##op_type, \ - "USE_OP_ITSELF must be called in global namespace"); \ - extern ::paddle::framework::OpRegistrar \ - __op_registrar_##op_type##__; \ - static int __use_op_ptr_##op_type##_without_kernel__ \ - __attribute__((unused)) = __op_register_##op_type##_handle__() - -#define USE_OP_KERNEL(op_type, DEVICE_TYPE) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __use_op_kernel_##op_type##_##DEVICE_TYPE##__, \ - "USE_OP_KERNEL must be in global namespace"); \ - extern int __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__(); \ - static int __use_op_ptr_##op_type##_##DEVICE_TYPE##_kernel__ \ - __attribute__((unused)) = \ - __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__() - -// use Operator with only cpu kernel. -#define USE_OP_CPU(op_type) \ - USE_OP_ITSELF(op_type); \ - USE_OP_KERNEL(op_type, CPU) +#define USE_OP_ITSELF(op_type) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __use_op_itself_##op_type, \ + "USE_OP_ITSELF must be called in global namespace"); \ + extern int TouchOpRegistrar_##op_type(); \ + static int use_op_itself_##op_type##_ __attribute__((unused)) = \ + TouchOpRegistrar_##op_type##() + +#define USE_OP_KERNEL(op_type, DEVICE_TYPE) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __use_op_kernel_##op_type##_##DEVICE_TYPE##__, \ + "USE_OP_KERNEL must be in global namespace"); \ + extern int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE(); \ + static int use_op_kernel_##op_type##_##DEVICE_TYPE##_ \ + __attribute__((unused)) = \ + TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE##() #ifdef PADDLE_ONLY_CPU -#define USE_OP(op_type) USE_OP_CPU(op_type) +#define USE_OP(op_type) \ + USE_OP_ITSELF(op_type); \ + USE_OP_KERNEL(op_type, CPU) #else -#define USE_OP(op_type) \ - USE_OP_CPU(op_type); \ +#define USE_OP(op_type) \ + USE_OP_ITSELF(op_type); \ + USE_OP_KERNEL(op_type, CPU); \ USE_OP_KERNEL(op_type, GPU) #endif -- GitLab