diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 7ebfc7df8c117edd7bcf14cc5ae6ba3dc1302c03..9edfacd6dfb506e12a0d82772d8de301bb8506e2 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -103,10 +103,12 @@ class MidWiseTransformIterator { MidWiseTransformIterator& operator++() { ++j_; - i_ = j_ / post_; - if (UNLIKELY(i_ == n_)) { + if (UNLIKELY(j_ == post_)) { + ++i_; j_ = 0; - i_ = 0; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } } return *this; } @@ -125,10 +127,10 @@ class MidWiseTransformIterator { private: const T* ptr_; - int i_; + int64_t i_; int64_t j_; int64_t n_; - int post_; + int64_t post_; }; #ifdef __NVCC__ diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index 707ebf05962fb65892c2adbbf41a0a3449763d31..c2633b2e16434558d16f699a701e7b8cf1de8342 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -61,14 +61,13 @@ class Im2ColFunctor(); T* col_data = col->data(); - for (int c = 0; c < channels_col; ++c) { int w_offset = c % filter_width; int h_offset = (c / filter_width) % filter_height; - int c_im = c / filter_width / filter_height; + int c_im = c / (filter_width * filter_height); for (int h = 0; h < col_height; ++h) { + int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; for (int w = 0; w < col_width; ++w) { - int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; int col_idx = (c * col_height + h) * col_width + w; int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; @@ -130,16 +129,14 @@ class Col2ImFunctor= 0 && (im_row_idx) < im_height && (im_col_idx) >= 0 && (im_col_idx) < im_width) { - im_row_idx += c_im * im_height; - im_data[im_row_idx * im_width + im_col_idx] += + im_data[(im_row_idx + c_im * im_height) * im_width + im_col_idx] += col_data[(c * col_height + h) * col_width + w]; } } @@ -199,12 +196,13 @@ class Im2ColFunctor= 0 && im_row_offset < im_height && im_col_offset >= 0 && im_col_offset < im_width) { int im_offset =