diff --git a/lite/kernels/arm/elementwise_compute.cc b/lite/kernels/arm/elementwise_compute.cc index 92512017348a556fdead7ef8aa02a65023844c11..71fedc966556aadc40848fff54bae7884cd0e9f8 100644 --- a/lite/kernels/arm/elementwise_compute.cc +++ b/lite/kernels/arm/elementwise_compute.cc @@ -132,24 +132,25 @@ void ElementwiseAddActivationCompute::Run() { } } -void ElementwiseSubCompute::Run() { - auto& param = Param(); - const float* x_data = param.X->data(); - const float* y_data = param.Y->data(); - float* out_data = param.Out->mutable_data(); +template +void ElementwiseSubCompute::Run() { + auto& param = this->template Param(); + const T* x_data = param.X->template data(); + const T* y_data = param.Y->template data(); + T* out_data = param.Out->template mutable_data(); int axis = param.axis; auto x_dims = param.X->dims(); auto y_dims = param.Y->dims(); int pre, n, post; if (x_dims.size() < y_dims.size() && is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) { - lite::arm::math::elementwise_sub_broadcast( + lite::arm::math::elementwise_sub_broadcast( y_data, x_data, out_data, pre, n, post); } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { - lite::arm::math::elementwise_sub_broadcast( + lite::arm::math::elementwise_sub_broadcast( x_data, y_data, out_data, pre, n, post); } else { - lite::arm::math::elementwise_sub( + lite::arm::math::elementwise_sub( x_data, y_data, out_data, x_dims.production()); } } @@ -419,12 +420,10 @@ REGISTER_LITE_KERNEL( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); -REGISTER_LITE_KERNEL(elementwise_sub, - kARM, - kFloat, - kNCHW, - paddle::lite::kernels::arm::ElementwiseSubCompute, - def) +using elementwise_sub_float_t = + paddle::lite::kernels::arm::ElementwiseSubCompute; +REGISTER_LITE_KERNEL( + elementwise_sub, kARM, kFloat, kNCHW, elementwise_sub_float_t, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) diff --git a/lite/kernels/arm/elementwise_compute.h b/lite/kernels/arm/elementwise_compute.h index 025ce7cc7f677d0c20b3ff72a0604409810d0deb..46efac084e52e52431604da372c9061f2b07acb4 100644 --- a/lite/kernels/arm/elementwise_compute.h +++ b/lite/kernels/arm/elementwise_compute.h @@ -38,8 +38,8 @@ class ElementwiseAddActivationCompute virtual ~ElementwiseAddActivationCompute() = default; }; -class ElementwiseSubCompute - : public KernelLite { +template +class ElementwiseSubCompute : public KernelLite { public: void Run() override; diff --git a/lite/tests/kernels/elementwise_grad_compute_test.cc b/lite/tests/kernels/elementwise_grad_compute_test.cc index baf7b16a94d4ba539bcd17d9c1b670001956a889..46964485cd1c80f4f79f401d6d017af8b0b66ebe 100644 --- a/lite/tests/kernels/elementwise_grad_compute_test.cc +++ b/lite/tests/kernels/elementwise_grad_compute_test.cc @@ -27,7 +27,7 @@ using param_t = operators::ElementwiseParam; using grad_param_t = operators::ElementwiseGradParam; using kernel_add_t = ElementwiseAddCompute; using grad_kernel_add_t = ElementwiseAddGradCompute; -using kernel_sub_t = ElementwiseSubCompute; +using kernel_sub_t = ElementwiseSubCompute; using grad_kernel_sub_t = ElementwiseSubGradCompute; void elementwise_common(grad_param_t& param, // NOLINT