diff --git a/mace/core/allocator.h b/mace/core/allocator.h index a241fd4c004292adf263bb80dbb3bc701b6a1769..a212e7f91434e13c6d4dd101bab16ce855153842 100644 --- a/mace/core/allocator.h +++ b/mace/core/allocator.h @@ -34,8 +34,8 @@ namespace mace { #if defined(__hexagon__) constexpr size_t kMaceAlignment = 128; #elif defined(__ANDROID__) -// 16 bytes = 128 bits = 32 * 4 (Neon) -constexpr size_t kMaceAlignment = 16; +// arm cache line +constexpr size_t kMaceAlignment = 64; #else // 32 bytes = 256 bits (AVX512) constexpr size_t kMaceAlignment = 32; diff --git a/mace/core/runtime/cpu/cpu_runtime.cc b/mace/core/runtime/cpu/cpu_runtime.cc index 10bfbee83536e50ff39646d8ee31f1d887b78d47..f9b1d49f2f9dad0408a3b1922c12169444aa549c 100644 --- a/mace/core/runtime/cpu/cpu_runtime.cc +++ b/mace/core/runtime/cpu/cpu_runtime.cc @@ -35,6 +35,8 @@ namespace mace { +int MaceOpenMPThreadCount = 1; + namespace { int GetCPUCount() { @@ -136,6 +138,8 @@ MaceStatus GetCPUBigLittleCoreIDs(std::vector *big_core_ids, MaceStatus SetOpenMPThreadsAndAffinityCPUs(int omp_num_threads, const std::vector &cpu_ids) { + MaceOpenMPThreadCount = omp_num_threads; + #ifdef MACE_ENABLE_OPENMP VLOG(1) << "Set OpenMP threads number: " << omp_num_threads << ", CPU core IDs: " << MakeString(cpu_ids); diff --git a/mace/core/runtime/cpu/cpu_runtime.h b/mace/core/runtime/cpu/cpu_runtime.h index 1fb463f5e7c4b0631ba012756f6ab3be81f3f65f..3382a8f1c66de2b8fa41b3420b380efc91da5ab1 100644 --- a/mace/core/runtime/cpu/cpu_runtime.h +++ b/mace/core/runtime/cpu/cpu_runtime.h @@ -22,6 +22,8 @@ namespace mace { +extern int MaceOpenMPThreadCount; + MaceStatus GetCPUBigLittleCoreIDs(std::vector *big_core_ids, std::vector *little_core_ids); diff --git a/mace/core/tensor.h b/mace/core/tensor.h index e48edc87490d8b1c7194ed82d3ba72fac4d40984..62ea5488a87f53233c049915c8170ff8eb41d709 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -100,31 +100,38 @@ enum DataFormat { NHWC = 0, NCHW = 1, HWOI = 2, OIHW = 3, HWIO = 4, OHWI = 5 }; class Tensor { public: - Tensor(Allocator *alloc, DataType type) + Tensor(Allocator *alloc, DataType type, + bool is_weight = false) : allocator_(alloc), dtype_(type), buffer_(nullptr), is_buffer_owner_(true), unused_(false), name_(""), + is_weight_(is_weight), scale_(0.f), zero_point_(0) {} - Tensor(BufferBase *buffer, DataType dtype) + Tensor(BufferBase *buffer, DataType dtype, + bool is_weight = false) : dtype_(dtype), buffer_(buffer), is_buffer_owner_(false), unused_(false), name_(""), + is_weight_(is_weight), scale_(0.f), zero_point_(0) {} - Tensor(const BufferSlice &buffer_slice, DataType dtype) + Tensor(const BufferSlice &buffer_slice, + DataType dtype, + bool is_weight = false) : dtype_(dtype), buffer_slice_(buffer_slice), is_buffer_owner_(false), unused_(false), name_(""), + is_weight_(is_weight), scale_(0.f), zero_point_(0) { buffer_ = &buffer_slice_; @@ -373,6 +380,10 @@ class Tensor { MACE_DISABLE_COPY_AND_ASSIGN(MappingGuard); }; + inline bool is_weight() const { + return is_weight_; + } + inline float scale() const { return scale_; } @@ -399,6 +410,7 @@ class Tensor { bool is_buffer_owner_; bool unused_; std::string name_; + const bool is_weight_; float scale_; int32_t zero_point_; diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index 170070cd8d27d273c049b10ac16e22076b2c18ec..07d855605ed744d64345ab722225a274bc09063c 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -105,7 +105,7 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def, std::unique_ptr tensor( new Tensor(GetDeviceAllocator(type), - const_tensor.data_type())); + const_tensor.data_type(), true)); tensor->Resize(dims); MACE_CHECK(tensor->size() == const_tensor.data_size(), @@ -159,7 +159,7 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def, tensor_buffer_.get(), const_tensor.offset(), const_tensor.data_size() * GetEnumTypeSize(const_tensor.data_type())), - const_tensor.data_type())); + const_tensor.data_type(), true)); tensor->Reshape(dims); tensor->SetScale(const_tensor.scale()); diff --git a/mace/kernels/gemm.cc b/mace/kernels/gemm.cc index 2003d3ec9a2a3bb262c605d77657255d48eca68d..c94c0af5e900e7414d452e35630b8c6f623418b7 100644 --- a/mace/kernels/gemm.cc +++ b/mace/kernels/gemm.cc @@ -14,8 +14,10 @@ #include #include +#include #include "mace/core/tensor.h" +#include "mace/core/runtime/cpu/cpu_runtime.h" #include "mace/kernels/gemm.h" /** @@ -329,37 +331,6 @@ inline void Gemm644(const float *a_ptr, #endif } -inline void GemmX44(const float *a_ptr, - const float *b_ptr, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *c_ptr, - int row) { - switch (row) { - case 1: - Gemm144(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); - break; - case 2: - Gemm244(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); - break; - case 3: - Gemm344(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); - break; - case 4: - Gemm444(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); - break; - case 5: - Gemm544(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); - break; - case 6: - Gemm644(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); - break; - default: - MACE_NOT_IMPLEMENTED; - } -} - inline void Gemm884(const float *a_ptr, const float *b_ptr, const index_t stride_a, @@ -770,43 +741,6 @@ inline void Gemm784(const float *a_ptr, #endif } -inline void GemmX84(const float *a_ptr, - const float *b_ptr, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *c_ptr, - int row) { - switch (row) { - case 1: - Gemm184(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); - break; - case 2: - Gemm284(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); - break; - case 3: - Gemm384(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); - break; - case 4: - Gemm484(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); - break; - case 5: - Gemm584(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); - break; - case 6: - Gemm684(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); - break; - case 7: - Gemm784(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); - break; - case 8: - Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); - break; - default: - MACE_NOT_IMPLEMENTED; - } -} - inline void GemmTile(const float *A, const float *B, const index_t height, @@ -873,6 +807,8 @@ inline void GemmTile(const float *A, float *c_ptr7 = C + (h + 7) * stride_c; asm volatile( + "0: \n" + "prfm pldl1keep, [%9, #128] \n" "ld1 {v16.4s}, [%9], #16 \n" @@ -882,8 +818,6 @@ inline void GemmTile(const float *A, "prfm pldl1keep, [%2, #128] \n" "ld1 {v19.4s}, [%2] \n" - "0: \n" - "prfm pldl1keep, [%3, #128] \n" "ld1 {v20.4s}, [%3] \n" "prfm pldl1keep, [%4, #128] \n" @@ -1002,19 +936,13 @@ inline void GemmTile(const float *A, "fmla v24.4s, v17.4s, %48.s[3] \n" "fmla v25.4s, v17.4s, %49.s[3] \n" + "subs %w0, %w0, #1 \n" + "st1 {v22.4s}, [%5], #16 \n" "st1 {v23.4s}, [%6], #16 \n" "st1 {v24.4s}, [%7], #16 \n" "st1 {v25.4s}, [%8], #16 \n" - "prfm pldl1keep, [%9, #128] \n" - "ld1 {v16.4s}, [%9], #16 \n" - "prfm pldl1keep, [%1, #128] \n" - "ld1 {v18.4s}, [%1] \n" - "prfm pldl1keep, [%2, #128] \n" - "ld1 {v19.4s}, [%2] \n" - - "subs %w0, %w0, #1 \n" "bne 0b \n" : "=r"(nw), // 0 "=r"(c_ptr0), // 1 @@ -1102,6 +1030,8 @@ inline void GemmTile(const float *A, float *c_ptr5 = C + (h + 5) * stride_c; asm volatile( + "0: \n" + "pld [%7, #128] \n" "vld1.f32 {d12-d13}, [%7]! \n" "pld [%1, #128] \n" @@ -1109,8 +1039,6 @@ inline void GemmTile(const float *A, "pld [%2, #128] \n" "vld1.f32 {d18-d19}, [%2] \n" - "0: \n" - "pld [%3, #128] \n" "vld1.f32 {d20-d21}, [%3] \n" "pld [%4, #128] \n" @@ -1159,22 +1087,11 @@ inline void GemmTile(const float *A, "vst1.f32 {d16-d17}, [%1]! \n" "vst1.f32 {d18-d19}, [%2]! \n" - - "pld [%7, #128] \n" - "vld1.f32 {d12-d13}, [%7]! \n" - "vst1.f32 {d20-d21}, [%3]! \n" "vst1.f32 {d22-d23}, [%4]! \n" - - "pld [%1, #128] \n" - "vld1.f32 {d16-d17}, [%1] \n" - "vst1.f32 {d24-d25}, [%5]! \n" "vst1.f32 {d26-d27}, [%6]! \n" - "pld [%2, #128] \n" - "vld1.f32 {d18-d19}, [%2] \n" - "subs %0, #1 \n" "bne 0b \n" : "=r"(nw), // 0 @@ -1228,17 +1145,69 @@ inline void GemmTile(const float *A, } if (h < height) { index_t remain_h = height - h; + + auto gemm_fn = Gemm184; + switch (remain_h) { + case 1: + #if defined(__aarch64__) + gemm_fn = Gemm184; + #else + gemm_fn = Gemm144; + #endif + break; + case 2: + #if defined(__aarch64__) + gemm_fn = Gemm284; + #else + gemm_fn = Gemm244; + #endif + break; + case 3: + #if defined(__aarch64__) + gemm_fn = Gemm384; + #else + gemm_fn = Gemm344; + #endif + break; + case 4: + #if defined(__aarch64__) + gemm_fn = Gemm484; + #else + gemm_fn = Gemm444; + #endif + break; + case 5: + #if defined(__aarch64__) + gemm_fn = Gemm584; + #else + gemm_fn = Gemm544; + #endif + break; + case 6: + #if defined(__aarch64__) + gemm_fn = Gemm684; + #else + LOG(FATAL) << "remain_h should < 6"; + #endif + break; + case 7: + #if defined(__aarch64__) + gemm_fn = Gemm784; + #else + LOG(FATAL) << "remain_h should < 6"; + #endif + break; + default: + LOG(FATAL) << "remain_h should < 8"; + } + for (k = 0; k < K - reg_K_tile; k += reg_K_tile) { const float *a_ptr = A + (h * stride_a + k); index_t w; for (w = 0; w + 3 < width; w += 4) { const float *b_ptr = B + (k * stride_b + w); float *c_ptr = C + (h * stride_c + w); -#if defined(__aarch64__) - GemmX84(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr, remain_h); -#else - GemmX44(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr, remain_h); -#endif + gemm_fn(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); } if (w < width) { const float *b_ptr = B + (k * stride_b + w); @@ -1260,20 +1229,27 @@ inline void GemmTile(const float *A, #endif // MACE_ENABLE_NEON } +} // namespace + void Transpose(const float *src, index_t height, index_t width, index_t stride_w, float *dst) { - for (index_t h = 0; h < height; ++h) { - for (index_t w = 0; w < width; ++w) { - dst[w * height + h] = src[h * stride_w + w]; + index_t tile_size = height > 512 || width > 512 ? 64 : 32; + for (index_t i = 0; i < height; i += tile_size) { + for (index_t j = 0; j < width; j += tile_size) { + index_t end_i = std::min(i + tile_size, height); + index_t end_j = std::min(j + tile_size, width); + for (index_t tile_i = i; tile_i < end_i; ++tile_i) { + for (index_t tile_j = j; tile_j < end_j; ++tile_j) { + dst[tile_j * height + tile_i] = src[tile_i * stride_w + tile_j]; + } + } } } } -} // namespace - // A: height x K, B: K x width, C: height x width void Gemm(const float *A, const float *B, @@ -1284,7 +1260,7 @@ void Gemm(const float *A, float *C, const bool transpose_a, const bool transpose_b) { - if (width == 1) { + if (width == 1 && !transpose_a) { for (index_t b = 0; b < batch; ++b) { Gemv(A + b * height * K, B + b * K, 1, K, height, C + b * height); } @@ -1292,45 +1268,78 @@ void Gemm(const float *A, } memset(C, 0, sizeof(float) * batch * height * width); - // It is better to use large block size if it fits for fast cache. - // Assume l1 cache size is 32k, we load three blocks at a time (A, B, C), - // the block size should be sqrt(32k / sizeof(T) / 3). - // As number of input channels of convolution is normally power of 2, and - // we have not optimized tiling remains, we use the following magic number - const index_t block_size = 64; - const index_t block_tile_height = RoundUpDiv(height, block_size); - const index_t block_tile_width = RoundUpDiv(width, block_size); - const index_t block_tile_k = RoundUpDiv(K, block_size); - const index_t block_tile[3] = {block_tile_height, block_tile_width, - block_tile_k}; - const index_t remain_height = height % block_size; - const index_t remain_width = width % block_size; - const index_t remain_k = K % block_size; - const index_t remain[3] = {remain_height, remain_width, remain_k}; + std::vector block_size_dims {height, width, K}; + index_t thread_count = MaceOpenMPThreadCount; + MACE_CHECK(thread_count >= 1, "thread should be ge 1"); + // TODO(liyin): apply gcd ? + if (height % thread_count == 0) { + block_size_dims[0] = height / thread_count; + } else if (thread_count == 4 && (height & 1) == 0 && (width & 1) == 0) { + block_size_dims[0] = height >> 1; + block_size_dims[1] = width >> 1; + } else if (width % thread_count == 0) { + block_size_dims[1] = width / thread_count; + } else { + if (height >= thread_count) { + block_size_dims[0] = height / thread_count; + } else { + thread_count = std::min(thread_count, height * width); + index_t thread_h = height; + index_t thread_w = RoundUpDiv(thread_count, thread_h); + block_size_dims[0] = 1; + block_size_dims[1] = std::max(static_cast(1), width / thread_w); + } + } + + const index_t block_tile[3] = {height / block_size_dims[0], + width / block_size_dims[1], + K / block_size_dims[2]}; + block_size_dims[0] = height / block_tile[0]; + block_size_dims[1] = width / block_tile[1]; + block_size_dims[2] = K / block_tile[2]; + + const index_t remain[3] = {height % block_tile[0], + width % block_tile[1], + K % block_tile[2]}; + #pragma omp parallel for collapse(3) for (index_t n = 0; n < batch; ++n) { for (index_t bh = 0; bh < block_tile[0]; ++bh) { for (index_t bw = 0; bw < block_tile[1]; ++bw) { + const index_t remain_height = remain[0]; + const index_t remain_width = remain[1]; + const index_t remain_k = remain[2]; + + const index_t block_size_height = block_size_dims[0]; + const index_t block_size_width = block_size_dims[1]; + const index_t block_size_k = block_size_dims[2]; + + const index_t this_block_size_height = + block_size_height + (bh < remain_height ? 1 : 0); + const index_t this_block_size_width = + block_size_width + (bw < remain_width ? 1 : 0); + const float *a_base = A + n * height * K; const float *b_base = B + n * K * width; float *c_base = C + n * height * width; - const index_t ih_begin = bh * block_size; - const index_t ih_end = - bh * block_size + - (bh == block_tile[0] - 1 && remain[0] > 0 ? remain[0] : block_size); - const index_t iw_begin = bw * block_size; - const index_t iw_end = - bw * block_size + - (bw == block_tile[1] - 1 && remain[1] > 0 ? remain[1] : block_size); + const index_t ih_begin = + bh * block_size_height + (bh < remain_height ? bh : remain_height); + const index_t + ih_end = std::min(height, ih_begin + this_block_size_height); + const index_t iw_begin = + bw * block_size_width + (bw < remain_width ? bw : remain_width); + const index_t + iw_end = std::min(width, iw_begin + this_block_size_width); for (index_t bk = 0; bk < block_tile[2]; ++bk) { - const index_t ik_begin = bk * block_size; - const index_t ik_end = - bk * block_size + (bk == block_tile[2] - 1 && remain[2] > 0 - ? remain[2] - : block_size); + const index_t + this_block_size_k = block_size_k + (bk < remain_k ? 1 : 0); + + const index_t + ik_begin = bk * block_size_k + (bk < remain_k ? bk : remain_k); + const index_t ik_end = std::min(K, ik_begin + this_block_size_k); Tensor trans_a; Tensor trans_b; @@ -1342,7 +1351,7 @@ void Gemm(const float *A, index_t stride_c = width; if (transpose_a) { - trans_a.Resize({block_size, block_size}); + trans_a.Resize({this_block_size_height, this_block_size_k}); float *trans_a_data = trans_a.mutable_data(); // A[K, H] -> A[H, K] Transpose(a_base + (ik_begin * height + ih_begin), @@ -1356,7 +1365,7 @@ void Gemm(const float *A, } if (transpose_b) { - trans_b.Resize({block_size, block_size}); + trans_b.Resize({this_block_size_k, this_block_size_width}); float *trans_b_data = trans_b.mutable_data(); // B[W, K] -> B[K, W] Transpose(b_base + (iw_begin * K + ik_begin), iw_end - iw_begin, @@ -1449,7 +1458,6 @@ void GemvRef(const float *m_ptr, } } -// TODO(liyin): batched gemv can be transformed to gemm (w/ transpose) void Gemv(const float *m_ptr, const float *v_ptr, const index_t batch, @@ -1457,88 +1465,74 @@ void Gemv(const float *m_ptr, const index_t height, float *out_ptr) { #if defined(MACE_ENABLE_NEON) -// TODO(liyin/wch): try height tiling = 8 + #pragma omp parallel for collapse(2) for (index_t b = 0; b < batch; ++b) { - for (index_t h = 0; h < height; h += 4) { - if (h + 3 < height) { - const float *m_ptr0 = m_ptr + h * width; - const float *m_ptr1 = m_ptr0 + width; - const float *m_ptr2 = m_ptr1 + width; - const float *m_ptr3 = m_ptr2 + width; - const float *v_ptr0 = v_ptr + b * width; - float *out_ptr0 = out_ptr + b * height + h; - - float32x4_t vm0, vm1, vm2, vm3; - float32x4_t vv; - - float32x4_t vsum0 = vdupq_n_f32(0.f); - float32x4_t vsum1 = vdupq_n_f32(0.f); - float32x4_t vsum2 = vdupq_n_f32(0.f); - float32x4_t vsum3 = vdupq_n_f32(0.f); - - index_t w; - for (w = 0; w + 3 < width; w += 4) { - vm0 = vld1q_f32(m_ptr0); - vm1 = vld1q_f32(m_ptr1); - vm2 = vld1q_f32(m_ptr2); - vm3 = vld1q_f32(m_ptr3); - vv = vld1q_f32(v_ptr0); - - vsum0 = vmlaq_f32(vsum0, vm0, vv); - vsum1 = vmlaq_f32(vsum1, vm1, vv); - vsum2 = vmlaq_f32(vsum2, vm2, vv); - vsum3 = vmlaq_f32(vsum3, vm3, vv); - - m_ptr0 += 4; - m_ptr1 += 4; - m_ptr2 += 4; - m_ptr3 += 4; - v_ptr0 += 4; - } - float sum0 = vaddvq_f32(vsum0); - float sum1 = vaddvq_f32(vsum1); - float sum2 = vaddvq_f32(vsum2); - float sum3 = vaddvq_f32(vsum3); - - // handle remaining w - for (; w < width; ++w) { - sum0 += m_ptr0[0] * v_ptr0[0]; - sum1 += m_ptr1[0] * v_ptr0[0]; - sum2 += m_ptr2[0] * v_ptr0[0]; - sum3 += m_ptr3[0] * v_ptr0[0]; - m_ptr0++; - m_ptr1++; - m_ptr2++; - m_ptr3++; - v_ptr0++; - } - *out_ptr0++ = sum0; - *out_ptr0++ = sum1; - *out_ptr0++ = sum2; - *out_ptr0++ = sum3; - } else { - for (index_t hh = h; hh < height; ++hh) { - float32x4_t vsum0 = vdupq_n_f32(0.f); - const float *m_ptr0 = m_ptr + hh * width; - const float *v_ptr0 = v_ptr + b * width; - index_t w; - for (w = 0; w + 3 < width; w += 4) { - float32x4_t vm = vld1q_f32(m_ptr0); - float32x4_t vv = vld1q_f32(v_ptr0); - vsum0 = vmlaq_f32(vsum0, vm, vv); - m_ptr0 += 4; - v_ptr0 += 4; - } - float sum = vaddvq_f32(vsum0); - for (; w < width; ++w) { - sum += m_ptr0[0] * v_ptr0[0]; - m_ptr0++; - v_ptr0++; - } - out_ptr[b * height + hh] = sum; - } - } // if + for (index_t h = 0; h < height; ++h) { + const float *m_ptr0 = m_ptr + h * width; + const float *v_ptr0 = v_ptr + b * width; + float *out_ptr0 = out_ptr + b * height + h; + + float32x4_t vm0, vm1, vm2, vm3; + float32x4_t vv0, vv1, vv2, vv3; + float32x4_t vsum0 = vdupq_n_f32(0.f); + float32x4_t vsum1 = vdupq_n_f32(0.f); + float32x4_t vsum2 = vdupq_n_f32(0.f); + float32x4_t vsum3 = vdupq_n_f32(0.f); + + index_t w; + for (w = 0; w + 15 < width; w += 16) { + vm0 = vld1q_f32(m_ptr0); + vv0 = vld1q_f32(v_ptr0); + vm1 = vld1q_f32(m_ptr0 + 4); + vv1 = vld1q_f32(v_ptr0 + 4); + vm2 = vld1q_f32(m_ptr0 + 8); + vv2 = vld1q_f32(v_ptr0 + 8); + vm3 = vld1q_f32(m_ptr0 + 12); + vv3 = vld1q_f32(v_ptr0 + 12); + + vsum0 = vmlaq_f32(vsum0, vm0, vv0); + vsum1 = vmlaq_f32(vsum1, vm1, vv1); + vsum2 = vmlaq_f32(vsum2, vm2, vv2); + vsum3 = vmlaq_f32(vsum3, vm3, vv3); + + m_ptr0 += 16; + v_ptr0 += 16; + } + + for (; w + 7 < width; w += 8) { + vm0 = vld1q_f32(m_ptr0); + vv0 = vld1q_f32(v_ptr0); + vm1 = vld1q_f32(m_ptr0 + 4); + vv1 = vld1q_f32(v_ptr0 + 4); + + vsum0 = vmlaq_f32(vsum0, vm0, vv0); + vsum1 = vmlaq_f32(vsum1, vm1, vv1); + + m_ptr0 += 8; + v_ptr0 += 8; + } + + for (; w + 3 < width; w += 4) { + vm0 = vld1q_f32(m_ptr0); + vv0 = vld1q_f32(v_ptr0); + vsum0 = vmlaq_f32(vsum0, vm0, vv0); + + m_ptr0 += 4; + v_ptr0 += 4; + } + vsum0 += vsum1; + vsum2 += vsum3; + vsum0 += vsum2; + float sum0 = vaddvq_f32(vsum0); + + // handle remaining w + for (; w < width; ++w) { + sum0 += m_ptr0[0] * v_ptr0[0]; + m_ptr0++; + v_ptr0++; + } + *out_ptr0++ = sum0; } // h } // b #else diff --git a/mace/kernels/gemm.h b/mace/kernels/gemm.h index f6ea31c41b04414d767a2b647dfaf1242600d4bd..17096bf5b48800425c6fbcdc29f750ee534dd239 100644 --- a/mace/kernels/gemm.h +++ b/mace/kernels/gemm.h @@ -66,6 +66,12 @@ void GemvRef(const float *m_ptr, const index_t height, float *out_ptr); +void Transpose(const float *src, + index_t height, + index_t width, + index_t stride_w, + float *dst); + } // namespace kernels } // namespace mace diff --git a/mace/kernels/gemm_test.cc b/mace/kernels/gemm_test.cc index 00c4ee0278c12c429e57c1a620d4bf6a275afab2..cec9491461db2ff229fdfc672a81ee7c50a7fe04 100644 --- a/mace/kernels/gemm_test.cc +++ b/mace/kernels/gemm_test.cc @@ -83,6 +83,8 @@ TEST(GEMMTest, AlignedWithoutBatch) { GemmTest(1, 6, 64, 128, false, true); GemmTest(1, 7, 64, 128, true, false); GemmTest(1, 17, 64, 128, true, true); + GemmTest(1, 256, 128, 4096, false, false); + GemmTest(1, 256, 128, 4104, false, false); } TEST(GEMMTest, UnalignedWithoutBatch) { diff --git a/mace/kernels/matmul.h b/mace/kernels/matmul.h index 736fdd34e746984c456fad248c97a7171e107bff..42e76002a231d3b0b5ebc38d3df0bacf0cc265a0 100644 --- a/mace/kernels/matmul.h +++ b/mace/kernels/matmul.h @@ -81,16 +81,34 @@ struct MatMulFunctor { const T *b_ptr_base = B->data(); T *c_ptr_base = C->mutable_data(); - // It is better to use large block size if it fits for fast cache. - // Assume l1 cache size is 32k, we load three blocks at a time (A, B, C), - // the block size should be sqrt(32k / sizeof(T) / 3). memset(c_ptr_base, 0, batch * height * width * sizeof(T)); - Gemm(a_ptr_base, b_ptr_base, batch, height, K, width, c_ptr_base, - transpose_a, transpose_b); + if (height == 1 && width > 1 && B->is_weight()) { + // A * B = (B^T * A^T)^T + if (!transpose_b) { + if (B_transpose_.get() == nullptr) { + B_transpose_.reset(new Tensor(GetDeviceAllocator(D), + DataTypeToEnum::v())); + B_transpose_->Resize({batch, width, K}); + Tensor::MappingGuard guardbt(B_transpose_.get()); + T *bt_ptr_base = B_transpose_->mutable_data(); + Transpose(b_ptr_base, K, width, width, bt_ptr_base); + } + Tensor::MappingGuard guardbt(B_transpose_.get()); + T *bt_ptr_base = B_transpose_->mutable_data(); + Gemv(bt_ptr_base, a_ptr_base, batch, K, width, c_ptr_base); + } else { + Gemv(b_ptr_base, a_ptr_base, batch, K, width, c_ptr_base); + } + } else { + Gemm(a_ptr_base, b_ptr_base, batch, height, K, width, c_ptr_base, + transpose_a, transpose_b); + } return MACE_SUCCESS; } + + std::unique_ptr B_transpose_; }; template <> diff --git a/mace/kernels/transpose.h b/mace/kernels/transpose.h index 5faa67c120ce194fccb657ca3bb41c473c83cf9e..8de796aa9259474639c31c37b60a7d6f1439710d 100644 --- a/mace/kernels/transpose.h +++ b/mace/kernels/transpose.h @@ -20,6 +20,7 @@ #endif #include +#include #include "mace/core/future.h" #include "mace/core/tensor.h" @@ -122,9 +123,20 @@ struct TransposeFunctor { MACE_CHECK(dims_[0] == 1 && dims_[1] == 0, "no need transform"); index_t stride_i = input_shape[0]; index_t stride_j = input_shape[1]; - for (int i = 0; i < input_shape[0]; ++i) { - for (int j = 0; j < input_shape[1]; ++j) { - output_data[j * stride_i + i] = input_data[i * stride_j + j]; + + index_t tile_size = input_shape[0] > 512 || input_shape[1] > 512 + ? 64 : 32; +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < input_shape[0]; i += tile_size) { + for (index_t j = 0; j < input_shape[1]; j += tile_size) { + index_t end_i = std::min(i + tile_size, input_shape[0]); + index_t end_j = std::min(j + tile_size, input_shape[1]); + for (index_t tile_i = i; tile_i < end_i; ++tile_i) { + for (index_t tile_j = j; tile_j < end_j; ++tile_j) { + output_data[tile_j * stride_i + tile_i] = + input_data[tile_i * stride_j + tile_j]; + } + } } } } else if (input->dim_size() == 4) { diff --git a/mace/ops/resize_bicubic_test.cc b/mace/ops/resize_bicubic_test.cc index ad4669f7ca1939ba5bf8b56966cc8c12f62f0b18..7c7bd8bc263dd579fc3576a278550a894f97a7d3 100644 --- a/mace/ops/resize_bicubic_test.cc +++ b/mace/ops/resize_bicubic_test.cc @@ -50,7 +50,7 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCorners) { // Check auto expected = CreateTensor({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8}); - ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-2); } TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCornersFloat) { @@ -82,7 +82,7 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCornersFloat) { 8.223037, 9.223036, 10.223037, 24., 25., 26., 28.110298, 29.1103, 30.110298, 32.223038, 33.223038, 34.223038}); - ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-2); } TEST_F(ResizeBicubicTest, ResizeBicubicWAlignCorners) { @@ -112,7 +112,7 @@ TEST_F(ResizeBicubicTest, ResizeBicubicWAlignCorners) { // Check auto expected = CreateTensor({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11}); - ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-2); } namespace { diff --git a/mace/ops/transpose_benchmark.cc b/mace/ops/transpose_benchmark.cc index aaf3faaa703f71a17a29154304c42edcaa70e01c..c5fe98cd4127c6e3a13f14af51aec4ec1f2666ec 100644 --- a/mace/ops/transpose_benchmark.cc +++ b/mace/ops/transpose_benchmark.cc @@ -90,6 +90,9 @@ MACE_BM_TRANSPOSE4D(1, 64, 64, 512, 0, 3, 1, 2); MACE_BM_TRANSPOSE4D(1, 512, 64, 64, 0, 2, 3, 1); MACE_BM_TRANSPOSE2D(128, 128); MACE_BM_TRANSPOSE2D(512, 512); +MACE_BM_TRANSPOSE2D(1024, 1024); +MACE_BM_TRANSPOSE2D(512, 2048); +MACE_BM_TRANSPOSE2D(2048, 512); } // namespace test } // namespace ops diff --git a/mace/ops/unstack_test.cc b/mace/ops/unstack_test.cc index 674ec0aea52345afe74c1b578bb4f0668b2bd8b6..306c836242426612763ea11cd573803c6d358021 100644 --- a/mace/ops/unstack_test.cc +++ b/mace/ops/unstack_test.cc @@ -43,7 +43,6 @@ void TestUnstack(const std::vector &input_shape, net.RunOp(); for (size_t i = 0; i < outputs.size(); ++i) { - LOG(INFO) << MakeString("Output", i); net.AddInputFromArray("ExpectedOutput", output_shape, outputs[i]); ExpectTensorNear(*net.GetOutput("ExpectedOutput"),