diff --git a/dnn/src/arm_common/matrix_mul/algos.cpp b/dnn/src/arm_common/matrix_mul/algos.cpp index d06c07e796a254f4d097b13985cf33c26299f9d2..6e4ffe9353dd90975de12f598a7d6b5731b1b6eb 100644 --- a/dnn/src/arm_common/matrix_mul/algos.cpp +++ b/dnn/src/arm_common/matrix_mul/algos.cpp @@ -138,6 +138,63 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern( return f32_gemv_kern; } +/* ===================== F32 Gevm algo ===================== */ +namespace { + +void gevm_fp32_kern(const MatrixMulImpl::KernParam& kern_param) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto LDB = kern_param.LDB; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + arm_common::sgemm_sgemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1); +} + +void gevm_int8_kern(const MatrixMulImpl::KernParam& kern_param) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto LDB = kern_param.LDB; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + arm_common::matmul::gemv_like_int8(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1); +} + +} // anonymous namespace + +bool MatrixMulImpl::AlgoGevm::usable( + const KernSizeParam& kern_size_param) const { + // enumerate the M, N, K, only usable when preferred + bool fp32_ok = + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float32(); + return (fp32_ok || can_be_treated_as_int8x8x32(kern_size_param)) && + preferred(kern_size_param); +} + +bool MatrixMulImpl::AlgoGevm::preferred( + const KernSizeParam& kern_size_param) const { + auto M = kern_size_param.M; + return kern_size_param.trB && M == 1; +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoGevm::get_kern( + const KernSizeParam& kern_size_param) const { + if (kern_size_param.A_type == dtype::Float32()) { + return gevm_fp32_kern; + } else if (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || + kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) { + return gevm_int8_kern; + } else { + megdnn_assert( + false, "no avaliable kern got A_type: %s B_type: %s C_type: %s", + kern_size_param.A_type.name(), kern_size_param.B_type.name(), + kern_size_param.C_type.name()); + } +} + #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC /* ===================== F16 Gemv algo ===================== */ namespace { diff --git a/dnn/src/arm_common/matrix_mul/algos.h b/dnn/src/arm_common/matrix_mul/algos.h index 68b13a85bb5c59a0c84eb9170f0c6b402556c858..a8f7a33902185c083ce6744b817175761a6a7e43 100644 --- a/dnn/src/arm_common/matrix_mul/algos.h +++ b/dnn/src/arm_common/matrix_mul/algos.h @@ -70,6 +70,21 @@ public: PackMode packmode() const override { return PackMode::NO_PACK; } }; #endif + +class MatrixMulImpl::AlgoGevm : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_COMMON_GEVM"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override { return 0; } + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } + PackMode packmode() const override { return PackMode::NO_PACK; } +}; + + } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.cpp b/dnn/src/arm_common/matrix_mul/opr_impl.cpp index 0767751e883bc0b5da0c2ece9cd47199532e97ef..5ef42558d0d7f30876bc0b2ad7b1f996f069eda7 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.cpp +++ b/dnn/src/arm_common/matrix_mul/opr_impl.cpp @@ -27,7 +27,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC AlgoF16Gemv f16gemv; #endif - AlgoInt8x8x32Gemv int8x8x32_gemv; + AlgoInt8x8x32Gemv int8x8x32_gemv; + AlgoGevm gevm; public: AlgoPack() { all_algos.emplace_back(&int8x8x16); @@ -35,6 +36,7 @@ public: all_algos.emplace_back(&f16gemv); #endif all_algos.emplace_back(&int8x8x32_gemv); + all_algos.emplace_back(&gevm); } SmallVector all_algos; }; diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.h b/dnn/src/arm_common/matrix_mul/opr_impl.h index bfde8615dcff89c1bec72b3b32390c4d1fe25a91..bd53987d50d0e41525a1e4b9851919e5ae7745c6 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.h +++ b/dnn/src/arm_common/matrix_mul/opr_impl.h @@ -27,6 +27,7 @@ protected: static void* const sm_arm_common_algo_type; class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv class AlgoF32Gemv; // Arm_common F32 Gemv + class AlgoGevm; // Arm_common Gemv(support int8 and fp32) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC class AlgoF16Gemv; #endif diff --git a/dnn/test/arm_common/matrix_mul.cpp b/dnn/test/arm_common/matrix_mul.cpp index e4d1ad49e38488e96f79db38c6ae3673e1745d9a..ad190fc8825cb5097e5b08e364746c78fa83f61f 100644 --- a/dnn/test/arm_common/matrix_mul.cpp +++ b/dnn/test/arm_common/matrix_mul.cpp @@ -164,6 +164,62 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV) { run(M, K, N); } +TEST_F(ARM_COMMON, QINT8x8x32_GEVM) { + Checker checker(handle()); + using Param = MatrixMul::Param; + + checker.set_before_exec_callback( + AlgoChecker("ARM_COMMON_GEVM")); + + std::unique_ptr rng = std::make_unique(-127, 127); + checker.set_rng(0, rng.get()).set_rng(1, rng.get()); + + auto run = [&](size_t M, size_t K, size_t N) { + Param param; + param.transposeA = false; + param.transposeB = true; + TensorShape A, B; + A = TensorShape{M, K}; + B = TensorShape{N, K}; + checker.set_param(param) + .set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .execs({A, B, {}}); + }; + + // M = 1 + for (size_t N : {1, 10, 16, 33, 64}) + for (size_t K : {7, 512, 1024}) + for (size_t M : {1}) + run(M, K, N); +} + +TEST_F(ARM_COMMON, FP32_GEVM) { + Checker checker(handle()); + using Param = MatrixMul::Param; + + checker.set_before_exec_callback( + AlgoChecker("ARM_COMMON_GEVM")); + + checker.set_epsilon(1e-2); + auto run = [&](size_t M, size_t K, size_t N) { + Param param; + param.transposeA = false; + param.transposeB = true; + TensorShape A, B; + A = TensorShape{M, K}; + B = TensorShape{N, K}; + checker.set_param(param).execs({A, B, {}}); + }; + + // M = 1 + for (size_t M : {1}) + for (size_t K : {1000, 4096, 25088}) + for (size_t N : {1000, 4096}) + run(M, K, N); +} + #if MEGDNN_WITH_BENCHMARK TEST_F(ARM_COMMON, BENCHMARK_SGEMV) {