diff --git a/mace/kernels/neon/max_pooling_neon_2x2.cc b/mace/kernels/neon/max_pooling_neon_2x2.cc index 5ac9daa8d33366e20c32c14b1f74d53ca0627492..69743b33cb4886d88c645229e84a567184bab0a2 100644 --- a/mace/kernels/neon/max_pooling_neon_2x2.cc +++ b/mace/kernels/neon/max_pooling_neon_2x2.cc @@ -61,6 +61,8 @@ void PoolingMaxNeonK2x2S2x2(const float *input, } } + w += num_vectors << 2; + for (; num_vectors > 0; --num_vectors) { float32x4_t r00 = vld1q_f32(r0); float32x4_t r10 = vld1q_f32(r1); @@ -79,7 +81,6 @@ void PoolingMaxNeonK2x2S2x2(const float *input, outptr += 4; } - w += num_vectors << 2; for (; w < out_width; ++w) { float max = std::numeric_limits::lowest(); for (int kh = 0; kh < 2; ++kh) { diff --git a/mace/kernels/neon/max_pooling_neon_3x3.cc b/mace/kernels/neon/max_pooling_neon_3x3.cc index 85b9e9648beb45e64334ff7dc9517b7baabcc50a..5a8bf246c9d338b9e777df2caeabda81bf86c47b 100644 --- a/mace/kernels/neon/max_pooling_neon_3x3.cc +++ b/mace/kernels/neon/max_pooling_neon_3x3.cc @@ -71,6 +71,7 @@ void PoolingMaxNeonK3x3S2x2(const float *input, } } + w += num_vectors << 2; float32x4x2_t row0 = vld2q_f32(r0); float32x4x2_t row1 = vld2q_f32(r1); float32x4x2_t row2 = vld2q_f32(r2); @@ -105,7 +106,6 @@ void PoolingMaxNeonK3x3S2x2(const float *input, outptr += 4; } - w += num_vectors << 2; for (; w < out_width; ++w) { float max = std::numeric_limits::lowest(); for (int kh = 0; kh < 3; ++kh) { diff --git a/mace/kernels/pooling.h b/mace/kernels/pooling.h index 78ce0ec5fdd1b2e14ef7d38c94ee5cc308fe845a..7925a42595ebac8155b456da81dcb02166497fed 100644 --- a/mace/kernels/pooling.h +++ b/mace/kernels/pooling.h @@ -39,10 +39,12 @@ class PoolingFunctor { index_t channels = output_shape[1]; index_t height = output_shape[2]; index_t width = output_shape[3]; + index_t out_image_size = height * width; index_t input_channels = input_shape[1]; index_t input_height = input_shape[2]; index_t input_width = input_shape[3]; + index_t in_image_size = input_height * input_width; int kernel_h = kernels_[0]; int kernel_w = kernels_[1]; @@ -57,56 +59,55 @@ class PoolingFunctor { int padded_h_start = 0 - paddings_[0] / 2; int padded_w_start = 0 - paddings_[1] / 2; + if (pooling_type_ == MAX) { #pragma omp parallel for collapse(2) - for (int n = 0; n < batch; ++n) { - for (int c = 0; c < channels; ++c) { - index_t out_offset = n * channels * height * width + c * height * width; - index_t in_offset = n * input_channels * input_height * input_width + - c * input_height * input_width; - for (int h = 0; h < height; ++h) { - for (int w = 0; w < width; ++w) { - T sum_or_max = 0; - switch (pooling_type_) { - case AVG: - break; - case MAX: - sum_or_max = std::numeric_limits::lowest(); - break; - default: - MACE_CHECK(false, "Unsupported pooling type: ", pooling_type_); - } - for (int kh = 0; kh < kernel_h; ++kh) { - for (int kw = 0; kw < kernel_w; ++kw) { - int inh = padded_h_start + h * stride_h + dilation_h * kh; - int inw = padded_w_start + w * stride_w + dilation_w * kw; - if (inh >= 0 && inh < input_height && inw >= 0 && - inw < input_width) { - index_t input_offset = in_offset + inh * input_width + inw; - switch (pooling_type_) { - case AVG: - sum_or_max += input[input_offset]; - break; - case MAX: - sum_or_max = std::max(sum_or_max, input[input_offset]); - break; - default: - MACE_CHECK(false, "Unsupported pooling type: ", - pooling_type_); + for (int b = 0; b < batch; ++b) { + for (int c = 0; c < channels; ++c) { + index_t out_offset = (b * channels + c) * out_image_size; + index_t in_offset = (b * input_channels + c) * in_image_size; + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + T max = std::numeric_limits::lowest(); + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + int inh = padded_h_start + h * stride_h + dilation_h * kh; + int inw = padded_w_start + w * stride_w + dilation_w * kw; + if (inh >= 0 && inh < input_height && inw >= 0 && + inw < input_width) { + index_t input_offset = in_offset + inh * input_width + inw; + max = std::max(max, input[input_offset]); } } } + output[out_offset] = max; + out_offset += 1; } - switch (pooling_type_) { - case AVG: - output[out_offset] = sum_or_max / (kernel_h * kernel_w); - break; - case MAX: - output[out_offset] = sum_or_max; - break; - default: - MACE_CHECK(false, "Unsupported pooling type: ", pooling_type_); + } + } + } + } else if (pooling_type_ == AVG) { +#pragma omp parallel for collapse(2) + for (int b = 0; b < batch; ++b) { + for (int c = 0; c < channels; ++c) { + index_t out_offset = (b * channels + c) * out_image_size; + index_t in_offset = (b * input_channels + c) * in_image_size; + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + T sum = 0; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + int inh = padded_h_start + h * stride_h + dilation_h * kh; + int inw = padded_w_start + w * stride_w + dilation_w * kw; + if (inh >= 0 && inh < input_height && inw >= 0 && + inw < input_width) { + index_t input_offset = in_offset + inh * input_width + inw; + sum += input[input_offset]; + } + } + } + output[out_offset] = sum / (kernel_h * kernel_w); + out_offset += 1; } - out_offset += 1; } } } diff --git a/mace/ops/pooling.h b/mace/ops/pooling.h index 597a4724be631d41aaffbbea85ceccd6a92d280a..3afb0f71a522a741ac62b9e6ca725efff127ad01 100644 --- a/mace/ops/pooling.h +++ b/mace/ops/pooling.h @@ -24,16 +24,17 @@ class PoolingOp : public ConvPool2dOpBase { bool Run() override { const Tensor* input = this->Input(INPUT); Tensor* output = this->Output(OUTPUT); - std::vector in_shape = input->shape(); std::vector output_shape(4); std::vector paddings(2); - std::vector filter_shape = std::vector(4); - filter_shape[0] = in_shape[1]; - filter_shape[1] = in_shape[0]; + std::vector filter_shape(4); + filter_shape[0] = input->shape()[1]; + filter_shape[1] = input->shape()[0]; filter_shape[2] = kernels_[0]; filter_shape[3] = kernels_[1]; - kernels::CalcPaddingAndOutputSize(in_shape.data(), filter_shape.data(), + + kernels::CalcPaddingAndOutputSize(input->shape().data(), + filter_shape.data(), this->dilations_.data(), this->strides_.data(), this->padding_, output_shape.data(), paddings.data()); @@ -42,7 +43,7 @@ class PoolingOp : public ConvPool2dOpBase { auto pooling_func = kernels::PoolingFunctor( pooling_type_, kernels_.data(), this->strides_.data(), paddings.data(), this->dilations_.data()); - pooling_func(input->data(), in_shape.data(), + pooling_func(input->data(), input->shape().data(), output->mutable_data(), output->shape().data()); return true; };