提交 1e4f55ae 编写于 作者: B Bin Li

optimize general convolution

上级 ecea6d5b
......@@ -98,33 +98,168 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
const int dilation_h,
const int dilation_w,
float *output) {
const index_t in_image_size = in_height * in_width;
const index_t out_image_size = out_height * out_width;
const index_t in_batch_size = in_channels * in_image_size;
const index_t out_batch_size = out_channels * out_image_size;
const index_t filter_size = filter_height * filter_width;
const index_t in_tile_size =
3 * stride_w + (filter_width - 1) * dilation_w + 1;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t m = 0; m < out_channels; ++m) {
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
index_t out_offset =
((b * out_channels + m) * out_height + h) * out_width + w;
for (index_t c = 0; c < in_channels; ++c) {
for (index_t kh = 0; kh < filter_height; ++kh) {
for (index_t kw = 0; kw < filter_width; ++kw) {
index_t ih = h * stride_h + kh * dilation_h;
index_t iw = w * stride_w + kw * dilation_w;
index_t in_offset =
((b * in_channels + c) * in_height + ih) * in_width + iw;
index_t filter_offset =
(((m * in_channels) + c) * filter_height + kh)
* filter_width
+ kw;
output[out_offset] +=
input[in_offset] * filter[filter_offset];
for (index_t m = 0; m < out_channels; m += 4) {
if (m + 3 < out_channels) {
float *out_ptr0_base =
output + b * out_batch_size + m * out_image_size;
float *out_ptr1_base =
output + b * out_batch_size + (m + 1) * out_image_size;
float *out_ptr2_base =
output + b * out_batch_size + (m + 2) * out_image_size;
float *out_ptr3_base =
output + b * out_batch_size + (m + 3) * out_image_size;
for (index_t c = 0; c < in_channels; ++c) {
const float *in_ptr_base =
input + b * in_batch_size + c * in_image_size;
const float *filter_ptr0 =
filter + m * in_channels * filter_size + c * filter_size;
const float *filter_ptr1 =
filter + (m + 1) * in_channels * filter_size + c * filter_size;
const float *filter_ptr2 =
filter + (m + 2) * in_channels * filter_size + c * filter_size;
const float *filter_ptr3 =
filter + (m + 3) * in_channels * filter_size + c * filter_size;
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w + 3 < out_width; w += 4) {
// input offset
index_t ih = h * stride_h;
index_t iw = w * stride_w;
index_t in_offset = ih * in_width + iw;
// output (4 outch x 1 height x 4 width): vo_outch_height
float vo0[4], vo1[4], vo2[4], vo3[4];
// load output
index_t out_offset = h * out_width + w;
for (index_t ow = 0; ow < 4; ++ow) {
vo0[ow] = out_ptr0_base[out_offset + ow];
vo1[ow] = out_ptr1_base[out_offset + ow];
vo2[ow] = out_ptr2_base[out_offset + ow];
vo3[ow] = out_ptr3_base[out_offset + ow];
}
}
}
}
}
}
}
// calc by row
for (index_t kh = 0; kh < filter_height; ++kh) {
for (index_t kw = 0; kw < filter_width; ++kw) {
// outch 0
vo0[0] += in_ptr_base[in_offset
+ kw * dilation_w] * filter_ptr0[kw];
vo0[1] += in_ptr_base[in_offset + stride_w
+ kw * dilation_w] * filter_ptr0[kw];
vo0[2] += in_ptr_base[in_offset + 2 * stride_w
+ kw * dilation_w] * filter_ptr0[kw];
vo0[3] += in_ptr_base[in_offset + 3 * stride_w
+ kw * dilation_w] * filter_ptr0[kw];
// outch 1
vo1[0] += in_ptr_base[in_offset
+ kw * dilation_w] * filter_ptr1[kw];
vo1[1] += in_ptr_base[in_offset + stride_w
+ kw * dilation_w] * filter_ptr1[kw];
vo1[2] += in_ptr_base[in_offset + 2 * stride_w
+ kw * dilation_w] * filter_ptr1[kw];
vo1[3] += in_ptr_base[in_offset + 3 * stride_w
+ kw * dilation_w] * filter_ptr1[kw];
// outch 2
vo2[0] += in_ptr_base[in_offset
+ kw * dilation_w] * filter_ptr2[kw];
vo2[1] += in_ptr_base[in_offset + stride_w
+ kw * dilation_w] * filter_ptr2[kw];
vo2[2] += in_ptr_base[in_offset + 2 * stride_w
+ kw * dilation_w] * filter_ptr2[kw];
vo2[3] += in_ptr_base[in_offset + 3 * stride_w
+ kw * dilation_w] * filter_ptr2[kw];
// outch 3
vo3[0] += in_ptr_base[in_offset
+ kw * dilation_w] * filter_ptr3[kw];
vo3[1] += in_ptr_base[in_offset + stride_w
+ kw * dilation_w] * filter_ptr3[kw];
vo3[2] += in_ptr_base[in_offset + 2 * stride_w
+ kw * dilation_w] * filter_ptr3[kw];
vo3[3] += in_ptr_base[in_offset + 3 * stride_w
+ kw * dilation_w] * filter_ptr3[kw];
} // kw
in_offset += dilation_h * in_width;
filter_ptr0 += filter_width;
filter_ptr1 += filter_width;
filter_ptr2 += filter_width;
filter_ptr3 += filter_width;
} // kh
for (index_t ow = 0; ow < 4; ++ow) {
out_ptr0_base[out_offset + ow] = vo0[ow];
out_ptr1_base[out_offset + ow] = vo1[ow];
out_ptr2_base[out_offset + ow] = vo2[ow];
out_ptr3_base[out_offset + ow] = vo3[ow];
}
filter_ptr0 -= filter_size;
filter_ptr1 -= filter_size;
filter_ptr2 -= filter_size;
filter_ptr3 -= filter_size;
} // w
} // h
} // c
} else {
for (index_t mm = m; mm < out_channels; ++mm) {
float *out_ptr0_base =
output + b * out_batch_size + mm * out_image_size;
for (index_t c = 0; c < in_channels; ++c) {
const float *in_ptr_base =
input + b * in_batch_size + c * in_image_size;
const float *filter_ptr0 =
filter + mm * in_channels * filter_size + c * filter_size;
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w + 3 < out_width; w += 4) {
// input offset
index_t ih = h * stride_h;
index_t iw = w * stride_w;
index_t in_offset = ih * in_width + iw;
// output (1 outch x 1 height x 4 width): vo_outch_height
float vo0[4];
// load output
index_t out_offset = h * out_width + w;
for (index_t ow = 0; ow < 4; ++ow) {
vo0[ow] = out_ptr0_base[out_offset + ow];
}
// calc by row
for (index_t kh = 0; kh < filter_height; ++kh) {
for (index_t kw = 0; kw < filter_width; ++kw) {
// outch 0
vo0[0] += in_ptr_base[in_offset
+ kw * dilation_w] * filter_ptr0[kw];
vo0[1] += in_ptr_base[in_offset + stride_w
+ kw * dilation_w] * filter_ptr0[kw];
vo0[2] += in_ptr_base[in_offset + 2 * stride_w
+ kw * dilation_w] * filter_ptr0[kw];
vo0[3] += in_ptr_base[in_offset + 3 * stride_w
+ kw * dilation_w] * filter_ptr0[kw];
} // kw
in_offset += dilation_h * in_width;
filter_ptr0 += filter_width;
} // kh
for (index_t ow = 0; ow < 4; ++ow) {
out_ptr0_base[out_offset + ow] = vo0[ow];
}
filter_ptr0 -= filter_size;
} // w
} // h
} // c
} // mm
} // if
} // m
} // b
}
void operator()(const Tensor *input,
......@@ -286,63 +421,15 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
} else if (use_neon_3x3_s2) {
extra_output_height = height;
extra_input_height =
std::max(padded_input_height, (extra_output_height - 1) * 2 + 3);
extra_output_width = RoundUp<index_t>(width, 4);
extra_input_width =
std::max(padded_input_width, (extra_output_width - 1) * 2 + 3);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
} else if (use_neon_5x5_s1) {
extra_output_height = height;
extra_input_height =
std::max(padded_input_height, extra_output_height + 4);
extra_output_width = RoundUp<index_t>(width, 4);
extra_input_width = std::max(padded_input_width, extra_output_width + 4);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
} else if (use_neon_7x7_s1) {
extra_output_height = height;
extra_input_height =
std::max(padded_input_height, extra_output_height + 6);
extra_output_width = RoundUp<index_t>(width, 4);
extra_input_width = std::max(padded_input_width, extra_output_width + 6);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
} else if (use_neon_7x7_s2) {
extra_output_height = height;
extra_input_height =
std::max(padded_input_height, (extra_output_height - 1) * 2 + 7);
extra_output_width = RoundUp<index_t>(width, 4);
extra_input_width =
std::max(padded_input_width, (extra_output_width - 1) * 2 + 7);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
} else if (use_neon_7x7_s3) {
} else {
extra_output_height = height;
extra_input_height =
std::max(padded_input_height, (extra_output_height - 1) * 3 + 7);
std::max(padded_input_height, (extra_output_height - 1) * stride_h
+ (filter_h - 1) * dilation_h + 1);
extra_output_width = RoundUp<index_t>(width, 4);
extra_input_width =
std::max(padded_input_width, (extra_output_width - 1) * 3 + 7);
std::max(padded_input_width, (extra_output_width - 1) * stride_w
+ (filter_w - 1) * dilation_w + 1);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册