提交 36e3bb6e 编写于 作者: M Megvii Engine Team

feat(mgb/dnn): add armv7 mk4_dot matmul

GitOrigin-RevId: d4206f8e21d1f58a7e07e1345d2738dc76e7bfbd
上级 580a2753
......@@ -706,6 +706,73 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8DotK4x8x4,
"AlgoQuint8DotK4x8x4"_hash,
armv7::matmul::gemm_dot_quint8_4x8,
uint8_t, int32_t);
/* ======================== Int8 MK4 8x6x4 dot algo ======================== */
namespace {
void int8_mk4_8x6x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_armv7_matmul_kern,
midout_iv("int8_mk4_8x6x4_dotprod_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<dt_int8>(),
Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
armv7::matmul::gemm_mk4_dots8_8x6 strategy(M, N, K, A_type, B_type,
C_type);
megdnn::matmul::GemmInterleaved<armv7::matmul::gemm_mk4_dots8_8x6>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // namespace
bool MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::usable(
const KernSizeParam& kern_size_param) const {
return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
(kern_size_param.C_type.enumv() == DTypeEnum::Int32 ||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::MK4_DOT &&
!kern_size_param.trA && !kern_size_param.trB;
}
size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(
megdnn_armv7_matmul_kern,
midout_iv("AlgoInt8x8x32MK4_8x6x4DotProd::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
armv7::matmul::gemm_mk4_dots8_8x6 strategy(M, N, K, A_type, B_type,
C_type);
return megdnn::matmul::GemmInterleaved<
armv7::matmul::gemm_mk4_dots8_8x6>(M, N, K, trA, trB,
strategy)
.get_workspace_size();
}
MIDOUT_END();
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::get_kern(
const KernSizeParam&) const {
return int8_mk4_8x6x4_dotprod_kern;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x6x4DotProd,
megdnn_armv7_matmul_kern,
"AlgoInt8x8x32MK4_8x6x4DotProd"_hash,
armv7::matmul::gemm_mk4_dots8_8x6, int8_t,
int32_t);
#endif
/* ===================== F32 algo K4x8 ===================== */
......
......@@ -93,6 +93,18 @@ public:
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH32_INT8_MK4_8X6X4_DOTPROD";
}
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
#endif
class MatrixMulImpl::AlgoF32Gemv final
......
......@@ -125,6 +125,20 @@ static inline void interleave_4x1_2_d(const int64_t*& inptr0,
: "q0", "q1", "q2", "q3", "cc", "memory");
}
static inline void interleave_2x1_4_s(const int32_t*& inptr0,
const int32_t*& inptr1,
int32_t*& outptr) {
asm volatile(
"vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3
"vld1.32 {d2, d3}, [%[inptr1]]!\n" // A0A1A2A3
"vst1.32 {d0, d1}, [%[outptr]]!\n"
"vst1.32 {d2, d3}, [%[outptr]]!\n"
:
[inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr)
:
: "d0", "d1", "d2", "d3", "cc", "memory");
}
template <typename T>
static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3,
......@@ -188,6 +202,17 @@ static inline void interleave_4x4_4_b(const T*& inptr0, const T*& inptr1,
: "q0", "q1", "q2", "q3", "memory");
}
template <typename T>
static inline void interleave_2x4_4_b(const T*& inptr0, const T*& inptr1,
T*& outptr) {
static_assert(
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
"interleave_2x4_4_b only support uint8_t and int8_t");
interleave_2x1_4_s(reinterpret_cast<const int32_t*&>(inptr0),
reinterpret_cast<const int32_t*&>(inptr1),
reinterpret_cast<int32_t*&>(outptr));
}
template <typename T>
static inline void interleave_6x4_4_b(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3,
......
此差异已折叠。
......@@ -16,6 +16,7 @@
#include "src/armv7/matrix_mul/int8/kernel_4x8x8.h"
#include "src/armv7/matrix_mul/int8/kernel_6x8x4.h"
#include "src/armv7/matrix_mul/int8/kernel_mk4_4x2x16.h"
#include "src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_common.h"
......@@ -252,6 +253,89 @@ void gemm_dots8_6x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
}
}
}
// ===========================gemm_mk4_dots8_8x6======================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_dots8_8x6);
void gemm_mk4_dots8_8x6::pack_A(dt_int8* out, const dt_int8* in, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
megdnn_assert(!transpose,
"matrix mul mk4 with transposed matrix A is not supported.");
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0,
"mk4 format matmul with m is not times of 4.");
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0,
"mk4 format matmul with k is not times of 4.");
matmul_mk4_dot_8x6x4::gemm_dots8_8x6_pack_A(out, in, ldin, y0, ymax, k0,
kmax);
}
void gemm_mk4_dots8_8x6::pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax,
bool transpose) const {
megdnn_assert(!transpose,
"matrix mul mk4 with transposed matrix B is not supported");
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0,
"mk4 format matmul with k is not times of 4.");
matmul_mk4_dot_8x6x4::gemm_dots8_8x6_pack_B(out, in, ldin, x0, xmax, k0,
kmax);
}
void gemm_mk4_dots8_8x6::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int32* C,
size_t LDC, bool is_first_k, const dt_int32* bias,
dt_int32* workspace) const {
MEGDNN_MARK_USED_VAR(bias);
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 6;
//! K is packed to times of 4
K = round_up<size_t>(K, 4);
const int K4 = K * 4;
const int K6 = K * 6;
const int K8 = K * 8;
size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int32_t* output = C + ((m >> 2) * LDC);
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_dot_8x6x4::kern_8x6(packA, cur_packB, K, output, LDC,
is_first_k);
output += 24;
cur_packB += K6;
}
for (; n < N; n += 4) {
size_t n_remain = std::min<size_t>(N - n, 4);
matmul_mk4_dot_8x6x4::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k, n_remain);
output += 16;
cur_packB += K4;
}
packA += K8;
}
for (; m < M; m += 4) {
int32_t* output = C + ((m >> 2) * LDC);
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_dot_8x6x4::kern_4x6(packA, cur_packB, K, output, LDC,
is_first_k);
output += 24;
cur_packB += K6;
}
for (; n < N; n += 4) {
size_t n_remain = std::min<size_t>(N - n, 4);
matmul_mk4_dot_8x6x4::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, n_remain);
output += 16;
cur_packB += K4;
}
packA += K4;
}
}
#endif
// ===========================gemm_mk4_s8_4x2======================================
......
......@@ -26,6 +26,9 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 2, 16, false, false,
#if __ARM_FEATURE_DOTPROD
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 6, 8, 4, false, false,
gemm_dots8_6x8);
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 6, 4, false, false,
gemm_mk4_dots8_8x6);
#endif
} // namespace matmul
} // namespace armv7
......
......@@ -29,6 +29,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#if __ARM_FEATURE_DOTPROD
AlgoInt8x8x32K6x8x4 int8_k6x8x4;
AlgoQuint8DotK4x8x4 quint8_k4x8x4;
AlgoInt8x8x32MK4_8x6x4DotProd int8x8x32_mk4_8x6x4_dotprod;
#endif
AlgoF32Gemv f32_gemv;
AlgoInt8x8x32MK4_4x2x16 int8x8x32_mk4_4x2x16;
......@@ -56,6 +57,7 @@ public:
all_algos.emplace_back(&f16_mk8_4x8);
#endif
#if __ARM_FEATURE_DOTPROD
all_algos.emplace_back(&int8x8x32_mk4_8x6x4_dotprod);
all_algos.emplace_back(&int8_k6x8x4);
all_algos.emplace_back(&quint8_k4x8x4);
#endif
......
......@@ -42,6 +42,8 @@ private:
#if __ARM_FEATURE_DOTPROD
class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4
class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4
class AlgoInt8x8x32MK4_8x6x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x6x4
// DotProduct
#endif
class AlgoPack;
};
......
......@@ -86,6 +86,18 @@ TEST_F(ARMV7, MATRIX_MUL_UDOT) {
dtype::Quantized8Asymm(4.0f, static_cast<uint8_t>(10)), dtype::Quantized8Asymm(3.0f, static_cast<uint8_t>(54)),
dtype::QuantizedS32(12.0f), handle(), "AARCH32_QUINT8_K4X8X4");
}
TEST_F(ARMV7, MATRIX_MUL_MK4_DOT_INT8) {
std::vector<matrix_mul::TestArg> args;
for (size_t m : {1, 2, 3, 4, 5, 7, 10, 11})
for (size_t n : {1, 2, 3, 4, 5, 8, 16, 24, 25, 32})
for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34})
args.emplace_back(m, n, k, 0);
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
handle(), "AARCH32_INT8_MK4_8X6X4_DOTPROD",
param::MatrixMul::Format::MK4_DOT, 1, 1e-3,
std::move(args));
}
#endif
#if MEGDNN_WITH_BENCHMARK
......@@ -286,6 +298,53 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x32_K6x8x4) {
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_QUINT8x8x32_K4x8x4) {
run_8x8x32_quint_benchmark(handle());
}
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x32_MK4_DOT) {
constexpr size_t RUNS = 50;
param::MatrixMul param;
Benchmarker<MatrixMul> benchmarker_default(handle());
benchmarker_default.set_times(RUNS)
.set_dtype(0, dtype::Int8())
.set_dtype(1, dtype::Int8())
.set_dtype(2, dtype::Int32())
.set_param(param)
.set_display(false);
benchmarker_default.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH32_INT8_K6X8X4"));
param.format = MatrixMul::Param::Format::MK4_DOT;
Benchmarker<MatrixMul> benchmarker_mk4_dot(handle());
benchmarker_mk4_dot.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH32_INT8_MK4_8X6X4_DOTPROD"));
benchmarker_mk4_dot.set_param(param)
.set_dtype(0, dtype::Int8())
.set_dtype(1, dtype::Int8())
.set_dtype(2, dtype::Int32())
.set_display(false)
.set_times(RUNS);
auto run = [&](size_t M, size_t N, size_t K) {
auto default_used =
benchmarker_default.exec({{M, K}, {K, N}, {}}) / RUNS;
auto mk4_dot_used = benchmarker_mk4_dot.exec(
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
RUNS;
float computations = 2.f * M * K * N * 1e-6;
printf("run: {%zu{M} %zu{K} %zu{N}} default: %f ms %f Gflops mk4_dot: "
"%f ms "
"%f Gflops speedup: %f\n",
M, K, N, default_used, computations / default_used, mk4_dot_used,
computations / mk4_dot_used, default_used / mk4_dot_used);
};
for (size_t M = 4; M < 512; M *= 2) {
for (size_t K = 4; K < 512; K *= 2) {
for (size_t N : {4, 8, 33, 113, 128}) {
run(M, N, K);
}
}
}
}
#endif
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x2x16) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册