diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index 50af3199f20099fc81c01fb90690f1dd5b9640d8..c2633b2e16434558d16f699a701e7b8cf1de8342 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -126,19 +126,10 @@ class Col2ImFunctordata(); const T* col_data = col.data(); - int w_offset = -1; - int h_offset = 0; - int c_im = 0; for (int c = 0; c < channels_col; ++c) { - ++w_offset; - if (w_offset == filter_width) { - w_offset = 0; - ++h_offset; - if (h_offset == filter_height) { - h_offset = 0; - ++c_im; - } - } + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int c_im = c / (filter_width * filter_height); for (int h = 0; h < col_height; ++h) { int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; for (int w = 0; w < col_width; ++w) {