提交 051676a7 编写于 作者: Q Qiao Longfei 提交者: GitHub

support multiple template parameter in KernelType for REGISTER_OP_XPU_KERNEL (#2932)

上级 861b66d4
...@@ -311,7 +311,7 @@ class OpRegisterHelper { ...@@ -311,7 +311,7 @@ class OpRegisterHelper {
/** /**
* Macro to Register OperatorKernel. * Macro to Register OperatorKernel.
*/ */
#define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, KernelType) \ #define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##type##_##DEVICE_TYPE##__, \ __reg_op_kernel_##type##_##DEVICE_TYPE##__, \
"REGISTER_OP_KERNEL must be in global namespace"); \ "REGISTER_OP_KERNEL must be in global namespace"); \
...@@ -320,17 +320,19 @@ class OpRegisterHelper { ...@@ -320,17 +320,19 @@ class OpRegisterHelper {
::paddle::framework::OperatorWithKernel::OpKernelKey key; \ ::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \ key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \ ::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \ .reset(new __VA_ARGS__()); \
} \ } \
}; \ }; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__; \ static __op_kernel_register__##type##__ __reg_kernel_##type##__; \
int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; } int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; }
#define REGISTER_OP_GPU_KERNEL(type, KernelType) \ // (type, KernelType)
REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, KernelType) #define REGISTER_OP_GPU_KERNEL(type, ...) \
REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL(type, KernelType) \ // (type, KernelType)
REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, KernelType) #define REGISTER_OP_CPU_KERNEL(type, ...) \
REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
/** /**
* Macro to mark what Operator and Kernel we will use and tell the compiler to * Macro to mark what Operator and Kernel we will use and tell the compiler to
......
...@@ -102,6 +102,7 @@ class OpWithKernelTest : public OperatorWithKernel { ...@@ -102,6 +102,7 @@ class OpWithKernelTest : public OperatorWithKernel {
const std::vector<Tensor*>& outputs) const override {} const std::vector<Tensor*>& outputs) const override {}
}; };
template <typename T1, typename T2>
class CPUKernelTest : public OpKernel { class CPUKernelTest : public OpKernel {
public: public:
void Compute(const KernelContext& ctx) const { void Compute(const KernelContext& ctx) const {
...@@ -171,7 +172,8 @@ class CPUKernalMultiInputsTest : public OpKernel { ...@@ -171,7 +172,8 @@ class CPUKernalMultiInputsTest : public OpKernel {
REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest, REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestProtoAndCheckerMaker); paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest); REGISTER_OP_CPU_KERNEL(op_with_kernel,
paddle::framework::CPUKernelTest<float, float>);
// test with single input // test with single input
TEST(OpKernel, all) { TEST(OpKernel, all) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册