提交 78c3e1de 编写于 作者: F fengjiayi

refactor

上级 9a52056d
...@@ -307,22 +307,37 @@ class OpRegistry { ...@@ -307,22 +307,37 @@ class OpRegistry {
} }
}; };
class Registrar {};
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
class OpRegisterHelper { class OpRegistrar : public Registrar {
public: public:
explicit OpRegisterHelper(const char* op_type) { explicit OpRegistrar(const char* op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type); OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
} }
}; };
template <typename GradOpType> template <typename GradOpType>
class GradOpRegisterHelper { class GradOpRegistrar : public Registrar {
public: public:
GradOpRegisterHelper(const char* op_type, const char* grad_op_type) { GradOpRegistrar(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type); OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
} }
}; };
template <typename PlaceType, typename KernelType>
class OpKernelRegistrar : public Registrar {
public:
explicit OpKernelRegistrar(const char* op_type) {
::paddle::framework::OperatorWithKernel::OpKernelKey key;
key.place_ = PlaceType();
::paddle::framework::OperatorWithKernel::AllOpKernels()[op_type][key].reset(
new KernelType);
}
};
int TouchRegistrar(const Registrar& registrar) { return 0; }
/** /**
* check if MACRO is used in GLOBAL NAMESPACE. * check if MACRO is used in GLOBAL NAMESPACE.
*/ */
...@@ -335,72 +350,58 @@ class GradOpRegisterHelper { ...@@ -335,72 +350,58 @@ class GradOpRegisterHelper {
/** /**
* Macro to Register Operator. * Macro to Register Operator.
*/ */
#define REGISTER_OP(__op_type, __op_class, __op_maker_class) \ #define REGISTER_OP(op_type, op_class, op_maker_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE(__reg_op__##__op_type, \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
"REGISTER_OP must be in global namespace"); \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
static ::paddle::framework::OpRegisterHelper<__op_class, __op_maker_class> \ static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \
__op_register_##__op_type##__(#__op_type); \ __op_registrar_##op_type##__(#op_type);
int __op_register_##__op_type##_handle__() { return 0; }
/** /**
* Macro to Register Gradient Operator. * Macro to Register Gradient Operator.
*/ */
#define REGISTER_GRADIENT_OP(__op_type, __grad_op_type, __grad_op_class) \ #define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type##__grad_op_type, \ __reg_gradient_op__##op_type##_##grad_op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \ "REGISTER_GRADIENT_OP must be called in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__grad_op_class> \ static ::paddle::framework::GradOpRegistrar<grad_op_class> \
__op_gradient_register_##__op_type##__grad_op_type##__(#__op_type, \ __op_gradient_register_##op_type##_##grad_op_type##__(#op_type, \
#__grad_op_type); \ #grad_op_type);
int __op_gradient_register_##__op_type##__grad_op_type##_handle__() { \
return 0; \
}
/** /**
* Macro to Forbid user register Gradient Operator. * Macro to Register OperatorKernel.
*/ */
#define NO_GRADIENT(__op_type) \ #define REGISTER_OP_KERNEL(op_type, DEVICE_TYPE, place_class, kernel_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type##__op_type##_grad, \ __reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \
"NO_GRADIENT must be in global namespace") "REGISTER_OP_KERNEL must be called in global namespace"); \
static ::paddle::framework::OpKernelRegistrar<place_class, kernel_class> \
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type);
/** /**
* Macro to Register OperatorKernel. * Macro to Forbid user register Gradient Operator.
*/ */
#define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, ...) \ #define NO_GRADIENT(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##type##_##DEVICE_TYPE##__, \ __reg_gradient_op__##op_type##_##op_type##_grad, \
"REGISTER_OP_KERNEL must be in global namespace"); \ "NO_GRADIENT must be called in global namespace")
struct __op_kernel_register__##type##__##DEVICE_TYPE##__ { \
__op_kernel_register__##type##__##DEVICE_TYPE##__() { \ #define REGISTER_OP_GPU_KERNEL(op_type, kernel_class) \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \ REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, kernel_class)
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \ #define REGISTER_OP_CPU_KERNEL(op_type, kernel_class) \
.reset(new __VA_ARGS__()); \ REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, kernel_class)
} \
}; \
static __op_kernel_register__##type##__##DEVICE_TYPE##__ \
__reg_kernel_##type##__##DEVICE_TYPE##__; \
int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; }
// (type, KernelType)
#define REGISTER_OP_GPU_KERNEL(type, ...) \
REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
// (type, KernelType)
#define REGISTER_OP_CPU_KERNEL(type, ...) \
REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
/** /**
* Macro to mark what Operator and Kernel we will use and tell the compiler to * Macro to mark what Operator and Kernel we will use and tell the compiler to
* link them into target. * link them into target.
*/ */
#define USE_OP_WITHOUT_KERNEL(op_type) \ #define USE_OP_ITSELF(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_without_kernel_##op_type, \ __use_op_itself_##op_type, \
"USE_OP_WITHOUT_KERNEL must be in global namespace"); \ "USE_OP_ITSELF must be called in global namespace"); \
extern int __op_register_##op_type##_handle__(); \ extern ::paddle::framework::OpRegistrar<op_class, op_maker_class> \
static int __use_op_ptr_##op_type##_without_kernel__ \ __op_registrar_##op_type##__; \
static int __use_op_ptr_##op_type##_without_kernel__ \
__attribute__((unused)) = __op_register_##op_type##_handle__() __attribute__((unused)) = __op_register_##op_type##_handle__()
#define USE_OP_KERNEL(op_type, DEVICE_TYPE) \ #define USE_OP_KERNEL(op_type, DEVICE_TYPE) \
...@@ -413,8 +414,8 @@ class GradOpRegisterHelper { ...@@ -413,8 +414,8 @@ class GradOpRegisterHelper {
__op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__() __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__()
// use Operator with only cpu kernel. // use Operator with only cpu kernel.
#define USE_OP_CPU(op_type) \ #define USE_OP_CPU(op_type) \
USE_OP_WITHOUT_KERNEL(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type, CPU) USE_OP_KERNEL(op_type, CPU)
#ifdef PADDLE_ONLY_CPU #ifdef PADDLE_ONLY_CPU
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册