diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index f51eb7b1f4c66d1a479640fb9f0fced0ace352ea..7b88c1251e2de100d353fd9ed7b5fb3d7ce1039c 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -745,20 +745,13 @@ struct Conv2dFunctor : Conv2dFunctorBase { const index_t stride_w, const T zero_point, const int pad_height, const int pad_width, const std::vector &out_shape, const index_t depth, T* im2col_data) { - const index_t batches = out_shape[0]; - const index_t out_height = out_shape[1]; - const index_t out_width = out_shape[2]; - const index_t column_len = depth; - const index_t in_height = in_shape[1]; - const index_t in_width = in_shape[2]; - const index_t in_channels = in_shape[3]; - const index_t input_row_size = in_width * in_channels; - const index_t patch_row_size = filter_w * in_channels; + const index_t input_row_size = in_shape[2] * in_shape[3]; + const index_t patch_row_size = filter_w * in_shape[3]; #pragma omp parallel for collapse(3) - for (index_t b = 0; b < batches; ++b) { - for (index_t h = 0; h < out_height; ++h) { - for (index_t w = 0; w < out_width; ++w) { + for (index_t b = 0; b < out_shape[0]; ++b) { + for (index_t h = 0; h < out_shape[1]; ++h) { + for (index_t w = 0; w < out_shape[2]; ++w) { // Reshape a patch of input to column, which is corresponding to // a column of output(:, column). const index_t ih_begin = h * stride_h - (pad_height >> 1); @@ -767,15 +760,15 @@ struct Conv2dFunctor : Conv2dFunctorBase { const index_t iw_end = iw_begin + filter_w; // gate height and width to separate padding const index_t ih_begin_gated = std::max(0, ih_begin); - const index_t ih_end_gated = std::min(ih_end, in_height); + const index_t ih_end_gated = std::min(ih_end, in_shape[1]); const index_t iw_begin_gated = std::max(0, iw_begin); - const index_t iw_end_gated = std::min(iw_end, in_width); + const index_t iw_end_gated = std::min(iw_end, in_shape[2]); const index_t pad_top = std::max(0, -ih_begin); const index_t pad_bottom = ih_end - ih_end_gated; const index_t pad_left = std::max(0, -iw_begin); const index_t pad_right = iw_end - iw_end_gated; index_t im2col_column_offset = - ((b * out_height + h) * out_width + w) * column_len; + ((b * out_shape[1] + h) * out_shape[2] + w) * depth; // fill in padding top if (pad_top > 0) { @@ -785,16 +778,15 @@ struct Conv2dFunctor : Conv2dFunctorBase { const index_t patch_row_size_gated = std::min(filter_w - pad_left, - in_width - iw_begin_gated) * in_channels; + in_shape[2] - iw_begin_gated) * in_shape[3]; MACE_CHECK(patch_row_size_gated == - ((filter_w - (pad_left + pad_right)) * in_channels)); - const index_t pad_left_size = pad_left * in_channels; - const index_t pad_right_size = pad_right * in_channels; + ((filter_w - (pad_left + pad_right)) * in_shape[3])); + const index_t pad_left_size = pad_left * in_shape[3]; + const index_t pad_right_size = pad_right * in_shape[3]; index_t im2col_offset = im2col_column_offset + - (pad_top * filter_w + pad_left) * in_channels; - index_t in_offset = - ((b * in_height + ih_begin_gated) * in_width + iw_begin_gated) * - in_channels; + (pad_top * filter_w + pad_left) * in_shape[3]; + index_t in_offset = ((b * in_shape[1] + ih_begin_gated) * in_shape[2] + + iw_begin_gated) * in_shape[3]; // fill in effective rows for (index_t ih = ih_begin_gated; ih < ih_end_gated; ++ih) { @@ -820,7 +812,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { if (pad_bottom > 0) { const index_t pad_bottom_size = pad_bottom * patch_row_size; const index_t bottom_offset = - im2col_column_offset + column_len - pad_bottom_size; + im2col_column_offset + depth - pad_bottom_size; std::fill_n(im2col_data + bottom_offset, pad_bottom_size, zero_point); }