diff --git a/paddle/fluid/operators/math/im2col.cc b/paddle/fluid/operators/math/im2col.cc index a50b9ace39249f4f899a46e171bbdced033b46bc..1a42ef12647cc397c0642808987425a9677127a2 100644 --- a/paddle/fluid/operators/math/im2col.cc +++ b/paddle/fluid/operators/math/im2col.cc @@ -40,22 +40,46 @@ class Im2ColFunctordims()[1]; int filter_width = col->dims()[2]; - int col_height = col->dims()[3]; - int col_width = col->dims()[4]; + int output_height = col->dims()[3]; + int output_width = col->dims()[4]; int channels_col = im_channels * filter_height * filter_width; const T* im_data = im.data(); T* col_data = col->data(); + // TODO(TJ): change me to template + // further optimaze: + // 1. padding != 1 + // 2. could also support stride_h != 1 + if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 && + dilation[1] == 1 && padding[0] == 0 && padding[1] == 0) { + int col_matrix_width = output_width * output_height; + for (int oh = 0; oh < output_height; ++oh) { + const T* im_data_start = im_data + oh * im_width; + T* dst_data = col_data + oh * output_width; + for (int ic = 0; ic < im_channels; ++ic) { + const T* src_data = im_data_start + ic * im_height * im_width; + for (int kh = 0; kh < filter_height; ++kh) { + for (int kw = 0; kw < filter_width; ++kw) { + std::memcpy(dst_data, src_data + kw, sizeof(T) * output_width); + dst_data = dst_data + col_matrix_width; + } + src_data = src_data + im_width; + } + } + } + return; + } + for (int c = 0; c < channels_col; ++c) { int w_offset = c % filter_width; int h_offset = (c / filter_width) % filter_height; int c_im = c / (filter_width * filter_height); - for (int h = 0; h < col_height; ++h) { + for (int h = 0; h < output_height; ++h) { int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; - for (int w = 0; w < col_width; ++w) { + for (int w = 0; w < output_width; ++w) { int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; - int col_idx = (c * col_height + h) * col_width + w; + int col_idx = (c * output_height + h) * output_width + w; int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||