From 4be0678b8d07b3833be26ca7086a24d49b567a31 Mon Sep 17 00:00:00 2001 From: GaoWei8 Date: Wed, 26 Feb 2020 02:48:29 +0000 Subject: [PATCH] refine methods of reduce sum test=develop --- lite/tests/kernels/reduce_sum_compute_test.cc | 80 ++++++------------- 1 file changed, 23 insertions(+), 57 deletions(-) diff --git a/lite/tests/kernels/reduce_sum_compute_test.cc b/lite/tests/kernels/reduce_sum_compute_test.cc index c009623449..9e979ec3d0 100644 --- a/lite/tests/kernels/reduce_sum_compute_test.cc +++ b/lite/tests/kernels/reduce_sum_compute_test.cc @@ -119,55 +119,32 @@ void reduce_sum_w(const float* src, } } -void reduce_sum_c( - const float* src, float* dst, int channel_in, int height_in, int width_in) { - int hw_size = height_in * width_in; - int data_index, src_index; - for (int h = 0; h < height_in; ++h) { - for (int w = 0; w < width_in; ++w) { - data_index = h * width_in + w; - dst[data_index] = 0.0; - for (int n = 0; n < channel_in; ++n) { - src_index = n * hw_size + data_index; - dst[data_index] += static_cast(src[src_index]); - } +void reduce_sum(const float* src, float* dst, const DDim& x_dim, int dims) { + int reduce_b[3] = {1, 1, 1}; + reduce_b[dims] = 0; + DDim reduce_dim{x_dim}; + reduce_dim[dims] = 1; + int dim_size = 1; + for (int i = 0; i < 3; i++) { + if (i != dims) { + dim_size *= x_dim[i]; } } -} -void reduce_sum_h( - const float* src, float* dst, int channel_in, int height_in, int width_in) { - int hw_size = height_in * width_in; - int data_index, src_index0, src_index; - for (int c = 0; c < channel_in; ++c) { - for (int w = 0; w < width_in; ++w) { - data_index = c * width_in + w; - src_index0 = c * hw_size + w; - dst[data_index] = 0.0; - for (int h = 0; h < height_in; ++h) { - src_index = src_index0 + h * width_in; - dst[data_index] += static_cast(src[src_index]); - } - } + for (int i = 0; i < dim_size; i++) { + dst[i] = 0.0; } -} -void reduce_sum_w( - const float* src, float* dst, int channel_in, int height_in, int width_in) { - int hw_size = height_in * width_in; - int data_index = 0; - int src_index0 = 0; - int src_index = 0; - for (int c = 0; c < channel_in; ++c) { - for (int h = 0; h < height_in; ++h) { - data_index = c * height_in + h; - src_index0 = c * hw_size + h * width_in; - dst[data_index] = 0.0; - for (int w = 0; w < width_in; ++w) { - src_index = src_index0 + w; - dst[data_index] += static_cast(src[src_index]); + for (int i = 0; i < x_dim[0]; i++) { + for (int j = 0; j < x_dim[1]; j++) { + for (int k = 0; k < x_dim[2]; k++) { + int src_index = i * x_dim[1] * x_dim[2] + j * x_dim[2] + k; + int dst_index = i * reduce_dim[1] * reduce_dim[2] * reduce_b[0] + + j * reduce_dim[2] * reduce_b[1] + k * reduce_b[2]; + dst[dst_index] += static_cast(src[src_index]); } } } } + void reduce_sum_all(const float* src, float* dst, int num_in, @@ -341,22 +318,11 @@ class ReduceSumComputeTester : public arena::TestCase { } } } else { - int in_c = x_dims_[0]; - int in_h = x_dims_[1]; - int in_w = x_dims_[2]; if (dim_.size() == 1 && !reduce_all_) { - switch (dim_[0]) { - case 0: - reduce_sum_c(x_data, out_data, in_c, in_h, in_w); - break; - case 1: - reduce_sum_h(x_data, out_data, in_c, in_h, in_w); - break; - case 2: - reduce_sum_w(x_data, out_data, in_c, in_h, in_w); - break; - default: - LOG(FATAL) << "error!!!"; + if (dim_[0] < x_dims_.size()) { + reduce_sum(x_data, out_data, x_dims_, dim_[0]); + } else { + LOG(FATAL) << "error!!!"; } } } -- GitLab