提交 f249d387 编写于 作者: M Megvii Engine Team

feat(fallback): imp gi matmul FB_GI_F32_4x12 algo

GitOrigin-RevId: 16255e7a728bf8cffbdc57094534600683af587a
上级 03f78547
......@@ -138,8 +138,6 @@ public:
}
}
//! TODO: move arm_v7 MatrixMulImpl::AlgoF32 matmul to gi fallback, for nchw
//! prefetch algo, also need update dnn/test/common/conv_bias.cpp:check_winograd
matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr)
->select_algo_type(
{AlgoDataType::FLOAT32, MatmulFormat::DEFAULT});
......
......@@ -15,6 +15,7 @@ MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like)
MIDOUT_DECL(megdnn_fb_matmul_naive)
MIDOUT_DECL(megdnn_fb_gi_exec_fp32)
MIDOUT_DECL(megdnn_fb_gi_matmul_kern)
MIDOUT_DECL(megdnn_fb_gi_f32_4x12)
using namespace megdnn;
using namespace fallback;
......@@ -293,4 +294,61 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern(
const KernSizeParam&) const {
return gi_f32_mk4_4x8_kern;
}
/* ===================== F32 algo ===================== */
namespace {
void f32_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_fb_gi_f32_4x12, midout_iv("f32_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<float>(), Bptr = kern_param.B<float>();
auto Cptr = kern_param.C<float>();
matmul::fallback::gi_sgemm_4x12 strategy(M, N, K, A_type, B_type, C_type);
megdnn::matmul::GemmInterleaved<matmul::fallback::gi_sgemm_4x12>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // anonymous namespace
bool MatrixMulImpl::AlgoF32Gi4x12::usable(const KernSizeParam& kern_size_param) const {
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float32();
}
size_t MatrixMulImpl::AlgoF32Gi4x12::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(
megdnn_fb_gi_f32_4x12, midout_iv("AlgoF32Gi4x12::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
matmul::fallback::gi_sgemm_4x12 strategy(M, N, K, A_type, B_type, C_type);
return megdnn::matmul::GemmInterleaved<matmul::fallback::gi_sgemm_4x12>(
M, N, K, trA, trB, strategy)
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gi4x12::get_kern(
const KernSizeParam&) const {
return f32_kern;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
AlgoF32Gi4x12, megdnn_fb_gi_f32_4x12, "AlgoF32Gi4x12Impl"_hash,
matmul::fallback::gi_sgemm_4x12, float, float, AlgoDataType::FLOAT32, DEFAULT);
// vim: syntax=cpp.doxygen
......@@ -97,6 +97,17 @@ public:
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_4x8)
};
class MatrixMulImpl::AlgoF32Gi4x12 final : public AlgoBase {
public:
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "FB_GI_F32_4x12"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_4x12)
};
} // namespace fallback
} // namespace megdnn
......
......@@ -8,6 +8,7 @@ namespace fallback {
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12);
MEGDNN_REG_GEMM_STRATEGY_NOPACK(
float, float, float, 4, 8, 1, false, true, gi_sgemm_nopack_4x8);
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, gi_sgemm_4x12);
} // namespace fallback
} // namespace matmul
......
#pragma once
#include "src/fallback/general_intrinsic/gi_float.h"
namespace megdnn {
namespace matmul {
namespace fallback {
/* ======================== transform ======================== */
/**
* interleave_INTERLEAVE_UNROLLK_BATCH_type
*
* BATCH means process BATCH * UNROLL_K cols once, BATCH * sizeof(TYPE) *
* UNROLL_K = 16bytes(128bits, a vector size).
*
* the elements traverse order:
* rep(j, 0, INTERLEAVE) rep(i, 0, UNROLL_K) *ouptr++ = inptr[j, i]
*/
template <typename T>
static GI_FORCEINLINE void interleave_4x4_1_s(
const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3,
T*& outptr) {
static_assert(sizeof(T) == 4, "interleave_4x4_1_s only support sizeof(T) == 4");
GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0);
GI_FLOAT32_t d2d3 = GiLoadFloat32(inptr1);
GI_FLOAT32_t d4d5 = GiLoadFloat32(inptr2);
GI_FLOAT32_t d6d7 = GiLoadFloat32(inptr3);
inptr0 += 4;
inptr1 += 4;
inptr2 += 4;
inptr3 += 4;
GiStoreFloat32(outptr, d0d1);
outptr += 4;
GiStoreFloat32(outptr, d2d3);
outptr += 4;
GiStoreFloat32(outptr, d4d5);
outptr += 4;
GiStoreFloat32(outptr, d6d7);
outptr += 4;
}
template <typename T>
static GI_FORCEINLINE void interleave_4x12_1_s(
const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3,
T*& outptr) {
static_assert(sizeof(T) == 4, "interleave_4x12_1_s only support sizeof(T) == 4");
GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0);
inptr0 += 4;
GI_FLOAT32_t d2d3 = GiLoadFloat32(inptr0);
inptr0 += 4;
GI_FLOAT32_t d4d5 = GiLoadFloat32(inptr0);
inptr0 += 4;
GI_FLOAT32_t d6d7 = GiLoadFloat32(inptr1);
inptr1 += 4;
GI_FLOAT32_t d8d9 = GiLoadFloat32(inptr1);
inptr1 += 4;
GI_FLOAT32_t d10d11 = GiLoadFloat32(inptr1);
inptr1 += 4;
GI_FLOAT32_t d12d13 = GiLoadFloat32(inptr2);
inptr2 += 4;
GI_FLOAT32_t d14d15 = GiLoadFloat32(inptr2);
inptr2 += 4;
GI_FLOAT32_t d16d17 = GiLoadFloat32(inptr2);
inptr2 += 4;
GI_FLOAT32_t d18d19 = GiLoadFloat32(inptr3);
inptr3 += 4;
GI_FLOAT32_t d20d21 = GiLoadFloat32(inptr3);
inptr3 += 4;
GI_FLOAT32_t d22d23 = GiLoadFloat32(inptr3);
inptr3 += 4;
GiStoreFloat32(outptr, d0d1);
outptr += 4;
GiStoreFloat32(outptr, d2d3);
outptr += 4;
GiStoreFloat32(outptr, d4d5);
outptr += 4;
GiStoreFloat32(outptr, d6d7);
outptr += 4;
GiStoreFloat32(outptr, d8d9);
outptr += 4;
GiStoreFloat32(outptr, d10d11);
outptr += 4;
GiStoreFloat32(outptr, d12d13);
outptr += 4;
GiStoreFloat32(outptr, d14d15);
outptr += 4;
GiStoreFloat32(outptr, d16d17);
outptr += 4;
GiStoreFloat32(outptr, d18d19);
outptr += 4;
GiStoreFloat32(outptr, d20d21);
outptr += 4;
GiStoreFloat32(outptr, d22d23);
outptr += 4;
}
template <typename T>
static GI_FORCEINLINE void interleave_1x12_1_s(const T*& inptr0, T*& outptr) {
static_assert(sizeof(T) == 4, "interleave_1x12_1_s only support sizeof(T) == 4");
GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0);
inptr0 += 4;
GI_FLOAT32_t d2d3 = GiLoadFloat32(inptr0);
inptr0 += 4;
GI_FLOAT32_t d4d5 = GiLoadFloat32(inptr0);
inptr0 += 4;
GiStoreFloat32(outptr, d0d1);
outptr += 4;
GiStoreFloat32(outptr, d2d3);
outptr += 4;
GiStoreFloat32(outptr, d4d5);
outptr += 4;
}
template <typename T>
static GI_FORCEINLINE void interleave_1x4_1_s(const T*& inptr0, T*& outptr) {
static_assert(sizeof(T) == 4, "interleave_1x4_1_s only support sizeof(T) == 4");
GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0);
inptr0 += 4;
GiStoreFloat32(outptr, d0d1);
outptr += 4;
}
template <typename T>
static GI_FORCEINLINE void interleave_helper(
const T*& inptr, T*& outptr, int unroll_k, int ksize, T val = 0) {
int k = 0;
for (; k < ksize; k++) {
*outptr++ = *inptr++;
}
for (; k < unroll_k; k++) {
*outptr++ = val;
}
}
template <typename T>
static GI_FORCEINLINE void interleave_1(
const T*& inptr0, T*& outptr, int unroll_k, int ksize, T val = 0) {
for (int k = 0; k < ksize; k += unroll_k) {
int size = std::min(unroll_k, ksize - k);
interleave_helper(inptr0, outptr, unroll_k, size, val);
}
}
template <typename T>
static GI_FORCEINLINE void interleave_4(
const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3,
T*& outptr, int unroll_k, int ksize, T val = 0) {
for (int k = 0; k < ksize; k += unroll_k) {
int size = std::min(unroll_k, ksize - k);
interleave_helper(inptr0, outptr, unroll_k, size, val);
interleave_helper(inptr1, outptr, unroll_k, size, val);
interleave_helper(inptr2, outptr, unroll_k, size, val);
interleave_helper(inptr3, outptr, unroll_k, size, val);
}
}
/* ======================== transpose pack B ======================== */
/**
* transpose_INTERLEAVE_UNROLLK_BATCH_type
*
* BATCH means process BATCH * INTERLEAVE cols once, BATCH * sizeof(TYPE) *
* INTERLEAVE = 16bytes(128bits, a vector size).
*
* the elements traverse order:
* rep(j, 0, INTERLEAVE) rep(i, 0, UNROLL_K) *ouptr++ = inptr[i, j]
*/
template <typename T>
static GI_FORCEINLINE void transpose_4x4_1_s(
const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3,
T*& outptr, int stride = 16) {
static_assert(sizeof(T) == 4, "transpose_4x4_1_s only support sizeof(T) == 4");
stride = stride / sizeof(float);
stride -= 2;
GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0);
GI_FLOAT32_t d2d3 = GiLoadFloat32(inptr1);
GI_FLOAT32_t d4d5 = GiLoadFloat32(inptr2);
GI_FLOAT32_t d6d7 = GiLoadFloat32(inptr3);
inptr0 += 4;
inptr1 += 4;
inptr2 += 4;
inptr3 += 4;
GI_FLOAT32_V2_t q0q1 = GiZipqFloat32(d0d1, d2d3);
GI_FLOAT32_V2_t q2q3 = GiZipqFloat32(d4d5, d6d7);
GiSt1Float32(outptr, GiGetLowFloat32(q0q1.val[0]));
outptr += 2;
GiSt1Float32(outptr, GiGetLowFloat32(q2q3.val[0]));
outptr += stride;
GiSt1Float32(outptr, GiGetHighFloat32(q0q1.val[0]));
outptr += 2;
GiSt1Float32(outptr, GiGetHighFloat32(q2q3.val[0]));
outptr += stride;
GiSt1Float32(outptr, GiGetLowFloat32(q0q1.val[1]));
outptr += 2;
GiSt1Float32(outptr, GiGetLowFloat32(q2q3.val[1]));
outptr += stride;
GiSt1Float32(outptr, GiGetHighFloat32(q0q1.val[1]));
outptr += 2;
GiSt1Float32(outptr, GiGetHighFloat32(q2q3.val[1]));
outptr += stride;
}
} // namespace fallback
} // namespace matmul
} // namespace megdnn
// vim: syntax=cpp.doxygen
此差异已折叠。
......@@ -28,16 +28,18 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoNaive naive;
AlgoF32GiGemvMK4 f32_gemv_mk4;
AlgoF32GiMK4_4x8 f32_mk4_4x8;
AlgoF32Gi4x12 f32_4x8;
SmallVector<AlgoBase*> m_all_algos;
AlgoBase::Mapper m_all_algos_map;
public:
AlgoPack() {
m_all_algos.emplace_back(&f32_gemv_mk4);
m_all_algos.emplace_back(&f32_mk4_4x8);
m_all_algos.emplace_back(&f32_4x8);
m_all_algos.emplace_back(&gemv);
m_all_algos.emplace_back(&f32_k8x12x1);
m_all_algos.emplace_back(&naive);
m_all_algos.emplace_back(&f32_gemv_mk4);
m_all_algos.emplace_back(&f32_mk4_4x8);
for (auto&& algo : m_all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
......
......@@ -103,6 +103,7 @@ public:
FB_NAIVE,
FB_GI_F32_GEMV_MK4,
FB_GI_F32_MK4_4x8,
FB_GI_F32_4x12,
#if MEGDNN_X86
//! x86
......@@ -232,6 +233,7 @@ private:
class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1
class AlgoF32GiGemvMK4; // fallback F32 gi Gemv NCHW44
class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44
class AlgoF32Gi4x12; // fallback F32 gi Gemm
class AlgoGemv;
class AlgoNaive;
class AlgoPack;
......
......@@ -42,6 +42,12 @@ TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) {
"FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1);
}
TEST_F(FALLBACK, MATRIX_MULF_GI_F32_4x12) {
matrix_mul::check_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
"FB_GI_F32_4x12");
}
TEST_F(FALLBACK, MATRIX_MUL_RECORD) {
TaskRecordChecker<MatrixMul> checker(1);
using Param = MatrixMul::Param;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册