提交 b3efb72b 编写于 作者: B Bin Li

Optimize armv7 armv8 conv1x7 and conv7x1

上级 cb70c32a
......@@ -47,6 +47,18 @@ extern void Conv2dNeonK5x5S1(const float *input,
const index_t *out_shape,
float *output);
extern void Conv2dNeonK1x7S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
extern void Conv2dNeonK7x1S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
extern void Conv2dNeonK7x7S1(const float *input,
const float *filter,
const index_t *in_shape,
......@@ -77,6 +89,29 @@ extern void Conv2dNeonK15x1S1(const float *input,
const index_t *out_shape,
float *output);
// calculate one output channel and one input channel
inline void Conv2dCPUKHxKWCalc(const float *in_ptr,
const float *filter_ptr,
const index_t in_width,
const index_t filter_height,
const index_t filter_width,
const index_t out_height,
const index_t out_width,
float *out_ptr,
const int stride) {
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
for (int i = 0; i < filter_height; ++i) {
for (int j = 0; j < filter_width; ++j) {
out_ptr[h * out_width + w]
+= in_ptr[(h * stride + i) * in_width + (w * stride + j)]
* filter_ptr[i * filter_width + j];
}
}
}
}
}
} // namespace kernels
} // namespace mace
......
......@@ -76,7 +76,7 @@ void Conv2dNeonK15x1S1(const float *input,
input + b * in_batch_size + c * in_image_size;
const float *filter_ptr = filter + m * in_channels * 15 + c * 15;
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
/* load filter (1 outch x 1 height x 4 width) */
/* load filter (1 outch x 4 height x 1 width) */
float32x4_t vf0, vf1, vf2, vf3;
vf0 = vld1q_f32(filter_ptr);
vf1 = vld1q_f32(filter_ptr + 4);
......@@ -87,7 +87,7 @@ void Conv2dNeonK15x1S1(const float *input,
for (index_t wt = 0; wt < tile_width && w + wt < out_width; ++wt) {
// load output
index_t out_offset = h * out_width + w + wt;
// output (1 outch x 1 height x 4 width): vo_outch_height
// output (1 outch x 4 height x 1 width): vo_outch_height
float32x4_t vo = {out_ptr_base[out_offset],
out_ptr_base[out_offset + out_width],
out_ptr_base[out_offset + 2 * out_width],
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include "mace/kernels/arm/conv_2d_neon.h"
namespace mace {
namespace kernels {
// Ho = 1, Wo = 4, Co = 4
void Conv2dNeonK1x7S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output) {
const index_t in_image_size = in_shape[2] * in_shape[3];
const index_t out_image_size = out_shape[2] * out_shape[3];
const index_t in_batch_size = in_shape[1] * in_image_size;
const index_t out_batch_size = out_shape[1] * out_image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < out_shape[0]; ++b) {
for (index_t m = 0; m < out_shape[1]; m += 4) {
const index_t out_channels = out_shape[1];
const index_t out_height = out_shape[2];
const index_t out_width = out_shape[3];
const index_t in_channels = in_shape[1];
const index_t in_width = in_shape[3];
if (m + 3 < out_channels) {
float *out_ptr0_base =
output + b * out_batch_size + m * out_image_size;
#if defined(MACE_ENABLE_NEON)
float *out_ptr1_base =
output + b * out_batch_size + (m + 1) * out_image_size;
float *out_ptr2_base =
output + b * out_batch_size + (m + 2) * out_image_size;
float *out_ptr3_base =
output + b * out_batch_size + (m + 3) * out_image_size;
#endif
for (index_t c = 0; c < in_channels; ++c) {
const float *in_ptr_base =
input + b * in_batch_size + c * in_image_size;
const float *filter_ptr0 = filter + m * in_channels * 7 + c * 7;
#if defined(MACE_ENABLE_NEON)
const float *filter_ptr1 =
filter + (m + 1) * in_channels * 7 + c * 7;
const float *filter_ptr2 =
filter + (m + 2) * in_channels * 7 + c * 7;
const float *filter_ptr3 =
filter + (m + 3) * in_channels * 7 + c * 7;
/* load filter (4 outch x 1 height x 4 width) */
float32x4_t vf00, vf01;
float32x4_t vf10, vf11;
float32x4_t vf20, vf21;
float32x4_t vf30, vf31;
vf00 = vld1q_f32(filter_ptr0);
vf01 = vld1q_f32(filter_ptr0 + 3);
vf10 = vld1q_f32(filter_ptr1);
vf11 = vld1q_f32(filter_ptr1 + 3);
vf20 = vld1q_f32(filter_ptr2);
vf21 = vld1q_f32(filter_ptr2 + 3);
vf30 = vld1q_f32(filter_ptr3);
vf31 = vld1q_f32(filter_ptr3 + 3);
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w + 3 < out_width; w += 4) {
// output (4 outch x 1 height x 4 width): vo_outch_height
float32x4_t vo0, vo1, vo2, vo3;
// load output
index_t out_offset = h * out_width + w;
vo0 = vld1q_f32(out_ptr0_base + out_offset);
vo1 = vld1q_f32(out_ptr1_base + out_offset);
vo2 = vld1q_f32(out_ptr2_base + out_offset);
vo3 = vld1q_f32(out_ptr3_base + out_offset);
// input (3 slide)
float32x4_t vi0, vi1, vi2, vi3, vi4, vi5, vi6, vi8;
// input offset
index_t in_offset = h * in_width + w;
// load input
vi0 = vld1q_f32(in_ptr_base + in_offset);
vi4 = vld1q_f32(in_ptr_base + in_offset + 4);
vi8 = vld1q_f32(in_ptr_base + in_offset + 8);
vi1 = vextq_f32(vi0, vi4, 1);
vi2 = vextq_f32(vi0, vi4, 2);
vi3 = vextq_f32(vi0, vi4, 3);
vi5 = vextq_f32(vi4, vi8, 1);
vi6 = vextq_f32(vi4, vi8, 2);
#if defined(__aarch64__)
/* outch 0 */
vo0 = vfmaq_laneq_f32(vo0, vi0, vf00, 0);
vo0 = vfmaq_laneq_f32(vo0, vi1, vf00, 1);
vo0 = vfmaq_laneq_f32(vo0, vi2, vf00, 2);
vo0 = vfmaq_laneq_f32(vo0, vi3, vf00, 3);
vo0 = vfmaq_laneq_f32(vo0, vi4, vf01, 1);
vo0 = vfmaq_laneq_f32(vo0, vi5, vf01, 2);
vo0 = vfmaq_laneq_f32(vo0, vi6, vf01, 3);
/* outch 1 */
vo1 = vfmaq_laneq_f32(vo1, vi0, vf10, 0);
vo1 = vfmaq_laneq_f32(vo1, vi1, vf10, 1);
vo1 = vfmaq_laneq_f32(vo1, vi2, vf10, 2);
vo1 = vfmaq_laneq_f32(vo1, vi3, vf10, 3);
vo1 = vfmaq_laneq_f32(vo1, vi4, vf11, 1);
vo1 = vfmaq_laneq_f32(vo1, vi5, vf11, 2);
vo1 = vfmaq_laneq_f32(vo1, vi6, vf11, 3);
/* outch 2 */
vo2 = vfmaq_laneq_f32(vo2, vi0, vf20, 0);
vo2 = vfmaq_laneq_f32(vo2, vi1, vf20, 1);
vo2 = vfmaq_laneq_f32(vo2, vi2, vf20, 2);
vo2 = vfmaq_laneq_f32(vo2, vi3, vf20, 3);
vo2 = vfmaq_laneq_f32(vo2, vi4, vf21, 1);
vo2 = vfmaq_laneq_f32(vo2, vi5, vf21, 2);
vo2 = vfmaq_laneq_f32(vo2, vi6, vf21, 3);
/* outch 3 */
vo3 = vfmaq_laneq_f32(vo3, vi0, vf30, 0);
vo3 = vfmaq_laneq_f32(vo3, vi1, vf30, 1);
vo3 = vfmaq_laneq_f32(vo3, vi2, vf30, 2);
vo3 = vfmaq_laneq_f32(vo3, vi3, vf30, 3);
vo3 = vfmaq_laneq_f32(vo3, vi4, vf31, 1);
vo3 = vfmaq_laneq_f32(vo3, vi5, vf31, 2);
vo3 = vfmaq_laneq_f32(vo3, vi6, vf31, 3);
#else
/* outch 0 */
vo0 = vmlaq_lane_f32(vo0, vi0, vget_low_f32(vf00), 0);
vo0 = vmlaq_lane_f32(vo0, vi1, vget_low_f32(vf00), 1);
vo0 = vmlaq_lane_f32(vo0, vi2, vget_high_f32(vf00), 0);
vo0 = vmlaq_lane_f32(vo0, vi3, vget_high_f32(vf00), 1);
vo0 = vmlaq_lane_f32(vo0, vi4, vget_low_f32(vf01), 1);
vo0 = vmlaq_lane_f32(vo0, vi5, vget_high_f32(vf01), 0);
vo0 = vmlaq_lane_f32(vo0, vi6, vget_high_f32(vf01), 1);
/* outch 1 */
vo1 = vmlaq_lane_f32(vo1, vi0, vget_low_f32(vf10), 0);
vo1 = vmlaq_lane_f32(vo1, vi1, vget_low_f32(vf10), 1);
vo1 = vmlaq_lane_f32(vo1, vi2, vget_high_f32(vf10), 0);
vo1 = vmlaq_lane_f32(vo1, vi3, vget_high_f32(vf10), 1);
vo1 = vmlaq_lane_f32(vo1, vi4, vget_low_f32(vf11), 1);
vo1 = vmlaq_lane_f32(vo1, vi5, vget_high_f32(vf11), 0);
vo1 = vmlaq_lane_f32(vo1, vi6, vget_high_f32(vf11), 1);
/* outch 2 */
vo2 = vmlaq_lane_f32(vo2, vi0, vget_low_f32(vf20), 0);
vo2 = vmlaq_lane_f32(vo2, vi1, vget_low_f32(vf20), 1);
vo2 = vmlaq_lane_f32(vo2, vi2, vget_high_f32(vf20), 0);
vo2 = vmlaq_lane_f32(vo2, vi3, vget_high_f32(vf20), 1);
vo2 = vmlaq_lane_f32(vo2, vi4, vget_low_f32(vf21), 1);
vo2 = vmlaq_lane_f32(vo2, vi5, vget_high_f32(vf21), 0);
vo2 = vmlaq_lane_f32(vo2, vi6, vget_high_f32(vf21), 1);
/* outch 3 */
vo3 = vmlaq_lane_f32(vo3, vi0, vget_low_f32(vf30), 0);
vo3 = vmlaq_lane_f32(vo3, vi1, vget_low_f32(vf30), 1);
vo3 = vmlaq_lane_f32(vo3, vi2, vget_high_f32(vf30), 0);
vo3 = vmlaq_lane_f32(vo3, vi3, vget_high_f32(vf30), 1);
vo3 = vmlaq_lane_f32(vo3, vi4, vget_low_f32(vf31), 1);
vo3 = vmlaq_lane_f32(vo3, vi5, vget_high_f32(vf31), 0);
vo3 = vmlaq_lane_f32(vo3, vi6, vget_high_f32(vf31), 1);
#endif
vst1q_f32(out_ptr0_base + out_offset, vo0);
vst1q_f32(out_ptr1_base + out_offset, vo1);
vst1q_f32(out_ptr2_base + out_offset, vo2);
vst1q_f32(out_ptr3_base + out_offset, vo3);
} // w
} // h
#else
for (index_t oc = 0; oc < 4; ++oc) {
Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0 + oc * in_channels * 7,
in_width, 1, 7, out_height, out_width,
out_ptr0_base + oc * out_image_size, 1);
}
#endif
} // c
} else {
for (index_t mm = m; mm < out_channels; ++mm) {
float *out_ptr0_base =
output + b * out_batch_size + mm * out_image_size;
for (index_t c = 0; c < in_channels; ++c) {
const float *in_ptr_base =
input + b * in_batch_size + c * in_image_size;
const float *filter_ptr0 = filter + mm * in_channels * 7 + c * 7;
#if defined(MACE_ENABLE_NEON)
/* load filter (1 outch x 1 height x 4 width) */
float32x4_t vf00, vf01;
vf00 = vld1q_f32(filter_ptr0);
vf01 = vld1q_f32(filter_ptr0 + 3);
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w + 3 < out_width; w += 4) {
// output (1 outch x 1 height x 4 width): vo_outch_height
float32x4_t vo0;
// load output
index_t out_offset = h * out_width + w;
vo0 = vld1q_f32(out_ptr0_base + out_offset);
// input (3 slide)
float32x4_t vi0, vi1, vi2, vi3, vi4, vi5, vi6, vi8;
// input offset
index_t in_offset = h * in_width + w;
// load input
vi0 = vld1q_f32(in_ptr_base + in_offset);
vi4 = vld1q_f32(in_ptr_base + in_offset + 4);
vi8 = vld1q_f32(in_ptr_base + in_offset + 8);
vi1 = vextq_f32(vi0, vi4, 1);
vi2 = vextq_f32(vi0, vi4, 2);
vi3 = vextq_f32(vi0, vi4, 3);
vi5 = vextq_f32(vi4, vi8, 1);
vi6 = vextq_f32(vi4, vi8, 2);
#if defined(__aarch64__)
vo0 = vfmaq_laneq_f32(vo0, vi0, vf00, 0);
vo0 = vfmaq_laneq_f32(vo0, vi1, vf00, 1);
vo0 = vfmaq_laneq_f32(vo0, vi2, vf00, 2);
vo0 = vfmaq_laneq_f32(vo0, vi3, vf00, 3);
vo0 = vfmaq_laneq_f32(vo0, vi4, vf01, 1);
vo0 = vfmaq_laneq_f32(vo0, vi5, vf01, 2);
vo0 = vfmaq_laneq_f32(vo0, vi6, vf01, 3);
#else
vo0 = vmlaq_lane_f32(vo0, vi0, vget_low_f32(vf00), 0);
vo0 = vmlaq_lane_f32(vo0, vi1, vget_low_f32(vf00), 1);
vo0 = vmlaq_lane_f32(vo0, vi2, vget_high_f32(vf00), 0);
vo0 = vmlaq_lane_f32(vo0, vi3, vget_high_f32(vf00), 1);
vo0 = vmlaq_lane_f32(vo0, vi4, vget_low_f32(vf01), 1);
vo0 = vmlaq_lane_f32(vo0, vi5, vget_high_f32(vf01), 0);
vo0 = vmlaq_lane_f32(vo0, vi6, vget_high_f32(vf01), 1);
#endif
vst1q_f32(out_ptr0_base + out_offset, vo0);
} // w
} // h
#else
Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0,
in_width, 1, 7, out_height, out_width,
out_ptr0_base, 1);
#endif
} // c
}
} // if
} // m
} // b
}
} // namespace kernels
} // namespace mace
......@@ -300,19 +300,11 @@ void Conv2dNeonK3x3S1(const float *input,
out_ptr1 += out_width;
} // h
#else
for (index_t io = 0; io < 2; ++io) {
for (index_t ih = 0; ih < out_height; ++ih) {
for (index_t iw = 0; iw < out_width; ++iw) {
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 3; ++j) {
out_ptr0[io * out_image_size + ih * out_width + iw] +=
in_ptr0[(ih + i) * in_width + (iw + j)]
* filter_ptr0[io * in_channels * 9 + i * 3 + j];
}
}
}
}
} // for
for (index_t oc = 0; oc < 2; ++oc) {
Conv2dCPUKHxKWCalc(in_ptr0, filter_ptr0 + oc * in_channels * 9,
in_width, 3, 3, out_height, out_width,
out_ptr0_base + oc * out_image_size, 1);
}
#endif
} // c
} else {
......@@ -501,17 +493,9 @@ void Conv2dNeonK3x3S1(const float *input,
out_ptr0 += out_width;
} // h
#else
for (index_t ih = 0; ih < out_height; ++ih) {
for (index_t iw = 0; iw < out_width; ++iw) {
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 3; ++j) {
out_ptr0[ih * out_width + iw] +=
in_ptr0[(ih + i) * in_width + (iw + j)]
* filter_ptr0[i * 3 + j];
}
}
}
}
Conv2dCPUKHxKWCalc(in_ptr0, filter_ptr0,
in_width, 3, 3, out_height, out_width,
out_ptr0_base, 1);
#endif
} // c
} // mm
......@@ -666,17 +650,9 @@ void Conv2dNeonK3x3S2(const float *input,
} // w
} // h
#else
for (index_t ih = 0; ih < out_height; ++ih) {
for (index_t iw = 0; iw < out_width; ++iw) {
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 3; ++j) {
out_base[ih * out_width + iw] +=
in_base[(ih * 2 + i) * in_width + (iw * 2 + j)]
* filter_ptr[i * 3 + j];
}
}
}
}
Conv2dCPUKHxKWCalc(in_base, filter_ptr,
in_width, 3, 3, out_height, out_width,
out_base, 2);
#endif
} // c
} // m
......
......@@ -76,30 +76,6 @@ namespace kernels {
vo0 = vmlaq_lane_f32(vo0, vi3, vget_high_f32(vf00), 1); \
vo0 = vmlaq_lane_f32(vo0, vi4, vf01, 1);
inline void Conv2dCPUK5x5Calc(const float *in_ptr_base,
const float *filter_ptr0,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_image_size,
float *out_ptr0_base,
const index_t io,
const int stride) {
for (index_t ih = 0; ih < out_height; ++ih) {
for (index_t iw = 0; iw < out_width; ++iw) {
for (int i = 0; i < 5; ++i) {
for (int j = 0; j < 5; ++j) {
out_ptr0_base[io * out_image_size + ih * out_width + iw] +=
in_ptr_base[(ih * stride + i) * in_width + (iw * stride + j)] *
filter_ptr0[io * in_channels * 25 + i * 5 + j];
}
}
}
}
}
// Ho = 1, Wo = 4, Co = 4
void Conv2dNeonK5x5S1(const float *input,
const float *filter,
......@@ -183,11 +159,11 @@ void Conv2dNeonK5x5S1(const float *input,
} // w
} // h
#else
for (index_t io = 0; io < 4; ++io) {
Conv2dCPUK5x5Calc(in_ptr_base, filter_ptr0, in_width, in_channels,
out_height, out_width, out_image_size,
out_ptr0_base, io, 1);
} // for
for (index_t oc = 0; oc < 4; ++oc) {
Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0 + oc * in_channels * 25,
in_width, 5, 5, out_height, out_width,
out_ptr0_base + oc * out_image_size, 1);
}
#endif
} // c
} else {
......@@ -229,9 +205,9 @@ void Conv2dNeonK5x5S1(const float *input,
} // w
} // h
#else
Conv2dCPUK5x5Calc(in_ptr_base, filter_ptr0, in_width, in_channels,
out_height, out_width, out_image_size,
out_ptr0_base, 0, 1);
Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0,
in_width, 5, 5, out_height, out_width,
out_ptr0_base, 1);
#endif
} // c
} // mm
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include "mace/kernels/arm/conv_2d_neon.h"
namespace mace {
namespace kernels {
// Ho = 4, Wo = 1, Co = 4
void Conv2dNeonK7x1S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output) {
const index_t in_image_size = in_shape[2] * in_shape[3];
const index_t out_image_size = out_shape[2] * out_shape[3];
const index_t in_batch_size = in_shape[1] * in_image_size;
const index_t out_batch_size = out_shape[1] * out_image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < out_shape[0]; ++b) {
for (index_t m = 0; m < out_shape[1]; m += 4) {
const index_t out_channels = out_shape[1];
const index_t out_height = out_shape[2];
const index_t out_width = out_shape[3];
const index_t in_channels = in_shape[1];
const index_t in_width = in_shape[3];
if (m + 3 < out_channels) {
float *out_ptr0_base =
output + b * out_batch_size + m * out_image_size;
#if defined(MACE_ENABLE_NEON)
float *out_ptr1_base =
output + b * out_batch_size + (m + 1) * out_image_size;
float *out_ptr2_base =
output + b * out_batch_size + (m + 2) * out_image_size;
float *out_ptr3_base =
output + b * out_batch_size + (m + 3) * out_image_size;
#endif
for (index_t c = 0; c < in_channels; ++c) {
const float *in_ptr_base =
input + b * in_batch_size + c * in_image_size;
const float *filter_ptr0 = filter + m * in_channels * 7 + c * 7;
#if defined(MACE_ENABLE_NEON)
const float *filter_ptr1 =
filter + (m + 1) * in_channels * 7 + c * 7;
const float *filter_ptr2 =
filter + (m + 2) * in_channels * 7 + c * 7;
const float *filter_ptr3 =
filter + (m + 3) * in_channels * 7 + c * 7;
/* load filter (4 outch x 4 height x 1 width) */
float32x4_t vf00, vf01;
float32x4_t vf10, vf11;
float32x4_t vf20, vf21;
float32x4_t vf30, vf31;
vf00 = vld1q_f32(filter_ptr0);
vf01 = vld1q_f32(filter_ptr0 + 3);
vf10 = vld1q_f32(filter_ptr1);
vf11 = vld1q_f32(filter_ptr1 + 3);
vf20 = vld1q_f32(filter_ptr2);
vf21 = vld1q_f32(filter_ptr2 + 3);
vf30 = vld1q_f32(filter_ptr3);
vf31 = vld1q_f32(filter_ptr3 + 3);
for (index_t h = 0; h + 3 < out_height; h += 4) {
for (index_t w = 0; w < out_width; ++w) {
// load output
index_t out_offset = h * out_width + w;
// output (4 outch x 4 height x 1 width): vo_outch_height
float32x4_t vo0 = {out_ptr0_base[out_offset],
out_ptr0_base[out_offset + out_width],
out_ptr0_base[out_offset + 2 * out_width],
out_ptr0_base[out_offset + 3 * out_width]};
float32x4_t vo1 = {out_ptr1_base[out_offset],
out_ptr1_base[out_offset + out_width],
out_ptr1_base[out_offset + 2 * out_width],
out_ptr1_base[out_offset + 3 * out_width]};
float32x4_t vo2 = {out_ptr2_base[out_offset],
out_ptr2_base[out_offset + out_width],
out_ptr2_base[out_offset + 2 * out_width],
out_ptr2_base[out_offset + 3 * out_width]};
float32x4_t vo3 = {out_ptr3_base[out_offset],
out_ptr3_base[out_offset + out_width],
out_ptr3_base[out_offset + 2 * out_width],
out_ptr3_base[out_offset + 3 * out_width]};
// input offset
index_t in_offset = h * in_width + w;
// input (3 slide)
float32x4_t vi0 = {in_ptr_base[in_offset],
in_ptr_base[in_offset + in_width],
in_ptr_base[in_offset + 2 * in_width],
in_ptr_base[in_offset + 3 * in_width]};
float32x4_t vi4 = {in_ptr_base[in_offset + 4 * in_width],
in_ptr_base[in_offset + 5 * in_width],
in_ptr_base[in_offset + 6 * in_width],
in_ptr_base[in_offset + 7 * in_width]};
float32x4_t vi8 = {in_ptr_base[in_offset + 8 * in_width],
in_ptr_base[in_offset + 9 * in_width]};
float32x4_t vi1 = vextq_f32(vi0, vi4, 1);
float32x4_t vi2 = vextq_f32(vi0, vi4, 2);
float32x4_t vi3 = vextq_f32(vi0, vi4, 3);
float32x4_t vi5 = vextq_f32(vi4, vi8, 1);
float32x4_t vi6 = vextq_f32(vi4, vi8, 2);
#if defined(__aarch64__)
/* outch 0 */
vo0 = vfmaq_laneq_f32(vo0, vi0, vf00, 0);
vo0 = vfmaq_laneq_f32(vo0, vi1, vf00, 1);
vo0 = vfmaq_laneq_f32(vo0, vi2, vf00, 2);
vo0 = vfmaq_laneq_f32(vo0, vi3, vf00, 3);
vo0 = vfmaq_laneq_f32(vo0, vi4, vf01, 1);
vo0 = vfmaq_laneq_f32(vo0, vi5, vf01, 2);
vo0 = vfmaq_laneq_f32(vo0, vi6, vf01, 3);
/* outch 1 */
vo1 = vfmaq_laneq_f32(vo1, vi0, vf10, 0);
vo1 = vfmaq_laneq_f32(vo1, vi1, vf10, 1);
vo1 = vfmaq_laneq_f32(vo1, vi2, vf10, 2);
vo1 = vfmaq_laneq_f32(vo1, vi3, vf10, 3);
vo1 = vfmaq_laneq_f32(vo1, vi4, vf11, 1);
vo1 = vfmaq_laneq_f32(vo1, vi5, vf11, 2);
vo1 = vfmaq_laneq_f32(vo1, vi6, vf11, 3);
/* outch 2 */
vo2 = vfmaq_laneq_f32(vo2, vi0, vf20, 0);
vo2 = vfmaq_laneq_f32(vo2, vi1, vf20, 1);
vo2 = vfmaq_laneq_f32(vo2, vi2, vf20, 2);
vo2 = vfmaq_laneq_f32(vo2, vi3, vf20, 3);
vo2 = vfmaq_laneq_f32(vo2, vi4, vf21, 1);
vo2 = vfmaq_laneq_f32(vo2, vi5, vf21, 2);
vo2 = vfmaq_laneq_f32(vo2, vi6, vf21, 3);
/* outch 3 */
vo3 = vfmaq_laneq_f32(vo3, vi0, vf30, 0);
vo3 = vfmaq_laneq_f32(vo3, vi1, vf30, 1);
vo3 = vfmaq_laneq_f32(vo3, vi2, vf30, 2);
vo3 = vfmaq_laneq_f32(vo3, vi3, vf30, 3);
vo3 = vfmaq_laneq_f32(vo3, vi4, vf31, 1);
vo3 = vfmaq_laneq_f32(vo3, vi5, vf31, 2);
vo3 = vfmaq_laneq_f32(vo3, vi6, vf31, 3);
#else
/* outch 0 */
vo0 = vmlaq_lane_f32(vo0, vi0, vget_low_f32(vf00), 0);
vo0 = vmlaq_lane_f32(vo0, vi1, vget_low_f32(vf00), 1);
vo0 = vmlaq_lane_f32(vo0, vi2, vget_high_f32(vf00), 0);
vo0 = vmlaq_lane_f32(vo0, vi3, vget_high_f32(vf00), 1);
vo0 = vmlaq_lane_f32(vo0, vi4, vget_low_f32(vf01), 1);
vo0 = vmlaq_lane_f32(vo0, vi5, vget_high_f32(vf01), 0);
vo0 = vmlaq_lane_f32(vo0, vi6, vget_high_f32(vf01), 1);
/* outch 1 */
vo1 = vmlaq_lane_f32(vo1, vi0, vget_low_f32(vf10), 0);
vo1 = vmlaq_lane_f32(vo1, vi1, vget_low_f32(vf10), 1);
vo1 = vmlaq_lane_f32(vo1, vi2, vget_high_f32(vf10), 0);
vo1 = vmlaq_lane_f32(vo1, vi3, vget_high_f32(vf10), 1);
vo1 = vmlaq_lane_f32(vo1, vi4, vget_low_f32(vf11), 1);
vo1 = vmlaq_lane_f32(vo1, vi5, vget_high_f32(vf11), 0);
vo1 = vmlaq_lane_f32(vo1, vi6, vget_high_f32(vf11), 1);
/* outch 2 */
vo2 = vmlaq_lane_f32(vo2, vi0, vget_low_f32(vf20), 0);
vo2 = vmlaq_lane_f32(vo2, vi1, vget_low_f32(vf20), 1);
vo2 = vmlaq_lane_f32(vo2, vi2, vget_high_f32(vf20), 0);
vo2 = vmlaq_lane_f32(vo2, vi3, vget_high_f32(vf20), 1);
vo2 = vmlaq_lane_f32(vo2, vi4, vget_low_f32(vf21), 1);
vo2 = vmlaq_lane_f32(vo2, vi5, vget_high_f32(vf21), 0);
vo2 = vmlaq_lane_f32(vo2, vi6, vget_high_f32(vf21), 1);
/* outch 3 */
vo3 = vmlaq_lane_f32(vo3, vi0, vget_low_f32(vf30), 0);
vo3 = vmlaq_lane_f32(vo3, vi1, vget_low_f32(vf30), 1);
vo3 = vmlaq_lane_f32(vo3, vi2, vget_high_f32(vf30), 0);
vo3 = vmlaq_lane_f32(vo3, vi3, vget_high_f32(vf30), 1);
vo3 = vmlaq_lane_f32(vo3, vi4, vget_low_f32(vf31), 1);
vo3 = vmlaq_lane_f32(vo3, vi5, vget_high_f32(vf31), 0);
vo3 = vmlaq_lane_f32(vo3, vi6, vget_high_f32(vf31), 1);
#endif
out_ptr0_base[out_offset] = vo0[0];
out_ptr0_base[out_offset + out_width] = vo0[1];
out_ptr0_base[out_offset + 2 * out_width] = vo0[2];
out_ptr0_base[out_offset + 3 * out_width] = vo0[3];
out_ptr1_base[out_offset] = vo1[0];
out_ptr1_base[out_offset + out_width] = vo1[1];
out_ptr1_base[out_offset + 2 * out_width] = vo1[2];
out_ptr1_base[out_offset + 3 * out_width] = vo1[3];
out_ptr2_base[out_offset] = vo2[0];
out_ptr2_base[out_offset + out_width] = vo2[1];
out_ptr2_base[out_offset + 2 * out_width] = vo2[2];
out_ptr2_base[out_offset + 3 * out_width] = vo2[3];
out_ptr3_base[out_offset] = vo3[0];
out_ptr3_base[out_offset + out_width] = vo3[1];
out_ptr3_base[out_offset + 2 * out_width] = vo3[2];
out_ptr3_base[out_offset + 3 * out_width] = vo3[3];
} // w
} // h
#else
for (index_t oc = 0; oc < 4; ++oc) {
Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0 + oc * in_channels * 7,
in_width, 7, 1, out_height, out_width,
out_ptr0_base + oc * out_image_size, 1);
}
#endif
} // c
} else {
for (index_t mm = m; mm < out_channels; ++mm) {
float *out_ptr0_base =
output + b * out_batch_size + mm * out_image_size;
for (index_t c = 0; c < in_channels; ++c) {
const float *in_ptr_base =
input + b * in_batch_size + c * in_image_size;
const float *filter_ptr0 = filter + mm * in_channels * 7 + c * 7;
#if defined(MACE_ENABLE_NEON)
/* load filter (1 outch x 4 height x 1 width) */
float32x4_t vf00, vf01;
vf00 = vld1q_f32(filter_ptr0);
vf01 = vld1q_f32(filter_ptr0 + 3);
for (index_t h = 0; h + 3 < out_height; h += 4) {
for (index_t w = 0; w < out_width; ++w) {
// load output
index_t out_offset = h * out_width + w;
// output (1 outch x 4 height x 1 width): vo_outch_height
float32x4_t vo0 = {out_ptr0_base[out_offset],
out_ptr0_base[out_offset + out_width],
out_ptr0_base[out_offset + 2 * out_width],
out_ptr0_base[out_offset + 3 * out_width]};
// input offset
index_t in_offset = h * in_width + w;
// input (3 slide)
float32x4_t vi0 = {in_ptr_base[in_offset],
in_ptr_base[in_offset + in_width],
in_ptr_base[in_offset + 2 * in_width],
in_ptr_base[in_offset + 3 * in_width]};
float32x4_t vi4 = {in_ptr_base[in_offset + 4 * in_width],
in_ptr_base[in_offset + 5 * in_width],
in_ptr_base[in_offset + 6 * in_width],
in_ptr_base[in_offset + 7 * in_width]};
float32x4_t vi8 = {in_ptr_base[in_offset + 8 * in_width],
in_ptr_base[in_offset + 9 * in_width],
in_ptr_base[in_offset + 10 * in_width],
in_ptr_base[in_offset + 11 * in_width]};
float32x4_t vi1 = vextq_f32(vi0, vi4, 1);
float32x4_t vi2 = vextq_f32(vi0, vi4, 2);
float32x4_t vi3 = vextq_f32(vi0, vi4, 3);
float32x4_t vi5 = vextq_f32(vi4, vi8, 1);
float32x4_t vi6 = vextq_f32(vi4, vi8, 2);
#if defined(__aarch64__)
vo0 = vfmaq_laneq_f32(vo0, vi0, vf00, 0);
vo0 = vfmaq_laneq_f32(vo0, vi1, vf00, 1);
vo0 = vfmaq_laneq_f32(vo0, vi2, vf00, 2);
vo0 = vfmaq_laneq_f32(vo0, vi3, vf00, 3);
vo0 = vfmaq_laneq_f32(vo0, vi4, vf01, 1);
vo0 = vfmaq_laneq_f32(vo0, vi5, vf01, 2);
vo0 = vfmaq_laneq_f32(vo0, vi6, vf01, 3);
#else
vo0 = vmlaq_lane_f32(vo0, vi0, vget_low_f32(vf00), 0);
vo0 = vmlaq_lane_f32(vo0, vi1, vget_low_f32(vf00), 1);
vo0 = vmlaq_lane_f32(vo0, vi2, vget_high_f32(vf00), 0);
vo0 = vmlaq_lane_f32(vo0, vi3, vget_high_f32(vf00), 1);
vo0 = vmlaq_lane_f32(vo0, vi4, vget_low_f32(vf01), 1);
vo0 = vmlaq_lane_f32(vo0, vi5, vget_high_f32(vf01), 0);
vo0 = vmlaq_lane_f32(vo0, vi6, vget_high_f32(vf01), 1);
#endif
out_ptr0_base[out_offset] = vo0[0];
out_ptr0_base[out_offset + out_width] = vo0[1];
out_ptr0_base[out_offset + 2 * out_width] = vo0[2];
out_ptr0_base[out_offset + 3 * out_width] = vo0[3];
} // w
} // h
#else
Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0,
in_width, 7, 1, out_height, out_width,
out_ptr0_base, 1);
#endif
} // c
}
} // if
} // m
} // b
}
} // namespace kernels
} // namespace mace
......@@ -153,30 +153,6 @@ namespace kernels {
vo0 = vmlaq_lane_f32(vo0, vi5, vget_high_f32(vf01), 0); \
vo0 = vmlaq_lane_f32(vo0, vi6, vget_high_f32(vf01), 1);
inline void Conv2dCPUK7x7Calc(const float *in_ptr_base,
const float *filter_ptr0,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_image_size,
float *out_ptr0_base,
const index_t io,
const int stride) {
for (index_t ih = 0; ih < out_height; ++ih) {
for (index_t iw = 0; iw < out_width; ++iw) {
for (int i = 0; i < 7; ++i) {
for (int j = 0; j < 7; ++j) {
out_ptr0_base[io * out_image_size + ih * out_width + iw] +=
in_ptr_base[(ih * stride + i) * in_width + (iw * stride + j)] *
filter_ptr0[io * in_channels * 49 + i * 7 + j];
}
}
}
}
}
// Ho = 1, Wo = 4, Co = 4
void Conv2dNeonK7x7S1(const float *input,
const float *filter,
......@@ -268,11 +244,11 @@ void Conv2dNeonK7x7S1(const float *input,
} // w
} // h
#else
for (index_t io = 0; io < 4; ++io) {
Conv2dCPUK7x7Calc(in_ptr_base, filter_ptr0, in_width, in_channels,
out_height, out_width, out_image_size,
out_ptr0_base, io, 1);
} // for
for (index_t oc = 0; oc < 4; ++oc) {
Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0 + oc * in_channels * 49,
in_width, 7, 7, out_height, out_width,
out_ptr0_base + oc * out_image_size, 1);
}
#endif
} // c
} else {
......@@ -322,9 +298,9 @@ void Conv2dNeonK7x7S1(const float *input,
} // w
} // h
#else
Conv2dCPUK7x7Calc(in_ptr_base, filter_ptr0, in_width, in_channels,
out_height, out_width, out_image_size,
out_ptr0_base, 0, 1);
Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0,
in_width, 7, 7, out_height, out_width,
out_ptr0_base, 1);
#endif
} // c
} // mm
......@@ -429,11 +405,11 @@ void Conv2dNeonK7x7S2(const float *input,
} // w
} // h
#else
for (index_t io = 0; io < 4; ++io) {
Conv2dCPUK7x7Calc(in_ptr_base, filter_ptr0, in_width, in_channels,
out_height, out_width, out_image_size,
out_ptr0_base, io, 2);
} // for
for (index_t oc = 0; oc < 4; ++oc) {
Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0 + oc * in_channels * 49,
in_width, 7, 7, out_height, out_width,
out_ptr0_base + oc * out_image_size, 2);
}
#endif
} // c
} else {
......@@ -488,9 +464,9 @@ void Conv2dNeonK7x7S2(const float *input,
} // w
} // h
#else
Conv2dCPUK7x7Calc(in_ptr_base, filter_ptr0, in_width, in_channels,
out_height, out_width, out_image_size,
out_ptr0_base, 0, 2);
Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0,
in_width, 7, 7, out_height, out_width,
out_ptr0_base, 2);
#endif
} // c
} // mm
......@@ -595,11 +571,11 @@ void Conv2dNeonK7x7S3(const float *input,
} // w
} // h
#else
for (index_t io = 0; io < 4; ++io) {
Conv2dCPUK7x7Calc(in_ptr_base, filter_ptr0, in_width, in_channels,
out_height, out_width, out_image_size,
out_ptr0_base, io, 3);
} // for
for (index_t oc = 0; oc < 4; ++oc) {
Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0 + oc * in_channels * 49,
in_width, 7, 7, out_height, out_width,
out_ptr0_base + oc * out_image_size, 3);
}
#endif
} // c
} else {
......@@ -654,9 +630,9 @@ void Conv2dNeonK7x7S3(const float *input,
} // w
} // h
#else
Conv2dCPUK7x7Calc(in_ptr_base, filter_ptr0, in_width, in_channels,
out_height, out_width, out_image_size,
out_ptr0_base, 0, 3);
Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0,
in_width, 7, 7, out_height, out_width,
out_ptr0_base, 3);
#endif
} // c
} // mm
......
......@@ -357,6 +357,10 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
bool use_neon_5x5_s1 = filter_h == 5 && filter_w == 5
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
bool use_neon_1x7_s1 = filter_h == 1 && filter_w == 7
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
bool use_neon_7x1_s1 = filter_h == 7 && filter_w == 1
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
bool use_neon_7x7_s1 = filter_h == 7 && filter_w == 7
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
bool use_neon_7x7_s2 = filter_h == 7 && filter_w == 7
......@@ -414,7 +418,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
} else if (use_neon_3x3_s1) {
tile_h = 2;
tile_w = 4;
} else if (use_neon_15x1_s1) {
} else if (use_neon_7x1_s1 || use_neon_15x1_s1) {
tile_h = 4;
tile_w = 1;
} else {
......@@ -566,6 +570,22 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
extra_output_shape,
pad_output);
};
} else if (use_neon_1x7_s1) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK1x7S1(pad_input,
filter_data,
extra_input_shape,
extra_output_shape,
pad_output);
};
} else if (use_neon_7x1_s1) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK7x1S1(pad_input,
filter_data,
extra_input_shape,
extra_output_shape,
pad_output);
};
} else if (use_neon_7x7_s1) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK7x7S1(pad_input,
......
......@@ -388,8 +388,7 @@ inline void GemmTile(const float *A,
"v22",
"v23",
"v24",
"v25",
"v26"
"v25"
);
w = (width >> 2) << 2;
......
......@@ -168,6 +168,10 @@ BM_CONV_2D(1, 1024, 7, 7, 1, 1, 1, 1, SAME, 1024);
BM_CONV_2D(64, 32, 34, 34, 3, 3, 1, 1, VALID, 32);
BM_CONV_2D(1, 32, 34, 34, 3, 3, 1, 1, VALID, 32);
BM_CONV_2D(1, 192, 17, 17, 1, 7, 1, 1, SAME, 192);
BM_CONV_2D(1, 192, 17, 17, 7, 1, 1, 1, SAME, 192);
BM_CONV_2D(1, 160, 17, 17, 7, 1, 1, 1, SAME, 192);
BM_CONV_2D(1, 32, 256, 256, 1, 15, 1, 1, SAME, 2);
BM_CONV_2D(1, 32, 256, 256, 15, 1, 1, 1, SAME, 2);
BM_CONV_2D(1, 64, 64, 64, 15, 1, 1, 1, SAME, 2);
......
......@@ -776,6 +776,36 @@ TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv3x3S12) {
{1, 1});
}
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv5x5S12) {
TestHalfComplexConvNxNS12<DeviceType::GPU>({32, 32}, {5, 5, 3, 64},
{1, 1});
TestHalfComplexConvNxNS12<DeviceType::GPU>({32, 32}, {5, 5, 3, 63},
{1, 1});
}
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv1x7S1) {
TestHalfComplexConvNxNS12<DeviceType::GPU>({17, 17}, {1, 7, 192, 192},
{1, 1});
TestHalfComplexConvNxNS12<DeviceType::GPU>({17, 17}, {1, 7, 192, 191},
{1, 1});
}
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv7x1S1) {
TestHalfComplexConvNxNS12<DeviceType::GPU>({17, 17}, {7, 1, 192, 192},
{1, 1});
TestHalfComplexConvNxNS12<DeviceType::GPU>({17, 17}, {7, 1, 160, 192},
{1, 1});
TestHalfComplexConvNxNS12<DeviceType::GPU>({17, 17}, {7, 1, 160, 191},
{1, 1});
}
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv7x7S12) {
TestHalfComplexConvNxNS12<DeviceType::GPU>({32, 32}, {7, 7, 3, 64},
{1, 1});
TestHalfComplexConvNxNS12<DeviceType::GPU>({32, 32}, {7, 7, 3, 63},
{1, 1});
}
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv15x1S12) {
TestHalfComplexConvNxNS12<DeviceType::GPU>({32, 32}, {15, 1, 256, 2},
{1, 1});
......@@ -792,11 +822,6 @@ TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv1x15S12) {
{1, 1});
}
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv7x75S12) {
TestHalfComplexConvNxNS12<DeviceType::GPU>({32, 32}, {7, 7, 3, 64},
{1, 1});
}
TEST_F(Conv2dOpTest, OPENCLHalfUnalignedConv1x1S12) {
TestHalfComplexConvNxNS12<DeviceType::GPU>({107, 113}, {1, 1, 5, 7},
{1, 1});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册