From 051676a7e483b59583d92cd49aff6bdace916dc4 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Tue, 18 Jul 2017 12:57:38 +0800 Subject: [PATCH] support multiple template parameter in KernelType for REGISTER_OP_XPU_KERNEL (#2932) --- paddle/framework/op_registry.h | 14 ++++++++------ paddle/framework/operator_test.cc | 4 +++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 7aa59f0b630..48f77a6784b 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -311,7 +311,7 @@ class OpRegisterHelper { /** * Macro to Register OperatorKernel. */ -#define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, KernelType) \ +#define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, ...) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op_kernel_##type##_##DEVICE_TYPE##__, \ "REGISTER_OP_KERNEL must be in global namespace"); \ @@ -320,17 +320,19 @@ class OpRegisterHelper { ::paddle::framework::OperatorWithKernel::OpKernelKey key; \ key.place_ = PlaceType(); \ ::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \ - .reset(new KernelType()); \ + .reset(new __VA_ARGS__()); \ } \ }; \ static __op_kernel_register__##type##__ __reg_kernel_##type##__; \ int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; } -#define REGISTER_OP_GPU_KERNEL(type, KernelType) \ - REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, KernelType) +// (type, KernelType) +#define REGISTER_OP_GPU_KERNEL(type, ...) \ + REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__) -#define REGISTER_OP_CPU_KERNEL(type, KernelType) \ - REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, KernelType) +// (type, KernelType) +#define REGISTER_OP_CPU_KERNEL(type, ...) \ + REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) /** * Macro to mark what Operator and Kernel we will use and tell the compiler to diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 6fa110f94cc..8e55d0111f3 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -102,6 +102,7 @@ class OpWithKernelTest : public OperatorWithKernel { const std::vector& outputs) const override {} }; +template class CPUKernelTest : public OpKernel { public: void Compute(const KernelContext& ctx) const { @@ -171,7 +172,8 @@ class CPUKernalMultiInputsTest : public OpKernel { REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest, paddle::framework::OpKernelTestProtoAndCheckerMaker); -REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest); +REGISTER_OP_CPU_KERNEL(op_with_kernel, + paddle::framework::CPUKernelTest); // test with single input TEST(OpKernel, all) { -- GitLab