diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 119790748dc45d1536f2a92193ed82d7541aa07e..3984fd9dc321a46faa120a4c5f49227453f469fe 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -85,7 +85,7 @@ ndk_versions_compatible_tests: - DEFAULT_NDK_PATH=$ANDROID_NDK_HOME - prefix_path=${DEFAULT_NDK_PATH%android-ndk-*} - > - for ndk in android-ndk-r12b android-ndk-r15c android-ndk-r16 android-ndk-r17b; + for ndk in android-ndk-r15c android-ndk-r16 android-ndk-r17b; do new_ndk_path=${prefix_path}${ndk}; if [ "$new_ndk_path" != "$DEFAULT_NDK_PATH" ]; then diff --git a/mace/core/tensor.h b/mace/core/tensor.h index f7e509876f1564b06cbcd94e433a8ca3c03e197f..5c4c807b1cfe12a4cdf2e771a36857bf978c83ff 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -399,6 +399,10 @@ class Tensor { zero_point_ = zero_point; } + inline void SetIsWeight(bool is_weight) { + is_weight_ = is_weight; + } + private: Allocator *allocator_; DataType dtype_; @@ -409,7 +413,7 @@ class Tensor { bool is_buffer_owner_; bool unused_; std::string name_; - const bool is_weight_; + bool is_weight_; float scale_; int32_t zero_point_; diff --git a/mace/core/testing/test_benchmark_main.cc b/mace/core/testing/test_benchmark_main.cc index 569a8345c147a763a2c2036b4ac082e60caed856..49c2632653626671aa2d4cfa410c79eed208df29 100644 --- a/mace/core/testing/test_benchmark_main.cc +++ b/mace/core/testing/test_benchmark_main.cc @@ -33,7 +33,8 @@ int main(int argc, char **argv) { // config runtime mace::MaceStatus status = mace::SetOpenMPThreadsAndAffinityPolicy( FLAGS_omp_num_threads, - static_cast(FLAGS_cpu_affinity_policy)); + static_cast(FLAGS_cpu_affinity_policy), + true); if (status != mace::MACE_SUCCESS) { LOG(WARNING) << "Set openmp or cpu affinity failed."; } diff --git a/mace/kernels/gemm_test.cc b/mace/kernels/gemm_test.cc index def7c417ce72abe14a5f2025b42b0299a22ec1a4..0942247d4b40b04578507bbf127073d8fca73194 100644 --- a/mace/kernels/gemm_test.cc +++ b/mace/kernels/gemm_test.cc @@ -74,25 +74,26 @@ void GemvTest(index_t batch, index_t N, index_t M) { } } -void SGemmTest(index_t N, +void SGemmTest(index_t batch, + index_t N, index_t K, index_t M, bool transpose_a, bool transpose_b) { - std::unique_ptr A(new float[N * K]); - std::unique_ptr B(new float[K * M]); - std::unique_ptr C(new float[N * M]); - std::unique_ptr C_ref(new float[N * M]); + std::unique_ptr A(new float[batch * N * K]); + std::unique_ptr B(new float[batch * K * M]); + std::unique_ptr C(new float[batch * N * M]); + std::unique_ptr C_ref(new float[batch * N * M]); std::random_device rd; std::mt19937 gen(rd()); std::normal_distribution nd(0, 1); - std::generate(A.get(), A.get() + N * K, + std::generate(A.get(), A.get() + batch * N * K, [&gen, &nd] { return nd(gen); }); - std::generate(B.get(), B.get() + K * M, + std::generate(B.get(), B.get() + batch * K * M, [&gen, &nd] { return nd(gen); }); - kernels::GemmRef(A.get(), B.get(), 1, N, K, M, C_ref.get(), transpose_a, + kernels::GemmRef(A.get(), B.get(), batch, N, K, M, C_ref.get(), transpose_a, transpose_b); kernels::MatrixMap matrix_a; @@ -100,22 +101,38 @@ void SGemmTest(index_t N, if (!transpose_a) { matrix_a = - kernels::MatrixMap(N, K, kernels::RowMajor, A.get()); + kernels::MatrixMap(batch, + N, + K, + kernels::RowMajor, + A.get()); } else { matrix_a = - kernels::MatrixMap(K, N, kernels::RowMajor, A.get()); + kernels::MatrixMap(batch, + K, + N, + kernels::RowMajor, + A.get()); matrix_a = matrix_a.transpose(); } if (!transpose_b) { matrix_b = - kernels::MatrixMap(K, M, kernels::RowMajor, B.get()); + kernels::MatrixMap(batch, + K, + M, + kernels::RowMajor, + B.get()); } else { matrix_b = - kernels::MatrixMap(M, K, kernels::RowMajor, B.get()); + kernels::MatrixMap(batch, + M, + K, + kernels::RowMajor, + B.get()); matrix_b = matrix_b.transpose(); } - kernels::MatrixMap matrix_c(N, M, kernels::RowMajor, C.get()); + kernels::MatrixMap matrix_c(batch, N, M, kernels::RowMajor, C.get()); kernels::SGemm sgemm; sgemm(matrix_a, matrix_b, &matrix_c); @@ -168,26 +185,21 @@ TEST(GEMMTest, gemv) { } namespace { -void TestSGemmTranspose(index_t N, index_t K, index_t M) { - SGemmTest(N, K, M, false, false); - SGemmTest(N, K, M, true, false); - SGemmTest(N, K, M, false, true); - SGemmTest(N, K, M, true, true); +void TestSGemmTranspose(index_t batch, index_t N, index_t K, index_t M) { + SGemmTest(batch, N, K, M, false, false); + SGemmTest(batch, N, K, M, true, false); + SGemmTest(batch, N, K, M, false, true); + SGemmTest(batch, N, K, M, true, true); } } -TEST(SGEMMTest, AlignedWithoutBatch) { - TestSGemmTranspose(4, 4, 4); - TestSGemmTranspose(8, 8, 8); - TestSGemmTranspose(16, 16, 16); -} - TEST(SGEMMTest, UnalignedWithoutBatch) { std::vector tests{1, 5, 14, 31, 47}; for (index_t N : tests) { for (index_t K : tests) { for (index_t M : tests) { - TestSGemmTranspose(N, K, M); + TestSGemmTranspose(1, N, K, M); + TestSGemmTranspose(16, N, K, M); } } } diff --git a/mace/kernels/matmul.h b/mace/kernels/matmul.h index 4b6c5cf1ef8309281178fd52f545556b87a80190..9c5292d2ac9f74cf29de36c2dc0f75502e875cdd 100644 --- a/mace/kernels/matmul.h +++ b/mace/kernels/matmul.h @@ -32,6 +32,7 @@ #include "mace/kernels/kernel.h" #include "mace/utils/utils.h" #include "mace/kernels/gemmlowp_util.h" +#include "mace/kernels/sgemm.h" #ifdef MACE_ENABLE_OPENCL #include "mace/core/runtime/opencl/cl2_header.h" @@ -83,39 +84,34 @@ struct MatMulFunctor : OpKernel { const T *b_ptr_base = B->data(); T *c_ptr_base = C->mutable_data(); - memset(c_ptr_base, 0, batch * height * width * sizeof(T)); - - if (height == 1 && width > 1 && B->is_weight()) { - // A * B = (B^T * A^T)^T - if (!transpose_b) { - if (B_transpose_.get() == nullptr) { - B_transpose_.reset(new Tensor(context_->device()->allocator(), - DataTypeToEnum::v())); - B_transpose_->Resize({batch, width, K}); - Tensor::MappingGuard guardbt(B_transpose_.get()); - T *bt_ptr_base = B_transpose_->mutable_data(); - Transpose(b_ptr_base, K, width, width, bt_ptr_base); - } - Tensor::MappingGuard guardbt(B_transpose_.get()); - T *bt_ptr_base = B_transpose_->mutable_data(); - Gemv(bt_ptr_base, a_ptr_base, batch, K, width, c_ptr_base); - } else { - Gemv(b_ptr_base, a_ptr_base, batch, K, width, c_ptr_base); - } - } else { - Gemm(a_ptr_base, b_ptr_base, batch, height, K, width, c_ptr_base, - transpose_a, transpose_b); - } - + const index_t height_a = A->dim(rank - 2); + const index_t width_a = A->dim(rank - 1); + const index_t height_b = B->dim(rank - 2); + const index_t width_b = B->dim(rank - 1); + + sgemm_.Run(a_ptr_base, + b_ptr_base, + batch, + height_a, + width_a, + height_b, + width_b, + transpose_a, + transpose_b, + A->is_weight(), + B->is_weight(), + c_ptr_base, + context_->workspace()->GetScratchBuffer(D)); return MACE_SUCCESS; } - std::unique_ptr B_transpose_; + SGemm sgemm_; }; template <> struct MatMulFunctor : OpKernel { explicit MatMulFunctor(OpKernelContext *context) : OpKernel(context) {} + template void MatMulImpl(const Tensor *A, const Tensor *B, @@ -213,6 +209,7 @@ struct MatMulFunctor : OpKernel { template struct MatMulFunctor : OpKernel { explicit MatMulFunctor(OpKernelContext *context) : OpKernel(context) {} + MaceStatus operator()(const Tensor *A, const Tensor *B, Tensor *C, diff --git a/mace/kernels/matmul_benchmark.cc b/mace/kernels/matmul_benchmark.cc index 0b78945637d93f96e13132322ed49a46328e68eb..be76a88ec9ea7efe7cfd3ac1e93dc313b8141dcf 100644 --- a/mace/kernels/matmul_benchmark.cc +++ b/mace/kernels/matmul_benchmark.cc @@ -114,9 +114,11 @@ void MatmulBenchmark_Mace_SGemm(int iters, int m, int k, int n) { std::vector rhs(k * n); std::vector result(m * n); - kernels::MatrixMap matrix_lhs(m, k, RowMajor, lhs.data(), true); - kernels::MatrixMap matrix_rhs(k, n, RowMajor, rhs.data(), true); - kernels::MatrixMap matrix_result(m, n, RowMajor, result.data()); + kernels::MatrixMap matrix_lhs(1, m, k, RowMajor, lhs.data(), + true); + kernels::MatrixMap matrix_rhs(1, k, n, RowMajor, rhs.data(), + true); + kernels::MatrixMap matrix_result(1, m, n, RowMajor, result.data()); kernels::SGemm sgemm; diff --git a/mace/kernels/sgemm.cc b/mace/kernels/sgemm.cc index 49c2717815c49d6fcc96ba7f5ef65d08cade8f3b..e1be6f99cb1ea81072b07c85b208ef53153ab1c5 100644 --- a/mace/kernels/sgemm.cc +++ b/mace/kernels/sgemm.cc @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "mace/kernels/sgemm.h" +#include "mace/core/runtime/cpu/cpu_runtime.h" -#include #if defined(MACE_ENABLE_NEON) #include @@ -29,29 +31,118 @@ namespace kernels { void SGemm::operator()(const MatrixMap &lhs, const MatrixMap &rhs, - MatrixMap *result) { + MatrixMap *result, + ScratchBuffer *scratch_buffer) { if (rhs.col() < lhs.row()) { MatrixMap lhs_transpose = lhs.transpose(); MatrixMap rhs_transpose = rhs.transpose(); MatrixMap result_transpose = result->transpose(); - return operator()(rhs_transpose, lhs_transpose, &result_transpose); + return operator()(rhs_transpose, + lhs_transpose, + &result_transpose, + scratch_buffer); } - if (!packed_ || !lhs.is_const()) { - PackLhs(lhs, &packed_lhs_); + if (scratch_buffer != nullptr) { + scratch_buffer->Rewind(); + index_t total_size = result->size(); + if (!lhs.is_const()) { + total_size += lhs.size(); + } + if (!rhs.is_const()) { + total_size += rhs.size(); + } + scratch_buffer->GrowSize(total_size * sizeof(float)); + + scratch_buffer->Rewind(); + if (!lhs.is_const()) { + packed_lhs_.reset(new Tensor(scratch_buffer->Scratch( + lhs.size() * sizeof(float)), DT_FLOAT)); + } + if (!rhs.is_const()) { + packed_lhs_.reset(new Tensor(scratch_buffer->Scratch( + rhs.size() * sizeof(float)), DT_FLOAT)); + } + packed_result_.reset(new Tensor(scratch_buffer->Scratch( + result->size() * sizeof(float)), DT_FLOAT)); + } + + if (packed_lhs_.get() == nullptr) { + packed_lhs_.reset(new Tensor(GetCPUAllocator(), DT_FLOAT)); + packed_lhs_->Resize({lhs.size()}); + } + if (packed_rhs_.get() == nullptr) { + packed_rhs_.reset(new Tensor(GetCPUAllocator(), DT_FLOAT)); + packed_rhs_->Resize({rhs.size()}); } - if (!packed_ || !rhs.is_const()) { - PackRhs(rhs, &packed_rhs_); + if (packed_result_.get() == nullptr) { + packed_result_.reset(new Tensor(GetCPUAllocator(), DT_FLOAT)); + packed_result_->Resize({result->size()}); + } + + if (!lhs.is_const() || !packed_) { + PackLhs(lhs, packed_lhs_.get()); + } + if (!rhs.is_const() || !packed_) { + PackRhs(rhs, packed_rhs_.get()); } packed_ = true; - operator()(packed_lhs_, - packed_rhs_, - lhs.row(), - lhs.col(), - rhs.col(), - &packed_result_); - UnPack(packed_result_, result); + RunInternal(*packed_lhs_, + *packed_rhs_, + lhs.batch(), + lhs.row(), + lhs.col(), + rhs.col(), + packed_result_.get()); + + UnPack(*packed_result_, result); +} + +void SGemm::Run(const float *A, + const float *B, + const index_t batch, + const index_t height_a, + const index_t width_a, + const index_t height_b, + const index_t width_b, + const bool transpose_a, + const bool transpose_b, + const bool is_a_weight, + const bool is_b_weight, + float *C, + ScratchBuffer *scratch_buffer) { + index_t height_c = height_a; + index_t width_c = width_b; + if (transpose_a) { + height_c = width_a; + } + if (transpose_b) { + width_c = height_b; + } + + MatrixMap matrix_a = + MatrixMap(batch, + height_a, + width_a, + kernels::RowMajor, + A, + is_a_weight); + MatrixMap matrix_b = + kernels::MatrixMap(batch, + height_b, + width_b, + kernels::RowMajor, + B, + is_b_weight); + if (transpose_a) { + matrix_a = matrix_a.transpose(); + } + if (transpose_b) { + matrix_b = matrix_b.transpose(); + } + MatrixMap matrix_c(batch, height_c, width_c, kernels::RowMajor, C); + operator()(matrix_a, matrix_b, &matrix_c, scratch_buffer); } #if defined(MACE_ENABLE_NEON) @@ -141,17 +232,43 @@ void SGemm::operator()(const MatrixMap &lhs, #endif // __aarch64__ #endif // MACE_ENABLE_NEON -void SGemm::operator()(const PackedBlock &lhs, - const PackedBlock &rhs, - const index_t height, - const index_t depth, - const index_t width, - PackedBlock *result) { - result->tensor()->Resize({height * width}); - const float *lhs_data = lhs.data(); - const float *rhs_data = rhs.data(); - float *result_data = result->mutable_data(); +void SGemm::RunInternal(const PackedBlock &lhs, + const PackedBlock &rhs, + const index_t batch, + const index_t height, + const index_t depth, + const index_t width, + PackedBlock *result) { + const float *lhs_data = lhs.data(); + const float *rhs_data = rhs.data(); + float *result_data = result->mutable_data(); + +#define MACE_SGEMM_RUN_PER_BATCH \ + for (index_t b = 0; b < batch; ++b) { \ + RunPerBatch(lhs_data + b * height * depth, \ + rhs_data + b * depth * width, \ + height, \ + depth, \ + width, \ + result_data + b * height * width); \ + } + + if (batch >= MaceOpenMPThreadCount) { +#pragma omp parallel for + MACE_SGEMM_RUN_PER_BATCH + } else { + MACE_SGEMM_RUN_PER_BATCH + } + +#undef MACE_SGEMM_RUN_PER_BATCH +} +void SGemm::RunPerBatch(const float *lhs_data, + const float *rhs_data, + const index_t height, + const index_t depth, + const index_t width, + float *result_data) { #if defined(MACE_ENABLE_NEON) const index_t block_w = width >> 2; const index_t remain_w = width - (block_w << 2); @@ -508,11 +625,10 @@ void SGemm::operator()(const PackedBlock &lhs, float32x4_t c0 = vdupq_n_f32(0.f); -#if defined(__aarch64__) + // d: 8 block_d = remain_d >> 3; remain_d -= (block_d << 3); - // d: 8 for (index_t bd = 0; bd < block_d; ++bd) { // 1.8.4 float32x4_t a0, a1; @@ -535,7 +651,6 @@ void SGemm::operator()(const PackedBlock &lhs, lhs_ptr += 8; rhs_ptr += 32; } -#endif // __aarch64__ block_d = remain_d >> 2; remain_d -= (block_d << 2); @@ -608,7 +723,7 @@ void SGemm::operator()(const PackedBlock &lhs, index_t remain_d = depth; - float32x4_t c0, c1, c2, c3, c4, c5, c6, c7; + float32x4_t c0, c1; c0 = vdupq_n_f32(0.f); c1 = vdupq_n_f32(0.f); @@ -787,93 +902,74 @@ void SGemm::operator()(const PackedBlock &lhs, } void SGemm::PackLhs(const MatrixMap &lhs, - PackedBlock *packed_block) { + PackedBlock *packed_block) { Pack(lhs, PackOrder::ColMajor, packed_block); } void SGemm::PackRhs(const MatrixMap &rhs, - PackedBlock *packed_block) { + PackedBlock *packed_block) { Pack(rhs, PackOrder::RowMajor, packed_block); } -void SGemm::UnPack(const PackedBlock &packed_result, +void SGemm::Pack(const MatrixMap &src, + const PackOrder order, + PackedBlock *packed_block) { + MACE_CHECK_NOTNULL(packed_block); + + const index_t height = src.row(); + const index_t width = src.col(); + auto packed_data = packed_block->mutable_data(); + +#define MACE_SGEMM_PACK_PER_BATCH \ + for (index_t b = 0; b < src.batch(); ++b) { \ + PackPerBatch(src, order, b, packed_data + b * height * width); \ + } + if (src.batch() >= MaceOpenMPThreadCount) { +#pragma omp parallel for + MACE_SGEMM_PACK_PER_BATCH + } else { + MACE_SGEMM_PACK_PER_BATCH + } +#undef MACE_SGEMM_PACK_PER_BATCH +} + +void SGemm::UnPack(const PackedBlock &packed_result, MatrixMap *matrix_map) { MACE_CHECK_NOTNULL(matrix_map); const index_t height = matrix_map->row(); const index_t width = matrix_map->col(); - auto packed_data = packed_result.data(); - auto unpacked_data = matrix_map->data(); + auto packed_data = packed_result.data(); - if (matrix_map->major() == Major::RowMajor) { - // This is for non-transposed result - index_t w = 0; -#if defined(MACE_ENABLE_NEON) - #pragma omp parallel for - for (index_t iw = w; iw <= width - 4; iw += 4) { - const float *packed_data_ptr = packed_data + iw * height; - float *unpacked_data_ptr = unpacked_data + iw; - for (index_t h = 0; h < height; ++h) { - const index_t packed_offset = h * 4; - const index_t unpacked_offset = h * width; - float32x4_t vs = vld1q_f32(packed_data_ptr + packed_offset); - vst1q_f32(unpacked_data_ptr + unpacked_offset, vs); - } - } - w += (width - w) / 4 * 4; -#endif +#define MACE_SGEMM_UNPACK_PER_BATCH \ + for (index_t b = 0; b < matrix_map->batch(); ++b) { \ + UnPackPerBatch(packed_data + b * height * width, b, matrix_map); \ + } + + if (matrix_map->batch() >= MaceOpenMPThreadCount) { #pragma omp parallel for - for (index_t iw = w; iw < width; ++iw) { - const float *packed_data_ptr = packed_data + iw * height; - float *unpacked_data_ptr = unpacked_data + iw; - for (index_t h = 0; h < height; ++h) { - unpacked_data_ptr[h * width] = packed_data_ptr[h]; - } - } + MACE_SGEMM_UNPACK_PER_BATCH } else { - // This is for transposed result - index_t w = 0; -#if defined(MACE_ENABLE_NEON) - #pragma omp parallel for - for (index_t iw = w; iw <= width - 4; iw += 4) { - const float *packed_data_ptr = packed_data + iw * height; - float *unpacked_data_ptr = unpacked_data + iw * height; - for (index_t h = 0; h < height; ++h) { - const index_t packed_offset = h * 4; - const index_t unpacked_offset = h; - float32x4_t vs = vld1q_f32(packed_data_ptr + packed_offset); - unpacked_data_ptr[unpacked_offset] = vs[0]; - unpacked_data_ptr[unpacked_offset + height] = vs[1]; - unpacked_data_ptr[unpacked_offset + 2 * height] = vs[2]; - unpacked_data_ptr[unpacked_offset + 3 * height] = vs[3]; - } - } - w += (width - w) / 4 * 4; -#endif -#pragma omp parallel for - for (index_t iw = w; iw < width; ++iw) { - std::copy_n( - packed_data + iw * height, height, unpacked_data + iw * height); - } + MACE_SGEMM_UNPACK_PER_BATCH } +#undef MACE_SGEMM_UNPACK_PER_BATCH } -void SGemm::Pack(const MatrixMap &src, - const PackOrder order, - PackedBlock *packed_block) { - MACE_CHECK_NOTNULL(packed_block); +void SGemm::PackPerBatch(const MatrixMap &src, + const PackOrder order, + const index_t batch_index, + float *packed_data) { + MACE_CHECK_NOTNULL(packed_data); const index_t height = src.row(); const index_t width = src.col(); - packed_block->tensor()->Resize({height * width}); - auto src_data = src.data(); - auto packed_data = packed_block->mutable_data(); + auto src_data = src.batch_data(batch_index); if (src.major() == Major::RowMajor && order == PackOrder::ColMajor) { // This is for packing no-transpose lhs. index_t h = 0; #if defined(MACE_ENABLE_NEON) - #if defined(__aarch64__) +#if defined(__aarch64__) #pragma omp parallel for for (index_t ih = h; ih <= height - 8; ih += 8) { const float *src_data_ptr = src_data + ih * width; @@ -919,7 +1015,7 @@ void SGemm::Pack(const MatrixMap &src, // This is for packing transpose-needed lhs. index_t h = 0; #if defined(MACE_ENABLE_NEON) - #if defined(__aarch64__) +#if defined(__aarch64__) #pragma omp parallel for for (index_t ih = h; ih <= height - 8; ih += 8) { const float *src_data_ptr = src_data + ih; @@ -960,7 +1056,7 @@ void SGemm::Pack(const MatrixMap &src, // This is for packing no-transpose rhs. index_t w = 0; #if defined(MACE_ENABLE_NEON) - #pragma omp parallel for +#pragma omp parallel for for (index_t iw = w; iw <= width - 4; iw += 4) { const float *src_data_ptr = src_data + iw; float *packed_data_ptr = packed_data + iw * height; @@ -985,7 +1081,7 @@ void SGemm::Pack(const MatrixMap &src, // This is for packing transpose-needed rhs. index_t w = 0; #if defined(MACE_ENABLE_NEON) - #pragma omp parallel for +#pragma omp parallel for for (index_t iw = w; iw <= width - 4; iw += 4) { const float *src_data_ptr = src_data + iw * height; float *packed_data_ptr = packed_data + iw * height; @@ -1008,5 +1104,67 @@ void SGemm::Pack(const MatrixMap &src, } } +void SGemm::UnPackPerBatch(const float *packed_data, + const index_t batch_index, + MatrixMap *matrix_map) { + MACE_CHECK_NOTNULL(matrix_map); + + const index_t height = matrix_map->row(); + const index_t width = matrix_map->col(); + auto unpacked_data = matrix_map->batch_data(batch_index); + + if (matrix_map->major() == Major::RowMajor) { + // This is for non-transposed result + index_t w = 0; +#if defined(MACE_ENABLE_NEON) +#pragma omp parallel for + for (index_t iw = w; iw <= width - 4; iw += 4) { + const float *packed_data_ptr = packed_data + iw * height; + float *unpacked_data_ptr = unpacked_data + iw; + for (index_t h = 0; h < height; ++h) { + const index_t packed_offset = h * 4; + const index_t unpacked_offset = h * width; + float32x4_t vs = vld1q_f32(packed_data_ptr + packed_offset); + vst1q_f32(unpacked_data_ptr + unpacked_offset, vs); + } + } + w += (width - w) / 4 * 4; +#endif +#pragma omp parallel for + for (index_t iw = w; iw < width; ++iw) { + const float *packed_data_ptr = packed_data + iw * height; + float *unpacked_data_ptr = unpacked_data + iw; + for (index_t h = 0; h < height; ++h) { + unpacked_data_ptr[h * width] = packed_data_ptr[h]; + } + } + } else { + // This is for transposed result + index_t w = 0; +#if defined(MACE_ENABLE_NEON) +#pragma omp parallel for + for (index_t iw = w; iw <= width - 4; iw += 4) { + const float *packed_data_ptr = packed_data + iw * height; + float *unpacked_data_ptr = unpacked_data + iw * height; + for (index_t h = 0; h < height; ++h) { + const index_t packed_offset = h * 4; + const index_t unpacked_offset = h; + float32x4_t vs = vld1q_f32(packed_data_ptr + packed_offset); + unpacked_data_ptr[unpacked_offset] = vs[0]; + unpacked_data_ptr[unpacked_offset + height] = vs[1]; + unpacked_data_ptr[unpacked_offset + 2 * height] = vs[2]; + unpacked_data_ptr[unpacked_offset + 3 * height] = vs[3]; + } + } + w += (width - w) / 4 * 4; +#endif +#pragma omp parallel for + for (index_t iw = w; iw < width; ++iw) { + std::copy_n( + packed_data + iw * height, height, unpacked_data + iw * height); + } + } +} + } // namespace kernels } // namespace mace diff --git a/mace/kernels/sgemm.h b/mace/kernels/sgemm.h index daed206a349d86d119637b6d367b8cfa0d8dc5c8..c24c9c03e55047e157208caed5139e3a1b93380f 100644 --- a/mace/kernels/sgemm.h +++ b/mace/kernels/sgemm.h @@ -39,11 +39,13 @@ class MatrixMap { public: MatrixMap() {} - MatrixMap(const index_t row, + MatrixMap(const index_t batch, + const index_t row, const index_t col, const Major major, T *data, const bool is_const = false) : + batch_(batch), row_(row), col_(col), stride_(major == RowMajor ? col : row), @@ -53,7 +55,11 @@ class MatrixMap { MatrixMap transpose() const { Major transpose_major = major_ == RowMajor ? ColMajor : RowMajor; - return MatrixMap(col_, row_, transpose_major, data_, is_const_); + return MatrixMap(batch_, col_, row_, transpose_major, data_, is_const_); + } + + index_t batch() const { + return batch_; } index_t row() const { @@ -76,8 +82,12 @@ class MatrixMap { return data_; } - T *data(int row, int col) const { - return data_ + row * stride_ + col; + T *batch_data(index_t batch) const { + return data_ + batch * row_ * col_; + } + + index_t size() const { + return batch_ * row_ * col_; } bool is_const() const { @@ -85,6 +95,7 @@ class MatrixMap { } private: + index_t batch_; index_t row_; index_t col_; index_t stride_; @@ -94,61 +105,76 @@ class MatrixMap { }; typedef Major PackOrder; - -template -class PackedBlock { - public: - PackedBlock() : data_tensor_(GetCPUAllocator(), - DataTypeToEnum::v()) {} - - const T *data() const { - return data_tensor_.data(); - } - - T *mutable_data() { - return data_tensor_.mutable_data(); - } - - Tensor *tensor() { - return &data_tensor_; - } - - private: - Tensor data_tensor_; -}; +typedef Tensor PackedBlock; class SGemm { public: - SGemm(): packed_(false) {} + SGemm() + : packed_lhs_(nullptr), + packed_rhs_(nullptr), + packed_(false) {} void operator()(const MatrixMap &lhs, const MatrixMap &rhs, - MatrixMap *result); - - void operator()(const PackedBlock &lhs, - const PackedBlock &rhs, - const index_t height, - const index_t depth, - const index_t width, - PackedBlock *result); + MatrixMap *result, + ScratchBuffer *scratch_buffer = nullptr); + + void Run(const float *A, + const float *B, + const index_t batch, + const index_t height_a, + const index_t width_a, + const index_t height_b, + const index_t width_b, + const bool transpose_a, + const bool transpose_b, + const bool is_a_weight, + const bool is_b_weight, + float *C, + ScratchBuffer *scratch_buffer = nullptr); void PackLhs(const MatrixMap &lhs, - PackedBlock *packed_block); + PackedBlock *packed_block); void PackRhs(const MatrixMap &rhs, - PackedBlock *packed_block); + PackedBlock *packed_block); - void UnPack(const PackedBlock &packed_result, + void UnPack(const PackedBlock &packed_result, MatrixMap *matrix_map); private: void Pack(const MatrixMap &src, const PackOrder order, - PackedBlock *packed_block); + PackedBlock *packed_block); + + void PackPerBatch(const MatrixMap &src, + const PackOrder order, + const index_t batch_index, + float *packed_data); + + void UnPackPerBatch(const float *packed_data, + const index_t batch_index, + MatrixMap *matrix_map); + + void RunInternal(const PackedBlock &lhs, + const PackedBlock &rhs, + const index_t batch, + const index_t height, + const index_t depth, + const index_t width, + PackedBlock *result); + + void RunPerBatch(const float *lhs, + const float *rhs, + const index_t height, + const index_t depth, + const index_t width, + float *result); + + std::unique_ptr packed_lhs_; + std::unique_ptr packed_rhs_; + std::unique_ptr packed_result_; - PackedBlock packed_lhs_; - PackedBlock packed_rhs_; - PackedBlock packed_result_; bool packed_; }; diff --git a/mace/kernels/sgemm_test.cc b/mace/kernels/sgemm_pack_test.cc similarity index 89% rename from mace/kernels/sgemm_test.cc rename to mace/kernels/sgemm_pack_test.cc index 095ea1b185cdb78078062781b1b62ed7b6cdbcc2..3e7aaa98e362896eb9da1452b589f3d678a56c54 100644 --- a/mace/kernels/sgemm_test.cc +++ b/mace/kernels/sgemm_pack_test.cc @@ -31,17 +31,17 @@ void TestPack(const std::vector &data, Major src_order, PackOrder pack_order) { SGemm sg; - MatrixMap src_matrix(height, width, src_order, data.data()); - PackedBlock packed; - packed.tensor()->Resize({height, width}); + MatrixMap src_matrix(1, height, width, src_order, data.data()); + PackedBlock packed; + packed.Resize({height, width}); if (pack_order == PackOrder::ColMajor) { sg.PackLhs(src_matrix, &packed); } else { sg.PackRhs(src_matrix, &packed); } - auto packed_data = packed.data(); - for (index_t i = 0; i < packed.tensor()->size(); ++i) { + auto packed_data = packed.data(); + for (index_t i = 0; i < packed.size(); ++i) { EXPECT_EQ(expected_data[i], packed_data[i]); } } @@ -57,9 +57,9 @@ void TestUnPack(const index_t height, data[i] = rand_r(&seed); } - MatrixMap src_matrix(height, width, src_order, data.data()); - PackedBlock packed; - packed.tensor()->Resize({height, width}); + MatrixMap src_matrix(1, height, width, src_order, data.data()); + PackedBlock packed; + packed.Resize({height, width}); SGemm sg; if (pack_order == PackOrder::ColMajor) { sg.PackLhs(src_matrix, &packed); @@ -68,17 +68,18 @@ void TestUnPack(const index_t height, } std::vector unpacked(matrix_size); - MatrixMap unpacked_matrix(height, width, src_order, unpacked.data()); + MatrixMap + unpacked_matrix(1, height, width, src_order, unpacked.data()); sg.UnPack(packed, &unpacked_matrix); auto unpacked_data = unpacked.data(); - for (index_t i = 0; i < packed.tensor()->size(); ++i) { + for (index_t i = 0; i < packed.size(); ++i) { EXPECT_EQ(data[i], unpacked_data[i]); } } } // namespace -TEST(SGemmTest, Pack) { +TEST(SGemmPackTest, Pack) { std::vector data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}; @@ -149,7 +150,7 @@ TEST(SGemmTest, Pack) { #endif } -TEST(SGemmTest, UnPack) { +TEST(SGemmPackTest, UnPack) { TestUnPack(4, 3, Major::RowMajor, PackOrder::RowMajor); TestUnPack(4, 4, Major::RowMajor, PackOrder::RowMajor); TestUnPack(4, 5, Major::RowMajor, PackOrder::RowMajor); diff --git a/mace/ops/matmul.h b/mace/ops/matmul.h index ceccb9398aaa7d5b730951672c0370e5509e1f7f..64b336a38d98a52be95917a7d9750abdc1c6e1a9 100644 --- a/mace/ops/matmul.h +++ b/mace/ops/matmul.h @@ -40,7 +40,11 @@ class MatMulOp : public Operator { "than or equal to 2"); index_t rank = A->dim_size(); for (index_t i = 0; i < rank - 2; ++i) { - MACE_CHECK(A->dim(i) == B->dim(i), "batch dimensions are not equal"); + MACE_CHECK(A->dim(i) == B->dim(i), + "batch dimensions are not equal: ", + A->dim(i), + " vs. ", + B->dim(i)); } index_t ak = transpose_a_ ? A->dim(rank - 2) : A->dim(rank - 1); index_t bk = transpose_b_ ? B->dim(rank - 1) : B->dim(rank - 2); diff --git a/mace/ops/matmul_benchmark.cc b/mace/ops/matmul_benchmark.cc index 146f8d1c541771eb23ee16ad737b2435ce641f78..08b06fa7e83cc9d07436b1f2149766f798d6fb1a 100644 --- a/mace/ops/matmul_benchmark.cc +++ b/mace/ops/matmul_benchmark.cc @@ -33,13 +33,15 @@ void MatMulBenchmark( // Add input data net.AddRandomInput("A", {batch, height, channels}); net.AddRandomInput("B", {batch, channels, out_width}); + net.GetTensor("A")->SetIsWeight(true); + net.GetTensor("B")->SetIsWeight(true); if (DataTypeToEnum::value == DT_UINT8) { net.GetTensor("A")->SetScale(0.1); net.GetTensor("B")->SetScale(0.1); } - if (D == DeviceType::GPU) { - BufferToImage(&net, "A", "AImage", kernels::BufferType::IN_OUT_WIDTH); + BufferToImage(&net, "A", "AImage", + kernels::BufferType::IN_OUT_WIDTH); BufferToImage(&net, "B", "BImage", kernels::BufferType::IN_OUT_HEIGHT); @@ -71,7 +73,7 @@ void MatMulBenchmark( mace::testing::StartTiming(); while (iters--) { - net.RunOp(D); + net.Run(); } net.Sync(); } @@ -86,6 +88,8 @@ void MatMulTransposeBenchmark( // Add input data net.AddRandomInput("A", {batch, height, channels}); net.AddRandomInput("B", {batch, out_width, channels}); + net.GetTensor("A")->SetIsWeight(true); + net.GetTensor("B")->SetIsWeight(true); if (DataTypeToEnum::value == DT_UINT8) { net.GetTensor("A")->SetScale(0.1); net.GetTensor("B")->SetScale(0.1); @@ -116,7 +120,7 @@ void MatMulTransposeBenchmark( mace::testing::StartTiming(); while (iters--) { - net.RunOp(D); + net.Run(); } net.Sync(); } @@ -154,10 +158,15 @@ void MatMulTransposeBenchmark( MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, float, CPU); \ MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, uint8_t, CPU); +MACE_BM_MATMUL(1, 128, 128, 49); +MACE_BM_MATMUL(2, 128, 128, 49); +MACE_BM_MATMUL(3, 128, 128, 49); +MACE_BM_MATMUL(4, 128, 128, 49); MACE_BM_MATMUL(16, 32, 128, 49); MACE_BM_MATMUL(16, 32, 128, 961); MACE_BM_MATMUL(16, 32, 128, 3969); MACE_BM_MATMUL(16, 128, 128, 49); +MACE_BM_MATMUL(16, 49, 128, 128); MACE_BM_MATMUL(16, 128, 128, 961); MACE_BM_MATMUL(16, 128, 128, 3969); diff --git a/mace/ops/winograd_transform_benchmark.cc b/mace/ops/winograd_transform_benchmark.cc index ecba841742cc23f8585fb7b30d294f3fe7ce9173..9955c9abc35f63c02f7b5461da2cbc425657265f 100644 --- a/mace/ops/winograd_transform_benchmark.cc +++ b/mace/ops/winograd_transform_benchmark.cc @@ -211,8 +211,8 @@ void WinoMatMulBenchmark( const index_t round_w = (width + block_size - 1) / block_size; const index_t out_width = round_h * round_w; // Add input data - net.AddRandomInput("A", {batch, out_channels, in_channels, 1}); - net.AddRandomInput("B", {batch, in_channels, out_width, 1}); + net.AddRandomInput("A", {batch, out_channels, in_channels}); + net.AddRandomInput("B", {batch, in_channels, out_width}); if (D == DeviceType::GPU) { BufferToImage(&net, "A", "AImage", kernels::BufferType::IN_OUT_WIDTH);