提交 73d84162 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(dnn/aarch64): add matmul with dotprod for mk4

GitOrigin-RevId: feb391d635f5d19ff745e6426d385ebeedb1f0ef
上级 c1397792
......@@ -474,6 +474,76 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvDotProd::get_kern(
MIDOUT_END();
return nullptr;
}
/* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */
namespace {
void int8x8x32_mk4_8x12x4_dotprod_kern(
const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("int8x8x32_mk4_8x12x4_dotprod_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<dt_int8>(),
Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
aarch64::matmul::gemm_mk4_s8_8x12 strategy(M, N, K, A_type, B_type,
C_type);
megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_mk4_s8_8x12>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // anonymous namespace
bool MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::usable(
const KernSizeParam& kern_size_param) const {
return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
(kern_size_param.C_type.enumv() == DTypeEnum::Int32 ||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::MK4_DOT &&
!kern_size_param.trA && !kern_size_param.trB;
}
size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(
megdnn_aarch64_matmul_kern,
midout_iv("AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
aarch64::matmul::gemm_mk4_s8_8x12 strategy(M, N, K, A_type, B_type,
C_type);
return megdnn::matmul::GemmInterleaved<
aarch64::matmul::gemm_mk4_s8_8x12>(M, N, K, trA, trB,
strategy)
.get_workspace_size();
}
MIDOUT_END();
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_kern(
const KernSizeParam&) const {
return int8x8x32_mk4_8x12x4_dotprod_kern;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x12x4DotProd,
megdnn_aarch64_matmul_kern,
"AlgoInt8x8x32MK4_8x12x4DotProdImpl"_hash,
aarch64::matmul::gemm_mk4_s8_8x12, int8_t,
int32_t);
#else
/* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */
......
......@@ -118,6 +118,19 @@ public:
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
};
class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD";
}
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
#else
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase {
......
......@@ -615,6 +615,20 @@ static inline void interleave_12x4_4_b(const T*& inptr0, const T*& inptr1,
reinterpret_cast<int32_t*&>(outptr));
}
static inline void interleave_2x1_4_s(const int32_t*& inptr0,
const int32_t*& inptr1,
int32_t*& outptr) {
asm volatile(
"ld1 {v0.4s}, [%[inptr0]], #16\n" // d0 = A0A1A2A3
"ld1 {v1.4s}, [%[inptr1]], #16\n" // d1 = B0B1B2B3
"st1 {v0.4s}, [%[outptr]], #16\n"
"st1 {v1.4s}, [%[outptr]], #16\n"
:
[inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr)
:
: "v0", "v1", "cc", "memory");
}
static inline void interleave_8x1_4_s(
const int32_t*& inptr0, const int32_t*& inptr1, const int32_t*& inptr2,
const int32_t*& inptr3, const int32_t*& inptr4, const int32_t*& inptr5,
......@@ -752,6 +766,17 @@ static inline void interleave_8x2_2_d(
"v11", "v12", "v13", "v14", "v15", "cc", "memory");
}
template <typename T>
static inline void interleave_2x4_4_b(const T*& inptr0, const T*& inptr1,
T*& outptr) {
static_assert(
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
"interleave_2x4_4_b only support uint8_t and int8_t");
interleave_2x1_4_s(reinterpret_cast<const int32_t*&>(inptr0),
reinterpret_cast<const int32_t*&>(inptr1),
reinterpret_cast<int32_t*&>(outptr));
}
template <typename T>
static inline void interleave_8x4_4_b(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3,
......
此差异已折叠。
......@@ -14,12 +14,14 @@
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h"
#include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h"
#if __ARM_FEATURE_DOTPROD
using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;
/* ====================== gemm_s8_8x12 ===========================*/
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12);
void gemm_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
......@@ -109,5 +111,91 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
packA += K4;
}
}
/* ====================== gemm_mk4_s8_8x12 ===========================*/
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_8x12);
void gemm_mk4_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix A is not supported");
matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0,
kmax);
}
void gemm_mk4_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose) const {
megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix B is not supported");
matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax);
}
void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int32* C,
size_t LDC, bool is_first_k, const dt_int32*,
dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 12;
//! K is packed to times of 4
K = round_up<size_t>(K, 4);
const int K8 = (K << 3);
const int K12 = K * 12;
const int K4 = K * 4;
size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int32_t* output = C + ((m >> 2) * LDC);
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k);
output += (B_INTERLEAVE << 2);
cur_packB += K12;
}
for (; n < N; n += 4) {
matmul_mk4_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
output += 16;
cur_packB += K4;
}
packA += K8;
}
for (; m < M; m += 4) {
int32_t* output = C + ((m >> 2) * LDC);
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k);
output += (B_INTERLEAVE << 2);
cur_packB += K12;
}
for (; n < N; n += 4) {
matmul_mk4_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
output += 16;
cur_packB += K4;
}
packA += K4;
}
}
#endif
// vim: syntax=cpp.doxygen
......@@ -19,6 +19,9 @@ namespace matmul {
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true,
gemm_s8_8x12);
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true,
gemm_mk4_s8_8x12);
} // namespace aarch64
} // namespace matmul
} // namespace megdnn
......
......@@ -29,6 +29,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#if __ARM_FEATURE_DOTPROD
AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod;
AlgoInt8x8x32GemvDotProd int8x8x32_gemv_dotprod;
AlgoInt8x8x32MK4_8x12x4DotProd int8x8x32_mk4_8x12x4_dotprod;
#else
AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16;
AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16;
......@@ -64,6 +65,7 @@ public:
#if __ARM_FEATURE_DOTPROD
all_algos.emplace_back(&int8x8x32_gemv_dotprod);
all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod);
all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod);
#else
all_algos.emplace_back(&int8x8x32_gemv);
all_algos.emplace_back(&int8x8x32_k4x4x16);
......
......@@ -35,6 +35,8 @@ private:
class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel
// 8x12x4 DotProduct
class AlgoInt8x8x32GemvDotProd; // Aarch64 Int8x8x32 Gemv DotProduct
class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel
// 8x12x4 DotProduct
#else
class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16
class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16
......
......@@ -64,6 +64,18 @@ TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_K8X12X4_DOTPROD) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
handle(), "AARCH64_INT8X8X32_K8X12X4_DOTPROD");
}
TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_MK4_8X12X4_DOTPROD) {
std::vector<matrix_mul::TestArg> args;
for (size_t m : {1, 2, 3, 4, 5, 6, 7, 10, 11})
for (size_t n : {2, 3, 4, 5, 8, 12, 13, 14, 15, 16, 31})
for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34})
args.emplace_back(m, n, k, 0);
matrix_mul::check_matrix_mul(
dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(),
"AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD",
param::MatrixMul::Format::MK4_DOT, 1, 1e-3, std::move(args));
}
#else
TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_K4X4X16) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
......@@ -460,6 +472,54 @@ TEST_F(AARCH64, BENCHMARK_GEMV_INT_8X8X32) {
run(M, N, K);
}
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT8X8X32_MK4_8X12X4) {
constexpr size_t RUNS = 50;
param::MatrixMul param;
param.transposeA = false;
param.transposeB = false;
Benchmarker<MatrixMul> benchmarker(handle());
Benchmarker<MatrixMul> benchmarker_mk4(handle());
benchmarker.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int32{})
.set_param(param)
.set_display(false);
benchmarker.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH64_INT8X8X32_K8X12X4"));
param.format = MatrixMul::Param::Format::MK4_DOT;
benchmarker_mk4.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"));
benchmarker_mk4.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int32{})
.set_param(param)
.set_display(false);
auto run = [&](size_t M, size_t N, size_t K) {
auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS;
auto mk_used = benchmarker_mk4.exec(
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
RUNS;
float computations = 2.f * M * K * N * 1e-6;
printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms "
"%f Gflops speedup_vs_normal: %f\n",
M, K, N, default_used, computations / default_used, mk_used,
computations / mk_used, default_used / mk_used);
};
run(256, 256, 128);
for (size_t k = 4; k <= 512; k *= 2) {
for (size_t m = 4; m <= 512; m *= 2) {
for (size_t n = 4; n <= 512; n *= 2) {
run(m, n, k);
}
}
std::cout << std::endl;
}
}
#endif // __ARM_FEATURE_DOTPROD
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册