提交 92b12685 编写于 作者: M Megvii Engine Team

feat(dnn/aarch64): add aarch64 int8X8X16_mk4_k8x8x8 matmul, performance is better

GitOrigin-RevId: b6af21e8e314b4edd62f0fddcf8578d2eaa0fc2a
上级 5ee1a1c4
......@@ -1310,4 +1310,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8,
int32_t);
#endif
/* ===================== Int8x8x16 K8x8x8 algo ===================== */
namespace {
void int8x8x16_mk4_8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("int8x8x16_mk4_8x8x8_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<dt_int8>(),
Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int16>();
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy(M, N, K, A_type,
B_type, C_type);
megdnn::matmul::GemmInterleaved<
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8>(M, N, K, trA, trB,
strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // anonymous namespace
bool MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::usable(
const KernSizeParam& kern_size_param) const {
return can_be_treated_as_int8x8x16(kern_size_param) &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
!kern_size_param.trA && !kern_size_param.trB &&
kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0;
}
bool MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::preferred(
const KernSizeParam&) const {
return true;
}
size_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("AlgoInt8x8x16_MK4_8x8x8::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy(M, N, K, A_type,
B_type, C_type);
return megdnn::matmul::GemmInterleaved<
matmul::gemm_s8x8x16_mk4_8x8x8>(M, N, K, trA, trB,
strategy)
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_kern(
const KernSizeParam&) const {
return int8x8x16_mk4_8x8x8_kern;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8,
megdnn_aarch64_matmul_kern,
"AlgoInt8x8x16MK4_K8x8x8Impl"_hash,
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, int8_t,
int16_t);
// vim: syntax=cpp.doxygen
......@@ -202,6 +202,22 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH64_INT8X8X16_MK4_K8X8X8";
}
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::DEFAULT; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
......
......@@ -2101,6 +2101,62 @@ static inline void transpos_12x4_s8(const int8_t* inptr0, int8_t* outptr) {
vreinterpretq_s32_s8(input2), 3);
}
template <typename T>
static inline void interleave_8x8_mk4_b(const T*& inptr0, const T*& inptr1,
T*& outptr) {
static_assert(
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
"transpose_8x4_1_b only support uint8_t and int8_t");
asm volatile(
"ld1 {v0.4s}, [%[inptr0]], #16\n"
"ld1 {v1.4s}, [%[inptr1]], #16\n"
"ld1 {v2.4s}, [%[inptr0]], #16\n"
"ld1 {v3.4s}, [%[inptr1]], #16\n"
"zip1 v4.4s, v0.4s, v1.4s \n"
"zip2 v5.4s, v0.4s, v1.4s \n"
"zip1 v6.4s, v2.4s, v3.4s\n"
"zip2 v7.4s, v2.4s, v3.4s\n"
"st1 {v4.4s},[%[outptr]],#16\n"
"st1 {v5.4s},[%[outptr]],#16\n"
"st1 {v6.4s},[%[outptr]],#16\n"
"st1 {v7.4s},[%[outptr]],#16\n"
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1),
[outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5","v6","v7","memory");
}
template <typename T>
static inline void transpose_8x8_mk4_b(const T*& inptr0, const T*& inptr1,
T* outptr) {
static_assert(
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
"transpose_8x4_1_b only support uint8_t and int8_t");
asm volatile(
"ld4 {v0.8b-v3.8b}, [%[inptr0]], #32\n"
"ld4 {v4.8b-v7.8b}, [%[inptr1]], #32\n"
"st1 {v0.2s},[%[outptr]],#8\n"
"st1 {v1.2s},[%[outptr]],#8\n"
"st1 {v2.2s},[%[outptr]],#8\n"
"st1 {v3.2s},[%[outptr]],#8\n"
"st1 {v4.2s},[%[outptr]],#8\n"
"st1 {v5.2s},[%[outptr]],#8\n"
"st1 {v6.2s},[%[outptr]],#8\n"
"st1 {v7.2s},[%[outptr]],#8\n"
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1),
[outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5","v6","v7","memory");
}
} // namespace aarch64
} // namespace megdnn
......
此差异已折叠。
......@@ -13,6 +13,7 @@
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h"
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
......@@ -357,4 +358,81 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB,
}
}
// ===========================gemm_s8x8x16_mk4_8x8x8==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8x8);
void gemm_s8x8x16_mk4_8x8x8::pack_A(dt_int8* out, const dt_int8* in,
int ldin, int y0, int ymax, int k0,
int kmax, bool) const {
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0,
ymax, k0, kmax);
}
void gemm_s8x8x16_mk4_8x8x8::pack_B(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool) const {
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0,
xmax, k0, kmax);
}
void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
megdnn_assert(is_first_k == true, "only impl is_first_k");
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4");
constexpr size_t pack_size = 4;
constexpr size_t pack_m = 8;
constexpr size_t pack_n = 8;
const size_t remain_n = N % pack_n;
size_t remain_m = M % pack_m;
K = round_up<size_t>(K, 8);
size_t KSIZE8 = K * pack_n;
size_t m_idx = 0;
for (; m_idx + pack_m <= M; m_idx += pack_m) {
int16_t* output = C + (m_idx / pack_size * LDC);
size_t n_idx = 0;
const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k, pack_m, pack_n);
output += pack_n * pack_size;
cur_packB += KSIZE8;
}
if (remain_n > 0) {
matmul_mk4_8x8x8::kern_8x8_remain(packA, cur_packB, K, output, LDC,
is_first_k, pack_m, remain_n);
output += remain_n * pack_size;
cur_packB += KSIZE8;
}
packA += KSIZE8;
}
if (remain_m == 4) {
int16_t* output = C + (m_idx / pack_size * LDC);
size_t n_idx = 0;
const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC,
is_first_k, 4, pack_n);
output += pack_n * pack_size;
cur_packB += pack_n * K;
}
if (remain_n > 0) {
matmul_mk4_8x8x8::kern_4x8_remain(packA, cur_packB, K, output, LDC,
is_first_k, 4, remain_n);
output += remain_n * pack_size;
cur_packB += pack_n * K;
}
}
}
// vim: syntax=cpp.doxygen
......@@ -26,6 +26,8 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 8, false, false,
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int16, dt_int16,
16, 12, 4, false, false,
gemm_s8x8x16_mk4_16x12_a53);
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, false,
gemm_s8x8x16_mk4_8x8x8);
} // namespace matmul
} // namespace aarch64
......
......@@ -39,6 +39,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16;
AlgoInt8x8x16MK4_16x12x4 int8x8x16_mk4_16x12x4;
AlgoInt8x8x16MK4_4x4x8 int8x8x16_mk4_4x4x8;
AlgoInt8x8x16MK4_K8x8x8 int8x8x16_mk4_k8x8x8;
AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1;
AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8;
......@@ -73,6 +74,7 @@ public:
#endif
all_algos.emplace_back(&int8x8x16_k4x4x16);
all_algos.emplace_back(&int8x8x16_k8x8x8);
all_algos.emplace_back(&int8x8x16_mk4_k8x8x8);
all_algos.emplace_back(&int8x8x16_mk4_4x4x8);
all_algos.emplace_back(&int8x8x16_mk4_16x12x4);
......
......@@ -57,6 +57,7 @@ private:
#else
class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8
#endif
class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16
class AlgoPack;
};
......
......@@ -122,6 +122,20 @@ TEST_F(AARCH64, MATRIX_MUL_INT8_MK4) {
std::move(args));
}
TEST_F(AARCH64, MATRIX_MUL_INT8x8x16_MK4) {
std::vector<matrix_mul::TestArg> args;
for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17})
for (size_t n :
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 24})
for (size_t k :
{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29})
args.emplace_back(m, n, k, 0);
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
handle(), "AARCH64_INT8X8X16_MK4_K8X8X8",
param::MatrixMul::Format::MK4, 1, 1e-3,
std::move(args));
}
TEST_F(AARCH64, MATRIX_MUL_MK4_8x8x16_4x4) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
handle(), "AARCH64_INT8X8X16_MK4_4X4X8",
......@@ -396,6 +410,71 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x16) {
run(384, 384, 384);
}
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x8_8x8x16_vs_4x4x16_8x8x16) {
constexpr size_t RUNS = 50;
param::MatrixMul param;
param.transposeA = false;
param.transposeB = false;
Benchmarker<MatrixMul> benchmarker(handle());
Benchmarker<MatrixMul> benchmarker_mk4(handle());
Benchmarker<MatrixMul> benchmarker_mk4_4x4x8(handle());
benchmarker.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int16{})
.set_param(param)
.set_display(false);
benchmarker.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_K4X4X16"));
param.format = MatrixMul::Param::Format::MK4;
benchmarker_mk4.set_before_exec_callback(
AlgoChecker<MatrixMul>(
"AARCH64_INT8X8X16_MK4_K8X8X8"
));
benchmarker_mk4.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int16{})
.set_param(param)
.set_display(false);
benchmarker_mk4_4x4x8.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_MK4_4X4X8"));
benchmarker_mk4_4x4x8.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int16{})
.set_param(param)
.set_display(false);
auto run = [&](size_t M, size_t N, size_t K) {
auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS;
auto mk_used = benchmarker_mk4.exec(
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
RUNS;
auto mk4_4x4x8_used =
benchmarker_mk4_4x4x8.exec(
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
RUNS;
float computations = 2.f * M * K * N * 1e-6;
printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms "
"%f Gflops speedup: %f, mk4_4x4x8 %f Gflops %f ms speedup: %f\n",
M, K, N, default_used, computations / default_used, mk_used,
computations / mk_used, default_used / mk_used,
computations / mk4_4x4x8_used, mk4_4x4x8_used , mk4_4x4x8_used/mk_used);
};
run(384, 384, 384);
run(512, 512, 512);
run(1024, 1024, 384);
run(256, 256, 384);
for(int m = 32; m <= 512;m*=2)
for(int n = 32; n <= 512;n*=2)
for(int k = 32; k < 512;k*=2){
run(m,n,k);
}
}
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_4X4X16) {
constexpr size_t RUNS = 50;
param::MatrixMul param;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册