提交 0f1afb09 编写于 作者: M Megvii Engine Team

feat(fallback): imp gi matmul AlgoF32GiMK4_4x8 algo,

move AlgoF32GemvMK4 from arm_common to fallback

GitOrigin-RevId: 6c065abf999e87ed28123d02071a1571e545df83
上级 410dcb6c
...@@ -239,46 +239,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern(const KernSizeParam&) ...@@ -239,46 +239,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern(const KernSizeParam&)
return f32_gemv_kern; return f32_gemv_kern;
} }
/* ================== F32 Gemv MK4 algo ================== */
namespace {
void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_arm_exec_fp32, midout_iv("f32_gemv_mk4_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_float32>(), Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>();
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
MIDOUT_END();
}
} // anonymous namespace
bool MatrixMulImpl::AlgoF32GemvMK4::usable(const KernSizeParam& kern_size_param) const {
// enumerate the M, N, K, only usable when preferred
auto M = kern_size_param.M;
auto N = kern_size_param.N;
auto K = kern_size_param.K;
auto LDB = kern_size_param.LDB;
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA &&
!kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4;
}
bool MatrixMulImpl::AlgoF32GemvMK4::preferred(
const KernSizeParam& kern_size_param) const {
MEGDNN_MARK_USED_VAR(kern_size_param);
return true;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GemvMK4::get_kern(
const KernSizeParam&) const {
return f32_gemv_mk4_kern;
}
/* ===================== F32 Gevm algo ===================== */ /* ===================== F32 Gevm algo ===================== */
namespace { namespace {
template <typename stype, typename dtype> template <typename stype, typename dtype>
......
...@@ -95,22 +95,6 @@ public: ...@@ -95,22 +95,6 @@ public:
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT)
}; };
class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override;
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4)
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F32_GEMV_MK4)
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class MatrixMulImpl::AlgoF16Gemv : public AlgoBase { class MatrixMulImpl::AlgoF16Gemv : public AlgoBase {
public: public:
......
...@@ -26,7 +26,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { ...@@ -26,7 +26,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot;
#endif #endif
AlgoGevm gevm; AlgoGevm gevm;
AlgoF32GemvMK4 f32_gemv_mk4;
SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos;
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map;
...@@ -42,7 +41,6 @@ public: ...@@ -42,7 +41,6 @@ public:
#endif #endif
m_all_algos.emplace_back(&int8x8x32_gemv); m_all_algos.emplace_back(&int8x8x32_gemv);
m_all_algos.emplace_back(&int8x8x32_gemv_mk4); m_all_algos.emplace_back(&int8x8x32_gemv_mk4);
m_all_algos.emplace_back(&f32_gemv_mk4);
m_all_algos.emplace_back(&gevm); m_all_algos.emplace_back(&gevm);
for (auto&& algo : m_all_algos) { for (auto&& algo : m_all_algos) {
......
...@@ -34,7 +34,6 @@ public: ...@@ -34,7 +34,6 @@ public:
protected: protected:
class AlgoF32Gemv; // Arm_common F32 Gemv class AlgoF32Gemv; // Arm_common F32 Gemv
class AlgoF32GemvMK4; // Arm_common F32 Gemv NCHW44
class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv
class AlgoInt8x8x32GemvMK4; // Arm_common Int8x8x32 Gemv NCHW44 class AlgoInt8x8x32GemvMK4; // Arm_common Int8x8x32 Gemv NCHW44
class AlgoGevm; // Arm_common Gevm(support int8 and fp32) class AlgoGevm; // Arm_common Gevm(support int8 and fp32)
......
...@@ -17,11 +17,15 @@ ...@@ -17,11 +17,15 @@
#include "src/naive/matrix_mul/matrix_mul_helper.h" #include "src/naive/matrix_mul/matrix_mul_helper.h"
#include "src/fallback/matrix_mul/gi/fp32/exec_sgemv.h"
#include "midout.h" #include "midout.h"
MIDOUT_DECL(megdnn_fb_matmul_f32_kern) MIDOUT_DECL(megdnn_fb_matmul_f32_kern)
MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like) MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like)
MIDOUT_DECL(megdnn_fb_matmul_naive) MIDOUT_DECL(megdnn_fb_matmul_naive)
MIDOUT_DECL(megdnn_fb_gi_exec_fp32)
MIDOUT_DECL(megdnn_fb_gi_matmul_kern)
using namespace megdnn; using namespace megdnn;
using namespace fallback; using namespace fallback;
...@@ -205,4 +209,99 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoNaive::get_kern(const KernSizeParam&) c ...@@ -205,4 +209,99 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoNaive::get_kern(const KernSizeParam&) c
return kern_naive; return kern_naive;
} }
/* ================== F32 Gemv MK4 gi algo ================== */
namespace {
void gi_f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_fb_gi_exec_fp32, midout_iv("f32_gemv_mk4_gi_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_float32>(), Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>();
gi_gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
MIDOUT_END();
}
} // anonymous namespace
bool MatrixMulImpl::AlgoF32GiGemvMK4::usable(
const KernSizeParam& kern_size_param) const {
// enumerate the M, N, K, only usable when preferred
auto M = kern_size_param.M;
auto N = kern_size_param.N;
auto K = kern_size_param.K;
auto LDB = kern_size_param.LDB;
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA &&
!kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4;
}
bool MatrixMulImpl::AlgoF32GiGemvMK4::preferred(
const KernSizeParam& kern_size_param) const {
MEGDNN_MARK_USED_VAR(kern_size_param);
return true;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiGemvMK4::get_kern(
const KernSizeParam&) const {
return gi_f32_gemv_mk4_kern;
}
/* ================== F32 Gemm MK4 gi algo ================== */
namespace {
void gi_f32_mk4_4x8_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_fb_gi_matmul_kern, midout_iv("gi_f32_mk4_4x8_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<float>(), Bptr = kern_param.B<float>();
auto Cptr = kern_param.C<float>();
matmul::fallback::gi_sgemm_nopack_4x8 strategy(A_type, B_type, C_type);
megdnn::matmul::GemmInterleaved<matmul::fallback::gi_sgemm_nopack_4x8, false>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // anonymous namespace
bool MatrixMulImpl::AlgoF32GiMK4_4x8::usable(
const KernSizeParam& kern_size_param) const {
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA &&
!kern_size_param.trB;
}
size_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(
megdnn_fb_gi_matmul_kern,
midout_iv("AlgoF32GiMK4_4x8::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
matmul::fallback::gi_sgemm_nopack_4x8 strategy(A_type, B_type, C_type);
return megdnn::matmul::GemmInterleaved<
matmul::fallback::gi_sgemm_nopack_4x8, false>(
M, N, K, trA, trB, strategy)
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern(
const KernSizeParam&) const {
return gi_f32_mk4_4x8_kern;
}
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -80,6 +80,34 @@ public: ...@@ -80,6 +80,34 @@ public:
DEFAULT) DEFAULT)
}; };
class MatrixMulImpl::AlgoF32GiGemvMK4 : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "FB_GI_F32_GEMV_MK4"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override;
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4)
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_GEMV_MK4)
};
class MatrixMulImpl::AlgoF32GiMK4_4x8 final : public AlgoBase {
public:
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "FB_GI_F32_MK4_4x8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4)
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_4x8)
};
} // namespace fallback } // namespace fallback
} // namespace megdnn } // namespace megdnn
......
...@@ -16,6 +16,8 @@ namespace matmul { ...@@ -16,6 +16,8 @@ namespace matmul {
namespace fallback { namespace fallback {
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12); MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12);
MEGDNN_REG_GEMM_STRATEGY_NOPACK(
float, float, float, 4, 8, 1, false, true, gi_sgemm_nopack_4x8);
} // namespace fallback } // namespace fallback
} // namespace matmul } // namespace matmul
......
/**
* \file dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2022 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/fallback/matrix_mul/gi/fp32/exec_sgemv.h"
#include "include/megdnn/oprs.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/general_intrinsic/gi_float.h"
#include "midout.h"
MIDOUT_DECL(megdnn_fp32_gi_sgemv)
using namespace megdnn;
using namespace fallback;
namespace {
void sgemv_gi_naive_n_mk4(
const float* __restrict A, const float* __restrict B, float* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) {
constexpr size_t PACK_SIZE = 4;
megdnn_assert(
N == 1 && Bstride == PACK_SIZE && M % PACK_SIZE == 0 && K % PACK_SIZE == 0);
auto Aptr = A;
auto Cptr = C;
size_t m = 0;
while (m < M) {
auto Aptr0 = Aptr;
auto Cptr0 = Cptr;
GI_FLOAT32_t c[4];
#define INIT(step) c[step] = GiBroadcastFloat32(0.0f);
UNROLL_CALL_RAW(4, INIT)
#undef INIT
auto Bptr = B;
size_t k = 0;
while (k < K) {
GI_FLOAT32_t b = GiLoadFloat32(Bptr);
GI_FLOAT32_V2_t a[2];
#if defined(GI_TEST_NAIVE)
#define LOAD_A(step) \
a[step].val[0] = GiLoadFloat32(Aptr0 + step * 8); \
a[step].val[1] = GiLoadFloat32(Aptr0 + step * 8 + 4);
#elif defined(__arm__) || defined(__aarch64__)
#define LOAD_A(step) a[step] = vld1q_f32_x2(Aptr0 + step * 8);
#else
#define LOAD_A(step) \
a[step].val[0] = GiLoadFloat32(Aptr0 + step * 8); \
a[step].val[1] = GiLoadFloat32(Aptr0 + step * 8 + 4);
#endif
UNROLL_CALL_RAW(2, LOAD_A)
#undef LOAD_A
#define COMPT(step) \
c[step] = GiSimdFmaLane(c[step], a[step / 2].val[step % 2], b, step % 4);
UNROLL_CALL_RAW(4, COMPT)
#undef COMPT
Bptr += Bstride;
Aptr0 += PACK_SIZE * PACK_SIZE;
k += PACK_SIZE;
}
#define ADD_C(step, stride) c[step] = GiAddFloat32(c[step], c[step + stride]);
UNROLL_CALL_RAW(2, ADD_C, 2)
UNROLL_CALL_RAW(1, ADD_C, 1)
#undef ADD_C
GiStoreFloat32(Cptr0, c[0]);
Aptr += Astride;
Cptr += Cstride;
m += PACK_SIZE;
}
}
} // namespace
namespace megdnn {
namespace fallback {
void gi_gemv_like_mk4(
const float* __restrict A, const float* __restrict B, float* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) {
megdnn_assert(N == 1 && Bstride == 4);
MIDOUT_BEGIN(megdnn_fp32_gi_sgemv, midout_iv("F32_GEMV_NCHW_GI_44_N"_hash)) {
return sgemv_gi_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride);
}
MIDOUT_END();
}
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2022 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <cstddef>
namespace megdnn {
namespace fallback {
void gi_gemv_like_mk4(
const float* __restrict A, const float* __restrict B, float* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride);
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2022 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/common/utils.h"
#include "src/fallback/general_intrinsic/gi_float.h"
#include "src/fallback/matrix_mul/generic_strategy.h"
using namespace megdnn;
using namespace matmul::fallback;
namespace {
void kern_4x1(const float* A, const float* B, size_t LDB, size_t K, float* C) {
LDB = LDB - 4;
K = K - 4;
GI_FLOAT32_t d8d9 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d10d11 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d12d13 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d14d15 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d16d17 = GiBroadcastFloat32(0.0f);
GI_FLOAT32_t d18d19 = GiBroadcastFloat32(0.0f);
GI_FLOAT32_t d20d21 = GiBroadcastFloat32(0.0f);
GI_FLOAT32_t d22d23 = GiBroadcastFloat32(0.0f);
GI_FLOAT32_t d0d1 = GiLoadFloat32(B);
B = B + 4;
d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d0d1, 1);
for (; K > 0; K -= 4) {
d8d9 = GiLoadFloat32(A);
A = A + 4;
d10d11 = GiLoadFloat32(A);
A = A + 4;
d20d21 = GiSimdFmaLane(d20d21, d12d13, d0d1, 2);
d22d23 = GiSimdFmaLane(d22d23, d14d15, d0d1, 3);
B = B + LDB;
d0d1 = GiLoadFloat32(B);
B = B + 4;
d12d13 = GiLoadFloat32(A);
A = A + 4;
d14d15 = GiLoadFloat32(A);
A = A + 4;
d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d0d1, 1);
}
d20d21 = GiSimdFmaLane(d20d21, d12d13, d0d1, 2);
d22d23 = GiSimdFmaLane(d22d23, d14d15, d0d1, 3);
d16d17 = GiAddFloat32(d16d17, d20d21);
d18d19 = GiAddFloat32(d18d19, d22d23);
d16d17 = GiAddFloat32(d16d17, d18d19);
GiStoreFloat32(C, d16d17);
C = C + 4;
}
void kern_4x4(const float* A, const float* B, size_t LDB, size_t K, float* C) {
LDB = (LDB - 16);
K = K - 4;
GI_FLOAT32_t d8d9 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d10d11 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d12d13 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d14d15 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d0d1 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d2d3 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d4d5 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d6d7 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0);
GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0);
GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0);
GI_FLOAT32_t d22d23 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0);
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1);
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1);
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1);
for (; K > 0; K -= 4) {
d8d9 = GiLoadFloat32(A);
A = A + 4;
d10d11 = GiLoadFloat32(A);
A = A + 4;
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2);
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2);
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2);
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2);
B = B + LDB;
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3);
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3);
d0d1 = GiLoadFloat32(B);
B = B + 4;
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3);
d2d3 = GiLoadFloat32(B);
B = B + 4;
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3);
d4d5 = GiLoadFloat32(B);
B = B + 4;
d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0);
d6d7 = GiLoadFloat32(B);
B = B + 4;
d18d19 = GiSimdFmaLane(d18d19, d8d9, d2d3, 0);
d20d21 = GiSimdFmaLane(d20d21, d8d9, d4d5, 0);
d22d23 = GiSimdFmaLane(d22d23, d8d9, d6d7, 0);
d12d13 = GiLoadFloat32(A);
A = A + 4;
d14d15 = GiLoadFloat32(A);
A = A + 4;
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1);
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1);
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1);
}
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2);
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2);
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2);
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2);
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3);
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3);
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3);
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3);
GiStoreFloat32(C, d16d17);
C = C + 4;
GiStoreFloat32(C, d18d19);
C = C + 4;
GiStoreFloat32(C, d20d21);
C = C + 4;
GiStoreFloat32(C, d22d23);
C = C + 4;
}
void kern_4x8(const float* A, const float* B, size_t LDB, size_t K, float* C) {
LDB -= 32;
GI_FLOAT32_t d8d9 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d10d11 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d12d13 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d14d15 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d0d1 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d2d3 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d4d5 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d6d7 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0);
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1);
GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0);
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1);
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3);
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2);
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3);
d0d1 = GiLoadFloat32(B);
B = B + 4;
d2d3 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0);
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1);
GI_FLOAT32_t d22d23 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0);
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2);
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1);
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3);
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2);
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3);
d4d5 = GiLoadFloat32(B);
B = B + 4;
d6d7 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d24d25 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0);
d24d25 = GiSimdFmaLane(d24d25, d10d11, d0d1, 1);
GI_FLOAT32_t d26d27 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0);
d24d25 = GiSimdFmaLane(d24d25, d12d13, d0d1, 2);
d26d27 = GiSimdFmaLane(d26d27, d10d11, d2d3, 1);
d24d25 = GiSimdFmaLane(d24d25, d14d15, d0d1, 3);
d26d27 = GiSimdFmaLane(d26d27, d12d13, d2d3, 2);
d26d27 = GiSimdFmaLane(d26d27, d14d15, d2d3, 3);
GI_FLOAT32_t d28d29 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0);
d28d29 = GiSimdFmaLane(d28d29, d10d11, d4d5, 1);
GI_FLOAT32_t d30d31 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0);
d28d29 = GiSimdFmaLane(d28d29, d12d13, d4d5, 2);
d30d31 = GiSimdFmaLane(d30d31, d10d11, d6d7, 1);
d28d29 = GiSimdFmaLane(d28d29, d14d15, d4d5, 3);
d30d31 = GiSimdFmaLane(d30d31, d12d13, d6d7, 2);
d30d31 = GiSimdFmaLane(d30d31, d14d15, d6d7, 3);
B = B + LDB;
K = K - 4;
for (; K > 0; K -= 4) {
d8d9 = GiLoadFloat32(A);
A = A + 4;
d10d11 = GiLoadFloat32(A);
A = A + 4;
d12d13 = GiLoadFloat32(A);
A = A + 4;
d14d15 = GiLoadFloat32(A);
A = A + 4;
d0d1 = GiLoadFloat32(B);
B = B + 4;
d2d3 = GiLoadFloat32(B);
B = B + 4;
d4d5 = GiLoadFloat32(B);
B = B + 4;
d6d7 = GiLoadFloat32(B);
B = B + 4;
d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0);
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d8d9, d2d3, 0);
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1);
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3);
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2);
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3);
d0d1 = GiLoadFloat32(B);
B = B + 4;
d2d3 = GiLoadFloat32(B);
B = B + 4;
d20d21 = GiSimdFmaLane(d20d21, d8d9, d4d5, 0);
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1);
d22d23 = GiSimdFmaLane(d22d23, d8d9, d6d7, 0);
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2);
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1);
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3);
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2);
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3);
d4d5 = GiLoadFloat32(B);
B = B + 4;
d6d7 = GiLoadFloat32(B);
B = B + 4;
d24d25 = GiSimdFmaLane(d24d25, d8d9, d0d1, 0);
d24d25 = GiSimdFmaLane(d24d25, d10d11, d0d1, 1);
d26d27 = GiSimdFmaLane(d26d27, d8d9, d2d3, 0);
d24d25 = GiSimdFmaLane(d24d25, d12d13, d0d1, 2);
d26d27 = GiSimdFmaLane(d26d27, d10d11, d2d3, 1);
d24d25 = GiSimdFmaLane(d24d25, d14d15, d0d1, 3);
d26d27 = GiSimdFmaLane(d26d27, d12d13, d2d3, 2);
d26d27 = GiSimdFmaLane(d26d27, d14d15, d2d3, 3);
d28d29 = GiSimdFmaLane(d28d29, d8d9, d4d5, 0);
d28d29 = GiSimdFmaLane(d28d29, d10d11, d4d5, 1);
d30d31 = GiSimdFmaLane(d30d31, d8d9, d6d7, 0);
d28d29 = GiSimdFmaLane(d28d29, d12d13, d4d5, 2);
d30d31 = GiSimdFmaLane(d30d31, d10d11, d6d7, 1);
d28d29 = GiSimdFmaLane(d28d29, d14d15, d4d5, 3);
d30d31 = GiSimdFmaLane(d30d31, d12d13, d6d7, 2);
d30d31 = GiSimdFmaLane(d30d31, d14d15, d6d7, 3);
B = B + LDB;
}
GiStoreFloat32(C, d16d17);
C = C + 4;
GiStoreFloat32(C, d18d19);
C = C + 4;
GiStoreFloat32(C, d20d21);
C = C + 4;
GiStoreFloat32(C, d22d23);
C = C + 4;
GiStoreFloat32(C, d24d25);
C = C + 4;
GiStoreFloat32(C, d26d27);
C = C + 4;
GiStoreFloat32(C, d28d29);
C = C + 4;
GiStoreFloat32(C, d30d31);
C = C + 4;
}
} // namespace
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gi_sgemm_nopack_4x8);
void gi_sgemm_nopack_4x8::kern(
const float* A, size_t LDA, const float* B, size_t LDB, float* C, size_t LDC,
size_t M, size_t K, size_t N, const float*, void*, bool trA, bool trB) const {
constexpr size_t MB = 4;
constexpr size_t KB = 4;
constexpr size_t NB = 8;
constexpr size_t NB_HALF = 4;
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0);
for (size_t m = 0; m < M; m += MB) {
float* output = C + (m / MB) * LDC;
const float* cur_B = B;
size_t n = 0;
for (; n + NB - 1 < N; n += NB) {
kern_4x8(A, cur_B, LDB, K, output);
cur_B += KB * NB;
output += MB * NB;
}
if (N - n >= 4) {
kern_4x4(A, cur_B, LDB, K, output);
cur_B += KB * NB_HALF;
output += MB * NB_HALF;
n += 4;
}
while (n < N) {
kern_4x1(A, cur_B, LDB, K, output);
cur_B += KB;
output += MB;
n++;
}
A += LDA;
}
}
// vim: syntax=cpp.doxygen
...@@ -36,6 +36,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { ...@@ -36,6 +36,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoF32K8x12x1 f32_k8x12x1; AlgoF32K8x12x1 f32_k8x12x1;
AlgoGemv gemv; AlgoGemv gemv;
AlgoNaive naive; AlgoNaive naive;
AlgoF32GiGemvMK4 f32_gemv_mk4;
AlgoF32GiMK4_4x8 f32_mk4_4x8;
SmallVector<AlgoBase*> m_all_algos; SmallVector<AlgoBase*> m_all_algos;
AlgoBase::Mapper m_all_algos_map; AlgoBase::Mapper m_all_algos_map;
...@@ -44,6 +46,8 @@ public: ...@@ -44,6 +46,8 @@ public:
m_all_algos.emplace_back(&gemv); m_all_algos.emplace_back(&gemv);
m_all_algos.emplace_back(&f32_k8x12x1); m_all_algos.emplace_back(&f32_k8x12x1);
m_all_algos.emplace_back(&naive); m_all_algos.emplace_back(&naive);
m_all_algos.emplace_back(&f32_gemv_mk4);
m_all_algos.emplace_back(&f32_mk4_4x8);
for (auto&& algo : m_all_algos) { for (auto&& algo : m_all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo); m_all_algos_map.emplace(algo->info().desc, algo);
} }
......
...@@ -112,6 +112,8 @@ public: ...@@ -112,6 +112,8 @@ public:
FB_F32K8x12x1 = 1 << 0, FB_F32K8x12x1 = 1 << 0,
FB_GEMV, FB_GEMV,
FB_NAIVE, FB_NAIVE,
FB_GI_F32_GEMV_MK4,
FB_GI_F32_MK4_4x8,
#if MEGDNN_X86 #if MEGDNN_X86
//! x86 //! x86
...@@ -131,7 +133,6 @@ public: ...@@ -131,7 +133,6 @@ public:
ARM_COMMON_INT8X8X32_GEMV, ARM_COMMON_INT8X8X32_GEMV,
ARM_COMMON_INT8X8X32_GEMV_MK4, ARM_COMMON_INT8X8X32_GEMV_MK4,
ARM_COMMON_INT8X8X32_GEMV_MK4_DOT, ARM_COMMON_INT8X8X32_GEMV_MK4_DOT,
ARM_COMMON_F32_GEMV_MK4,
ARM_COMMON_F16_GEMV, ARM_COMMON_F16_GEMV,
ARM_COMMON_GEVM, ARM_COMMON_GEVM,
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
...@@ -237,6 +238,8 @@ public: ...@@ -237,6 +238,8 @@ public:
private: private:
class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1
class AlgoF32GiGemvMK4; // fallback F32 gi Gemv NCHW44
class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44
class AlgoGemv; class AlgoGemv;
class AlgoNaive; class AlgoNaive;
class AlgoPack; class AlgoPack;
......
...@@ -45,6 +45,13 @@ TEST_F(FALLBACK, MATRIX_MUL) { ...@@ -45,6 +45,13 @@ TEST_F(FALLBACK, MATRIX_MUL) {
checker.execl({AL, BL, CL}); checker.execl({AL, BL, CL});
} }
} }
TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) {
matrix_mul::check_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
"FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1);
}
TEST_F(FALLBACK, MATRIX_MUL_RECORD) { TEST_F(FALLBACK, MATRIX_MUL_RECORD) {
TaskRecordChecker<MatrixMul> checker(1); TaskRecordChecker<MatrixMul> checker(1);
using Param = MatrixMul::Param; using Param = MatrixMul::Param;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册