提交 5d950063 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(dnn): refactor dot gemv for both aarch64 and aarch32

GitOrigin-RevId: 2b98867e4563bffe69f0340676e1a6c9ca8d0a2d
上级 53c288a3
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include "src/aarch64/matrix_mul/fp32/strategy.h" #include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/aarch64/matrix_mul/int16/strategy.h" #include "src/aarch64/matrix_mul/int16/strategy.h"
#include "src/aarch64/matrix_mul/int8/strategy.h" #include "src/aarch64/matrix_mul/int8/strategy.h"
#include "src/aarch64/matrix_mul/int8_dot/gemv.h"
#include "src/aarch64/matrix_mul/int8_dot/strategy.h" #include "src/aarch64/matrix_mul/int8_dot/strategy.h"
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h" #include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
#include "src/aarch64/matrix_mul/quint8/strategy.h" #include "src/aarch64/matrix_mul/quint8/strategy.h"
...@@ -441,39 +440,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd, ...@@ -441,39 +440,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd,
"AlgoInt8x8x32K8x12x4DotProdImpl"_hash, "AlgoInt8x8x32K8x12x4DotProdImpl"_hash,
aarch64::matmul::gemm_s8_8x12, int8_t, aarch64::matmul::gemm_s8_8x12, int8_t,
int32_t); int32_t);
/* ===================== Int8x8x32 Gemv DotProd algo ===================== */
namespace {
void int8x8x32_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
aarch64::matmul::gemv_like_int8(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
} // anonymous namespace
bool MatrixMulImpl::AlgoInt8x8x32GemvDotProd::usable(
const KernSizeParam& kern_size_param) const {
return can_be_treated_as_int8x8x32(kern_size_param) &&
!kern_size_param.trA && !kern_size_param.trB &&
kern_size_param.N == 1 && kern_size_param.LDB == 1;
}
bool MatrixMulImpl::AlgoInt8x8x32GemvDotProd::preferred(
const KernSizeParam& kern_size_param) const {
auto N = kern_size_param.N, LDB = kern_size_param.LDB;
return (N == 1 && LDB == 1);
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvDotProd::get_kern(
const KernSizeParam&) const {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("AlgoInt8x8x32GemvDotProd::get_kern"_hash)) {
return int8x8x32_gemv_dotprod_kern;
}
MIDOUT_END();
return nullptr;
}
/* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */ /* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */
namespace { namespace {
......
...@@ -104,21 +104,6 @@ public: ...@@ -104,21 +104,6 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };
class MatrixMulImpl::AlgoInt8x8x32GemvDotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH64_INT8X8X32_GEMV_DOTPROD";
}
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
};
class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase {
public: public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
...@@ -174,10 +159,6 @@ public: ...@@ -174,10 +159,6 @@ public:
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };
class MatrixMulImpl::AlgoInt8x8x32Gemv final
: public arm_common::MatrixMulImpl::AlgoInt8x8x32Gemv {};
#endif #endif
class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase {
......
/**
* \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/aarch64/matrix_mul/int8_dot/gemv.h"
#include <cstddef>
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/common/unroll_macro.h"
#if __ARM_FEATURE_DOTPROD
namespace {
void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride) {
megdnn_assert(N == 1 && Bstride == 1);
size_t m = 0;
for (; m + 2 <= M; m += 2) {
int32_t acc[4];
int32x4_t acc_neon = vdupq_n_s32(0);
size_t k = 0;
for (; k + 16 <= K; k += 16) {
int64x2_t a0 = vreinterpretq_s64_s8(vld1q_s8(A + m * Astride + k));
int64x2_t a1 =
vreinterpretq_s64_s8(vld1q_s8(A + (m + 1) * Astride + k));
//! the first 8 elements is m, the last 8 elements is m + 1
int8x16_t a2 = vreinterpretq_s8_s64(vzip1q_s64(a0, a1));
int8x16_t a3 = vreinterpretq_s8_s64(vzip2q_s64(a0, a1));
int64x2_t b0 = vreinterpretq_s64_s8(vld1q_s8(B + k));
int8x16_t b2 = vreinterpretq_s8_s64(vzip1q_s64(b0, b0));
int8x16_t b3 = vreinterpretq_s8_s64(vzip2q_s64(b0, b0));
acc_neon = vdotq_s32(acc_neon, a2, b2);
acc_neon = vdotq_s32(acc_neon, a3, b3);
}
vst1q_s32(acc, acc_neon);
for (; k + 8 <= K; k += 8) {
int8x8_t a0 = vld1_s8(A + m * Astride + k);
int8x8_t a1 = vld1_s8(A + (m + 1) * Astride + k);
int8x8_t b0 = vld1_s8(B + k);
uint32x2_t zero = vdup_n_s32(0);
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0));
zero = vdup_n_s32(0);
acc[3] += vaddv_s32(vdot_s32(zero, a1, b0));
}
for (; k < K; ++k) {
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k];
acc[3] += static_cast<int32_t>(A[(m + 1) * Astride + k]) * B[k];
}
C[m * Cstride] = acc[0] + acc[1];
C[(m + 1) * Cstride] = acc[2] + acc[3];
}
for (; m < M; ++m) {
int32_t acc[4];
int32x4_t acc_neon = vdupq_n_s32(0);
size_t k = 0;
for (; k + 16 <= K; k += 16) {
int8x16_t a0 = vld1q_s8(A + m * Astride + k);
int8x16_t b0 = vld1q_s8(B + k);
acc_neon = vdotq_s32(acc_neon, a0, b0);
}
vst1q_s32(acc, acc_neon);
for (; k + 8 <= K; k += 8) {
int8x8_t a0 = vld1_s8(A + m * Astride + k);
int8x8_t b0 = vld1_s8(B + k);
uint32x2_t zero = vdup_n_s32(0);
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0));
}
for (; k < K; ++k) {
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k];
}
C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3];
}
}
} // namespace
bool megdnn::aarch64::matmul::is_gemv_like_preferred_int8(
bool transposeA, bool transposeB, size_t M, size_t N, size_t K,
size_t /* LDA */, size_t LDB, size_t /* LDC */) {
if (transposeA)
return false;
if (transposeB)
return false;
MEGDNN_MARK_USED_VAR(K);
MEGDNN_MARK_USED_VAR(M);
return (N == 1 && LDB == 1);
}
void megdnn::aarch64::matmul::gemv_like_int8(const int8_t* __restrict A,
const int8_t* __restrict B,
int32_t* __restrict C, size_t M,
size_t N, size_t K, size_t Astride,
size_t Bstride, size_t Cstride) {
megdnn_assert(N == 1);
return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride);
}
#endif
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <cstddef>
#include <cstdint>
#if __ARM_FEATURE_DOTPROD
namespace megdnn {
namespace aarch64 {
namespace matmul {
bool is_gemv_like_preferred_int8(bool transposeA, bool transposeB, size_t M,
size_t N, size_t K, size_t LDA, size_t LDB,
size_t LDC);
void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride);
} // namespace matmul
} // namespace aarch64
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
...@@ -28,13 +28,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { ...@@ -28,13 +28,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#endif #endif
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod; AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod;
AlgoInt8x8x32GemvDotProd int8x8x32_gemv_dotprod;
AlgoInt8x8x32MK4_8x12x4DotProd int8x8x32_mk4_8x12x4_dotprod; AlgoInt8x8x32MK4_8x12x4DotProd int8x8x32_mk4_8x12x4_dotprod;
#else #else
AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16; AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16;
AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16; AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16;
AlgoInt8x8x32K8x8x8 int8x8x32_k8x8x8; AlgoInt8x8x32K8x8x8 int8x8x32_k8x8x8;
AlgoInt8x8x32Gemv int8x8x32_gemv;
#endif #endif
AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8; AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8;
AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16;
...@@ -63,11 +61,9 @@ public: ...@@ -63,11 +61,9 @@ public:
all_algos.emplace_back(&f16_mk8_8x8); all_algos.emplace_back(&f16_mk8_8x8);
#endif #endif
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
all_algos.emplace_back(&int8x8x32_gemv_dotprod);
all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod);
all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod);
#else #else
all_algos.emplace_back(&int8x8x32_gemv);
all_algos.emplace_back(&int8x8x32_k4x4x16); all_algos.emplace_back(&int8x8x32_k4x4x16);
all_algos.emplace_back(&int8x8x32_k8x8x8); all_algos.emplace_back(&int8x8x32_k8x8x8);
all_algos.emplace_back(&int8x8x32_mk4_4x4x16); all_algos.emplace_back(&int8x8x32_mk4_4x4x16);
......
...@@ -34,14 +34,12 @@ private: ...@@ -34,14 +34,12 @@ private:
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel
// 8x12x4 DotProduct // 8x12x4 DotProduct
class AlgoInt8x8x32GemvDotProd; // Aarch64 Int8x8x32 Gemv DotProduct
class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel
// 8x12x4 DotProduct // 8x12x4 DotProduct
#else #else
class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16
class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16
class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8
class AlgoInt8x8x32Gemv; // Aarch64 Int8x8x32 Gemv
#endif #endif
class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8
class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16
......
...@@ -72,7 +72,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern( ...@@ -72,7 +72,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern(
return exec_int_8x8x16; return exec_int_8x8x16;
} }
#if !__ARM_FEATURE_DOTPROD
/* ===================== Int8x8x32 Gemv algo ===================== */ /* ===================== Int8x8x32 Gemv algo ===================== */
namespace { namespace {
void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) {
...@@ -102,7 +101,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern( ...@@ -102,7 +101,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern(
const KernSizeParam&) const { const KernSizeParam&) const {
return int8x8x32_gemv_kern; return int8x8x32_gemv_kern;
} }
#endif
/* ===================== F32 Gemv algo ===================== */ /* ===================== F32 Gemv algo ===================== */
namespace { namespace {
...@@ -112,7 +110,6 @@ void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { ...@@ -112,7 +110,6 @@ void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) {
const auto Aptr = kern_param.A<dt_float32>(), const auto Aptr = kern_param.A<dt_float32>(),
Bptr = kern_param.B<dt_float32>(); Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>(); auto Cptr = kern_param.C<dt_float32>();
arm_common::sgemm_sgemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); arm_common::sgemm_sgemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
} }
} // anonymous namespace } // anonymous namespace
......
...@@ -27,11 +27,7 @@ public: ...@@ -27,11 +27,7 @@ public:
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
}; };
#if !__ARM_FEATURE_DOTPROD
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase {
protected:
~AlgoInt8x8x32Gemv() = default;
public: public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; }
...@@ -43,7 +39,6 @@ public: ...@@ -43,7 +39,6 @@ public:
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
}; };
#endif
class MatrixMulImpl::AlgoF32Gemv : public AlgoBase { class MatrixMulImpl::AlgoF32Gemv : public AlgoBase {
protected: protected:
......
...@@ -9,8 +9,6 @@ ...@@ -9,8 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#if !__ARM_FEATURE_DOTPROD
#include <cstddef> #include <cstddef>
#include "src/arm_common/matrix_mul/int8/gemv.h" #include "src/arm_common/matrix_mul/int8/gemv.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
...@@ -23,6 +21,8 @@ MIDOUT_DECL(megdnn_arm_common_int8_gemv) ...@@ -23,6 +21,8 @@ MIDOUT_DECL(megdnn_arm_common_int8_gemv)
using namespace megdnn; using namespace megdnn;
using namespace arm_common; using namespace arm_common;
#if !__ARM_FEATURE_DOTPROD
namespace { namespace {
void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B,
...@@ -95,8 +95,82 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, ...@@ -95,8 +95,82 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B,
C[m * Cstride] = acc0; C[m * Cstride] = acc0;
} }
} }
} // namespace
#endif
#if __ARM_FEATURE_DOTPROD
namespace {
void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride) {
megdnn_assert(N == 1 && Bstride == 1);
size_t m = 0;
for (; m + 2 <= M; m += 2) {
int32_t acc[4];
int32x4_t acc_neon = vdupq_n_s32(0);
size_t k = 0;
for (; k + 16 <= K; k += 16) {
int64x2_t a0 = vreinterpretq_s64_s8(vld1q_s8(A + m * Astride + k));
int64x2_t a1 =
vreinterpretq_s64_s8(vld1q_s8(A + (m + 1) * Astride + k));
//! the first 8 elements is m, the last 8 elements is m + 1
int8x16_t a2 = vreinterpretq_s8_s64(vzip1q_s64(a0, a1));
int8x16_t a3 = vreinterpretq_s8_s64(vzip2q_s64(a0, a1));
int64x2_t b0 = vreinterpretq_s64_s8(vld1q_s8(B + k));
int8x16_t b2 = vreinterpretq_s8_s64(vzip1q_s64(b0, b0));
int8x16_t b3 = vreinterpretq_s8_s64(vzip2q_s64(b0, b0));
acc_neon = vdotq_s32(acc_neon, a2, b2);
acc_neon = vdotq_s32(acc_neon, a3, b3);
}
vst1q_s32(acc, acc_neon);
for (; k + 8 <= K; k += 8) {
int8x8_t a0 = vld1_s8(A + m * Astride + k);
int8x8_t a1 = vld1_s8(A + (m + 1) * Astride + k);
int8x8_t b0 = vld1_s8(B + k);
uint32x2_t zero = vdup_n_s32(0);
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0));
zero = vdup_n_s32(0);
acc[3] += vaddv_s32(vdot_s32(zero, a1, b0));
}
for (; k < K; ++k) {
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k];
acc[3] += static_cast<int32_t>(A[(m + 1) * Astride + k]) * B[k];
}
C[m * Cstride] = acc[0] + acc[1];
C[(m + 1) * Cstride] = acc[2] + acc[3];
}
for (; m < M; ++m) {
int32_t acc[4];
int32x4_t acc_neon = vdupq_n_s32(0);
size_t k = 0;
for (; k + 16 <= K; k += 16) {
int8x16_t a0 = vld1q_s8(A + m * Astride + k);
int8x16_t b0 = vld1q_s8(B + k);
acc_neon = vdotq_s32(acc_neon, a0, b0);
}
vst1q_s32(acc, acc_neon);
for (; k + 8 <= K; k += 8) {
int8x8_t a0 = vld1_s8(A + m * Astride + k);
int8x8_t b0 = vld1_s8(B + k);
uint32x2_t zero = vdup_n_s32(0);
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0));
}
for (; k < K; ++k) {
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k];
}
C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3];
}
}
} // namespace } // namespace
#endif
bool matmul::is_gemv_like_preferred_int8(bool transposeA, bool transposeB, bool matmul::is_gemv_like_preferred_int8(bool transposeA, bool transposeB,
size_t M, size_t N, size_t K, size_t M, size_t N, size_t K,
...@@ -124,6 +198,5 @@ void matmul::gemv_like_int8(const int8_t* __restrict A, ...@@ -124,6 +198,5 @@ void matmul::gemv_like_int8(const int8_t* __restrict A,
} MIDOUT_END(); } MIDOUT_END();
} }
#endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#if !__ARM_FEATURE_DOTPROD
namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {
namespace matmul { namespace matmul {
...@@ -28,6 +27,6 @@ void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B, ...@@ -28,6 +27,6 @@ void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B,
} // namespace matmul } // namespace matmul
} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn
#endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -27,13 +27,14 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { ...@@ -27,13 +27,14 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoF16Gemv f16gemv; AlgoF16Gemv f16gemv;
#endif #endif
AlgoInt8x8x32Gemv int8x8x32_gemv;
public: public:
AlgoPack() { AlgoPack() {
all_algos.emplace_back(&int8x8x16); all_algos.emplace_back(&int8x8x16);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
all_algos.emplace_back(&f16gemv); all_algos.emplace_back(&f16gemv);
#endif #endif
all_algos.emplace_back(&int8x8x32_gemv);
} }
SmallVector<AlgoBase*> all_algos; SmallVector<AlgoBase*> all_algos;
}; };
......
...@@ -25,9 +25,7 @@ public: ...@@ -25,9 +25,7 @@ public:
protected: protected:
static void* const sm_arm_common_algo_type; static void* const sm_arm_common_algo_type;
#if !__ARM_FEATURE_DOTPROD
class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv
#endif
class AlgoF32Gemv; // Arm_common F32 Gemv class AlgoF32Gemv; // Arm_common F32 Gemv
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class AlgoF16Gemv; class AlgoF16Gemv;
......
...@@ -388,6 +388,19 @@ __ai int64x2_t vmovl_high_s32(int32x4_t __p0) { ...@@ -388,6 +388,19 @@ __ai int64x2_t vmovl_high_s32(int32x4_t __p0) {
__ai uint64x2_t vmovl_high_u32(uint32x4_t __p0) { __ai uint64x2_t vmovl_high_u32(uint32x4_t __p0) {
return vmovl_u32(vget_high_u32(__p0)); return vmovl_u32(vget_high_u32(__p0));
} }
__ai int64x2_t vzip1q_s64(int64x2_t& a, int64x2_t& b) {
return vcombine_s64(vget_low_s64(a), vget_low_s64(b));
}
__ai int64x2_t vzip2q_s64(int64x2_t& a, int64x2_t& b) {
return vcombine_s64(vget_high_s64(a), vget_high_s64(b));
}
__ai int32_t vaddv_s32(int32x2_t a) {
return vget_lane_s32(a, 0) + vget_lane_s32(a, 1);
}
#endif // MEGDNN_ARMV7 #endif // MEGDNN_ARMV7
//! pack vmovl_low_xx() on armv7 and armv8 //! pack vmovl_low_xx() on armv7 and armv8
......
...@@ -134,11 +134,6 @@ public: ...@@ -134,11 +134,6 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };
#if !__ARM_FEATURE_DOTPROD
class MatrixMulImpl::AlgoInt8x8x32Gemv final
: public arm_common::MatrixMulImpl::AlgoInt8x8x32Gemv {};
#endif
class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase { class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase {
public: public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
......
...@@ -35,9 +35,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { ...@@ -35,9 +35,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt8x8x32MK4_4x2x16 int8x8x32_mk4_4x2x16; AlgoInt8x8x32MK4_4x2x16 int8x8x32_mk4_4x2x16;
AlgoInt8x8x32K4x2x16 int8x8x32_k4x2x16; AlgoInt8x8x32K4x2x16 int8x8x32_k4x2x16;
AlgoInt8x8x32K4x8x8 int8x8x32_k4x8x8; AlgoInt8x8x32K4x8x8 int8x8x32_k4x8x8;
#if !__ARM_FEATURE_DOTPROD
AlgoInt8x8x32Gemv int8x8x32_gemv;
#endif
AlgoQuint8K4x8x8 quint8_k4x8x8; AlgoQuint8K4x8x8 quint8_k4x8x8;
AlgoInt8x8x16K4x2x16 int8x8x16_k4x2x16; AlgoInt8x8x16K4x2x16 int8x8x16_k4x2x16;
AlgoInt8x8x16K4x8x8 int8x8x16_k4x8x8; AlgoInt8x8x16K4x8x8 int8x8x16_k4x8x8;
...@@ -60,9 +57,6 @@ public: ...@@ -60,9 +57,6 @@ public:
all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod); all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod);
all_algos.emplace_back(&int8_k6x8x4); all_algos.emplace_back(&int8_k6x8x4);
all_algos.emplace_back(&quint8_k4x8x4); all_algos.emplace_back(&quint8_k4x8x4);
#endif
#if !__ARM_FEATURE_DOTPROD
all_algos.emplace_back(&int8x8x32_gemv);
#endif #endif
all_algos.emplace_back(&int8x8x32_mk4_4x2x16); all_algos.emplace_back(&int8x8x32_mk4_4x2x16);
all_algos.emplace_back(&int8x8x32_k4x2x16); all_algos.emplace_back(&int8x8x32_k4x2x16);
......
...@@ -27,9 +27,6 @@ private: ...@@ -27,9 +27,6 @@ private:
class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8 class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8
class AlgoInt8x8x32K4x2x16; // Armv7 Int8x8x32 Kernel 4x2x16 class AlgoInt8x8x32K4x2x16; // Armv7 Int8x8x32 Kernel 4x2x16
class AlgoInt8x8x32MK4_4x2x16; // Armv7 Int8x8x32 Kernel MK4 4x2x16 class AlgoInt8x8x32MK4_4x2x16; // Armv7 Int8x8x32 Kernel MK4 4x2x16
#if !__ARM_FEATURE_DOTPROD
class AlgoInt8x8x32Gemv; // Armv7 Int8x8x32 Gemv
#endif
class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8 class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8
class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16 class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16
class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8 class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8
......
...@@ -133,6 +133,36 @@ TEST_F(ARM_COMMON, MATRIX_MUL_FP16_TEST) { ...@@ -133,6 +133,36 @@ TEST_F(ARM_COMMON, MATRIX_MUL_FP16_TEST) {
} }
#endif #endif
TEST_F(ARM_COMMON, QINT8x8x32_GEMV) {
Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param;
checker.set_before_exec_callback(
AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEMV"));
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127);
checker.set_rng(0, rng.get()).set_rng(1, rng.get());
auto run = [&](size_t M, size_t K, size_t N) {
Param param;
param.transposeA = false;
param.transposeB = false;
TensorShape A, B;
A = TensorShape{M, K};
B = TensorShape{K, N};
checker.set_param(param)
.set_dtype(0, dtype::QuantizedS8(2.5f))
.set_dtype(1, dtype::QuantizedS8(2.5f))
.set_dtype(2, dtype::QuantizedS32(6.25f))
.execs({A, B, {}});
};
// N = 1
for (size_t M : {1, 10, 16, 33, 64})
for (size_t K : {7, 512, 1024})
for (size_t N : {1})
run(M, K, N);
}
#if MEGDNN_WITH_BENCHMARK #if MEGDNN_WITH_BENCHMARK
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册