diff --git a/lite/kernels/arm/elementwise_compute.cc b/lite/kernels/arm/elementwise_compute.cc index 2214bf8ce854c1ab960067989be7caa57dcdb2e1..44ce1c8652d94b7d686833f6f8c5ae76505de4f6 100644 --- a/lite/kernels/arm/elementwise_compute.cc +++ b/lite/kernels/arm/elementwise_compute.cc @@ -71,24 +71,25 @@ inline bool is_broadcast(const DDim& x_dims, return true; } -void ElementwiseAddCompute::Run() { - auto& param = Param(); - const float* x_data = param.X->data(); - const float* y_data = param.Y->data(); - float* out_data = param.Out->mutable_data(); +template +void ElementwiseAddCompute::Run() { + auto& param = this->template Param(); + const T* x_data = param.X->template data(); + const T* y_data = param.Y->template data(); + T* out_data = param.Out->template mutable_data(); int axis = param.axis; auto x_dims = param.X->dims(); auto y_dims = param.Y->dims(); int pre, n, post; if (x_dims.size() < y_dims.size() && is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) { - lite::arm::math::elementwise_add_broadcast( + lite::arm::math::elementwise_add_broadcast( y_data, x_data, out_data, pre, n, post); } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { - lite::arm::math::elementwise_add_broadcast( + lite::arm::math::elementwise_add_broadcast( x_data, y_data, out_data, pre, n, post); } else { - lite::arm::math::elementwise_add( + lite::arm::math::elementwise_add( x_data, y_data, out_data, x_dims.production()); } } @@ -377,12 +378,10 @@ void ElementwiseModCompute::Run() { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(elementwise_add, - kARM, - kFloat, - kNCHW, - paddle::lite::kernels::arm::ElementwiseAddCompute, - def) +using elementwise_add_float_t = + paddle::lite::kernels::arm::ElementwiseAddCompute; +REGISTER_LITE_KERNEL( + elementwise_add, kARM, kFloat, kNCHW, elementwise_add_float_t, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) diff --git a/lite/kernels/arm/elementwise_compute.h b/lite/kernels/arm/elementwise_compute.h index 89d9898648d25fec98568f2456fe96903da0a69d..025ce7cc7f677d0c20b3ff72a0604409810d0deb 100644 --- a/lite/kernels/arm/elementwise_compute.h +++ b/lite/kernels/arm/elementwise_compute.h @@ -22,8 +22,8 @@ namespace lite { namespace kernels { namespace arm { -class ElementwiseAddCompute - : public KernelLite { +template +class ElementwiseAddCompute : public KernelLite { public: void Run() override; diff --git a/lite/kernels/arm/elementwise_compute_test.cc b/lite/kernels/arm/elementwise_compute_test.cc index 79262fb4ef75283eba12efa0a4ad8dc048681338..bf454f10a874f6ad1d65887f5199e75f9afce284 100644 --- a/lite/kernels/arm/elementwise_compute_test.cc +++ b/lite/kernels/arm/elementwise_compute_test.cc @@ -33,7 +33,7 @@ TEST(elementwise_add_arm, retrive_op) { } TEST(elementwise_add_arm, init) { - ElementwiseAddCompute elementwise_add; + ElementwiseAddCompute elementwise_add; ASSERT_EQ(elementwise_add.precision(), PRECISION(kFloat)); ASSERT_EQ(elementwise_add.target(), TARGET(kARM)); } @@ -255,7 +255,7 @@ template void elementwise_imod_compute_ref( const operators::ElementwiseParam& param, const std::string act_type); TEST(elementwise_add, compute) { - ElementwiseAddCompute elementwise_add; + ElementwiseAddCompute elementwise_add; operators::ElementwiseParam param; lite::Tensor x, y, output, output_ref; diff --git a/lite/tests/kernels/elementwise_grad_compute_test.cc b/lite/tests/kernels/elementwise_grad_compute_test.cc index 04e74e49099f13a7e5920b306f8d2e26650a2574..baf7b16a94d4ba539bcd17d9c1b670001956a889 100644 --- a/lite/tests/kernels/elementwise_grad_compute_test.cc +++ b/lite/tests/kernels/elementwise_grad_compute_test.cc @@ -25,7 +25,7 @@ namespace arm { using param_t = operators::ElementwiseParam; using grad_param_t = operators::ElementwiseGradParam; -using kernel_add_t = ElementwiseAddCompute; +using kernel_add_t = ElementwiseAddCompute; using grad_kernel_add_t = ElementwiseAddGradCompute; using kernel_sub_t = ElementwiseSubCompute; using grad_kernel_sub_t = ElementwiseSubGradCompute;