From f1913d46972b11d852f42072eedd5485c721d2c5 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 27 Sep 2017 17:28:12 -0700 Subject: [PATCH] Change registry, test register double kernel --- paddle/framework/op_registry.h | 34 ++++++++++++++++++++++---- paddle/operators/elementwise_mul_op.cc | 6 +++-- paddle/operators/elementwise_mul_op.cu | 6 +++-- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 0db67e4c678..804f901dfa2 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -100,14 +100,38 @@ class OpRegistrar : public Registrar { } }; -template +template +struct OpKernelRegistrarFunctor; + +template +struct OpKernelRegistrarFunctor { + using KT = typename std::tuple_element>::type; + + void operator()(const char* op_type) const { + using T = typename KT::ELEMENT_TYPE; + OperatorWithKernel::OpKernelKey key(ToDataType(std::type_index(typeid(T))), + PlaceType()); + OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KT); + + constexpr auto size = std::tuple_size>::value; + OpKernelRegistrarFunctor + func; + func(op_type); + } +}; + +template +struct OpKernelRegistrarFunctor { + void operator()(const char* op_type) const {} +}; + +// User can register many kernel in one place. The data type could be different. +template class OpKernelRegistrar : public Registrar { public: explicit OpKernelRegistrar(const char* op_type) { - using T = typename KernelType::ELEMENT_TYPE; - OperatorWithKernel::OpKernelKey key(ToDataType(std::type_index(typeid(T))), - PlaceType()); - OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType); + OpKernelRegistrarFunctor func; + func(op_type); } }; diff --git a/paddle/operators/elementwise_mul_op.cc b/paddle/operators/elementwise_mul_op.cc index bda5dfe03e9..da7765aa6a7 100644 --- a/paddle/operators/elementwise_mul_op.cc +++ b/paddle/operators/elementwise_mul_op.cc @@ -36,7 +36,9 @@ REGISTER_OP(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpMaker, elementwise_mul_grad, ops::ElementwiseOpGrad); REGISTER_OP_CPU_KERNEL( elementwise_mul, - ops::ElementwiseMulKernel); + ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel); REGISTER_OP_CPU_KERNEL( elementwise_mul_grad, - ops::ElementwiseMulGradKernel); + ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel); diff --git a/paddle/operators/elementwise_mul_op.cu b/paddle/operators/elementwise_mul_op.cu index da08a75596c..056f081d3e6 100644 --- a/paddle/operators/elementwise_mul_op.cu +++ b/paddle/operators/elementwise_mul_op.cu @@ -19,7 +19,9 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( elementwise_mul, - ops::ElementwiseMulKernel); + ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel); REGISTER_OP_GPU_KERNEL( elementwise_mul_grad, - ops::ElementwiseMulGradKernel); + ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel); -- GitLab