提交 a025ac02 编写于 作者: 李寅

Optimize gemm v7

上级 dad3d11a
...@@ -50,7 +50,7 @@ inline void GemmBlock(const float *A, ...@@ -50,7 +50,7 @@ inline void GemmBlock(const float *A,
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__) #if defined(__aarch64__)
#define MACE_GEMM_PART_CAL(RC, RA, RAN) \ #define MACE_GEMM_PART_CAL_8(RC, RA, RAN) \
c##RC = vfmaq_laneq_f32(c##RC, b0, a##RA, 0); \ c##RC = vfmaq_laneq_f32(c##RC, b0, a##RA, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b1, a##RA, 1); \ c##RC = vfmaq_laneq_f32(c##RC, b1, a##RA, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b2, a##RA, 2); \ c##RC = vfmaq_laneq_f32(c##RC, b2, a##RA, 2); \
...@@ -60,7 +60,7 @@ inline void GemmBlock(const float *A, ...@@ -60,7 +60,7 @@ inline void GemmBlock(const float *A,
c##RC = vfmaq_laneq_f32(c##RC, b6, a##RAN, 2); \ c##RC = vfmaq_laneq_f32(c##RC, b6, a##RAN, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b7, a##RAN, 3); c##RC = vfmaq_laneq_f32(c##RC, b7, a##RAN, 3);
#else #else
#define MACE_GEMM_PART_CAL(RC, RA, RAN) \ #define MACE_GEMM_PART_CAL_8(RC, RA, RAN) \
c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RA), 0); \ c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RA), 0); \
c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RA), 1); \ c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RA), 1); \
c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RA), 0); \ c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RA), 0); \
...@@ -72,6 +72,283 @@ inline void GemmBlock(const float *A, ...@@ -72,6 +72,283 @@ inline void GemmBlock(const float *A,
#endif #endif
#endif #endif
#if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
#define MACE_GEMM_PART_CAL_4(RC) \
c##RC = vfmaq_laneq_f32(c##RC, b0, a##RC, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b1, a##RC, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b2, a##RC, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b3, a##RC, 3);
#else
#define MACE_GEMM_PART_CAL_4(RC) \
c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RC), 0); \
c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RC), 1); \
c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RC), 0); \
c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RC), 1);
#endif
#endif
inline void Gemm144(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
MACE_UNUSED(stride_a);
MACE_UNUSED(stride_c);
float32x4_t a0;
float32x4_t b0, b1, b2, b3;
float32x4_t c0;
a0 = vld1q_f32(a_ptr);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
MACE_GEMM_PART_CAL_4(0);
vst1q_f32(c_ptr, c0);
#else
GemmBlock(a_ptr, b_ptr, 1, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm244(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1;
float32x4_t b0, b1, b2, b3;
float32x4_t c0, c1;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
MACE_GEMM_PART_CAL_4(0);
MACE_GEMM_PART_CAL_4(1);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
#else
GemmBlock(a_ptr, b_ptr, 2, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm344(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2;
float32x4_t b0, b1, b2, b3;
float32x4_t c0, c1, c2;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
a2 = vld1q_f32(a_ptr + 2 * stride_a);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
MACE_GEMM_PART_CAL_4(0);
MACE_GEMM_PART_CAL_4(1);
MACE_GEMM_PART_CAL_4(2);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
#else
GemmBlock(a_ptr, b_ptr, 3, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm444(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3;
float32x4_t b0, b1, b2, b3;
float32x4_t c0, c1, c2, c3;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
a2 = vld1q_f32(a_ptr + 2 * stride_a);
a3 = vld1q_f32(a_ptr + 3 * stride_a);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
MACE_GEMM_PART_CAL_4(0);
MACE_GEMM_PART_CAL_4(1);
MACE_GEMM_PART_CAL_4(2);
MACE_GEMM_PART_CAL_4(3);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
#else
GemmBlock(a_ptr, b_ptr, 4, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm544(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4;
float32x4_t b0, b1, b2, b3;
float32x4_t c0, c1, c2, c3, c4;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
a2 = vld1q_f32(a_ptr + 2 * stride_a);
a3 = vld1q_f32(a_ptr + 3 * stride_a);
a4 = vld1q_f32(a_ptr + 4 * stride_a);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
c4 = vld1q_f32(c_ptr + 4 * stride_c);
MACE_GEMM_PART_CAL_4(0);
MACE_GEMM_PART_CAL_4(1);
MACE_GEMM_PART_CAL_4(2);
MACE_GEMM_PART_CAL_4(3);
MACE_GEMM_PART_CAL_4(4);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
vst1q_f32(c_ptr + 4 * stride_c, c4);
#else
GemmBlock(a_ptr, b_ptr, 5, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm644(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5;
float32x4_t b0, b1, b2, b3;
float32x4_t c0, c1, c2, c3, c4, c5;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
a2 = vld1q_f32(a_ptr + 2 * stride_a);
a3 = vld1q_f32(a_ptr + 3 * stride_a);
a4 = vld1q_f32(a_ptr + 4 * stride_a);
a5 = vld1q_f32(a_ptr + 5 * stride_a);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
c4 = vld1q_f32(c_ptr + 4 * stride_c);
c5 = vld1q_f32(c_ptr + 5 * stride_c);
MACE_GEMM_PART_CAL_4(0);
MACE_GEMM_PART_CAL_4(1);
MACE_GEMM_PART_CAL_4(2);
MACE_GEMM_PART_CAL_4(3);
MACE_GEMM_PART_CAL_4(4);
MACE_GEMM_PART_CAL_4(5);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
vst1q_f32(c_ptr + 4 * stride_c, c4);
vst1q_f32(c_ptr + 5 * stride_c, c5);
#else
GemmBlock(a_ptr, b_ptr, 6, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void GemmX44(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr,
int row) {
switch (row) {
case 1:
Gemm144(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 2:
Gemm244(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 3:
Gemm344(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 4:
Gemm444(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 5:
Gemm544(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 6:
Gemm644(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
default:
MACE_NOT_IMPLEMENTED;
}
}
inline void Gemm884(const float *a_ptr, inline void Gemm884(const float *a_ptr,
const float *b_ptr, const float *b_ptr,
const index_t stride_a, const index_t stride_a,
...@@ -119,25 +396,14 @@ inline void Gemm884(const float *a_ptr, ...@@ -119,25 +396,14 @@ inline void Gemm884(const float *a_ptr,
c6 = vld1q_f32(c_ptr + 6 * stride_c); c6 = vld1q_f32(c_ptr + 6 * stride_c);
c7 = vld1q_f32(c_ptr + 7 * stride_c); c7 = vld1q_f32(c_ptr + 7 * stride_c);
#if defined(__aarch64__) MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL(0, 0, 1); MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL(1, 2, 3); MACE_GEMM_PART_CAL_8(2, 4, 5);
MACE_GEMM_PART_CAL(2, 4, 5); MACE_GEMM_PART_CAL_8(3, 6, 7);
MACE_GEMM_PART_CAL(3, 6, 7); MACE_GEMM_PART_CAL_8(4, 8, 9);
MACE_GEMM_PART_CAL(4, 8, 9); MACE_GEMM_PART_CAL_8(5, 10, 11);
MACE_GEMM_PART_CAL(5, 10, 11); MACE_GEMM_PART_CAL_8(6, 12, 13);
MACE_GEMM_PART_CAL(6, 12, 13); MACE_GEMM_PART_CAL_8(7, 14, 15);
MACE_GEMM_PART_CAL(7, 14, 15);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
MACE_GEMM_PART_CAL(5, 10, 11);
MACE_GEMM_PART_CAL(6, 12, 13);
MACE_GEMM_PART_CAL(7, 14, 15);
#endif
vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1); vst1q_f32(c_ptr + 1 * stride_c, c1);
...@@ -180,11 +446,7 @@ inline void Gemm184(const float *a_ptr, ...@@ -180,11 +446,7 @@ inline void Gemm184(const float *a_ptr,
c0 = vld1q_f32(c_ptr); c0 = vld1q_f32(c_ptr);
#if defined(__aarch64__) MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL(0, 0, 1);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
#endif
vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr, c0);
#else #else
...@@ -220,13 +482,8 @@ inline void Gemm284(const float *a_ptr, ...@@ -220,13 +482,8 @@ inline void Gemm284(const float *a_ptr,
c0 = vld1q_f32(c_ptr); c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c); c1 = vld1q_f32(c_ptr + 1 * stride_c);
#if defined(__aarch64__) MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL(0, 0, 1); MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL(1, 2, 3);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
#endif
vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1); vst1q_f32(c_ptr + 1 * stride_c, c1);
...@@ -266,15 +523,9 @@ inline void Gemm384(const float *a_ptr, ...@@ -266,15 +523,9 @@ inline void Gemm384(const float *a_ptr,
c1 = vld1q_f32(c_ptr + 1 * stride_c); c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c); c2 = vld1q_f32(c_ptr + 2 * stride_c);
#if defined(__aarch64__) MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL(0, 0, 1); MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL(1, 2, 3); MACE_GEMM_PART_CAL_8(2, 4, 5);
MACE_GEMM_PART_CAL(2, 4, 5);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
#endif
vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1); vst1q_f32(c_ptr + 1 * stride_c, c1);
...@@ -318,17 +569,10 @@ inline void Gemm484(const float *a_ptr, ...@@ -318,17 +569,10 @@ inline void Gemm484(const float *a_ptr,
c2 = vld1q_f32(c_ptr + 2 * stride_c); c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c); c3 = vld1q_f32(c_ptr + 3 * stride_c);
#if defined(__aarch64__) MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL(0, 0, 1); MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL(1, 2, 3); MACE_GEMM_PART_CAL_8(2, 4, 5);
MACE_GEMM_PART_CAL(2, 4, 5); MACE_GEMM_PART_CAL_8(3, 6, 7);
MACE_GEMM_PART_CAL(3, 6, 7);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
#endif
vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1); vst1q_f32(c_ptr + 1 * stride_c, c1);
...@@ -376,19 +620,11 @@ inline void Gemm584(const float *a_ptr, ...@@ -376,19 +620,11 @@ inline void Gemm584(const float *a_ptr,
c3 = vld1q_f32(c_ptr + 3 * stride_c); c3 = vld1q_f32(c_ptr + 3 * stride_c);
c4 = vld1q_f32(c_ptr + 4 * stride_c); c4 = vld1q_f32(c_ptr + 4 * stride_c);
#if defined(__aarch64__) MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL(0, 0, 1); MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL(1, 2, 3); MACE_GEMM_PART_CAL_8(2, 4, 5);
MACE_GEMM_PART_CAL(2, 4, 5); MACE_GEMM_PART_CAL_8(3, 6, 7);
MACE_GEMM_PART_CAL(3, 6, 7); MACE_GEMM_PART_CAL_8(4, 8, 9);
MACE_GEMM_PART_CAL(4, 8, 9);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
#endif
vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1); vst1q_f32(c_ptr + 1 * stride_c, c1);
...@@ -440,21 +676,12 @@ inline void Gemm684(const float *a_ptr, ...@@ -440,21 +676,12 @@ inline void Gemm684(const float *a_ptr,
c4 = vld1q_f32(c_ptr + 4 * stride_c); c4 = vld1q_f32(c_ptr + 4 * stride_c);
c5 = vld1q_f32(c_ptr + 5 * stride_c); c5 = vld1q_f32(c_ptr + 5 * stride_c);
#if defined(__aarch64__) MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL(0, 0, 1); MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL(1, 2, 3); MACE_GEMM_PART_CAL_8(2, 4, 5);
MACE_GEMM_PART_CAL(2, 4, 5); MACE_GEMM_PART_CAL_8(3, 6, 7);
MACE_GEMM_PART_CAL(3, 6, 7); MACE_GEMM_PART_CAL_8(4, 8, 9);
MACE_GEMM_PART_CAL(4, 8, 9); MACE_GEMM_PART_CAL_8(5, 10, 11);
MACE_GEMM_PART_CAL(5, 10, 11);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
MACE_GEMM_PART_CAL(5, 10, 11);
#endif
vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1); vst1q_f32(c_ptr + 1 * stride_c, c1);
...@@ -511,23 +738,13 @@ inline void Gemm784(const float *a_ptr, ...@@ -511,23 +738,13 @@ inline void Gemm784(const float *a_ptr,
c5 = vld1q_f32(c_ptr + 5 * stride_c); c5 = vld1q_f32(c_ptr + 5 * stride_c);
c6 = vld1q_f32(c_ptr + 6 * stride_c); c6 = vld1q_f32(c_ptr + 6 * stride_c);
#if defined(__aarch64__) MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL(0, 0, 1); MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL(1, 2, 3); MACE_GEMM_PART_CAL_8(2, 4, 5);
MACE_GEMM_PART_CAL(2, 4, 5); MACE_GEMM_PART_CAL_8(3, 6, 7);
MACE_GEMM_PART_CAL(3, 6, 7); MACE_GEMM_PART_CAL_8(4, 8, 9);
MACE_GEMM_PART_CAL(4, 8, 9); MACE_GEMM_PART_CAL_8(5, 10, 11);
MACE_GEMM_PART_CAL(5, 10, 11); MACE_GEMM_PART_CAL_8(6, 12, 13);
MACE_GEMM_PART_CAL(6, 12, 13);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
MACE_GEMM_PART_CAL(5, 10, 11);
MACE_GEMM_PART_CAL(6, 12, 13);
#endif
vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1); vst1q_f32(c_ptr + 1 * stride_c, c1);
...@@ -589,9 +806,19 @@ inline void GemmTile(const float *A, ...@@ -589,9 +806,19 @@ inline void GemmTile(const float *A,
const index_t stride_c, const index_t stride_c,
float *C) { float *C) {
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
index_t h, w, k; index_t h = 0;
for (h = 0; h < height - 7; h += 8) { index_t w = 0;
for (k = 0; k < K - 7; k += 8) { index_t k = 0;
#if defined(__aarch64__)
int reg_height_tile = 8;
int reg_K_tile = 8;
#else
int reg_height_tile = 6;
int reg_K_tile = 4;
#endif
for (h = 0; h < height - reg_height_tile + 1; h += reg_height_tile) {
for (k = 0; k < K - reg_K_tile + 1; k += reg_K_tile) {
const float *a_ptr = A + (h * stride_a + k); const float *a_ptr = A + (h * stride_a + k);
#if defined(__aarch64__) && defined(__clang__) #if defined(__aarch64__) && defined(__clang__)
int nw = width >> 2; int nw = width >> 2;
...@@ -833,43 +1060,180 @@ inline void GemmTile(const float *A, ...@@ -833,43 +1060,180 @@ inline void GemmTile(const float *A,
w = (width >> 2) << 2; w = (width >> 2) << 2;
} }
#else // gcc || armv7a #elif defined(__aarch64__) // gcc
for (w = 0; w + 3 < width; w += 4) { for (w = 0; w + 3 < width; w += 4) {
const float *b_ptr = B + (k * stride_b + w); const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w); float *c_ptr = C + (h * stride_c + w);
Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
} }
#endif // clang && armv8a #else // armv7
int nw = width >> 2;
if (nw > 0) {
float32x4_t a0, a1, a2, a3, a4, a5;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
a2 = vld1q_f32(a_ptr + 2 * stride_a);
a3 = vld1q_f32(a_ptr + 3 * stride_a);
a4 = vld1q_f32(a_ptr + 4 * stride_a);
a5 = vld1q_f32(a_ptr + 5 * stride_a);
const float *b_ptr0 = B + k * stride_b;
const float *b_ptr1 = B + (k + 1) * stride_b;
const float *b_ptr2 = B + (k + 2) * stride_b;
const float *b_ptr3 = B + (k + 3) * stride_b;
float *c_ptr0 = C + h * stride_c;
float *c_ptr1 = C + (h + 1) * stride_c;
float *c_ptr2 = C + (h + 2) * stride_c;
float *c_ptr3 = C + (h + 3) * stride_c;
float *c_ptr4 = C + (h + 4) * stride_c;
float *c_ptr5 = C + (h + 5) * stride_c;
asm volatile(
"pld [%7, #128] \n"
"vld1.f32 {d12-d13}, [%7]! \n"
"pld [%1, #128] \n"
"vld1.f32 {d16-d17}, [%1] \n"
"pld [%2, #128] \n"
"vld1.f32 {d18-d19}, [%2] \n"
"0: \n"
"pld [%3, #128] \n"
"vld1.f32 {d20-d21}, [%3] \n"
"pld [%4, #128] \n"
"vld1.f32 {d22-d23}, [%4] \n"
"pld [%5, #128] \n"
"vld1.f32 {d24-d25}, [%5] \n"
"pld [%6, #128] \n"
"vld1.f32 {d26-d27}, [%6] \n"
"pld [%8, #128] \n"
"vld1.f32 {d14-d15}, [%8]! \n"
"vmla.f32 q8, q6, %e22[0] \n"
"vmla.f32 q9, q6, %e23[0] \n"
"vmla.f32 q10, q6, %e24[0] \n"
"vmla.f32 q11, q6, %e25[0] \n"
"vmla.f32 q12, q6, %e26[0] \n"
"vmla.f32 q13, q6, %e27[0] \n"
"pld [%9, #128] \n"
"vld1.f32 {d12-d13}, [%9]! \n"
"vmla.f32 q8, q7, %e22[1] \n"
"vmla.f32 q9, q7, %e23[1] \n"
"vmla.f32 q10, q7, %e24[1] \n"
"vmla.f32 q11, q7, %e25[1] \n"
"vmla.f32 q12, q7, %e26[1] \n"
"vmla.f32 q13, q7, %e27[1] \n"
"pld [%10, #128] \n"
"vld1.f32 {d14-d15}, [%10]! \n"
"vmla.f32 q8, q6, %f22[0] \n"
"vmla.f32 q9, q6, %f23[0] \n"
"vmla.f32 q10, q6, %f24[0] \n"
"vmla.f32 q11, q6, %f25[0] \n"
"vmla.f32 q12, q6, %f26[0] \n"
"vmla.f32 q13, q6, %f27[0] \n"
"vmla.f32 q8, q7, %f22[1] \n"
"vmla.f32 q9, q7, %f23[1] \n"
"vmla.f32 q10, q7, %f24[1] \n"
"vmla.f32 q11, q7, %f25[1] \n"
"vmla.f32 q12, q7, %f26[1] \n"
"vmla.f32 q13, q7, %f27[1] \n"
"vst1.f32 {d16-d17}, [%1]! \n"
"vst1.f32 {d18-d19}, [%2]! \n"
"pld [%7, #128] \n"
"vld1.f32 {d12-d13}, [%7]! \n"
"vst1.f32 {d20-d21}, [%3]! \n"
"vst1.f32 {d22-d23}, [%4]! \n"
"pld [%1, #128] \n"
"vld1.f32 {d16-d17}, [%1] \n"
"vst1.f32 {d24-d25}, [%5]! \n"
"vst1.f32 {d26-d27}, [%6]! \n"
"pld [%2, #128] \n"
"vld1.f32 {d18-d19}, [%2] \n"
"subs %0, #1 \n"
"bne 0b \n"
: "=r"(nw), // 0
"=r"(c_ptr0), // 1
"=r"(c_ptr1), // 2
"=r"(c_ptr2), // 3
"=r"(c_ptr3), // 4
"=r"(c_ptr4), // 5
"=r"(c_ptr5), // 6
"=r"(b_ptr0), // 7
"=r"(b_ptr1), // 8
"=r"(b_ptr2), // 9
"=r"(b_ptr3) // 10
: "0"(nw), // 11
"1"(c_ptr0), // 12
"2"(c_ptr1), // 13
"3"(c_ptr2), // 14
"4"(c_ptr3), // 15
"5"(c_ptr4), // 16
"6"(c_ptr5), // 17
"7"(b_ptr0), // 18
"8"(b_ptr1), // 19
"9"(b_ptr2), // 20
"10"(b_ptr3), // 21
"w"(a0), // 22
"w"(a1), // 23
"w"(a2), // 24
"w"(a3), // 25
"w"(a4), // 26
"w"(a5) // 27
: "cc", "memory", "q6", "q7", "q8", "q9", "q10", "q11", "q12",
"q13", "q14", "q15");
w = (width >> 2) << 2;
}
#endif
if (w < width) { if (w < width) {
const float *b_ptr = B + (k * stride_b + w); const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w); float *c_ptr = C + (h * stride_c + w);
GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_a, stride_b, stride_c, GemmBlock(a_ptr, b_ptr, reg_height_tile, reg_K_tile, width - w,
c_ptr); stride_a, stride_b, stride_c, c_ptr);
} }
} }
if (k < K) { if (k < K) {
const float *a_ptr = A + (h * stride_a + k); const float *a_ptr = A + (h * stride_a + k);
const float *b_ptr = B + k * stride_b; const float *b_ptr = B + k * stride_b;
float *c_ptr = C + h * stride_c; float *c_ptr = C + h * stride_c;
GemmBlock(a_ptr, b_ptr, 8, K - k, width, stride_a, stride_b, stride_c, GemmBlock(a_ptr, b_ptr, reg_height_tile, K - k, width, stride_a, stride_b,
c_ptr); stride_c, c_ptr);
} }
} }
if (h < height) { if (h < height) {
index_t remain_h = height - h; index_t remain_h = height - h;
for (k = 0; k < K - 7; k += 8) { for (k = 0; k < K - reg_K_tile; k += reg_K_tile) {
const float *a_ptr = A + (h * stride_a + k); const float *a_ptr = A + (h * stride_a + k);
index_t w; index_t w;
for (w = 0; w + 3 < width; w += 4) { for (w = 0; w + 3 < width; w += 4) {
const float *b_ptr = B + (k * stride_b + w); const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w); float *c_ptr = C + (h * stride_c + w);
#if defined(__aarch64__)
GemmX84(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr, remain_h); GemmX84(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr, remain_h);
#else
GemmX44(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr, remain_h);
#endif
} }
if (w < width) { if (w < width) {
const float *b_ptr = B + (k * stride_b + w); const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w); float *c_ptr = C + (h * stride_c + w);
GemmBlock(a_ptr, b_ptr, remain_h, 8, width - w, stride_a, stride_b, GemmBlock(a_ptr, b_ptr, remain_h, reg_K_tile, width - w, stride_a,
stride_c, c_ptr); stride_b, stride_c, c_ptr);
} }
} }
if (k < K) { if (k < K) {
......
...@@ -38,7 +38,9 @@ void TestShapeOp(const std::vector<index_t> &input_shape) { ...@@ -38,7 +38,9 @@ void TestShapeOp(const std::vector<index_t> &input_shape) {
std::vector<int32_t> expected_input_shape(input_shape.begin(), std::vector<int32_t> expected_input_shape(input_shape.begin(),
input_shape.end()); input_shape.end());
if (!expected_input_shape.empty()) { if (!expected_input_shape.empty()) {
net.AddInputFromArray<CPU, int32_t>("ExpectedOutput", {input_shape.size()}, net.AddInputFromArray<CPU, int32_t>("ExpectedOutput",
{static_cast<int32_t>(
input_shape.size())},
expected_input_shape); expected_input_shape);
} else { } else {
net.AddInputFromArray<CPU, int32_t>("ExpectedOutput", {}, {0}); net.AddInputFromArray<CPU, int32_t>("ExpectedOutput", {}, {0});
......
...@@ -37,11 +37,18 @@ void TestSlice(const std::vector<index_t> &input_shape, ...@@ -37,11 +37,18 @@ void TestSlice(const std::vector<index_t> &input_shape,
const std::vector<float> &output) { const std::vector<float> &output) {
OpsTestNet net; OpsTestNet net;
net.AddInputFromArray<CPU, float>("Input", input_shape, input); net.AddInputFromArray<CPU, float>("Input", input_shape, input);
net.AddInputFromArray<CPU, int32_t>("BeginIndices", {input_shape.size()}, net.AddInputFromArray<CPU, int32_t>("BeginIndices",
{static_cast<int32_t>(
input_shape.size())},
begin_indices); begin_indices);
net.AddInputFromArray<CPU, int32_t>("EndIndices", {input_shape.size()}, net.AddInputFromArray<CPU, int32_t>("EndIndices",
{static_cast<int32_t>(
input_shape.size())},
end_indices); end_indices);
net.AddInputFromArray<CPU, int32_t>("Strides", {input_shape.size()}, strides); net.AddInputFromArray<CPU, int32_t>("Strides",
{static_cast<int32_t>(
input_shape.size())},
strides);
OpDefBuilder("StridedSlice", "StridedSliceOpTest") OpDefBuilder("StridedSlice", "StridedSliceOpTest")
.Input("Input") .Input("Input")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册