From 8018225c89330734f3051814324bef101083e921 Mon Sep 17 00:00:00 2001 From: zhangwen31 Date: Mon, 14 Sep 2020 05:19:57 +0000 Subject: [PATCH] [arm][kernel]refactor: ElementwiseAddCompute in arm kernel is template now --- lite/kernels/arm/elementwise_compute.cc | 27 +++++++++---------- lite/kernels/arm/elementwise_compute.h | 4 +-- lite/kernels/arm/elementwise_compute_test.cc | 4 +-- .../kernels/elementwise_grad_compute_test.cc | 2 +- 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/lite/kernels/arm/elementwise_compute.cc b/lite/kernels/arm/elementwise_compute.cc index 2214bf8ce8..44ce1c8652 100644 --- a/lite/kernels/arm/elementwise_compute.cc +++ b/lite/kernels/arm/elementwise_compute.cc @@ -71,24 +71,25 @@ inline bool is_broadcast(const DDim& x_dims, return true; } -void ElementwiseAddCompute::Run() { - auto& param = Param(); - const float* x_data = param.X->data(); - const float* y_data = param.Y->data(); - float* out_data = param.Out->mutable_data(); +template +void ElementwiseAddCompute::Run() { + auto& param = this->template Param(); + const T* x_data = param.X->template data(); + const T* y_data = param.Y->template data(); + T* out_data = param.Out->template mutable_data(); int axis = param.axis; auto x_dims = param.X->dims(); auto y_dims = param.Y->dims(); int pre, n, post; if (x_dims.size() < y_dims.size() && is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) { - lite::arm::math::elementwise_add_broadcast( + lite::arm::math::elementwise_add_broadcast( y_data, x_data, out_data, pre, n, post); } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { - lite::arm::math::elementwise_add_broadcast( + lite::arm::math::elementwise_add_broadcast( x_data, y_data, out_data, pre, n, post); } else { - lite::arm::math::elementwise_add( + lite::arm::math::elementwise_add( x_data, y_data, out_data, x_dims.production()); } } @@ -377,12 +378,10 @@ void ElementwiseModCompute::Run() { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(elementwise_add, - kARM, - kFloat, - kNCHW, - paddle::lite::kernels::arm::ElementwiseAddCompute, - def) +using elementwise_add_float_t = + paddle::lite::kernels::arm::ElementwiseAddCompute; +REGISTER_LITE_KERNEL( + elementwise_add, kARM, kFloat, kNCHW, elementwise_add_float_t, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) diff --git a/lite/kernels/arm/elementwise_compute.h b/lite/kernels/arm/elementwise_compute.h index 89d9898648..025ce7cc7f 100644 --- a/lite/kernels/arm/elementwise_compute.h +++ b/lite/kernels/arm/elementwise_compute.h @@ -22,8 +22,8 @@ namespace lite { namespace kernels { namespace arm { -class ElementwiseAddCompute - : public KernelLite { +template +class ElementwiseAddCompute : public KernelLite { public: void Run() override; diff --git a/lite/kernels/arm/elementwise_compute_test.cc b/lite/kernels/arm/elementwise_compute_test.cc index 79262fb4ef..bf454f10a8 100644 --- a/lite/kernels/arm/elementwise_compute_test.cc +++ b/lite/kernels/arm/elementwise_compute_test.cc @@ -33,7 +33,7 @@ TEST(elementwise_add_arm, retrive_op) { } TEST(elementwise_add_arm, init) { - ElementwiseAddCompute elementwise_add; + ElementwiseAddCompute elementwise_add; ASSERT_EQ(elementwise_add.precision(), PRECISION(kFloat)); ASSERT_EQ(elementwise_add.target(), TARGET(kARM)); } @@ -255,7 +255,7 @@ template void elementwise_imod_compute_ref( const operators::ElementwiseParam& param, const std::string act_type); TEST(elementwise_add, compute) { - ElementwiseAddCompute elementwise_add; + ElementwiseAddCompute elementwise_add; operators::ElementwiseParam param; lite::Tensor x, y, output, output_ref; diff --git a/lite/tests/kernels/elementwise_grad_compute_test.cc b/lite/tests/kernels/elementwise_grad_compute_test.cc index 04e74e4909..baf7b16a94 100644 --- a/lite/tests/kernels/elementwise_grad_compute_test.cc +++ b/lite/tests/kernels/elementwise_grad_compute_test.cc @@ -25,7 +25,7 @@ namespace arm { using param_t = operators::ElementwiseParam; using grad_param_t = operators::ElementwiseGradParam; -using kernel_add_t = ElementwiseAddCompute; +using kernel_add_t = ElementwiseAddCompute; using grad_kernel_add_t = ElementwiseAddGradCompute; using kernel_sub_t = ElementwiseSubCompute; using grad_kernel_sub_t = ElementwiseSubGradCompute; -- GitLab