diff --git a/lite/backends/arm/math/reduce_max.cc b/lite/backends/arm/math/reduce_max.cc index 302dcc105d3756ad88c24a18f8a545074224a8a0..0ca0cf2abb369fc27bacb0cd75755b56551d6c1b 100644 --- a/lite/backends/arm/math/reduce_max.cc +++ b/lite/backends/arm/math/reduce_max.cc @@ -47,62 +47,62 @@ void reduce_n(const float* src, } template <> -void reduce_first_of_three(const float* src, - float* dst, - int first_in, - int second_in, - int third_in){ - for (int i = 0; i < second_in; i++){ - for (int j = 0; j < third_in; j++){ - dst[i*third_in+j] = src[i*third_in+j]; - for (int k = 1; k < first_in; k++){ - dst[i*third_in+j] = src[k*second_in*third_in+i*third_in+j] > dst[i*third_in+j] ? src[k*second_in*third_in+i*third_in+j] : dst[i*third_in+j]; +void reduce_first_of_three( + const float* src, float* dst, int first_in, int second_in, int third_in) { + for (int i = 0; i < second_in; i++) { + for (int j = 0; j < third_in; j++) { + dst[i * third_in + j] = src[i * third_in + j]; + for (int k = 1; k < first_in; k++) { + dst[i * third_in + j] = + src[k * second_in * third_in + i * third_in + j] > + dst[i * third_in + j] + ? src[k * second_in * third_in + i * third_in + j] + : dst[i * third_in + j]; } } } } template <> -void reduce_second_of_three(const float* src, - float* dst, - int first_in, - int second_in, - int third_in){ - for (int i = 0; i < first_in; i++){ - for (int j = 0; j < third_in; j++){ - dst[i*third_in+j] = src[i*second_in*third_in+j]; - for (int k = 1; k < second_in; k++){ - dst[i*third_in+j] = src[i*second_in*third_in+third_in*k+j] > dst[i*third_in+j] ? src[i*second_in*third_in+third_in*k+j] : dst[i*third_in+j]; +void reduce_second_of_three( + const float* src, float* dst, int first_in, int second_in, int third_in) { + for (int i = 0; i < first_in; i++) { + for (int j = 0; j < third_in; j++) { + dst[i * third_in + j] = src[i * second_in * third_in + j]; + for (int k = 1; k < second_in; k++) { + dst[i * third_in + j] = + src[i * second_in * third_in + third_in * k + j] > + dst[i * third_in + j] + ? src[i * second_in * third_in + third_in * k + j] + : dst[i * third_in + j]; } } } } template <> -void reduce_third_of_three(const float* src, - float* dst, - int first_in, - int second_in, - int third_in){ - for (int i = 0; i < first_in; i++){ - for (int j = 0; j < second_in; j++){ - dst[i*second_in+j] = src[i*second_in*third_in+j*second_in]; - for (int k = 0; k< third_in; k++){ - dst[i*second_in+j] = src[i*second_in*third_in+j*second_in+k] > dst[i*second_in+j] ? src[i*second_in*third_in+j*second_in+k] : dst[i*second_in+j]; +void reduce_third_of_three( + const float* src, float* dst, int first_in, int second_in, int third_in) { + for (int i = 0; i < first_in; i++) { + for (int j = 0; j < second_in; j++) { + dst[i * second_in + j] = src[i * second_in * third_in + j * second_in]; + for (int k = 0; k < third_in; k++) { + dst[i * second_in + j] = + src[i * second_in * third_in + j * second_in + k] > + dst[i * second_in + j] + ? src[i * second_in * third_in + j * second_in + k] + : dst[i * second_in + j]; } } } } template <> -void reduce_all_of_three(const float* src, - float* dst, - int first_in, - int second_in, - int third_in){ +void reduce_all_of_three( + const float* src, float* dst, int first_in, int second_in, int third_in) { float max = src[0]; int total_element = first_in * second_in * third_in; - for (int i = 0; i max ? src[i] : max; } dst[0] = max; diff --git a/lite/backends/arm/math/reduce_max.h b/lite/backends/arm/math/reduce_max.h index 972daa79941d7429f33116559de0bd4a63ebcccb..e8dafd076536abee12c7d9abe57627ef91b7c3c9 100644 --- a/lite/backends/arm/math/reduce_max.h +++ b/lite/backends/arm/math/reduce_max.h @@ -36,32 +36,20 @@ void reduce_c(const T* src, int width_in); template -void reduce_all_of_three(const T* src, - T* dst, - int first_in, - int second_in, - int third_in); +void reduce_all_of_three( + const T* src, T* dst, int first_in, int second_in, int third_in); template -void reduce_first_of_three(const T* src, - T* dst, - int first_in, - int second_in, - int third_in); +void reduce_first_of_three( + const T* src, T* dst, int first_in, int second_in, int third_in); template -void reduce_second_of_three(const T* src, - T* dst, - int first_in, - int second_in, - int third_in); +void reduce_second_of_three( + const T* src, T* dst, int first_in, int second_in, int third_in); template -void reduce_third_of_three(const T* src, - T* dst, - int first_in, - int second_in, - int third_in); +void reduce_third_of_three( + const T* src, T* dst, int first_in, int second_in, int third_in); template void reduce_h(const T* src, diff --git a/lite/kernels/arm/reduce_max_compute.cc b/lite/kernels/arm/reduce_max_compute.cc index ddd0723b70c5405a8289b8d9ebac95fbf33825fc..51053a52f523de20b710144a3fbf99d8c249a008 100644 --- a/lite/kernels/arm/reduce_max_compute.cc +++ b/lite/kernels/arm/reduce_max_compute.cc @@ -25,7 +25,7 @@ void ReduceMaxCompute::Run() { auto& param = Param(); const float* input = param.X->data(); auto x_dims = param.X->dims(); - + int x_rank = x_dims.size(); float* output = param.Out->mutable_data(); bool keep_dim = param.keep_dim; @@ -39,37 +39,33 @@ void ReduceMaxCompute::Run() { } } - if (x_dims.size()==3){ - if (dim.size() == 0 || dim.size() == 3){ - lite::arm::math::reduce_all_of_three(input, output, x_dims[0], x_dims[1], x_dims[2]); - } - else if (dim.size() == 1){ - switch (dim[0]) - { - case 0: - lite::arm::math::reduce_first_of_three(input, output, x_dims[0], x_dims[1], x_dims[2]); - break; - case 1: - lite::arm::math::reduce_second_of_three(input, output, x_dims[0], x_dims[1], x_dims[2]); - break; + if (x_dims.size() == 3) { + if (dim.size() == 0 || dim.size() == 3) { + lite::arm::math::reduce_all_of_three( + input, output, x_dims[0], x_dims[1], x_dims[2]); + } else if (dim.size() == 1) { + switch (dim[0]) { + case 0: + lite::arm::math::reduce_first_of_three( + input, output, x_dims[0], x_dims[1], x_dims[2]); + break; + case 1: + lite::arm::math::reduce_second_of_three( + input, output, x_dims[0], x_dims[1], x_dims[2]); + break; - case 2: - lite::arm::math::reduce_third_of_three(input, output, x_dims[0], x_dims[1], x_dims[2]); - break; - default: - LOG(FATAL) << "error!!!"; + case 2: + lite::arm::math::reduce_third_of_three( + input, output, x_dims[0], x_dims[1], x_dims[2]); + break; + default: + LOG(FATAL) << "error!!!"; } - } - else if (dim.size() == 2){ - - } - else { + } else if (dim.size() == 2) { + } else { LOG(FATAL) << "dim size should not larger than 3!!!"; } - - } - - else if (x_dims.size()==4){ + } else if (x_dims.size() == 4) { int n_in = x_dims[0]; int c_in = x_dims[1]; int h_in = x_dims[2]; @@ -78,37 +74,38 @@ void ReduceMaxCompute::Run() { if (dim.size() == 0) { lite::arm::math::reduce_all(input, output, n_in, c_in, h_in, w_in); } else if (dim.size() == 1) { - switch (dim[0]) { - case 0: - lite::arm::math::reduce_n(input, output, n_in, c_in, h_in, w_in); - break; - case 1: - lite::arm::math::reduce_c(input, output, n_in, c_in, h_in, w_in); - break; - case 2: - lite::arm::math::reduce_h(input, output, n_in, c_in, h_in, w_in); - break; - case 3: - lite::arm::math::reduce_w(input, output, n_in, c_in, h_in, w_in); - break; - default: - LOG(FATAL) << "error!!!"; + switch (dim[0]) { + case 0: + lite::arm::math::reduce_n(input, output, n_in, c_in, h_in, w_in); + break; + case 1: + lite::arm::math::reduce_c(input, output, n_in, c_in, h_in, w_in); + break; + case 2: + lite::arm::math::reduce_h(input, output, n_in, c_in, h_in, w_in); + break; + case 3: + lite::arm::math::reduce_w(input, output, n_in, c_in, h_in, w_in); + break; + default: + LOG(FATAL) << "error!!!"; } } else if (dim.size() == 2) { - if (dim[0] == 0 && dim[1] == 1) { - lite::arm::math::reduce_nc(input, output, n_in, c_in, h_in, w_in); - } else if (dim[0] == 1 && dim[1] == 2) { - lite::arm::math::reduce_ch(input, output, n_in, c_in, h_in, w_in); - } else if (dim[0] == 2 && dim[1] == 3) { - lite::arm::math::reduce_hw(input, output, n_in, c_in, h_in, w_in); - } else { - LOG(FATAL) << "invalid dim!!"; - } + if (dim[0] == 0 && dim[1] == 1) { + lite::arm::math::reduce_nc(input, output, n_in, c_in, h_in, w_in); + } else if (dim[0] == 1 && dim[1] == 2) { + lite::arm::math::reduce_ch(input, output, n_in, c_in, h_in, w_in); + } else if (dim[0] == 2 && dim[1] == 3) { + lite::arm::math::reduce_hw(input, output, n_in, c_in, h_in, w_in); + } else { + LOG(FATAL) << "invalid dim!!"; + } } else { LOG(FATAL) << "dim's size over than 2, which is not supported now!!"; } + } else { + LOG(FATAL) << "only support input with 3&4 dimensions now!!"; } - } } // namespace arm diff --git a/lite/tests/kernels/reduce_max_compute_test.cc b/lite/tests/kernels/reduce_max_compute_test.cc index ac32f0ed97c55a7abf9c033784794d9caba3dd7e..bf9d4ea720dcc73bdf4f01b6cad10644a72de9b1 100644 --- a/lite/tests/kernels/reduce_max_compute_test.cc +++ b/lite/tests/kernels/reduce_max_compute_test.cc @@ -190,71 +190,64 @@ void reduce_hw(const float* src, reduce_w(tmp_out, dst, num_in, channel_in, 1, width_in); } -void reduce_first_of_three(const float* src, - float* dst, - int first_in, - int second_in, - int third_in){ - - for (int i = 0; i < second_in; i++){ - for (int j = 0; j < third_in; j++){ - dst[i*third_in+j] = src[i*third_in+j]; - for (int k = 1; k < first_in; k++){ - dst[i*third_in+j] = src[k*second_in*third_in+i*third_in+j] > dst[i*third_in+j] ? src[k*second_in*third_in+i*third_in+j] : dst[i*third_in+j]; +void reduce_first_of_three( + const float* src, float* dst, int first_in, int second_in, int third_in) { + for (int i = 0; i < second_in; i++) { + for (int j = 0; j < third_in; j++) { + dst[i * third_in + j] = src[i * third_in + j]; + for (int k = 1; k < first_in; k++) { + dst[i * third_in + j] = + src[k * second_in * third_in + i * third_in + j] > + dst[i * third_in + j] + ? src[k * second_in * third_in + i * third_in + j] + : dst[i * third_in + j]; } } } } - -void reduce_second_of_three(const float* src, - float* dst, - int first_in, - int second_in, - int third_in){ - - for (int i = 0; i < first_in; i++){ - for (int j = 0; j < third_in; j++){ - dst[i*third_in+j] = src[i*second_in*third_in+j]; - for (int k = 1; k < second_in; k++){ - dst[i*third_in+j] = src[i*second_in*third_in+third_in*k+j] > dst[i*third_in+j] ? src[i*second_in*third_in+third_in*k+j] : dst[i*third_in+j]; +void reduce_second_of_three( + const float* src, float* dst, int first_in, int second_in, int third_in) { + for (int i = 0; i < first_in; i++) { + for (int j = 0; j < third_in; j++) { + dst[i * third_in + j] = src[i * second_in * third_in + j]; + for (int k = 1; k < second_in; k++) { + dst[i * third_in + j] = + src[i * second_in * third_in + third_in * k + j] > + dst[i * third_in + j] + ? src[i * second_in * third_in + third_in * k + j] + : dst[i * third_in + j]; } } } } - -void reduce_third_of_three(const float* src, - float* dst, - int first_in, - int second_in, - int third_in){ - - for (int i = 0; i < first_in; i++){ - for (int j = 0; j < second_in; j++){ - dst[i*second_in+j] = src[i*second_in*third_in+j*second_in]; - for (int k = 0; k< third_in; k++){ - dst[i*second_in+j] = src[i*second_in*third_in+j*second_in+k] > dst[i*second_in+j] ? src[i*second_in*third_in+j*second_in+k] : dst[i*second_in+j]; +void reduce_third_of_three( + const float* src, float* dst, int first_in, int second_in, int third_in) { + for (int i = 0; i < first_in; i++) { + for (int j = 0; j < second_in; j++) { + dst[i * second_in + j] = src[i * second_in * third_in + j * second_in]; + for (int k = 0; k < third_in; k++) { + dst[i * second_in + j] = + src[i * second_in * third_in + j * second_in + k] > + dst[i * second_in + j] + ? src[i * second_in * third_in + j * second_in + k] + : dst[i * second_in + j]; } } } } - -void reduce_all_of_three(const float* src, - float* dst, - int first_in, - int second_in, - int third_in){ +void reduce_all_of_three( + const float* src, float* dst, int first_in, int second_in, int third_in) { float max = src[0]; int total_element = first_in * second_in * third_in; - for (int i = 0; i max ? src[i] : max; } dst[0] = max; } - class ReduceMaxComputeTester : public arena::TestCase { protected: // common attributes for this op. @@ -321,37 +314,36 @@ class ReduceMaxComputeTester : public arena::TestCase { } auto* out_data = out->mutable_data(); - - if (x_dims_.size()==3){ - if (dim_.size() == 0 || dim_.size() == 3){ - reduce_all_of_three(x_data, out_data, x_dims_[0], x_dims_[1], x_dims_[2]); - } - else if (dim_.size() == 1){ - switch (dim_[0]) - { - case 0: - reduce_first_of_three(x_data, out_data, x_dims_[0], x_dims_[1], x_dims_[2]); - break; - case 1: - reduce_second_of_three(x_data, out_data, x_dims_[0], x_dims_[1], x_dims_[2]); - break; - case 2: - reduce_third_of_three(x_data, out_data, x_dims_[0], x_dims_[1], x_dims_[2]); - break; - default: - LOG(FATAL) << "error!!!"; + if (x_dims_.size() == 3) { + if (dim_.size() == 0 || dim_.size() == 3) { + reduce_all_of_three( + x_data, out_data, x_dims_[0], x_dims_[1], x_dims_[2]); + } else if (dim_.size() == 1) { + switch (dim_[0]) { + case 0: + reduce_first_of_three( + x_data, out_data, x_dims_[0], x_dims_[1], x_dims_[2]); + break; + case 1: + reduce_second_of_three( + x_data, out_data, x_dims_[0], x_dims_[1], x_dims_[2]); + break; + + case 2: + reduce_third_of_three( + x_data, out_data, x_dims_[0], x_dims_[1], x_dims_[2]); + break; + default: + LOG(FATAL) << "error!!!"; } - } - else if (dim_.size() == 2){ - LOG(FATAL) << "invalid dims_!!"; - } - else { + } else if (dim_.size() == 2) { + LOG(FATAL) << "invalid dims_!!"; + } else { LOG(FATAL) << "dim size should not larger than 3!!!"; } - - } - else if (x_dims_.size()==4){ + + } else if (x_dims_.size() == 4) { int in_n = x_dims_[0]; int in_c = x_dims_[1]; int in_h = x_dims_[2]; @@ -384,13 +376,9 @@ class ReduceMaxComputeTester : public arena::TestCase { reduce_hw(x_data, out_data, in_n, in_c, in_h, in_w); } else { LOG(FATAL) << "invalid dims_!!"; - } + } } - } - - - } void PrepareOpDesc(cpp::OpDesc* op_desc) { @@ -434,26 +422,23 @@ void test_reduce_max(Place place) { } void test_reduce_max_for_three(Place place) { - std::vector> reduce_dim{ - {0}, {1}, {2}}; + std::vector> reduce_dim{{0}, {1}, {2}}; for (auto f : {1, 3}) { for (auto s : {1, 2}) { for (auto t : {1, 3}) { for (bool keep_dim : {false, true}) { for (auto dim : reduce_dim) { auto x_dims = DDim(std::vector({f, s, t})); - std::unique_ptr tester( - new ReduceMaxComputeTester( - place, "def", dim, keep_dim, x_dims)); + std::unique_ptr tester(new ReduceMaxComputeTester( + place, "def", dim, keep_dim, x_dims)); arena::Arena arena(std::move(tester), place, 2e-5); arena.TestPrecision(); - } } } } } } - +} TEST(ReduceMax, precision) { // #ifdef LITE_WITH_X86