提交 f1913d46 编写于 作者: Y Yu Yang

Change registry, test register double kernel

上级 2c05465d
......@@ -100,14 +100,38 @@ class OpRegistrar : public Registrar {
}
};
template <typename PlaceType, typename KernelType>
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor;
template <typename PlaceType, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelType...> {
using KT = typename std::tuple_element<I, std::tuple<KernelType...>>::type;
void operator()(const char* op_type) const {
using T = typename KT::ELEMENT_TYPE;
OperatorWithKernel::OpKernelKey key(ToDataType(std::type_index(typeid(T))),
PlaceType());
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KT);
constexpr auto size = std::tuple_size<std::tuple<KernelType...>>::value;
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelType...>
func;
func(op_type);
}
};
template <typename PlaceType, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
void operator()(const char* op_type) const {}
};
// User can register many kernel in one place. The data type could be different.
template <typename PlaceType, typename... KernelType>
class OpKernelRegistrar : public Registrar {
public:
explicit OpKernelRegistrar(const char* op_type) {
using T = typename KernelType::ELEMENT_TYPE;
OperatorWithKernel::OpKernelKey key(ToDataType(std::type_index(typeid(T))),
PlaceType());
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType);
OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func;
func(op_type);
}
};
......
......@@ -36,7 +36,9 @@ REGISTER_OP(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpMaker,
elementwise_mul_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, float>);
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, float>);
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, double>);
......@@ -19,7 +19,9 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, float>);
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, float>,
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, float>);
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, float>,
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册