提交 d37229fa 编写于 作者: M Megvii Engine Team

feat(dnn): optimize f23 and f63 nchw44 winograd

GitOrigin-RevId: 8569c9dfc6db1b6853d4aa35bcf5b2bc9b6f89b1
上级 d7c0dd45
......@@ -32,29 +32,37 @@ constexpr size_t pack_size = 4;
struct InputTransformF23_NCHW44 {
template <bool inner>
static void prepare(const float* input, float* patch, float* patchT,
int ih_start, int iw_start, size_t IH, size_t IW,
size_t ic, size_t IC) {
MEGDNN_MARK_USED_VAR(patch);
static void transform(float* patchT, const float* input,
float* input_transform_buf, size_t ih_start,
size_t iw_start, size_t IH, size_t IW,
size_t unit_idx, size_t nr_units_in_tile, size_t ic,
size_t IC) {
size_t IW4 = IW * pack_size;
size_t iw4_start = iw_start * pack_size;
size_t icb = ic / pack_size;
size_t iw4_start = iw_start * pack_size;
size_t ICB = IC / pack_size;
#define cb(m, n) Vector<float, 4> d##m##n;
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
#undef cb
if (!(inner && ic + pack_size < IC)) {
memset(patchT, 0, sizeof(float) * pack_size * alpha * alpha);
}
if (inner) {
MEGDNN_MARK_USED_VAR(patchT);
const float* input_ptr =
input + icb * IH * IW4 + ih_start * IW4 + iw4_start;
for (size_t ih = 0; ih < alpha; ih++) {
#define cb(i) auto v##i = vld1q_f32(input_ptr + pack_size * i);
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
#define cb(n, m) d##m##n = Vector<float, 4>::load(input_ptr + pack_size * n);
#define cb(i) vst1q_f32(patchT + ih * alpha * pack_size + i * pack_size, v##i);
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
UNROLL_CALL_RAW(4, cb, 0);
input_ptr += IW4;
}
UNROLL_CALL_RAW(4, cb, 1);
input_ptr += IW4;
UNROLL_CALL_RAW(4, cb, 2);
input_ptr += IW4;
UNROLL_CALL_RAW(4, cb, 3);
#undef cb
} else {
int ih0_act = std::max<int>(ih_start, 0),
ih1_act = std::min<int>(ih_start + alpha, IH),
......@@ -71,19 +79,12 @@ struct InputTransformF23_NCHW44 {
src);
}
}
}
}
static void transform(const float* patchT, float* input_transform_buf,
size_t unit_idx, size_t nr_units_in_tile, size_t ic,
size_t IC) {
// BT * d * B
#define cb(m, n) \
Vector<float, 4> d##m##n = Vector<float, 4>::load( \
patchT + m * alpha * pack_size + n * pack_size);
d##m##n = Vector<float, 4>::load(patchT + m * alpha * pack_size + \
n * pack_size);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
#undef cb
}
//! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0
......@@ -106,8 +107,6 @@ struct InputTransformF23_NCHW44 {
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
size_t ICB = IC / 4;
size_t icb = ic / 4;
#define cb(m, n) \
d##m##n.save(input_transform_buf + \
(m * alpha + n) * ICB * nr_units_in_tile * pack_size + \
......@@ -273,7 +272,6 @@ void winograd_F23_mk4_f_nchw44::input(const float* input,
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w =
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
float* patch = transform_mid_buf;
float* patchT = transform_mid_buf + 4 * alpha * alpha;
for (size_t ic = 0; ic < IC; ic += 4) {
......@@ -285,20 +283,13 @@ void winograd_F23_mk4_f_nchw44::input(const float* input,
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) {
InputTransformF23_NCHW44::prepare<true>(input, patch, patchT,
ih_start, iw_start, IH,
IW, ic, IC);
InputTransformF23_NCHW44::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile,
ic, IC);
InputTransformF23_NCHW44::transform<true>(
patchT, input, input_transform_buf, ih_start, iw_start,
IH, IW, unit_idx, nr_units_in_tile, ic, IC);
} else {
InputTransformF23_NCHW44::prepare<false>(input, patch, patchT,
ih_start, iw_start, IH,
IW, ic, IC);
InputTransformF23_NCHW44::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile,
ic, IC);
InputTransformF23_NCHW44::transform<false>(
patchT, input, input_transform_buf, ih_start, iw_start,
IH, IW, unit_idx, nr_units_in_tile, ic, IC);
}
}
}
......@@ -312,8 +303,20 @@ void winograd_F23_mk4_f_nchw44::output(const float* output_transform_buf,
size_t oc_end, size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
OutputTransformF23_NCHW44<_bmode MEGDNN_COMMA _nonline_op>::transform( \
__VA_ARGS__);
for (size_t oc = oc_start; oc < oc_end; oc += 4) { \
size_t oc_index = oc - oc_start; \
rep(unit_idx, nr_units_in_tile) { \
size_t index = unit_start_idx + unit_idx; \
auto nh = index / units_w; \
auto nw = index % units_w; \
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \
OutputTransformF23_NCHW44<_bmode, _nonline_op>::transform( \
output_transform_buf, bias, output, transform_mid_buf, \
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, \
unit_idx, nr_units_in_tile, src_dtype, dst_dtype); \
} \
}
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
constexpr size_t pack_size = 4;
......@@ -323,22 +326,8 @@ void winograd_F23_mk4_f_nchw44::output(const float* output_transform_buf,
oc_end % pack_size == 0,
"NCHW44 Winograd filter transform requires OC is times of 4");
for (size_t oc = oc_start; oc < oc_end; oc += 4) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_nchw44_fp32_F23_mk4, cb, float,
float, bmode, nonline_mode, output_transform_buf, bias,
output, transform_mid_buf, oh_start, ow_start, OH, OW,
oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile,
src_dtype, dst_dtype);
}
}
DISPATCH_CONV_WINOGRAD_BIAS(megdnn_arm_common_winograd_nchw44_fp32_F23_mk4,
cb, float, float, bmode, nonline_mode);
#undef cb
}
......
......@@ -31,6 +31,8 @@ namespace {
constexpr size_t alpha = 6 + 3 - 1;
constexpr size_t pack_size = 4;
constexpr float input_parameters[12] = {5.25f, 4.25f, 0.5f, 0.25f, 2.5f, 1.25f,
2.0f, 4.0f, 5.0f, 0.0f, 0.0f, 0.0f};
struct InputTransformF63_NCHW44 {
template <bool inner>
......@@ -80,12 +82,14 @@ struct InputTransformF63_NCHW44 {
size_t unit_idx, size_t nr_units_in_tile, size_t ic,
size_t IC) {
// BT * d * B
#define cb(m, n) \
Vector<float, 4> d##m##n = Vector<float, 4>::load( \
patchT + m * alpha * pack_size + n * pack_size);
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb);
#undef cb
size_t ICB = IC / pack_size;
size_t icb = ic / pack_size;
float32x4_t d0, d1, d2, d3, d4, d5, d6, d7;
float32x4_t v0 = vld1q_f32(input_parameters + 0);
float32x4_t v1 = vld1q_f32(input_parameters + 4);
float32x4_t v2 = vld1q_f32(input_parameters + 8);
//! B
//! 1 0 0 0 0 0 0 0
......@@ -96,49 +100,147 @@ struct InputTransformF63_NCHW44 {
//! 0 1 -1 2 -2 0.5 -0.5 -5.25
//! -1 1 1 1 1 1 1 0
//! 0 0 0 0 0 0 0 1
#define cb(m) \
auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; \
auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; \
auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; \
auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + \
d5##m * 2.f + d6##m; \
auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - \
d4##m * 1.25f - d5##m * 2.f + d6##m; \
auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + \
d5##m * 0.5f + d6##m; \
auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - \
d5##m * 0.5f + d6##m; \
auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f;
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
#define cb(m) \
d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; \
d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - \
(t##m##3 + t##m##4) * 4.25f; \
d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + \
(t##m##3 - t##m##4) * 4.25f; \
d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f - \
t##m##4 * 1.25f + t##m##5 * 2.f + t##m##6; \
d##m##4 = t##m##1 * (-0.5f) + t##m##2 * 0.25f + t##m##3 * 2.5f - \
t##m##4 * 1.25f - t##m##5 * 2.f + t##m##6; \
d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + \
t##m##5 * 0.5f + t##m##6; \
d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - \
t##m##4 * 5.f - t##m##5 * 0.5f + t##m##6; \
d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f;
UNROLL_CALL_NOWRAPPER(8, cb);
#define cb(i) \
d1 = vld1q_f32(patchT + i * alpha * pack_size + 1 * pack_size); \
d2 = vld1q_f32(patchT + i * alpha * pack_size + 2 * pack_size); \
d3 = vld1q_f32(patchT + i * alpha * pack_size + 3 * pack_size); \
d4 = vld1q_f32(patchT + i * alpha * pack_size + 4 * pack_size); \
d5 = vld1q_f32(patchT + i * alpha * pack_size + 5 * pack_size); \
d6 = vld1q_f32(patchT + i * alpha * pack_size + 6 * pack_size); \
auto t##i##0 = vld1q_f32(patchT + i * alpha * pack_size + 0 * pack_size); \
auto t##i##7 = vld1q_f32(patchT + i * alpha * pack_size + 7 * pack_size); \
auto t##i##1 = d6; \
auto t##i##2 = d6; \
auto t##i##3 = d6; \
auto t##i##4 = d6; \
auto t##i##5 = d6; \
auto t##i##6 = d6; \
t##i##0 = t##i##0 - d6; \
t##i##1 = t##i##1 + d1; \
t##i##2 = t##i##2 - d1; \
t##i##3 = vfmaq_laneq_f32(t##i##3, d1, v0, 2); \
t##i##4 = vfmsq_laneq_f32(t##i##4, d1, v0, 2); \
t##i##5 = vfmaq_laneq_f32(t##i##5, d1, v1, 2); \
t##i##6 = vfmsq_laneq_f32(t##i##6, d1, v1, 2); \
t##i##7 = t##i##7 - d1; \
t##i##0 = vfmsq_laneq_f32(t##i##0, d2, v0, 0); \
t##i##1 = t##i##1 + d2; \
t##i##2 = t##i##2 + d2; \
t##i##3 = vfmaq_laneq_f32(t##i##3, d2, v0, 3); \
t##i##4 = vfmaq_laneq_f32(t##i##4, d2, v0, 3); \
t##i##5 = vfmaq_laneq_f32(t##i##5, d2, v1, 3); \
t##i##6 = vfmaq_laneq_f32(t##i##6, d2, v1, 3); \
t##i##1 = vfmsq_laneq_f32(t##i##1, d3, v0, 1); \
t##i##2 = vfmaq_laneq_f32(t##i##2, d3, v0, 1); \
t##i##3 = vfmsq_laneq_f32(t##i##3, d3, v1, 0); \
t##i##4 = vfmaq_laneq_f32(t##i##4, d3, v1, 0); \
t##i##5 = vfmsq_laneq_f32(t##i##5, d3, v1, 0); \
t##i##6 = vfmaq_laneq_f32(t##i##6, d3, v1, 0); \
t##i##7 = vfmaq_laneq_f32(t##i##7, d3, v0, 0); \
t##i##0 = vfmaq_laneq_f32(t##i##0, d4, v0, 0); \
t##i##1 = vfmsq_laneq_f32(t##i##1, d4, v0, 1); \
t##i##2 = vfmsq_laneq_f32(t##i##2, d4, v0, 1); \
t##i##3 = vfmsq_laneq_f32(t##i##3, d4, v1, 1); \
t##i##4 = vfmsq_laneq_f32(t##i##4, d4, v1, 1); \
t##i##5 = vfmsq_laneq_f32(t##i##5, d4, v2, 0); \
t##i##6 = vfmsq_laneq_f32(t##i##6, d4, v2, 0); \
t##i##1 = t##i##1 + d5; \
t##i##2 = t##i##2 - d5; \
t##i##3 = vfmaq_laneq_f32(t##i##3, d5, v1, 2); \
t##i##4 = vfmsq_laneq_f32(t##i##4, d5, v1, 2); \
t##i##5 = vfmaq_laneq_f32(t##i##5, d5, v0, 2); \
t##i##6 = vfmsq_laneq_f32(t##i##6, d5, v0, 2); \
t##i##7 = vfmsq_laneq_f32(t##i##7, d5, v0, 0);
UNROLL_CALL_RAW(8, cb);
#undef cb
size_t ICB = IC / pack_size;
size_t icb = ic / pack_size;
#define cb(m, n) \
d##m##n.save(input_transform_buf + \
(m * alpha + n) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size);
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb)
#define cb(i) \
d0 = t0##i; \
d1 = t6##i; \
d2 = t6##i; \
d3 = t6##i; \
d4 = t6##i; \
d5 = t6##i; \
d6 = t6##i; \
d7 = t7##i; \
d0 = d0 - t6##i; \
d1 = d1 + t1##i; \
d2 = d2 - t1##i; \
d3 = vfmaq_laneq_f32(d3, t1##i, v0, 2); \
d4 = vfmsq_laneq_f32(d4, t1##i, v0, 2); \
d5 = vfmaq_laneq_f32(d5, t1##i, v1, 2); \
d6 = vfmsq_laneq_f32(d6, t1##i, v1, 2); \
d7 = d7 - t1##i; \
d0 = vfmsq_laneq_f32(d0, t2##i, v0, 0); \
d1 = d1 + t2##i; \
d2 = d2 + t2##i; \
d3 = vfmaq_laneq_f32(d3, t2##i, v0, 3); \
d4 = vfmaq_laneq_f32(d4, t2##i, v0, 3); \
d5 = vfmaq_laneq_f32(d5, t2##i, v1, 3); \
d6 = vfmaq_laneq_f32(d6, t2##i, v1, 3); \
d1 = vfmsq_laneq_f32(d1, t3##i, v0, 1); \
d2 = vfmaq_laneq_f32(d2, t3##i, v0, 1); \
d3 = vfmsq_laneq_f32(d3, t3##i, v1, 0); \
d4 = vfmaq_laneq_f32(d4, t3##i, v1, 0); \
d5 = vfmsq_laneq_f32(d5, t3##i, v1, 0); \
d6 = vfmaq_laneq_f32(d6, t3##i, v1, 0); \
d7 = vfmaq_laneq_f32(d7, t3##i, v0, 0); \
d0 = vfmaq_laneq_f32(d0, t4##i, v0, 0); \
d1 = vfmsq_laneq_f32(d1, t4##i, v0, 1); \
d2 = vfmsq_laneq_f32(d2, t4##i, v0, 1); \
d3 = vfmsq_laneq_f32(d3, t4##i, v1, 1); \
d4 = vfmsq_laneq_f32(d4, t4##i, v1, 1); \
d5 = vfmsq_laneq_f32(d5, t4##i, v2, 0); \
d6 = vfmsq_laneq_f32(d6, t4##i, v2, 0); \
d1 = d1 + t5##i; \
d2 = d2 - t5##i; \
d3 = vfmaq_laneq_f32(d3, t5##i, v1, 2); \
d4 = vfmsq_laneq_f32(d4, t5##i, v1, 2); \
d5 = vfmaq_laneq_f32(d5, t5##i, v0, 2); \
d6 = vfmsq_laneq_f32(d6, t5##i, v0, 2); \
d7 = vfmsq_laneq_f32(d7, t5##i, v0, 0); \
vst1q_f32(input_transform_buf + \
(0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d0); \
vst1q_f32(input_transform_buf + \
(1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d1); \
vst1q_f32(input_transform_buf + \
(2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d2); \
vst1q_f32(input_transform_buf + \
(3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d3); \
vst1q_f32(input_transform_buf + \
(4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d4); \
vst1q_f32(input_transform_buf + \
(5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d5); \
vst1q_f32(input_transform_buf + \
(6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d6); \
vst1q_f32(input_transform_buf + \
(7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d7);
UNROLL_CALL_RAW(8, cb);
#undef cb
}
};
......@@ -178,7 +280,7 @@ struct OutputTransformF63_NCHW44 {
* 1 -2 4 -8 16 -32
* 1 0.5 0.25 0.125 0.0625 0.03125
* 1 -0.5 0.25 -0.125 0.0625 -0.03125
* 0 0.0 0 0 0 1
* 0 0 0 0 0 1
*/
Vector<float, 4> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6;
......@@ -378,28 +480,33 @@ void winograd_F63_mk4_f_nchw44::output(const float* output_transform_buf,
size_t OW, size_t oc_start,
size_t oc_end, size_t unit_start_idx,
size_t nr_units_in_tile) {
constexpr size_t pack_size = 4;
#define cb(_bmode, _nonline_op, ...) \
OutputTransformF63_NCHW44<_bmode MEGDNN_COMMA _nonline_op>::transform( \
__VA_ARGS__);
for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { \
size_t oc_index = oc - oc_start; \
rep(unit_idx, nr_units_in_tile) { \
size_t index = unit_start_idx + unit_idx; \
auto nh = index / units_w; \
auto nw = index % units_w; \
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \
OutputTransformF63_NCHW44<_bmode MEGDNN_COMMA _nonline_op>:: \
transform(output_transform_buf, bias, output, \
transform_mid_buf, oh_start, ow_start, OH, OW, \
oc_start, oc_end, oc_index, unit_idx, \
nr_units_in_tile, src_dtype, dst_dtype); \
} \
}
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
for (size_t oc = oc_start; oc < oc_end; oc += pack_size) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp32_F63_mk4, cb, float, float,
bmode, nonline_mode, output_transform_buf, bias, output,
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start,
oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype,
dst_dtype);
}
}
constexpr size_t pack_size = 4;
size_t OC = oc_end - oc_start;
megdnn_assert(OC % pack_size == 0 && oc_start % pack_size == 0 &&
oc_end % pack_size == 0,
"NCHW44 Winograd filter transform requires OC is times of 4");
DISPATCH_CONV_WINOGRAD_BIAS(megdnn_arm_common_winograd_fp32_F63_mk4, cb,
float, float, bmode, nonline_mode);
#undef cb
}
......
......@@ -538,10 +538,43 @@ struct Vfmaq_laneq_f32_armv7<3> {
return vmlaq_lane_f32(a, b, vget_high_f32(v), 1);
}
};
template <int lane>
struct Vfmsq_laneq_f32_armv7 {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v);
};
template <>
struct Vfmsq_laneq_f32_armv7<0> {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlsq_lane_f32(a, b, vget_low_f32(v), 0);
}
};
template <>
struct Vfmsq_laneq_f32_armv7<1> {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlsq_lane_f32(a, b, vget_low_f32(v), 1);
}
};
template <>
struct Vfmsq_laneq_f32_armv7<2> {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlsq_lane_f32(a, b, vget_high_f32(v), 0);
}
};
template <>
struct Vfmsq_laneq_f32_armv7<3> {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
return vmlsq_lane_f32(a, b, vget_high_f32(v), 1);
}
};
} // namespace
#define vfmaq_laneq_f32(a, b, v, lane) \
Vfmaq_laneq_f32_armv7<lane>::impl(a, b, v)
#define vfmsq_laneq_f32(a, b, v, lane) \
Vfmsq_laneq_f32_armv7<lane>::impl(a, b, v)
#if __ARM_FEATURE_DOTPROD
namespace {
template <int lane>
......@@ -582,7 +615,6 @@ struct Vdotq_laneq_s32_armv7<3> {
//! GCC split fmla with lane to dup+fmla when version < 9
//! https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
#if !defined(__clang__) && __GNUC__ < 9
#if MEGDNN_AARCH64
namespace {
......@@ -630,13 +662,59 @@ struct Vfmaq_laneq_f32_armv8<3> {
return a;
}
};
template <int lane>
struct Vfmsq_laneq_f32_armv8 {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v);
};
template <>
struct Vfmsq_laneq_f32_armv8<0> {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
asm volatile("fmls %0.4s, %1.4s, %2.s[0]\n"
: "+w"(a)
: "w"(b), "w"(v)
:);
return a;
}
};
template <>
struct Vfmsq_laneq_f32_armv8<1> {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
asm volatile("fmls %0.4s, %1.4s, %2.s[1]\n"
: "+w"(a)
: "w"(b), "w"(v)
:);
return a;
}
};
template <>
struct Vfmsq_laneq_f32_armv8<2> {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
asm volatile("fmls %0.4s, %1.4s, %2.s[2]\n"
: "+w"(a)
: "w"(b), "w"(v)
:);
return a;
}
};
template <>
struct Vfmsq_laneq_f32_armv8<3> {
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) {
asm volatile("fmls %0.4s, %1.4s, %2.s[3]\n"
: "+w"(a)
: "w"(b), "w"(v)
:);
return a;
}
};
} // namespace
#undef vfmaq_laneq_f32
#define vfmaq_laneq_f32(a, b, v, lane) \
Vfmaq_laneq_f32_armv8<lane>::impl(a, b, v)
#endif
#undef vfmsq_laneq_f32
#define vfmsq_laneq_f32(a, b, v, lane) \
Vfmsq_laneq_f32_armv8<lane>::impl(a, b, v)
#endif
__ai int8x16_t vld_dup_tbl_s32(const int8_t* ptr, uint8x16_t& idx) {
......@@ -678,6 +756,16 @@ __ai int16x8_t vld1_dup_s8_s16(const int8_t* ptr) {
return vmovl_s8(vld1_dup_s8(ptr));
}
//! we add this because we found that cpu=aarch64_android cann't compile fmsq into fmls.
//! it use dup+fmla instead
__ai float32x4_t Vfmsq_f32(float32x4_t& a, float32x4_t& b, float32x4_t& v) {
asm volatile("fmls %0.4s, %1.4s, %2.4s\n"
: "+w"(a)
: "w"(b), "w"(v)
:);
return a;
}
#undef __ai
#pragma GCC diagnostic pop
......
......@@ -791,8 +791,8 @@ void benchmark_winograd_nchw_vs_nchw44(const char* algo_name, Handle* handle) {
std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
for (auto nlmode : nonlinemode)
for (size_t n : {1, 2})
for (size_t group = 1; group <= 2; ++group) {
for (size_t n : {1})
for (size_t group = 1; group <= 1; ++group) {
pack(n, 512, 512, 15, 15, group, nlmode);
pack(n, 512, 256, 15, 15, group, nlmode);
pack(n, 256, 256, 29, 29, group, nlmode);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册