提交 a450d0f5 编写于 作者: M Megvii Engine Team

feat(fallback): add fp16 mk8 8x8 matmul

GitOrigin-RevId: 1a50a8a7be433f3ba688eeaa1af460458e5f0b53
上级 b85792ac
...@@ -297,7 +297,66 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern( ...@@ -297,7 +297,66 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern(
const KernSizeParam&) const { const KernSizeParam&) const {
return gi_f32_mk4_4x8_kern; return gi_f32_mk4_4x8_kern;
} }
#if defined(GI_SUPPORT_F16)
/* ================== F16 Gemm MK8 gi algo ================== */
namespace {
void gi_f16_mk8_8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_fb_gi_matmul_kern, midout_iv("gi_f16_mk8_8x8_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<dt_float16>(), Bptr = kern_param.B<dt_float16>();
auto Cptr = kern_param.C<dt_float16>();
matmul::fallback::gi_sgemm_nopack_mk8_8x8_fp16 strategy(A_type, B_type, C_type);
megdnn::matmul::GemmInterleaved<
matmul::fallback::gi_sgemm_nopack_mk8_8x8_fp16, false>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // anonymous namespace
bool MatrixMulImpl::AlgoF16GiMK8_8x8::usable(
const KernSizeParam& kern_size_param) const {
constexpr size_t MB = 8;
constexpr size_t KB = 8;
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::MK8 &&
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float16() && !kern_size_param.trA &&
!kern_size_param.trB && kern_size_param.M % MB == 0 &&
kern_size_param.K % KB == 0;
}
size_t MatrixMulImpl::AlgoF16GiMK8_8x8::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(
megdnn_fb_gi_matmul_kern,
midout_iv("AlgoF16GiMK8_8x8::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
matmul::fallback::gi_sgemm_nopack_mk8_8x8_fp16 strategy(A_type, B_type, C_type);
return megdnn::matmul::GemmInterleaved<
matmul::fallback::gi_sgemm_nopack_mk8_8x8_fp16, false>(
M, N, K, trA, trB, strategy)
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16GiMK8_8x8::get_kern(
const KernSizeParam&) const {
return gi_f16_mk8_8x8_kern;
}
#endif
/* ===================== F32 algo gi mk4 pack K4x12 ===================== */ /* ===================== F32 algo gi mk4 pack K4x12 ===================== */
namespace { namespace {
void f32_gi_mk4_pack_4x12_kern(const MatrixMulImpl::KernParam& kern_param) { void f32_gi_mk4_pack_4x12_kern(const MatrixMulImpl::KernParam& kern_param) {
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <type_traits> #include <type_traits>
#include "src/common/algo_base.h" #include "src/common/algo_base.h"
#include "src/fallback/general_intrinsic/gi_common.h"
#include "src/fallback/matrix_mul/gemm_common.h" #include "src/fallback/matrix_mul/gemm_common.h"
#include "src/fallback/matrix_mul/opr_impl.h" #include "src/fallback/matrix_mul/opr_impl.h"
...@@ -97,6 +98,20 @@ public: ...@@ -97,6 +98,20 @@ public:
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_4x8) MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_4x8)
}; };
#if defined(GI_SUPPORT_F16)
class MatrixMulImpl::AlgoF16GiMK8_8x8 final : public AlgoBase {
public:
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "FB_GI_F16_MK8_8x8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 8, AlgoDataType::FLOAT16, MK8)
MEGDNN_DECL_ALGO_TYPE(FB_GI_F16_MK8_8x8)
};
#endif
class MatrixMulImpl::AlgoF32GiMK4Pack4x12 final : public AlgoBase { class MatrixMulImpl::AlgoF32GiMK4Pack4x12 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
......
#pragma once #pragma once
#include "src/fallback/general_intrinsic/gi_common.h"
#include "src/fallback/matrix_mul/gemm_common.h" #include "src/fallback/matrix_mul/gemm_common.h"
namespace megdnn { namespace megdnn {
...@@ -8,6 +9,11 @@ namespace fallback { ...@@ -8,6 +9,11 @@ namespace fallback {
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12); MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12);
MEGDNN_REG_GEMM_STRATEGY_NOPACK( MEGDNN_REG_GEMM_STRATEGY_NOPACK(
float, float, float, 4, 8, 1, false, true, gi_sgemm_nopack_4x8); float, float, float, 4, 8, 1, false, true, gi_sgemm_nopack_4x8);
#if defined(GI_SUPPORT_F16)
MEGDNN_REG_GEMM_STRATEGY_NOPACK(
dt_float16, dt_float16, dt_float16, 8, 8, 1, false, true,
gi_sgemm_nopack_mk8_8x8_fp16);
#endif
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, gi_sgemm_4x12); MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, gi_sgemm_4x12);
MEGDNN_REG_GEMM_STRATEGY( MEGDNN_REG_GEMM_STRATEGY(
float, float, float, 4, 12, 1, false, false, gi_sgemm_mk4_pack_4x12); float, float, float, 4, 12, 1, false, false, gi_sgemm_mk4_pack_4x12);
......
#include "src/fallback/general_intrinsic/gi_float16.h"
#if defined(GI_SUPPORT_F16)
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/generic_strategy.h"
using namespace megdnn;
using namespace matmul::fallback;
namespace {
#define MLA GiMultiplyAddScalarFloat16
void kern_8x1(
const gi_float16_t* A, const gi_float16_t* B, size_t LDB, size_t K,
gi_float16_t* C) {
LDB = LDB - 8;
K = K - 8;
GI_FLOAT16_t d0 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d1 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d2 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d3 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d4 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d5 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d6 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d7 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t vfzero = GiBroadcastFloat16(0.0);
GI_FLOAT16_t d8 = MLA(vfzero, d0, *(B));
d8 = MLA(d8, d1, *(B + 1));
d8 = MLA(d8, d2, *(B + 2));
d8 = MLA(d8, d3, *(B + 3));
d8 = MLA(d8, d4, *(B + 4));
d8 = MLA(d8, d5, *(B + 5));
d8 = MLA(d8, d6, *(B + 6));
d8 = MLA(d8, d7, *(B + 7));
B += 8;
B += LDB;
for (; K > 0; K -= 8) {
d0 = GiLoadFloat16(A);
A = A + 8;
d1 = GiLoadFloat16(A);
A = A + 8;
d2 = GiLoadFloat16(A);
A = A + 8;
d3 = GiLoadFloat16(A);
A = A + 8;
d4 = GiLoadFloat16(A);
A = A + 8;
d5 = GiLoadFloat16(A);
A = A + 8;
d6 = GiLoadFloat16(A);
A = A + 8;
d7 = GiLoadFloat16(A);
A = A + 8;
d8 = MLA(d8, d0, *(B));
d8 = MLA(d8, d1, *(B + 1));
d8 = MLA(d8, d2, *(B + 2));
d8 = MLA(d8, d3, *(B + 3));
d8 = MLA(d8, d4, *(B + 4));
d8 = MLA(d8, d5, *(B + 5));
d8 = MLA(d8, d6, *(B + 6));
d8 = MLA(d8, d7, *(B + 7));
B += 8;
B += LDB;
}
GiStoreFloat16(C, d8);
}
void kern_8x4(
const gi_float16_t* A, const gi_float16_t* B, size_t LDB, size_t K,
gi_float16_t* C) {
LDB = LDB - 32;
K = K - 8;
GI_FLOAT16_t d0 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d1 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d2 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d3 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d4 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d5 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d6 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d7 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t vfzero = GiBroadcastFloat16(0.0);
GI_FLOAT16_t d8 = MLA(vfzero, d0, *(B));
d8 = MLA(d8, d1, *(B + 1));
d8 = MLA(d8, d2, *(B + 2));
d8 = MLA(d8, d3, *(B + 3));
d8 = MLA(d8, d4, *(B + 4));
d8 = MLA(d8, d5, *(B + 5));
d8 = MLA(d8, d6, *(B + 6));
d8 = MLA(d8, d7, *(B + 7));
B += 8;
GI_FLOAT16_t d9 = MLA(vfzero, d0, *(B));
d9 = MLA(d9, d1, *(B + 1));
d9 = MLA(d9, d2, *(B + 2));
d9 = MLA(d9, d3, *(B + 3));
d9 = MLA(d9, d4, *(B + 4));
d9 = MLA(d9, d5, *(B + 5));
d9 = MLA(d9, d6, *(B + 6));
d9 = MLA(d9, d7, *(B + 7));
B += 8;
GI_FLOAT16_t d10 = MLA(vfzero, d0, *(B));
d10 = MLA(d10, d1, *(B + 1));
d10 = MLA(d10, d2, *(B + 2));
d10 = MLA(d10, d3, *(B + 3));
d10 = MLA(d10, d4, *(B + 4));
d10 = MLA(d10, d5, *(B + 5));
d10 = MLA(d10, d6, *(B + 6));
d10 = MLA(d10, d7, *(B + 7));
B += 8;
GI_FLOAT16_t d11 = MLA(vfzero, d0, *(B));
d11 = MLA(d11, d1, *(B + 1));
d11 = MLA(d11, d2, *(B + 2));
d11 = MLA(d11, d3, *(B + 3));
d11 = MLA(d11, d4, *(B + 4));
d11 = MLA(d11, d5, *(B + 5));
d11 = MLA(d11, d6, *(B + 6));
d11 = MLA(d11, d7, *(B + 7));
B += 8;
B += LDB;
for (; K > 0; K -= 8) {
d0 = GiLoadFloat16(A);
A = A + 8;
d1 = GiLoadFloat16(A);
A = A + 8;
d2 = GiLoadFloat16(A);
A = A + 8;
d3 = GiLoadFloat16(A);
A = A + 8;
d4 = GiLoadFloat16(A);
A = A + 8;
d5 = GiLoadFloat16(A);
A = A + 8;
d6 = GiLoadFloat16(A);
A = A + 8;
d7 = GiLoadFloat16(A);
A = A + 8;
d8 = MLA(d8, d0, *(B));
d8 = MLA(d8, d1, *(B + 1));
d8 = MLA(d8, d2, *(B + 2));
d8 = MLA(d8, d3, *(B + 3));
d8 = MLA(d8, d4, *(B + 4));
d8 = MLA(d8, d5, *(B + 5));
d8 = MLA(d8, d6, *(B + 6));
d8 = MLA(d8, d7, *(B + 7));
B += 8;
d9 = MLA(d9, d0, *(B));
d9 = MLA(d9, d1, *(B + 1));
d9 = MLA(d9, d2, *(B + 2));
d9 = MLA(d9, d3, *(B + 3));
d9 = MLA(d9, d4, *(B + 4));
d9 = MLA(d9, d5, *(B + 5));
d9 = MLA(d9, d6, *(B + 6));
d9 = MLA(d9, d7, *(B + 7));
B += 8;
d10 = MLA(d10, d0, *(B));
d10 = MLA(d10, d1, *(B + 1));
d10 = MLA(d10, d2, *(B + 2));
d10 = MLA(d10, d3, *(B + 3));
d10 = MLA(d10, d4, *(B + 4));
d10 = MLA(d10, d5, *(B + 5));
d10 = MLA(d10, d6, *(B + 6));
d10 = MLA(d10, d7, *(B + 7));
B += 8;
d11 = MLA(d11, d0, *(B));
d11 = MLA(d11, d1, *(B + 1));
d11 = MLA(d11, d2, *(B + 2));
d11 = MLA(d11, d3, *(B + 3));
d11 = MLA(d11, d4, *(B + 4));
d11 = MLA(d11, d5, *(B + 5));
d11 = MLA(d11, d6, *(B + 6));
d11 = MLA(d11, d7, *(B + 7));
B += 8;
B += LDB;
}
GiStoreFloat16(C, d8);
C = C + 8;
GiStoreFloat16(C, d9);
C = C + 8;
GiStoreFloat16(C, d10);
C = C + 8;
GiStoreFloat16(C, d11);
C = C + 8;
}
void kern_8x8(
const gi_float16_t* A, const gi_float16_t* B, size_t LDB, size_t K,
gi_float16_t* C) {
LDB -= 64;
K = K - 8;
GI_FLOAT16_t d0 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d1 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d2 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d3 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d4 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d5 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d6 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t d7 = GiLoadFloat16(A);
A = A + 8;
GI_FLOAT16_t vfzero = GiZeroFloat16();
GI_FLOAT16_t d8 = MLA(vfzero, d0, *(B));
d8 = MLA(d8, d1, *(B + 1));
d8 = MLA(d8, d2, *(B + 2));
d8 = MLA(d8, d3, *(B + 3));
d8 = MLA(d8, d4, *(B + 4));
d8 = MLA(d8, d5, *(B + 5));
d8 = MLA(d8, d6, *(B + 6));
d8 = MLA(d8, d7, *(B + 7));
B = B + 8;
GI_FLOAT16_t d9 = MLA(vfzero, d0, *(B));
d9 = MLA(d9, d1, *(B + 1));
d9 = MLA(d9, d2, *(B + 2));
d9 = MLA(d9, d3, *(B + 3));
d9 = MLA(d9, d4, *(B + 4));
d9 = MLA(d9, d5, *(B + 5));
d9 = MLA(d9, d6, *(B + 6));
d9 = MLA(d9, d7, *(B + 7));
B = B + 8;
GI_FLOAT16_t d10 = MLA(vfzero, d0, *(B));
d10 = MLA(d10, d1, *(B + 1));
d10 = MLA(d10, d2, *(B + 2));
d10 = MLA(d10, d3, *(B + 3));
d10 = MLA(d10, d4, *(B + 4));
d10 = MLA(d10, d5, *(B + 5));
d10 = MLA(d10, d6, *(B + 6));
d10 = MLA(d10, d7, *(B + 7));
B = B + 8;
GI_FLOAT16_t d11 = MLA(vfzero, d0, *(B));
d11 = MLA(d11, d1, *(B + 1));
d11 = MLA(d11, d2, *(B + 2));
d11 = MLA(d11, d3, *(B + 3));
d11 = MLA(d11, d4, *(B + 4));
d11 = MLA(d11, d5, *(B + 5));
d11 = MLA(d11, d6, *(B + 6));
d11 = MLA(d11, d7, *(B + 7));
B = B + 8;
GI_FLOAT16_t d12 = MLA(vfzero, d0, *(B));
d12 = MLA(d12, d1, *(B + 1));
d12 = MLA(d12, d2, *(B + 2));
d12 = MLA(d12, d3, *(B + 3));
d12 = MLA(d12, d4, *(B + 4));
d12 = MLA(d12, d5, *(B + 5));
d12 = MLA(d12, d6, *(B + 6));
d12 = MLA(d12, d7, *(B + 7));
B = B + 8;
GI_FLOAT16_t d13 = MLA(vfzero, d0, *(B));
d13 = MLA(d13, d1, *(B + 1));
d13 = MLA(d13, d2, *(B + 2));
d13 = MLA(d13, d3, *(B + 3));
d13 = MLA(d13, d4, *(B + 4));
d13 = MLA(d13, d5, *(B + 5));
d13 = MLA(d13, d6, *(B + 6));
d13 = MLA(d13, d7, *(B + 7));
B = B + 8;
GI_FLOAT16_t d14 = MLA(vfzero, d0, *(B));
d14 = MLA(d14, d1, *(B + 1));
d14 = MLA(d14, d2, *(B + 2));
d14 = MLA(d14, d3, *(B + 3));
d14 = MLA(d14, d4, *(B + 4));
d14 = MLA(d14, d5, *(B + 5));
d14 = MLA(d14, d6, *(B + 6));
d14 = MLA(d14, d7, *(B + 7));
B = B + 8;
GI_FLOAT16_t d15 = MLA(vfzero, d0, *(B));
d15 = MLA(d15, d1, *(B + 1));
d15 = MLA(d15, d2, *(B + 2));
d15 = MLA(d15, d3, *(B + 3));
d15 = MLA(d15, d4, *(B + 4));
d15 = MLA(d15, d5, *(B + 5));
d15 = MLA(d15, d6, *(B + 6));
d15 = MLA(d15, d7, *(B + 7));
B = B + 8;
B = B + LDB;
for (; K > 0; K -= 8) {
d0 = GiLoadFloat16(A);
A = A + 8;
d1 = GiLoadFloat16(A);
A = A + 8;
d2 = GiLoadFloat16(A);
A = A + 8;
d3 = GiLoadFloat16(A);
A = A + 8;
d4 = GiLoadFloat16(A);
A = A + 8;
d5 = GiLoadFloat16(A);
A = A + 8;
d6 = GiLoadFloat16(A);
A = A + 8;
d7 = GiLoadFloat16(A);
A = A + 8;
d8 = MLA(d8, d0, *(B));
d8 = MLA(d8, d1, *(B + 1));
d8 = MLA(d8, d2, *(B + 2));
d8 = MLA(d8, d3, *(B + 3));
d8 = MLA(d8, d4, *(B + 4));
d8 = MLA(d8, d5, *(B + 5));
d8 = MLA(d8, d6, *(B + 6));
d8 = MLA(d8, d7, *(B + 7));
B = B + 8;
d9 = MLA(d9, d0, *(B));
d9 = MLA(d9, d1, *(B + 1));
d9 = MLA(d9, d2, *(B + 2));
d9 = MLA(d9, d3, *(B + 3));
d9 = MLA(d9, d4, *(B + 4));
d9 = MLA(d9, d5, *(B + 5));
d9 = MLA(d9, d6, *(B + 6));
d9 = MLA(d9, d7, *(B + 7));
B = B + 8;
d10 = MLA(d10, d0, *(B));
d10 = MLA(d10, d1, *(B + 1));
d10 = MLA(d10, d2, *(B + 2));
d10 = MLA(d10, d3, *(B + 3));
d10 = MLA(d10, d4, *(B + 4));
d10 = MLA(d10, d5, *(B + 5));
d10 = MLA(d10, d6, *(B + 6));
d10 = MLA(d10, d7, *(B + 7));
B = B + 8;
d11 = MLA(d11, d0, *(B));
d11 = MLA(d11, d1, *(B + 1));
d11 = MLA(d11, d2, *(B + 2));
d11 = MLA(d11, d3, *(B + 3));
d11 = MLA(d11, d4, *(B + 4));
d11 = MLA(d11, d5, *(B + 5));
d11 = MLA(d11, d6, *(B + 6));
d11 = MLA(d11, d7, *(B + 7));
B = B + 8;
d12 = MLA(d12, d0, *(B));
d12 = MLA(d12, d1, *(B + 1));
d12 = MLA(d12, d2, *(B + 2));
d12 = MLA(d12, d3, *(B + 3));
d12 = MLA(d12, d4, *(B + 4));
d12 = MLA(d12, d5, *(B + 5));
d12 = MLA(d12, d6, *(B + 6));
d12 = MLA(d12, d7, *(B + 7));
B = B + 8;
d13 = MLA(d13, d0, *(B));
d13 = MLA(d13, d1, *(B + 1));
d13 = MLA(d13, d2, *(B + 2));
d13 = MLA(d13, d3, *(B + 3));
d13 = MLA(d13, d4, *(B + 4));
d13 = MLA(d13, d5, *(B + 5));
d13 = MLA(d13, d6, *(B + 6));
d13 = MLA(d13, d7, *(B + 7));
B = B + 8;
d14 = MLA(d14, d0, *(B));
d14 = MLA(d14, d1, *(B + 1));
d14 = MLA(d14, d2, *(B + 2));
d14 = MLA(d14, d3, *(B + 3));
d14 = MLA(d14, d4, *(B + 4));
d14 = MLA(d14, d5, *(B + 5));
d14 = MLA(d14, d6, *(B + 6));
d14 = MLA(d14, d7, *(B + 7));
B = B + 8;
d15 = MLA(d15, d0, *(B));
d15 = MLA(d15, d1, *(B + 1));
d15 = MLA(d15, d2, *(B + 2));
d15 = MLA(d15, d3, *(B + 3));
d15 = MLA(d15, d4, *(B + 4));
d15 = MLA(d15, d5, *(B + 5));
d15 = MLA(d15, d6, *(B + 6));
d15 = MLA(d15, d7, *(B + 7));
B = B + 8 + LDB;
}
GiStoreFloat16(C, d8);
C = C + 8;
GiStoreFloat16(C, d9);
C = C + 8;
GiStoreFloat16(C, d10);
C = C + 8;
GiStoreFloat16(C, d11);
C = C + 8;
GiStoreFloat16(C, d12);
C = C + 8;
GiStoreFloat16(C, d13);
C = C + 8;
GiStoreFloat16(C, d14);
C = C + 8;
GiStoreFloat16(C, d15);
C = C + 8;
}
#undef MLA
} // namespace
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gi_sgemm_nopack_mk8_8x8_fp16);
void gi_sgemm_nopack_mk8_8x8_fp16::kern(
const dt_float16* A, size_t LDA, const dt_float16* B, size_t LDB, dt_float16* C,
size_t LDC, size_t M, size_t K, size_t N, const dt_float16*, void*, bool trA,
bool trB) const {
constexpr size_t MB = 8;
constexpr size_t KB = 8;
constexpr size_t NB = 8;
constexpr size_t NB_HALF = 4;
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0);
for (size_t m = 0; m < M; m += MB) {
gi_float16_t* output = reinterpret_cast<gi_float16_t*>(C) + (m / MB) * LDC;
const gi_float16_t* cur_B = reinterpret_cast<const gi_float16_t*>(B);
size_t n = 0;
for (; n + NB - 1 < N; n += NB) {
kern_8x8(reinterpret_cast<const gi_float16_t*>(A), cur_B, LDB, K, output);
cur_B += KB * NB;
output += MB * NB;
}
if (N - n >= 4) {
kern_8x4(reinterpret_cast<const gi_float16_t*>(A), cur_B, LDB, K, output);
cur_B += KB * NB_HALF;
output += MB * NB_HALF;
n += 4;
}
while (n < N) {
kern_8x1(reinterpret_cast<const gi_float16_t*>(A), cur_B, LDB, K, output);
cur_B += KB;
output += MB;
n++;
}
A += LDA;
}
}
#endif
// vim: syntax=cpp.doxygen
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "src/common/algo_chooser.h" #include "src/common/algo_chooser.h"
#include "src/common/metahelper.h" #include "src/common/metahelper.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/general_intrinsic/gi_common.h"
#include "src/fallback/matrix_mul/algos.h" #include "src/fallback/matrix_mul/algos.h"
#include "src/fallback/matrix_mul/gemm_impl.h" #include "src/fallback/matrix_mul/gemm_impl.h"
#include "src/fallback/matrix_mul/generic_strategy.h" #include "src/fallback/matrix_mul/generic_strategy.h"
...@@ -30,6 +31,9 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { ...@@ -30,6 +31,9 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoF32GiMK4_4x8 f32_mk4_4x8; AlgoF32GiMK4_4x8 f32_mk4_4x8;
AlgoF32GiMK4Pack4x12 f32_mk4_gi_pack_4x12; AlgoF32GiMK4Pack4x12 f32_mk4_gi_pack_4x12;
AlgoF32Gi4x12 f32_4x8; AlgoF32Gi4x12 f32_4x8;
#if defined(GI_SUPPORT_F16)
AlgoF16GiMK8_8x8 f16_mk8_8x8;
#endif
SmallVector<AlgoBase*> m_all_algos; SmallVector<AlgoBase*> m_all_algos;
AlgoBase::Mapper m_all_algos_map; AlgoBase::Mapper m_all_algos_map;
...@@ -39,6 +43,9 @@ public: ...@@ -39,6 +43,9 @@ public:
m_all_algos.emplace_back(&f32_mk4_4x8); m_all_algos.emplace_back(&f32_mk4_4x8);
m_all_algos.emplace_back(&f32_mk4_gi_pack_4x12); m_all_algos.emplace_back(&f32_mk4_gi_pack_4x12);
m_all_algos.emplace_back(&f32_4x8); m_all_algos.emplace_back(&f32_4x8);
#if defined(GI_SUPPORT_F16)
m_all_algos.emplace_back(&f16_mk8_8x8);
#endif
m_all_algos.emplace_back(&gemv); m_all_algos.emplace_back(&gemv);
m_all_algos.emplace_back(&f32_k8x12x1); m_all_algos.emplace_back(&f32_k8x12x1);
m_all_algos.emplace_back(&naive); m_all_algos.emplace_back(&naive);
......
...@@ -103,6 +103,7 @@ public: ...@@ -103,6 +103,7 @@ public:
FB_NAIVE, FB_NAIVE,
FB_GI_F32_GEMV_MK4, FB_GI_F32_GEMV_MK4,
FB_GI_F32_MK4_4x8, FB_GI_F32_MK4_4x8,
FB_GI_F16_MK8_8x8,
FB_GI_F32_MK4_PACK_4x12, FB_GI_F32_MK4_PACK_4x12,
FB_GI_F32_4x12, FB_GI_F32_4x12,
...@@ -237,6 +238,7 @@ private: ...@@ -237,6 +238,7 @@ private:
class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44 class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44
class AlgoF32GiMK4Pack4x12; // fallback F32 gi Gemm pack NCHW44 class AlgoF32GiMK4Pack4x12; // fallback F32 gi Gemm pack NCHW44
class AlgoF32Gi4x12; // fallback F32 gi Gemm class AlgoF32Gi4x12; // fallback F32 gi Gemm
class AlgoF16GiMK8_8x8;
class AlgoGemv; class AlgoGemv;
class AlgoNaive; class AlgoNaive;
class AlgoPack; class AlgoPack;
......
#include "test/common/matrix_mul.h" #include "test/common/matrix_mul.h"
#include "src/fallback/general_intrinsic/gi_common.h"
#include "test/common/checker.h" #include "test/common/checker.h"
#include "test/common/rng.h" #include "test/common/rng.h"
#include "test/common/task_record_check.h" #include "test/common/task_record_check.h"
...@@ -42,6 +43,14 @@ TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) { ...@@ -42,6 +43,14 @@ TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) {
"FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1); "FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1);
} }
#if defined(GI_SUPPORT_F16)
TEST_F(FALLBACK, MATRIX_MUL_FP16_MK8_GI) {
matrix_mul::check_matrix_mul(
dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(),
"FB_GI_F16_MK8_8x8", param::MatrixMul::Format::MK8, 1);
}
#endif
TEST_F(FALLBACK, MATRIX_MUL_GI_F32_4x12) { TEST_F(FALLBACK, MATRIX_MUL_GI_F32_4x12) {
matrix_mul::check_matrix_mul( matrix_mul::check_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
...@@ -183,6 +192,15 @@ TEST_F(FALLBACK, BENCHMARK_MATRIX_FB_GI_F32_MK4_4x8) { ...@@ -183,6 +192,15 @@ TEST_F(FALLBACK, BENCHMARK_MATRIX_FB_GI_F32_MK4_4x8) {
"FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4); "FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4);
} }
#if defined(GI_SUPPORT_F16)
TEST_F(FALLBACK, BENCHMARK_MATRIX_FB_GI_F16_MK8_8x8) {
auto args = matrix_mul::get_benchmark_matmul_args();
matrix_mul::benchmark_single_algo(
handle(), args, dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
"FB_GI_F16_MK8_8x8", param::MatrixMul::Format::MK8);
}
#endif
#endif #endif
} // namespace test } // namespace test
} // namespace megdnn } // namespace megdnn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册