From db658bec92ae2b0adf6e7a4ed9e870ccafa4a9f0 Mon Sep 17 00:00:00 2001 From: zhangwen31 Date: Mon, 14 Sep 2020 20:11:26 +0000 Subject: [PATCH] [arm][kernel] refactor: elementwise-sub uses template now --- lite/kernels/arm/elementwise_compute.cc | 27 +++++++++---------- lite/kernels/arm/elementwise_compute.h | 4 +-- .../kernels/elementwise_grad_compute_test.cc | 2 +- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/lite/kernels/arm/elementwise_compute.cc b/lite/kernels/arm/elementwise_compute.cc index 9251201734..71fedc9665 100644 --- a/lite/kernels/arm/elementwise_compute.cc +++ b/lite/kernels/arm/elementwise_compute.cc @@ -132,24 +132,25 @@ void ElementwiseAddActivationCompute::Run() { } } -void ElementwiseSubCompute::Run() { - auto& param = Param(); - const float* x_data = param.X->data(); - const float* y_data = param.Y->data(); - float* out_data = param.Out->mutable_data(); +template +void ElementwiseSubCompute::Run() { + auto& param = this->template Param(); + const T* x_data = param.X->template data(); + const T* y_data = param.Y->template data(); + T* out_data = param.Out->template mutable_data(); int axis = param.axis; auto x_dims = param.X->dims(); auto y_dims = param.Y->dims(); int pre, n, post; if (x_dims.size() < y_dims.size() && is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) { - lite::arm::math::elementwise_sub_broadcast( + lite::arm::math::elementwise_sub_broadcast( y_data, x_data, out_data, pre, n, post); } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { - lite::arm::math::elementwise_sub_broadcast( + lite::arm::math::elementwise_sub_broadcast( x_data, y_data, out_data, pre, n, post); } else { - lite::arm::math::elementwise_sub( + lite::arm::math::elementwise_sub( x_data, y_data, out_data, x_dims.production()); } } @@ -419,12 +420,10 @@ REGISTER_LITE_KERNEL( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); -REGISTER_LITE_KERNEL(elementwise_sub, - kARM, - kFloat, - kNCHW, - paddle::lite::kernels::arm::ElementwiseSubCompute, - def) +using elementwise_sub_float_t = + paddle::lite::kernels::arm::ElementwiseSubCompute; +REGISTER_LITE_KERNEL( + elementwise_sub, kARM, kFloat, kNCHW, elementwise_sub_float_t, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) diff --git a/lite/kernels/arm/elementwise_compute.h b/lite/kernels/arm/elementwise_compute.h index 025ce7cc7f..46efac084e 100644 --- a/lite/kernels/arm/elementwise_compute.h +++ b/lite/kernels/arm/elementwise_compute.h @@ -38,8 +38,8 @@ class ElementwiseAddActivationCompute virtual ~ElementwiseAddActivationCompute() = default; }; -class ElementwiseSubCompute - : public KernelLite { +template +class ElementwiseSubCompute : public KernelLite { public: void Run() override; diff --git a/lite/tests/kernels/elementwise_grad_compute_test.cc b/lite/tests/kernels/elementwise_grad_compute_test.cc index baf7b16a94..46964485cd 100644 --- a/lite/tests/kernels/elementwise_grad_compute_test.cc +++ b/lite/tests/kernels/elementwise_grad_compute_test.cc @@ -27,7 +27,7 @@ using param_t = operators::ElementwiseParam; using grad_param_t = operators::ElementwiseGradParam; using kernel_add_t = ElementwiseAddCompute; using grad_kernel_add_t = ElementwiseAddGradCompute; -using kernel_sub_t = ElementwiseSubCompute; +using kernel_sub_t = ElementwiseSubCompute; using grad_kernel_sub_t = ElementwiseSubGradCompute; void elementwise_common(grad_param_t& param, // NOLINT -- GitLab