diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 7aa59f0b630d5a99fa15c00c9e32a22dd59b9a70..48f77a6784b1c993a369ffe4f7544c7efe1b7de8 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -311,7 +311,7 @@ class OpRegisterHelper { /** * Macro to Register OperatorKernel. */ -#define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, KernelType) \ +#define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, ...) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op_kernel_##type##_##DEVICE_TYPE##__, \ "REGISTER_OP_KERNEL must be in global namespace"); \ @@ -320,17 +320,19 @@ class OpRegisterHelper { ::paddle::framework::OperatorWithKernel::OpKernelKey key; \ key.place_ = PlaceType(); \ ::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \ - .reset(new KernelType()); \ + .reset(new __VA_ARGS__()); \ } \ }; \ static __op_kernel_register__##type##__ __reg_kernel_##type##__; \ int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; } -#define REGISTER_OP_GPU_KERNEL(type, KernelType) \ - REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, KernelType) +// (type, KernelType) +#define REGISTER_OP_GPU_KERNEL(type, ...) \ + REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__) -#define REGISTER_OP_CPU_KERNEL(type, KernelType) \ - REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, KernelType) +// (type, KernelType) +#define REGISTER_OP_CPU_KERNEL(type, ...) \ + REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) /** * Macro to mark what Operator and Kernel we will use and tell the compiler to diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 6fa110f94ccc0c0a2f2e61316aa5dc271631a11c..8e55d0111f39b2f632cf5a49c2ad3f210683652c 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -102,6 +102,7 @@ class OpWithKernelTest : public OperatorWithKernel { const std::vector& outputs) const override {} }; +template class CPUKernelTest : public OpKernel { public: void Compute(const KernelContext& ctx) const { @@ -171,7 +172,8 @@ class CPUKernalMultiInputsTest : public OpKernel { REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest, paddle::framework::OpKernelTestProtoAndCheckerMaker); -REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest); +REGISTER_OP_CPU_KERNEL(op_with_kernel, + paddle::framework::CPUKernelTest); // test with single input TEST(OpKernel, all) {