From 12735ae4b16c89da1dc89a74f954ab81633a88d8 Mon Sep 17 00:00:00 2001 From: HappyAngel Date: Fri, 18 Sep 2020 15:35:58 +0800 Subject: [PATCH] [arm] add reduce_sum op on arm. test=develop (#4289) * add reduce op on arm. test=develop * fix format. test=develop * fix acccording to comments. test=develop --- lite/backends/arm/math/CMakeLists.txt | 1 + lite/backends/arm/math/funcs.h | 10 + lite/backends/arm/math/reduce_sum.cc | 385 ++++++++++++++++++ lite/backends/arm/math/reduce_sum.h | 84 ++++ lite/kernels/arm/CMakeLists.txt | 1 + lite/kernels/arm/reduce_sum_compute.cc | 114 ++++++ lite/kernels/arm/reduce_sum_compute.h | 36 ++ lite/tests/kernels/reduce_sum_compute_test.cc | 8 +- 8 files changed, 635 insertions(+), 4 deletions(-) create mode 100644 lite/backends/arm/math/reduce_sum.cc create mode 100644 lite/backends/arm/math/reduce_sum.h create mode 100644 lite/kernels/arm/reduce_sum_compute.cc create mode 100644 lite/kernels/arm/reduce_sum_compute.h diff --git a/lite/backends/arm/math/CMakeLists.txt b/lite/backends/arm/math/CMakeLists.txt index 67fc64ab9d..88c449e6a9 100644 --- a/lite/backends/arm/math/CMakeLists.txt +++ b/lite/backends/arm/math/CMakeLists.txt @@ -127,6 +127,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR) anchor_generator.cc split_merge_lod_tenosr.cc reduce_prod.cc + reduce_sum.cc lstm.cc clip.cc pixel_shuffle.cc diff --git a/lite/backends/arm/math/funcs.h b/lite/backends/arm/math/funcs.h index 131c1dbd37..f1ac1d63a1 100644 --- a/lite/backends/arm/math/funcs.h +++ b/lite/backends/arm/math/funcs.h @@ -53,6 +53,7 @@ #include "lite/backends/arm/math/reduce_max.h" #include "lite/backends/arm/math/reduce_mean.h" #include "lite/backends/arm/math/reduce_prod.h" +#include "lite/backends/arm/math/reduce_sum.h" #include "lite/backends/arm/math/scale.h" #include "lite/backends/arm/math/scatter.h" #include "lite/backends/arm/math/sequence_expand.h" @@ -358,6 +359,15 @@ inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) { return exp_ps(vmulq_f32(b, log_ps(a))); } +inline float32x4_t vpaddq_f32(float32x4_t a, float32x4_t b) { + float32x4_t vrst; + vrst[0] = a[0] + a[1]; + vrst[1] = a[2] + a[3]; + vrst[2] = b[0] + b[1]; + vrst[3] = b[2] + b[3]; + return vrst; +} + template void fill_bias_fc( T* tensor, const T* bias, int num, int channel, bool flag_relu); diff --git a/lite/backends/arm/math/reduce_sum.cc b/lite/backends/arm/math/reduce_sum.cc new file mode 100644 index 0000000000..b563887e86 --- /dev/null +++ b/lite/backends/arm/math/reduce_sum.cc @@ -0,0 +1,385 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/backends/arm/math/reduce_sum.h" +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void reduce_sum_n(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int chw_size = channel_in * height_in * width_in; + if (num_in == 1) { + memcpy(dst, src, sizeof(float) * chw_size); + } else { + int cnt_n = num_in >> 2; + int remain_n = num_in & 3; + int cnt_chw = chw_size >> 3; + int cnt_rem = chw_size & 7; + int stride = chw_size << 2; + int stride_c = 0; + for (int c = 0; c < cnt_chw; c++) { + float32x4_t vsum0 = vdupq_n_f32(0.f); + float32x4_t vsum1 = vdupq_n_f32(0.f); + const float* din_ptr0 = src + stride_c; + const float* din_ptr1 = din_ptr0 + chw_size; + const float* din_ptr2 = din_ptr1 + chw_size; + const float* din_ptr3 = din_ptr2 + chw_size; + for (int n = 0; n < cnt_n; n++) { + float32x4_t va0 = vld1q_f32(din_ptr0); + float32x4_t vb0 = vld1q_f32(din_ptr1); + float32x4_t va1 = vld1q_f32(din_ptr0 + 4); + float32x4_t vb1 = vld1q_f32(din_ptr1 + 4); + float32x4_t vc0 = vld1q_f32(din_ptr2); + float32x4_t vd0 = vld1q_f32(din_ptr3); + float32x4_t vs00 = vaddq_f32(va0, vb0); + float32x4_t vc1 = vld1q_f32(din_ptr2 + 4); + float32x4_t vs10 = vaddq_f32(va1, vb1); + float32x4_t vd1 = vld1q_f32(din_ptr3 + 4); + float32x4_t vs01 = vaddq_f32(vc0, vd0); + vsum0 = vaddq_f32(vsum0, vs00); + float32x4_t vs11 = vaddq_f32(vc1, vd1); + vsum1 = vaddq_f32(vsum1, vs10); + din_ptr0 += stride; + din_ptr1 += stride; + vsum0 = vaddq_f32(vsum0, vs01); + din_ptr2 += stride; + din_ptr3 += stride; + vsum1 = vaddq_f32(vsum1, vs11); + } + for (int n = 0; n < remain_n; n++) { + float32x4_t va0 = vld1q_f32(din_ptr0); + float32x4_t va1 = vld1q_f32(din_ptr0 + 4); + vsum0 = vaddq_f32(vsum0, va0); + din_ptr0 += chw_size; + vsum1 = vaddq_f32(vsum1, va1); + } + vst1q_f32(dst, vsum0); + dst += 4; + stride_c += 8; + vst1q_f32(dst, vsum1); + dst += 4; + } + if (cnt_rem > 3) { + float32x4_t vsum0 = vdupq_n_f32(0.f); + const float* din_ptr0 = src + stride_c; + const float* din_ptr1 = din_ptr0 + chw_size; + const float* din_ptr2 = din_ptr1 + chw_size; + const float* din_ptr3 = din_ptr2 + chw_size; + for (int n = 0; n < cnt_n; n++) { + float32x4_t va0 = vld1q_f32(din_ptr0); + float32x4_t vb0 = vld1q_f32(din_ptr1); + float32x4_t vc0 = vld1q_f32(din_ptr2); + float32x4_t vd0 = vld1q_f32(din_ptr3); + float32x4_t vs00 = vaddq_f32(va0, vb0); + float32x4_t vs01 = vaddq_f32(vc0, vd0); + vsum0 = vaddq_f32(vsum0, vs00); + din_ptr0 += stride; + din_ptr1 += stride; + vsum0 = vaddq_f32(vsum0, vs01); + din_ptr2 += stride; + din_ptr3 += stride; + } + for (int n = 0; n < remain_n; n++) { + float32x4_t va0 = vld1q_f32(din_ptr0); + din_ptr0 += chw_size; + vsum0 = vaddq_f32(vsum0, va0); + } + stride_c += 4; + vst1q_f32(dst, vsum0); + dst += 4; + cnt_rem -= 4; + } + for (int c = 0; c < cnt_rem; c++) { + const float* din_ptr0 = src + stride_c; + const float* din_ptr1 = din_ptr0 + chw_size; + const float* din_ptr2 = din_ptr1 + chw_size; + const float* din_ptr3 = din_ptr2 + chw_size; + float sum = 0.0; + for (int n = 0; n < cnt_n; n++) { + float tmp0 = din_ptr0[0] + din_ptr1[0]; + float tmp1 = din_ptr2[0] + din_ptr3[0]; + din_ptr0 += stride; + din_ptr1 += stride; + sum += tmp0; + din_ptr2 += stride; + din_ptr3 += stride; + sum += tmp1; + } + for (int n = 0; n < remain_n; n++) { + sum += din_ptr0[0]; + din_ptr0 += chw_size; + } + stride_c++; + dst[0] = sum; + dst++; + } + } +} + +template <> +void reduce_sum_c(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int chw_size = hw_size * channel_in; + for (int n = 0; n < num_in; ++n) { + reduce_sum_n(src, dst, channel_in, 1, height_in, width_in); + src += chw_size; + dst += hw_size; + } +} + +template <> +void reduce_sum_h(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int nc_size = num_in * channel_in; + int hw_size = height_in * width_in; + for (int n = 0; n < nc_size; ++n) { + reduce_sum_n(src, dst, height_in, 1, 1, width_in); + src += hw_size; + dst += width_in; + } +} + +template <> +void reduce_sum_w(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int nch_size = num_in * channel_in * height_in; + int cnt_w = width_in >> 3; + int cnt_n = nch_size >> 2; + int rem_w = width_in & 7; + int rem_n = nch_size & 3; + int stride = 0; + int stride_n = width_in << 2; + for (int n = 0; n < cnt_n; n++) { + const float* din_ptr0 = src + stride; + const float* din_ptr1 = din_ptr0 + width_in; + const float* din_ptr2 = din_ptr1 + width_in; + const float* din_ptr3 = din_ptr2 + width_in; + float32x4_t vsum = vdupq_n_f32(0.f); + int tmp = rem_w; + for (int w = 0; w < cnt_w; w++) { + float32x4_t va0 = vld1q_f32(din_ptr0); + float32x4_t va1 = vld1q_f32(din_ptr0 + 4); + float32x4_t vb0 = vld1q_f32(din_ptr1); + float32x4_t vb1 = vld1q_f32(din_ptr1 + 4); + float32x4_t vc0 = vld1q_f32(din_ptr2); + float32x4_t vc1 = vld1q_f32(din_ptr2 + 4); + float32x4_t vs0 = vaddq_f32(va0, va1); + float32x4_t vd0 = vld1q_f32(din_ptr3); + float32x4_t vs1 = vaddq_f32(vb0, vb1); + float32x4_t vd1 = vld1q_f32(din_ptr3 + 4); + float32x4_t vs2 = vaddq_f32(vc0, vc1); + din_ptr0 += 8; + float32x4_t vs3 = vaddq_f32(vd0, vd1); + din_ptr1 += 8; + float32x4_t vs00 = vpaddq_f32(vs0, vs1); + din_ptr2 += 8; + float32x4_t vs01 = vpaddq_f32(vs2, vs3); + din_ptr3 += 8; + float32x4_t vs = vpaddq_f32(vs00, vs01); + vsum = vaddq_f32(vs, vsum); + } + if (tmp > 3) { + float32x4_t va0 = vld1q_f32(din_ptr0); + float32x4_t vb0 = vld1q_f32(din_ptr1); + float32x4_t vc0 = vld1q_f32(din_ptr2); + float32x4_t vd0 = vld1q_f32(din_ptr3); + din_ptr0 += 4; + din_ptr1 += 4; + float32x4_t vs00 = vpaddq_f32(va0, vb0); + float32x4_t vs01 = vpaddq_f32(vc0, vd0); + din_ptr2 += 4; + din_ptr3 += 4; + float32x4_t vs = vpaddq_f32(vs00, vs01); + vsum = vaddq_f32(vs, vsum); + tmp -= 4; + } + for (int w = 0; w < tmp; w++) { + vsum[0] += *din_ptr0++; + vsum[1] += *din_ptr1++; + vsum[2] += *din_ptr2++; + vsum[3] += *din_ptr3++; + } + stride += stride_n; + vst1q_f32(dst, vsum); + dst += 4; + } + if (rem_n > 1) { + const float* din_ptr0 = src + stride; + const float* din_ptr1 = din_ptr0 + width_in; + float32x4_t vsum = vdupq_n_f32(0.f); + for (int w = 0; w < cnt_w; w++) { + float32x4_t va0 = vld1q_f32(din_ptr0); + din_ptr0 += 4; + float32x4_t vb0 = vld1q_f32(din_ptr1); + din_ptr1 += 4; + float32x4_t va1 = vld1q_f32(din_ptr0); + float32x4_t vb1 = vld1q_f32(din_ptr1); + float32x4_t vs0 = vpaddq_f32(va0, vb0); + din_ptr0 += 4; + float32x4_t vs1 = vpaddq_f32(va1, vb1); + din_ptr1 += 4; + float32x4_t vs00 = vpaddq_f32(vs0, vs1); + vsum = vaddq_f32(vs00, vsum); + } + int tmp = rem_w; + if (tmp > 3) { + float32x4_t va0 = vld1q_f32(din_ptr0); + float32x4_t vb0 = vld1q_f32(din_ptr1); + din_ptr0 += 4; + din_ptr1 += 4; + float32x4_t vs00 = vpaddq_f32(va0, vb0); + tmp -= 4; + vsum[0] += vs00[0]; + vsum[2] += vs00[1]; + vsum[1] += vs00[2]; + vsum[3] += vs00[3]; + } + vsum[0] += vsum[2]; + vsum[1] += vsum[3]; + for (int w = 0; w < tmp; w++) { + vsum[0] += *din_ptr0++; + vsum[1] += *din_ptr1++; + } + stride += width_in; + *dst++ = vsum[0]; + stride += width_in; + *dst++ = vsum[1]; + rem_n -= 2; + } + for (int n = 0; n < rem_n; n++) { + const float* din_ptr0 = src + stride; + float32x4_t vsum = vdupq_n_f32(0.f); + for (int w = 0; w < cnt_w; w++) { + float32x4_t va0 = vld1q_f32(din_ptr0); + float32x4_t va1 = vld1q_f32(din_ptr0 + 4); + float32x4_t vs0 = vaddq_f32(va0, va1); + din_ptr0 += 8; + vsum = vaddq_f32(vs0, vsum); + } + if (rem_w > 3) { + float32x4_t va0 = vld1q_f32(din_ptr0); + din_ptr0 += 4; + vsum = vaddq_f32(vsum, va0); + rem_w -= 4; + } + vsum[1] += vsum[2]; + for (int w = 0; w < rem_w; w++) { + vsum[0] += *din_ptr0++; + } + vsum[1] += vsum[3]; + vsum[0] += vsum[1]; + *dst++ = vsum[0]; + } +} + +template <> +void reduce_sum_all(const float* src, float* dst, int all_size) { + int cnt_n = all_size >> 4; + int rem_n = all_size & 15; + int cnt_rem = rem_n >> 2; + int rem_rem = rem_n & 3; + float32x4_t vsum = vdupq_n_f32(0.f); + for (int n = 0; n < cnt_n; n++) { + float32x4_t va0 = vld1q_f32(src); + float32x4_t va1 = vld1q_f32(src + 4); + float32x4_t va2 = vld1q_f32(src + 8); + float32x4_t va3 = vld1q_f32(src + 12); + src += 16; + float32x4_t vs0 = vaddq_f32(va0, va1); + float32x4_t vs1 = vaddq_f32(va2, va3); + float32x4_t vs = vpaddq_f32(vs0, vs1); + vsum = vaddq_f32(vsum, vs); + } + for (int n = 0; n < cnt_rem; n++) { + float32x4_t va0 = vld1q_f32(src); + src += 4; + vsum = vaddq_f32(vsum, va0); + } + vsum[1] += vsum[2]; + for (int n = 0; n < rem_rem; n++) { + vsum[0] += *src++; + } + vsum[1] += vsum[3]; + vsum[0] += vsum[1]; + dst[0] = vsum[0]; +} + +template <> +void reduce_sum_nc(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce nc. + int num = num_in * channel_in; + int size = height_in * width_in; + reduce_sum_n(src, dst, num, size, 1, 1); +} + +template <> +void reduce_sum_ch(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int ch_size = channel_in * height_in; + int chw_size = ch_size * width_in; + for (int n = 0; n < num_in; n++) { + reduce_sum_n(src, dst, ch_size, 1, 1, width_in); + src += chw_size; + dst += width_in; + } +} + +template <> +void reduce_sum_hw(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int nc_size = num_in * channel_in; + reduce_sum_w(src, dst, nc_size, 1, 1, hw_size); +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/reduce_sum.h b/lite/backends/arm/math/reduce_sum.h new file mode 100644 index 0000000000..74e0b6dc75 --- /dev/null +++ b/lite/backends/arm/math/reduce_sum.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void reduce_sum_n(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_sum_c(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_sum_h(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_sum_w(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_sum_nc(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_sum_ch(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_sum_hw(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_sum_all(const T* src, T* dst, int all_size); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 83789070cc..40cb03872d 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -68,6 +68,7 @@ add_kernel(sequence_conv_compute_arm ARM extra SRCS sequence_conv_compute.cc DEP add_kernel(layer_norm_compute_arm ARM extra SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(gather_compute_arm ARM extra SRCS gather_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(reduce_prod_compute_arm ARM extra SRCS reduce_prod_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(reduce_sum_compute_arm ARM extra SRCS reduce_sum_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(split_lod_tensor_compute_arm ARM extra SRCS split_lod_tensor_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(merge_lod_tensor_compute_arm ARM extra SRCS merge_lod_tensor_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(anchor_generator_compute_arm ARM extra SRCS anchor_generator_compute.cc DEPS ${lite_kernel_deps} math_arm) diff --git a/lite/kernels/arm/reduce_sum_compute.cc b/lite/kernels/arm/reduce_sum_compute.cc new file mode 100644 index 0000000000..261ed2b6a3 --- /dev/null +++ b/lite/kernels/arm/reduce_sum_compute.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/reduce_sum_compute.h" +#include +#include +#include "lite/backends/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void ReduceSumCompute::Run() { + auto& param = this->template Param(); + auto* input = param.x->template data(); + auto x_dims = param.x->dims(); + int x_rank = x_dims.size(); + auto* output = param.output->template mutable_data(); + std::vector dim = param.dim; + bool keep_dim = param.keep_dim; + bool reduce_all = param.reduce_all; + + if (!dim.empty()) { + for (int i = 0; i < dim.size(); i++) { + if (dim[i] < 0) { + dim[i] += x_rank; + } + } + } + + if (reduce_all) { + lite::arm::math::reduce_sum_all(input, output, x_dims.production()); + } else { + int n_in = 1; + int c_in = 1; + int h_in = 1; + int w_in = 1; + switch (x_dims.size()) { + case 4: + w_in = x_dims[3]; + case 3: + h_in = x_dims[2]; + case 2: + c_in = x_dims[1]; + case 1: + n_in = x_dims[0]; + break; + default: + LOG(FATAL) << "x_dims.size is " << x_dims.size() + << ", which should not be over than 4."; + } + + if (dim.size() == 1) { + switch (dim[0]) { + case 0: + lite::arm::math::reduce_sum_n(input, output, n_in, c_in, h_in, w_in); + break; + case 1: + lite::arm::math::reduce_sum_c(input, output, n_in, c_in, h_in, w_in); + break; + case 2: + lite::arm::math::reduce_sum_h(input, output, n_in, c_in, h_in, w_in); + break; + case 3: + lite::arm::math::reduce_sum_w(input, output, n_in, c_in, h_in, w_in); + break; + default: + LOG(FATAL) << "dim[0] is " << dim[0] + << ", which should be less than 4."; + } + } else if (dim.size() == 2) { + if (dim[0] == 0 && dim[1] == 1) { + lite::arm::math::reduce_sum_nc(input, output, n_in, c_in, h_in, w_in); + } else if (dim[0] == 1 && dim[1] == 2) { + lite::arm::math::reduce_sum_ch(input, output, n_in, c_in, h_in, w_in); + } else if (dim[0] == 2 && dim[1] == 3) { + lite::arm::math::reduce_sum_hw(input, output, n_in, c_in, h_in, w_in); + } else { + LOG(FATAL) + << "Only support the values of the dim are 0,1 1,2 or 2,3 for now."; + } + } else { + LOG(FATAL) << "dim's size: " << dim.size() + << " over than 2, which is not supported now!!"; + } + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(reduce_sum, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::ReduceSumCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .Finalize(); diff --git a/lite/kernels/arm/reduce_sum_compute.h b/lite/kernels/arm/reduce_sum_compute.h new file mode 100644 index 0000000000..15dcc90b64 --- /dev/null +++ b/lite/kernels/arm/reduce_sum_compute.h @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/backends/arm/math/type_trans.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class ReduceSumCompute : public KernelLite { + public: + void Run() override; + + virtual ~ReduceSumCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/reduce_sum_compute_test.cc b/lite/tests/kernels/reduce_sum_compute_test.cc index 18490e2f9e..c38132a1a0 100644 --- a/lite/tests/kernels/reduce_sum_compute_test.cc +++ b/lite/tests/kernels/reduce_sum_compute_test.cc @@ -340,10 +340,10 @@ TEST(ReduceSum, precision) { Place place(TARGET(kX86)); test_reduce_sum(place); #endif - // #ifdef LITE_WITH_ARM - // Place place(TARGET(kARM)); - // test_reduce_sum(place); - // #endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_reduce_sum(place); +#endif } } // namespace lite -- GitLab