diff --git a/lite/tests/kernels/reduce_sum_compute_test.cc b/lite/tests/kernels/reduce_sum_compute_test.cc index 9cfe213750b1191c1ef8fe7fba1b1c1035c2ae42..c00962344970b6d75b97afb84c975d34a1691606 100644 --- a/lite/tests/kernels/reduce_sum_compute_test.cc +++ b/lite/tests/kernels/reduce_sum_compute_test.cc @@ -119,6 +119,55 @@ void reduce_sum_w(const float* src, } } +void reduce_sum_c( + const float* src, float* dst, int channel_in, int height_in, int width_in) { + int hw_size = height_in * width_in; + int data_index, src_index; + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + data_index = h * width_in + w; + dst[data_index] = 0.0; + for (int n = 0; n < channel_in; ++n) { + src_index = n * hw_size + data_index; + dst[data_index] += static_cast(src[src_index]); + } + } + } +} +void reduce_sum_h( + const float* src, float* dst, int channel_in, int height_in, int width_in) { + int hw_size = height_in * width_in; + int data_index, src_index0, src_index; + for (int c = 0; c < channel_in; ++c) { + for (int w = 0; w < width_in; ++w) { + data_index = c * width_in + w; + src_index0 = c * hw_size + w; + dst[data_index] = 0.0; + for (int h = 0; h < height_in; ++h) { + src_index = src_index0 + h * width_in; + dst[data_index] += static_cast(src[src_index]); + } + } + } +} +void reduce_sum_w( + const float* src, float* dst, int channel_in, int height_in, int width_in) { + int hw_size = height_in * width_in; + int data_index = 0; + int src_index0 = 0; + int src_index = 0; + for (int c = 0; c < channel_in; ++c) { + for (int h = 0; h < height_in; ++h) { + data_index = c * height_in + h; + src_index0 = c * hw_size + h * width_in; + dst[data_index] = 0.0; + for (int w = 0; w < width_in; ++w) { + src_index = src_index0 + w; + dst[data_index] += static_cast(src[src_index]); + } + } + } +} void reduce_sum_all(const float* src, float* dst, int num_in, @@ -255,39 +304,60 @@ class ReduceSumComputeTester : public arena::TestCase { out->Resize(DDim(out_dims)); auto* out_data = out->mutable_data(); - int in_n = x_dims_[0]; - int in_c = x_dims_[1]; - int in_h = x_dims_[2]; - int in_w = x_dims_[3]; + if (x_dims_.size() == 4) { + int in_n = x_dims_[0]; + int in_c = x_dims_[1]; + int in_h = x_dims_[2]; + int in_w = x_dims_[3]; - if (reduce_all_) { - reduce_sum_all(x_data, out_data, in_n, in_c, in_h, in_w); - } else if (dim_.size() == 1) { - switch (dim_[0]) { - case 0: - reduce_sum_n(x_data, out_data, in_n, in_c, in_h, in_w); - break; - case 1: - reduce_sum_c(x_data, out_data, in_n, in_c, in_h, in_w); - break; - case 2: - reduce_sum_h(x_data, out_data, in_n, in_c, in_h, in_w); - break; - case 3: - reduce_sum_w(x_data, out_data, in_n, in_c, in_h, in_w); - break; - default: - LOG(FATAL) << "error!!!"; + if (reduce_all_) { + reduce_sum_all(x_data, out_data, in_n, in_c, in_h, in_w); + } else if (dim_.size() == 1) { + switch (dim_[0]) { + case 0: + reduce_sum_n(x_data, out_data, in_n, in_c, in_h, in_w); + break; + case 1: + reduce_sum_c(x_data, out_data, in_n, in_c, in_h, in_w); + break; + case 2: + reduce_sum_h(x_data, out_data, in_n, in_c, in_h, in_w); + break; + case 3: + reduce_sum_w(x_data, out_data, in_n, in_c, in_h, in_w); + break; + default: + LOG(FATAL) << "error!!!"; + } + } else if (dim_.size() == 2) { + if (dim_[0] == 0 && dim_[1] == 1) { + reduce_sum_nc(x_data, out_data, in_n, in_c, in_h, in_w); + } else if (dim_[0] == 1 && dim_[1] == 2) { + reduce_sum_ch(x_data, out_data, in_n, in_c, in_h, in_w); + } else if (dim_[0] == 2 && dim_[1] == 3) { + reduce_sum_hw(x_data, out_data, in_n, in_c, in_h, in_w); + } else { + LOG(FATAL) << "invalid dims_!!"; + } } - } else if (dim_.size() == 2) { - if (dim_[0] == 0 && dim_[1] == 1) { - reduce_sum_nc(x_data, out_data, in_n, in_c, in_h, in_w); - } else if (dim_[0] == 1 && dim_[1] == 2) { - reduce_sum_ch(x_data, out_data, in_n, in_c, in_h, in_w); - } else if (dim_[0] == 2 && dim_[1] == 3) { - reduce_sum_hw(x_data, out_data, in_n, in_c, in_h, in_w); - } else { - LOG(FATAL) << "invalid dims_!!"; + } else { + int in_c = x_dims_[0]; + int in_h = x_dims_[1]; + int in_w = x_dims_[2]; + if (dim_.size() == 1 && !reduce_all_) { + switch (dim_[0]) { + case 0: + reduce_sum_c(x_data, out_data, in_c, in_h, in_w); + break; + case 1: + reduce_sum_h(x_data, out_data, in_c, in_h, in_w); + break; + case 2: + reduce_sum_w(x_data, out_data, in_c, in_h, in_w); + break; + default: + LOG(FATAL) << "error!!!"; + } } } } @@ -333,6 +403,20 @@ void test_reduce_sum(Place place) { } } } + std::vector> reduce_dimm{{0}, {1}, {2}}; + for (auto dim : reduce_dimm) { + for (auto c : {1, 3}) { + for (auto h : {1, 3}) { + for (auto w : {1, 4}) { + auto x_dims = DDim(std::vector({c, h, w})); + std::unique_ptr tester(new ReduceSumComputeTester( + place, "def", dim, false, false, x_dims)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } } TEST(ReduceSum, precision) {