diff --git a/lite/operators/conv_op.cc b/lite/operators/conv_op.cc index 9ae52d1cb6a406dc8d1059ad97f3757dbc0a31fa..d9c0ecb4fd8457782ac90850b8b6a002c7dfcffe 100644 --- a/lite/operators/conv_op.cc +++ b/lite/operators/conv_op.cc @@ -52,12 +52,12 @@ inline int ConvOutputSize(int input_size, return output_size; } -void UpdatePaddingAndDilation(std::vector* paddings, - std::vector* dilations, - const std::vector& strides, - const std::string padding_algorithm, - const lite::DDim data_dims, - const lite::DDim& ksize) { +inline void UpdatePaddingAndDilation(std::vector* paddings, + std::vector* dilations, + const std::vector& strides, + const std::string padding_algorithm, + const lite::DDim data_dims, + const lite::DDim& ksize) { // when padding_desc is "VALID" or "SAME" if (padding_algorithm == "SAME") { for (size_t i = 0; i < strides.size(); ++i) { diff --git a/lite/tests/kernels/reduce_sum_compute_test.cc b/lite/tests/kernels/reduce_sum_compute_test.cc index 9cfe213750b1191c1ef8fe7fba1b1c1035c2ae42..644d761de06fd7f117e781d5146812b7a8ae92d8 100644 --- a/lite/tests/kernels/reduce_sum_compute_test.cc +++ b/lite/tests/kernels/reduce_sum_compute_test.cc @@ -196,7 +196,7 @@ class ReduceSumComputeTester : public arena::TestCase { std::vector dim_{0}; bool keep_dim_ = false; bool reduce_all_ = false; - DDim x_dims_{{3, 2, 3, 4}}; + DDim x_dims_; public: ReduceSumComputeTester(const Place& place, @@ -255,39 +255,76 @@ class ReduceSumComputeTester : public arena::TestCase { out->Resize(DDim(out_dims)); auto* out_data = out->mutable_data(); - int in_n = x_dims_[0]; - int in_c = x_dims_[1]; - int in_h = x_dims_[2]; - int in_w = x_dims_[3]; + if (x_dims_.size() == 4) { + int in_n = x_dims_[0]; + int in_c = x_dims_[1]; + int in_h = x_dims_[2]; + int in_w = x_dims_[3]; - if (reduce_all_) { - reduce_sum_all(x_data, out_data, in_n, in_c, in_h, in_w); - } else if (dim_.size() == 1) { - switch (dim_[0]) { - case 0: - reduce_sum_n(x_data, out_data, in_n, in_c, in_h, in_w); - break; - case 1: - reduce_sum_c(x_data, out_data, in_n, in_c, in_h, in_w); - break; - case 2: - reduce_sum_h(x_data, out_data, in_n, in_c, in_h, in_w); - break; - case 3: - reduce_sum_w(x_data, out_data, in_n, in_c, in_h, in_w); - break; - default: - LOG(FATAL) << "error!!!"; + if (reduce_all_) { + reduce_sum_all(x_data, out_data, in_n, in_c, in_h, in_w); + } else if (dim_.size() == 1) { + switch (dim_[0]) { + case 0: + reduce_sum_n(x_data, out_data, in_n, in_c, in_h, in_w); + break; + case 1: + reduce_sum_c(x_data, out_data, in_n, in_c, in_h, in_w); + break; + case 2: + reduce_sum_h(x_data, out_data, in_n, in_c, in_h, in_w); + break; + case 3: + reduce_sum_w(x_data, out_data, in_n, in_c, in_h, in_w); + break; + default: + LOG(FATAL) << "error!!!"; + } + } else if (dim_.size() == 2) { + if (dim_[0] == 0 && dim_[1] == 1) { + reduce_sum_nc(x_data, out_data, in_n, in_c, in_h, in_w); + } else if (dim_[0] == 1 && dim_[1] == 2) { + reduce_sum_ch(x_data, out_data, in_n, in_c, in_h, in_w); + } else if (dim_[0] == 2 && dim_[1] == 3) { + reduce_sum_hw(x_data, out_data, in_n, in_c, in_h, in_w); + } else { + LOG(FATAL) << "invalid dims_!!"; + } } - } else if (dim_.size() == 2) { - if (dim_[0] == 0 && dim_[1] == 1) { - reduce_sum_nc(x_data, out_data, in_n, in_c, in_h, in_w); - } else if (dim_[0] == 1 && dim_[1] == 2) { - reduce_sum_ch(x_data, out_data, in_n, in_c, in_h, in_w); - } else if (dim_[0] == 2 && dim_[1] == 3) { - reduce_sum_hw(x_data, out_data, in_n, in_c, in_h, in_w); - } else { - LOG(FATAL) << "invalid dims_!!"; + } else { + if (dim_.size() == 1 && !reduce_all_) { + switch (dim_[0]) { + case 0: + for (int i = 0; i < x_dims_[1]; i++) { + for (int j = 0; j < x_dims_[2]; j++) { + out_data[i * x_dims_[2] + j] = + i * x_dims_[0] * x_dims_[2] + j * x_dims_[0] + + x_dims_[0] * (x_dims_[0] - 1) * x_dims_[1] * x_dims_[2] / 2; + } + } + break; + case 1: + for (int i = 0; i < x_dims_[0]; i++) { + for (int j = 0; j < x_dims_[2]; j++) { + out_data[i * x_dims_[2] + j] = + i * x_dims_[1] * x_dims_[1] * x_dims_[2] + j * x_dims_[1] + + x_dims_[1] * (x_dims_[1] - 1) * x_dims_[2] / 2; + } + } + break; + case 2: + for (int i = 0; i < x_dims_[0]; i++) { + for (int j = 0; j < x_dims_[1]; j++) { + out_data[i * x_dims_[1] + j] = + i * x_dims_[1] * x_dims_[2] * x_dims_[2] + + j * x_dims_[2] * x_dims_[2] + + x_dims_[2] * (x_dims_[2] - 1) / 2; + } + } + break; + default: + LOG(FATAL) << "error!!!"; + } } } } @@ -333,6 +370,21 @@ void test_reduce_sum(Place place) { } } } + + std::vector> reduce_dimm{{0}, {1}, {2}}; + for (auto dim : reduce_dimm) { + for (auto c : {1, 3}) { + for (auto h : {1, 3}) { + for (auto w : {1, 4}) { + auto x_dims = DDim(std::vector({c, h, w})); + std::unique_ptr tester(new ReduceSumComputeTester( + place, "def", dim, false, false, x_dims)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } } TEST(ReduceSum, precision) {