From 293b292e0ff3e6055dceb807c4cb57fc7bacb226 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 19 Dec 2017 17:00:55 +0800 Subject: [PATCH] refine im2col --- paddle/operators/math/im2col.cc | 39 +++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index 707ebf059..a746c267b 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -61,14 +61,22 @@ class Im2ColFunctor(); T* col_data = col->data(); - + int w_offset = -1; + int h_offset = 0; + int c_im = 0; for (int c = 0; c < channels_col; ++c) { - int w_offset = c % filter_width; - int h_offset = (c / filter_width) % filter_height; - int c_im = c / filter_width / filter_height; + ++w_offset; + if (UNLIKELY(w_offset == filter_width)) { + w_offset = 0; + ++h_offset; + if (UNLIKELY(h_offset == filter_height)) { + h_offset = 0; + ++c_im; + } + } for (int h = 0; h < col_height; ++h) { + int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; for (int w = 0; w < col_width; ++w) { - int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; int col_idx = (c * col_height + h) * col_width + w; int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; @@ -127,19 +135,26 @@ class Col2ImFunctordata(); const T* col_data = col.data(); + int w_offset = -1; + int h_offset = 0; + int c_im = 0; for (int c = 0; c < channels_col; ++c) { - int w_offset = c % filter_width; - int h_offset = (c / filter_width) % filter_height; - int c_im = c / filter_width / filter_height; + ++w_offset; + if (UNLIKELY(w_offset == filter_width)) { + w_offset = 0; + ++h_offset; + if (UNLIKELY(h_offset == filter_height)) { + h_offset = 0; + ++c_im; + } + } for (int h = 0; h < col_height; ++h) { + int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; for (int w = 0; w < col_width; ++w) { - int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; - if ((im_row_idx) >= 0 && (im_row_idx) < im_height && (im_col_idx) >= 0 && (im_col_idx) < im_width) { - im_row_idx += c_im * im_height; - im_data[im_row_idx * im_width + im_col_idx] += + im_data[(im_row_idx + c_im * im_height) * im_width + im_col_idx] += col_data[(c * col_height + h) * col_width + w]; } } -- GitLab