diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 05f51d885c3d2b09cf09d73da221f6f94428e877..aed244d61acd8efb49b9650151d53e754bd3ab0f 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -307,7 +307,10 @@ class OpRegistry { } }; -class Registrar {}; +class Registrar { + public: + void Touch() {} +}; template class OpRegistrar : public Registrar { @@ -351,37 +354,40 @@ class OpKernelRegistrar : public Registrar { #define REGISTER_OP(op_type, op_class, op_maker_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ + static ::paddle::framework::OpRegistrar \ + __op_registrar_##op_type##__(#op_type); \ int TouchOpRegistrar_##op_type() { \ - static ::paddle::framework::OpRegistrar \ - __op_registrar_##op_type##__(#op_type); \ + __op_registrar_##op_type##__.Touch(); \ return 0; \ } /** * Macro to Register Gradient Operator. */ -#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_gradient_op__##op_type##_##grad_op_type, \ - "REGISTER_GRADIENT_OP must be called in global namespace"); \ - int TouchOpGradientRegistrar_##op_type() { \ - static ::paddle::framework::GradOpRegistrar \ - __op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \ - #grad_op_type); \ - return 0; \ +#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_gradient_op__##op_type##_##grad_op_type, \ + "REGISTER_GRADIENT_OP must be called in global namespace"); \ + static ::paddle::framework::GradOpRegistrar \ + __op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \ + #grad_op_type); \ + int TouchOpGradientRegistrar_##op_type() { \ + __op_gradient_registrar_##op_type##_##grad_op_type##__.Touch(); \ + return 0; \ } /** * Macro to Register OperatorKernel. */ -#define REGISTER_OP_KERNEL(op_type, DEVICE_TYPE, place_class, ...) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \ - "REGISTER_OP_KERNEL must be called in global namespace"); \ - int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() { \ - static ::paddle::framework::OpKernelRegistrar \ - __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); \ - return 0; \ +#define REGISTER_OP_KERNEL(op_type, DEVICE_TYPE, place_class, ...) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \ + "REGISTER_OP_KERNEL must be called in global namespace"); \ + static ::paddle::framework::OpKernelRegistrar \ + __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); \ + int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() { \ + __op_kernel_registrar_##op_type##_##DEVICE_TYPE##__.Touch(); \ + return 0; \ } /** @@ -436,6 +442,8 @@ class OpKernelRegistrar : public Registrar { __attribute__((unused)) = \ TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() +// TODO(jiayi): The following macros seems ugly, do we have better method? + #ifdef PADDLE_ONLY_CPU #define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU) #else