diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index d11a6afe9b38a90fc039ab42fbce4b49f9fe26a6..50af3199f20099fc81c01fb90690f1dd5b9640d8 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -61,19 +61,10 @@ class Im2ColFunctor(); T* col_data = col->data(); - int w_offset = -1; - int h_offset = 0; - int c_im = 0; for (int c = 0; c < channels_col; ++c) { - ++w_offset; - if (w_offset == filter_width) { - w_offset = 0; - ++h_offset; - if (h_offset == filter_height) { - h_offset = 0; - ++c_im; - } - } + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int c_im = c / (filter_width * filter_height); for (int h = 0; h < col_height; ++h) { int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; for (int w = 0; w < col_width; ++w) {