diff --git a/src/operators/math/im2col.cpp b/src/operators/math/im2col.cpp index 02e6b1c6f99940a9780324752c37dbf5d7172f96..a7b97e5bfca6c5a9753d6a9e664bc4a8ee5450f6 100644 --- a/src/operators/math/im2col.cpp +++ b/src/operators/math/im2col.cpp @@ -12,12 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "operators/math/im2col.h" #include #ifdef __ARM_NEON #include #endif +#include #include "common/types.h" +#include "operators/math/im2col.h" + namespace paddle_mobile { namespace operators { namespace math { @@ -294,6 +296,173 @@ void ExtendToImg(const float *col_data, float *im_data, } } +template <> +void ExtendToImgV2(const float *col_data, float *im_data, + const int im_height, const int im_width, + const int col_height, const int col_width, + const int padding_h, const int padding_w, + const int stride_h, const int stride_w, const int kh, + const int kernel_w) { + int col_spatial_size = col_height * col_width; + int h = padding_h - kh; + int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0; + int start_height = kh + col_start_height * stride_h - padding_h; + int end_height = (col_height - col_start_height) * stride_h + start_height; + end_height = end_height > im_height ? im_height : end_height; + im_data += start_height * im_width; + col_data += col_start_height * col_width; + + int kw = 0; + for (; kw < kernel_w - 1; kw += 2) { + int w0 = padding_w - kw; + int w1 = padding_w - (kw + 1); + int col_start_width0 = w0 > 0 ? (w0 + stride_w - 1) / stride_w : 0; + int col_start_width1 = w1 > 0 ? (w1 + stride_w - 1) / stride_w : 0; + int start_width0 = kw + col_start_width0 * stride_w - padding_w; + int start_width1 = (kw + 1) + col_start_width1 * stride_w - padding_w; + + int end_width0 = (col_width - col_start_width0) * stride_w + start_width0; + end_width0 = end_width0 > im_width ? im_width : end_width0; + int end_width1 = (col_width - col_start_width1) * stride_w + start_width1; + end_width1 = end_width1 > im_width ? im_width : end_width1; + int start_width = 0; + int end_width = 0; + if (stride_w == 1) { + start_width = std::max(start_width0, start_width1); + end_width = std::min(end_width0, end_width1); + } else if (stride_w == 2) { + start_width = std::min(start_width0, start_width1); + end_width = std::min(end_width0, end_width1); + } else { + PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1 and 2."); + } + + // DLOG << "start_width0: " << start_width0 << ", end_width0: " << + // end_width0; DLOG << "start_width1: " << start_width1 << ", end_width1: + // " << end_width1; + int extend = end_width - start_width; + float *im_data01 = im_data + start_width; + float *im_data0 = im_data + start_width0; + float *im_data1 = im_data + start_width1; + const float *col_data0 = col_data + col_start_width0; + const float *col_data1 = col_data + col_spatial_size + col_start_width1; + + for (int i = start_height; i < end_height; i += stride_h) { + int s = 0; + if (stride_w == 1) { + int offset0 = start_width - start_width0; + int offset1 = start_width - start_width1; + for (int ss = 0; ss < start_width - start_width0; ++ss) { + im_data0[ss] += col_data0[ss]; + } + for (int ss = 0; ss < start_width - start_width1; ++ss) { + im_data1[ss] += col_data1[ss]; + } +#if __ARM_NEON + for (; s < extend - 3; s += 4) { + float32x4_t _col0 = vld1q_f32(col_data0 + offset0 + s); + float32x4_t _col1 = vld1q_f32(col_data1 + offset1 + s); + float32x4_t _img = vld1q_f32(im_data01 + s); + _img = vaddq_f32(_img, _col0); + _img = vaddq_f32(_img, _col1); + vst1q_f32(im_data01 + s, _img); + } +#endif + for (int ss = s; ss < end_width0 - start_width0; ++ss) { + im_data0[ss] += col_data0[ss]; + } + for (int ss = s; ss < end_width1 - start_width1; ++ss) { + im_data1[ss] += col_data1[ss]; + } + } else if (stride_w == 2) { + if (start_width0 < start_width1) { +#if __ARM_NEON + for (; s < extend - 7; s += 8) { + float32x4_t _col0 = vld1q_f32(col_data0 + s / 2); + float32x4_t _col1 = vld1q_f32(col_data1 + s / 2); + float32x4x2_t _img = vld2q_f32(im_data01 + s); + _img.val[0] = vaddq_f32(_img.val[0], _col0); + _img.val[1] = vaddq_f32(_img.val[1], _col1); + vst2q_f32(im_data01 + s, _img); + } +#endif + } else { +#if __ARM_NEON + for (; s < extend - 7; s += 8) { + float32x4_t _col0 = vld1q_f32(col_data0 + s / 2); + float32x4_t _col1 = vld1q_f32(col_data1 + s / 2); + float32x4x2_t _img = vld2q_f32(im_data01 + s); + _img.val[0] = vaddq_f32(_img.val[0], _col1); + _img.val[1] = vaddq_f32(_img.val[1], _col0); + vst2q_f32(im_data01 + s, _img); + } +#endif + } + for (int ss = s; ss < end_width0 - start_width0; ss += 2) { + im_data0[ss] += col_data0[ss / 2]; + } + for (int ss = s; ss < end_width1 - start_width1; ss += 2) { + im_data1[ss] += col_data1[ss / 2]; + } + } + + im_data0 += im_width * stride_h; + im_data1 += im_width * stride_h; + im_data01 += im_width * stride_h; + col_data0 += col_width; + col_data1 += col_width; + } + col_data += 2 * col_spatial_size; + } + + for (; kw < kernel_w; ++kw) { + int w = padding_w - kw; + int col_start_width = w > 0 ? (w + stride_w - 1) / stride_w : 0; + int start_width = kw + col_start_width * stride_w - padding_w; + + int end_width = (col_width - col_start_width) * stride_w + start_width; + end_width = end_width > im_width ? im_width : end_width; + int extend = end_width - start_width; + + float *im_data0 = im_data + start_width; + const float *col_data0 = col_data + col_start_width; + + for (int i = start_height; i < end_height; i += stride_h) { + int s = 0; + if (stride_w == 1) { +#if __ARM_NEON + for (; s < extend - 3; s += 4) { + float32x4_t _col = vld1q_f32(col_data + s); + float32x4_t _img = vld1q_f32(im_data + s); + _img = vaddq_f32(_img, _col); + vst1q_f32(im_data + s, _img); + } +#endif + for (; s < extend; ++s) { + im_data[s] += col_data[s]; + } + } else if (stride_w == 2) { +#if __ARM_NEON + for (; s < extend - 7; s += 8) { + float32x4_t _col = vld1q_f32(col_data + s / 2); + float32x4x2_t _img = vld2q_f32(im_data + s); + _img.val[0] = vaddq_f32(_img.val[0], _col); + vst2q_f32(im_data + s, _img); + } +#endif + for (; s < extend; s += 2) { + im_data[s] += col_data[s / 2]; + } + } else { + PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1 and 2."); + } + im_data += im_width * stride_h; + col_data += col_width; + } + col_data += col_spatial_size; + } +} + /* * im = [input_channels, input_height, input_width] * col = @@ -331,12 +500,19 @@ class Col2ImFunctor { const T *local_col_data = col_data + ic * filter_height * filter_width * col_spatial_size; for (int kh = 0; kh < filter_height; ++kh) { +#if 0 for (int kw = 0; kw < filter_width; ++kw) { ExtendToImg(local_col_data, local_im_data, im_height, im_width, col_height, col_width, padding[0], padding[1], stride[0], stride[1], kh, kw); local_col_data += col_spatial_size; } +#else + ExtendToImgV2(local_col_data, local_im_data, im_height, im_width, + col_height, col_width, padding[0], padding[1], + stride[0], stride[1], kh, filter_width); + local_col_data += col_spatial_size * filter_width; +#endif } } } else { diff --git a/src/operators/math/im2col.h b/src/operators/math/im2col.h index f6b17c074e621c11a401b2bac8463e97a9699f79..347f72c9177d7492d0f41d2a1abace9335422d34 100644 --- a/src/operators/math/im2col.h +++ b/src/operators/math/im2col.h @@ -37,6 +37,13 @@ void ExtendToImg(const T *col_data, T *im_data, const int im_height, const int padding_h, const int padding_w, const int stride_h, const int stride_w, const int kh, const int kw); +template +void ExtendToImgV2(const T *col_data, T *im_data, const int im_height, + const int im_width, const int col_height, + const int col_width, const int padding_h, + const int padding_w, const int stride_h, const int stride_w, + const int kh, const int kernel_w); + /* * \brief Converts the image data of three dimensions(CHW) into a * colData of