提交 07d1d0ab 编写于 作者: M Megvii Engine Team

feat(dnn/arm64): add fp32 mk4 matmul

GitOrigin-RevId: f6df006547e08ba5b76be984a2fe87cf053c31de
上级 7ba641fe
...@@ -86,6 +86,67 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern, ...@@ -86,6 +86,67 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern,
"AlgoF32K8x12x1Impl"_hash, "AlgoF32K8x12x1Impl"_hash,
aarch64::matmul::sgemm_8x12, float, float); aarch64::matmul::sgemm_8x12, float, float);
/* ===================== F32_MK4_8X12X1 algo ===================== */
bool MatrixMulImpl::AlgoF32MK4_8x12x1::usable(
const KernSizeParam& kern_size_param) const {
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float32() &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
!kern_size_param.trA && !kern_size_param.trB &&
kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0;
}
size_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("AlgoF32MK4_8x12x1::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
aarch64::matmul::sgemm_mk4_8x12 strategy(M, N, K, A_type, B_type,
C_type);
return megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_mk4_8x12>(
M, N, K, trA, trB, strategy)
.get_workspace_size();
}
MIDOUT_END();
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_kern(
const KernSizeParam&) const {
auto f32_kern_mk4_8x12 = [](const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("AlgoF32MK4_8x12x1::get_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB,
LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<float>(),
Bptr = kern_param.B<float>();
auto Cptr = kern_param.C<float>();
aarch64::matmul::sgemm_mk4_8x12 strategy(M, N, K, A_type, B_type,
C_type);
megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_mk4_8x12>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
MIDOUT_END();
};
return f32_kern_mk4_8x12;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4_8x12x1,
megdnn_aarch64_matmul_kern,
"AlgoF32MK4_8x12x1Impl"_hash,
aarch64::matmul::sgemm_mk4_8x12, float,
float);
/* ===================== F32K4X16X1 algo ===================== */ /* ===================== F32K4X16X1 algo ===================== */
bool MatrixMulImpl::AlgoF32K4x16x1::usable( bool MatrixMulImpl::AlgoF32K4x16x1::usable(
......
...@@ -29,6 +29,17 @@ public: ...@@ -29,6 +29,17 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };
class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase {
public: public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
......
...@@ -1103,6 +1103,36 @@ static inline void interleave_4x4_1_s(const T*& inptr0, const T*& inptr1, ...@@ -1103,6 +1103,36 @@ static inline void interleave_4x4_1_s(const T*& inptr0, const T*& inptr1,
: "v0", "v1", "v2", "v3", "cc", "memory"); : "v0", "v1", "v2", "v3", "cc", "memory");
} }
template <typename T>
static inline void interleave_2x4_4_s(const T*& inptr0, const T*& inptr1,
T* outptr) {
static_assert(sizeof(T) == 4, "interleave_2x4_4_s only support size == 4");
asm volatile(
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n"
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[inptr1]], #64\n"
"stp q0, q4, [%[outptr]]\n"
"stp q1, q5, [%[outptr], #32]\n"
"stp q2, q6, [%[outptr], #64]\n"
"stp q3, q7, [%[outptr], #96]\n"
: [ inptr0 ] "+r"(inptr0), [ inptr1 ] "+r"(inptr1),
[ outptr ] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory");
}
template <typename T>
static inline void interleave_1x4_4_s(const T*& inptr0, T* outptr) {
static_assert(sizeof(T) == 4, "interleave_1x4_4_s only support size == 4");
asm volatile(
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n"
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "memory");
}
template <typename T> template <typename T>
static inline void interleave_4x8_2_b(const T*& inptr0, const T*& inptr1, static inline void interleave_4x8_2_b(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3, const T*& inptr2, const T*& inptr3,
...@@ -1479,6 +1509,41 @@ static inline void transpose_4x4_1_s(const T*& inptr0, const T*& inptr1, ...@@ -1479,6 +1509,41 @@ static inline void transpose_4x4_1_s(const T*& inptr0, const T*& inptr1,
"v11", "memory"); "v11", "memory");
} }
template <typename T>
static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) {
static_assert(sizeof(T) == 4,
"transpose_1x12_4_s only support sizeof(T) == 4");
asm volatile(
"ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n"
"ld4 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[inptr0]], #64\n"
"ld4 {v8.4s, v9.4s, v10.4s, v11.4s},[%[inptr0]], #64\n"
"stp q0, q4, [%[outptr]] \n"
"stp q8, q1, [%[outptr], #32] \n"
"stp q5, q9, [%[outptr], #64] \n"
"stp q2, q6, [%[outptr], #96] \n"
"stp q10, q3, [%[outptr], #128] \n"
"stp q7, q11, [%[outptr], #160] \n"
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "memory");
}
template <typename T>
static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) {
static_assert(sizeof(T) == 4,
"transpose_1x4_4_s only support sizeof(T) == 4");
asm volatile(
"ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n"
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "memory");
}
template <typename T> template <typename T>
static inline void transpose_8x4_1_s(const T*& inptr0, const T*& inptr1, static inline void transpose_8x4_1_s(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3, const T*& inptr2, const T*& inptr3,
......
...@@ -899,6 +899,10 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output, ...@@ -899,6 +899,10 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output,
: :
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2",
"x3", "x10", "cc", "memory"); "x3", "x10", "cc", "memory");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
} }
void sgemm_8x12_pack_A_n(float* outptr, const float* inptr, int ldin, int y0, void sgemm_8x12_pack_A_n(float* outptr, const float* inptr, int ldin, int y0,
......
此差异已折叠。
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "src/aarch64/matrix_mul/fp32/strategy.h" #include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h" #include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h" #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h"
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h"
#include "src/common/utils.h" #include "src/common/utils.h"
using namespace megdnn; using namespace megdnn;
...@@ -163,4 +164,80 @@ void sgemm_8x12::kern(const float* packA, const float* packB, ...@@ -163,4 +164,80 @@ void sgemm_8x12::kern(const float* packA, const float* packB,
} }
} }
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12);
void sgemm_mk4_8x12::pack_A(float* out, const float* in, int ldin, int y0,
int ymax, int k0, int kmax,
bool transpose_A) const {
megdnn_assert(!transpose_A, "mk4 float matmul not support transpose A");
matmul_mk4_8x12::sgemm_8x12_pack_A(out, in, ldin, y0, ymax, k0, kmax);
}
void sgemm_mk4_8x12::pack_B(float* out, const float* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose_B) const {
megdnn_assert(!transpose_B, "mk4 float matmul not support transpose B");
matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax);
}
void sgemm_mk4_8x12::kern(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C, size_t LDC,
bool is_first_k, const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4");
constexpr size_t PACK_C_SIZE = 4;
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t A_INTERLEAVE4 = 4;
constexpr size_t B_INTERLEAVE = 12;
const int K12 = K * 12;
const int K8 = K * 8;
const int K4 = K * 4;
size_t m = 0;
for (; m + A_INTERLEAVE <= M; m += A_INTERLEAVE) {
float* output = C + (m / PACK_C_SIZE * LDC);
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) {
matmul_mk4_8x12::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k);
output += B_INTERLEAVE * PACK_C_SIZE;
cur_packB += K12;
}
for (; n < N; n += 4) {
matmul_mk4_8x12::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
output += 4 * PACK_C_SIZE;
cur_packB += K4;
}
packA += K8;
}
for (; m < M; m += A_INTERLEAVE4) {
float* output = C + (m / PACK_C_SIZE * LDC);
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_8x12::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k);
output += B_INTERLEAVE * PACK_C_SIZE;
cur_packB += K12;
}
for (; n < N; n += 4) {
matmul_mk4_8x12::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
output += 4 * PACK_C_SIZE;
cur_packB += K4;
}
packA += K4;
}
}
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -20,6 +20,9 @@ MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, ...@@ -20,6 +20,9 @@ MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true,
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true, MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true,
sgemm_4x16); sgemm_4x16);
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, false,
sgemm_mk4_8x12);
MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 16, 1, false, true, MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 16, 1, false, true,
sgemm_nopack_4x16); sgemm_nopack_4x16);
......
...@@ -18,6 +18,7 @@ using namespace aarch64; ...@@ -18,6 +18,7 @@ using namespace aarch64;
class MatrixMulImpl::AlgoPack : NonCopyableObj { class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoF32K8x12x1 f32K8x12x1; AlgoF32K8x12x1 f32K8x12x1;
AlgoF32MK4_8x12x1 f32_mk4_8x12x1;
AlgoF32K4x16x1 f32k4x16x1; AlgoF32K4x16x1 f32k4x16x1;
AlgoF32MK4_4x16 f32mk4_4x16; AlgoF32MK4_4x16 f32mk4_4x16;
AlgoF32Gemv f32_gemv; AlgoF32Gemv f32_gemv;
...@@ -53,6 +54,7 @@ public: ...@@ -53,6 +54,7 @@ public:
AlgoPack() { AlgoPack() {
all_algos.emplace_back(&f32_gemv); all_algos.emplace_back(&f32_gemv);
all_algos.emplace_back(&f32K8x12x1); all_algos.emplace_back(&f32K8x12x1);
all_algos.emplace_back(&f32_mk4_8x12x1);
all_algos.emplace_back(&f32k4x16x1); all_algos.emplace_back(&f32k4x16x1);
all_algos.emplace_back(&f32mk4_4x16); all_algos.emplace_back(&f32mk4_4x16);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......
...@@ -22,6 +22,7 @@ public: ...@@ -22,6 +22,7 @@ public:
private: private:
class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1
class AlgoF32MK4_8x12x1; // Aarch64 F32 Kernel MK4 8x12x1
class AlgoF32K4x16x1; // Aarch64 F32 Kernel 4x16x1 class AlgoF32K4x16x1; // Aarch64 F32 Kernel 4x16x1
class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4 class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4
class AlgoF32Gemv; // Aarch64 F32 Gemv class AlgoF32Gemv; // Aarch64 F32 Gemv
......
...@@ -244,6 +244,7 @@ bool ConvBiasImpl::AlgoFP16WinogradF23_8x8::usable( ...@@ -244,6 +244,7 @@ bool ConvBiasImpl::AlgoFP16WinogradF23_8x8::usable(
if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0) if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0)
return false; return false;
using Strategy = winograd::winograd_2x3_8x8_f16; using Strategy = winograd::winograd_2x3_8x8_f16;
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode;
Strategy strategy(param.src_type, param.filter_type, param.dst_type); Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param = auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy, megdnn::winograd::ConvBias<Strategy,
...@@ -252,6 +253,7 @@ bool ConvBiasImpl::AlgoFP16WinogradF23_8x8::usable( ...@@ -252,6 +253,7 @@ bool ConvBiasImpl::AlgoFP16WinogradF23_8x8::usable(
param.osz[1], param.filter_meta.ocpg) param.osz[1], param.filter_meta.ocpg)
.get_matmul_kern_param(param); .get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) && return m_matmul_algo->usable(matmul_param) &&
m_matmul_algo->packmode() == PackMode::NO_PACK &&
(opr->param().format == param::ConvBias::Format::NCHW || (opr->param().format == param::ConvBias::Format::NCHW ||
(opr->param().format == (opr->param().format ==
param::ConvBias::Format::NCHW_WINOGRAD && param::ConvBias::Format::NCHW_WINOGRAD &&
......
...@@ -38,6 +38,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable( ...@@ -38,6 +38,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable(
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0)
return false; return false;
using Strategy = winograd::winograd_2x3_4x4_f; using Strategy = winograd::winograd_2x3_4x4_f;
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode;
Strategy strategy(param.src_type, param.filter_type, param.dst_type); Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param = auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy, megdnn::winograd::ConvBias<Strategy,
...@@ -46,6 +47,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable( ...@@ -46,6 +47,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable(
param.osz[1], param.filter_meta.ocpg) param.osz[1], param.filter_meta.ocpg)
.get_matmul_kern_param(param); .get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) && return m_matmul_algo->usable(matmul_param) &&
m_matmul_algo->packmode() == PackMode::NO_PACK &&
(opr->param().format == param::ConvBias::Format::NCHW || (opr->param().format == param::ConvBias::Format::NCHW ||
(opr->param().format == (opr->param().format ==
param::ConvBias::Format::NCHW_WINOGRAD && param::ConvBias::Format::NCHW_WINOGRAD &&
...@@ -319,6 +321,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable( ...@@ -319,6 +321,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable(
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0)
return false; return false;
using Strategy = winograd::winograd_6x3_4x4_f; using Strategy = winograd::winograd_6x3_4x4_f;
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode;
Strategy strategy(param.src_type, param.filter_type, param.dst_type); Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param = auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy, megdnn::winograd::ConvBias<Strategy,
...@@ -327,6 +330,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable( ...@@ -327,6 +330,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable(
param.osz[1], param.filter_meta.ocpg) param.osz[1], param.filter_meta.ocpg)
.get_matmul_kern_param(param); .get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) && return m_matmul_algo->usable(matmul_param) &&
m_matmul_algo->packmode() == PackMode::NO_PACK &&
(opr->param().format == param::ConvBias::Format::NCHW || (opr->param().format == param::ConvBias::Format::NCHW ||
(opr->param().format == (opr->param().format ==
param::ConvBias::Format::NCHW_WINOGRAD && param::ConvBias::Format::NCHW_WINOGRAD &&
......
...@@ -217,6 +217,7 @@ bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable( ...@@ -217,6 +217,7 @@ bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable(
if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0) if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0)
return false; return false;
using Strategy = winograd::winograd_2x3_8x8_s8; using Strategy = winograd::winograd_2x3_8x8_s8;
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode;
Strategy strategy(param.src_type, param.filter_type, param.dst_type); Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param = auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy, param::MatrixMul::Format::MK8>( megdnn::winograd::ConvBias<Strategy, param::MatrixMul::Format::MK8>(
...@@ -224,6 +225,7 @@ bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable( ...@@ -224,6 +225,7 @@ bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable(
param.osz[1], param.filter_meta.ocpg) param.osz[1], param.filter_meta.ocpg)
.get_matmul_kern_param(param); .get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) && return m_matmul_algo->usable(matmul_param) &&
m_matmul_algo->packmode() == PackMode::NO_PACK &&
((opr->param().format == param::ConvBias::Format::NCHW && ((opr->param().format == param::ConvBias::Format::NCHW &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8) || param.filter_type.enumv() == DTypeEnum::QuantizedS8) ||
(opr->param().format == param::ConvBias::Format::NCHW_WINOGRAD && (opr->param().format == param::ConvBias::Format::NCHW_WINOGRAD &&
......
...@@ -31,6 +31,12 @@ TEST_F(AARCH64, MATRIX_MUL_FP32K4X16) { ...@@ -31,6 +31,12 @@ TEST_F(AARCH64, MATRIX_MUL_FP32K4X16) {
"AARCH64_F32K4X16X1"); "AARCH64_F32K4X16X1");
} }
TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4) {
matrix_mul::check_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
"AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1);
}
TEST_F(AARCH64, MATRIX_MUL_FP32_MK4) { TEST_F(AARCH64, MATRIX_MUL_FP32_MK4) {
//! nbase should be 4 in order to test the last rest 4 in N dim //! nbase should be 4 in order to test the last rest 4 in N dim
matrix_mul::check_matrix_mul( matrix_mul::check_matrix_mul(
...@@ -527,6 +533,15 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_MK4) { ...@@ -527,6 +533,15 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_MK4) {
dtype::Float32{}); dtype::Float32{});
} }
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_PACK_MK4) {
auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(16);
matrix_mul::benchmark_with_contrast(
handle(), args, dtype::Float32{}, dtype::Float32{},
dtype::Float32{}, "AARCH64_F32_MK4_K8X12X1",
param::MatrixMul::Format::MK4, dtype::Float32{}, dtype::Float32{},
dtype::Float32{}, "AARCH64_F32K8X12X1");
}
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16x16x32_MK8) { TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16x16x32_MK8) {
auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8); auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8);
matrix_mul::benchmark_with_contrast( matrix_mul::benchmark_with_contrast(
......
...@@ -40,8 +40,8 @@ std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_mk_packed_args( ...@@ -40,8 +40,8 @@ std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_mk_packed_args(
size_t nbase) { size_t nbase) {
std::vector<TestArg> args; std::vector<TestArg> args;
for (size_t m : {1, 2, 3, 4, 5}) for (size_t m : {1, 2, 3, 4, 5})
for (size_t n : {1, 2, 3, 4, 5, 8, 16, 24}) for (size_t n : {1, 2, 3, 4, 5, 8, 12, 16, 24})
for (size_t k : {1, 2, 3, 4, 5}) for (size_t k : {1, 2, 3, 4, 5, 9, 10})
args.emplace_back(m, n * nbase, k, 0); args.emplace_back(m, n * nbase, k, 0);
return args; return args;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册