diff --git a/mace/kernels/gemm.cc b/mace/kernels/gemm.cc index 477831c7acc7b94544118d05a988d63b0c49128a..40c90f58c0a4c6c8fda054194b6a5cced71cece6 100644 --- a/mace/kernels/gemm.cc +++ b/mace/kernels/gemm.cc @@ -12,18 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include +#include "mace/core/tensor.h" +#include "mace/kernels/gemm.h" + #if defined(MACE_ENABLE_NEON) #include #endif -#include "mace/core/macros.h" -#include "mace/kernels/gemm.h" -#include "mace/utils/logging.h" - #if defined(MACE_ENABLE_NEON) && !defined(__aarch64__) #define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3]) #endif @@ -37,13 +35,14 @@ inline void GemmBlock(const float *A, const index_t height, const index_t K, const index_t width, - const index_t stride_k, - const index_t stride_w, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, float *C) { for (int i = 0; i < height; ++i) { for (int j = 0; j < width; ++j) { for (int k = 0; k < K; ++k) { - C[i * stride_w + j] += A[i * stride_k + k] * B[k * stride_w + j]; + C[i * stride_c + j] += A[i * stride_a + k] * B[k * stride_b + j]; } } } @@ -75,8 +74,9 @@ inline void GemmBlock(const float *A, inline void Gemm884(const float *a_ptr, const float *b_ptr, - index_t stride_k, - index_t stride_w, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, float *c_ptr) { #if defined(MACE_ENABLE_NEON) float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, @@ -86,38 +86,38 @@ inline void Gemm884(const float *a_ptr, a0 = vld1q_f32(a_ptr); a1 = vld1q_f32(a_ptr + 4); - a2 = vld1q_f32(a_ptr + 1 * stride_k); - a3 = vld1q_f32(a_ptr + 1 * stride_k + 4); - a4 = vld1q_f32(a_ptr + 2 * stride_k); - a5 = vld1q_f32(a_ptr + 2 * stride_k + 4); - a6 = vld1q_f32(a_ptr + 3 * stride_k); - a7 = vld1q_f32(a_ptr + 3 * stride_k + 4); - a8 = vld1q_f32(a_ptr + 4 * stride_k); - a9 = vld1q_f32(a_ptr + 4 * stride_k + 4); - a10 = vld1q_f32(a_ptr + 5 * stride_k); - a11 = vld1q_f32(a_ptr + 5 * stride_k + 4); - a12 = vld1q_f32(a_ptr + 6 * stride_k); - a13 = vld1q_f32(a_ptr + 6 * stride_k + 4); - a14 = vld1q_f32(a_ptr + 7 * stride_k); - a15 = vld1q_f32(a_ptr + 7 * stride_k + 4); + a2 = vld1q_f32(a_ptr + 1 * stride_a); + a3 = vld1q_f32(a_ptr + 1 * stride_a + 4); + a4 = vld1q_f32(a_ptr + 2 * stride_a); + a5 = vld1q_f32(a_ptr + 2 * stride_a + 4); + a6 = vld1q_f32(a_ptr + 3 * stride_a); + a7 = vld1q_f32(a_ptr + 3 * stride_a + 4); + a8 = vld1q_f32(a_ptr + 4 * stride_a); + a9 = vld1q_f32(a_ptr + 4 * stride_a + 4); + a10 = vld1q_f32(a_ptr + 5 * stride_a); + a11 = vld1q_f32(a_ptr + 5 * stride_a + 4); + a12 = vld1q_f32(a_ptr + 6 * stride_a); + a13 = vld1q_f32(a_ptr + 6 * stride_a + 4); + a14 = vld1q_f32(a_ptr + 7 * stride_a); + a15 = vld1q_f32(a_ptr + 7 * stride_a + 4); b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_w); - b2 = vld1q_f32(b_ptr + 2 * stride_w); - b3 = vld1q_f32(b_ptr + 3 * stride_w); - b4 = vld1q_f32(b_ptr + 4 * stride_w); - b5 = vld1q_f32(b_ptr + 5 * stride_w); - b6 = vld1q_f32(b_ptr + 6 * stride_w); - b7 = vld1q_f32(b_ptr + 7 * stride_w); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + b4 = vld1q_f32(b_ptr + 4 * stride_b); + b5 = vld1q_f32(b_ptr + 5 * stride_b); + b6 = vld1q_f32(b_ptr + 6 * stride_b); + b7 = vld1q_f32(b_ptr + 7 * stride_b); c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_w); - c2 = vld1q_f32(c_ptr + 2 * stride_w); - c3 = vld1q_f32(c_ptr + 3 * stride_w); - c4 = vld1q_f32(c_ptr + 4 * stride_w); - c5 = vld1q_f32(c_ptr + 5 * stride_w); - c6 = vld1q_f32(c_ptr + 6 * stride_w); - c7 = vld1q_f32(c_ptr + 7 * stride_w); + c1 = vld1q_f32(c_ptr + 1 * stride_c); + c2 = vld1q_f32(c_ptr + 2 * stride_c); + c3 = vld1q_f32(c_ptr + 3 * stride_c); + c4 = vld1q_f32(c_ptr + 4 * stride_c); + c5 = vld1q_f32(c_ptr + 5 * stride_c); + c6 = vld1q_f32(c_ptr + 6 * stride_c); + c7 = vld1q_f32(c_ptr + 7 * stride_c); #if defined(__aarch64__) MACE_GEMM_PART_CAL(0, 0, 1); @@ -140,25 +140,28 @@ inline void Gemm884(const float *a_ptr, #endif vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_w, c1); - vst1q_f32(c_ptr + 2 * stride_w, c2); - vst1q_f32(c_ptr + 3 * stride_w, c3); - vst1q_f32(c_ptr + 4 * stride_w, c4); - vst1q_f32(c_ptr + 5 * stride_w, c5); - vst1q_f32(c_ptr + 6 * stride_w, c6); - vst1q_f32(c_ptr + 7 * stride_w, c7); + vst1q_f32(c_ptr + 1 * stride_c, c1); + vst1q_f32(c_ptr + 2 * stride_c, c2); + vst1q_f32(c_ptr + 3 * stride_c, c3); + vst1q_f32(c_ptr + 4 * stride_c, c4); + vst1q_f32(c_ptr + 5 * stride_c, c5); + vst1q_f32(c_ptr + 6 * stride_c, c6); + vst1q_f32(c_ptr + 7 * stride_c, c7); #else - GemmBlock(a_ptr, b_ptr, 8, 8, 4, stride_k, stride_w, c_ptr); + GemmBlock(a_ptr, b_ptr, 8, 8, 4, stride_a, stride_b, stride_c, c_ptr); #endif } inline void Gemm184(const float *a_ptr, const float *b_ptr, - index_t stride_k, - index_t stride_w, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, float *c_ptr) { - MACE_UNUSED(stride_k); #if defined(MACE_ENABLE_NEON) + MACE_UNUSED(stride_a); + MACE_UNUSED(stride_c); + float32x4_t a0, a1; float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; float32x4_t c0; @@ -167,13 +170,13 @@ inline void Gemm184(const float *a_ptr, a1 = vld1q_f32(a_ptr + 4); b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_w); - b2 = vld1q_f32(b_ptr + 2 * stride_w); - b3 = vld1q_f32(b_ptr + 3 * stride_w); - b4 = vld1q_f32(b_ptr + 4 * stride_w); - b5 = vld1q_f32(b_ptr + 5 * stride_w); - b6 = vld1q_f32(b_ptr + 6 * stride_w); - b7 = vld1q_f32(b_ptr + 7 * stride_w); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + b4 = vld1q_f32(b_ptr + 4 * stride_b); + b5 = vld1q_f32(b_ptr + 5 * stride_b); + b6 = vld1q_f32(b_ptr + 6 * stride_b); + b7 = vld1q_f32(b_ptr + 7 * stride_b); c0 = vld1q_f32(c_ptr); @@ -185,14 +188,15 @@ inline void Gemm184(const float *a_ptr, vst1q_f32(c_ptr, c0); #else - GemmBlock(a_ptr, b_ptr, 1, 8, 4, stride_k, stride_w, c_ptr); + GemmBlock(a_ptr, b_ptr, 1, 8, 4, stride_a, stride_b, stride_c, c_ptr); #endif } inline void Gemm284(const float *a_ptr, const float *b_ptr, - index_t stride_k, - index_t stride_w, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, float *c_ptr) { #if defined(MACE_ENABLE_NEON) float32x4_t a0, a1, a2, a3; @@ -201,20 +205,20 @@ inline void Gemm284(const float *a_ptr, a0 = vld1q_f32(a_ptr); a1 = vld1q_f32(a_ptr + 4); - a2 = vld1q_f32(a_ptr + 1 * stride_k); - a3 = vld1q_f32(a_ptr + 1 * stride_k + 4); + a2 = vld1q_f32(a_ptr + 1 * stride_a); + a3 = vld1q_f32(a_ptr + 1 * stride_a + 4); b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_w); - b2 = vld1q_f32(b_ptr + 2 * stride_w); - b3 = vld1q_f32(b_ptr + 3 * stride_w); - b4 = vld1q_f32(b_ptr + 4 * stride_w); - b5 = vld1q_f32(b_ptr + 5 * stride_w); - b6 = vld1q_f32(b_ptr + 6 * stride_w); - b7 = vld1q_f32(b_ptr + 7 * stride_w); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + b4 = vld1q_f32(b_ptr + 4 * stride_b); + b5 = vld1q_f32(b_ptr + 5 * stride_b); + b6 = vld1q_f32(b_ptr + 6 * stride_b); + b7 = vld1q_f32(b_ptr + 7 * stride_b); c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_w); + c1 = vld1q_f32(c_ptr + 1 * stride_c); #if defined(__aarch64__) MACE_GEMM_PART_CAL(0, 0, 1); @@ -225,16 +229,17 @@ inline void Gemm284(const float *a_ptr, #endif vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_w, c1); + vst1q_f32(c_ptr + 1 * stride_c, c1); #else - GemmBlock(a_ptr, b_ptr, 2, 8, 4, stride_k, stride_w, c_ptr); + GemmBlock(a_ptr, b_ptr, 2, 8, 4, stride_a, stride_b, stride_c, c_ptr); #endif } inline void Gemm384(const float *a_ptr, const float *b_ptr, - index_t stride_k, - index_t stride_w, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, float *c_ptr) { #if defined(MACE_ENABLE_NEON) float32x4_t a0, a1, a2, a3, a4, a5; @@ -243,23 +248,23 @@ inline void Gemm384(const float *a_ptr, a0 = vld1q_f32(a_ptr); a1 = vld1q_f32(a_ptr + 4); - a2 = vld1q_f32(a_ptr + 1 * stride_k); - a3 = vld1q_f32(a_ptr + 1 * stride_k + 4); - a4 = vld1q_f32(a_ptr + 2 * stride_k); - a5 = vld1q_f32(a_ptr + 2 * stride_k + 4); + a2 = vld1q_f32(a_ptr + 1 * stride_a); + a3 = vld1q_f32(a_ptr + 1 * stride_a + 4); + a4 = vld1q_f32(a_ptr + 2 * stride_a); + a5 = vld1q_f32(a_ptr + 2 * stride_a + 4); b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_w); - b2 = vld1q_f32(b_ptr + 2 * stride_w); - b3 = vld1q_f32(b_ptr + 3 * stride_w); - b4 = vld1q_f32(b_ptr + 4 * stride_w); - b5 = vld1q_f32(b_ptr + 5 * stride_w); - b6 = vld1q_f32(b_ptr + 6 * stride_w); - b7 = vld1q_f32(b_ptr + 7 * stride_w); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + b4 = vld1q_f32(b_ptr + 4 * stride_b); + b5 = vld1q_f32(b_ptr + 5 * stride_b); + b6 = vld1q_f32(b_ptr + 6 * stride_b); + b7 = vld1q_f32(b_ptr + 7 * stride_b); c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_w); - c2 = vld1q_f32(c_ptr + 2 * stride_w); + c1 = vld1q_f32(c_ptr + 1 * stride_c); + c2 = vld1q_f32(c_ptr + 2 * stride_c); #if defined(__aarch64__) MACE_GEMM_PART_CAL(0, 0, 1); @@ -272,17 +277,18 @@ inline void Gemm384(const float *a_ptr, #endif vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_w, c1); - vst1q_f32(c_ptr + 2 * stride_w, c2); + vst1q_f32(c_ptr + 1 * stride_c, c1); + vst1q_f32(c_ptr + 2 * stride_c, c2); #else - GemmBlock(a_ptr, b_ptr, 3, 8, 4, stride_k, stride_w, c_ptr); + GemmBlock(a_ptr, b_ptr, 3, 8, 4, stride_a, stride_b, stride_c, c_ptr); #endif } inline void Gemm484(const float *a_ptr, const float *b_ptr, - index_t stride_k, - index_t stride_w, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, float *c_ptr) { #if defined(MACE_ENABLE_NEON) float32x4_t a0, a1, a2, a3, a4, a5, a6, a7; @@ -291,26 +297,26 @@ inline void Gemm484(const float *a_ptr, a0 = vld1q_f32(a_ptr); a1 = vld1q_f32(a_ptr + 4); - a2 = vld1q_f32(a_ptr + 1 * stride_k); - a3 = vld1q_f32(a_ptr + 1 * stride_k + 4); - a4 = vld1q_f32(a_ptr + 2 * stride_k); - a5 = vld1q_f32(a_ptr + 2 * stride_k + 4); - a6 = vld1q_f32(a_ptr + 3 * stride_k); - a7 = vld1q_f32(a_ptr + 3 * stride_k + 4); + a2 = vld1q_f32(a_ptr + 1 * stride_a); + a3 = vld1q_f32(a_ptr + 1 * stride_a + 4); + a4 = vld1q_f32(a_ptr + 2 * stride_a); + a5 = vld1q_f32(a_ptr + 2 * stride_a + 4); + a6 = vld1q_f32(a_ptr + 3 * stride_a); + a7 = vld1q_f32(a_ptr + 3 * stride_a + 4); b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_w); - b2 = vld1q_f32(b_ptr + 2 * stride_w); - b3 = vld1q_f32(b_ptr + 3 * stride_w); - b4 = vld1q_f32(b_ptr + 4 * stride_w); - b5 = vld1q_f32(b_ptr + 5 * stride_w); - b6 = vld1q_f32(b_ptr + 6 * stride_w); - b7 = vld1q_f32(b_ptr + 7 * stride_w); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + b4 = vld1q_f32(b_ptr + 4 * stride_b); + b5 = vld1q_f32(b_ptr + 5 * stride_b); + b6 = vld1q_f32(b_ptr + 6 * stride_b); + b7 = vld1q_f32(b_ptr + 7 * stride_b); c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_w); - c2 = vld1q_f32(c_ptr + 2 * stride_w); - c3 = vld1q_f32(c_ptr + 3 * stride_w); + c1 = vld1q_f32(c_ptr + 1 * stride_c); + c2 = vld1q_f32(c_ptr + 2 * stride_c); + c3 = vld1q_f32(c_ptr + 3 * stride_c); #if defined(__aarch64__) MACE_GEMM_PART_CAL(0, 0, 1); @@ -325,18 +331,19 @@ inline void Gemm484(const float *a_ptr, #endif vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_w, c1); - vst1q_f32(c_ptr + 2 * stride_w, c2); - vst1q_f32(c_ptr + 3 * stride_w, c3); + vst1q_f32(c_ptr + 1 * stride_c, c1); + vst1q_f32(c_ptr + 2 * stride_c, c2); + vst1q_f32(c_ptr + 3 * stride_c, c3); #else - GemmBlock(a_ptr, b_ptr, 4, 8, 4, stride_k, stride_w, c_ptr); + GemmBlock(a_ptr, b_ptr, 4, 8, 4, stride_a, stride_b, stride_c, c_ptr); #endif } inline void Gemm584(const float *a_ptr, const float *b_ptr, - index_t stride_k, - index_t stride_w, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, float *c_ptr) { #if defined(MACE_ENABLE_NEON) float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9; @@ -345,29 +352,29 @@ inline void Gemm584(const float *a_ptr, a0 = vld1q_f32(a_ptr); a1 = vld1q_f32(a_ptr + 4); - a2 = vld1q_f32(a_ptr + 1 * stride_k); - a3 = vld1q_f32(a_ptr + 1 * stride_k + 4); - a4 = vld1q_f32(a_ptr + 2 * stride_k); - a5 = vld1q_f32(a_ptr + 2 * stride_k + 4); - a6 = vld1q_f32(a_ptr + 3 * stride_k); - a7 = vld1q_f32(a_ptr + 3 * stride_k + 4); - a8 = vld1q_f32(a_ptr + 4 * stride_k); - a9 = vld1q_f32(a_ptr + 4 * stride_k + 4); + a2 = vld1q_f32(a_ptr + 1 * stride_a); + a3 = vld1q_f32(a_ptr + 1 * stride_a + 4); + a4 = vld1q_f32(a_ptr + 2 * stride_a); + a5 = vld1q_f32(a_ptr + 2 * stride_a + 4); + a6 = vld1q_f32(a_ptr + 3 * stride_a); + a7 = vld1q_f32(a_ptr + 3 * stride_a + 4); + a8 = vld1q_f32(a_ptr + 4 * stride_a); + a9 = vld1q_f32(a_ptr + 4 * stride_a + 4); b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_w); - b2 = vld1q_f32(b_ptr + 2 * stride_w); - b3 = vld1q_f32(b_ptr + 3 * stride_w); - b4 = vld1q_f32(b_ptr + 4 * stride_w); - b5 = vld1q_f32(b_ptr + 5 * stride_w); - b6 = vld1q_f32(b_ptr + 6 * stride_w); - b7 = vld1q_f32(b_ptr + 7 * stride_w); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + b4 = vld1q_f32(b_ptr + 4 * stride_b); + b5 = vld1q_f32(b_ptr + 5 * stride_b); + b6 = vld1q_f32(b_ptr + 6 * stride_b); + b7 = vld1q_f32(b_ptr + 7 * stride_b); c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_w); - c2 = vld1q_f32(c_ptr + 2 * stride_w); - c3 = vld1q_f32(c_ptr + 3 * stride_w); - c4 = vld1q_f32(c_ptr + 4 * stride_w); + c1 = vld1q_f32(c_ptr + 1 * stride_c); + c2 = vld1q_f32(c_ptr + 2 * stride_c); + c3 = vld1q_f32(c_ptr + 3 * stride_c); + c4 = vld1q_f32(c_ptr + 4 * stride_c); #if defined(__aarch64__) MACE_GEMM_PART_CAL(0, 0, 1); @@ -384,19 +391,20 @@ inline void Gemm584(const float *a_ptr, #endif vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_w, c1); - vst1q_f32(c_ptr + 2 * stride_w, c2); - vst1q_f32(c_ptr + 3 * stride_w, c3); - vst1q_f32(c_ptr + 4 * stride_w, c4); + vst1q_f32(c_ptr + 1 * stride_c, c1); + vst1q_f32(c_ptr + 2 * stride_c, c2); + vst1q_f32(c_ptr + 3 * stride_c, c3); + vst1q_f32(c_ptr + 4 * stride_c, c4); #else - GemmBlock(a_ptr, b_ptr, 5, 8, 4, stride_k, stride_w, c_ptr); + GemmBlock(a_ptr, b_ptr, 5, 8, 4, stride_a, stride_b, stride_c, c_ptr); #endif } inline void Gemm684(const float *a_ptr, const float *b_ptr, - index_t stride_k, - index_t stride_w, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, float *c_ptr) { #if defined(MACE_ENABLE_NEON) float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11; @@ -405,32 +413,32 @@ inline void Gemm684(const float *a_ptr, a0 = vld1q_f32(a_ptr); a1 = vld1q_f32(a_ptr + 4); - a2 = vld1q_f32(a_ptr + 1 * stride_k); - a3 = vld1q_f32(a_ptr + 1 * stride_k + 4); - a4 = vld1q_f32(a_ptr + 2 * stride_k); - a5 = vld1q_f32(a_ptr + 2 * stride_k + 4); - a6 = vld1q_f32(a_ptr + 3 * stride_k); - a7 = vld1q_f32(a_ptr + 3 * stride_k + 4); - a8 = vld1q_f32(a_ptr + 4 * stride_k); - a9 = vld1q_f32(a_ptr + 4 * stride_k + 4); - a10 = vld1q_f32(a_ptr + 5 * stride_k); - a11 = vld1q_f32(a_ptr + 5 * stride_k + 4); + a2 = vld1q_f32(a_ptr + 1 * stride_a); + a3 = vld1q_f32(a_ptr + 1 * stride_a + 4); + a4 = vld1q_f32(a_ptr + 2 * stride_a); + a5 = vld1q_f32(a_ptr + 2 * stride_a + 4); + a6 = vld1q_f32(a_ptr + 3 * stride_a); + a7 = vld1q_f32(a_ptr + 3 * stride_a + 4); + a8 = vld1q_f32(a_ptr + 4 * stride_a); + a9 = vld1q_f32(a_ptr + 4 * stride_a + 4); + a10 = vld1q_f32(a_ptr + 5 * stride_a); + a11 = vld1q_f32(a_ptr + 5 * stride_a + 4); b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_w); - b2 = vld1q_f32(b_ptr + 2 * stride_w); - b3 = vld1q_f32(b_ptr + 3 * stride_w); - b4 = vld1q_f32(b_ptr + 4 * stride_w); - b5 = vld1q_f32(b_ptr + 5 * stride_w); - b6 = vld1q_f32(b_ptr + 6 * stride_w); - b7 = vld1q_f32(b_ptr + 7 * stride_w); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + b4 = vld1q_f32(b_ptr + 4 * stride_b); + b5 = vld1q_f32(b_ptr + 5 * stride_b); + b6 = vld1q_f32(b_ptr + 6 * stride_b); + b7 = vld1q_f32(b_ptr + 7 * stride_b); c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_w); - c2 = vld1q_f32(c_ptr + 2 * stride_w); - c3 = vld1q_f32(c_ptr + 3 * stride_w); - c4 = vld1q_f32(c_ptr + 4 * stride_w); - c5 = vld1q_f32(c_ptr + 5 * stride_w); + c1 = vld1q_f32(c_ptr + 1 * stride_c); + c2 = vld1q_f32(c_ptr + 2 * stride_c); + c3 = vld1q_f32(c_ptr + 3 * stride_c); + c4 = vld1q_f32(c_ptr + 4 * stride_c); + c5 = vld1q_f32(c_ptr + 5 * stride_c); #if defined(__aarch64__) MACE_GEMM_PART_CAL(0, 0, 1); @@ -449,21 +457,22 @@ inline void Gemm684(const float *a_ptr, #endif vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_w, c1); - vst1q_f32(c_ptr + 2 * stride_w, c2); - vst1q_f32(c_ptr + 3 * stride_w, c3); - vst1q_f32(c_ptr + 4 * stride_w, c4); - vst1q_f32(c_ptr + 5 * stride_w, c5); + vst1q_f32(c_ptr + 1 * stride_c, c1); + vst1q_f32(c_ptr + 2 * stride_c, c2); + vst1q_f32(c_ptr + 3 * stride_c, c3); + vst1q_f32(c_ptr + 4 * stride_c, c4); + vst1q_f32(c_ptr + 5 * stride_c, c5); #else - GemmBlock(a_ptr, b_ptr, 6, 8, 4, stride_k, stride_w, c_ptr); + GemmBlock(a_ptr, b_ptr, 6, 8, 4, stride_a, stride_b, stride_c, c_ptr); #endif } inline void Gemm784(const float *a_ptr, const float *b_ptr, - index_t stride_k, - index_t stride_w, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, float *c_ptr) { #if defined(MACE_ENABLE_NEON) float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13; @@ -472,35 +481,35 @@ inline void Gemm784(const float *a_ptr, a0 = vld1q_f32(a_ptr); a1 = vld1q_f32(a_ptr + 4); - a2 = vld1q_f32(a_ptr + 1 * stride_k); - a3 = vld1q_f32(a_ptr + 1 * stride_k + 4); - a4 = vld1q_f32(a_ptr + 2 * stride_k); - a5 = vld1q_f32(a_ptr + 2 * stride_k + 4); - a6 = vld1q_f32(a_ptr + 3 * stride_k); - a7 = vld1q_f32(a_ptr + 3 * stride_k + 4); - a8 = vld1q_f32(a_ptr + 4 * stride_k); - a9 = vld1q_f32(a_ptr + 4 * stride_k + 4); - a10 = vld1q_f32(a_ptr + 5 * stride_k); - a11 = vld1q_f32(a_ptr + 5 * stride_k + 4); - a12 = vld1q_f32(a_ptr + 6 * stride_k); - a13 = vld1q_f32(a_ptr + 6 * stride_k + 4); + a2 = vld1q_f32(a_ptr + 1 * stride_a); + a3 = vld1q_f32(a_ptr + 1 * stride_a + 4); + a4 = vld1q_f32(a_ptr + 2 * stride_a); + a5 = vld1q_f32(a_ptr + 2 * stride_a + 4); + a6 = vld1q_f32(a_ptr + 3 * stride_a); + a7 = vld1q_f32(a_ptr + 3 * stride_a + 4); + a8 = vld1q_f32(a_ptr + 4 * stride_a); + a9 = vld1q_f32(a_ptr + 4 * stride_a + 4); + a10 = vld1q_f32(a_ptr + 5 * stride_a); + a11 = vld1q_f32(a_ptr + 5 * stride_a + 4); + a12 = vld1q_f32(a_ptr + 6 * stride_a); + a13 = vld1q_f32(a_ptr + 6 * stride_a + 4); b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_w); - b2 = vld1q_f32(b_ptr + 2 * stride_w); - b3 = vld1q_f32(b_ptr + 3 * stride_w); - b4 = vld1q_f32(b_ptr + 4 * stride_w); - b5 = vld1q_f32(b_ptr + 5 * stride_w); - b6 = vld1q_f32(b_ptr + 6 * stride_w); - b7 = vld1q_f32(b_ptr + 7 * stride_w); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + b4 = vld1q_f32(b_ptr + 4 * stride_b); + b5 = vld1q_f32(b_ptr + 5 * stride_b); + b6 = vld1q_f32(b_ptr + 6 * stride_b); + b7 = vld1q_f32(b_ptr + 7 * stride_b); c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_w); - c2 = vld1q_f32(c_ptr + 2 * stride_w); - c3 = vld1q_f32(c_ptr + 3 * stride_w); - c4 = vld1q_f32(c_ptr + 4 * stride_w); - c5 = vld1q_f32(c_ptr + 5 * stride_w); - c6 = vld1q_f32(c_ptr + 6 * stride_w); + c1 = vld1q_f32(c_ptr + 1 * stride_c); + c2 = vld1q_f32(c_ptr + 2 * stride_c); + c3 = vld1q_f32(c_ptr + 3 * stride_c); + c4 = vld1q_f32(c_ptr + 4 * stride_c); + c5 = vld1q_f32(c_ptr + 5 * stride_c); + c6 = vld1q_f32(c_ptr + 6 * stride_c); #if defined(__aarch64__) MACE_GEMM_PART_CAL(0, 0, 1); @@ -521,48 +530,49 @@ inline void Gemm784(const float *a_ptr, #endif vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_w, c1); - vst1q_f32(c_ptr + 2 * stride_w, c2); - vst1q_f32(c_ptr + 3 * stride_w, c3); - vst1q_f32(c_ptr + 4 * stride_w, c4); - vst1q_f32(c_ptr + 5 * stride_w, c5); - vst1q_f32(c_ptr + 6 * stride_w, c6); + vst1q_f32(c_ptr + 1 * stride_c, c1); + vst1q_f32(c_ptr + 2 * stride_c, c2); + vst1q_f32(c_ptr + 3 * stride_c, c3); + vst1q_f32(c_ptr + 4 * stride_c, c4); + vst1q_f32(c_ptr + 5 * stride_c, c5); + vst1q_f32(c_ptr + 6 * stride_c, c6); #else - GemmBlock(a_ptr, b_ptr, 7, 8, 4, stride_k, stride_w, c_ptr); + GemmBlock(a_ptr, b_ptr, 7, 8, 4, stride_a, stride_b, stride_c, c_ptr); #endif } inline void GemmX84(const float *a_ptr, const float *b_ptr, - index_t stride_k, - index_t stride_w, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, float *c_ptr, int row) { switch (row) { case 1: - Gemm184(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + Gemm184(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); break; case 2: - Gemm284(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + Gemm284(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); break; case 3: - Gemm384(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + Gemm384(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); break; case 4: - Gemm484(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + Gemm484(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); break; case 5: - Gemm584(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + Gemm584(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); break; case 6: - Gemm684(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + Gemm684(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); break; case 7: - Gemm784(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + Gemm784(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); break; case 8: - Gemm884(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); break; default: MACE_NOT_IMPLEMENTED; @@ -574,14 +584,15 @@ inline void GemmTile(const float *A, const index_t height, const index_t K, const index_t width, - const index_t stride_k, - const index_t stride_w, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, float *C) { #if defined(MACE_ENABLE_NEON) index_t h, w, k; for (h = 0; h < height - 7; h += 8) { for (k = 0; k < K - 7; k += 8) { - const float *a_ptr = A + (h * stride_k + k); + const float *a_ptr = A + (h * stride_a + k); #if defined(__aarch64__) && defined(__clang__) int nw = width >> 2; if (nw > 0) { @@ -590,38 +601,38 @@ inline void GemmTile(const float *A, a14, a15; a0 = vld1q_f32(a_ptr); a1 = vld1q_f32(a_ptr + 4); - a2 = vld1q_f32(a_ptr + 1 * stride_k); - a3 = vld1q_f32(a_ptr + 1 * stride_k + 4); - a4 = vld1q_f32(a_ptr + 2 * stride_k); - a5 = vld1q_f32(a_ptr + 2 * stride_k + 4); - a6 = vld1q_f32(a_ptr + 3 * stride_k); - a7 = vld1q_f32(a_ptr + 3 * stride_k + 4); - a8 = vld1q_f32(a_ptr + 4 * stride_k); - a9 = vld1q_f32(a_ptr + 4 * stride_k + 4); - a10 = vld1q_f32(a_ptr + 5 * stride_k); - a11 = vld1q_f32(a_ptr + 5 * stride_k + 4); - a12 = vld1q_f32(a_ptr + 6 * stride_k); - a13 = vld1q_f32(a_ptr + 6 * stride_k + 4); - a14 = vld1q_f32(a_ptr + 7 * stride_k); - a15 = vld1q_f32(a_ptr + 7 * stride_k + 4); - - const float *b_ptr0 = B + k * stride_w; - const float *b_ptr1 = B + (k + 1) * stride_w; - const float *b_ptr2 = B + (k + 2) * stride_w; - const float *b_ptr3 = B + (k + 3) * stride_w; - const float *b_ptr4 = B + (k + 4) * stride_w; - const float *b_ptr5 = B + (k + 5) * stride_w; - const float *b_ptr6 = B + (k + 6) * stride_w; - const float *b_ptr7 = B + (k + 7) * stride_w; - - float *c_ptr0 = C + h * stride_w; - float *c_ptr1 = C + (h + 1) * stride_w; - float *c_ptr2 = C + (h + 2) * stride_w; - float *c_ptr3 = C + (h + 3) * stride_w; - float *c_ptr4 = C + (h + 4) * stride_w; - float *c_ptr5 = C + (h + 5) * stride_w; - float *c_ptr6 = C + (h + 6) * stride_w; - float *c_ptr7 = C + (h + 7) * stride_w; + a2 = vld1q_f32(a_ptr + 1 * stride_a); + a3 = vld1q_f32(a_ptr + 1 * stride_a + 4); + a4 = vld1q_f32(a_ptr + 2 * stride_a); + a5 = vld1q_f32(a_ptr + 2 * stride_a + 4); + a6 = vld1q_f32(a_ptr + 3 * stride_a); + a7 = vld1q_f32(a_ptr + 3 * stride_a + 4); + a8 = vld1q_f32(a_ptr + 4 * stride_a); + a9 = vld1q_f32(a_ptr + 4 * stride_a + 4); + a10 = vld1q_f32(a_ptr + 5 * stride_a); + a11 = vld1q_f32(a_ptr + 5 * stride_a + 4); + a12 = vld1q_f32(a_ptr + 6 * stride_a); + a13 = vld1q_f32(a_ptr + 6 * stride_a + 4); + a14 = vld1q_f32(a_ptr + 7 * stride_a); + a15 = vld1q_f32(a_ptr + 7 * stride_a + 4); + + const float *b_ptr0 = B + k * stride_b; + const float *b_ptr1 = B + (k + 1) * stride_b; + const float *b_ptr2 = B + (k + 2) * stride_b; + const float *b_ptr3 = B + (k + 3) * stride_b; + const float *b_ptr4 = B + (k + 4) * stride_b; + const float *b_ptr5 = B + (k + 5) * stride_b; + const float *b_ptr6 = B + (k + 6) * stride_b; + const float *b_ptr7 = B + (k + 7) * stride_b; + + float *c_ptr0 = C + h * stride_c; + float *c_ptr1 = C + (h + 1) * stride_c; + float *c_ptr2 = C + (h + 2) * stride_c; + float *c_ptr3 = C + (h + 3) * stride_c; + float *c_ptr4 = C + (h + 4) * stride_c; + float *c_ptr5 = C + (h + 5) * stride_c; + float *c_ptr6 = C + (h + 6) * stride_c; + float *c_ptr7 = C + (h + 7) * stride_c; asm volatile( "prfm pldl1keep, [%9, #128] \n" @@ -824,53 +835,68 @@ inline void GemmTile(const float *A, } #else // gcc || armv7a for (w = 0; w + 3 < width; w += 4) { - const float *b_ptr = B + (k * stride_w + w); - float *c_ptr = C + (h * stride_w + w); - Gemm884(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + const float *b_ptr = B + (k * stride_b + w); + float *c_ptr = C + (h * stride_c + w); + Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); } #endif // clang && armv8a if (w < width) { - const float *b_ptr = B + (k * stride_w + w); - float *c_ptr = C + (h * stride_w + w); - GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_k, stride_w, c_ptr); + const float *b_ptr = B + (k * stride_b + w); + float *c_ptr = C + (h * stride_c + w); + GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_a, stride_b, stride_c, + c_ptr); } } if (k < K) { - const float *a_ptr = A + (h * stride_k + k); - const float *b_ptr = B + k * stride_w; - float *c_ptr = C + h * stride_w; - GemmBlock(a_ptr, b_ptr, 8, K - k, width, stride_k, stride_w, c_ptr); + const float *a_ptr = A + (h * stride_a + k); + const float *b_ptr = B + k * stride_b; + float *c_ptr = C + h * stride_c; + GemmBlock(a_ptr, b_ptr, 8, K - k, width, stride_a, stride_b, stride_c, + c_ptr); } } if (h < height) { index_t remain_h = height - h; for (k = 0; k < K - 7; k += 8) { - const float *a_ptr = A + (h * stride_k + k); + const float *a_ptr = A + (h * stride_a + k); index_t w; for (w = 0; w + 3 < width; w += 4) { - const float *b_ptr = B + (k * stride_w + w); - float *c_ptr = C + (h * stride_w + w); - GemmX84(a_ptr, b_ptr, stride_k, stride_w, c_ptr, remain_h); + const float *b_ptr = B + (k * stride_b + w); + float *c_ptr = C + (h * stride_c + w); + GemmX84(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr, remain_h); } if (w < width) { - const float *b_ptr = B + (k * stride_w + w); - float *c_ptr = C + (h * stride_w + w); - GemmBlock(a_ptr, b_ptr, remain_h, 8, width - w, stride_k, stride_w, - c_ptr); + const float *b_ptr = B + (k * stride_b + w); + float *c_ptr = C + (h * stride_c + w); + GemmBlock(a_ptr, b_ptr, remain_h, 8, width - w, stride_a, stride_b, + stride_c, c_ptr); } } if (k < K) { - const float *a_ptr = A + (h * stride_k + k); - const float *b_ptr = B + k * stride_w; - float *c_ptr = C + h * stride_w; - GemmBlock(a_ptr, b_ptr, remain_h, K - k, width, stride_k, stride_w, - c_ptr); + const float *a_ptr = A + (h * stride_a + k); + const float *b_ptr = B + k * stride_b; + float *c_ptr = C + h * stride_c; + GemmBlock(a_ptr, b_ptr, remain_h, K - k, width, stride_a, stride_b, + stride_c, c_ptr); } } #else - GemmBlock(A, B, height, K, width, stride_k, stride_w, C); + GemmBlock(A, B, height, K, width, stride_a, stride_b, stride_c, C); #endif // MACE_ENABLE_NEON } + +void Transpose(const float *src, + index_t height, + index_t width, + index_t stride_w, + float *dst) { + for (index_t h = 0; h < height; ++h) { + for (index_t w = 0; w < width; ++w) { + dst[w * height + h] = src[h * stride_w + w]; + } + } +} + } // namespace // A: height x K, B: K x width, C: height x width @@ -880,7 +906,9 @@ void Gemm(const float *A, const index_t height, const index_t K, const index_t width, - float *C) { + float *C, + const bool transpose_a, + const bool transpose_b) { if (width == 1) { for (index_t b = 0; b < batch; ++b) { Gemv(A + b * height * K, B + b * K, 1, K, height, C + b * height); @@ -898,41 +926,77 @@ void Gemm(const float *A, const index_t block_tile_height = RoundUpDiv(height, block_size); const index_t block_tile_width = RoundUpDiv(width, block_size); const index_t block_tile_k = RoundUpDiv(K, block_size); + const index_t block_tile[3] = {block_tile_height, block_tile_width, + block_tile_k}; const index_t remain_height = height % block_size; const index_t remain_width = width % block_size; const index_t remain_k = K % block_size; + const index_t remain[3] = {remain_height, remain_width, remain_k}; #pragma omp parallel for collapse(3) for (index_t n = 0; n < batch; ++n) { - for (index_t bh = 0; bh < block_tile_height; ++bh) { - for (index_t bw = 0; bw < block_tile_width; ++bw) { + for (index_t bh = 0; bh < block_tile[0]; ++bh) { + for (index_t bw = 0; bw < block_tile[1]; ++bw) { const float *a_base = A + n * height * K; const float *b_base = B + n * K * width; float *c_base = C + n * height * width; const index_t ih_begin = bh * block_size; const index_t ih_end = - bh * block_size + (bh == block_tile_height - 1 && remain_height > 0 - ? remain_height - : block_size); + bh * block_size + + (bh == block_tile[0] - 1 && remain[0] > 0 ? remain[0] : block_size); const index_t iw_begin = bw * block_size; const index_t iw_end = - bw * block_size + (bw == block_tile_width - 1 && remain_width > 0 - ? remain_width - : block_size); + bw * block_size + + (bw == block_tile[1] - 1 && remain[1] > 0 ? remain[1] : block_size); - for (index_t bk = 0; bk < block_tile_k; ++bk) { + for (index_t bk = 0; bk < block_tile[2]; ++bk) { const index_t ik_begin = bk * block_size; const index_t ik_end = - bk * block_size + - (bk == block_tile_k - 1 && remain_k > 0 ? remain_k : block_size); + bk * block_size + (bk == block_tile[2] - 1 && remain[2] > 0 + ? remain[2] + : block_size); + + Tensor trans_a; + Tensor trans_b; + const float *real_a = nullptr; + const float *real_b = nullptr; + float *real_c = c_base + (ih_begin * width + iw_begin); + index_t stride_a; + index_t stride_b; + index_t stride_c = width; + + if (transpose_a) { + trans_a.Resize({block_size, block_size}); + float *trans_a_data = trans_a.mutable_data(); + // A[K, H] -> A[H, K] + Transpose(a_base + (ik_begin * height + ih_begin), + ik_end - ik_begin, ih_end - ih_begin, height, + trans_a_data); + real_a = trans_a_data; + stride_a = ik_end - ik_begin; + } else { + real_a = a_base + (ih_begin * K + ik_begin); + stride_a = K; + } + + if (transpose_b) { + trans_b.Resize({block_size, block_size}); + float *trans_b_data = trans_b.mutable_data(); + // B[W, K] -> B[K, W] + Transpose(b_base + (iw_begin * K + ik_begin), iw_end - iw_begin, + ik_end - ik_begin, K, trans_b_data); + real_b = trans_b_data; + stride_b = iw_end - iw_begin; + } else { + real_b = b_base + (ik_begin * width + iw_begin); + stride_b = width; + } // inside block: // calculate C[bh, bw] += A[bh, bk] * B[bk, bw] for one k - GemmTile(a_base + (ih_begin * K + ik_begin), - b_base + (ik_begin * width + iw_begin), ih_end - ih_begin, - ik_end - ik_begin, iw_end - iw_begin, K, width, - c_base + (ih_begin * width + iw_begin)); + GemmTile(real_a, real_b, ih_end - ih_begin, ik_end - ik_begin, + iw_end - iw_begin, stride_a, stride_b, stride_c, real_c); } // bk } // bw } // bh @@ -946,14 +1010,47 @@ void GemmRef(const float *A, const index_t height, const index_t K, const index_t width, - float *C) { + float *C, + const bool transpose_a, + const bool transpose_b) { memset(C, 0, sizeof(float) * batch * height * width); + + Tensor trans_a; + Tensor trans_b; + float *trans_a_data = nullptr; + float *trans_b_data = nullptr; + if (transpose_a) { + trans_a.Resize({height, K}); + trans_a_data = trans_a.mutable_data(); + } + if (transpose_b) { + trans_b.Resize({K, width}); + trans_b_data = trans_b.mutable_data(); + } + for (index_t b = 0; b < batch; ++b) { + const float *real_a = nullptr; + const float *real_b = nullptr; + float *real_c = C + b * height * width; + if (transpose_a) { + // A[K, H] -> A[H, K] + Transpose(A + b * height * K, K, height, height, trans_a_data); + real_a = trans_a_data; + } else { + real_a = A + b * height * K; + } + if (transpose_b) { + // B[W, K] -> B[K, W] + Transpose(B + b * width * K, width, K, K, trans_b_data); + real_b = trans_b_data; + } else { + real_b = B + b * width * K; + } + for (index_t i = 0; i < height; ++i) { for (index_t j = 0; j < width; ++j) { for (index_t k = 0; k < K; ++k) { - C[(b * height + i) * width + j] += - A[(b * height + i) * K + k] * B[(b * K + k) * width + j]; + real_c[i * width + j] += real_a[i * K + k] * real_b[k * width + j]; } } } diff --git a/mace/kernels/gemm.h b/mace/kernels/gemm.h index 9a7ce77bcab2138bc98bcf9863cf9e2e146f4637..ade517a08891270458d2eb7dfb20b114caf4f19c 100644 --- a/mace/kernels/gemm.h +++ b/mace/kernels/gemm.h @@ -30,7 +30,9 @@ void Gemm(const float *A, const index_t height, const index_t K, const index_t width, - float *C); + float *C, + const bool transpose_a = false, + const bool transpose_b = false); void GemmRef(const float *A, const float *B, @@ -38,7 +40,9 @@ void GemmRef(const float *A, const index_t height, const index_t K, const index_t width, - float *C); + float *C, + const bool transpose_a = false, + const bool transpose_b = false); void Gemv(const float *m_ptr, const float *v_ptr, diff --git a/mace/kernels/gemm_test.cc b/mace/kernels/gemm_test.cc index 90a792ef236455617107b311c7b2eea7c7d56aa0..00c4ee0278c12c429e57c1a620d4bf6a275afab2 100644 --- a/mace/kernels/gemm_test.cc +++ b/mace/kernels/gemm_test.cc @@ -13,17 +13,22 @@ // limitations under the License. #include -#include #include +#include -#include "mace/kernels/gemm.h" #include "mace/core/types.h" +#include "mace/kernels/gemm.h" namespace mace { namespace { -void GemmTest(index_t batch, index_t N, index_t K, index_t M) { +void GemmTest(index_t batch, + index_t N, + index_t K, + index_t M, + bool transpose_a, + bool transpose_b) { std::unique_ptr A(new float[batch * N * K]); std::unique_ptr B(new float[batch * K * M]); std::unique_ptr C(new float[batch * N * M]); @@ -34,15 +39,13 @@ void GemmTest(index_t batch, index_t N, index_t K, index_t M) { std::normal_distribution nd(0, 1); std::generate(A.get(), A.get() + batch * N * K, - [&gen, &nd] { - return nd(gen); - }); + [&gen, &nd] { return nd(gen); }); std::generate(B.get(), B.get() + batch * K * M, - [&gen, &nd] { - return nd(gen); - }); - kernels::Gemm(A.get(), B.get(), batch, N, K, M, C.get()); - kernels::GemmRef(A.get(), B.get(), batch, N, K, M, C_ref.get()); + [&gen, &nd] { return nd(gen); }); + kernels::Gemm(A.get(), B.get(), batch, N, K, M, C.get(), transpose_a, + transpose_b); + kernels::GemmRef(A.get(), B.get(), batch, N, K, M, C_ref.get(), transpose_a, + transpose_b); for (int i = 0; i < batch * N * M; ++i) { EXPECT_NEAR(C_ref[i], C[i], 0.1); @@ -59,14 +62,8 @@ void GemvTest(index_t batch, index_t N, index_t M) { std::mt19937 gen(rd()); std::normal_distribution nd(0, 1); - std::generate(A.get(), A.get() + N * M, - [&gen, &nd] { - return nd(gen); - }); - std::generate(B.get(), B.get() + batch * M, - [&gen, &nd] { - return nd(gen); - }); + std::generate(A.get(), A.get() + N * M, [&gen, &nd] { return nd(gen); }); + std::generate(B.get(), B.get() + batch * M, [&gen, &nd] { return nd(gen); }); kernels::Gemv(A.get(), B.get(), batch, M, N, C.get()); kernels::GemvRef(A.get(), B.get(), batch, M, N, C_ref.get()); @@ -78,36 +75,36 @@ void GemvTest(index_t batch, index_t N, index_t M) { } // namespace TEST(GEMMTest, AlignedWithoutBatch) { - GemmTest(1, 1, 64, 128); - GemmTest(1, 2, 64, 128); - GemmTest(1, 3, 64, 128); - GemmTest(1, 4, 64, 128); - GemmTest(1, 5, 64, 128); - GemmTest(1, 6, 64, 128); - GemmTest(1, 7, 64, 128); - GemmTest(1, 17, 64, 128); + GemmTest(1, 1, 64, 128, false, false); + GemmTest(1, 2, 64, 128, false, true); + GemmTest(1, 3, 64, 128, true, false); + GemmTest(1, 4, 64, 128, true, true); + GemmTest(1, 5, 64, 128, false, false); + GemmTest(1, 6, 64, 128, false, true); + GemmTest(1, 7, 64, 128, true, false); + GemmTest(1, 17, 64, 128, true, true); } TEST(GEMMTest, UnalignedWithoutBatch) { - GemmTest(1, 1, 63, 127); - GemmTest(1, 2, 63, 127); - GemmTest(1, 3, 63, 127); - GemmTest(1, 4, 63, 127); - GemmTest(1, 5, 63, 127); - GemmTest(1, 6, 63, 127); - GemmTest(1, 7, 63, 127); - GemmTest(1, 17, 63, 127); + GemmTest(1, 1, 63, 127, false, false); + GemmTest(1, 2, 63, 127, false, true); + GemmTest(1, 3, 63, 127, true, false); + GemmTest(1, 4, 63, 127, true, true); + GemmTest(1, 5, 63, 127, false, false); + GemmTest(1, 6, 63, 127, false, true); + GemmTest(1, 7, 63, 127, true, false); + GemmTest(1, 17, 63, 127, true, true); } TEST(GEMMTest, UnalignedWithBatch) { - GemmTest(3, 1, 63, 127); - GemmTest(3, 2, 63, 127); - GemmTest(3, 3, 63, 127); - GemmTest(3, 4, 63, 127); - GemmTest(3, 5, 63, 127); - GemmTest(3, 6, 63, 127); - GemmTest(3, 7, 63, 127); - GemmTest(3, 17, 63, 127); + GemmTest(3, 1, 63, 127, false, false); + GemmTest(3, 2, 63, 127, false, true); + GemmTest(3, 3, 63, 127, true, false); + GemmTest(3, 4, 63, 127, true, true); + GemmTest(3, 5, 63, 127, false, false); + GemmTest(3, 6, 63, 127, false, true); + GemmTest(3, 7, 63, 127, true, false); + GemmTest(3, 17, 63, 127, true, true); } TEST(GEMMTest, gemv) { diff --git a/mace/kernels/matmul.h b/mace/kernels/matmul.h index dee53c4a6b09cfd42bc8b2f09eb3a2f4f1a9773d..7b54aa1b6ddaf97186f8b1aeb97efbb9cc0ff245 100644 --- a/mace/kernels/matmul.h +++ b/mace/kernels/matmul.h @@ -20,6 +20,8 @@ #endif #include +#include +#include #include #include #include @@ -36,14 +38,39 @@ namespace mace { namespace kernels { -template +template struct MatMulFunctor { MaceStatus operator()(const Tensor *A, - const Tensor *B, - Tensor *C, - StatsFuture *future) { + const Tensor *B, + Tensor *C, + bool transpose_a, + bool transpose_b, + StatsFuture *future) { MACE_UNUSED(future); - std::vector c_shape = {A->dim(0), A->dim(1), B->dim(2), 1}; + + index_t batch; + index_t height; + index_t K; + index_t width; + + index_t rank = A->dim_size(); + height = A->dim(rank - 2); + K = A->dim(rank - 1); + if (transpose_a) { + std::swap(height, K); + } + if (transpose_b) { + width = B->dim(rank - 2); + } else { + width = B->dim(rank - 1); + } + batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1, + std::multiplies()); + + std::vector c_shape = A->shape(); + c_shape[rank - 2] = height; + c_shape[rank - 1] = width; + MACE_RETURN_IF_ERROR(C->Resize(c_shape)); Tensor::MappingGuard guarda(A); @@ -53,28 +80,27 @@ struct MatMulFunctor { const T *b_ptr_base = B->data(); T *c_ptr_base = C->mutable_data(); - const index_t batch = C->dim(0); - const index_t height = C->dim(1); - const index_t width = C->dim(2); - const index_t K = A->dim(2); // It is better to use large block size if it fits for fast cache. // Assume l1 cache size is 32k, we load three blocks at a time (A, B, C), // the block size should be sqrt(32k / sizeof(T) / 3). memset(c_ptr_base, 0, batch * height * width * sizeof(T)); - Gemm(a_ptr_base, b_ptr_base, batch, height, K, width, c_ptr_base); + Gemm(a_ptr_base, b_ptr_base, batch, height, K, width, c_ptr_base, + transpose_a, transpose_b); return MACE_SUCCESS; } }; #ifdef MACE_ENABLE_OPENCL -template +template struct MatMulFunctor { MaceStatus operator()(const Tensor *A, - const Tensor *B, - Tensor *C, - StatsFuture *future); + const Tensor *B, + Tensor *C, + bool transpose_a, + bool transpose_b, + StatsFuture *future); cl::Kernel kernel_; uint32_t kwg_size_; diff --git a/mace/kernels/opencl/buffer_to_image.cc b/mace/kernels/opencl/buffer_to_image.cc index c21d67a1f3531f63f8027cdc0f47ec8f5a033f43..14cc931275e95464d3a488695013a65e512a5588 100644 --- a/mace/kernels/opencl/buffer_to_image.cc +++ b/mace/kernels/opencl/buffer_to_image.cc @@ -134,7 +134,11 @@ MaceStatus BufferToImageFunctor::operator()( } else { b2f_kernel.setArg(idx++, static_cast(buffer->dim(1))); b2f_kernel.setArg(idx++, static_cast(buffer->dim(2))); - b2f_kernel.setArg(idx++, static_cast(buffer->dim(3))); + if (buffer->dim_size() < 4) { + b2f_kernel.setArg(idx++, static_cast(1)); + } else { + b2f_kernel.setArg(idx++, static_cast(buffer->dim(3))); + } } b2f_kernel.setArg(idx++, *(image->opencl_image())); diff --git a/mace/kernels/opencl/helper.cc b/mace/kernels/opencl/helper.cc index fff04eb1c53353f011e3acaf9ab4d22976a6080b..f01cd0a6a7ae41e8c229ed8696322b7bc2322b0e 100644 --- a/mace/kernels/opencl/helper.cc +++ b/mace/kernels/opencl/helper.cc @@ -76,19 +76,27 @@ void CalWinogradFilterImageShape( // [W * C, N * RoundUp<4>(H)] void CalInOutHeightImageShape(const std::vector &shape, /* NHWC */ std::vector *image_shape) { - MACE_CHECK(shape.size() == 4); + std::vector padded_shape = shape; + while (padded_shape.size() < 4) { + padded_shape.push_back(1); + } + MACE_CHECK(padded_shape.size() == 4); image_shape->resize(2); - (*image_shape)[0] = shape[2] * shape[3]; - (*image_shape)[1] = shape[0] * RoundUpDiv4(shape[1]); + (*image_shape)[0] = padded_shape[2] * padded_shape[3]; + (*image_shape)[1] = padded_shape[0] * RoundUpDiv4(padded_shape[1]); } // [RoundUp<4>(W) * C, N * H] void CalInOutWidthImageShape(const std::vector &shape, /* NHWC */ std::vector *image_shape) { - MACE_CHECK(shape.size() == 4); + std::vector padded_shape = shape; + while (padded_shape.size() < 4) { + padded_shape.push_back(1); + } + MACE_CHECK(padded_shape.size() == 4); image_shape->resize(2); - (*image_shape)[0] = RoundUpDiv4(shape[2]) * shape[3]; - (*image_shape)[1] = shape[0] * shape[1]; + (*image_shape)[0] = RoundUpDiv4(padded_shape[2]) * padded_shape[3]; + (*image_shape)[1] = padded_shape[0] * padded_shape[1]; } // [Ic * H * W, (Oc + 3) / 4] @@ -150,10 +158,10 @@ void CalImage2DShape(const std::vector &shape, /* NHWC */ std::vector CalWinogradShape(const std::vector &shape, const BufferType type) { if (type == WINOGRAD_FILTER) { - return {16, shape[0], shape[1], 1}; + return {16, shape[0], shape[1]}; } else if (type == IN_OUT_HEIGHT) { index_t out_width = shape[0] * ((shape[1] - 1) / 2) * ((shape[2] - 1) / 2); - return {16, shape[3], out_width, 1}; + return {16, shape[3], out_width}; } else { LOG(FATAL) << "Mace not supported yet."; return std::vector(); diff --git a/mace/kernels/opencl/image_to_buffer.cc b/mace/kernels/opencl/image_to_buffer.cc index 132b146c5e8e5e844df782819982088c8f24bd5e..e22c6e31664719cce6e2c8310a11591de2219954 100644 --- a/mace/kernels/opencl/image_to_buffer.cc +++ b/mace/kernels/opencl/image_to_buffer.cc @@ -122,7 +122,11 @@ MaceStatus ImageToBufferFunctor::operator()( } else { b2f_kernel.setArg(idx++, static_cast(buffer->dim(1))); b2f_kernel.setArg(idx++, static_cast(buffer->dim(2))); - b2f_kernel.setArg(idx++, static_cast(buffer->dim(3))); + if (buffer->dim_size() < 4) { + b2f_kernel.setArg(idx++, static_cast(1)); + } else { + b2f_kernel.setArg(idx++, static_cast(buffer->dim(3))); + } } b2f_kernel.setArg(idx++, *(image->opencl_image())); diff --git a/mace/kernels/opencl/matmul.cc b/mace/kernels/opencl/matmul.cc index eb7f0e53bcf2228c7711270ae7f632b3dc55edb9..e222ae8d6d6d62ac442663c632baaadd00c533a1 100644 --- a/mace/kernels/opencl/matmul.cc +++ b/mace/kernels/opencl/matmul.cc @@ -24,17 +24,27 @@ template MaceStatus MatMulFunctor::operator()(const Tensor *A, const Tensor *B, Tensor *C, + bool transpose_a, + bool transpose_b, StatsFuture *future) { MACE_UNUSED(future); - std::vector c_shape = {A->dim(0), A->dim(1), B->dim(2), 1}; + MACE_CHECK(!transpose_a && !transpose_b, + "GPU does not support transpose matmul"); + + index_t rank = A->dim_size(); + index_t height = A->dim(rank - 2); + index_t K = A->dim(rank - 1); + index_t width = B->dim(rank - 1); + index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1, + std::multiplies()); + + std::vector c_shape = A->shape(); + c_shape[rank - 2] = height; + c_shape[rank - 1] = width; std::vector c_image_shape; CalImage2DShape(c_shape, BufferType::IN_OUT_HEIGHT, &c_image_shape); MACE_RETURN_IF_ERROR(C->ResizeImage(c_shape, c_image_shape)); - const index_t batch = C->dim(0); - const index_t height = C->dim(1); - const index_t width = C->dim(2); - const index_t height_blocks = RoundUpDiv4(height); const index_t width_blocks = RoundUpDiv4(width); const uint32_t gws[2] = { @@ -82,13 +92,12 @@ MaceStatus MatMulFunctor::operator()(const Tensor *A, kernel_.setArg(idx++, *(C->opencl_image())); kernel_.setArg(idx++, static_cast(height)); kernel_.setArg(idx++, static_cast(width)); - kernel_.setArg(idx++, static_cast(A->dim(2))); + kernel_.setArg(idx++, static_cast(K)); kernel_.setArg(idx++, static_cast(height_blocks)); - kernel_.setArg(idx++, static_cast(RoundUpDiv4(A->dim(2)))); + kernel_.setArg(idx++, static_cast(RoundUpDiv4(K))); const std::vector lws = {kwg_size_ / 64, 64, 0}; - std::string tuning_key = Concat("matmul_opencl_kernel", C->dim(0), C->dim(1), - C->dim(2), C->dim(3)); + std::string tuning_key = Concat("matmul_opencl_kernel", batch, height, width); TuningOrRun2DKernel(kernel_, tuning_key, gws, lws, future); if (runtime->IsOutOfRangeCheckEnabled()) { diff --git a/mace/kernels/opencl/winograd_transform.cc b/mace/kernels/opencl/winograd_transform.cc index 91f0ad0798e8b792cf2a416fa5943221f6e46c06..1dfe5f27ca2a42955e49fcd9ba00fe718d40a8a7 100644 --- a/mace/kernels/opencl/winograd_transform.cc +++ b/mace/kernels/opencl/winograd_transform.cc @@ -74,7 +74,7 @@ MaceStatus WinogradTransformFunctor::operator()( static_cast(RoundUpDiv4(input_tensor->dim(3)))}; if (!IsVecEqual(input_shape_, input_tensor->shape())) { - output_shape = {16, input_tensor->dim(3), out_width, 1}; + output_shape = {16, input_tensor->dim(3), out_width}; std::vector image_shape; CalImage2DShape(output_shape, BufferType::IN_OUT_HEIGHT, &image_shape); MACE_RETURN_IF_ERROR(output_tensor->ResizeImage(output_shape, image_shape)); @@ -104,7 +104,7 @@ MaceStatus WinogradTransformFunctor::operator()( const std::vector lws = {kwg_size_ / 8, 8, 0}; std::string tuning_key = Concat("winograd_transform_kernel", output_tensor->dim(0), output_tensor->dim(1), - output_tensor->dim(2), output_tensor->dim(3)); + output_tensor->dim(2)); TuningOrRun2DKernel(kernel_, tuning_key, gws, lws, future); if (runtime->IsOutOfRangeCheckEnabled()) { diff --git a/mace/ops/matmul.h b/mace/ops/matmul.h index 10a4357868aa845b36f67bda9d9d29bbed803ff2..e5e0dafaafdf547817727dad8079373858406dc6 100644 --- a/mace/ops/matmul.h +++ b/mace/ops/matmul.h @@ -25,24 +25,37 @@ template class MatMulOp : public Operator { public: MatMulOp(const OperatorDef &operator_def, Workspace *ws) - : Operator(operator_def, ws) {} + : Operator(operator_def, ws), + transpose_a_(OperatorBase::GetOptionalArg("transpose_a", false)), + transpose_b_(OperatorBase::GetOptionalArg("transpose_b", false)) { + } MaceStatus Run(StatsFuture *future) override { - const Tensor *A = this->Input(0); - const Tensor *B = this->Input(1); - Tensor *C = this->Output(0); - MACE_CHECK(A->dim_size() == 4 && 4 == B->dim_size()) - << "The dimension of A and B should be 4"; - MACE_CHECK(A->dim(0) == B->dim(0)) << "A and B must have same batch size"; - MACE_CHECK(A->dim(2) == B->dim(1)) - << "the number of A's column " << A->dim(2) - << " must be equal to B's row " << B->dim(1); - - return functor_(A, B, C, future); + const Tensor *A = this->Input(INPUT_A); + const Tensor *B = this->Input(INPUT_B); + Tensor *C = this->Output(OUTPUT); + MACE_CHECK(A->dim_size() == B->dim_size() && A->dim_size() >= 2, + "rank(A) should be equal to rank(B), rank should be greater " + "than or equal to 2"); + index_t rank = A->dim_size(); + for (index_t i = 0; i < rank - 2; ++i) { + MACE_CHECK(A->dim(i) == B->dim(i), "batch dimensions are not equal"); + } + index_t ak = transpose_a_ ? A->dim(rank - 2) : A->dim(rank - 1); + index_t bk = transpose_b_ ? B->dim(rank - 1) : B->dim(rank - 2); + MACE_CHECK(ak == bk, "the number of A's column ", ak, + " must be equal to B's row ", bk); + + return functor_(A, B, C, transpose_a_, transpose_b_, future); } private: + MACE_OP_INPUT_TAGS(INPUT_A, INPUT_B); + MACE_OP_OUTPUT_TAGS(OUTPUT); + kernels::MatMulFunctor functor_; + bool transpose_a_; + bool transpose_b_; }; } // namespace ops diff --git a/mace/ops/matmul_benchmark.cc b/mace/ops/matmul_benchmark.cc index 382fdf7c8829a5887f7920a74285ca0a33178c4b..3e3327d8bf81b54050c251961062a91007c7f1c1 100644 --- a/mace/ops/matmul_benchmark.cc +++ b/mace/ops/matmul_benchmark.cc @@ -31,8 +31,8 @@ void MatMulBenchmark( OpsTestNet net; // Add input data - net.AddRandomInput("A", {batch, height, channels, 1}); - net.AddRandomInput("B", {batch, channels, out_width, 1}); + net.AddRandomInput("A", {batch, height, channels}); + net.AddRandomInput("B", {batch, channels, out_width}); if (D == DeviceType::GPU) { BufferToImage(&net, "A", "AImage", kernels::BufferType::IN_OUT_WIDTH); @@ -65,6 +65,41 @@ void MatMulBenchmark( } net.Sync(); } + +template +void MatMulTransposeBenchmark( + int iters, int batch, int height, int channels, int out_width) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + net.AddRandomInput("A", {batch, height, channels}); + net.AddRandomInput("B", {batch, out_width, channels}); + + if (D == DeviceType::CPU) { + OpDefBuilder("MatMul", "MatMulBM") + .Input("A") + .Input("B") + .AddIntArg("transpose_b", 1) + .Output("Output") + .Finalize(net.NewOperatorDef()); + } else { + MACE_NOT_IMPLEMENTED; + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} } // namespace #define MACE_BM_MATMUL_MACRO(N, H, C, W, TYPE, DEVICE) \ @@ -83,6 +118,20 @@ void MatMulBenchmark( MACE_BM_MATMUL_MACRO(N, H, C, W, float, GPU); \ MACE_BM_MATMUL_MACRO(N, H, C, W, half, GPU); +#define MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, TYPE, DEVICE) \ + static void MACE_BM_MATMUL_##T_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t macc = static_cast(iters) * N * C * H * W; \ + const int64_t tot = static_cast(iters) * N * (C * H + H * W); \ + mace::testing::MaccProcessed(macc); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + MatMulTransposeBenchmark(iters, N, H, C, W); \ + } \ + MACE_BENCHMARK(MACE_BM_MATMUL_##T_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE) + +#define MACE_BM_MATMUL_TRANPOSE(N, H, C, W) \ + MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, float, CPU); + MACE_BM_MATMUL(16, 32, 128, 49); MACE_BM_MATMUL(16, 32, 128, 961); MACE_BM_MATMUL(16, 32, 128, 3969); @@ -90,6 +139,13 @@ MACE_BM_MATMUL(16, 128, 128, 49); MACE_BM_MATMUL(16, 128, 128, 961); MACE_BM_MATMUL(16, 128, 128, 3969); +MACE_BM_MATMUL_TRANPOSE(16, 32, 128, 49); +MACE_BM_MATMUL_TRANPOSE(16, 32, 128, 961); +MACE_BM_MATMUL_TRANPOSE(16, 32, 128, 3969); +MACE_BM_MATMUL_TRANPOSE(16, 128, 128, 49); +MACE_BM_MATMUL_TRANPOSE(16, 128, 128, 961); +MACE_BM_MATMUL_TRANPOSE(16, 128, 128, 3969); + } // namespace test } // namespace ops } // namespace mace diff --git a/mace/ops/matmul_test.cc b/mace/ops/matmul_test.cc index 8999fface5ed208830e37491c9c8d709255ca4fe..397b00fe686a10d38cb04d6bc0c3b3faa35525cc 100644 --- a/mace/ops/matmul_test.cc +++ b/mace/ops/matmul_test.cc @@ -72,46 +72,46 @@ void Simple(const std::vector &A_shape, } // namespace TEST_F(MatMulOpTest, SimpleCPU) { - Simple({1, 2, 3, 1}, {1, 2, 3, 4, 5, 6}, {1, 3, 2, 1}, - {1, 2, 3, 4, 5, 6}, {1, 2, 2, 1}, {22, 28, 49, 64}); + Simple({1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 3, 2}, + {1, 2, 3, 4, 5, 6}, {1, 2, 2}, {22, 28, 49, 64}); Simple( - {1, 5, 5, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, - 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, - {1, 5, 5, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, - 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, - {1, 5, 5, 1}, {215, 230, 245, 260, 275, 490, 530, 570, 610, - 650, 765, 830, 895, 960, 1025, 1040, 1130, 1220, - 1310, 1400, 1315, 1430, 1545, 1660, 1775}); + {1, 5, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, + {1, 5, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, + {1, 5, 5}, {215, 230, 245, 260, 275, 490, 530, 570, 610, + 650, 765, 830, 895, 960, 1025, 1040, 1130, 1220, + 1310, 1400, 1315, 1430, 1545, 1660, 1775}); } TEST_F(MatMulOpTest, SimpleCPUWithBatch) { - Simple({2, 2, 3, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, - {2, 3, 2, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, - {2, 2, 2, 1}, {22, 28, 49, 64, 22, 28, 49, 64}); + Simple({2, 2, 3}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, + {2, 3, 2}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, + {2, 2, 2}, {22, 28, 49, 64, 22, 28, 49, 64}); } TEST_F(MatMulOpTest, SimpleOPENCL) { - Simple({1, 2, 3, 1}, {1, 2, 3, 4, 5, 6}, {1, 3, 2, 1}, - {1, 2, 3, 4, 5, 6}, {1, 2, 2, 1}, {22, 28, 49, 64}); + Simple({1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 3, 2}, + {1, 2, 3, 4, 5, 6}, {1, 2, 2}, {22, 28, 49, 64}); Simple( - {1, 5, 5, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, - 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, - {1, 5, 5, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, - 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, - {1, 5, 5, 1}, {215, 230, 245, 260, 275, 490, 530, 570, 610, - 650, 765, 830, 895, 960, 1025, 1040, 1130, 1220, - 1310, 1400, 1315, 1430, 1545, 1660, 1775}); + {1, 5, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, + {1, 5, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, + {1, 5, 5}, {215, 230, 245, 260, 275, 490, 530, 570, 610, + 650, 765, 830, 895, 960, 1025, 1040, 1130, 1220, + 1310, 1400, 1315, 1430, 1545, 1660, 1775}); } TEST_F(MatMulOpTest, SimpleGPUWithBatch) { - Simple({2, 2, 3, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, - {2, 3, 2, 1}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, - {2, 2, 2, 1}, {22, 28, 49, 64, 22, 28, 49, 64}); + Simple({2, 2, 3}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, + {2, 3, 2}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, + {2, 2, 2}, {22, 28, 49, 64, 22, 28, 49, 64}); } namespace { template -void Complex(const index_t batch, +void Complex(const std::vector &batch, const index_t height, const index_t channels, const index_t out_width) { @@ -119,23 +119,14 @@ void Complex(const index_t batch, // Construct graph OpsTestNet net; - OpDefBuilder("MatMul", "MatMulTest") - .Input("A") - .Input("B") - .Output("Output") - .Finalize(net.NewOperatorDef()); // Add input data - net.AddRandomInput("A", {batch, height, channels, 1}); - net.AddRandomInput("B", - {batch, channels, out_width, 1}); - - // run cpu - net.RunOp(); - - // Check - Tensor expected; - expected.Copy(*net.GetOutput("Output")); + index_t batch_count = std::accumulate(batch.begin(), batch.end(), 1, + std::multiplies()); + net.AddRandomInput("A", + {batch_count, height, channels}); + net.AddRandomInput( + "B", {batch_count, channels, out_width}); // Run on opencl BufferToImage(&net, "A", "AImage", @@ -150,11 +141,40 @@ void Complex(const index_t batch, .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); - // Run on opencl net.RunOp(DeviceType::GPU); ImageToBuffer(&net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT_HEIGHT); + + // run cpu + std::vector shape_a = batch; + shape_a.push_back(height); + shape_a.push_back(channels); + std::vector shape_b = batch; + shape_b.push_back(channels); + shape_b.push_back(out_width); + std::vector expected_output_shape = batch; + expected_output_shape.push_back(height); + expected_output_shape.push_back(out_width); + + net.GetTensor("A")->Reshape(shape_a); + net.GetTensor("B")->Reshape(shape_b); + + OpDefBuilder("MatMul", "MatMulTest") + .Input("A") + .Input("B") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + // Check + EXPECT_EQ(expected_output_shape, net.GetOutput("Output")->shape()); + + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + expected.Reshape({batch_count, height, out_width}); + if (DataTypeToEnum::value == DataType::DT_HALF) { ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1e-2, 1e-1); @@ -166,28 +186,36 @@ void Complex(const index_t batch, } // namespace TEST_F(MatMulOpTest, OPENCLAlignedWithoutBatch) { - Complex(1, 64, 128, 32); - Complex(1, 64, 32, 128); + Complex({1}, 64, 128, 32); + Complex({1}, 64, 32, 128); + Complex({2, 3}, 64, 32, 128); } TEST_F(MatMulOpTest, OPENCLUnAlignedWithoutBatch) { - Complex(1, 31, 113, 61); - Complex(1, 113, 31, 73); + Complex({1}, 31, 113, 61); + Complex({1}, 113, 31, 73); + Complex({2, 3}, 113, 31, 73); } TEST_F(MatMulOpTest, OPENCLUnAlignedWithBatch) { - Complex(2, 3, 3, 3); - Complex(16, 31, 61, 67); - Complex(31, 31, 61, 67); + Complex({2}, 3, 3, 3); + Complex({16}, 31, 61, 67); + Complex({31}, 31, 61, 67); + Complex({2, 3}, 31, 61, 67); } TEST_F(MatMulOpTest, OPENCLHalfAlignedWithoutBatch) { - Complex(1, 64, 128, 32); - Complex(1, 64, 32, 128); + Complex({1}, 64, 128, 32); + Complex({1}, 64, 32, 128); + Complex({2, 3}, 64, 32, 128); } TEST_F(MatMulOpTest, OPENCLHalfUnAlignedWithBatch) { - Complex(2, 31, 113, 61); - Complex(16, 32, 64, 64); - Complex(31, 31, 61, 67); + Complex({2}, 31, 113, 61); + Complex({16}, 32, 64, 64); + Complex({31}, 31, 61, 67); + Complex({2, 3}, 31, 61, 67); } +// TODO(liyin): test transpose after implementing gpu runtime +// now transpose test is in kernels_test + } // namespace test } // namespace ops } // namespace mace diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index c707e4249100922b8aed079f05615e25fe16ce03..008281a5106bd04197fc094d4ef4227d84aeb303 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -518,7 +518,7 @@ class Transformer(base_converter.ConverterInterface): wt_output_width = batch * ( (out_height + 1) / 2) * ((out_width + 1) / 2) wt_output_shape.dims.extend( - [16, in_channels, wt_output_width, 1]) + [16, in_channels, wt_output_width]) if ConverterUtil.get_arg(op, MaceKeyword.mace_padding_str) \ @@ -543,7 +543,7 @@ class Transformer(base_converter.ConverterInterface): matmul_op.output.extend([matmul_op.name]) matmul_output_shape = matmul_op.output_shape.add() matmul_output_shape.dims.extend( - [16, out_channels, wt_output_width, 1]) + [16, out_channels, wt_output_width]) arg = matmul_op.arg.add() arg.name = MaceKeyword.mace_winograd_filter_transformed diff --git a/mace/python/tools/memory_optimizer.py b/mace/python/tools/memory_optimizer.py index 9b0947e091cfe78253a8b66fc9a844c0ef3b633a..c05b94be7780c628495dbb85523e657dd55aeddc 100644 --- a/mace/python/tools/memory_optimizer.py +++ b/mace/python/tools/memory_optimizer.py @@ -167,7 +167,7 @@ class GPUMemoryOptimizer(MemoryOptimizer): def get_op_mem_block(self, op_type, output_shape): mem_block = [0, 0] if op_type == 'WinogradTransform' or op_type == 'MatMul': - mem_block[0] = output_shape[2] * output_shape[3] + mem_block[0] = output_shape[2] mem_block[1] = output_shape[0] * int((output_shape[1] + 3) / 4) else: mem_block[0] = output_shape[2] * int((output_shape[3] + 3) / 4)