From c0f7ecb4f29e569332895d02cb714dd9718bc642 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Sat, 16 Mar 2019 18:28:30 +0800 Subject: [PATCH] Optimize general col2im to speed up transpose conv --- src/operators/math/im2col.cpp | 348 ++++++++++++++++++---------------- src/operators/math/im2col.h | 12 ++ 2 files changed, 198 insertions(+), 162 deletions(-) diff --git a/src/operators/math/im2col.cpp b/src/operators/math/im2col.cpp index fedd17ed0c..02e6b1c6f9 100644 --- a/src/operators/math/im2col.cpp +++ b/src/operators/math/im2col.cpp @@ -22,10 +22,13 @@ namespace paddle_mobile { namespace operators { namespace math { -void ExtractToImg(const float *im_data, float *col_data, const int im_height, - const int im_width, const int col_height, const int col_width, - const int padding_h, const int padding_w, const int stride_h, - const int stride_w, const int kh, const int kw) { +template <> +void ExtractToImg(const float *im_data, float *col_data, + const int im_height, const int im_width, + const int col_height, const int col_width, + const int padding_h, const int padding_w, + const int stride_h, const int stride_w, const int kh, + const int kw) { int h = padding_h - kh; int w = padding_w - kw; int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0; @@ -41,48 +44,43 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height, im_data += start_height * im_width + start_width; col_data += col_start_height * col_width + col_start_width; - for (int i = start_height; i < end_height; i += stride_h) { + int s = 0; if (stride_w == 1) { - // memcpy(col_data, im_data, extract * sizeof(float)); - int s = 0; #if __ARM_NEON for (; s < extract - 3; s += 4) { - float32x4_t img = vld1q_f32(im_data + s); - vst1q_f32(col_data + s, img); + float32x4_t _img = vld1q_f32(im_data + s); + vst1q_f32(col_data + s, _img); } #endif for (; s < extract; ++s) { col_data[s] = im_data[s]; } } else if (stride_w == 2) { - int s = 0; #if __ARM_NEON for (; s < extract - 3; s += 4) { - float32x4x2_t img = vld2q_f32(im_data + s * 2); - vst1q_f32(col_data + s, img.val[0]); + float32x4x2_t _img = vld2q_f32(im_data + s * 2); + vst1q_f32(col_data + s, _img.val[0]); } #endif for (; s < extract; ++s) { col_data[s] = im_data[s * 2]; } } else if (stride_w == 3) { - int s = 0; #if __ARM_NEON for (; s < extract - 3; s += 4) { - float32x4x3_t img = vld3q_f32(im_data + s * 3); - vst1q_f32(col_data + s, img.val[0]); + float32x4x3_t _img = vld3q_f32(im_data + s * 3); + vst1q_f32(col_data + s, _img.val[0]); } #endif for (; s < extract; ++s) { col_data[s] = im_data[s * 3]; } } else if (stride_w == 4) { - int s = 0; #if __ARM_NEON for (; s < extract - 3; s += 4) { - float32x4x4_t img = vld4q_f32(im_data + s * 4); - vst1q_f32(col_data + s, img.val[0]); + float32x4x4_t _img = vld4q_f32(im_data + s * 4); + vst1q_f32(col_data + s, _img.val[0]); } #endif for (; s < extract; ++s) { @@ -96,77 +94,13 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height, } } -/* - * im = [input_channels, input_height, input_width] - * col = - * [input_channels, filter_height, filter_width, output_height, - * output_width] - */ template <> -void Im2ColFunctor::operator()( - const framework::Tensor &im, const std::vector &dilation, - const std::vector &stride, const std::vector &padding, - framework::Tensor *col) { - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; - int filter_height = col->dims()[1]; - int filter_width = col->dims()[2]; - int col_height = col->dims()[3]; - int col_width = col->dims()[4]; - - int channels_col = im_channels * filter_height * filter_width; - const float *im_data = im.data(); - float *col_data = col->data(); -#if __ARM_NEON - if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) { - int im_spatial_size = im_height * im_width; - int col_spatial_size = col_height * col_width; - // pad 0 - memset(col_data, 0, col->numel() * sizeof(float)); - #pragma omp parallel for - for (int ic = 0; ic < im_channels; ++ic) { - const float *local_im_data = im_data + ic * im_spatial_size; - float *local_col_data = - col_data + ic * filter_height * filter_width * col_spatial_size; - for (int kh = 0; kh < filter_height; ++kh) { - for (int kw = 0; kw < filter_width; ++kw) { - ExtractToImg(local_im_data, local_col_data, im_height, im_width, - col_height, col_width, padding[0], padding[1], stride[0], - stride[1], kh, kw); - local_col_data += col_spatial_size; - } - } - } - } else { -#endif - for (int c = 0; c < channels_col; ++c) { - int w_offset = c % filter_width; - int h_offset = (c / filter_width) % filter_height; - int c_im = c / (filter_width * filter_height); - for (int h = 0; h < col_height; ++h) { - int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; - for (int w = 0; w < col_width; ++w) { - int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; - int col_idx = (c * col_height + h) * col_width + w; - int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; - - col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || - im_col_idx < 0 || im_col_idx >= im_width) - ? static_cast(0) - : im_data[im_idx]; - } - } - } -#if __ARM_NEON - } -#endif -} - -void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height, - const int im_width, const int col_height, const int col_width, - const int padding_h, const int padding_w, const int stride_h, - const int stride_w, const int kh, const int kw) { +void ExtractToImg(const int8_t *im_data, int8_t *col_data, + const int im_height, const int im_width, + const int col_height, const int col_width, + const int padding_h, const int padding_w, + const int stride_h, const int stride_w, const int kh, + const int kw) { int h = padding_h - kh; int w = padding_w - kw; int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0; @@ -183,21 +117,26 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height, im_data += start_height * im_width + start_width; col_data += col_start_height * col_width + col_start_width; for (int i = start_height; i < end_height; i += stride_h) { + int s = 0; if (stride_w == 1) { - memcpy(col_data, im_data, extract * sizeof(int8_t)); + for (; s < extract - 15; s += 16) { + int8x16_t _img = vld1q_s8(im_data + s); + vst1q_s8(col_data + s, _img); + } + for (; s < extract; ++s) { + col_data[s] = im_data[s]; + } } else if (stride_w == 2) { - int s = 0; #if __ARM_NEON for (; s < extract - 15; s += 16) { - int8x16x2_t img = vld2q_s8(im_data + s * 2); - vst1q_s8(col_data + s, img.val[0]); + int8x16x2_t _img = vld2q_s8(im_data + s * 2); + vst1q_s8(col_data + s, _img.val[0]); } #endif for (; s < extract; ++s) { col_data[s] = im_data[s * 2]; } } else if (stride_w == 3) { - int s = 0; #if __ARM_NEON for (; s < extract - 15; s += 16) { int8x16x3_t img = vld3q_s8(im_data + s * 3); @@ -208,7 +147,6 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height, col_data[s] = im_data[s * 3]; } } else if (stride_w == 4) { - int s = 0; #if __ARM_NEON for (; s < extract - 15; s += 16) { int8x16x4_t img = vld4q_s8(im_data + s * 4); @@ -232,65 +170,128 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height, * [input_channels, filter_height, filter_width, output_height, * output_width] */ -template <> -void Im2ColFunctor::operator()( - const framework::Tensor &im, const std::vector &dilation, - const std::vector &stride, const std::vector &padding, - framework::Tensor *col) { - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; - int filter_height = col->dims()[1]; - int filter_width = col->dims()[2]; - int col_height = col->dims()[3]; - int col_width = col->dims()[4]; - - int channels_col = im_channels * filter_height * filter_width; - const int8_t *im_data = im.data(); - int8_t *col_data = col->mutable_data(); -#if defined(__ARM_NEON__) || defined(__ARM_NEON) - if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) { - int im_spatial_size = im_height * im_width; - int col_spatial_size = col_height * col_width; - // pad 0 - memset(col_data, 0, col->numel() * sizeof(int8_t)); - #pragma omp parallel for - for (int ic = 0; ic < im_channels; ++ic) { - const int8_t *local_im_data = im_data + ic * im_spatial_size; - int8_t *local_col_data = - col_data + ic * filter_height * filter_width * col_spatial_size; - for (int kh = 0; kh < filter_height; ++kh) { - for (int kw = 0; kw < filter_width; ++kw) { - ExtractToImg(local_im_data, local_col_data, im_height, im_width, - col_height, col_width, padding[0], padding[1], stride[0], - stride[1], kh, kw); - local_col_data += col_spatial_size; +template +class Im2ColFunctor { + public: + void operator()(const framework::Tensor &im, const std::vector &dilation, + const std::vector &stride, + const std::vector &padding, framework::Tensor *col) { + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; + int filter_height = col->dims()[1]; + int filter_width = col->dims()[2]; + int col_height = col->dims()[3]; + int col_width = col->dims()[4]; + + int channels_col = im_channels * filter_height * filter_width; + const T *im_data = im.data(); + T *col_data = col->data(); +#if __ARM_NEON + if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) { + int im_spatial_size = im_height * im_width; + int col_spatial_size = col_height * col_width; + // pad 0 + memset(col_data, 0, col->numel() * sizeof(T)); + + #pragma omp parallel for + for (int ic = 0; ic < im_channels; ++ic) { + const T *local_im_data = im_data + ic * im_spatial_size; + T *local_col_data = + col_data + ic * filter_height * filter_width * col_spatial_size; + for (int kh = 0; kh < filter_height; ++kh) { + for (int kw = 0; kw < filter_width; ++kw) { + ExtractToImg(local_im_data, local_col_data, im_height, im_width, + col_height, col_width, padding[0], padding[1], + stride[0], stride[1], kh, kw); + local_col_data += col_spatial_size; + } } } - } - } else { + } else { #endif - for (int c = 0; c < channels_col; ++c) { - int w_offset = c % filter_width; - int h_offset = (c / filter_width) % filter_height; - int c_im = c / (filter_width * filter_height); - for (int h = 0; h < col_height; ++h) { - int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; - for (int w = 0; w < col_width; ++w) { - int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; - int col_idx = (c * col_height + h) * col_width + w; - int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; - - col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || - im_col_idx < 0 || im_col_idx >= im_width) - ? static_cast(0) - : im_data[im_idx]; + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int c_im = c / (filter_width * filter_height); + for (int h = 0; h < col_height; ++h) { + int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; + for (int w = 0; w < col_width; ++w) { + int im_col_idx = + w * stride[1] - padding[1] + w_offset * dilation[1]; + int col_idx = (c * col_height + h) * col_width + w; + int im_idx = + (im_row_idx + c_im * im_height) * im_width + im_col_idx; + + col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || + im_col_idx < 0 || im_col_idx >= im_width) + ? static_cast(0) + : im_data[im_idx]; + } } } +#if __ARM_NEON } -#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#endif } +}; + +template <> +void ExtendToImg(const float *col_data, float *im_data, + const int im_height, const int im_width, + const int col_height, const int col_width, + const int padding_h, const int padding_w, + const int stride_h, const int stride_w, const int kh, + const int kw) { + int h = padding_h - kh; + int w = padding_w - kw; + int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0; + int col_start_width = w > 0 ? (w + stride_w - 1) / stride_w : 0; + int start_height = kh + col_start_height * stride_h - padding_h; + int start_width = kw + col_start_width * stride_w - padding_w; + + int end_height = (col_height - col_start_height) * stride_h + start_height; + end_height = end_height > im_height ? im_height : end_height; + int end_width = (col_width - col_start_width) * stride_w + start_width; + end_width = end_width > im_width ? im_width : end_width; + // int extract = (end_width - start_width + stride_w - 1) / stride_w; + int extend = end_width - start_width; + + im_data += start_height * im_width + start_width; + col_data += col_start_height * col_width + col_start_width; + + for (int i = start_height; i < end_height; i += stride_h) { + int s = 0; + if (stride_w == 1) { +#if __ARM_NEON + for (; s < extend - 3; s += 4) { + float32x4_t _col = vld1q_f32(col_data + s); + float32x4_t _img = vld1q_f32(im_data + s); + _img = vaddq_f32(_img, _col); + vst1q_f32(im_data + s, _img); + } #endif + for (; s < extend; ++s) { + im_data[s] += col_data[s]; + } + } else if (stride_w == 2) { +#if __ARM_NEON + for (; s < extend - 7; s += 8) { + float32x4_t _col = vld1q_f32(col_data + s / 2); + float32x4x2_t _img = vld2q_f32(im_data + s); + _img.val[0] = vaddq_f32(_img.val[0], _col); + vst2q_f32(im_data + s, _img); + } +#endif + for (; s < extend; s += 2) { + im_data[s] += col_data[s / 2]; + } + } else { + PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1 and 2."); + } + im_data += im_width * stride_h; + col_data += col_width; + } } /* @@ -306,8 +307,6 @@ class Col2ImFunctor { const std::vector &dilation, const std::vector &stride, const std::vector &padding, framework::Tensor *im) { - // PADDLE_ENFORCE(im->dims().size() == 3); - // PADDLE_ENFORCE(col.dims().size() == 5); int im_channels = im->dims()[0]; int im_height = im->dims()[1]; int im_width = im->dims()[2]; @@ -317,34 +316,59 @@ class Col2ImFunctor { int col_width = col.dims()[4]; int channels_col = im_channels * filter_height * filter_width; - - T *im_data = im->data(); const T *col_data = col.data(); + T *im_data = im->data(); memset(static_cast(im_data), 0, sizeof(T) * im->numel()); - for (int c = 0; c < channels_col; ++c) { - int w_offset = c % filter_width; - int h_offset = (c / filter_width) % filter_height; - int c_im = c / (filter_width * filter_height); - for (int h = 0; h < col_height; ++h) { - int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; - for (int w = 0; w < col_width; ++w) { - int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; - if ((im_row_idx) >= 0 && (im_row_idx) < im_height && - (im_col_idx) >= 0 && (im_col_idx) < im_width) { - im_data[(im_row_idx + c_im * im_height) * im_width + im_col_idx] += - col_data[(c * col_height + h) * col_width + w]; +#if __ARM_NEON + if (stride[0] <= 2 && dilation[0] == 1 && dilation[0] == dilation[1]) { + int im_spatial_size = im_height * im_width; + int col_spatial_size = col_height * col_width; + + #pragma omp parallel for + for (int ic = 0; ic < im_channels; ++ic) { + T *local_im_data = im_data + ic * im_spatial_size; + const T *local_col_data = + col_data + ic * filter_height * filter_width * col_spatial_size; + for (int kh = 0; kh < filter_height; ++kh) { + for (int kw = 0; kw < filter_width; ++kw) { + ExtendToImg(local_col_data, local_im_data, im_height, im_width, + col_height, col_width, padding[0], padding[1], + stride[0], stride[1], kh, kw); + local_col_data += col_spatial_size; } } } + } else { +#endif + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int c_im = c / (filter_width * filter_height); + for (int h = 0; h < col_height; ++h) { + int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; + for (int w = 0; w < col_width; ++w) { + int im_col_idx = + w * stride[1] - padding[1] + w_offset * dilation[1]; + if ((im_row_idx) >= 0 && (im_row_idx) < im_height && + (im_col_idx) >= 0 && (im_col_idx) < im_width) { + im_data[(im_row_idx + c_im * im_height) * im_width + + im_col_idx] += + col_data[(c * col_height + h) * col_width + w]; + } + } + } + } +#if __ARM_NEON } +#endif } }; template class Im2ColFunctor; template class Im2ColFunctor; template class Col2ImFunctor; -template class Col2ImFunctor; +// template class Col2ImFunctor; /* * im = [input_channels, input_height, input_width] diff --git a/src/operators/math/im2col.h b/src/operators/math/im2col.h index fd557ac7c5..f6b17c074e 100644 --- a/src/operators/math/im2col.h +++ b/src/operators/math/im2col.h @@ -25,6 +25,18 @@ namespace math { * Col2ImFunctor. */ enum class ColFormat { kCFO = 0, kOCF = 1 }; +template +void ExtractToImg(const T *im_data, T *col_data, const int im_height, + const int im_width, const int col_height, const int col_width, + const int padding_h, const int padding_w, const int stride_h, + const int stride_w, const int kh, const int kw); + +template +void ExtendToImg(const T *col_data, T *im_data, const int im_height, + const int im_width, const int col_height, const int col_width, + const int padding_h, const int padding_w, const int stride_h, + const int stride_w, const int kh, const int kw); + /* * \brief Converts the image data of three dimensions(CHW) into a * colData of -- GitLab