diff --git a/paddle/fluid/operators/math/im2col.cc b/paddle/fluid/operators/math/im2col.cc index 101e046acbbfa8838f6e204b802b6de3590480ab..311401b3d73a2cb2c4db4a93c3280686d215919b 100644 --- a/paddle/fluid/operators/math/im2col.cc +++ b/paddle/fluid/operators/math/im2col.cc @@ -48,29 +48,63 @@ class Im2ColFunctor(); T* col_data = col->data(); // TODO(TJ): change me to template - // further optimaze: - // 1. padding != 1 - // 2. could also support stride_h != 1 + // further optimize: padding == 1 need special if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 && - dilation[1] == 1 && padding[0] == 0 && padding[1] == 0) { + dilation[1] == 1) { int col_matrix_width = output_width * output_height; int im_size = im_height * im_width; - size_t copy_size = sizeof(T) * output_width; - for (int oh = 0; oh < output_height; ++oh) { - const T* im_data_start = im_data + oh * im_width; - T* dst_data = col_data + oh * output_width; - for (int ic = 0; ic < im_channels; ++ic) { - const T* src_data = im_data_start + ic * im_size; - for (int kh = 0; kh < filter_height; ++kh) { + if (padding[0] == 0 && padding[1] == 0) { + size_t copy_size = sizeof(T) * output_width; + for (int oh = 0; oh < output_height; ++oh) { + const T* im_data_start = im_data + oh * im_width; + T* dst_data = col_data + oh * output_width; + for (int ic = 0; ic < im_channels; ++ic) { + const T* src_data = im_data_start + ic * im_size; + for (int kh = 0; kh < filter_height; ++kh) { + for (int kw = 0; kw < filter_width; ++kw) { + std::memcpy(dst_data, src_data + kw, copy_size); + dst_data = dst_data + col_matrix_width; + } + src_data = src_data + im_width; + } + } + } + return; + } else { + int plh = padding[0]; + // int plw = padding[1]; + int prh = + (output_height - 1) * stride[0] + filter_height - im_height - plh; + // int prw = (output_width - 1) * stride[1] + filter_width - im_width - + // plw; + + // fill height padding : 0 ~ plh-1, (oh-prh) ~ (oh-1) + // TODO(TJ): reuse sizes + assert(plh == prh); // because stride_h == 1 + for (int ph = 0; ph < plh; ++ph) { + size_t sz = sizeof(T) * output_width * (plh - ph); + T* col_start_l = col_data + ph * filter_width * col_matrix_width; + T* col_start_r = + col_data + + (filter_width - ph - 1) * filter_width * col_matrix_width + + col_matrix_width - output_width * (plh - ph); + for (int ic = 0; ic < im_channels; ++ic) { + T* dst_data_l = + col_start_l + + ic * filter_width * filter_height * col_matrix_width; + T* dst_data_r = + col_start_r + + ic * filter_width * filter_height * col_matrix_width; for (int kw = 0; kw < filter_width; ++kw) { - std::memcpy(dst_data, src_data + kw, copy_size); - dst_data = dst_data + col_matrix_width; + std::memset(dst_data_l, 0, sz); + std::memset(dst_data_r, 0, sz); + dst_data_l = dst_data_l + col_matrix_width; + dst_data_r = dst_data_r + col_matrix_width; } - src_data = src_data + im_width; } } + return; } - return; } for (int c = 0; c < channels_col; ++c) {