From d132b8f3fb846ccda3fe16ffb881f8bc13fd5d38 Mon Sep 17 00:00:00 2001 From: jingqinghe Date: Wed, 5 Aug 2020 14:07:29 +0800 Subject: [PATCH] add support for three-dimentional input in reduce max test=develop --- lite/backends/arm/math/reduce_max.cc | 62 +++++++ lite/backends/arm/math/reduce_max.h | 28 +++ lite/kernels/arm/reduce_max_compute.cc | 84 ++++++--- lite/tests/kernels/reduce_max_compute_test.cc | 171 +++++++++++++++--- 4 files changed, 297 insertions(+), 48 deletions(-) diff --git a/lite/backends/arm/math/reduce_max.cc b/lite/backends/arm/math/reduce_max.cc index 5c75960d72..302dcc105d 100644 --- a/lite/backends/arm/math/reduce_max.cc +++ b/lite/backends/arm/math/reduce_max.cc @@ -46,6 +46,68 @@ void reduce_n(const float* src, } } +template <> +void reduce_first_of_three(const float* src, + float* dst, + int first_in, + int second_in, + int third_in){ + for (int i = 0; i < second_in; i++){ + for (int j = 0; j < third_in; j++){ + dst[i*third_in+j] = src[i*third_in+j]; + for (int k = 1; k < first_in; k++){ + dst[i*third_in+j] = src[k*second_in*third_in+i*third_in+j] > dst[i*third_in+j] ? src[k*second_in*third_in+i*third_in+j] : dst[i*third_in+j]; + } + } + } +} + +template <> +void reduce_second_of_three(const float* src, + float* dst, + int first_in, + int second_in, + int third_in){ + for (int i = 0; i < first_in; i++){ + for (int j = 0; j < third_in; j++){ + dst[i*third_in+j] = src[i*second_in*third_in+j]; + for (int k = 1; k < second_in; k++){ + dst[i*third_in+j] = src[i*second_in*third_in+third_in*k+j] > dst[i*third_in+j] ? src[i*second_in*third_in+third_in*k+j] : dst[i*third_in+j]; + } + } + } +} + +template <> +void reduce_third_of_three(const float* src, + float* dst, + int first_in, + int second_in, + int third_in){ + for (int i = 0; i < first_in; i++){ + for (int j = 0; j < second_in; j++){ + dst[i*second_in+j] = src[i*second_in*third_in+j*second_in]; + for (int k = 0; k< third_in; k++){ + dst[i*second_in+j] = src[i*second_in*third_in+j*second_in+k] > dst[i*second_in+j] ? src[i*second_in*third_in+j*second_in+k] : dst[i*second_in+j]; + } + } + } +} + +template <> +void reduce_all_of_three(const float* src, + float* dst, + int first_in, + int second_in, + int third_in){ + float max = src[0]; + int total_element = first_in * second_in * third_in; + for (int i = 0; i max ? src[i] : max; + } + dst[0] = max; +} + template <> void reduce_c(const float* src, float* dst, diff --git a/lite/backends/arm/math/reduce_max.h b/lite/backends/arm/math/reduce_max.h index dab9626182..972daa7994 100644 --- a/lite/backends/arm/math/reduce_max.h +++ b/lite/backends/arm/math/reduce_max.h @@ -35,6 +35,34 @@ void reduce_c(const T* src, int height_in, int width_in); +template +void reduce_all_of_three(const T* src, + T* dst, + int first_in, + int second_in, + int third_in); + +template +void reduce_first_of_three(const T* src, + T* dst, + int first_in, + int second_in, + int third_in); + +template +void reduce_second_of_three(const T* src, + T* dst, + int first_in, + int second_in, + int third_in); + +template +void reduce_third_of_three(const T* src, + T* dst, + int first_in, + int second_in, + int third_in); + template void reduce_h(const T* src, T* dst, diff --git a/lite/kernels/arm/reduce_max_compute.cc b/lite/kernels/arm/reduce_max_compute.cc index 7a4a4313e0..ddd0723b70 100644 --- a/lite/kernels/arm/reduce_max_compute.cc +++ b/lite/kernels/arm/reduce_max_compute.cc @@ -25,6 +25,7 @@ void ReduceMaxCompute::Run() { auto& param = Param(); const float* input = param.X->data(); auto x_dims = param.X->dims(); + int x_rank = x_dims.size(); float* output = param.Out->mutable_data(); bool keep_dim = param.keep_dim; @@ -37,42 +38,77 @@ void ReduceMaxCompute::Run() { } } } - int n_in = x_dims[0]; - int c_in = x_dims[1]; - int h_in = x_dims[2]; - int w_in = x_dims[3]; - if (dim.size() == 0) { - lite::arm::math::reduce_all(input, output, n_in, c_in, h_in, w_in); - } else if (dim.size() == 1) { - switch (dim[0]) { + + if (x_dims.size()==3){ + if (dim.size() == 0 || dim.size() == 3){ + lite::arm::math::reduce_all_of_three(input, output, x_dims[0], x_dims[1], x_dims[2]); + } + else if (dim.size() == 1){ + switch (dim[0]) + { case 0: - lite::arm::math::reduce_n(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_first_of_three(input, output, x_dims[0], x_dims[1], x_dims[2]); break; case 1: - lite::arm::math::reduce_c(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_second_of_three(input, output, x_dims[0], x_dims[1], x_dims[2]); break; + case 2: - lite::arm::math::reduce_h(input, output, n_in, c_in, h_in, w_in); - break; - case 3: - lite::arm::math::reduce_w(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_third_of_three(input, output, x_dims[0], x_dims[1], x_dims[2]); break; default: LOG(FATAL) << "error!!!"; + } + } + else if (dim.size() == 2){ + } - } else if (dim.size() == 2) { - if (dim[0] == 0 && dim[1] == 1) { - lite::arm::math::reduce_nc(input, output, n_in, c_in, h_in, w_in); - } else if (dim[0] == 1 && dim[1] == 2) { - lite::arm::math::reduce_ch(input, output, n_in, c_in, h_in, w_in); - } else if (dim[0] == 2 && dim[1] == 3) { - lite::arm::math::reduce_hw(input, output, n_in, c_in, h_in, w_in); + else { + LOG(FATAL) << "dim size should not larger than 3!!!"; + } + + } + + else if (x_dims.size()==4){ + int n_in = x_dims[0]; + int c_in = x_dims[1]; + int h_in = x_dims[2]; + int w_in = x_dims[3]; + + if (dim.size() == 0) { + lite::arm::math::reduce_all(input, output, n_in, c_in, h_in, w_in); + } else if (dim.size() == 1) { + switch (dim[0]) { + case 0: + lite::arm::math::reduce_n(input, output, n_in, c_in, h_in, w_in); + break; + case 1: + lite::arm::math::reduce_c(input, output, n_in, c_in, h_in, w_in); + break; + case 2: + lite::arm::math::reduce_h(input, output, n_in, c_in, h_in, w_in); + break; + case 3: + lite::arm::math::reduce_w(input, output, n_in, c_in, h_in, w_in); + break; + default: + LOG(FATAL) << "error!!!"; + } + } else if (dim.size() == 2) { + if (dim[0] == 0 && dim[1] == 1) { + lite::arm::math::reduce_nc(input, output, n_in, c_in, h_in, w_in); + } else if (dim[0] == 1 && dim[1] == 2) { + lite::arm::math::reduce_ch(input, output, n_in, c_in, h_in, w_in); + } else if (dim[0] == 2 && dim[1] == 3) { + lite::arm::math::reduce_hw(input, output, n_in, c_in, h_in, w_in); + } else { + LOG(FATAL) << "invalid dim!!"; + } } else { - LOG(FATAL) << "invalid dim!!"; + LOG(FATAL) << "dim's size over than 2, which is not supported now!!"; } - } else { - LOG(FATAL) << "dim's size over than 2, which is not supported now!!"; } + } } // namespace arm diff --git a/lite/tests/kernels/reduce_max_compute_test.cc b/lite/tests/kernels/reduce_max_compute_test.cc index 506a45368b..ac32f0ed97 100644 --- a/lite/tests/kernels/reduce_max_compute_test.cc +++ b/lite/tests/kernels/reduce_max_compute_test.cc @@ -190,6 +190,71 @@ void reduce_hw(const float* src, reduce_w(tmp_out, dst, num_in, channel_in, 1, width_in); } +void reduce_first_of_three(const float* src, + float* dst, + int first_in, + int second_in, + int third_in){ + + for (int i = 0; i < second_in; i++){ + for (int j = 0; j < third_in; j++){ + dst[i*third_in+j] = src[i*third_in+j]; + for (int k = 1; k < first_in; k++){ + dst[i*third_in+j] = src[k*second_in*third_in+i*third_in+j] > dst[i*third_in+j] ? src[k*second_in*third_in+i*third_in+j] : dst[i*third_in+j]; + } + } + } +} + + +void reduce_second_of_three(const float* src, + float* dst, + int first_in, + int second_in, + int third_in){ + + for (int i = 0; i < first_in; i++){ + for (int j = 0; j < third_in; j++){ + dst[i*third_in+j] = src[i*second_in*third_in+j]; + for (int k = 1; k < second_in; k++){ + dst[i*third_in+j] = src[i*second_in*third_in+third_in*k+j] > dst[i*third_in+j] ? src[i*second_in*third_in+third_in*k+j] : dst[i*third_in+j]; + } + } + } +} + + +void reduce_third_of_three(const float* src, + float* dst, + int first_in, + int second_in, + int third_in){ + + for (int i = 0; i < first_in; i++){ + for (int j = 0; j < second_in; j++){ + dst[i*second_in+j] = src[i*second_in*third_in+j*second_in]; + for (int k = 0; k< third_in; k++){ + dst[i*second_in+j] = src[i*second_in*third_in+j*second_in+k] > dst[i*second_in+j] ? src[i*second_in*third_in+j*second_in+k] : dst[i*second_in+j]; + } + } + } +} + + +void reduce_all_of_three(const float* src, + float* dst, + int first_in, + int second_in, + int third_in){ + float max = src[0]; + int total_element = first_in * second_in * third_in; + for (int i = 0; i max ? src[i] : max; + } + dst[0] = max; +} + + class ReduceMaxComputeTester : public arena::TestCase { protected: // common attributes for this op. @@ -256,41 +321,76 @@ class ReduceMaxComputeTester : public arena::TestCase { } auto* out_data = out->mutable_data(); - int in_n = x_dims_[0]; - int in_c = x_dims_[1]; - int in_h = x_dims_[2]; - int in_w = x_dims_[3]; - - if (dim_.size() == 0) { - reduce_all(x_data, out_data, in_n, in_c, in_h, in_w); - } else if (dim_.size() == 1) { - switch (dim_[0]) { + + if (x_dims_.size()==3){ + if (dim_.size() == 0 || dim_.size() == 3){ + reduce_all_of_three(x_data, out_data, x_dims_[0], x_dims_[1], x_dims_[2]); + } + else if (dim_.size() == 1){ + switch (dim_[0]) + { case 0: - reduce_n(x_data, out_data, in_n, in_c, in_h, in_w); + reduce_first_of_three(x_data, out_data, x_dims_[0], x_dims_[1], x_dims_[2]); break; case 1: - reduce_c(x_data, out_data, in_n, in_c, in_h, in_w); + reduce_second_of_three(x_data, out_data, x_dims_[0], x_dims_[1], x_dims_[2]); break; + case 2: - reduce_h(x_data, out_data, in_n, in_c, in_h, in_w); - break; - case 3: - reduce_w(x_data, out_data, in_n, in_c, in_h, in_w); + reduce_third_of_three(x_data, out_data, x_dims_[0], x_dims_[1], x_dims_[2]); break; default: LOG(FATAL) << "error!!!"; + } } - } else if (dim_.size() == 2) { - if (dim_[0] == 0 && dim_[1] == 1) { - reduce_nc(x_data, out_data, in_n, in_c, in_h, in_w); - } else if (dim_[0] == 1 && dim_[1] == 2) { - reduce_ch(x_data, out_data, in_n, in_c, in_h, in_w); - } else if (dim_[0] == 2 && dim_[1] == 3) { - reduce_hw(x_data, out_data, in_n, in_c, in_h, in_w); - } else { - LOG(FATAL) << "invalid dims_!!"; + else if (dim_.size() == 2){ + LOG(FATAL) << "invalid dims_!!"; + } + else { + LOG(FATAL) << "dim size should not larger than 3!!!"; } + } + else if (x_dims_.size()==4){ + int in_n = x_dims_[0]; + int in_c = x_dims_[1]; + int in_h = x_dims_[2]; + int in_w = x_dims_[3]; + if (dim_.size() == 0) { + reduce_all(x_data, out_data, in_n, in_c, in_h, in_w); + } else if (dim_.size() == 1) { + switch (dim_[0]) { + case 0: + reduce_n(x_data, out_data, in_n, in_c, in_h, in_w); + break; + case 1: + reduce_c(x_data, out_data, in_n, in_c, in_h, in_w); + break; + case 2: + reduce_h(x_data, out_data, in_n, in_c, in_h, in_w); + break; + case 3: + reduce_w(x_data, out_data, in_n, in_c, in_h, in_w); + break; + default: + LOG(FATAL) << "error!!!"; + } + } else if (dim_.size() == 2) { + if (dim_[0] == 0 && dim_[1] == 1) { + reduce_nc(x_data, out_data, in_n, in_c, in_h, in_w); + } else if (dim_[0] == 1 && dim_[1] == 2) { + reduce_ch(x_data, out_data, in_n, in_c, in_h, in_w); + } else if (dim_[0] == 2 && dim_[1] == 3) { + reduce_hw(x_data, out_data, in_n, in_c, in_h, in_w); + } else { + LOG(FATAL) << "invalid dims_!!"; + } + } + + } + + + } void PrepareOpDesc(cpp::OpDesc* op_desc) { @@ -333,6 +433,28 @@ void test_reduce_max(Place place) { } } +void test_reduce_max_for_three(Place place) { + std::vector> reduce_dim{ + {0}, {1}, {2}}; + for (auto f : {1, 3}) { + for (auto s : {1, 2}) { + for (auto t : {1, 3}) { + for (bool keep_dim : {false, true}) { + for (auto dim : reduce_dim) { + auto x_dims = DDim(std::vector({f, s, t})); + std::unique_ptr tester( + new ReduceMaxComputeTester( + place, "def", dim, keep_dim, x_dims)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } + } + } + + TEST(ReduceMax, precision) { // #ifdef LITE_WITH_X86 // Place place(TARGET(kX86)); @@ -340,6 +462,7 @@ TEST(ReduceMax, precision) { #ifdef LITE_WITH_ARM Place place(TARGET(kARM)); test_reduce_max(place); + test_reduce_max_for_three(place); #endif } -- GitLab