diff --git a/mace/kernels/neon/conv_2d_neon_3x3.cc b/mace/kernels/neon/conv_2d_neon_3x3.cc index b6331409b6a41c43f3cab06d23723f557f9063c9..0cf589b929831375fab427d6f8e66b8728ea190c 100644 --- a/mace/kernels/neon/conv_2d_neon_3x3.cc +++ b/mace/kernels/neon/conv_2d_neon_3x3.cc @@ -8,221 +8,288 @@ namespace mace { namespace kernels { +#define KERNEL_HEAD_CODE \ + int output_batch = output_shape[0]; \ + int output_channels = output_shape[1]; \ + int output_height = output_shape[2]; \ + int output_width = output_shape[3]; \ + int input_batch = input_shape[0]; \ + int input_channels = input_shape[1]; \ + int input_height = input_shape[2]; \ + int input_width = input_shape[3]; \ + int kernel_h = 3; \ + int kernel_w = 3; \ + for (int b = 0; b < output_batch; ++b) { \ + float* output_ptr_base = output + b * output_channels * output_height * output_width; \ + for (int oc = 0; oc < output_channels; ++oc) { \ + const float* filter_ptr = filter + oc * input_channels * kernel_h * kernel_w; \ + const float* input_ptr = input + b * input_channels * input_height * input_width; \ + float* output_ptr = output_ptr_base + oc * output_height * output_width; \ + std::fill(output_ptr, output_ptr + output_height * output_width, bias[oc]); \ + for (int ic = 0; ic < input_channels; ++ic) { \ + float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr), vld1q_f32(filter_ptr+3), vld1q_f32(filter_ptr+6)}; + +#define KERNEL_TAIL_CODE \ + filter_ptr += 9; \ + input_ptr += input_height * input_width; \ + } \ + } \ + } + static const int kRegisterSize = 4; -void Conv2dNeonK3x3S1(const float* input, // NCHW - const index_t* input_shape, - const float* filter, // c_out, c_in, kernel_h, kernel_w - const float* bias, // c_out - float* output, // NCHW - const index_t* output_shape) { - int batch = output_shape[0]; - int channels = output_shape[1]; - int height = output_shape[2]; - int width = output_shape[3]; - - int input_batch = input_shape[0]; - int input_channels = input_shape[1]; - int input_height = input_shape[2]; - int input_width = input_shape[3]; - - int kernel_h = 3; - int kernel_w = 3; - - int height_count = (height >> 1) << 1; - for (int b = 0; b < batch; ++b) { - float* output_ptr_base = output + b * channels * height * width; - for (int oc = 0; oc < channels; ++oc) { - const float* filter_ptr = - filter + oc * input_channels * kernel_h * kernel_w; - const float* input_ptr = - input + b * input_channels * input_height * input_width; - float* output_ptr = output_ptr_base + oc * height * width; - - std::fill(output_ptr, output_ptr + height * width, bias[oc]); - for (int ic = 0; ic < input_channels; ++ic) { - float32x4_t filter0 = vld1q_f32(filter_ptr); - float32x4_t filter3 = vld1q_f32(filter_ptr + 3); - float32x4_t filter6 = vld1q_f32(filter_ptr + 6); - - const float* row[kRegisterSize] = {input_ptr, input_ptr + input_width, - input_ptr + 2 * input_width, - input_ptr + 3 * input_width}; - - float* output_ptr1 = output_ptr; - float* output_ptr2 = output_ptr + width; +void Conv2dNeonK3x3S1(const float *input, // NCHW + const index_t *input_shape, + const float *filter, // c_out, c_in, kernel_h, kernel_w + const float *bias, // c_out + float *output, // NCHW + const index_t *output_shape) { + + int height_count = (output_shape[2] >> 1) << 1; + + KERNEL_HEAD_CODE + + const float *row_ptr_v[kRegisterSize] = { + input_ptr, input_ptr + input_width, + input_ptr + 2 * input_width, input_ptr + 3 * input_width + }; + + float *output_ptr_v[] = {output_ptr, output_ptr + output_width}; for (int h = 0; h < height_count; h += 2) { - int count = width >> 2; - int remain_count = width & 3; + int count = output_width >> 2; + int remain_count = output_width & 3; for (; count > 0; --count) { - float32x4_t sum0 = vdupq_n_f32(.0f); - float32x4_t sum1 = vdupq_n_f32(.0f); - float32x4_t row0_ext_0 = vld1q_f32(row[0]); // 0123 - float32x4_t row0_latter = - vld1q_f32(row[0] + kRegisterSize); // 4567 - float32x4_t row0_ext_1 = - vextq_f32(row0_ext_0, row0_latter, 1); // 1234 - float32x4_t row0_ext_2 = - vextq_f32(row0_ext_0, row0_latter, 2); // 2345 - - sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter0, 0); - sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter0, 1); - sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter0, 2); - - float32x4_t row1_ext_0 = vld1q_f32(row[1]); // 0123 - float32x4_t row1_latter = - vld1q_f32(row[1] + kRegisterSize); // 4567 - float32x4_t row1_ext_1 = - vextq_f32(row1_ext_0, row1_latter, 1); // 1234 - float32x4_t row1_ext_2 = - vextq_f32(row1_ext_0, row1_latter, 2); // 2345 - - sum0 = vfmaq_laneq_f32(sum0, row1_ext_0, filter3, 0); - sum0 = vfmaq_laneq_f32(sum0, row1_ext_1, filter3, 1); - sum0 = vfmaq_laneq_f32(sum0, row1_ext_2, filter3, 2); - - row0_ext_0 = vld1q_f32(row[2]); // 0123 - row0_latter = vld1q_f32(row[2] + kRegisterSize); // 4567 - row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); // 1234 - row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); // 2345 - - sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter6, 0); - sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter6, 1); - sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter6, 2); + float32x4_t n_sum0 = vdupq_n_f32(.0f); + + float32x4_t n_row_former = vld1q_f32(row_ptr_v[0]); + float32x4_t n_row_latter = vld1q_f32(row_ptr_v[0] + kRegisterSize); + float32x4_t n_row_ext0 = vextq_f32(n_row_former, n_row_latter, 1); + float32x4_t n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 2); + n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_former, n_filter_v[0], 0); + n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_ext0, n_filter_v[0], 1); + n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_ext1, n_filter_v[0], 2); + + float32x4_t n_row1_former = vld1q_f32(row_ptr_v[1]); + float32x4_t n_row1_latter = vld1q_f32(row_ptr_v[1] + kRegisterSize); + float32x4_t n_row1_ext0 = vextq_f32(n_row1_former, n_row1_latter, 1); + float32x4_t n_row1_ext1 = vextq_f32(n_row1_former, n_row1_latter, 2); + n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_former, n_filter_v[1], 0); + n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_ext0, n_filter_v[1], 1); + n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_ext1, n_filter_v[1], 2); + + n_row_former = vld1q_f32(row_ptr_v[2]); + n_row_latter = vld1q_f32(row_ptr_v[2] + kRegisterSize); + n_row_ext0 = vextq_f32(n_row_former, n_row_latter, 1); + n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 2); + n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_former, n_filter_v[2], 0); + n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_ext0, n_filter_v[2], 1); + n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_ext1, n_filter_v[2], 2); // second row - sum1 = vfmaq_laneq_f32(sum1, row1_ext_0, filter0, 0); - sum1 = vfmaq_laneq_f32(sum1, row1_ext_1, filter0, 1); - sum1 = vfmaq_laneq_f32(sum1, row1_ext_2, filter0, 2); - - sum1 = vfmaq_laneq_f32(sum1, row0_ext_0, filter3, 0); - sum1 = vfmaq_laneq_f32(sum1, row0_ext_1, filter3, 1); - sum1 = vfmaq_laneq_f32(sum1, row0_ext_2, filter3, 2); - - row1_ext_0 = vld1q_f32(row[3]); // 0123 - row1_latter = vld1q_f32(row[3] + kRegisterSize); // 4567 - row1_ext_1 = vextq_f32(row1_ext_0, row1_latter, 1); // 1234 - row1_ext_2 = vextq_f32(row1_ext_0, row1_latter, 2); // 2345 - - sum1 = vfmaq_laneq_f32(sum1, row1_ext_0, filter6, 0); - sum1 = vfmaq_laneq_f32(sum1, row1_ext_1, filter6, 1); - sum1 = vfmaq_laneq_f32(sum1, row1_ext_2, filter6, 2); - - float32x4_t output_row0 = vld1q_f32(output_ptr1); - float32x4_t output_row1 = vld1q_f32(output_ptr2); - output_row0 = vaddq_f32(output_row0, sum0); - output_row1 = vaddq_f32(output_row1, sum1); - vst1q_f32(output_ptr1, output_row0); - vst1q_f32(output_ptr2, output_row1); - - output_ptr1 += kRegisterSize; - output_ptr2 += kRegisterSize; + float32x4_t n_sum1 = vdupq_n_f32(.0f); + + n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_former, n_filter_v[0], 0); + n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_ext0, n_filter_v[0], 1); + n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_ext1, n_filter_v[0], 2); + + n_sum1 = vfmaq_laneq_f32(n_sum1, n_row_former, n_filter_v[1], 0); + n_sum1 = vfmaq_laneq_f32(n_sum1, n_row_ext0, n_filter_v[1], 1); + n_sum1 = vfmaq_laneq_f32(n_sum1, n_row_ext1, n_filter_v[1], 2); + + n_row1_former = vld1q_f32(row_ptr_v[3]); + n_row1_latter = vld1q_f32(row_ptr_v[3] + kRegisterSize); + n_row1_ext0 = vextq_f32(n_row1_former, n_row1_latter, 1); + n_row1_ext1 = vextq_f32(n_row1_former, n_row1_latter, 2); + n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_former, n_filter_v[2], 0); + n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_ext0, n_filter_v[2], 1); + n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_ext1, n_filter_v[2], 2); + + float32x4_t n_output_row = vld1q_f32(output_ptr_v[0]); + float32x4_t n_output_row1 = vld1q_f32(output_ptr_v[1]); + n_output_row = vaddq_f32(n_output_row, n_sum0); + n_output_row1 = vaddq_f32(n_output_row1, n_sum1); + vst1q_f32(output_ptr_v[0], n_output_row); + vst1q_f32(output_ptr_v[1], n_output_row1); + output_ptr_v[0] += kRegisterSize; + output_ptr_v[1] += kRegisterSize; for (int i = 0; i < kRegisterSize; ++i) { - row[i] += kRegisterSize; + row_ptr_v[i] += kRegisterSize; } } for (; remain_count > 0; --remain_count) { - float32x4_t row0 = vld1q_f32(row[0]); // 0123 - float32x4_t row1 = vld1q_f32(row[1]); // 0123 - float32x4_t row2 = vld1q_f32(row[2]); // 0123 - float32x4_t row3 = vld1q_f32(row[3]); // 0123 - - float32x4_t sum = vmulq_f32(row0, filter0); - sum = vmlaq_f32(sum, row1, filter3); - sum = vmlaq_f32(sum, row2, filter6); - sum = vsetq_lane_f32(*output_ptr1, sum, 3); - *output_ptr1 = vaddvq_f32(sum); - - sum = vmulq_f32(row1, filter0); - sum = vmlaq_f32(sum, row2, filter3); - sum = vmlaq_f32(sum, row3, filter6); - sum = vsetq_lane_f32(*output_ptr2, sum, 3); - *output_ptr2 = vaddvq_f32(sum); - - ++output_ptr1; - ++output_ptr2; + float32x4_t n_row_v[] = { + vld1q_f32(row_ptr_v[0]), + vld1q_f32(row_ptr_v[1]), + vld1q_f32(row_ptr_v[2]) + }; + float32x4_t n_sum0 = vmulq_f32(n_row_v[0], n_filter_v[0]); + n_sum0 = vmlaq_f32(n_sum0, n_row_v[1], n_filter_v[1]); + n_sum0 = vmlaq_f32(n_sum0, n_row_v[2], n_filter_v[2]); + n_sum0 = vsetq_lane_f32(*output_ptr_v[0], n_sum0, 3); + *output_ptr_v[0] = vaddvq_f32(n_sum0); + + float32x4_t n_row3 = vld1q_f32(row_ptr_v[3]); + float32x4_t n_sum1 = vmulq_f32(n_row_v[1], n_filter_v[0]); + n_sum1 = vmlaq_f32(n_sum1, n_row_v[2], n_filter_v[1]); + n_sum1 = vmlaq_f32(n_sum1, n_row3, n_filter_v[2]); + n_sum1 = vsetq_lane_f32(*output_ptr_v[1], n_sum1, 3); + *output_ptr_v[1] = vaddvq_f32(n_sum1); + + ++output_ptr_v[0]; + ++output_ptr_v[1]; for (int i = 0; i < kRegisterSize; ++i) { - row[i] += 1; + row_ptr_v[i] += 1; } } - output_ptr1 += width; - output_ptr2 += width; + output_ptr_v[0] += output_width; + output_ptr_v[1] += output_width; for (int i = 0; i < kRegisterSize; ++i) { - row[i] += 2 + input_width; + row_ptr_v[i] += 2 + input_width; } } - if (height != height_count) { - int count = width >> 2; - int remain_count = width & 3; + if (output_height != height_count) { + int count = output_width >> 2; + int remain_count = output_width & 3; for (; count > 0; --count) { - float32x4_t sum0 = vdupq_n_f32(.0f); - float32x4_t row0_ext_0 = vld1q_f32(row[0]); // 0123 - float32x4_t row0_latter = - vld1q_f32(row[0] + kRegisterSize); // 4567 - float32x4_t row0_ext_1 = - vextq_f32(row0_ext_0, row0_latter, 1); // 1234 - float32x4_t row0_ext_2 = - vextq_f32(row0_ext_0, row0_latter, 2); // 2345 - - sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter0, 0); - sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter0, 1); - sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter0, 2); - - float32x4_t row1_ext_0 = vld1q_f32(row[1]); // 0123 - float32x4_t row1_latter = - vld1q_f32(row[1] + kRegisterSize); // 4567 - float32x4_t row1_ext_1 = - vextq_f32(row1_ext_0, row1_latter, 1); // 1234 - float32x4_t row1_ext_2 = - vextq_f32(row1_ext_0, row1_latter, 2); // 2345 - - sum0 = vfmaq_laneq_f32(sum0, row1_ext_0, filter3, 0); - sum0 = vfmaq_laneq_f32(sum0, row1_ext_1, filter3, 1); - sum0 = vfmaq_laneq_f32(sum0, row1_ext_2, filter3, 2); - - row0_ext_0 = vld1q_f32(row[2]); // 0123 - row0_latter = vld1q_f32(row[2] + kRegisterSize); // 4567 - row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); // 1234 - row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); // 2345 - - sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter6, 0); - sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter6, 1); - sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter6, 2); - - float32x4_t output_row0 = vld1q_f32(output_ptr1); - output_row0 = vaddq_f32(output_row0, sum0); - vst1q_f32(output_ptr1, output_row0); - output_ptr1 += kRegisterSize; + float32x4_t n_sum = vdupq_n_f32(.0f); + float32x4_t n_row_former = vld1q_f32(row_ptr_v[0]); + float32x4_t n_row_latter = vld1q_f32(row_ptr_v[0] + kRegisterSize); + float32x4_t n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 1); + float32x4_t n_row_ext2 = vextq_f32(n_row_former, n_row_latter, 2); + n_sum = vfmaq_laneq_f32(n_sum, n_row_former, n_filter_v[0], 0); + n_sum = vfmaq_laneq_f32(n_sum, n_row_ext1, n_filter_v[0], 1); + n_sum = vfmaq_laneq_f32(n_sum, n_row_ext2, n_filter_v[0], 2); + + n_row_former = vld1q_f32(row_ptr_v[1]); + n_row_latter = vld1q_f32(row_ptr_v[1] + kRegisterSize); + n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 1); + n_row_ext2 = vextq_f32(n_row_former, n_row_latter, 2); + n_sum = vfmaq_laneq_f32(n_sum, n_row_former, n_filter_v[1], 0); + n_sum = vfmaq_laneq_f32(n_sum, n_row_ext1, n_filter_v[1], 1); + n_sum = vfmaq_laneq_f32(n_sum, n_row_ext2, n_filter_v[1], 2); + + n_row_former = vld1q_f32(row_ptr_v[2]); + n_row_latter = vld1q_f32(row_ptr_v[2] + kRegisterSize); + n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 1); + n_row_ext2 = vextq_f32(n_row_former, n_row_latter, 2); + n_sum = vfmaq_laneq_f32(n_sum, n_row_former, n_filter_v[2], 0); + n_sum = vfmaq_laneq_f32(n_sum, n_row_ext1, n_filter_v[2], 1); + n_sum = vfmaq_laneq_f32(n_sum, n_row_ext2, n_filter_v[2], 2); + + float32x4_t n_output_row = vld1q_f32(output_ptr_v[0]); + n_output_row = vaddq_f32(n_output_row, n_sum); + vst1q_f32(output_ptr_v[0], n_output_row); + output_ptr_v[0] += kRegisterSize; for (int i = 0; i < 3; ++i) { - row[i] += kRegisterSize; + row_ptr_v[i] += kRegisterSize; } } for (; remain_count > 0; --remain_count) { - float32x4_t row0 = vld1q_f32(row[0]); // 0123 - float32x4_t row1 = vld1q_f32(row[1]); // 0123 - float32x4_t row2 = vld1q_f32(row[2]); // 0123 + float32x4_t n_row_v[] = { + vld1q_f32(row_ptr_v[0]), + vld1q_f32(row_ptr_v[1]), + vld1q_f32(row_ptr_v[2]), + }; + + float32x4_t n_sum = vmulq_f32(n_row_v[0], n_filter_v[0]); + n_sum = vmlaq_f32(n_sum, n_row_v[1], n_filter_v[1]); + n_sum = vmlaq_f32(n_sum, n_row_v[2], n_filter_v[2]); + n_sum = vsetq_lane_f32(*output_ptr_v[0], n_sum, 3); + *output_ptr_v[0] = vaddvq_f32(n_sum); + + ++output_ptr_v[0]; + for (int i = 0; i < 3; ++i) { + row_ptr_v[i] += 1; + } + } + } + + KERNEL_TAIL_CODE +} + +void Conv2dNeonK3x3S2(const float *input, // NCHW + const index_t *input_shape, + const float *filter, // c_out, c_in, kernel_h, kernel_w + const float *bias, // c_out + float *output, // NCHW + const index_t *output_shape) { + int tail_step = 2 * (input_shape[3] - output_shape[3]); - float32x4_t sum = vmulq_f32(row0, filter0); - sum = vmlaq_f32(sum, row1, filter3); - sum = vmlaq_f32(sum, row2, filter6); - sum = vsetq_lane_f32(*output_ptr1, sum, 3); - *output_ptr1 = vaddvq_f32(sum); + KERNEL_HEAD_CODE - ++output_ptr1; + const float *row_ptr_v[3] = { + input_ptr, input_ptr + input_width, input_ptr + 2 * input_width + }; + + float *output_ptr_inner = output_ptr; + + for (int h = 0; h < output_height; ++h) { + int count = output_width >> 2; + int remain_count = output_width & 3; + + for (; count > 0; --count) { + float32x4_t n_sum = vdupq_n_f32(.0f); + + float32x4x2_t n_row_former = vld2q_f32(row_ptr_v[0]); + float32x4_t n_row_latter = vld1q_f32(row_ptr_v[0] + 8); + float32x4_t n_row_ext = vextq_f32(n_row_former.val[0], n_row_latter, 1); + + n_sum = vfmaq_laneq_f32(n_sum, n_row_former.val[0], n_filter_v[0], 0); + n_sum = vfmaq_laneq_f32(n_sum, n_row_former.val[1], n_filter_v[0], 1); + n_sum = vfmaq_laneq_f32(n_sum, n_row_ext, n_filter_v[0], 2); + + float32x4x2_t n_row1_former = vld2q_f32(row_ptr_v[1]); + float32x4_t n_row1_latter = vld1q_f32(row_ptr_v[1] + 8); + float32x4_t n_row1_ext = vextq_f32(n_row1_former.val[0], n_row1_latter, 1); + n_sum = vfmaq_laneq_f32(n_sum, n_row1_former.val[0], n_filter_v[1], 0); + n_sum = vfmaq_laneq_f32(n_sum, n_row1_former.val[1], n_filter_v[1], 1); + n_sum = vfmaq_laneq_f32(n_sum, n_row1_ext, n_filter_v[1], 2); + + float32x4x2_t n_row2_former = vld2q_f32(row_ptr_v[2]); + float32x4_t n_row2_latter = vld1q_f32(row_ptr_v[2] + 8); + float32x4_t n_row2_ext = vextq_f32(n_row2_former.val[0], n_row2_latter, 1); + n_sum = vfmaq_laneq_f32(n_sum, n_row2_former.val[0], n_filter_v[2], 0); + n_sum = vfmaq_laneq_f32(n_sum, n_row2_former.val[1], n_filter_v[2], 1); + n_sum = vfmaq_laneq_f32(n_sum, n_row2_ext, n_filter_v[2], 2); + + float32x4_t n_output_row = vld1q_f32(output_ptr_inner); + n_output_row = vaddq_f32(n_output_row, n_sum); + vst1q_f32(output_ptr_inner, n_output_row); + output_ptr_inner += kRegisterSize; + for (int i = 0; i < 3; ++i) { + row_ptr_v[i] += 2 * kRegisterSize; + } + } + for (; remain_count > 0; --remain_count) { + float32x4_t n_row_v[] = { + vld1q_f32(row_ptr_v[0]), + vld1q_f32(row_ptr_v[1]), + vld1q_f32(row_ptr_v[2]) + }; + float32x4_t n_sum = vmulq_f32(n_row_v[0], n_filter_v[0]); + n_sum = vmlaq_f32(n_sum, n_row_v[1], n_filter_v[1]); + n_sum = vmlaq_f32(n_sum, n_row_v[2], n_filter_v[2]); + n_sum = vsetq_lane_f32(*output_ptr_inner, n_sum, 3); + *output_ptr_inner = vaddvq_f32(n_sum); + + ++output_ptr_inner; for (int i = 0; i < 3; ++i) { - row[i] += 1; + row_ptr_v[i] += 2; } } + for (int i = 0; i < 3; ++i) { + row_ptr_v[i] += tail_step; + } } - filter_ptr += 9; - input_ptr += input_height * input_width; - } - } - } + + KERNEL_TAIL_CODE } -} // namespace kernels -} // namespace mace +#undef KERNEL_HEAD_CODE +#undef KERNEL_TAIL_CODE + +} // namespace kernels +} // namespace mace diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc index 066c56e0d3d08b4bf26fa386c74d01ab190fd220..4164ba0c9faa1189249d1037aa0204d962a08336 100644 --- a/mace/ops/conv_2d_benchmark.cc +++ b/mace/ops/conv_2d_benchmark.cc @@ -76,6 +76,10 @@ BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float); BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float); BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float); BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 128, float); +BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 128, float); +BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 128, float); +BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 128, float); +BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 128, float); BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, VALID, 128, float); BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, VALID, 128, float); BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, SAME, 128, float); diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index db6f2b48850fdbdd9c11f183685decf04838f826..96880a02ef0fc41857aba2bf8a78698a91e4240f 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -174,8 +174,8 @@ TEST_F(Conv2dOpTest, ConvNxNS12) { // generate random input index_t batch = 1 + rand() % 10; index_t input_channels = 1 + rand() % 50; - index_t height = 7 + rand() % 100; - index_t width = 7 + rand() % 100; + index_t height = 11 + rand() % 100; + index_t width = 11 + rand() % 100; index_t output_channels = 1 + rand() % 50; // Construct graph auto& net = test_net();