diff --git a/mace/kernels/gemm.cc b/mace/kernels/gemm.cc index 0e05106fe0d6ef492c20f53b6afca9008445b062..2003d3ec9a2a3bb262c605d77657255d48eca68d 100644 --- a/mace/kernels/gemm.cc +++ b/mace/kernels/gemm.cc @@ -18,6 +18,17 @@ #include "mace/core/tensor.h" #include "mace/kernels/gemm.h" +/** + * Gemm does fast matrix multiplications with batch. + * It is optimized for arm64-v8 and armeabi-v7a using neon. + * + * We adopt two-level tiling to make better use of l1 cache and register. + * For register tiling, function like GemmXYZ computes gemm for + * matrix[X, Y] * matrix[Y, Z] with all data being able to fit in register. + * For cache tiling, we try to compute one block of multiplication with + * two input matrices and one output matrix fit in l1 cache. + */ + #if defined(MACE_ENABLE_NEON) #include #endif diff --git a/mace/kernels/gemm.h b/mace/kernels/gemm.h index ade517a08891270458d2eb7dfb20b114caf4f19c..f6ea31c41b04414d767a2b647dfaf1242600d4bd 100644 --- a/mace/kernels/gemm.h +++ b/mace/kernels/gemm.h @@ -21,9 +21,15 @@ #include "mace/core/types.h" +// Gemm function does fast matrix-matrix multiplications with batch. +// Gemv function does fast matrix-vector multiplications with batch. + namespace mace { namespace kernels { +// Gemm calculates A[batch, height, K] dot B[batch, K, width] within each batch, +// and output to C[batch, height, width]. +// height, K, width correspond to matrix dimension size after transpose (if any) void Gemm(const float *A, const float *B, const index_t batch, @@ -44,6 +50,8 @@ void GemmRef(const float *A, const bool transpose_a = false, const bool transpose_b = false); +// Gemm calculates M[height, width] dot V[batch, height] within each batch of V, +// and output to out[batch, width]. void Gemv(const float *m_ptr, const float *v_ptr, const index_t batch,