diff --git a/paddle/fluid/operators/math/im2col.cc b/paddle/fluid/operators/math/im2col.cc index c29a1373194d6efe6c227e9ed57ce042e64713d6..be373c99d138dbc12eb7a1d96fe0b38418a88529 100644 --- a/paddle/fluid/operators/math/im2col.cc +++ b/paddle/fluid/operators/math/im2col.cc @@ -126,11 +126,9 @@ class Im2ColFunctor 1 for (int ic = 0; ic < im_channels; ++ic) { // TODO(TJ): use add and resue stride - T* dst_data_ic = - col_data + ic * filter_width * filter_height * col_matrix_width; + T* dst_data_ic = col_data + ic * col_block_ic; for (int kh = 0; kh < filter_height; ++kh) { - T* dst_data_kh = - dst_data_ic + kh * filter_width * col_matrix_width; + T* dst_data_kh = dst_data_ic + kh * col_block_fh; for (int kw = 0; kw < plw; ++kw) { // TODO(TJ): reuse array outside this for size_t sz = sizeof(T) * (plw - kw); @@ -158,6 +156,67 @@ class Im2ColFunctor 2*pw: kw = 3, pw = 1 + // 0 x x x x ... x x x x 0 + // 1 1 1 1 1 1 + // ==> + // 0 x ... x x + // x x ... x x + // x x ... x 0 + // 2. kw < 2*pw: kw = 3, pw = 2 + // 0 0 x x x ... x x x 0 0 + // 1 1 1 1 1 1 + // ==> + // 0 0 x ... x x x + // 0 x x ... x x 0 + // x x x ... x 0 0 + + // TODO(TJ): use array like: size_t copy_size[kw]={sizeof(T) * + // (output_width-1)} + // length of copy_size is equal kw. + if (plw + prw < filter_width) { + for (int oh = 0; oh < output_height; ++oh) { + const T* im_data_start = + im_data + (oh - plh > 0 ? oh - plh : 0) * im_width; + T* dst_data = col_data + oh * output_width; + for (int ic = 0; ic < im_channels; ++ic) { + const T* src_data = im_data_start + ic * im_size; + for (int kh = 0; kh < filter_height; ++kh) { + if ((oh < plh && kh < plh) || + (oh > (output_height - prh - 1) && + kh > (filter_height - prh - 1))) { + dst_data = dst_data + filter_width * col_matrix_width; + continue; + } + // TODO(TJ): reuse plw-kw outside this for + // try to unify + for (int kw = 0; kw < plw; ++kw) { + std::memcpy(dst_data + (plw - kw), src_data, + sizeof(T) * (output_width - (plw - kw))); + dst_data = dst_data + col_matrix_width; + } + for (int kw = plw; kw < filter_width - prw; ++kw) { + std::memcpy(dst_data, src_data + (kw - plw), + sizeof(T) * output_width); + dst_data = dst_data + col_matrix_width; + } + int i = 1; + for (int kw = filter_width - prw; kw < filter_width; + ++kw, ++i) { + std::memcpy(dst_data, src_data + (kw - plw), + sizeof(T) * (output_width - i)); + dst_data = dst_data + col_matrix_width; + } + src_data = src_data + im_width; + } + } + } + } else { + LOG(FATAL) << "Not implement yet"; + } return; } }