From 1e4f55ae5873e84cef8a6511c11035bceaa3c690 Mon Sep 17 00:00:00 2001 From: Bin Li Date: Sat, 28 Apr 2018 19:14:11 +0800 Subject: [PATCH] optimize general convolution --- mace/kernels/conv_2d.h | 241 ++++++++++++++++++++++++++++------------- 1 file changed, 164 insertions(+), 77 deletions(-) diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index 790d6f3a..cf49d2cb 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -98,33 +98,168 @@ struct Conv2dFunctor : Conv2dFunctorBase { const int dilation_h, const int dilation_w, float *output) { + const index_t in_image_size = in_height * in_width; + const index_t out_image_size = out_height * out_width; + const index_t in_batch_size = in_channels * in_image_size; + const index_t out_batch_size = out_channels * out_image_size; + const index_t filter_size = filter_height * filter_width; + const index_t in_tile_size = + 3 * stride_w + (filter_width - 1) * dilation_w + 1; + #pragma omp parallel for collapse(2) for (index_t b = 0; b < batch; ++b) { - for (index_t m = 0; m < out_channels; ++m) { - for (index_t h = 0; h < out_height; ++h) { - for (index_t w = 0; w < out_width; ++w) { - index_t out_offset = - ((b * out_channels + m) * out_height + h) * out_width + w; - for (index_t c = 0; c < in_channels; ++c) { - for (index_t kh = 0; kh < filter_height; ++kh) { - for (index_t kw = 0; kw < filter_width; ++kw) { - index_t ih = h * stride_h + kh * dilation_h; - index_t iw = w * stride_w + kw * dilation_w; - index_t in_offset = - ((b * in_channels + c) * in_height + ih) * in_width + iw; - index_t filter_offset = - (((m * in_channels) + c) * filter_height + kh) - * filter_width - + kw; - output[out_offset] += - input[in_offset] * filter[filter_offset]; + for (index_t m = 0; m < out_channels; m += 4) { + if (m + 3 < out_channels) { + float *out_ptr0_base = + output + b * out_batch_size + m * out_image_size; + float *out_ptr1_base = + output + b * out_batch_size + (m + 1) * out_image_size; + float *out_ptr2_base = + output + b * out_batch_size + (m + 2) * out_image_size; + float *out_ptr3_base = + output + b * out_batch_size + (m + 3) * out_image_size; + for (index_t c = 0; c < in_channels; ++c) { + const float *in_ptr_base = + input + b * in_batch_size + c * in_image_size; + const float *filter_ptr0 = + filter + m * in_channels * filter_size + c * filter_size; + const float *filter_ptr1 = + filter + (m + 1) * in_channels * filter_size + c * filter_size; + const float *filter_ptr2 = + filter + (m + 2) * in_channels * filter_size + c * filter_size; + const float *filter_ptr3 = + filter + (m + 3) * in_channels * filter_size + c * filter_size; + for (index_t h = 0; h < out_height; ++h) { + for (index_t w = 0; w + 3 < out_width; w += 4) { + // input offset + index_t ih = h * stride_h; + index_t iw = w * stride_w; + index_t in_offset = ih * in_width + iw; + // output (4 outch x 1 height x 4 width): vo_outch_height + float vo0[4], vo1[4], vo2[4], vo3[4]; + // load output + index_t out_offset = h * out_width + w; + for (index_t ow = 0; ow < 4; ++ow) { + vo0[ow] = out_ptr0_base[out_offset + ow]; + vo1[ow] = out_ptr1_base[out_offset + ow]; + vo2[ow] = out_ptr2_base[out_offset + ow]; + vo3[ow] = out_ptr3_base[out_offset + ow]; } - } - } - } - } - } - } + // calc by row + for (index_t kh = 0; kh < filter_height; ++kh) { + for (index_t kw = 0; kw < filter_width; ++kw) { + // outch 0 + vo0[0] += in_ptr_base[in_offset + + kw * dilation_w] * filter_ptr0[kw]; + vo0[1] += in_ptr_base[in_offset + stride_w + + kw * dilation_w] * filter_ptr0[kw]; + vo0[2] += in_ptr_base[in_offset + 2 * stride_w + + kw * dilation_w] * filter_ptr0[kw]; + vo0[3] += in_ptr_base[in_offset + 3 * stride_w + + kw * dilation_w] * filter_ptr0[kw]; + // outch 1 + vo1[0] += in_ptr_base[in_offset + + kw * dilation_w] * filter_ptr1[kw]; + vo1[1] += in_ptr_base[in_offset + stride_w + + kw * dilation_w] * filter_ptr1[kw]; + vo1[2] += in_ptr_base[in_offset + 2 * stride_w + + kw * dilation_w] * filter_ptr1[kw]; + vo1[3] += in_ptr_base[in_offset + 3 * stride_w + + kw * dilation_w] * filter_ptr1[kw]; + // outch 2 + vo2[0] += in_ptr_base[in_offset + + kw * dilation_w] * filter_ptr2[kw]; + vo2[1] += in_ptr_base[in_offset + stride_w + + kw * dilation_w] * filter_ptr2[kw]; + vo2[2] += in_ptr_base[in_offset + 2 * stride_w + + kw * dilation_w] * filter_ptr2[kw]; + vo2[3] += in_ptr_base[in_offset + 3 * stride_w + + kw * dilation_w] * filter_ptr2[kw]; + // outch 3 + vo3[0] += in_ptr_base[in_offset + + kw * dilation_w] * filter_ptr3[kw]; + vo3[1] += in_ptr_base[in_offset + stride_w + + kw * dilation_w] * filter_ptr3[kw]; + vo3[2] += in_ptr_base[in_offset + 2 * stride_w + + kw * dilation_w] * filter_ptr3[kw]; + vo3[3] += in_ptr_base[in_offset + 3 * stride_w + + kw * dilation_w] * filter_ptr3[kw]; + } // kw + + in_offset += dilation_h * in_width; + filter_ptr0 += filter_width; + filter_ptr1 += filter_width; + filter_ptr2 += filter_width; + filter_ptr3 += filter_width; + } // kh + + for (index_t ow = 0; ow < 4; ++ow) { + out_ptr0_base[out_offset + ow] = vo0[ow]; + out_ptr1_base[out_offset + ow] = vo1[ow]; + out_ptr2_base[out_offset + ow] = vo2[ow]; + out_ptr3_base[out_offset + ow] = vo3[ow]; + } + + filter_ptr0 -= filter_size; + filter_ptr1 -= filter_size; + filter_ptr2 -= filter_size; + filter_ptr3 -= filter_size; + } // w + } // h + } // c + } else { + for (index_t mm = m; mm < out_channels; ++mm) { + float *out_ptr0_base = + output + b * out_batch_size + mm * out_image_size; + for (index_t c = 0; c < in_channels; ++c) { + const float *in_ptr_base = + input + b * in_batch_size + c * in_image_size; + const float *filter_ptr0 = + filter + mm * in_channels * filter_size + c * filter_size; + + for (index_t h = 0; h < out_height; ++h) { + for (index_t w = 0; w + 3 < out_width; w += 4) { + // input offset + index_t ih = h * stride_h; + index_t iw = w * stride_w; + index_t in_offset = ih * in_width + iw; + // output (1 outch x 1 height x 4 width): vo_outch_height + float vo0[4]; + // load output + index_t out_offset = h * out_width + w; + for (index_t ow = 0; ow < 4; ++ow) { + vo0[ow] = out_ptr0_base[out_offset + ow]; + } + + // calc by row + for (index_t kh = 0; kh < filter_height; ++kh) { + for (index_t kw = 0; kw < filter_width; ++kw) { + // outch 0 + vo0[0] += in_ptr_base[in_offset + + kw * dilation_w] * filter_ptr0[kw]; + vo0[1] += in_ptr_base[in_offset + stride_w + + kw * dilation_w] * filter_ptr0[kw]; + vo0[2] += in_ptr_base[in_offset + 2 * stride_w + + kw * dilation_w] * filter_ptr0[kw]; + vo0[3] += in_ptr_base[in_offset + 3 * stride_w + + kw * dilation_w] * filter_ptr0[kw]; + } // kw + + in_offset += dilation_h * in_width; + filter_ptr0 += filter_width; + } // kh + + for (index_t ow = 0; ow < 4; ++ow) { + out_ptr0_base[out_offset + ow] = vo0[ow]; + } + filter_ptr0 -= filter_size; + } // w + } // h + } // c + } // mm + } // if + } // m + } // b } void operator()(const Tensor *input, @@ -286,63 +421,15 @@ struct Conv2dFunctor : Conv2dFunctorBase { if (extra_input_width != padded_input_width) { pad_right += (extra_input_width - padded_input_width); } - } else if (use_neon_3x3_s2) { - extra_output_height = height; - extra_input_height = - std::max(padded_input_height, (extra_output_height - 1) * 2 + 3); - extra_output_width = RoundUp(width, 4); - extra_input_width = - std::max(padded_input_width, (extra_output_width - 1) * 2 + 3); - if (extra_input_height != padded_input_height) { - pad_bottom += (extra_input_height - padded_input_height); - } - if (extra_input_width != padded_input_width) { - pad_right += (extra_input_width - padded_input_width); - } - } else if (use_neon_5x5_s1) { - extra_output_height = height; - extra_input_height = - std::max(padded_input_height, extra_output_height + 4); - extra_output_width = RoundUp(width, 4); - extra_input_width = std::max(padded_input_width, extra_output_width + 4); - if (extra_input_height != padded_input_height) { - pad_bottom += (extra_input_height - padded_input_height); - } - if (extra_input_width != padded_input_width) { - pad_right += (extra_input_width - padded_input_width); - } - } else if (use_neon_7x7_s1) { - extra_output_height = height; - extra_input_height = - std::max(padded_input_height, extra_output_height + 6); - extra_output_width = RoundUp(width, 4); - extra_input_width = std::max(padded_input_width, extra_output_width + 6); - if (extra_input_height != padded_input_height) { - pad_bottom += (extra_input_height - padded_input_height); - } - if (extra_input_width != padded_input_width) { - pad_right += (extra_input_width - padded_input_width); - } - } else if (use_neon_7x7_s2) { - extra_output_height = height; - extra_input_height = - std::max(padded_input_height, (extra_output_height - 1) * 2 + 7); - extra_output_width = RoundUp(width, 4); - extra_input_width = - std::max(padded_input_width, (extra_output_width - 1) * 2 + 7); - if (extra_input_height != padded_input_height) { - pad_bottom += (extra_input_height - padded_input_height); - } - if (extra_input_width != padded_input_width) { - pad_right += (extra_input_width - padded_input_width); - } - } else if (use_neon_7x7_s3) { + } else { extra_output_height = height; extra_input_height = - std::max(padded_input_height, (extra_output_height - 1) * 3 + 7); + std::max(padded_input_height, (extra_output_height - 1) * stride_h + + (filter_h - 1) * dilation_h + 1); extra_output_width = RoundUp(width, 4); extra_input_width = - std::max(padded_input_width, (extra_output_width - 1) * 3 + 7); + std::max(padded_input_width, (extra_output_width - 1) * stride_w + + (filter_w - 1) * dilation_w + 1); if (extra_input_height != padded_input_height) { pad_bottom += (extra_input_height - padded_input_height); } -- GitLab