diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 119790748dc45d1536f2a92193ed82d7541aa07e..3984fd9dc321a46faa120a4c5f49227453f469fe 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -85,7 +85,7 @@ ndk_versions_compatible_tests: - DEFAULT_NDK_PATH=$ANDROID_NDK_HOME - prefix_path=${DEFAULT_NDK_PATH%android-ndk-*} - > - for ndk in android-ndk-r12b android-ndk-r15c android-ndk-r16 android-ndk-r17b; + for ndk in android-ndk-r15c android-ndk-r16 android-ndk-r17b; do new_ndk_path=${prefix_path}${ndk}; if [ "$new_ndk_path" != "$DEFAULT_NDK_PATH" ]; then diff --git a/mace/core/tensor.h b/mace/core/tensor.h index f7e509876f1564b06cbcd94e433a8ca3c03e197f..5c4c807b1cfe12a4cdf2e771a36857bf978c83ff 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -399,6 +399,10 @@ class Tensor { zero_point_ = zero_point; } + inline void SetIsWeight(bool is_weight) { + is_weight_ = is_weight; + } + private: Allocator *allocator_; DataType dtype_; @@ -409,7 +413,7 @@ class Tensor { bool is_buffer_owner_; bool unused_; std::string name_; - const bool is_weight_; + bool is_weight_; float scale_; int32_t zero_point_; diff --git a/mace/core/testing/test_benchmark_main.cc b/mace/core/testing/test_benchmark_main.cc index 569a8345c147a763a2c2036b4ac082e60caed856..49c2632653626671aa2d4cfa410c79eed208df29 100644 --- a/mace/core/testing/test_benchmark_main.cc +++ b/mace/core/testing/test_benchmark_main.cc @@ -33,7 +33,8 @@ int main(int argc, char **argv) { // config runtime mace::MaceStatus status = mace::SetOpenMPThreadsAndAffinityPolicy( FLAGS_omp_num_threads, - static_cast(FLAGS_cpu_affinity_policy)); + static_cast(FLAGS_cpu_affinity_policy), + true); if (status != mace::MACE_SUCCESS) { LOG(WARNING) << "Set openmp or cpu affinity failed."; } diff --git a/mace/kernels/gemm_test.cc b/mace/kernels/gemm_test.cc index cec9491461db2ff229fdfc672a81ee7c50a7fe04..0942247d4b40b04578507bbf127073d8fca73194 100644 --- a/mace/kernels/gemm_test.cc +++ b/mace/kernels/gemm_test.cc @@ -13,11 +13,13 @@ // limitations under the License. #include +#include #include #include #include "mace/core/types.h" #include "mace/kernels/gemm.h" +#include "mace/kernels/sgemm.h" namespace mace { @@ -72,6 +74,74 @@ void GemvTest(index_t batch, index_t N, index_t M) { } } +void SGemmTest(index_t batch, + index_t N, + index_t K, + index_t M, + bool transpose_a, + bool transpose_b) { + std::unique_ptr A(new float[batch * N * K]); + std::unique_ptr B(new float[batch * K * M]); + std::unique_ptr C(new float[batch * N * M]); + std::unique_ptr C_ref(new float[batch * N * M]); + + std::random_device rd; + std::mt19937 gen(rd()); + std::normal_distribution nd(0, 1); + + std::generate(A.get(), A.get() + batch * N * K, + [&gen, &nd] { return nd(gen); }); + std::generate(B.get(), B.get() + batch * K * M, + [&gen, &nd] { return nd(gen); }); + kernels::GemmRef(A.get(), B.get(), batch, N, K, M, C_ref.get(), transpose_a, + transpose_b); + + kernels::MatrixMap matrix_a; + kernels::MatrixMap matrix_b; + + if (!transpose_a) { + matrix_a = + kernels::MatrixMap(batch, + N, + K, + kernels::RowMajor, + A.get()); + } else { + matrix_a = + kernels::MatrixMap(batch, + K, + N, + kernels::RowMajor, + A.get()); + matrix_a = matrix_a.transpose(); + } + + if (!transpose_b) { + matrix_b = + kernels::MatrixMap(batch, + K, + M, + kernels::RowMajor, + B.get()); + } else { + matrix_b = + kernels::MatrixMap(batch, + M, + K, + kernels::RowMajor, + B.get()); + matrix_b = matrix_b.transpose(); + } + kernels::MatrixMap matrix_c(batch, N, M, kernels::RowMajor, C.get()); + + kernels::SGemm sgemm; + sgemm(matrix_a, matrix_b, &matrix_c); + + for (int i = 0; i < N * M; ++i) { + EXPECT_NEAR(C_ref[i], C[i], 0.1); + } +} + } // namespace TEST(GEMMTest, AlignedWithoutBatch) { @@ -114,4 +184,25 @@ TEST(GEMMTest, gemv) { GemvTest(3, 17, 63); } +namespace { +void TestSGemmTranspose(index_t batch, index_t N, index_t K, index_t M) { + SGemmTest(batch, N, K, M, false, false); + SGemmTest(batch, N, K, M, true, false); + SGemmTest(batch, N, K, M, false, true); + SGemmTest(batch, N, K, M, true, true); +} +} + +TEST(SGEMMTest, UnalignedWithoutBatch) { + std::vector tests{1, 5, 14, 31, 47}; + for (index_t N : tests) { + for (index_t K : tests) { + for (index_t M : tests) { + TestSGemmTranspose(1, N, K, M); + TestSGemmTranspose(16, N, K, M); + } + } + } +} + } // namespace mace diff --git a/mace/kernels/matmul.h b/mace/kernels/matmul.h index 4b6c5cf1ef8309281178fd52f545556b87a80190..9c5292d2ac9f74cf29de36c2dc0f75502e875cdd 100644 --- a/mace/kernels/matmul.h +++ b/mace/kernels/matmul.h @@ -32,6 +32,7 @@ #include "mace/kernels/kernel.h" #include "mace/utils/utils.h" #include "mace/kernels/gemmlowp_util.h" +#include "mace/kernels/sgemm.h" #ifdef MACE_ENABLE_OPENCL #include "mace/core/runtime/opencl/cl2_header.h" @@ -83,39 +84,34 @@ struct MatMulFunctor : OpKernel { const T *b_ptr_base = B->data(); T *c_ptr_base = C->mutable_data(); - memset(c_ptr_base, 0, batch * height * width * sizeof(T)); - - if (height == 1 && width > 1 && B->is_weight()) { - // A * B = (B^T * A^T)^T - if (!transpose_b) { - if (B_transpose_.get() == nullptr) { - B_transpose_.reset(new Tensor(context_->device()->allocator(), - DataTypeToEnum::v())); - B_transpose_->Resize({batch, width, K}); - Tensor::MappingGuard guardbt(B_transpose_.get()); - T *bt_ptr_base = B_transpose_->mutable_data(); - Transpose(b_ptr_base, K, width, width, bt_ptr_base); - } - Tensor::MappingGuard guardbt(B_transpose_.get()); - T *bt_ptr_base = B_transpose_->mutable_data(); - Gemv(bt_ptr_base, a_ptr_base, batch, K, width, c_ptr_base); - } else { - Gemv(b_ptr_base, a_ptr_base, batch, K, width, c_ptr_base); - } - } else { - Gemm(a_ptr_base, b_ptr_base, batch, height, K, width, c_ptr_base, - transpose_a, transpose_b); - } - + const index_t height_a = A->dim(rank - 2); + const index_t width_a = A->dim(rank - 1); + const index_t height_b = B->dim(rank - 2); + const index_t width_b = B->dim(rank - 1); + + sgemm_.Run(a_ptr_base, + b_ptr_base, + batch, + height_a, + width_a, + height_b, + width_b, + transpose_a, + transpose_b, + A->is_weight(), + B->is_weight(), + c_ptr_base, + context_->workspace()->GetScratchBuffer(D)); return MACE_SUCCESS; } - std::unique_ptr B_transpose_; + SGemm sgemm_; }; template <> struct MatMulFunctor : OpKernel { explicit MatMulFunctor(OpKernelContext *context) : OpKernel(context) {} + template void MatMulImpl(const Tensor *A, const Tensor *B, @@ -213,6 +209,7 @@ struct MatMulFunctor : OpKernel { template struct MatMulFunctor : OpKernel { explicit MatMulFunctor(OpKernelContext *context) : OpKernel(context) {} + MaceStatus operator()(const Tensor *A, const Tensor *B, Tensor *C, diff --git a/mace/kernels/matmul_benchmark.cc b/mace/kernels/matmul_benchmark.cc index d109de478ba0b1b109633764eec61b4ca868ed42..be76a88ec9ea7efe7cfd3ac1e93dc313b8141dcf 100644 --- a/mace/kernels/matmul_benchmark.cc +++ b/mace/kernels/matmul_benchmark.cc @@ -22,6 +22,7 @@ #include "mace/core/testing/test_benchmark.h" #include "mace/kernels/gemm.h" #include "mace/kernels/gemmlowp_util.h" +#include "mace/kernels/sgemm.h" namespace gemmlowp { @@ -107,6 +108,28 @@ void MatmulBenchmark_Mace(int iters, int m, int k, int n) { } } +void MatmulBenchmark_Mace_SGemm(int iters, int m, int k, int n) { + mace::testing::StopTiming(); + std::vector lhs(m * k); + std::vector rhs(k * n); + std::vector result(m * n); + + kernels::MatrixMap matrix_lhs(1, m, k, RowMajor, lhs.data(), + true); + kernels::MatrixMap matrix_rhs(1, k, n, RowMajor, rhs.data(), + true); + kernels::MatrixMap matrix_result(1, m, n, RowMajor, result.data()); + + kernels::SGemm sgemm; + + sgemm(matrix_lhs, matrix_rhs, &matrix_result); + + mace::testing::StartTiming(); + while (iters--) { + sgemm(matrix_lhs, matrix_rhs, &matrix_result); + } +} + void MatmulBenchmark_Eigen(int iters, int m, int k, int n) { mace::testing::StopTiming(); Eigen::MatrixXf lhs = Eigen::MatrixXf::Random(m, k); @@ -202,6 +225,7 @@ void MatmulBenchmark_gemmlowp_int32(int iters, int rows, int depth, int cols) { #define MACE_BM_MATMUL(M, K, N) \ MACE_BM_MATMUL_FUNC(M, K, N, Mace, float); \ + MACE_BM_MATMUL_FUNC(M, K, N, Mace_SGemm, float); \ MACE_BM_MATMUL_FUNC(M, K, N, Eigen, float); \ MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_uint8, uint8_t); \ MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_int32, uint8_t); @@ -215,15 +239,43 @@ MACE_BM_MATMUL(15, 384, 384); MACE_BM_MATMUL(15, 384, 1536); MACE_BM_MATMUL(15, 1536, 384); -MACE_BM_MATMUL(1, 384, 384); -MACE_BM_MATMUL(1, 384, 1536); -MACE_BM_MATMUL(1, 1536, 384); -MACE_BM_MATMUL(1, 384, 44678); +MACE_BM_MATMUL(1, 256, 256); +MACE_BM_MATMUL(1, 256, 1536); +MACE_BM_MATMUL(1, 1536, 256); +MACE_BM_MATMUL(256, 256, 1); +MACE_BM_MATMUL(1536, 256, 1); +MACE_BM_MATMUL(256, 1536, 1); +MACE_BM_MATMUL(29792, 256, 1); +MACE_BM_MATMUL(1, 256, 29792); +MACE_BM_MATMUL(2, 256, 256); +MACE_BM_MATMUL(2, 256, 1536); +MACE_BM_MATMUL(2, 1536, 256); +MACE_BM_MATMUL(3, 256, 256); +MACE_BM_MATMUL(3, 256, 1536); +MACE_BM_MATMUL(3, 1536, 256); +MACE_BM_MATMUL(4, 256, 256); +MACE_BM_MATMUL(4, 256, 1536); +MACE_BM_MATMUL(4, 1536, 256); +MACE_BM_MATMUL(8, 256, 256); +MACE_BM_MATMUL(8, 256, 1536); +MACE_BM_MATMUL(8, 1536, 256); +MACE_BM_MATMUL(10, 256, 256); +MACE_BM_MATMUL(10, 256, 1536); +MACE_BM_MATMUL(10, 1536, 256); +MACE_BM_MATMUL(15, 256, 256); +MACE_BM_MATMUL(15, 256, 1536); +MACE_BM_MATMUL(15, 1536, 256); // Embedding size 128 MACE_BM_MATMUL(1, 128, 1536); MACE_BM_MATMUL(1, 128, 44678); +// MobileNet +MACE_BM_MATMUL(128, 128, 3136); +MACE_BM_MATMUL(256, 256, 784); +MACE_BM_MATMUL(512, 512, 196); +MACE_BM_MATMUL(1024, 1024, 49); + } // namespace test } // namespace kernels } // namespace mace diff --git a/mace/kernels/sgemm.cc b/mace/kernels/sgemm.cc index ae9a4e0fb37b50f8f7192943b26844ee2d00e227..e1be6f99cb1ea81072b07c85b208ef53153ab1c5 100644 --- a/mace/kernels/sgemm.cc +++ b/mace/kernels/sgemm.cc @@ -12,80 +12,1158 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include +#include #include "mace/kernels/sgemm.h" +#include "mace/core/runtime/cpu/cpu_runtime.h" + #if defined(MACE_ENABLE_NEON) #include #endif +#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__) +#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3]) +#endif + namespace mace { namespace kernels { -void SGemm::operator()(const MatrixMap &lhs, - const MatrixMap &rhs, - MatrixMap *result) { - PackedBlock packed_lhs; - PackLhs(lhs, &packed_lhs); - - PackedBlock packed_rhs; - PackRhs(rhs, &packed_rhs); - - PackedBlock packed_result; - operator()(packed_lhs, - packed_rhs, - lhs.row(), - lhs.col(), - rhs.col(), - &packed_result); - UnPack(packed_result, result); +void SGemm::operator()(const MatrixMap &lhs, + const MatrixMap &rhs, + MatrixMap *result, + ScratchBuffer *scratch_buffer) { + if (rhs.col() < lhs.row()) { + MatrixMap lhs_transpose = lhs.transpose(); + MatrixMap rhs_transpose = rhs.transpose(); + MatrixMap result_transpose = result->transpose(); + return operator()(rhs_transpose, + lhs_transpose, + &result_transpose, + scratch_buffer); + } + + if (scratch_buffer != nullptr) { + scratch_buffer->Rewind(); + index_t total_size = result->size(); + if (!lhs.is_const()) { + total_size += lhs.size(); + } + if (!rhs.is_const()) { + total_size += rhs.size(); + } + scratch_buffer->GrowSize(total_size * sizeof(float)); + + scratch_buffer->Rewind(); + if (!lhs.is_const()) { + packed_lhs_.reset(new Tensor(scratch_buffer->Scratch( + lhs.size() * sizeof(float)), DT_FLOAT)); + } + if (!rhs.is_const()) { + packed_lhs_.reset(new Tensor(scratch_buffer->Scratch( + rhs.size() * sizeof(float)), DT_FLOAT)); + } + packed_result_.reset(new Tensor(scratch_buffer->Scratch( + result->size() * sizeof(float)), DT_FLOAT)); + } + + if (packed_lhs_.get() == nullptr) { + packed_lhs_.reset(new Tensor(GetCPUAllocator(), DT_FLOAT)); + packed_lhs_->Resize({lhs.size()}); + } + if (packed_rhs_.get() == nullptr) { + packed_rhs_.reset(new Tensor(GetCPUAllocator(), DT_FLOAT)); + packed_rhs_->Resize({rhs.size()}); + } + if (packed_result_.get() == nullptr) { + packed_result_.reset(new Tensor(GetCPUAllocator(), DT_FLOAT)); + packed_result_->Resize({result->size()}); + } + + if (!lhs.is_const() || !packed_) { + PackLhs(lhs, packed_lhs_.get()); + } + if (!rhs.is_const() || !packed_) { + PackRhs(rhs, packed_rhs_.get()); + } + packed_ = true; + + RunInternal(*packed_lhs_, + *packed_rhs_, + lhs.batch(), + lhs.row(), + lhs.col(), + rhs.col(), + packed_result_.get()); + + UnPack(*packed_result_, result); +} + +void SGemm::Run(const float *A, + const float *B, + const index_t batch, + const index_t height_a, + const index_t width_a, + const index_t height_b, + const index_t width_b, + const bool transpose_a, + const bool transpose_b, + const bool is_a_weight, + const bool is_b_weight, + float *C, + ScratchBuffer *scratch_buffer) { + index_t height_c = height_a; + index_t width_c = width_b; + if (transpose_a) { + height_c = width_a; + } + if (transpose_b) { + width_c = height_b; + } + + MatrixMap matrix_a = + MatrixMap(batch, + height_a, + width_a, + kernels::RowMajor, + A, + is_a_weight); + MatrixMap matrix_b = + kernels::MatrixMap(batch, + height_b, + width_b, + kernels::RowMajor, + B, + is_b_weight); + if (transpose_a) { + matrix_a = matrix_a.transpose(); + } + if (transpose_b) { + matrix_b = matrix_b.transpose(); + } + MatrixMap matrix_c(batch, height_c, width_c, kernels::RowMajor, C); + operator()(matrix_a, matrix_b, &matrix_c, scratch_buffer); } -void SGemm::operator()(const PackedBlock &lhs, - const PackedBlock &rhs, - const index_t height, - const index_t depth, - const index_t width, - PackedBlock *result) { - (void) lhs; - (void) rhs; - (void) result; - (void) height; - (void) depth; - (void) width; +#if defined(MACE_ENABLE_NEON) +#if defined(__aarch64__) + +// calculate 8 rows, 4 cols for each depth +#define MACE_SGEMM_PART_CAL_R8_C4_D1(D, VD, VDN) \ + c0 = vfmaq_laneq_f32(c0, b##D, a##VD, 0); \ + c1 = vfmaq_laneq_f32(c1, b##D, a##VD, 1); \ + c2 = vfmaq_laneq_f32(c2, b##D, a##VD, 2); \ + c3 = vfmaq_laneq_f32(c3, b##D, a##VD, 3); \ + c4 = vfmaq_laneq_f32(c4, b##D, a##VDN, 0); \ + c5 = vfmaq_laneq_f32(c5, b##D, a##VDN, 1); \ + c6 = vfmaq_laneq_f32(c6, b##D, a##VDN, 2); \ + c7 = vfmaq_laneq_f32(c7, b##D, a##VDN, 3); + +// calculate 4 rows, 4 cols for each depth +#define MACE_SGEMM_PART_CAL_R4_C4_D1(D) \ + c0 = vfmaq_laneq_f32(c0, b##D, a##D, 0); \ + c1 = vfmaq_laneq_f32(c1, b##D, a##D, 1); \ + c2 = vfmaq_laneq_f32(c2, b##D, a##D, 2); \ + c3 = vfmaq_laneq_f32(c3, b##D, a##D, 3); + +// calculate 4 cols for 8 depths for each row +#define MACE_SGEMM_PART_CAL_R1_C4_D8(R, VR, VRN) \ + c##R = vfmaq_laneq_f32(c##R, b0, a##VR, 0); \ + c##R = vfmaq_laneq_f32(c##R, b1, a##VR, 1); \ + c##R = vfmaq_laneq_f32(c##R, b2, a##VR, 2); \ + c##R = vfmaq_laneq_f32(c##R, b3, a##VR, 3); \ + c##R = vfmaq_laneq_f32(c##R, b4, a##VRN, 0); \ + c##R = vfmaq_laneq_f32(c##R, b5, a##VRN, 1); \ + c##R = vfmaq_laneq_f32(c##R, b6, a##VRN, 2); \ + c##R = vfmaq_laneq_f32(c##R, b7, a##VRN, 3); + +// calculate 4 cols for 4 depths for each row +#define MACE_SGEMM_PART_CAL_R1_C4_D4(R) \ + c##R = vfmaq_laneq_f32(c##R, b0, a##R, 0); \ + c##R = vfmaq_laneq_f32(c##R, b1, a##R, 1); \ + c##R = vfmaq_laneq_f32(c##R, b2, a##R, 2); \ + c##R = vfmaq_laneq_f32(c##R, b3, a##R, 3); + +// calculate 8 cols for 4 depths for each row +#define MACE_SGEMM_PART_CAL_R1_C8_D4(VR, VRN, R) \ + c##VR = vfmaq_laneq_f32(c##VR, b0, a##R, 0); \ + c##VR = vfmaq_laneq_f32(c##VR, b2, a##R, 1); \ + c##VR = vfmaq_laneq_f32(c##VR, b4, a##R, 2); \ + c##VR = vfmaq_laneq_f32(c##VR, b6, a##R, 3); \ + c##VRN = vfmaq_laneq_f32(c##VRN, b1, a##R, 0); \ + c##VRN = vfmaq_laneq_f32(c##VRN, b3, a##R, 1); \ + c##VRN = vfmaq_laneq_f32(c##VRN, b5, a##R, 2); \ + c##VRN = vfmaq_laneq_f32(c##VRN, b7, a##R, 3); + +#else + +#define MACE_SGEMM_PART_CAL_R8_C4_D1(D, VD, VDN) \ + c0 = vmlaq_lane_f32(c0, b##D, vget_low_f32(a##VD), 0); \ + c1 = vmlaq_lane_f32(c1, b##D, vget_low_f32(a##VD), 1); \ + c2 = vmlaq_lane_f32(c2, b##D, vget_high_f32(a##VD), 0); \ + c3 = vmlaq_lane_f32(c3, b##D, vget_high_f32(a##VD), 1); \ + c4 = vmlaq_lane_f32(c4, b##D, vget_low_f32(a##VDN), 0); \ + c5 = vmlaq_lane_f32(c5, b##D, vget_low_f32(a##VDN), 1); \ + c6 = vmlaq_lane_f32(c6, b##D, vget_high_f32(a##VDN), 0); \ + c7 = vmlaq_lane_f32(c7, b##D, vget_high_f32(a##VDN), 1); + +#define MACE_SGEMM_PART_CAL_R4_C4_D1(D) \ + c0 = vmlaq_lane_f32(c0, b##D, vget_low_f32(a##D), 0); \ + c1 = vmlaq_lane_f32(c1, b##D, vget_low_f32(a##D), 1); \ + c2 = vmlaq_lane_f32(c2, b##D, vget_high_f32(a##D), 0); \ + c3 = vmlaq_lane_f32(c3, b##D, vget_high_f32(a##D), 1); + +#define MACE_SGEMM_PART_CAL_R1_C4_D8(R, VR, VRN) \ + c##R = vmlaq_lane_f32(c##R, b0, vget_low_f32(a##VR), 0); \ + c##R = vmlaq_lane_f32(c##R, b1, vget_low_f32(a##VR), 1); \ + c##R = vmlaq_lane_f32(c##R, b2, vget_high_f32(a##VR), 0); \ + c##R = vmlaq_lane_f32(c##R, b3, vget_high_f32(a##VR), 1); \ + c##R = vmlaq_lane_f32(c##R, b4, vget_low_f32(a##VRN), 0); \ + c##R = vmlaq_lane_f32(c##R, b5, vget_low_f32(a##VRN), 1); \ + c##R = vmlaq_lane_f32(c##R, b6, vget_high_f32(a##VRN), 0); \ + c##R = vmlaq_lane_f32(c##R, b7, vget_high_f32(a##VRN), 1); + +#define MACE_SGEMM_PART_CAL_R1_C4_D4(R) \ + c##R = vmlaq_lane_f32(c##R, b0, vget_low_f32(a##R), 0); \ + c##R = vmlaq_lane_f32(c##R, b1, vget_low_f32(a##R), 1); \ + c##R = vmlaq_lane_f32(c##R, b2, vget_high_f32(a##R), 0); \ + c##R = vmlaq_lane_f32(c##R, b3, vget_high_f32(a##R), 1); + +#endif // __aarch64__ +#endif // MACE_ENABLE_NEON + +void SGemm::RunInternal(const PackedBlock &lhs, + const PackedBlock &rhs, + const index_t batch, + const index_t height, + const index_t depth, + const index_t width, + PackedBlock *result) { + const float *lhs_data = lhs.data(); + const float *rhs_data = rhs.data(); + float *result_data = result->mutable_data(); - // (8, 8) * (8, 4) +#define MACE_SGEMM_RUN_PER_BATCH \ + for (index_t b = 0; b < batch; ++b) { \ + RunPerBatch(lhs_data + b * height * depth, \ + rhs_data + b * depth * width, \ + height, \ + depth, \ + width, \ + result_data + b * height * width); \ + } - // (4, 4) * (4, 4) + if (batch >= MaceOpenMPThreadCount) { +#pragma omp parallel for + MACE_SGEMM_RUN_PER_BATCH + } else { + MACE_SGEMM_RUN_PER_BATCH + } - // remain +#undef MACE_SGEMM_RUN_PER_BATCH } -void SGemm::PackLhs(const MatrixMap &lhs, - PackedBlock *packed_block) { +void SGemm::RunPerBatch(const float *lhs_data, + const float *rhs_data, + const index_t height, + const index_t depth, + const index_t width, + float *result_data) { +#if defined(MACE_ENABLE_NEON) + const index_t block_w = width >> 2; + const index_t remain_w = width - (block_w << 2); +#else + const index_t remain_w = width; +#endif + +#if defined(MACE_ENABLE_NEON) + // TODO(liyin): make better use l2(l1) cache, try to fit as much lhs data as + // as possible to cache, by tiling lhs by height and rhs by width. + + // w: 4 +#pragma omp parallel for + for (index_t bw = 0; bw < block_w; ++bw) { + index_t remain_h = height; + index_t block_h = 0; + + const float *lhs_ptr = lhs_data; + float *res_ptr = result_data + height * (bw << 2); + +#if defined(__aarch64__) + block_h = remain_h >> 3; + remain_h -= (block_h << 3); + + // h: 8 + for (index_t bh = 0; bh < block_h; ++bh) { + const float *rhs_ptr = rhs_data + depth * (bw << 2); + + index_t remain_d = depth; + index_t block_d = remain_d >> 3; + remain_d -= (block_d << 3); + + float32x4_t c0, c1, c2, c3, c4, c5, c6, c7; + c0 = vdupq_n_f32(0.f); + c1 = vdupq_n_f32(0.f); + c2 = vdupq_n_f32(0.f); + c3 = vdupq_n_f32(0.f); + c4 = vdupq_n_f32(0.f); + c5 = vdupq_n_f32(0.f); + c6 = vdupq_n_f32(0.f); + c7 = vdupq_n_f32(0.f); + + // d: 8 + for (index_t bd = 0; bd < block_d; ++bd) { + // 8.8.4 + float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, + a14, a15; + float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; + + a0 = vld1q_f32(lhs_ptr); + a1 = vld1q_f32(lhs_ptr + 4); + a2 = vld1q_f32(lhs_ptr + 8); + a3 = vld1q_f32(lhs_ptr + 12); + a4 = vld1q_f32(lhs_ptr + 16); + a5 = vld1q_f32(lhs_ptr + 20); + a6 = vld1q_f32(lhs_ptr + 24); + a7 = vld1q_f32(lhs_ptr + 28); + a8 = vld1q_f32(lhs_ptr + 32); + a9 = vld1q_f32(lhs_ptr + 36); + a10 = vld1q_f32(lhs_ptr + 40); + a11 = vld1q_f32(lhs_ptr + 44); + a12 = vld1q_f32(lhs_ptr + 48); + a13 = vld1q_f32(lhs_ptr + 52); + a14 = vld1q_f32(lhs_ptr + 56); + a15 = vld1q_f32(lhs_ptr + 60); + + b0 = vld1q_f32(rhs_ptr); + b1 = vld1q_f32(rhs_ptr + 4); + b2 = vld1q_f32(rhs_ptr + 8); + b3 = vld1q_f32(rhs_ptr + 12); + b4 = vld1q_f32(rhs_ptr + 16); + b5 = vld1q_f32(rhs_ptr + 20); + b6 = vld1q_f32(rhs_ptr + 24); + b7 = vld1q_f32(rhs_ptr + 28); + + MACE_SGEMM_PART_CAL_R8_C4_D1(0, 0, 1); // d = 1 + MACE_SGEMM_PART_CAL_R8_C4_D1(1, 2, 3); // d = 2 + MACE_SGEMM_PART_CAL_R8_C4_D1(2, 4, 5); + MACE_SGEMM_PART_CAL_R8_C4_D1(3, 6, 7); + MACE_SGEMM_PART_CAL_R8_C4_D1(4, 8, 9); + MACE_SGEMM_PART_CAL_R8_C4_D1(5, 10, 11); + MACE_SGEMM_PART_CAL_R8_C4_D1(6, 12, 13); + MACE_SGEMM_PART_CAL_R8_C4_D1(7, 14, 15); + + lhs_ptr += 64; + rhs_ptr += 32; + } + + block_d = remain_d >> 2; + remain_d -= (block_d << 2); + + // d: 4 + for (index_t bd = 0; bd < block_d; ++bd) { + // 8.4.4 + float32x4_t a0, a1, a2, a3, a4, a5, a6, a7; + float32x4_t b0, b1, b2, b3; + + a0 = vld1q_f32(lhs_ptr); + a1 = vld1q_f32(lhs_ptr + 4); + a2 = vld1q_f32(lhs_ptr + 8); + a3 = vld1q_f32(lhs_ptr + 12); + a4 = vld1q_f32(lhs_ptr + 16); + a5 = vld1q_f32(lhs_ptr + 20); + a6 = vld1q_f32(lhs_ptr + 24); + a7 = vld1q_f32(lhs_ptr + 28); + + b0 = vld1q_f32(rhs_ptr); + b1 = vld1q_f32(rhs_ptr + 4); + b2 = vld1q_f32(rhs_ptr + 8); + b3 = vld1q_f32(rhs_ptr + 12); + + MACE_SGEMM_PART_CAL_R8_C4_D1(0, 0, 1); // d = 1 + MACE_SGEMM_PART_CAL_R8_C4_D1(1, 2, 3); // d = 2 + MACE_SGEMM_PART_CAL_R8_C4_D1(2, 4, 5); + MACE_SGEMM_PART_CAL_R8_C4_D1(3, 6, 7); + + lhs_ptr += 32; + rhs_ptr += 16; + } + + // TODO(liyin): handle remain by each case + // d: remain + for (index_t d = 0; d < remain_d; ++d) { + // 8.1.4 + float32x4_t a0, a1; + float32x4_t b0; + + a0 = vld1q_f32(lhs_ptr); + a1 = vld1q_f32(lhs_ptr + 4); + + b0 = vld1q_f32(rhs_ptr); + + MACE_SGEMM_PART_CAL_R8_C4_D1(0, 0, 1); // d = 1 + + lhs_ptr += 8; + rhs_ptr += 4; + } + + vst1q_f32(res_ptr, c0); + vst1q_f32(res_ptr + 4, c1); + vst1q_f32(res_ptr + 8, c2); + vst1q_f32(res_ptr + 12, c3); + vst1q_f32(res_ptr + 16, c4); + vst1q_f32(res_ptr + 20, c5); + vst1q_f32(res_ptr + 24, c6); + vst1q_f32(res_ptr + 28, c7); + + res_ptr += 32; + } // bh: 8 +#endif // __aarch64__ + + // h: 4 + block_h = remain_h >> 2; + remain_h -= (block_h << 2); + + for (index_t bh = 0; bh < block_h; ++bh) { + const float *rhs_ptr = rhs_data + depth * (bw << 2); + + index_t remain_d = depth; + index_t block_d = 0; + + float32x4_t c0, c1, c2, c3; + c0 = vdupq_n_f32(0.f); + c1 = vdupq_n_f32(0.f); + c2 = vdupq_n_f32(0.f); + c3 = vdupq_n_f32(0.f); + + // d: 8 + block_d = remain_d >> 3; + remain_d -= (block_d << 3); + +#if defined(__aarch64__) + for (index_t bd = 0; bd < block_d; ++bd) { + // 4.8.4 + float32x4_t a0, a1, a2, a3, a4, a5, a6, a7; + float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; + + a0 = vld1q_f32(lhs_ptr); + a1 = vld1q_f32(lhs_ptr + 4); + a2 = vld1q_f32(lhs_ptr + 8); + a3 = vld1q_f32(lhs_ptr + 12); + a4 = vld1q_f32(lhs_ptr + 16); + a5 = vld1q_f32(lhs_ptr + 20); + a6 = vld1q_f32(lhs_ptr + 24); + a7 = vld1q_f32(lhs_ptr + 28); + + b0 = vld1q_f32(rhs_ptr); + b1 = vld1q_f32(rhs_ptr + 4); + b2 = vld1q_f32(rhs_ptr + 8); + b3 = vld1q_f32(rhs_ptr + 12); + b4 = vld1q_f32(rhs_ptr + 16); + b5 = vld1q_f32(rhs_ptr + 20); + b6 = vld1q_f32(rhs_ptr + 24); + b7 = vld1q_f32(rhs_ptr + 28); + + MACE_SGEMM_PART_CAL_R4_C4_D1(0); // d = 1 + MACE_SGEMM_PART_CAL_R4_C4_D1(1); // d = 2 + MACE_SGEMM_PART_CAL_R4_C4_D1(2); + MACE_SGEMM_PART_CAL_R4_C4_D1(3); + MACE_SGEMM_PART_CAL_R4_C4_D1(4); + MACE_SGEMM_PART_CAL_R4_C4_D1(5); + MACE_SGEMM_PART_CAL_R4_C4_D1(6); + MACE_SGEMM_PART_CAL_R4_C4_D1(7); + + lhs_ptr += 32; + rhs_ptr += 32; + } +#else // arm v7 + // 4.8.4 + if (block_d > 0) { + asm volatile( + "0: \n" + + "vld1.f32 {d0-d1}, [%[lhs_ptr]]! \n" + "vld1.f32 {d2-d3}, [%[lhs_ptr]]! \n" + "vld1.f32 {d4-d5}, [%[lhs_ptr]]! \n" + + "vld1.f32 {d20-d21}, [%[rhs_ptr]]! \n" + "vld1.f32 {d22-d23}, [%[rhs_ptr]]! \n" + "vld1.f32 {d24-d25}, [%[rhs_ptr]]! \n" + + "vmla.f32 %[c0], q10, d0[0] \n" + "vmla.f32 %[c1], q10, d0[1] \n" + "vmla.f32 %[c2], q10, d1[0] \n" + "vmla.f32 %[c3], q10, d1[1] \n" + + "vld1.f32 {d6-d7}, [%[lhs_ptr]]! \n" + "vld1.f32 {d26-d27}, [%[rhs_ptr]]! \n" + + "vmla.f32 %[c0], q11, d2[0] \n" + "vmla.f32 %[c1], q11, d2[1] \n" + "vmla.f32 %[c2], q11, d3[0] \n" + "vmla.f32 %[c3], q11, d3[1] \n" + + "vld1.f32 {d8-d9}, [%[lhs_ptr]]! \n" + "vld1.f32 {d28-d29}, [%[rhs_ptr]]! \n" + + "vmla.f32 %[c0], q12, d4[0] \n" + "vmla.f32 %[c1], q12, d4[1] \n" + "vmla.f32 %[c2], q12, d5[0] \n" + "vmla.f32 %[c3], q12, d5[1] \n" + + "vld1.f32 {d10-d11}, [%[lhs_ptr]]! \n" + "vld1.f32 {d30-d31}, [%[rhs_ptr]]! \n" + + "vmla.f32 %[c0], q13, d6[0] \n" + "vmla.f32 %[c1], q13, d6[1] \n" + "vmla.f32 %[c2], q13, d7[0] \n" + "vmla.f32 %[c3], q13, d7[1] \n" + + "vld1.f32 {d0-d1}, [%[lhs_ptr]]! \n" + "vld1.f32 {d2-d3}, [%[lhs_ptr]]! \n" + + "vld1.f32 {d20-d21}, [%[rhs_ptr]]! \n" + "vld1.f32 {d22-d23}, [%[rhs_ptr]]! \n" + + "vmla.f32 %[c0], q14, d8[0] \n" + "vmla.f32 %[c1], q14, d8[1] \n" + "vmla.f32 %[c2], q14, d9[0] \n" + "vmla.f32 %[c3], q14, d9[1] \n" + + "vmla.f32 %[c0], q15, d10[0] \n" + "vmla.f32 %[c1], q15, d10[1] \n" + "vmla.f32 %[c2], q15, d11[0] \n" + "vmla.f32 %[c3], q15, d11[1] \n" + + "vmla.f32 %[c0], q10, d0[0] \n" + "vmla.f32 %[c1], q10, d0[1] \n" + "vmla.f32 %[c2], q10, d1[0] \n" + "vmla.f32 %[c3], q10, d1[1] \n" + + "subs %[block_d], %[block_d], #1 \n" + + "vmla.f32 %[c0], q11, d2[0] \n" + "vmla.f32 %[c1], q11, d2[1] \n" + "vmla.f32 %[c2], q11, d3[0] \n" + "vmla.f32 %[c3], q11, d3[1] \n" + + "bne 0b \n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), + [rhs_ptr] "+r"(rhs_ptr), + [res_ptr] "+r"(res_ptr), + [block_d] "+r"(block_d), + [c0] "+w"(c0), + [c1] "+w"(c1), + [c2] "+w"(c2), + [c3] "+w"(c3) + : // inputs + : // clabbers + "cc", "memory", + "q0", "q1", "q2", "q3", "q4", "q5", + "q10", "q11", "q12", "q13", "q14", "q15"); + } +#endif // __aarch64__ + + // d: 4 + block_d = remain_d >> 2; + remain_d -= (block_d << 2); + + for (index_t bd = 0; bd < block_d; ++bd) { + // 4.4.4 + float32x4_t a0, a1, a2, a3; + float32x4_t b0, b1, b2, b3; + + a0 = vld1q_f32(lhs_ptr); + a1 = vld1q_f32(lhs_ptr + 4); + a2 = vld1q_f32(lhs_ptr + 8); + a3 = vld1q_f32(lhs_ptr + 12); + + b0 = vld1q_f32(rhs_ptr); + b1 = vld1q_f32(rhs_ptr + 4); + b2 = vld1q_f32(rhs_ptr + 8); + b3 = vld1q_f32(rhs_ptr + 12); + + MACE_SGEMM_PART_CAL_R4_C4_D1(0); // d = 1 + MACE_SGEMM_PART_CAL_R4_C4_D1(1); // d = 2 + MACE_SGEMM_PART_CAL_R4_C4_D1(2); + MACE_SGEMM_PART_CAL_R4_C4_D1(3); + + lhs_ptr += 16; + rhs_ptr += 16; + } + + // d: remain + for (index_t d = 0; d < remain_d; ++d) { + // 4.1.4 + float32x4_t a0; + float32x4_t b0; + + a0 = vld1q_f32(lhs_ptr); + + b0 = vld1q_f32(rhs_ptr); + + MACE_SGEMM_PART_CAL_R4_C4_D1(0); // d = 1 + + lhs_ptr += 4; + rhs_ptr += 4; + } + vst1q_f32(res_ptr, c0); + vst1q_f32(res_ptr + 4, c1); + vst1q_f32(res_ptr + 8, c2); + vst1q_f32(res_ptr + 12, c3); + + res_ptr += 16; + } // bh: 4 + + // h: 1 + for (index_t h = 0; h < remain_h; ++h) { + const float *rhs_ptr = rhs_data + depth * (bw << 2); + + index_t remain_d = depth; + index_t block_d = 0; + + float32x4_t c0 = vdupq_n_f32(0.f); + + // d: 8 + block_d = remain_d >> 3; + remain_d -= (block_d << 3); + + for (index_t bd = 0; bd < block_d; ++bd) { + // 1.8.4 + float32x4_t a0, a1; + float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; + + a0 = vld1q_f32(lhs_ptr); + a1 = vld1q_f32(lhs_ptr + 4); + + b0 = vld1q_f32(rhs_ptr); + b1 = vld1q_f32(rhs_ptr + 4); + b2 = vld1q_f32(rhs_ptr + 8); + b3 = vld1q_f32(rhs_ptr + 12); + b4 = vld1q_f32(rhs_ptr + 16); + b5 = vld1q_f32(rhs_ptr + 20); + b6 = vld1q_f32(rhs_ptr + 24); + b7 = vld1q_f32(rhs_ptr + 28); + + MACE_SGEMM_PART_CAL_R1_C4_D8(0, 0, 1); + + lhs_ptr += 8; + rhs_ptr += 32; + } + + block_d = remain_d >> 2; + remain_d -= (block_d << 2); + + // d: 4 + for (index_t bd = 0; bd < block_d; ++bd) { + // 1.4.4 + float32x4_t a0; + float32x4_t b0, b1, b2, b3; + + a0 = vld1q_f32(lhs_ptr); + + b0 = vld1q_f32(rhs_ptr); + b1 = vld1q_f32(rhs_ptr + 4); + b2 = vld1q_f32(rhs_ptr + 8); + b3 = vld1q_f32(rhs_ptr + 12); + + MACE_SGEMM_PART_CAL_R1_C4_D4(0); + + lhs_ptr += 4; + rhs_ptr += 16; + } + + // d: remain + float s0 = 0; + float s1 = 0; + float s2 = 0; + float s3 = 0; + for (index_t d = 0; d < remain_d; ++d) { + // 1.1.4 + s0 += lhs_ptr[0] * rhs_ptr[0]; + s1 += lhs_ptr[0] * rhs_ptr[1]; + s2 += lhs_ptr[0] * rhs_ptr[2]; + s3 += lhs_ptr[0] * rhs_ptr[3]; + lhs_ptr += 1; + rhs_ptr += 4; + } + float32x4_t c0_remain = {s0, s1, s2, s3}; + c0 += c0_remain; + + vst1q_f32(res_ptr, c0); + res_ptr += 4; + } // bh: remain + } // bw + +#endif // MACE_ENABLE_NEON + + // ========================== remain width =========================== + + result_data += (width - remain_w) * height; + rhs_data += (width - remain_w) * depth; + + // w: 1 +#pragma omp parallel for + for (index_t bw = 0; bw < remain_w; ++bw) { + index_t remain_h = height; + + const float *lhs_ptr = lhs_data; + float *res_ptr = result_data + height * bw; + +#if defined(MACE_ENABLE_NEON) + index_t block_h = 0; +#if defined(__aarch64__) + block_h = remain_h >> 3; + remain_h -= (block_h << 3); + + // h: 8 + for (index_t bh = 0; bh < block_h; ++bh) { + const float *rhs_ptr = rhs_data + depth * bw; + + index_t remain_d = depth; + + float32x4_t c0, c1; + c0 = vdupq_n_f32(0.f); + c1 = vdupq_n_f32(0.f); + + index_t block_d = remain_d >> 2; + remain_d -= (block_d << 2); + + // d: 4 + for (index_t bd = 0; bd < block_d; ++bd) { + // 8.4.1 + float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; + float32x4_t a0; + + b0 = vld1q_f32(lhs_ptr); + b1 = vld1q_f32(lhs_ptr + 4); + b2 = vld1q_f32(lhs_ptr + 8); + b3 = vld1q_f32(lhs_ptr + 12); + b4 = vld1q_f32(lhs_ptr + 16); + b5 = vld1q_f32(lhs_ptr + 20); + b6 = vld1q_f32(lhs_ptr + 24); + b7 = vld1q_f32(lhs_ptr + 28); + + a0 = vld1q_f32(rhs_ptr); + + MACE_SGEMM_PART_CAL_R1_C8_D4(0, 1, 0); + + lhs_ptr += 32; + rhs_ptr += 4; + } + + // d: remain + for (index_t d = 0; d < remain_d; ++d) { + // 8.1.1 + float32x4_t b0, b1; + float32x4_t a0 = vdupq_n_f32(rhs_ptr[0]); + + b0 = vld1q_f32(lhs_ptr); + b1 = vld1q_f32(lhs_ptr + 4); + + c0 = vfmaq_laneq_f32(c0, b0, a0, 0); + c1 = vfmaq_laneq_f32(c1, b1, a0, 0); + + lhs_ptr += 8; + rhs_ptr += 1; + } + + vst1q_f32(res_ptr, c0); + vst1q_f32(res_ptr + 4, c1); + + res_ptr += 8; + } // bh: 8 +#endif + + // h: 4 + block_h = remain_h >> 2; + remain_h -= (block_h << 2); + + for (index_t bh = 0; bh < block_h; ++bh) { + const float *rhs_ptr = rhs_data + depth * bw; + + index_t remain_d = depth; + index_t block_d = 0; + + float32x4_t c0 = vdupq_n_f32(0.f); + + block_d = remain_d >> 2; + remain_d -= (block_d << 2); + + // d: 4 + for (index_t bd = 0; bd < block_d; ++bd) { + // 4.4.1 + float32x4_t b0, b1, b2, b3; + float32x4_t a0; + + b0 = vld1q_f32(lhs_ptr); + b1 = vld1q_f32(lhs_ptr + 4); + b2 = vld1q_f32(lhs_ptr + 8); + b3 = vld1q_f32(lhs_ptr + 12); + + a0 = vld1q_f32(rhs_ptr); + + MACE_SGEMM_PART_CAL_R1_C4_D4(0); + + lhs_ptr += 16; + rhs_ptr += 4; + } + + // d: remain + for (index_t d = 0; d < remain_d; ++d) { + // 4.1.1 + float32x4_t b0, b1; + float32x2_t a0 = vdup_n_f32(rhs_ptr[0]); + + b0 = vld1q_f32(lhs_ptr); + + c0 = vmlaq_lane_f32(c0, b0, a0, 0); + + lhs_ptr += 4; + rhs_ptr += 1; + } + vst1q_f32(res_ptr, c0); + + res_ptr += 4; + } // bh: 4 + +#endif // MACE_ENABLE_NEON + + // h: 1 + for (index_t h = 0; h < remain_h; ++h) { + const float *rhs_ptr = rhs_data + depth * bw; + + index_t remain_d = depth; + + float sum = 0.f; + +#if defined(MACE_ENABLE_NEON) + index_t block_d = 0; + + float32x4_t c0, c1; + c0 = vdupq_n_f32(0.f); + c1 = vdupq_n_f32(0.f); + + block_d = remain_d >> 3; + remain_d -= (block_d << 3); + + // d: 8 + for (index_t bd = 0; bd < block_d; ++bd) { + // 1.8.1 + float32x4_t a0, a1; + float32x4_t b0, b1; + + a0 = vld1q_f32(lhs_ptr); + a1 = vld1q_f32(lhs_ptr + 4); + b0 = vld1q_f32(rhs_ptr); + b1 = vld1q_f32(rhs_ptr + 4); + + c0 = vmlaq_f32(c0, a0, b0); + c1 = vmlaq_f32(c1, a1, b1); + + lhs_ptr += 8; + rhs_ptr += 8; + } + + block_d = remain_d >> 2; + remain_d -= (block_d << 2); + + // d: 4 + for (index_t bd = 0; bd < block_d; ++bd) { + // 1.4.1 + float32x4_t a0; + float32x4_t b0; + + a0 = vld1q_f32(lhs_ptr); + b0 = vld1q_f32(rhs_ptr); + + c0 = vmlaq_f32(c0, a0, b0); + + lhs_ptr += 4; + rhs_ptr += 4; + } + sum += vaddvq_f32(c0); + sum += vaddvq_f32(c1); +#endif // MACE_ENABLE_NEON + + // d: remain + for (index_t d = 0; d < remain_d; ++d) { + // 1.1.1 + sum += lhs_ptr[0] * rhs_ptr[0]; + lhs_ptr += 1; + rhs_ptr += 1; + } + + *res_ptr = sum; + ++res_ptr; + } // bh: remain + } // bw +} + +void SGemm::PackLhs(const MatrixMap &lhs, + PackedBlock *packed_block) { Pack(lhs, PackOrder::ColMajor, packed_block); } -void SGemm::PackRhs(const MatrixMap &rhs, - PackedBlock *packed_block) { +void SGemm::PackRhs(const MatrixMap &rhs, + PackedBlock *packed_block) { Pack(rhs, PackOrder::RowMajor, packed_block); } -void SGemm::UnPack(const PackedBlock &packed_result, +void SGemm::Pack(const MatrixMap &src, + const PackOrder order, + PackedBlock *packed_block) { + MACE_CHECK_NOTNULL(packed_block); + + const index_t height = src.row(); + const index_t width = src.col(); + auto packed_data = packed_block->mutable_data(); + +#define MACE_SGEMM_PACK_PER_BATCH \ + for (index_t b = 0; b < src.batch(); ++b) { \ + PackPerBatch(src, order, b, packed_data + b * height * width); \ + } + if (src.batch() >= MaceOpenMPThreadCount) { +#pragma omp parallel for + MACE_SGEMM_PACK_PER_BATCH + } else { + MACE_SGEMM_PACK_PER_BATCH + } +#undef MACE_SGEMM_PACK_PER_BATCH +} + +void SGemm::UnPack(const PackedBlock &packed_result, MatrixMap *matrix_map) { - (void) packed_result; - (void) matrix_map; + MACE_CHECK_NOTNULL(matrix_map); + + const index_t height = matrix_map->row(); + const index_t width = matrix_map->col(); + auto packed_data = packed_result.data(); + +#define MACE_SGEMM_UNPACK_PER_BATCH \ + for (index_t b = 0; b < matrix_map->batch(); ++b) { \ + UnPackPerBatch(packed_data + b * height * width, b, matrix_map); \ + } + + if (matrix_map->batch() >= MaceOpenMPThreadCount) { +#pragma omp parallel for + MACE_SGEMM_UNPACK_PER_BATCH + } else { + MACE_SGEMM_UNPACK_PER_BATCH + } +#undef MACE_SGEMM_UNPACK_PER_BATCH } -void SGemm::Pack(const MatrixMap &src, - const PackOrder order, - PackedBlock *packed_block) { - (void) src; - (void) order; - (void) packed_block; +void SGemm::PackPerBatch(const MatrixMap &src, + const PackOrder order, + const index_t batch_index, + float *packed_data) { + MACE_CHECK_NOTNULL(packed_data); + + const index_t height = src.row(); + const index_t width = src.col(); + auto src_data = src.batch_data(batch_index); + + if (src.major() == Major::RowMajor && order == PackOrder::ColMajor) { + // This is for packing no-transpose lhs. + index_t h = 0; +#if defined(MACE_ENABLE_NEON) +#if defined(__aarch64__) +#pragma omp parallel for + for (index_t ih = h; ih <= height - 8; ih += 8) { + const float *src_data_ptr = src_data + ih * width; + float *packed_data_ptr = packed_data + ih * width; + for (index_t w = 0; w < width; ++w) { + const index_t src_offset = w; + const index_t packed_offset = w * 8; + float32x4_t vs0 = {src_data_ptr[src_offset], + src_data_ptr[src_offset + width], + src_data_ptr[src_offset + 2 * width], + src_data_ptr[src_offset + 3 * width]}; + float32x4_t vs1 = {src_data_ptr[src_offset + 4 * width], + src_data_ptr[src_offset + 5 * width], + src_data_ptr[src_offset + 6 * width], + src_data_ptr[src_offset + 7 * width]}; + vst1q_f32(packed_data_ptr + packed_offset, vs0); + vst1q_f32(packed_data_ptr + packed_offset + 4, vs1); + } + } + h += (height - h) / 8 * 8; +#endif +#pragma omp parallel for + for (index_t ih = h; ih <= height - 4; ih += 4) { + const float *src_data_ptr = src_data + ih * width; + float *packed_data_ptr = packed_data + ih * width; + for (index_t w = 0; w < width; ++w) { + const index_t src_offset = w; + const index_t packed_offset = w * 4; + float32x4_t vs = {src_data_ptr[src_offset], + src_data_ptr[src_offset + width], + src_data_ptr[src_offset + 2 * width], + src_data_ptr[src_offset + 3 * width]}; + vst1q_f32(packed_data_ptr + packed_offset, vs); + } + } + h += (height - h) / 4 * 4; +#endif +#pragma omp parallel for + for (index_t ih = h; ih < height; ++ih) { + std::copy_n(src_data + ih * width, width, packed_data + ih * width); + } + } else if (src.major() == Major::ColMajor && order == PackOrder::ColMajor) { + // This is for packing transpose-needed lhs. + index_t h = 0; +#if defined(MACE_ENABLE_NEON) +#if defined(__aarch64__) +#pragma omp parallel for + for (index_t ih = h; ih <= height - 8; ih += 8) { + const float *src_data_ptr = src_data + ih; + float *packed_data_ptr = packed_data + ih * width; + for (index_t w = 0; w < width; ++w) { + const index_t src_offset = w * height; + const index_t packed_offset = w * 8; + float32x4_t vs0 = vld1q_f32(src_data_ptr + src_offset); + float32x4_t vs1 = vld1q_f32(src_data_ptr + src_offset + 4); + vst1q_f32(packed_data_ptr + packed_offset, vs0); + vst1q_f32(packed_data_ptr + packed_offset + 4, vs1); + } + } + h += (height - h) / 8 * 8; +#endif +#pragma omp parallel for + for (index_t ih = h; ih <= height - 4; ih += 4) { + const float *src_data_ptr = src_data + ih; + float *packed_data_ptr = packed_data + ih * width; + for (index_t w = 0; w < width; ++w) { + const index_t src_offset = w * height; + const index_t packed_offset = w * 4; + float32x4_t vs = vld1q_f32(src_data_ptr + src_offset); + vst1q_f32(packed_data_ptr + packed_offset, vs); + } + } + h += (height - h) / 4 * 4; +#endif +#pragma omp parallel for + for (index_t ih = h; ih < height; ++ih) { + const float *src_data_ptr = src_data + ih; + float *packed_data_ptr = packed_data + ih * width; + for (index_t w = 0; w < width; ++w) { + packed_data_ptr[w] = src_data_ptr[w * height]; + } + } + } else if (src.major() == Major::RowMajor && order == PackOrder::RowMajor) { + // This is for packing no-transpose rhs. + index_t w = 0; +#if defined(MACE_ENABLE_NEON) +#pragma omp parallel for + for (index_t iw = w; iw <= width - 4; iw += 4) { + const float *src_data_ptr = src_data + iw; + float *packed_data_ptr = packed_data + iw * height; + for (index_t h = 0; h < height; ++h) { + const index_t src_offset = h * width; + const index_t packed_offset = h * 4; + float32x4_t vs = vld1q_f32(src_data_ptr + src_offset); + vst1q_f32(packed_data_ptr + packed_offset, vs); + } + } + w += (width - w) / 4 * 4; +#endif +#pragma omp parallel for + for (index_t iw = w; iw < width; ++iw) { + const float *src_data_ptr = src_data + iw; + float *packed_data_ptr = packed_data + iw * height; + for (index_t h = 0; h < height; ++h) { + packed_data_ptr[h] = src_data_ptr[h * width]; + } + } + } else if (src.major() == Major::ColMajor && order == PackOrder::RowMajor) { + // This is for packing transpose-needed rhs. + index_t w = 0; +#if defined(MACE_ENABLE_NEON) +#pragma omp parallel for + for (index_t iw = w; iw <= width - 4; iw += 4) { + const float *src_data_ptr = src_data + iw * height; + float *packed_data_ptr = packed_data + iw * height; + for (index_t h = 0; h < height; ++h) { + const index_t src_offset = h; + const index_t packed_offset = h * 4; + float32x4_t vs = {src_data_ptr[src_offset], + src_data_ptr[src_offset + height], + src_data_ptr[src_offset + 2 * height], + src_data_ptr[src_offset + 3 * height]}; + vst1q_f32(packed_data_ptr + packed_offset, vs); + } + } + w += (width - w) / 4 * 4; +#endif +#pragma omp parallel for + for (index_t iw = w; iw < width; ++iw) { + std::copy_n(src_data + iw * height, height, packed_data + iw * height); + } + } +} + +void SGemm::UnPackPerBatch(const float *packed_data, + const index_t batch_index, + MatrixMap *matrix_map) { + MACE_CHECK_NOTNULL(matrix_map); + + const index_t height = matrix_map->row(); + const index_t width = matrix_map->col(); + auto unpacked_data = matrix_map->batch_data(batch_index); + + if (matrix_map->major() == Major::RowMajor) { + // This is for non-transposed result + index_t w = 0; +#if defined(MACE_ENABLE_NEON) +#pragma omp parallel for + for (index_t iw = w; iw <= width - 4; iw += 4) { + const float *packed_data_ptr = packed_data + iw * height; + float *unpacked_data_ptr = unpacked_data + iw; + for (index_t h = 0; h < height; ++h) { + const index_t packed_offset = h * 4; + const index_t unpacked_offset = h * width; + float32x4_t vs = vld1q_f32(packed_data_ptr + packed_offset); + vst1q_f32(unpacked_data_ptr + unpacked_offset, vs); + } + } + w += (width - w) / 4 * 4; +#endif +#pragma omp parallel for + for (index_t iw = w; iw < width; ++iw) { + const float *packed_data_ptr = packed_data + iw * height; + float *unpacked_data_ptr = unpacked_data + iw; + for (index_t h = 0; h < height; ++h) { + unpacked_data_ptr[h * width] = packed_data_ptr[h]; + } + } + } else { + // This is for transposed result + index_t w = 0; +#if defined(MACE_ENABLE_NEON) +#pragma omp parallel for + for (index_t iw = w; iw <= width - 4; iw += 4) { + const float *packed_data_ptr = packed_data + iw * height; + float *unpacked_data_ptr = unpacked_data + iw * height; + for (index_t h = 0; h < height; ++h) { + const index_t packed_offset = h * 4; + const index_t unpacked_offset = h; + float32x4_t vs = vld1q_f32(packed_data_ptr + packed_offset); + unpacked_data_ptr[unpacked_offset] = vs[0]; + unpacked_data_ptr[unpacked_offset + height] = vs[1]; + unpacked_data_ptr[unpacked_offset + 2 * height] = vs[2]; + unpacked_data_ptr[unpacked_offset + 3 * height] = vs[3]; + } + } + w += (width - w) / 4 * 4; +#endif +#pragma omp parallel for + for (index_t iw = w; iw < width; ++iw) { + std::copy_n( + packed_data + iw * height, height, unpacked_data + iw * height); + } + } } } // namespace kernels diff --git a/mace/kernels/sgemm.h b/mace/kernels/sgemm.h index 3aaf5d478324ed8ec4d32452ceeb39422d89ac1f..c24c9c03e55047e157208caed5139e3a1b93380f 100644 --- a/mace/kernels/sgemm.h +++ b/mace/kernels/sgemm.h @@ -15,6 +15,9 @@ #ifndef MACE_KERNELS_SGEMM_H_ #define MACE_KERNELS_SGEMM_H_ +#include +#include + #if defined(MACE_ENABLE_NEON) #include #endif @@ -34,22 +37,29 @@ enum Major { template class MatrixMap { public: - MatrixMap(const index_t row, + MatrixMap() {} + + MatrixMap(const index_t batch, + const index_t row, const index_t col, const Major major, - T *data) : + T *data, + const bool is_const = false) : + batch_(batch), row_(row), col_(col), stride_(major == RowMajor ? col : row), major_(major), - data_(data) {} - - MatrixMap transpose(const MatrixMap &matrix_map) { - Major transpose_major = matrix_map.major_ == RowMajor ? ColMajor : RowMajor; - return MatrixMap(matrix_map.col_, - matrix_map.row_, - transpose_major, - matrix_map.data_); + data_(data), + is_const_(is_const) {} + + MatrixMap transpose() const { + Major transpose_major = major_ == RowMajor ? ColMajor : RowMajor; + return MatrixMap(batch_, col_, row_, transpose_major, data_, is_const_); + } + + index_t batch() const { + return batch_; } index_t row() const { @@ -72,66 +82,100 @@ class MatrixMap { return data_; } - T *data(int row, int col) const { - return data_ + row * stride_ + col; + T *batch_data(index_t batch) const { + return data_ + batch * row_ * col_; + } + + index_t size() const { + return batch_ * row_ * col_; + } + + bool is_const() const { + return is_const_; } private: + index_t batch_; index_t row_; index_t col_; index_t stride_; Major major_; T *data_; + bool is_const_; }; typedef Major PackOrder; - -template -class PackedBlock { - public: - PackedBlock() : data_tensor_(GetCPUAllocator(), - DataTypeToEnum::v()) {} - - const T *data() { - return data_tensor_.data(); - } - - T *mutable_data() { - return data_tensor_.mutable_data(); - } - - Tensor *tensor() { - return &data_tensor_; - } - - private: - Tensor data_tensor_; -}; +typedef Tensor PackedBlock; class SGemm { public: - void operator()(const MatrixMap &lhs, - const MatrixMap &rhs, - MatrixMap *result); - - void operator()(const PackedBlock &lhs, - const PackedBlock &rhs, - const index_t height, - const index_t depth, - const index_t width, - PackedBlock *result); - - void PackLhs(const MatrixMap &lhs, PackedBlock *packed_block); - - void PackRhs(const MatrixMap &rhs, PackedBlock *packed_block); - - void UnPack(const PackedBlock &packed_result, + SGemm() + : packed_lhs_(nullptr), + packed_rhs_(nullptr), + packed_(false) {} + + void operator()(const MatrixMap &lhs, + const MatrixMap &rhs, + MatrixMap *result, + ScratchBuffer *scratch_buffer = nullptr); + + void Run(const float *A, + const float *B, + const index_t batch, + const index_t height_a, + const index_t width_a, + const index_t height_b, + const index_t width_b, + const bool transpose_a, + const bool transpose_b, + const bool is_a_weight, + const bool is_b_weight, + float *C, + ScratchBuffer *scratch_buffer = nullptr); + + void PackLhs(const MatrixMap &lhs, + PackedBlock *packed_block); + + void PackRhs(const MatrixMap &rhs, + PackedBlock *packed_block); + + void UnPack(const PackedBlock &packed_result, MatrixMap *matrix_map); private: - void Pack(const MatrixMap &src, + void Pack(const MatrixMap &src, const PackOrder order, - PackedBlock *packed_block); + PackedBlock *packed_block); + + void PackPerBatch(const MatrixMap &src, + const PackOrder order, + const index_t batch_index, + float *packed_data); + + void UnPackPerBatch(const float *packed_data, + const index_t batch_index, + MatrixMap *matrix_map); + + void RunInternal(const PackedBlock &lhs, + const PackedBlock &rhs, + const index_t batch, + const index_t height, + const index_t depth, + const index_t width, + PackedBlock *result); + + void RunPerBatch(const float *lhs, + const float *rhs, + const index_t height, + const index_t depth, + const index_t width, + float *result); + + std::unique_ptr packed_lhs_; + std::unique_ptr packed_rhs_; + std::unique_ptr packed_result_; + + bool packed_; }; } // namespace kernels diff --git a/mace/kernels/sgemm_pack_test.cc b/mace/kernels/sgemm_pack_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3e7aaa98e362896eb9da1452b589f3d678a56c54 --- /dev/null +++ b/mace/kernels/sgemm_pack_test.cc @@ -0,0 +1,167 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "mace/kernels/sgemm.h" + +namespace mace { +namespace kernels { +namespace test { + +namespace { +void TestPack(const std::vector &data, + const std::vector &expected_data, + const index_t height, + const index_t width, + Major src_order, + PackOrder pack_order) { + SGemm sg; + MatrixMap src_matrix(1, height, width, src_order, data.data()); + PackedBlock packed; + packed.Resize({height, width}); + if (pack_order == PackOrder::ColMajor) { + sg.PackLhs(src_matrix, &packed); + } else { + sg.PackRhs(src_matrix, &packed); + } + + auto packed_data = packed.data(); + for (index_t i = 0; i < packed.size(); ++i) { + EXPECT_EQ(expected_data[i], packed_data[i]); + } +} + +void TestUnPack(const index_t height, + const index_t width, + Major src_order, + PackOrder pack_order) { + static auto seed = static_cast(time(nullptr)); + const index_t matrix_size = height * width; + std::vector data(matrix_size); + for (int i = 0; i < matrix_size; ++i) { + data[i] = rand_r(&seed); + } + + MatrixMap src_matrix(1, height, width, src_order, data.data()); + PackedBlock packed; + packed.Resize({height, width}); + SGemm sg; + if (pack_order == PackOrder::ColMajor) { + sg.PackLhs(src_matrix, &packed); + } else { + sg.PackRhs(src_matrix, &packed); + } + + std::vector unpacked(matrix_size); + MatrixMap + unpacked_matrix(1, height, width, src_order, unpacked.data()); + sg.UnPack(packed, &unpacked_matrix); + auto unpacked_data = unpacked.data(); + for (index_t i = 0; i < packed.size(); ++i) { + EXPECT_EQ(data[i], unpacked_data[i]); + } +} +} // namespace + + +TEST(SGemmPackTest, Pack) { + std::vector data = + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}; + + // For no-transpose lhs + TestPack(data, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + 3, 4, Major::RowMajor, PackOrder::ColMajor); +#if defined(MACE_ENABLE_NEON) + TestPack(data, + {1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16}, + 4, 4, Major::RowMajor, PackOrder::ColMajor); + TestPack(data, + {1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16, 17, 18, 19, + 20}, + 5, 4, Major::RowMajor, PackOrder::ColMajor); +#if defined(__aarch64__) + TestPack(data, + {1, 5, 9, 13, 17, 21, 25, 29, 2, 6, 10, 14, 18, 22, 26, 30, 3, 7, 11, + 15, 19, 23, 27, 31, 4, 8, 12, 16, 20, 24, 28, 32, 33, 34, 35, 36}, + 9, 4, Major::RowMajor, PackOrder::ColMajor); +#endif +#endif + // For transpose-needed lhs + TestPack(data, + {1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12}, + 3, 4, Major::ColMajor, PackOrder::ColMajor); +#if defined(MACE_ENABLE_NEON) + TestPack(data, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + 4, 4, Major::ColMajor, PackOrder::ColMajor); + TestPack(data, + {1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 5, 10, 15, + 20}, + 5, 4, Major::ColMajor, PackOrder::ColMajor); +#if defined(__aarch64__) + TestPack(data, + {1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21, + 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 9, 18, 27, 36}, + 9, 4, Major::ColMajor, PackOrder::ColMajor); +#endif +#endif + // For no-transpose rhs + TestPack(data, + {1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12}, + 4, 3, Major::RowMajor, PackOrder::RowMajor); +#if defined(MACE_ENABLE_NEON) + TestPack(data, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + 4, 4, Major::RowMajor, PackOrder::RowMajor); + TestPack(data, + {1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 5, 10, 15, + 20}, + 4, 5, Major::RowMajor, PackOrder::RowMajor); +#endif + // For transpose-needed rhs + TestPack(data, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + 4, 3, Major::ColMajor, PackOrder::RowMajor); +#if defined(MACE_ENABLE_NEON) + TestPack(data, + {1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16}, + 4, 4, Major::ColMajor, PackOrder::RowMajor); + TestPack(data, + {1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16, 17, 18, 19, + 20}, + 4, 5, Major::ColMajor, PackOrder::RowMajor); +#endif +} + +TEST(SGemmPackTest, UnPack) { + TestUnPack(4, 3, Major::RowMajor, PackOrder::RowMajor); + TestUnPack(4, 4, Major::RowMajor, PackOrder::RowMajor); + TestUnPack(4, 5, Major::RowMajor, PackOrder::RowMajor); + TestUnPack(4, 100, Major::RowMajor, PackOrder::RowMajor); + TestUnPack(4, 3, Major::ColMajor, PackOrder::RowMajor); + TestUnPack(4, 4, Major::ColMajor, PackOrder::RowMajor); + TestUnPack(4, 5, Major::ColMajor, PackOrder::RowMajor); + TestUnPack(4, 100, Major::ColMajor, PackOrder::RowMajor); +} + +} // namespace test +} // namespace kernels +} // namespace mace + diff --git a/mace/ops/matmul.h b/mace/ops/matmul.h index ceccb9398aaa7d5b730951672c0370e5509e1f7f..64b336a38d98a52be95917a7d9750abdc1c6e1a9 100644 --- a/mace/ops/matmul.h +++ b/mace/ops/matmul.h @@ -40,7 +40,11 @@ class MatMulOp : public Operator { "than or equal to 2"); index_t rank = A->dim_size(); for (index_t i = 0; i < rank - 2; ++i) { - MACE_CHECK(A->dim(i) == B->dim(i), "batch dimensions are not equal"); + MACE_CHECK(A->dim(i) == B->dim(i), + "batch dimensions are not equal: ", + A->dim(i), + " vs. ", + B->dim(i)); } index_t ak = transpose_a_ ? A->dim(rank - 2) : A->dim(rank - 1); index_t bk = transpose_b_ ? B->dim(rank - 1) : B->dim(rank - 2); diff --git a/mace/ops/matmul_benchmark.cc b/mace/ops/matmul_benchmark.cc index 146f8d1c541771eb23ee16ad737b2435ce641f78..08b06fa7e83cc9d07436b1f2149766f798d6fb1a 100644 --- a/mace/ops/matmul_benchmark.cc +++ b/mace/ops/matmul_benchmark.cc @@ -33,13 +33,15 @@ void MatMulBenchmark( // Add input data net.AddRandomInput("A", {batch, height, channels}); net.AddRandomInput("B", {batch, channels, out_width}); + net.GetTensor("A")->SetIsWeight(true); + net.GetTensor("B")->SetIsWeight(true); if (DataTypeToEnum::value == DT_UINT8) { net.GetTensor("A")->SetScale(0.1); net.GetTensor("B")->SetScale(0.1); } - if (D == DeviceType::GPU) { - BufferToImage(&net, "A", "AImage", kernels::BufferType::IN_OUT_WIDTH); + BufferToImage(&net, "A", "AImage", + kernels::BufferType::IN_OUT_WIDTH); BufferToImage(&net, "B", "BImage", kernels::BufferType::IN_OUT_HEIGHT); @@ -71,7 +73,7 @@ void MatMulBenchmark( mace::testing::StartTiming(); while (iters--) { - net.RunOp(D); + net.Run(); } net.Sync(); } @@ -86,6 +88,8 @@ void MatMulTransposeBenchmark( // Add input data net.AddRandomInput("A", {batch, height, channels}); net.AddRandomInput("B", {batch, out_width, channels}); + net.GetTensor("A")->SetIsWeight(true); + net.GetTensor("B")->SetIsWeight(true); if (DataTypeToEnum::value == DT_UINT8) { net.GetTensor("A")->SetScale(0.1); net.GetTensor("B")->SetScale(0.1); @@ -116,7 +120,7 @@ void MatMulTransposeBenchmark( mace::testing::StartTiming(); while (iters--) { - net.RunOp(D); + net.Run(); } net.Sync(); } @@ -154,10 +158,15 @@ void MatMulTransposeBenchmark( MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, float, CPU); \ MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, uint8_t, CPU); +MACE_BM_MATMUL(1, 128, 128, 49); +MACE_BM_MATMUL(2, 128, 128, 49); +MACE_BM_MATMUL(3, 128, 128, 49); +MACE_BM_MATMUL(4, 128, 128, 49); MACE_BM_MATMUL(16, 32, 128, 49); MACE_BM_MATMUL(16, 32, 128, 961); MACE_BM_MATMUL(16, 32, 128, 3969); MACE_BM_MATMUL(16, 128, 128, 49); +MACE_BM_MATMUL(16, 49, 128, 128); MACE_BM_MATMUL(16, 128, 128, 961); MACE_BM_MATMUL(16, 128, 128, 3969); diff --git a/mace/ops/winograd_transform_benchmark.cc b/mace/ops/winograd_transform_benchmark.cc index ecba841742cc23f8585fb7b30d294f3fe7ce9173..9955c9abc35f63c02f7b5461da2cbc425657265f 100644 --- a/mace/ops/winograd_transform_benchmark.cc +++ b/mace/ops/winograd_transform_benchmark.cc @@ -211,8 +211,8 @@ void WinoMatMulBenchmark( const index_t round_w = (width + block_size - 1) / block_size; const index_t out_width = round_h * round_w; // Add input data - net.AddRandomInput("A", {batch, out_channels, in_channels, 1}); - net.AddRandomInput("B", {batch, in_channels, out_width, 1}); + net.AddRandomInput("A", {batch, out_channels, in_channels}); + net.AddRandomInput("B", {batch, in_channels, out_width}); if (D == DeviceType::GPU) { BufferToImage(&net, "A", "AImage", kernels::BufferType::IN_OUT_WIDTH);