diff --git a/mace/kernels/arm/conv_2d_neon.h b/mace/kernels/arm/conv_2d_neon.h index b35429baf035b87950b500c673e8e6260a38f469..59a24dc66316f939e9fe1d800742c7db685654e6 100644 --- a/mace/kernels/arm/conv_2d_neon.h +++ b/mace/kernels/arm/conv_2d_neon.h @@ -47,6 +47,18 @@ extern void Conv2dNeonK5x5S1(const float *input, const index_t *out_shape, float *output); +extern void Conv2dNeonK1x7S1(const float *input, + const float *filter, + const index_t *in_shape, + const index_t *out_shape, + float *output); + +extern void Conv2dNeonK7x1S1(const float *input, + const float *filter, + const index_t *in_shape, + const index_t *out_shape, + float *output); + extern void Conv2dNeonK7x7S1(const float *input, const float *filter, const index_t *in_shape, @@ -77,6 +89,29 @@ extern void Conv2dNeonK15x1S1(const float *input, const index_t *out_shape, float *output); +// calculate one output channel and one input channel +inline void Conv2dCPUKHxKWCalc(const float *in_ptr, + const float *filter_ptr, + const index_t in_width, + const index_t filter_height, + const index_t filter_width, + const index_t out_height, + const index_t out_width, + float *out_ptr, + const int stride) { + for (index_t h = 0; h < out_height; ++h) { + for (index_t w = 0; w < out_width; ++w) { + for (int i = 0; i < filter_height; ++i) { + for (int j = 0; j < filter_width; ++j) { + out_ptr[h * out_width + w] + += in_ptr[(h * stride + i) * in_width + (w * stride + j)] + * filter_ptr[i * filter_width + j]; + } + } + } + } +} + } // namespace kernels } // namespace mace diff --git a/mace/kernels/arm/conv_2d_neon_15x1.cc b/mace/kernels/arm/conv_2d_neon_15x1.cc index 80dda31493b1ba3f157dd6333848d13f6c247001..9a5d2c410f46ea0cb125805291db6c43bf26fd0e 100644 --- a/mace/kernels/arm/conv_2d_neon_15x1.cc +++ b/mace/kernels/arm/conv_2d_neon_15x1.cc @@ -76,7 +76,7 @@ void Conv2dNeonK15x1S1(const float *input, input + b * in_batch_size + c * in_image_size; const float *filter_ptr = filter + m * in_channels * 15 + c * 15; #if defined(MACE_ENABLE_NEON) && !defined(__aarch64__) - /* load filter (1 outch x 1 height x 4 width) */ + /* load filter (1 outch x 4 height x 1 width) */ float32x4_t vf0, vf1, vf2, vf3; vf0 = vld1q_f32(filter_ptr); vf1 = vld1q_f32(filter_ptr + 4); @@ -87,7 +87,7 @@ void Conv2dNeonK15x1S1(const float *input, for (index_t wt = 0; wt < tile_width && w + wt < out_width; ++wt) { // load output index_t out_offset = h * out_width + w + wt; - // output (1 outch x 1 height x 4 width): vo_outch_height + // output (1 outch x 4 height x 1 width): vo_outch_height float32x4_t vo = {out_ptr_base[out_offset], out_ptr_base[out_offset + out_width], out_ptr_base[out_offset + 2 * out_width], diff --git a/mace/kernels/arm/conv_2d_neon_1x7.cc b/mace/kernels/arm/conv_2d_neon_1x7.cc new file mode 100644 index 0000000000000000000000000000000000000000..a60fa56dbb4d98556b7be4b3740eb9a4acff0063 --- /dev/null +++ b/mace/kernels/arm/conv_2d_neon_1x7.cc @@ -0,0 +1,256 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(MACE_ENABLE_NEON) +#include +#endif + +#include "mace/kernels/arm/conv_2d_neon.h" + +namespace mace { +namespace kernels { + +// Ho = 1, Wo = 4, Co = 4 +void Conv2dNeonK1x7S1(const float *input, + const float *filter, + const index_t *in_shape, + const index_t *out_shape, + float *output) { + const index_t in_image_size = in_shape[2] * in_shape[3]; + const index_t out_image_size = out_shape[2] * out_shape[3]; + const index_t in_batch_size = in_shape[1] * in_image_size; + const index_t out_batch_size = out_shape[1] * out_image_size; + +#pragma omp parallel for collapse(2) + for (index_t b = 0; b < out_shape[0]; ++b) { + for (index_t m = 0; m < out_shape[1]; m += 4) { + const index_t out_channels = out_shape[1]; + const index_t out_height = out_shape[2]; + const index_t out_width = out_shape[3]; + const index_t in_channels = in_shape[1]; + const index_t in_width = in_shape[3]; + if (m + 3 < out_channels) { + float *out_ptr0_base = + output + b * out_batch_size + m * out_image_size; +#if defined(MACE_ENABLE_NEON) + float *out_ptr1_base = + output + b * out_batch_size + (m + 1) * out_image_size; + float *out_ptr2_base = + output + b * out_batch_size + (m + 2) * out_image_size; + float *out_ptr3_base = + output + b * out_batch_size + (m + 3) * out_image_size; +#endif + for (index_t c = 0; c < in_channels; ++c) { + const float *in_ptr_base = + input + b * in_batch_size + c * in_image_size; + const float *filter_ptr0 = filter + m * in_channels * 7 + c * 7; +#if defined(MACE_ENABLE_NEON) + const float *filter_ptr1 = + filter + (m + 1) * in_channels * 7 + c * 7; + const float *filter_ptr2 = + filter + (m + 2) * in_channels * 7 + c * 7; + const float *filter_ptr3 = + filter + (m + 3) * in_channels * 7 + c * 7; + /* load filter (4 outch x 1 height x 4 width) */ + float32x4_t vf00, vf01; + float32x4_t vf10, vf11; + float32x4_t vf20, vf21; + float32x4_t vf30, vf31; + vf00 = vld1q_f32(filter_ptr0); + vf01 = vld1q_f32(filter_ptr0 + 3); + vf10 = vld1q_f32(filter_ptr1); + vf11 = vld1q_f32(filter_ptr1 + 3); + vf20 = vld1q_f32(filter_ptr2); + vf21 = vld1q_f32(filter_ptr2 + 3); + vf30 = vld1q_f32(filter_ptr3); + vf31 = vld1q_f32(filter_ptr3 + 3); + + for (index_t h = 0; h < out_height; ++h) { + for (index_t w = 0; w + 3 < out_width; w += 4) { + // output (4 outch x 1 height x 4 width): vo_outch_height + float32x4_t vo0, vo1, vo2, vo3; + // load output + index_t out_offset = h * out_width + w; + vo0 = vld1q_f32(out_ptr0_base + out_offset); + vo1 = vld1q_f32(out_ptr1_base + out_offset); + vo2 = vld1q_f32(out_ptr2_base + out_offset); + vo3 = vld1q_f32(out_ptr3_base + out_offset); + + // input (3 slide) + float32x4_t vi0, vi1, vi2, vi3, vi4, vi5, vi6, vi8; + // input offset + index_t in_offset = h * in_width + w; + // load input + vi0 = vld1q_f32(in_ptr_base + in_offset); + vi4 = vld1q_f32(in_ptr_base + in_offset + 4); + vi8 = vld1q_f32(in_ptr_base + in_offset + 8); + vi1 = vextq_f32(vi0, vi4, 1); + vi2 = vextq_f32(vi0, vi4, 2); + vi3 = vextq_f32(vi0, vi4, 3); + vi5 = vextq_f32(vi4, vi8, 1); + vi6 = vextq_f32(vi4, vi8, 2); + +#if defined(__aarch64__) + /* outch 0 */ + vo0 = vfmaq_laneq_f32(vo0, vi0, vf00, 0); + vo0 = vfmaq_laneq_f32(vo0, vi1, vf00, 1); + vo0 = vfmaq_laneq_f32(vo0, vi2, vf00, 2); + vo0 = vfmaq_laneq_f32(vo0, vi3, vf00, 3); + vo0 = vfmaq_laneq_f32(vo0, vi4, vf01, 1); + vo0 = vfmaq_laneq_f32(vo0, vi5, vf01, 2); + vo0 = vfmaq_laneq_f32(vo0, vi6, vf01, 3); + /* outch 1 */ + vo1 = vfmaq_laneq_f32(vo1, vi0, vf10, 0); + vo1 = vfmaq_laneq_f32(vo1, vi1, vf10, 1); + vo1 = vfmaq_laneq_f32(vo1, vi2, vf10, 2); + vo1 = vfmaq_laneq_f32(vo1, vi3, vf10, 3); + vo1 = vfmaq_laneq_f32(vo1, vi4, vf11, 1); + vo1 = vfmaq_laneq_f32(vo1, vi5, vf11, 2); + vo1 = vfmaq_laneq_f32(vo1, vi6, vf11, 3); + /* outch 2 */ + vo2 = vfmaq_laneq_f32(vo2, vi0, vf20, 0); + vo2 = vfmaq_laneq_f32(vo2, vi1, vf20, 1); + vo2 = vfmaq_laneq_f32(vo2, vi2, vf20, 2); + vo2 = vfmaq_laneq_f32(vo2, vi3, vf20, 3); + vo2 = vfmaq_laneq_f32(vo2, vi4, vf21, 1); + vo2 = vfmaq_laneq_f32(vo2, vi5, vf21, 2); + vo2 = vfmaq_laneq_f32(vo2, vi6, vf21, 3); + /* outch 3 */ + vo3 = vfmaq_laneq_f32(vo3, vi0, vf30, 0); + vo3 = vfmaq_laneq_f32(vo3, vi1, vf30, 1); + vo3 = vfmaq_laneq_f32(vo3, vi2, vf30, 2); + vo3 = vfmaq_laneq_f32(vo3, vi3, vf30, 3); + vo3 = vfmaq_laneq_f32(vo3, vi4, vf31, 1); + vo3 = vfmaq_laneq_f32(vo3, vi5, vf31, 2); + vo3 = vfmaq_laneq_f32(vo3, vi6, vf31, 3); +#else + /* outch 0 */ + vo0 = vmlaq_lane_f32(vo0, vi0, vget_low_f32(vf00), 0); + vo0 = vmlaq_lane_f32(vo0, vi1, vget_low_f32(vf00), 1); + vo0 = vmlaq_lane_f32(vo0, vi2, vget_high_f32(vf00), 0); + vo0 = vmlaq_lane_f32(vo0, vi3, vget_high_f32(vf00), 1); + vo0 = vmlaq_lane_f32(vo0, vi4, vget_low_f32(vf01), 1); + vo0 = vmlaq_lane_f32(vo0, vi5, vget_high_f32(vf01), 0); + vo0 = vmlaq_lane_f32(vo0, vi6, vget_high_f32(vf01), 1); + /* outch 1 */ + vo1 = vmlaq_lane_f32(vo1, vi0, vget_low_f32(vf10), 0); + vo1 = vmlaq_lane_f32(vo1, vi1, vget_low_f32(vf10), 1); + vo1 = vmlaq_lane_f32(vo1, vi2, vget_high_f32(vf10), 0); + vo1 = vmlaq_lane_f32(vo1, vi3, vget_high_f32(vf10), 1); + vo1 = vmlaq_lane_f32(vo1, vi4, vget_low_f32(vf11), 1); + vo1 = vmlaq_lane_f32(vo1, vi5, vget_high_f32(vf11), 0); + vo1 = vmlaq_lane_f32(vo1, vi6, vget_high_f32(vf11), 1); + /* outch 2 */ + vo2 = vmlaq_lane_f32(vo2, vi0, vget_low_f32(vf20), 0); + vo2 = vmlaq_lane_f32(vo2, vi1, vget_low_f32(vf20), 1); + vo2 = vmlaq_lane_f32(vo2, vi2, vget_high_f32(vf20), 0); + vo2 = vmlaq_lane_f32(vo2, vi3, vget_high_f32(vf20), 1); + vo2 = vmlaq_lane_f32(vo2, vi4, vget_low_f32(vf21), 1); + vo2 = vmlaq_lane_f32(vo2, vi5, vget_high_f32(vf21), 0); + vo2 = vmlaq_lane_f32(vo2, vi6, vget_high_f32(vf21), 1); + /* outch 3 */ + vo3 = vmlaq_lane_f32(vo3, vi0, vget_low_f32(vf30), 0); + vo3 = vmlaq_lane_f32(vo3, vi1, vget_low_f32(vf30), 1); + vo3 = vmlaq_lane_f32(vo3, vi2, vget_high_f32(vf30), 0); + vo3 = vmlaq_lane_f32(vo3, vi3, vget_high_f32(vf30), 1); + vo3 = vmlaq_lane_f32(vo3, vi4, vget_low_f32(vf31), 1); + vo3 = vmlaq_lane_f32(vo3, vi5, vget_high_f32(vf31), 0); + vo3 = vmlaq_lane_f32(vo3, vi6, vget_high_f32(vf31), 1); +#endif + + vst1q_f32(out_ptr0_base + out_offset, vo0); + vst1q_f32(out_ptr1_base + out_offset, vo1); + vst1q_f32(out_ptr2_base + out_offset, vo2); + vst1q_f32(out_ptr3_base + out_offset, vo3); + } // w + } // h +#else + for (index_t oc = 0; oc < 4; ++oc) { + Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0 + oc * in_channels * 7, + in_width, 1, 7, out_height, out_width, + out_ptr0_base + oc * out_image_size, 1); + } +#endif + } // c + } else { + for (index_t mm = m; mm < out_channels; ++mm) { + float *out_ptr0_base = + output + b * out_batch_size + mm * out_image_size; + for (index_t c = 0; c < in_channels; ++c) { + const float *in_ptr_base = + input + b * in_batch_size + c * in_image_size; + const float *filter_ptr0 = filter + mm * in_channels * 7 + c * 7; +#if defined(MACE_ENABLE_NEON) + /* load filter (1 outch x 1 height x 4 width) */ + float32x4_t vf00, vf01; + vf00 = vld1q_f32(filter_ptr0); + vf01 = vld1q_f32(filter_ptr0 + 3); + + for (index_t h = 0; h < out_height; ++h) { + for (index_t w = 0; w + 3 < out_width; w += 4) { + // output (1 outch x 1 height x 4 width): vo_outch_height + float32x4_t vo0; + // load output + index_t out_offset = h * out_width + w; + vo0 = vld1q_f32(out_ptr0_base + out_offset); + + // input (3 slide) + float32x4_t vi0, vi1, vi2, vi3, vi4, vi5, vi6, vi8; + // input offset + index_t in_offset = h * in_width + w; + // load input + vi0 = vld1q_f32(in_ptr_base + in_offset); + vi4 = vld1q_f32(in_ptr_base + in_offset + 4); + vi8 = vld1q_f32(in_ptr_base + in_offset + 8); + vi1 = vextq_f32(vi0, vi4, 1); + vi2 = vextq_f32(vi0, vi4, 2); + vi3 = vextq_f32(vi0, vi4, 3); + vi5 = vextq_f32(vi4, vi8, 1); + vi6 = vextq_f32(vi4, vi8, 2); + +#if defined(__aarch64__) + vo0 = vfmaq_laneq_f32(vo0, vi0, vf00, 0); + vo0 = vfmaq_laneq_f32(vo0, vi1, vf00, 1); + vo0 = vfmaq_laneq_f32(vo0, vi2, vf00, 2); + vo0 = vfmaq_laneq_f32(vo0, vi3, vf00, 3); + vo0 = vfmaq_laneq_f32(vo0, vi4, vf01, 1); + vo0 = vfmaq_laneq_f32(vo0, vi5, vf01, 2); + vo0 = vfmaq_laneq_f32(vo0, vi6, vf01, 3); +#else + vo0 = vmlaq_lane_f32(vo0, vi0, vget_low_f32(vf00), 0); + vo0 = vmlaq_lane_f32(vo0, vi1, vget_low_f32(vf00), 1); + vo0 = vmlaq_lane_f32(vo0, vi2, vget_high_f32(vf00), 0); + vo0 = vmlaq_lane_f32(vo0, vi3, vget_high_f32(vf00), 1); + vo0 = vmlaq_lane_f32(vo0, vi4, vget_low_f32(vf01), 1); + vo0 = vmlaq_lane_f32(vo0, vi5, vget_high_f32(vf01), 0); + vo0 = vmlaq_lane_f32(vo0, vi6, vget_high_f32(vf01), 1); +#endif + + vst1q_f32(out_ptr0_base + out_offset, vo0); + } // w + } // h +#else + Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0, + in_width, 1, 7, out_height, out_width, + out_ptr0_base, 1); +#endif + } // c + } + } // if + } // m + } // b +} + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/arm/conv_2d_neon_3x3.cc b/mace/kernels/arm/conv_2d_neon_3x3.cc index 58b28ddc48a5aa5b880c80e6bbdde8ca32f46e38..dbfa5c0c495ae349bcdb92601455eed87881c59f 100644 --- a/mace/kernels/arm/conv_2d_neon_3x3.cc +++ b/mace/kernels/arm/conv_2d_neon_3x3.cc @@ -300,19 +300,11 @@ void Conv2dNeonK3x3S1(const float *input, out_ptr1 += out_width; } // h #else - for (index_t io = 0; io < 2; ++io) { - for (index_t ih = 0; ih < out_height; ++ih) { - for (index_t iw = 0; iw < out_width; ++iw) { - for (int i = 0; i < 3; ++i) { - for (int j = 0; j < 3; ++j) { - out_ptr0[io * out_image_size + ih * out_width + iw] += - in_ptr0[(ih + i) * in_width + (iw + j)] - * filter_ptr0[io * in_channels * 9 + i * 3 + j]; - } - } - } - } - } // for + for (index_t oc = 0; oc < 2; ++oc) { + Conv2dCPUKHxKWCalc(in_ptr0, filter_ptr0 + oc * in_channels * 9, + in_width, 3, 3, out_height, out_width, + out_ptr0_base + oc * out_image_size, 1); + } #endif } // c } else { @@ -501,17 +493,9 @@ void Conv2dNeonK3x3S1(const float *input, out_ptr0 += out_width; } // h #else - for (index_t ih = 0; ih < out_height; ++ih) { - for (index_t iw = 0; iw < out_width; ++iw) { - for (int i = 0; i < 3; ++i) { - for (int j = 0; j < 3; ++j) { - out_ptr0[ih * out_width + iw] += - in_ptr0[(ih + i) * in_width + (iw + j)] - * filter_ptr0[i * 3 + j]; - } - } - } - } + Conv2dCPUKHxKWCalc(in_ptr0, filter_ptr0, + in_width, 3, 3, out_height, out_width, + out_ptr0_base, 1); #endif } // c } // mm @@ -666,17 +650,9 @@ void Conv2dNeonK3x3S2(const float *input, } // w } // h #else - for (index_t ih = 0; ih < out_height; ++ih) { - for (index_t iw = 0; iw < out_width; ++iw) { - for (int i = 0; i < 3; ++i) { - for (int j = 0; j < 3; ++j) { - out_base[ih * out_width + iw] += - in_base[(ih * 2 + i) * in_width + (iw * 2 + j)] - * filter_ptr[i * 3 + j]; - } - } - } - } + Conv2dCPUKHxKWCalc(in_base, filter_ptr, + in_width, 3, 3, out_height, out_width, + out_base, 2); #endif } // c } // m diff --git a/mace/kernels/arm/conv_2d_neon_5x5.cc b/mace/kernels/arm/conv_2d_neon_5x5.cc index 3d77d8f6b5535a386dd2d6ba1a18367e1c189bf1..61672bd435ef0d49790d0f55f69e4be5355d8a12 100644 --- a/mace/kernels/arm/conv_2d_neon_5x5.cc +++ b/mace/kernels/arm/conv_2d_neon_5x5.cc @@ -76,30 +76,6 @@ namespace kernels { vo0 = vmlaq_lane_f32(vo0, vi3, vget_high_f32(vf00), 1); \ vo0 = vmlaq_lane_f32(vo0, vi4, vf01, 1); -inline void Conv2dCPUK5x5Calc(const float *in_ptr_base, - const float *filter_ptr0, - const index_t in_width, - const index_t in_channels, - const index_t out_height, - const index_t out_width, - const index_t out_image_size, - float *out_ptr0_base, - const index_t io, - const int stride) { - for (index_t ih = 0; ih < out_height; ++ih) { - for (index_t iw = 0; iw < out_width; ++iw) { - for (int i = 0; i < 5; ++i) { - for (int j = 0; j < 5; ++j) { - out_ptr0_base[io * out_image_size + ih * out_width + iw] += - in_ptr_base[(ih * stride + i) * in_width + (iw * stride + j)] * - filter_ptr0[io * in_channels * 25 + i * 5 + j]; - } - } - } - } -} - - // Ho = 1, Wo = 4, Co = 4 void Conv2dNeonK5x5S1(const float *input, const float *filter, @@ -183,11 +159,11 @@ void Conv2dNeonK5x5S1(const float *input, } // w } // h #else - for (index_t io = 0; io < 4; ++io) { - Conv2dCPUK5x5Calc(in_ptr_base, filter_ptr0, in_width, in_channels, - out_height, out_width, out_image_size, - out_ptr0_base, io, 1); - } // for + for (index_t oc = 0; oc < 4; ++oc) { + Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0 + oc * in_channels * 25, + in_width, 5, 5, out_height, out_width, + out_ptr0_base + oc * out_image_size, 1); + } #endif } // c } else { @@ -229,9 +205,9 @@ void Conv2dNeonK5x5S1(const float *input, } // w } // h #else - Conv2dCPUK5x5Calc(in_ptr_base, filter_ptr0, in_width, in_channels, - out_height, out_width, out_image_size, - out_ptr0_base, 0, 1); + Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0, + in_width, 5, 5, out_height, out_width, + out_ptr0_base, 1); #endif } // c } // mm diff --git a/mace/kernels/arm/conv_2d_neon_7x1.cc b/mace/kernels/arm/conv_2d_neon_7x1.cc new file mode 100644 index 0000000000000000000000000000000000000000..17215bb8beea2738c7576cb382e67206353451ce --- /dev/null +++ b/mace/kernels/arm/conv_2d_neon_7x1.cc @@ -0,0 +1,297 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(MACE_ENABLE_NEON) +#include +#endif + +#include "mace/kernels/arm/conv_2d_neon.h" + +namespace mace { +namespace kernels { + +// Ho = 4, Wo = 1, Co = 4 +void Conv2dNeonK7x1S1(const float *input, + const float *filter, + const index_t *in_shape, + const index_t *out_shape, + float *output) { + const index_t in_image_size = in_shape[2] * in_shape[3]; + const index_t out_image_size = out_shape[2] * out_shape[3]; + const index_t in_batch_size = in_shape[1] * in_image_size; + const index_t out_batch_size = out_shape[1] * out_image_size; + +#pragma omp parallel for collapse(2) + for (index_t b = 0; b < out_shape[0]; ++b) { + for (index_t m = 0; m < out_shape[1]; m += 4) { + const index_t out_channels = out_shape[1]; + const index_t out_height = out_shape[2]; + const index_t out_width = out_shape[3]; + const index_t in_channels = in_shape[1]; + const index_t in_width = in_shape[3]; + if (m + 3 < out_channels) { + float *out_ptr0_base = + output + b * out_batch_size + m * out_image_size; +#if defined(MACE_ENABLE_NEON) + float *out_ptr1_base = + output + b * out_batch_size + (m + 1) * out_image_size; + float *out_ptr2_base = + output + b * out_batch_size + (m + 2) * out_image_size; + float *out_ptr3_base = + output + b * out_batch_size + (m + 3) * out_image_size; +#endif + for (index_t c = 0; c < in_channels; ++c) { + const float *in_ptr_base = + input + b * in_batch_size + c * in_image_size; + const float *filter_ptr0 = filter + m * in_channels * 7 + c * 7; +#if defined(MACE_ENABLE_NEON) + const float *filter_ptr1 = + filter + (m + 1) * in_channels * 7 + c * 7; + const float *filter_ptr2 = + filter + (m + 2) * in_channels * 7 + c * 7; + const float *filter_ptr3 = + filter + (m + 3) * in_channels * 7 + c * 7; + /* load filter (4 outch x 4 height x 1 width) */ + float32x4_t vf00, vf01; + float32x4_t vf10, vf11; + float32x4_t vf20, vf21; + float32x4_t vf30, vf31; + vf00 = vld1q_f32(filter_ptr0); + vf01 = vld1q_f32(filter_ptr0 + 3); + vf10 = vld1q_f32(filter_ptr1); + vf11 = vld1q_f32(filter_ptr1 + 3); + vf20 = vld1q_f32(filter_ptr2); + vf21 = vld1q_f32(filter_ptr2 + 3); + vf30 = vld1q_f32(filter_ptr3); + vf31 = vld1q_f32(filter_ptr3 + 3); + + for (index_t h = 0; h + 3 < out_height; h += 4) { + for (index_t w = 0; w < out_width; ++w) { + // load output + index_t out_offset = h * out_width + w; + // output (4 outch x 4 height x 1 width): vo_outch_height + float32x4_t vo0 = {out_ptr0_base[out_offset], + out_ptr0_base[out_offset + out_width], + out_ptr0_base[out_offset + 2 * out_width], + out_ptr0_base[out_offset + 3 * out_width]}; + float32x4_t vo1 = {out_ptr1_base[out_offset], + out_ptr1_base[out_offset + out_width], + out_ptr1_base[out_offset + 2 * out_width], + out_ptr1_base[out_offset + 3 * out_width]}; + float32x4_t vo2 = {out_ptr2_base[out_offset], + out_ptr2_base[out_offset + out_width], + out_ptr2_base[out_offset + 2 * out_width], + out_ptr2_base[out_offset + 3 * out_width]}; + float32x4_t vo3 = {out_ptr3_base[out_offset], + out_ptr3_base[out_offset + out_width], + out_ptr3_base[out_offset + 2 * out_width], + out_ptr3_base[out_offset + 3 * out_width]}; + + + // input offset + index_t in_offset = h * in_width + w; + // input (3 slide) + float32x4_t vi0 = {in_ptr_base[in_offset], + in_ptr_base[in_offset + in_width], + in_ptr_base[in_offset + 2 * in_width], + in_ptr_base[in_offset + 3 * in_width]}; + float32x4_t vi4 = {in_ptr_base[in_offset + 4 * in_width], + in_ptr_base[in_offset + 5 * in_width], + in_ptr_base[in_offset + 6 * in_width], + in_ptr_base[in_offset + 7 * in_width]}; + float32x4_t vi8 = {in_ptr_base[in_offset + 8 * in_width], + in_ptr_base[in_offset + 9 * in_width]}; + float32x4_t vi1 = vextq_f32(vi0, vi4, 1); + float32x4_t vi2 = vextq_f32(vi0, vi4, 2); + float32x4_t vi3 = vextq_f32(vi0, vi4, 3); + float32x4_t vi5 = vextq_f32(vi4, vi8, 1); + float32x4_t vi6 = vextq_f32(vi4, vi8, 2); + +#if defined(__aarch64__) + /* outch 0 */ + vo0 = vfmaq_laneq_f32(vo0, vi0, vf00, 0); + vo0 = vfmaq_laneq_f32(vo0, vi1, vf00, 1); + vo0 = vfmaq_laneq_f32(vo0, vi2, vf00, 2); + vo0 = vfmaq_laneq_f32(vo0, vi3, vf00, 3); + vo0 = vfmaq_laneq_f32(vo0, vi4, vf01, 1); + vo0 = vfmaq_laneq_f32(vo0, vi5, vf01, 2); + vo0 = vfmaq_laneq_f32(vo0, vi6, vf01, 3); + /* outch 1 */ + vo1 = vfmaq_laneq_f32(vo1, vi0, vf10, 0); + vo1 = vfmaq_laneq_f32(vo1, vi1, vf10, 1); + vo1 = vfmaq_laneq_f32(vo1, vi2, vf10, 2); + vo1 = vfmaq_laneq_f32(vo1, vi3, vf10, 3); + vo1 = vfmaq_laneq_f32(vo1, vi4, vf11, 1); + vo1 = vfmaq_laneq_f32(vo1, vi5, vf11, 2); + vo1 = vfmaq_laneq_f32(vo1, vi6, vf11, 3); + /* outch 2 */ + vo2 = vfmaq_laneq_f32(vo2, vi0, vf20, 0); + vo2 = vfmaq_laneq_f32(vo2, vi1, vf20, 1); + vo2 = vfmaq_laneq_f32(vo2, vi2, vf20, 2); + vo2 = vfmaq_laneq_f32(vo2, vi3, vf20, 3); + vo2 = vfmaq_laneq_f32(vo2, vi4, vf21, 1); + vo2 = vfmaq_laneq_f32(vo2, vi5, vf21, 2); + vo2 = vfmaq_laneq_f32(vo2, vi6, vf21, 3); + /* outch 3 */ + vo3 = vfmaq_laneq_f32(vo3, vi0, vf30, 0); + vo3 = vfmaq_laneq_f32(vo3, vi1, vf30, 1); + vo3 = vfmaq_laneq_f32(vo3, vi2, vf30, 2); + vo3 = vfmaq_laneq_f32(vo3, vi3, vf30, 3); + vo3 = vfmaq_laneq_f32(vo3, vi4, vf31, 1); + vo3 = vfmaq_laneq_f32(vo3, vi5, vf31, 2); + vo3 = vfmaq_laneq_f32(vo3, vi6, vf31, 3); +#else + /* outch 0 */ + vo0 = vmlaq_lane_f32(vo0, vi0, vget_low_f32(vf00), 0); + vo0 = vmlaq_lane_f32(vo0, vi1, vget_low_f32(vf00), 1); + vo0 = vmlaq_lane_f32(vo0, vi2, vget_high_f32(vf00), 0); + vo0 = vmlaq_lane_f32(vo0, vi3, vget_high_f32(vf00), 1); + vo0 = vmlaq_lane_f32(vo0, vi4, vget_low_f32(vf01), 1); + vo0 = vmlaq_lane_f32(vo0, vi5, vget_high_f32(vf01), 0); + vo0 = vmlaq_lane_f32(vo0, vi6, vget_high_f32(vf01), 1); + /* outch 1 */ + vo1 = vmlaq_lane_f32(vo1, vi0, vget_low_f32(vf10), 0); + vo1 = vmlaq_lane_f32(vo1, vi1, vget_low_f32(vf10), 1); + vo1 = vmlaq_lane_f32(vo1, vi2, vget_high_f32(vf10), 0); + vo1 = vmlaq_lane_f32(vo1, vi3, vget_high_f32(vf10), 1); + vo1 = vmlaq_lane_f32(vo1, vi4, vget_low_f32(vf11), 1); + vo1 = vmlaq_lane_f32(vo1, vi5, vget_high_f32(vf11), 0); + vo1 = vmlaq_lane_f32(vo1, vi6, vget_high_f32(vf11), 1); + /* outch 2 */ + vo2 = vmlaq_lane_f32(vo2, vi0, vget_low_f32(vf20), 0); + vo2 = vmlaq_lane_f32(vo2, vi1, vget_low_f32(vf20), 1); + vo2 = vmlaq_lane_f32(vo2, vi2, vget_high_f32(vf20), 0); + vo2 = vmlaq_lane_f32(vo2, vi3, vget_high_f32(vf20), 1); + vo2 = vmlaq_lane_f32(vo2, vi4, vget_low_f32(vf21), 1); + vo2 = vmlaq_lane_f32(vo2, vi5, vget_high_f32(vf21), 0); + vo2 = vmlaq_lane_f32(vo2, vi6, vget_high_f32(vf21), 1); + /* outch 3 */ + vo3 = vmlaq_lane_f32(vo3, vi0, vget_low_f32(vf30), 0); + vo3 = vmlaq_lane_f32(vo3, vi1, vget_low_f32(vf30), 1); + vo3 = vmlaq_lane_f32(vo3, vi2, vget_high_f32(vf30), 0); + vo3 = vmlaq_lane_f32(vo3, vi3, vget_high_f32(vf30), 1); + vo3 = vmlaq_lane_f32(vo3, vi4, vget_low_f32(vf31), 1); + vo3 = vmlaq_lane_f32(vo3, vi5, vget_high_f32(vf31), 0); + vo3 = vmlaq_lane_f32(vo3, vi6, vget_high_f32(vf31), 1); +#endif + + out_ptr0_base[out_offset] = vo0[0]; + out_ptr0_base[out_offset + out_width] = vo0[1]; + out_ptr0_base[out_offset + 2 * out_width] = vo0[2]; + out_ptr0_base[out_offset + 3 * out_width] = vo0[3]; + out_ptr1_base[out_offset] = vo1[0]; + out_ptr1_base[out_offset + out_width] = vo1[1]; + out_ptr1_base[out_offset + 2 * out_width] = vo1[2]; + out_ptr1_base[out_offset + 3 * out_width] = vo1[3]; + out_ptr2_base[out_offset] = vo2[0]; + out_ptr2_base[out_offset + out_width] = vo2[1]; + out_ptr2_base[out_offset + 2 * out_width] = vo2[2]; + out_ptr2_base[out_offset + 3 * out_width] = vo2[3]; + out_ptr3_base[out_offset] = vo3[0]; + out_ptr3_base[out_offset + out_width] = vo3[1]; + out_ptr3_base[out_offset + 2 * out_width] = vo3[2]; + out_ptr3_base[out_offset + 3 * out_width] = vo3[3]; + } // w + } // h +#else + for (index_t oc = 0; oc < 4; ++oc) { + Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0 + oc * in_channels * 7, + in_width, 7, 1, out_height, out_width, + out_ptr0_base + oc * out_image_size, 1); + } +#endif + } // c + } else { + for (index_t mm = m; mm < out_channels; ++mm) { + float *out_ptr0_base = + output + b * out_batch_size + mm * out_image_size; + for (index_t c = 0; c < in_channels; ++c) { + const float *in_ptr_base = + input + b * in_batch_size + c * in_image_size; + const float *filter_ptr0 = filter + mm * in_channels * 7 + c * 7; +#if defined(MACE_ENABLE_NEON) + /* load filter (1 outch x 4 height x 1 width) */ + float32x4_t vf00, vf01; + vf00 = vld1q_f32(filter_ptr0); + vf01 = vld1q_f32(filter_ptr0 + 3); + + for (index_t h = 0; h + 3 < out_height; h += 4) { + for (index_t w = 0; w < out_width; ++w) { + // load output + index_t out_offset = h * out_width + w; + // output (1 outch x 4 height x 1 width): vo_outch_height + float32x4_t vo0 = {out_ptr0_base[out_offset], + out_ptr0_base[out_offset + out_width], + out_ptr0_base[out_offset + 2 * out_width], + out_ptr0_base[out_offset + 3 * out_width]}; + + // input offset + index_t in_offset = h * in_width + w; + // input (3 slide) + float32x4_t vi0 = {in_ptr_base[in_offset], + in_ptr_base[in_offset + in_width], + in_ptr_base[in_offset + 2 * in_width], + in_ptr_base[in_offset + 3 * in_width]}; + float32x4_t vi4 = {in_ptr_base[in_offset + 4 * in_width], + in_ptr_base[in_offset + 5 * in_width], + in_ptr_base[in_offset + 6 * in_width], + in_ptr_base[in_offset + 7 * in_width]}; + float32x4_t vi8 = {in_ptr_base[in_offset + 8 * in_width], + in_ptr_base[in_offset + 9 * in_width], + in_ptr_base[in_offset + 10 * in_width], + in_ptr_base[in_offset + 11 * in_width]}; + float32x4_t vi1 = vextq_f32(vi0, vi4, 1); + float32x4_t vi2 = vextq_f32(vi0, vi4, 2); + float32x4_t vi3 = vextq_f32(vi0, vi4, 3); + float32x4_t vi5 = vextq_f32(vi4, vi8, 1); + float32x4_t vi6 = vextq_f32(vi4, vi8, 2); + +#if defined(__aarch64__) + vo0 = vfmaq_laneq_f32(vo0, vi0, vf00, 0); + vo0 = vfmaq_laneq_f32(vo0, vi1, vf00, 1); + vo0 = vfmaq_laneq_f32(vo0, vi2, vf00, 2); + vo0 = vfmaq_laneq_f32(vo0, vi3, vf00, 3); + vo0 = vfmaq_laneq_f32(vo0, vi4, vf01, 1); + vo0 = vfmaq_laneq_f32(vo0, vi5, vf01, 2); + vo0 = vfmaq_laneq_f32(vo0, vi6, vf01, 3); +#else + vo0 = vmlaq_lane_f32(vo0, vi0, vget_low_f32(vf00), 0); + vo0 = vmlaq_lane_f32(vo0, vi1, vget_low_f32(vf00), 1); + vo0 = vmlaq_lane_f32(vo0, vi2, vget_high_f32(vf00), 0); + vo0 = vmlaq_lane_f32(vo0, vi3, vget_high_f32(vf00), 1); + vo0 = vmlaq_lane_f32(vo0, vi4, vget_low_f32(vf01), 1); + vo0 = vmlaq_lane_f32(vo0, vi5, vget_high_f32(vf01), 0); + vo0 = vmlaq_lane_f32(vo0, vi6, vget_high_f32(vf01), 1); +#endif + + out_ptr0_base[out_offset] = vo0[0]; + out_ptr0_base[out_offset + out_width] = vo0[1]; + out_ptr0_base[out_offset + 2 * out_width] = vo0[2]; + out_ptr0_base[out_offset + 3 * out_width] = vo0[3]; + } // w + } // h +#else + Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0, + in_width, 7, 1, out_height, out_width, + out_ptr0_base, 1); +#endif + } // c + } + } // if + } // m + } // b +} + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/arm/conv_2d_neon_7x7.cc b/mace/kernels/arm/conv_2d_neon_7x7.cc index 4432f2a05b848fdc9978b5a10fed7985ff3d4cff..b6c2d5fd897f55af5add7a2b8699618d10b62bfc 100644 --- a/mace/kernels/arm/conv_2d_neon_7x7.cc +++ b/mace/kernels/arm/conv_2d_neon_7x7.cc @@ -153,30 +153,6 @@ namespace kernels { vo0 = vmlaq_lane_f32(vo0, vi5, vget_high_f32(vf01), 0); \ vo0 = vmlaq_lane_f32(vo0, vi6, vget_high_f32(vf01), 1); -inline void Conv2dCPUK7x7Calc(const float *in_ptr_base, - const float *filter_ptr0, - const index_t in_width, - const index_t in_channels, - const index_t out_height, - const index_t out_width, - const index_t out_image_size, - float *out_ptr0_base, - const index_t io, - const int stride) { - for (index_t ih = 0; ih < out_height; ++ih) { - for (index_t iw = 0; iw < out_width; ++iw) { - for (int i = 0; i < 7; ++i) { - for (int j = 0; j < 7; ++j) { - out_ptr0_base[io * out_image_size + ih * out_width + iw] += - in_ptr_base[(ih * stride + i) * in_width + (iw * stride + j)] * - filter_ptr0[io * in_channels * 49 + i * 7 + j]; - } - } - } - } -} - - // Ho = 1, Wo = 4, Co = 4 void Conv2dNeonK7x7S1(const float *input, const float *filter, @@ -268,11 +244,11 @@ void Conv2dNeonK7x7S1(const float *input, } // w } // h #else - for (index_t io = 0; io < 4; ++io) { - Conv2dCPUK7x7Calc(in_ptr_base, filter_ptr0, in_width, in_channels, - out_height, out_width, out_image_size, - out_ptr0_base, io, 1); - } // for + for (index_t oc = 0; oc < 4; ++oc) { + Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0 + oc * in_channels * 49, + in_width, 7, 7, out_height, out_width, + out_ptr0_base + oc * out_image_size, 1); + } #endif } // c } else { @@ -322,9 +298,9 @@ void Conv2dNeonK7x7S1(const float *input, } // w } // h #else - Conv2dCPUK7x7Calc(in_ptr_base, filter_ptr0, in_width, in_channels, - out_height, out_width, out_image_size, - out_ptr0_base, 0, 1); + Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0, + in_width, 7, 7, out_height, out_width, + out_ptr0_base, 1); #endif } // c } // mm @@ -429,11 +405,11 @@ void Conv2dNeonK7x7S2(const float *input, } // w } // h #else - for (index_t io = 0; io < 4; ++io) { - Conv2dCPUK7x7Calc(in_ptr_base, filter_ptr0, in_width, in_channels, - out_height, out_width, out_image_size, - out_ptr0_base, io, 2); - } // for + for (index_t oc = 0; oc < 4; ++oc) { + Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0 + oc * in_channels * 49, + in_width, 7, 7, out_height, out_width, + out_ptr0_base + oc * out_image_size, 2); + } #endif } // c } else { @@ -488,9 +464,9 @@ void Conv2dNeonK7x7S2(const float *input, } // w } // h #else - Conv2dCPUK7x7Calc(in_ptr_base, filter_ptr0, in_width, in_channels, - out_height, out_width, out_image_size, - out_ptr0_base, 0, 2); + Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0, + in_width, 7, 7, out_height, out_width, + out_ptr0_base, 2); #endif } // c } // mm @@ -595,11 +571,11 @@ void Conv2dNeonK7x7S3(const float *input, } // w } // h #else - for (index_t io = 0; io < 4; ++io) { - Conv2dCPUK7x7Calc(in_ptr_base, filter_ptr0, in_width, in_channels, - out_height, out_width, out_image_size, - out_ptr0_base, io, 3); - } // for + for (index_t oc = 0; oc < 4; ++oc) { + Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0 + oc * in_channels * 49, + in_width, 7, 7, out_height, out_width, + out_ptr0_base + oc * out_image_size, 3); + } #endif } // c } else { @@ -654,9 +630,9 @@ void Conv2dNeonK7x7S3(const float *input, } // w } // h #else - Conv2dCPUK7x7Calc(in_ptr_base, filter_ptr0, in_width, in_channels, - out_height, out_width, out_image_size, - out_ptr0_base, 0, 3); + Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0, + in_width, 7, 7, out_height, out_width, + out_ptr0_base, 3); #endif } // c } // mm diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index 7a0b8328bb84ed088dcf532295928b5a6040e658..53531324503059d4ff68855b55f515f178a62c73 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -357,6 +357,10 @@ struct Conv2dFunctor : Conv2dFunctorBase { && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; bool use_neon_5x5_s1 = filter_h == 5 && filter_w == 5 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; + bool use_neon_1x7_s1 = filter_h == 1 && filter_w == 7 + && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; + bool use_neon_7x1_s1 = filter_h == 7 && filter_w == 1 + && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; bool use_neon_7x7_s1 = filter_h == 7 && filter_w == 7 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; bool use_neon_7x7_s2 = filter_h == 7 && filter_w == 7 @@ -414,7 +418,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { } else if (use_neon_3x3_s1) { tile_h = 2; tile_w = 4; - } else if (use_neon_15x1_s1) { + } else if (use_neon_7x1_s1 || use_neon_15x1_s1) { tile_h = 4; tile_w = 1; } else { @@ -566,6 +570,22 @@ struct Conv2dFunctor : Conv2dFunctorBase { extra_output_shape, pad_output); }; + } else if (use_neon_1x7_s1) { + conv_func = [=](const float *pad_input, float *pad_output) { + Conv2dNeonK1x7S1(pad_input, + filter_data, + extra_input_shape, + extra_output_shape, + pad_output); + }; + } else if (use_neon_7x1_s1) { + conv_func = [=](const float *pad_input, float *pad_output) { + Conv2dNeonK7x1S1(pad_input, + filter_data, + extra_input_shape, + extra_output_shape, + pad_output); + }; } else if (use_neon_7x7_s1) { conv_func = [=](const float *pad_input, float *pad_output) { Conv2dNeonK7x7S1(pad_input, diff --git a/mace/kernels/gemm.cc b/mace/kernels/gemm.cc index 178e0720a13555d00387ccc57a87f065cd746b42..9b3aa599f53923925119bf21378b7ba7896ac189 100644 --- a/mace/kernels/gemm.cc +++ b/mace/kernels/gemm.cc @@ -388,8 +388,7 @@ inline void GemmTile(const float *A, "v22", "v23", "v24", - "v25", - "v26" + "v25" ); w = (width >> 2) << 2; diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc index 4a5d80e4cfbfabcd7d948d874fdc8c2f144fbfd4..c0e5e28d7ed20b0aa3efc2f19d304723609dc083 100644 --- a/mace/ops/conv_2d_benchmark.cc +++ b/mace/ops/conv_2d_benchmark.cc @@ -168,6 +168,10 @@ BM_CONV_2D(1, 1024, 7, 7, 1, 1, 1, 1, SAME, 1024); BM_CONV_2D(64, 32, 34, 34, 3, 3, 1, 1, VALID, 32); BM_CONV_2D(1, 32, 34, 34, 3, 3, 1, 1, VALID, 32); +BM_CONV_2D(1, 192, 17, 17, 1, 7, 1, 1, SAME, 192); +BM_CONV_2D(1, 192, 17, 17, 7, 1, 1, 1, SAME, 192); +BM_CONV_2D(1, 160, 17, 17, 7, 1, 1, 1, SAME, 192); + BM_CONV_2D(1, 32, 256, 256, 1, 15, 1, 1, SAME, 2); BM_CONV_2D(1, 32, 256, 256, 15, 1, 1, 1, SAME, 2); BM_CONV_2D(1, 64, 64, 64, 15, 1, 1, 1, SAME, 2); diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index 543e2ac906aed88c1bb904c592dab6bfa7708482..be0e38cc5825aa8b2d090bb31957839808537a1d 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -776,6 +776,36 @@ TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv3x3S12) { {1, 1}); } +TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv5x5S12) { + TestHalfComplexConvNxNS12({32, 32}, {5, 5, 3, 64}, + {1, 1}); + TestHalfComplexConvNxNS12({32, 32}, {5, 5, 3, 63}, + {1, 1}); +} + +TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv1x7S1) { + TestHalfComplexConvNxNS12({17, 17}, {1, 7, 192, 192}, + {1, 1}); + TestHalfComplexConvNxNS12({17, 17}, {1, 7, 192, 191}, + {1, 1}); +} + +TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv7x1S1) { + TestHalfComplexConvNxNS12({17, 17}, {7, 1, 192, 192}, + {1, 1}); + TestHalfComplexConvNxNS12({17, 17}, {7, 1, 160, 192}, + {1, 1}); + TestHalfComplexConvNxNS12({17, 17}, {7, 1, 160, 191}, + {1, 1}); +} + +TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv7x7S12) { + TestHalfComplexConvNxNS12({32, 32}, {7, 7, 3, 64}, + {1, 1}); + TestHalfComplexConvNxNS12({32, 32}, {7, 7, 3, 63}, + {1, 1}); +} + TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv15x1S12) { TestHalfComplexConvNxNS12({32, 32}, {15, 1, 256, 2}, {1, 1}); @@ -792,11 +822,6 @@ TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv1x15S12) { {1, 1}); } -TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv7x75S12) { - TestHalfComplexConvNxNS12({32, 32}, {7, 7, 3, 64}, - {1, 1}); -} - TEST_F(Conv2dOpTest, OPENCLHalfUnalignedConv1x1S12) { TestHalfComplexConvNxNS12({107, 113}, {1, 1, 5, 7}, {1, 1});