提交 9939c334 编写于 作者: 吴承辉

Merge branch 'gemm' into 'master'

Optimize gemm x84 (v8/v7)  gemv v7

See merge request !544
...@@ -20,11 +20,13 @@ ...@@ -20,11 +20,13 @@
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
#include "mace/core/macros.h"
#include "mace/kernels/gemm.h" #include "mace/kernels/gemm.h"
#include "mace/utils/utils.h"
#include "mace/utils/logging.h" #include "mace/utils/logging.h"
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif
namespace mace { namespace mace {
namespace kernels { namespace kernels {
...@@ -47,13 +49,36 @@ inline void GemmBlock(const float *A, ...@@ -47,13 +49,36 @@ inline void GemmBlock(const float *A,
} }
} }
// TODO(liyin): may need implement 883 since RGB #if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
#define MACE_GEMM_PART_CAL(RC, RA, RAN) \
c##RC = vfmaq_laneq_f32(c##RC, b0, a##RA, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b1, a##RA, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b2, a##RA, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b3, a##RA, 3); \
c##RC = vfmaq_laneq_f32(c##RC, b4, a##RAN, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b5, a##RAN, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b6, a##RAN, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b7, a##RAN, 3);
#else
#define MACE_GEMM_PART_CAL(RC, RA, RAN) \
c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RA), 0); \
c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RA), 1); \
c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RA), 0); \
c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RA), 1); \
c##RC = vmlaq_lane_f32(c##RC, b4, vget_low_f32(a##RAN), 0); \
c##RC = vmlaq_lane_f32(c##RC, b5, vget_low_f32(a##RAN), 1); \
c##RC = vmlaq_lane_f32(c##RC, b6, vget_high_f32(a##RAN), 0); \
c##RC = vmlaq_lane_f32(c##RC, b7, vget_high_f32(a##RAN), 1);
#endif
#endif
inline void Gemm884(const float *a_ptr, inline void Gemm884(const float *a_ptr,
const float *b_ptr, const float *b_ptr,
index_t stride_k, index_t stride_k,
index_t stride_w, index_t stride_w,
float *c_ptr) { float *c_ptr) {
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) #if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14,
a15; a15;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
...@@ -94,24 +119,25 @@ inline void Gemm884(const float *a_ptr, ...@@ -94,24 +119,25 @@ inline void Gemm884(const float *a_ptr,
c6 = vld1q_f32(c_ptr + 6 * stride_w); c6 = vld1q_f32(c_ptr + 6 * stride_w);
c7 = vld1q_f32(c_ptr + 7 * stride_w); c7 = vld1q_f32(c_ptr + 7 * stride_w);
#define MACE_CONV_1x1_REG_CAL(RC, RA, RAN) \ #if defined(__aarch64__)
c##RC = vfmaq_laneq_f32(c##RC, b0, a##RA, 0); \ MACE_GEMM_PART_CAL(0, 0, 1);
c##RC = vfmaq_laneq_f32(c##RC, b1, a##RA, 1); \ MACE_GEMM_PART_CAL(1, 2, 3);
c##RC = vfmaq_laneq_f32(c##RC, b2, a##RA, 2); \ MACE_GEMM_PART_CAL(2, 4, 5);
c##RC = vfmaq_laneq_f32(c##RC, b3, a##RA, 3); \ MACE_GEMM_PART_CAL(3, 6, 7);
c##RC = vfmaq_laneq_f32(c##RC, b4, a##RAN, 0); \ MACE_GEMM_PART_CAL(4, 8, 9);
c##RC = vfmaq_laneq_f32(c##RC, b5, a##RAN, 1); \ MACE_GEMM_PART_CAL(5, 10, 11);
c##RC = vfmaq_laneq_f32(c##RC, b6, a##RAN, 2); \ MACE_GEMM_PART_CAL(6, 12, 13);
c##RC = vfmaq_laneq_f32(c##RC, b7, a##RAN, 3); MACE_GEMM_PART_CAL(7, 14, 15);
#else
MACE_CONV_1x1_REG_CAL(0, 0, 1); MACE_GEMM_PART_CAL(0, 0, 1);
MACE_CONV_1x1_REG_CAL(1, 2, 3); MACE_GEMM_PART_CAL(1, 2, 3);
MACE_CONV_1x1_REG_CAL(2, 4, 5); MACE_GEMM_PART_CAL(2, 4, 5);
MACE_CONV_1x1_REG_CAL(3, 6, 7); MACE_GEMM_PART_CAL(3, 6, 7);
MACE_CONV_1x1_REG_CAL(4, 8, 9); MACE_GEMM_PART_CAL(4, 8, 9);
MACE_CONV_1x1_REG_CAL(5, 10, 11); MACE_GEMM_PART_CAL(5, 10, 11);
MACE_CONV_1x1_REG_CAL(6, 12, 13); MACE_GEMM_PART_CAL(6, 12, 13);
MACE_CONV_1x1_REG_CAL(7, 14, 15); MACE_GEMM_PART_CAL(7, 14, 15);
#endif
vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_w, c1); vst1q_f32(c_ptr + 1 * stride_w, c1);
...@@ -121,12 +147,428 @@ inline void Gemm884(const float *a_ptr, ...@@ -121,12 +147,428 @@ inline void Gemm884(const float *a_ptr,
vst1q_f32(c_ptr + 5 * stride_w, c5); vst1q_f32(c_ptr + 5 * stride_w, c5);
vst1q_f32(c_ptr + 6 * stride_w, c6); vst1q_f32(c_ptr + 6 * stride_w, c6);
vst1q_f32(c_ptr + 7 * stride_w, c7); vst1q_f32(c_ptr + 7 * stride_w, c7);
#else #else
GemmBlock(a_ptr, b_ptr, 8, 8, 4, stride_k, stride_w, c_ptr); GemmBlock(a_ptr, b_ptr, 8, 8, 4, stride_k, stride_w, c_ptr);
#endif #endif
} }
inline void Gemm184(const float *a_ptr,
const float *b_ptr,
index_t stride_k,
index_t stride_w,
float *c_ptr) {
MACE_UNUSED(stride_k);
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t c0;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_w);
b2 = vld1q_f32(b_ptr + 2 * stride_w);
b3 = vld1q_f32(b_ptr + 3 * stride_w);
b4 = vld1q_f32(b_ptr + 4 * stride_w);
b5 = vld1q_f32(b_ptr + 5 * stride_w);
b6 = vld1q_f32(b_ptr + 6 * stride_w);
b7 = vld1q_f32(b_ptr + 7 * stride_w);
c0 = vld1q_f32(c_ptr);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
#endif
vst1q_f32(c_ptr, c0);
#else
GemmBlock(a_ptr, b_ptr, 1, 8, 4, stride_k, stride_w, c_ptr);
#endif
}
inline void Gemm284(const float *a_ptr,
const float *b_ptr,
index_t stride_k,
index_t stride_w,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t c0, c1;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_k);
a3 = vld1q_f32(a_ptr + 1 * stride_k + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_w);
b2 = vld1q_f32(b_ptr + 2 * stride_w);
b3 = vld1q_f32(b_ptr + 3 * stride_w);
b4 = vld1q_f32(b_ptr + 4 * stride_w);
b5 = vld1q_f32(b_ptr + 5 * stride_w);
b6 = vld1q_f32(b_ptr + 6 * stride_w);
b7 = vld1q_f32(b_ptr + 7 * stride_w);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_w);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
#endif
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_w, c1);
#else
GemmBlock(a_ptr, b_ptr, 2, 8, 4, stride_k, stride_w, c_ptr);
#endif
}
inline void Gemm384(const float *a_ptr,
const float *b_ptr,
index_t stride_k,
index_t stride_w,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t c0, c1, c2;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_k);
a3 = vld1q_f32(a_ptr + 1 * stride_k + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_k);
a5 = vld1q_f32(a_ptr + 2 * stride_k + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_w);
b2 = vld1q_f32(b_ptr + 2 * stride_w);
b3 = vld1q_f32(b_ptr + 3 * stride_w);
b4 = vld1q_f32(b_ptr + 4 * stride_w);
b5 = vld1q_f32(b_ptr + 5 * stride_w);
b6 = vld1q_f32(b_ptr + 6 * stride_w);
b7 = vld1q_f32(b_ptr + 7 * stride_w);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_w);
c2 = vld1q_f32(c_ptr + 2 * stride_w);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
#endif
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_w, c1);
vst1q_f32(c_ptr + 2 * stride_w, c2);
#else
GemmBlock(a_ptr, b_ptr, 3, 8, 4, stride_k, stride_w, c_ptr);
#endif
}
inline void Gemm484(const float *a_ptr,
const float *b_ptr,
index_t stride_k,
index_t stride_w,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t c0, c1, c2, c3;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_k);
a3 = vld1q_f32(a_ptr + 1 * stride_k + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_k);
a5 = vld1q_f32(a_ptr + 2 * stride_k + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_k);
a7 = vld1q_f32(a_ptr + 3 * stride_k + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_w);
b2 = vld1q_f32(b_ptr + 2 * stride_w);
b3 = vld1q_f32(b_ptr + 3 * stride_w);
b4 = vld1q_f32(b_ptr + 4 * stride_w);
b5 = vld1q_f32(b_ptr + 5 * stride_w);
b6 = vld1q_f32(b_ptr + 6 * stride_w);
b7 = vld1q_f32(b_ptr + 7 * stride_w);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_w);
c2 = vld1q_f32(c_ptr + 2 * stride_w);
c3 = vld1q_f32(c_ptr + 3 * stride_w);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
#endif
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_w, c1);
vst1q_f32(c_ptr + 2 * stride_w, c2);
vst1q_f32(c_ptr + 3 * stride_w, c3);
#else
GemmBlock(a_ptr, b_ptr, 4, 8, 4, stride_k, stride_w, c_ptr);
#endif
}
inline void Gemm584(const float *a_ptr,
const float *b_ptr,
index_t stride_k,
index_t stride_w,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t c0, c1, c2, c3, c4;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_k);
a3 = vld1q_f32(a_ptr + 1 * stride_k + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_k);
a5 = vld1q_f32(a_ptr + 2 * stride_k + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_k);
a7 = vld1q_f32(a_ptr + 3 * stride_k + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_k);
a9 = vld1q_f32(a_ptr + 4 * stride_k + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_w);
b2 = vld1q_f32(b_ptr + 2 * stride_w);
b3 = vld1q_f32(b_ptr + 3 * stride_w);
b4 = vld1q_f32(b_ptr + 4 * stride_w);
b5 = vld1q_f32(b_ptr + 5 * stride_w);
b6 = vld1q_f32(b_ptr + 6 * stride_w);
b7 = vld1q_f32(b_ptr + 7 * stride_w);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_w);
c2 = vld1q_f32(c_ptr + 2 * stride_w);
c3 = vld1q_f32(c_ptr + 3 * stride_w);
c4 = vld1q_f32(c_ptr + 4 * stride_w);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
#endif
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_w, c1);
vst1q_f32(c_ptr + 2 * stride_w, c2);
vst1q_f32(c_ptr + 3 * stride_w, c3);
vst1q_f32(c_ptr + 4 * stride_w, c4);
#else
GemmBlock(a_ptr, b_ptr, 5, 8, 4, stride_k, stride_w, c_ptr);
#endif
}
inline void Gemm684(const float *a_ptr,
const float *b_ptr,
index_t stride_k,
index_t stride_w,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t c0, c1, c2, c3, c4, c5;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_k);
a3 = vld1q_f32(a_ptr + 1 * stride_k + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_k);
a5 = vld1q_f32(a_ptr + 2 * stride_k + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_k);
a7 = vld1q_f32(a_ptr + 3 * stride_k + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_k);
a9 = vld1q_f32(a_ptr + 4 * stride_k + 4);
a10 = vld1q_f32(a_ptr + 5 * stride_k);
a11 = vld1q_f32(a_ptr + 5 * stride_k + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_w);
b2 = vld1q_f32(b_ptr + 2 * stride_w);
b3 = vld1q_f32(b_ptr + 3 * stride_w);
b4 = vld1q_f32(b_ptr + 4 * stride_w);
b5 = vld1q_f32(b_ptr + 5 * stride_w);
b6 = vld1q_f32(b_ptr + 6 * stride_w);
b7 = vld1q_f32(b_ptr + 7 * stride_w);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_w);
c2 = vld1q_f32(c_ptr + 2 * stride_w);
c3 = vld1q_f32(c_ptr + 3 * stride_w);
c4 = vld1q_f32(c_ptr + 4 * stride_w);
c5 = vld1q_f32(c_ptr + 5 * stride_w);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
MACE_GEMM_PART_CAL(5, 10, 11);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
MACE_GEMM_PART_CAL(5, 10, 11);
#endif
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_w, c1);
vst1q_f32(c_ptr + 2 * stride_w, c2);
vst1q_f32(c_ptr + 3 * stride_w, c3);
vst1q_f32(c_ptr + 4 * stride_w, c4);
vst1q_f32(c_ptr + 5 * stride_w, c5);
#else
GemmBlock(a_ptr, b_ptr, 6, 8, 4, stride_k, stride_w, c_ptr);
#endif
}
inline void Gemm784(const float *a_ptr,
const float *b_ptr,
index_t stride_k,
index_t stride_w,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t c0, c1, c2, c3, c4, c5, c6;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_k);
a3 = vld1q_f32(a_ptr + 1 * stride_k + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_k);
a5 = vld1q_f32(a_ptr + 2 * stride_k + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_k);
a7 = vld1q_f32(a_ptr + 3 * stride_k + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_k);
a9 = vld1q_f32(a_ptr + 4 * stride_k + 4);
a10 = vld1q_f32(a_ptr + 5 * stride_k);
a11 = vld1q_f32(a_ptr + 5 * stride_k + 4);
a12 = vld1q_f32(a_ptr + 6 * stride_k);
a13 = vld1q_f32(a_ptr + 6 * stride_k + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_w);
b2 = vld1q_f32(b_ptr + 2 * stride_w);
b3 = vld1q_f32(b_ptr + 3 * stride_w);
b4 = vld1q_f32(b_ptr + 4 * stride_w);
b5 = vld1q_f32(b_ptr + 5 * stride_w);
b6 = vld1q_f32(b_ptr + 6 * stride_w);
b7 = vld1q_f32(b_ptr + 7 * stride_w);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_w);
c2 = vld1q_f32(c_ptr + 2 * stride_w);
c3 = vld1q_f32(c_ptr + 3 * stride_w);
c4 = vld1q_f32(c_ptr + 4 * stride_w);
c5 = vld1q_f32(c_ptr + 5 * stride_w);
c6 = vld1q_f32(c_ptr + 6 * stride_w);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
MACE_GEMM_PART_CAL(5, 10, 11);
MACE_GEMM_PART_CAL(6, 12, 13);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
MACE_GEMM_PART_CAL(5, 10, 11);
MACE_GEMM_PART_CAL(6, 12, 13);
#endif
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_w, c1);
vst1q_f32(c_ptr + 2 * stride_w, c2);
vst1q_f32(c_ptr + 3 * stride_w, c3);
vst1q_f32(c_ptr + 4 * stride_w, c4);
vst1q_f32(c_ptr + 5 * stride_w, c5);
vst1q_f32(c_ptr + 6 * stride_w, c6);
#else
GemmBlock(a_ptr, b_ptr, 7, 8, 4, stride_k, stride_w, c_ptr);
#endif
}
inline void GemmX84(const float *a_ptr,
const float *b_ptr,
index_t stride_k,
index_t stride_w,
float *c_ptr,
int row) {
switch (row) {
case 1:
Gemm184(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
break;
case 2:
Gemm284(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
break;
case 3:
Gemm384(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
break;
case 4:
Gemm484(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
break;
case 5:
Gemm584(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
break;
case 6:
Gemm684(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
break;
case 7:
Gemm784(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
break;
case 8:
Gemm884(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
break;
default:
MACE_NOT_IMPLEMENTED;
}
}
inline void GemmTile(const float *A, inline void GemmTile(const float *A,
const float *B, const float *B,
const index_t height, const index_t height,
...@@ -137,18 +579,15 @@ inline void GemmTile(const float *A, ...@@ -137,18 +579,15 @@ inline void GemmTile(const float *A,
float *C) { float *C) {
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
index_t h, w, k; index_t h, w, k;
#endif
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
for (h = 0; h < height - 7; h += 8) { for (h = 0; h < height - 7; h += 8) {
for (k = 0; k < K - 7; k += 8) { for (k = 0; k < K - 7; k += 8) {
const float *a_ptr = A + (h * stride_k + k); const float *a_ptr = A + (h * stride_k + k);
#ifdef __clang__ #if defined(__aarch64__) && defined(__clang__)
int nw = width >> 2; int nw = width >> 2;
if (nw > 0) { if (nw > 0) {
// load A // load A
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13,
a8, a9, a10, a11, a12, a13, a14, a15; a14, a15;
a0 = vld1q_f32(a_ptr); a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4); a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_k); a2 = vld1q_f32(a_ptr + 1 * stride_k);
...@@ -378,30 +817,19 @@ inline void GemmTile(const float *A, ...@@ -378,30 +817,19 @@ inline void GemmTile(const float *A,
"w"(a11), // 47 "w"(a11), // 47
"w"(a13), // 48 "w"(a13), // 48
"w"(a15) // 49 "w"(a15) // 49
: "cc", "memory", : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
"v16", "v23", "v24", "v25");
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25"
);
w = (width >> 2) << 2; w = (width >> 2) << 2;
} }
#else // gcc #else // gcc || armv7a
for (w = 0; w + 3 < width; w += 4) { for (w = 0; w + 3 < width; w += 4) {
const float *b_ptr = B + (k * stride_w + w); const float *b_ptr = B + (k * stride_w + w);
float *c_ptr = C + (h * stride_w + w); float *c_ptr = C + (h * stride_w + w);
Gemm884(a_ptr, b_ptr, stride_k, stride_w, c_ptr); Gemm884(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
} }
#endif // clang #endif // clang && armv8a
if (w < width) { if (w < width) {
const float *a_ptr = A + (h * stride_k + k);
const float *b_ptr = B + (k * stride_w + w); const float *b_ptr = B + (k * stride_w + w);
float *c_ptr = C + (h * stride_w + w); float *c_ptr = C + (h * stride_w + w);
GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_k, stride_w, c_ptr); GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_k, stride_w, c_ptr);
...@@ -411,154 +839,37 @@ inline void GemmTile(const float *A, ...@@ -411,154 +839,37 @@ inline void GemmTile(const float *A,
const float *a_ptr = A + (h * stride_k + k); const float *a_ptr = A + (h * stride_k + k);
const float *b_ptr = B + k * stride_w; const float *b_ptr = B + k * stride_w;
float *c_ptr = C + h * stride_w; float *c_ptr = C + h * stride_w;
GemmBlock(a_ptr, GemmBlock(a_ptr, b_ptr, 8, K - k, width, stride_k, stride_w, c_ptr);
b_ptr,
8,
K - k,
width,
stride_k,
stride_w,
c_ptr);
} }
} }
if (h < height) { if (h < height) {
// TODO(liyin): may use Gemm444 index_t remain_h = height - h;
const float *a_ptr = A + (h * stride_k); for (k = 0; k < K - 7; k += 8) {
const float *b_ptr = B;
float *c_ptr = C + h * stride_w;
GemmBlock(a_ptr,
b_ptr,
height - h,
K,
width,
stride_k,
stride_w,
c_ptr);
}
#else
#if defined(MACE_ENABLE_NEON) // armv7
w = (width >> 2) << 2;
for (h = 0; h + 3 < height; h += 4) {
for (k = 0; k + 3 < K; k += 4) {
const float *a_ptr = A + (h * stride_k + k); const float *a_ptr = A + (h * stride_k + k);
int nw = width >> 2; index_t w;
if (nw > 0) { for (w = 0; w + 3 < width; w += 4) {
// load A const float *b_ptr = B + (k * stride_w + w);
float32x2_t a00, a01, a10, a11, a20, a21, a30, a31; float *c_ptr = C + (h * stride_w + w);
a00 = vld1_f32(a_ptr); GemmX84(a_ptr, b_ptr, stride_k, stride_w, c_ptr, remain_h);
a01 = vld1_f32(a_ptr + 2);
a10 = vld1_f32(a_ptr + 1 * stride_k);
a11 = vld1_f32(a_ptr + 1 * stride_k + 2);
a20 = vld1_f32(a_ptr + 2 * stride_k);
a21 = vld1_f32(a_ptr + 2 * stride_k + 2);
a30 = vld1_f32(a_ptr + 3 * stride_k);
a31 = vld1_f32(a_ptr + 3 * stride_k + 2);
const float *b_ptr0 = B + k * stride_w;
const float *b_ptr1 = B + (k + 1) * stride_w;
const float *b_ptr2 = B + (k + 2) * stride_w;
const float *b_ptr3 = B + (k + 3) * stride_w;
float *c_ptr0 = C + h * stride_w;
float *c_ptr1 = C + (h + 1) * stride_w;
float *c_ptr2 = C + (h + 2) * stride_w;
float *c_ptr3 = C + (h + 3) * stride_w;
// TODO(liyin): asm v7 prefetch and load optimization
while (nw--) {
float32x4_t b0, b1, b2, b3;
float32x4_t c0, c1, c2, c3;
c0 = vld1q_f32(c_ptr0);
b0 = vld1q_f32(b_ptr0);
b1 = vld1q_f32(b_ptr1);
b2 = vld1q_f32(b_ptr2);
b3 = vld1q_f32(b_ptr3);
c1 = vld1q_f32(c_ptr1);
c2 = vld1q_f32(c_ptr2);
c3 = vld1q_f32(c_ptr3);
c0 = vmlaq_lane_f32(c0, b0, a00, 0);
c0 = vmlaq_lane_f32(c0, b1, a00, 1);
c0 = vmlaq_lane_f32(c0, b2, a01, 0);
c0 = vmlaq_lane_f32(c0, b3, a01, 1);
vst1q_f32(c_ptr0, c0);
c1 = vmlaq_lane_f32(c1, b0, a10, 0);
c1 = vmlaq_lane_f32(c1, b1, a10, 1);
c1 = vmlaq_lane_f32(c1, b2, a11, 0);
c1 = vmlaq_lane_f32(c1, b3, a11, 1);
vst1q_f32(c_ptr1, c1);
c2 = vmlaq_lane_f32(c2, b0, a20, 0);
c2 = vmlaq_lane_f32(c2, b1, a20, 1);
c2 = vmlaq_lane_f32(c2, b2, a21, 0);
c2 = vmlaq_lane_f32(c2, b3, a21, 1);
vst1q_f32(c_ptr2, c2);
c3 = vmlaq_lane_f32(c3, b0, a30, 0);
c3 = vmlaq_lane_f32(c3, b1, a30, 1);
c3 = vmlaq_lane_f32(c3, b2, a31, 0);
c3 = vmlaq_lane_f32(c3, b3, a31, 1);
vst1q_f32(c_ptr3, c3);
b_ptr0 += 4;
b_ptr1 += 4;
b_ptr2 += 4;
b_ptr3 += 4;
c_ptr0 += 4;
c_ptr1 += 4;
c_ptr2 += 4;
c_ptr3 += 4;
}
} }
if (w < width) { if (w < width) {
const float *a_ptr = A + (h * stride_k + k);
const float *b_ptr = B + (k * stride_w + w); const float *b_ptr = B + (k * stride_w + w);
float *c_ptr = C + (h * stride_w + w); float *c_ptr = C + (h * stride_w + w);
GemmBlock(a_ptr, b_ptr, 4, 4, width - w, stride_k, stride_w, c_ptr); GemmBlock(a_ptr, b_ptr, remain_h, 8, width - w, stride_k, stride_w,
c_ptr);
} }
} }
if (k < K) { if (k < K) {
const float *a_ptr = A + (h * stride_k + k); const float *a_ptr = A + (h * stride_k + k);
const float *b_ptr = B + k * stride_w; const float *b_ptr = B + k * stride_w;
float *c_ptr = C + h * stride_w; float *c_ptr = C + h * stride_w;
GemmBlock(a_ptr, GemmBlock(a_ptr, b_ptr, remain_h, K - k, width, stride_k, stride_w,
b_ptr,
4,
K - k,
width,
stride_k,
stride_w,
c_ptr); c_ptr);
} }
} }
if (h < height) { #else
const float *a_ptr = A + (h * stride_k);
const float *b_ptr = B;
float *c_ptr = C + h * stride_w;
GemmBlock(a_ptr,
b_ptr,
height - h,
K,
width,
stride_k,
stride_w,
c_ptr);
}
#else // cpu
GemmBlock(A, B, height, K, width, stride_k, stride_w, C); GemmBlock(A, B, height, K, width, stride_k, stride_w, C);
#endif // armv7 #endif // MACE_ENABLE_NEON
#endif // aarch64
} }
} // namespace } // namespace
...@@ -602,29 +913,25 @@ void Gemm(const float *A, ...@@ -602,29 +913,25 @@ void Gemm(const float *A,
const index_t ih_begin = bh * block_size; const index_t ih_begin = bh * block_size;
const index_t ih_end = const index_t ih_end =
bh * block_size + (bh == block_tile_height - 1 && remain_height > 0 bh * block_size + (bh == block_tile_height - 1 && remain_height > 0
? remain_height : block_size); ? remain_height
: block_size);
const index_t iw_begin = bw * block_size; const index_t iw_begin = bw * block_size;
const index_t iw_end = const index_t iw_end =
bw * block_size bw * block_size + (bw == block_tile_width - 1 && remain_width > 0
+ (bw == block_tile_width - 1 && remain_width > 0 ? remain_width ? remain_width
: block_size); : block_size);
for (index_t bk = 0; bk < block_tile_k; ++bk) { for (index_t bk = 0; bk < block_tile_k; ++bk) {
const index_t ik_begin = bk * block_size; const index_t ik_begin = bk * block_size;
const index_t ik_end = const index_t ik_end =
bk * block_size bk * block_size +
+ (bk == block_tile_k - 1 && remain_k > 0 ? remain_k (bk == block_tile_k - 1 && remain_k > 0 ? remain_k : block_size);
: block_size);
// inside block: // inside block:
// calculate C[bh, bw] += A[bh, bk] * B[bk, bw] for one k // calculate C[bh, bw] += A[bh, bk] * B[bk, bw] for one k
GemmTile(a_base + (ih_begin * K + ik_begin), GemmTile(a_base + (ih_begin * K + ik_begin),
b_base + (ik_begin * width + iw_begin), b_base + (ik_begin * width + iw_begin), ih_end - ih_begin,
ih_end - ih_begin, ik_end - ik_begin, iw_end - iw_begin, K, width,
ik_end - ik_begin,
iw_end - iw_begin,
K,
width,
c_base + (ih_begin * width + iw_begin)); c_base + (ih_begin * width + iw_begin));
} // bk } // bk
} // bw } // bw
...@@ -635,59 +942,60 @@ void Gemm(const float *A, ...@@ -635,59 +942,60 @@ void Gemm(const float *A,
// A: height x K, B: K x width, C: height x width // A: height x K, B: K x width, C: height x width
void GemmRef(const float *A, void GemmRef(const float *A,
const float *B, const float *B,
const index_t batch,
const index_t height, const index_t height,
const index_t K, const index_t K,
const index_t width, const index_t width,
float *C) { float *C) {
memset(C, 0, sizeof(float) * height * width); memset(C, 0, sizeof(float) * batch * height * width);
for (int i = 0; i < height; ++i) { for (index_t b = 0; b < batch; ++b) {
for (int j = 0; j < width; ++j) { for (index_t i = 0; i < height; ++i) {
for (int k = 0; k < K; ++k) { for (index_t j = 0; j < width; ++j) {
C[i * width + j] += A[i * K + k] * B[k * width + j]; for (index_t k = 0; k < K; ++k) {
C[(b * height + i) * width + j] +=
A[(b * height + i) * K + k] * B[(b * K + k) * width + j];
}
} }
} }
} }
} }
void GemvRef(const float *m_ptr, void GemvRef(const float *m_ptr,
const float *v_ptr, const float *v_ptr,
const index_t batch, const index_t batch,
const index_t width, const index_t width,
const index_t height, const index_t height,
float *out_ptr) { float *out_ptr) {
memset(out_ptr, 0, sizeof(float) * height * batch); memset(out_ptr, 0, batch * height * sizeof(float));
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) { for (int b = 0; b < batch; ++b) {
for (int h = 0; h < height; ++h) { for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) { for (int w = 0; w < width; ++w) {
out_ptr[h + b * height] += v_ptr[w + b * width] * m_ptr[h * width + w]; out_ptr[b * height + h] += v_ptr[b * width + w] * m_ptr[h * width + w];
} }
} }
} }
} }
// M: height x width, Vin: width x 1, Vout: height x 1 // TODO(liyin): batched gemv can be transformed to gemm (w/ transpose)
void Gemv(const float *m_ptr, void Gemv(const float *m_ptr,
const float *v_ptr, const float *v_ptr,
const index_t batch, const index_t batch,
const index_t width, const index_t width,
const index_t height, const index_t height,
float *out_ptr) { float *out_ptr) {
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) #if defined(MACE_ENABLE_NEON)
index_t height_d4 = height >> 2; // TODO(liyin/wch): try height tiling = 8
index_t width_d4 = width >> 2; #pragma omp parallel for collapse(2)
index_t remain_w = width - (width_d4 << 2);
index_t remain_h = height - (height_d4 << 2);
for (index_t b = 0; b < batch; ++b) { for (index_t b = 0; b < batch; ++b) {
#pragma omp parallel for for (index_t h = 0; h < height; h += 4) {
for (index_t h = 0; h < height_d4; ++h) { if (h + 3 < height) {
const float *m_ptr0 = m_ptr + h * width * 4; const float *m_ptr0 = m_ptr + h * width;
const float *m_ptr1 = m_ptr0 + width; const float *m_ptr1 = m_ptr0 + width;
const float *m_ptr2 = m_ptr1 + width; const float *m_ptr2 = m_ptr1 + width;
const float *m_ptr3 = m_ptr2 + width; const float *m_ptr3 = m_ptr2 + width;
const float *v_ptr0 = v_ptr + b * width; const float *v_ptr0 = v_ptr + b * width;
float *out_ptr0 = out_ptr + h * 4 + b * height; float *out_ptr0 = out_ptr + b * height + h;
float32x4_t vm0, vm1, vm2, vm3; float32x4_t vm0, vm1, vm2, vm3;
float32x4_t vv; float32x4_t vv;
...@@ -697,7 +1005,8 @@ void Gemv(const float *m_ptr, ...@@ -697,7 +1005,8 @@ void Gemv(const float *m_ptr,
float32x4_t vsum2 = vdupq_n_f32(0.f); float32x4_t vsum2 = vdupq_n_f32(0.f);
float32x4_t vsum3 = vdupq_n_f32(0.f); float32x4_t vsum3 = vdupq_n_f32(0.f);
for (index_t w = 0; w < width_d4; ++w) { index_t w;
for (w = 0; w + 3 < width; w += 4) {
vm0 = vld1q_f32(m_ptr0); vm0 = vld1q_f32(m_ptr0);
vm1 = vld1q_f32(m_ptr1); vm1 = vld1q_f32(m_ptr1);
vm2 = vld1q_f32(m_ptr2); vm2 = vld1q_f32(m_ptr2);
...@@ -721,7 +1030,7 @@ void Gemv(const float *m_ptr, ...@@ -721,7 +1030,7 @@ void Gemv(const float *m_ptr,
float sum3 = vaddvq_f32(vsum3); float sum3 = vaddvq_f32(vsum3);
// handle remaining w // handle remaining w
for (index_t w = 0; w < remain_w; ++w) { for (; w < width; ++w) {
sum0 += m_ptr0[0] * v_ptr0[0]; sum0 += m_ptr0[0] * v_ptr0[0];
sum1 += m_ptr1[0] * v_ptr0[0]; sum1 += m_ptr1[0] * v_ptr0[0];
sum2 += m_ptr2[0] * v_ptr0[0]; sum2 += m_ptr2[0] * v_ptr0[0];
...@@ -736,16 +1045,13 @@ void Gemv(const float *m_ptr, ...@@ -736,16 +1045,13 @@ void Gemv(const float *m_ptr,
*out_ptr0++ = sum1; *out_ptr0++ = sum1;
*out_ptr0++ = sum2; *out_ptr0++ = sum2;
*out_ptr0++ = sum3; *out_ptr0++ = sum3;
} } else {
for (index_t hh = h; hh < height; ++hh) {
// handle remaining h
index_t remain_start_height = height_d4 << 2;
#pragma omp parallel for
for (index_t h = 0; h < remain_h; ++h) {
float32x4_t vsum0 = vdupq_n_f32(0.f); float32x4_t vsum0 = vdupq_n_f32(0.f);
const float *m_ptr0 = m_ptr + (h + remain_start_height) * width; const float *m_ptr0 = m_ptr + hh * width;
const float *v_ptr0 = v_ptr + b * width; const float *v_ptr0 = v_ptr + b * width;
for (index_t w = 0; w < width_d4; ++w) { index_t w;
for (w = 0; w + 3 < width; w += 4) {
float32x4_t vm = vld1q_f32(m_ptr0); float32x4_t vm = vld1q_f32(m_ptr0);
float32x4_t vv = vld1q_f32(v_ptr0); float32x4_t vv = vld1q_f32(v_ptr0);
vsum0 = vmlaq_f32(vsum0, vm, vv); vsum0 = vmlaq_f32(vsum0, vm, vv);
...@@ -753,14 +1059,16 @@ void Gemv(const float *m_ptr, ...@@ -753,14 +1059,16 @@ void Gemv(const float *m_ptr,
v_ptr0 += 4; v_ptr0 += 4;
} }
float sum = vaddvq_f32(vsum0); float sum = vaddvq_f32(vsum0);
for (index_t w = 0; w < remain_w; ++w) { for (; w < width; ++w) {
sum += m_ptr0[0] * v_ptr0[0]; sum += m_ptr0[0] * v_ptr0[0];
m_ptr0++; m_ptr0++;
v_ptr0++; v_ptr0++;
} }
out_ptr[remain_start_height + h + b * height] = sum; out_ptr[b * height + hh] = sum;
}
} }
} // if
} // h
} // b
#else #else
GemvRef(m_ptr, v_ptr, batch, width, height, out_ptr); GemvRef(m_ptr, v_ptr, batch, width, height, out_ptr);
#endif #endif
......
...@@ -34,6 +34,7 @@ void Gemm(const float *A, ...@@ -34,6 +34,7 @@ void Gemm(const float *A,
void GemmRef(const float *A, void GemmRef(const float *A,
const float *B, const float *B,
const index_t batch,
const index_t height, const index_t height,
const index_t K, const index_t K,
const index_t width, const index_t width,
......
...@@ -21,62 +21,98 @@ ...@@ -21,62 +21,98 @@
namespace mace { namespace mace {
TEST(GEMMTest, gemm) { namespace {
index_t N = 17;
index_t M = 33; void GemmTest(index_t batch, index_t N, index_t K, index_t M) {
index_t K = 64; std::unique_ptr<float[]> A(new float[batch * N * K]);
std::unique_ptr<float[]> A(new float[N * K]); std::unique_ptr<float[]> B(new float[batch * K * M]);
std::unique_ptr<float[]> B(new float[K * M]); std::unique_ptr<float[]> C(new float[batch * N * M]);
std::unique_ptr<float[]> C(new float[N * M]); std::unique_ptr<float[]> C_ref(new float[batch * N * M]);
std::unique_ptr<float[]> C_ref(new float[N * M]);
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1); std::normal_distribution<float> nd(0, 1);
std::generate(A.get(), A.get() + N * K, std::generate(A.get(), A.get() + batch * N * K,
[&gen, &nd] { [&gen, &nd] {
return nd(gen); return nd(gen);
}); });
std::generate(B.get(), B.get() + K * M, std::generate(B.get(), B.get() + batch * K * M,
[&gen, &nd] { [&gen, &nd] {
return nd(gen); return nd(gen);
}); });
kernels::Gemm(A.get(), B.get(), 1, N, K, M, C.get()); kernels::Gemm(A.get(), B.get(), batch, N, K, M, C.get());
kernels::GemmRef(A.get(), B.get(), N, K, M, C_ref.get()); kernels::GemmRef(A.get(), B.get(), batch, N, K, M, C_ref.get());
for (int i = 0; i < N * M; ++i) { for (int i = 0; i < batch * N * M; ++i) {
EXPECT_NEAR(C_ref[i], C[i], 0.1); EXPECT_NEAR(C_ref[i], C[i], 0.1);
} }
} }
TEST(GEMMTest, gemv) { void GemvTest(index_t batch, index_t N, index_t M) {
index_t N = 17; std::unique_ptr<float[]> A(new float[N * M]);
index_t K = 63; std::unique_ptr<float[]> B(new float[batch * M]);
std::unique_ptr<float[]> A(new float[N * K]); std::unique_ptr<float[]> C(new float[batch * N]);
std::unique_ptr<float[]> B(new float[K]); std::unique_ptr<float[]> C_ref(new float[batch * N]);
std::unique_ptr<float[]> C(new float[N]);
std::unique_ptr<float[]> C_ref(new float[N]);
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1); std::normal_distribution<float> nd(0, 1);
std::generate(A.get(), A.get() + N * K, std::generate(A.get(), A.get() + N * M,
[&gen, &nd] { [&gen, &nd] {
return nd(gen); return nd(gen);
}); });
std::generate(B.get(), B.get() + K, std::generate(B.get(), B.get() + batch * M,
[&gen, &nd] { [&gen, &nd] {
return nd(gen); return nd(gen);
}); });
kernels::Gemv(A.get(), B.get(), 1, K, N, C.get()); kernels::Gemv(A.get(), B.get(), batch, M, N, C.get());
kernels::GemvRef(A.get(), B.get(), 1, K, N, C_ref.get()); kernels::GemvRef(A.get(), B.get(), batch, M, N, C_ref.get());
for (int i = 0; i < N; ++i) { for (int i = 0; i < batch * N; ++i) {
EXPECT_NEAR(C_ref[i], C[i], 0.1); EXPECT_NEAR(C_ref[i], C[i], 0.1);
} }
} }
} // namespace
TEST(GEMMTest, AlignedWithoutBatch) {
GemmTest(1, 1, 64, 128);
GemmTest(1, 2, 64, 128);
GemmTest(1, 3, 64, 128);
GemmTest(1, 4, 64, 128);
GemmTest(1, 5, 64, 128);
GemmTest(1, 6, 64, 128);
GemmTest(1, 7, 64, 128);
GemmTest(1, 17, 64, 128);
}
TEST(GEMMTest, UnalignedWithoutBatch) {
GemmTest(1, 1, 63, 127);
GemmTest(1, 2, 63, 127);
GemmTest(1, 3, 63, 127);
GemmTest(1, 4, 63, 127);
GemmTest(1, 5, 63, 127);
GemmTest(1, 6, 63, 127);
GemmTest(1, 7, 63, 127);
GemmTest(1, 17, 63, 127);
}
TEST(GEMMTest, UnalignedWithBatch) {
GemmTest(3, 1, 63, 127);
GemmTest(3, 2, 63, 127);
GemmTest(3, 3, 63, 127);
GemmTest(3, 4, 63, 127);
GemmTest(3, 5, 63, 127);
GemmTest(3, 6, 63, 127);
GemmTest(3, 7, 63, 127);
GemmTest(3, 17, 63, 127);
}
TEST(GEMMTest, gemv) {
GemvTest(1, 17, 63);
GemvTest(3, 17, 63);
}
} // namespace mace } // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册