From aea546455cac8323ed3681aea407fb210a200b13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Tue, 13 Mar 2018 12:58:21 +0800 Subject: [PATCH] matmul tiling impl. --- mace/kernels/matmul.h | 212 +++++++++++++++++++++++++++++++---- mace/ops/matmul_benchmark.cc | 3 + 2 files changed, 191 insertions(+), 24 deletions(-) diff --git a/mace/kernels/matmul.h b/mace/kernels/matmul.h index d893e951..7610c0da 100644 --- a/mace/kernels/matmul.h +++ b/mace/kernels/matmul.h @@ -5,14 +5,128 @@ #ifndef MACE_KERNELS_MATMUL_H_ #define MACE_KERNELS_MATMUL_H_ +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) +#include +#endif + +#include +#include +#include + #include "mace/core/future.h" #include "mace/core/runtime/opencl/cl2_header.h" #include "mace/core/tensor.h" +#include "mace/utils/utils.h" namespace mace { namespace kernels { -template +namespace { +template +inline void MatMulKernelFunc(const T *A, + const T *B, + T *C, + index_t offset_h, + index_t offset_w, + index_t offset_k, + index_t stride_h, + index_t stride_w, + index_t stride_k) { + T a_tmp[register_tile_size][register_tile_size] = {0}; + T b_tmp[register_tile_size][register_tile_size] = {0}; + T c_tmp[register_tile_size][register_tile_size] = {0}; + + for (int h = 0; h < h_count; ++h) { + for (int k = 0; k < k_count; ++k) { + a_tmp[h][k] = A[(offset_h + h) * stride_k + (offset_k + k)]; + } + } + for (int k = 0; k < k_count; ++k) { + for (int w = 0; w < w_count; ++w) { + b_tmp[k][w] = B[(offset_k + k) * stride_w + (offset_w + w)]; + } + } + +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) + static_assert(register_tile_size == 4, "register tile size must be 4"); + float32x4_t a_dup; + float32x4_t b_vec[4] = + {vld1q_f32(b_tmp[0]), vld1q_f32(b_tmp[1]), vld1q_f32(b_tmp[2]), + vld1q_f32(b_tmp[3])}; + float32x4_t + c_vec[4] = {vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0)}; + + for (int h = 0; h < register_tile_size; ++h) { + for (int k = 0; k < register_tile_size; ++k) { + a_dup = vdupq_n_f32(a_tmp[h][k]); + c_vec[h] = vfmaq_f32(c_vec[h], a_dup, b_vec[k]); + } + } + + for (int h = 0; h < register_tile_size; ++h) { + vst1q_f32(c_tmp[h], c_vec[h]); + } + +#else + for (int h = 0; h < register_tile_size; ++h) { + for (int w = 0; w < register_tile_size; ++w) { + for (int k = 0; k < register_tile_size; ++k) { + c_tmp[h][w] += a_tmp[h][k] * b_tmp[k][w]; + } + } + } +#endif + + for (int h = 0; h < h_count; ++h) { + for (int w = 0; w < w_count; ++w) { + C[(offset_h + h) * stride_w + (offset_w + w)] += c_tmp[h][w]; + } + } +} +} // namespace + +#define CASE_K_MATMUL(HC, WC, KC) \ + case KC: \ + MatMulKernelFunc(a_ptr_batch_base, \ + b_ptr_batch_base, \ + c_ptr_batch_base, \ + ih, \ + iw, \ + ik, \ + height, \ + width, \ + K); \ + break; + +#define CASE_W_MATMUL(HC, WC) \ + case WC: \ + switch (k_count) { \ + CASE_K_MATMUL(HC, WC, 1); \ + CASE_K_MATMUL(HC, WC, 2); \ + CASE_K_MATMUL(HC, WC, 3); \ + CASE_K_MATMUL(HC, WC, 4); \ + default: \ + LOG(FATAL) << "Unsupported k tile: " << k_count; \ + } \ + break; + +#define CASE_H_MATMUL(HC) \ + case HC: \ + switch (w_count) { \ + CASE_W_MATMUL(HC, 1); \ + CASE_W_MATMUL(HC, 2); \ + CASE_W_MATMUL(HC, 3); \ + CASE_W_MATMUL(HC, 4); \ + default: \ + LOG(FATAL) << "Unsupported w tile: " << k_count; \ + } \ + break; + +template struct MatMulFunctor { void operator()(const Tensor *A, const Tensor *B, @@ -20,37 +134,87 @@ struct MatMulFunctor { StatsFuture *future) { std::vector c_shape = {A->dim(0), A->dim(1), B->dim(2), 1}; C->Resize(c_shape); - const index_t N = C->dim(0); - const index_t height = C->dim(1); - const index_t width = C->dim(2); - const index_t K = A->dim(2); + Tensor::MappingGuard guarda(A); Tensor::MappingGuard guardb(B); Tensor::MappingGuard guardc(C); const T *a_ptr_base = A->data(); const T *b_ptr_base = B->data(); - T *c_ptr = C->mutable_data(); - for (int i = 0; i < N; ++i) { - for (int h = 0; h < height; ++h) { - for (int w = 0; w < width; ++w) { - const T *a_ptr = a_ptr_base + h * K; - const T *b_ptr = b_ptr_base + w; - *c_ptr = 0; - for (int k = 0; k < K; ++k) { - *c_ptr += *a_ptr * *b_ptr; - a_ptr++; - b_ptr += width; - } - c_ptr++; - } - } - a_ptr_base += height * K; - b_ptr_base += K * width; - } + T *c_ptr_base = C->mutable_data(); + + const index_t batch = C->dim(0); + const index_t height = C->dim(1); + const index_t width = C->dim(2); + const index_t K = A->dim(2); + // It is better to use large block size if it fits for fast cache. + // Assume l1 cache size is 32k, we load three blocks at a time (A, B, C), + // the block size should be sqrt(32k / sizeof(T) / 3). + const index_t block_size = 48; + const index_t block_tile_height = RoundUpDiv(height, block_size); + const index_t block_tile_width = RoundUpDiv(width, block_size); + const index_t block_tile_k = RoundUpDiv(K, block_size); + const index_t remain_height = height % block_size; + const index_t remain_width = width % block_size; + const index_t remain_k = K % block_size; + constexpr index_t register_tile_size = 4; + memset(c_ptr_base, 0, batch * height * width * sizeof(T)); + +#pragma omp parallel for collapse(3) + for (index_t n = 0; n < batch; ++n) { + // handle block + for (index_t bh = 0; bh < block_tile_height; ++bh) { + for (index_t bw = 0; bw < block_tile_width; ++bw) { + const T *a_ptr_batch_base = a_ptr_base + n * height * K; + const T *b_ptr_batch_base = b_ptr_base + n * K * width; + T *c_ptr_batch_base = c_ptr_base + n * height * width; + const index_t ih_begin = bh * block_size; + const index_t ih_end = + bh * block_size + (bh == block_tile_height - 1 && remain_height > 0 + ? remain_height : block_size); + const index_t iw_begin = bw * block_size; + const index_t iw_end = + bw * block_size + + (bw == block_tile_width - 1 && remain_width > 0 ? remain_width + : block_size); + + for (index_t bk = 0; bk < block_tile_k; ++bk) { + const index_t ik_begin = bk * block_size; + const index_t ik_end = + bk * block_size + + (bk == block_tile_k - 1 && remain_k > 0 ? remain_k + : block_size); + + // inside block: + // calculate C[bh, bw] += A[bh, bk] * B[bk, bw] for one k + for (index_t ih = ih_begin; ih < ih_end; + ih += register_tile_size) { + for (index_t iw = iw_begin; iw < iw_end; + iw += register_tile_size) { + for (index_t ik = ik_begin; ik < ik_end; + ik += register_tile_size) { + const int h_count = std::min(register_tile_size, ih_end - ih); + const int w_count = std::min(register_tile_size, iw_end - iw); + const int k_count = std::min(register_tile_size, ik_end - ik); + + switch (h_count) { + CASE_H_MATMUL(1); + CASE_H_MATMUL(2); + CASE_H_MATMUL(3); + CASE_H_MATMUL(4); + default:LOG(FATAL) << "Unsupported height tile: " + << h_count; + } + } // ik + } // iw + } // ih + } // bk + } // bw + } // bh + } // n } }; -template +template struct MatMulFunctor { void operator()(const Tensor *A, const Tensor *B, diff --git a/mace/ops/matmul_benchmark.cc b/mace/ops/matmul_benchmark.cc index f6e1c6d1..c83b872a 100644 --- a/mace/ops/matmul_benchmark.cc +++ b/mace/ops/matmul_benchmark.cc @@ -68,4 +68,7 @@ static void MatMulBenchmark( BM_MATMUL(16, 32, 128, 49); BM_MATMUL(16, 32, 128, 961); BM_MATMUL(16, 32, 128, 3969); +BM_MATMUL(16, 128, 128, 49); +BM_MATMUL(16, 128, 128, 961); +BM_MATMUL(16, 128, 128, 3969); } // namespace mace -- GitLab