提交 5e306b75 编写于 作者: M Megvii Engine Team

feat(x86): make conv1x1 and im2col available on with x86-NCHW44

add AlgoF32GiMK4Pack4x12 matrix_mul algo

GitOrigin-RevId: 47cfe1d733d80c4c8ca8f4a8fa2d8469e530da75
上级 481a6cbb
......@@ -197,10 +197,17 @@ bool ConvBiasImpl::AlgoConv1x1::usable(
return false;
}
}
#else //! x86 only support nchw mode
if (format != param::ConvBias::Format::NCHW) {
#else //! x86 and RISC-V do not support NCHW44_DOT
if (format != param::ConvBias::Format::NCHW &&
format != param::ConvBias::Format::NCHW44) {
return false;
}
//! hybird mode is not support
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) {
if (param.filter_meta.icpg < 4_z || param.filter_meta.ocpg == 1) {
return false;
}
}
#endif
//! param
if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) {
......
......@@ -345,8 +345,20 @@ bool ConvBiasImpl::AlgoIm2col::usable(
}
}
#else
if (format != param::ConvBias::Format::NCHW) {
if (format != param::ConvBias::Format::NCHW &&
format != param::ConvBias::Format::NCHW44) {
return false;
}
if (format == param::ConvBias::Format::NCHW44) {
//! current NCHW44 im2col only support DEFAULT mode matmul
if (matmul_desc.packmode != Pack_Mode::DEFAULT) {
return false;
//! nchw44 hybird mode and channel wise is not support
} else if (
param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 ||
param.filter_meta.ocpg == 1) {
return false;
}
}
#endif
if (param.src_type.enumv() != param.filter_type.enumv() ||
......
......@@ -216,10 +216,9 @@ public:
cb1(NCHW, DEFAULT, dt_float32, dt_float32, PostprocessMode::FLOAT,
"DefaultStrategyType::FLOAT"_hash);
} else if (format == param::ConvBias::Format::NCHW44) {
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
auto matmul_block = matmul_algo->get_inner_block_size();
//! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1
//! im2col+pack fuse
//! Optimize NCHW44 3x3s2 on aarch64 8X12X4 and fallback/armv7
//! 4x12x4 im2col+pack fuse
if ((matmul_block.m == 8 || matmul_block.m == 4) &&
matmul_block.n == 12 && matmul_block.k == 1 &&
param.filter_meta.spatial[0] == 3 &&
......@@ -236,7 +235,6 @@ public:
MIDOUT_END();
return {};
}
#endif
cb1(NCHW44, DEFAULT, dt_float32, dt_float32, PostprocessMode::FLOAT,
"DefaultStrategyTypeNCHW44::FLOAT"_hash);
......
......@@ -530,7 +530,6 @@ public:
};
#endif
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
template <
typename op_ctype, typename op_dtype, megdnn::PostprocessMode postprocess_mode>
class StrategyFuseXx12x1Nchw44K3x3S2
......@@ -553,7 +552,6 @@ public:
fallback::MatrixMulImpl::KernParam matmul_param,
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
};
#endif
} // namespace megdnn
// vim: syntax=cpp.doxygen
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
#include <arm_neon.h>
#include "src/fallback/general_intrinsic/gi_float.h"
using namespace megdnn;
......@@ -11,32 +10,32 @@ namespace {
int out_index = 0; \
outptr = output_base; \
for (; out_index + 11 < block_size; out_index += 12) { \
float32x4x4_t v0 = vld4q_f32(tmp_output); \
float32x4x4_t v1 = vld4q_f32(tmp_output + 16); \
float32x4x4_t v2 = vld4q_f32(tmp_output + 32); \
vst1q_f32(outptr, v0.val[0]); \
vst1q_f32(outptr + 4, v1.val[0]); \
vst1q_f32(outptr + 8, v2.val[0]); \
vst1q_f32(outptr + 12, v0.val[1]); \
vst1q_f32(outptr + 16, v1.val[1]); \
vst1q_f32(outptr + 20, v2.val[1]); \
vst1q_f32(outptr + 24, v0.val[2]); \
vst1q_f32(outptr + 28, v1.val[2]); \
vst1q_f32(outptr + 32, v2.val[2]); \
vst1q_f32(outptr + 36, v0.val[3]); \
vst1q_f32(outptr + 40, v1.val[3]); \
vst1q_f32(outptr + 44, v2.val[3]); \
GI_FLOAT32_V4_t v0 = GiLoadUzipFloat32V4(tmp_output); \
GI_FLOAT32_V4_t v1 = GiLoadUzipFloat32V4(tmp_output + 16); \
GI_FLOAT32_V4_t v2 = GiLoadUzipFloat32V4(tmp_output + 32); \
GiStoreFloat32(outptr, GiGetSubVectorFloat32V4(v0, 0)); \
GiStoreFloat32(outptr + 4, GiGetSubVectorFloat32V4(v1, 0)); \
GiStoreFloat32(outptr + 8, GiGetSubVectorFloat32V4(v2, 0)); \
GiStoreFloat32(outptr + 12, GiGetSubVectorFloat32V4(v0, 1)); \
GiStoreFloat32(outptr + 16, GiGetSubVectorFloat32V4(v1, 1)); \
GiStoreFloat32(outptr + 20, GiGetSubVectorFloat32V4(v2, 1)); \
GiStoreFloat32(outptr + 24, GiGetSubVectorFloat32V4(v0, 2)); \
GiStoreFloat32(outptr + 28, GiGetSubVectorFloat32V4(v1, 2)); \
GiStoreFloat32(outptr + 32, GiGetSubVectorFloat32V4(v2, 2)); \
GiStoreFloat32(outptr + 36, GiGetSubVectorFloat32V4(v0, 3)); \
GiStoreFloat32(outptr + 40, GiGetSubVectorFloat32V4(v1, 3)); \
GiStoreFloat32(outptr + 44, GiGetSubVectorFloat32V4(v2, 3)); \
outptr += ksize12; \
tmp_output += 48; \
} \
\
outptr = output_base4; \
for (; out_index + 3 < block_size; out_index += 4) { \
float32x4x4_t v0 = vld4q_f32(tmp_output); \
vst1q_f32(outptr, v0.val[0]); \
vst1q_f32(outptr + 4, v0.val[1]); \
vst1q_f32(outptr + 8, v0.val[2]); \
vst1q_f32(outptr + 12, v0.val[3]); \
GI_FLOAT32_V4_t v0 = GiLoadUzipFloat32V4(tmp_output); \
GiStoreFloat32(outptr, GiGetSubVectorFloat32V4(v0, 0)); \
GiStoreFloat32(outptr + 4, GiGetSubVectorFloat32V4(v0, 1)); \
GiStoreFloat32(outptr + 8, GiGetSubVectorFloat32V4(v0, 2)); \
GiStoreFloat32(outptr + 12, GiGetSubVectorFloat32V4(v0, 3)); \
outptr += ksize4; \
tmp_output += 16; \
} \
......@@ -45,21 +44,21 @@ namespace {
float zerobuffer[16] = {0}; \
size_t out_remain = std::min(block_size - out_index, 4); \
std::memcpy(zerobuffer, tmp_output, out_remain * sizeof(float) * 4); \
float32x4x4_t v0 = vld4q_f32(zerobuffer); \
vst1q_f32(outptr, v0.val[0]); \
vst1q_f32(outptr + 4, v0.val[1]); \
vst1q_f32(outptr + 8, v0.val[2]); \
vst1q_f32(outptr + 12, v0.val[3]); \
GI_FLOAT32_V4_t v0 = GiLoadUzipFloat32V4(zerobuffer); \
GiStoreFloat32(outptr, GiGetSubVectorFloat32V4(v0, 0)); \
GiStoreFloat32(outptr + 4, GiGetSubVectorFloat32V4(v0, 1)); \
GiStoreFloat32(outptr + 8, GiGetSubVectorFloat32V4(v0, 2)); \
GiStoreFloat32(outptr + 12, GiGetSubVectorFloat32V4(v0, 3)); \
} \
output_base += 48; \
output_base4 += 16;
#define LOAD_AND_STOR_IM2COL_DST() \
float32x4_t v1 = vld1q_f32(&src[index + 4]); \
float32x4_t v2 = vld1q_f32(&src[index + 8]); \
vst1q_f32(&output0[i], v0); \
vst1q_f32(&output1[i], v1); \
vst1q_f32(&output2[i], v2); \
GI_FLOAT32_t v1 = GiLoadFloat32(&src[index + 4]); \
GI_FLOAT32_t v2 = GiLoadFloat32(&src[index + 8]); \
GiStoreFloat32(&output0[i], v0); \
GiStoreFloat32(&output1[i], v1); \
GiStoreFloat32(&output2[i], v2); \
i += 4; \
index += 8; \
v0 = v2;
......@@ -94,12 +93,12 @@ void fuse_packb(
size_t index = 4 * (ic * IH * IW + (start_h * SH + fh) * IW +
cur_remain_w * SW);
for (int w = cur_remain_w; w < end_remain_w; w++) {
vst1q_f32(&output02[i], vld1q_f32(&src[index]));
vst1q_f32(&output1[i], vld1q_f32(&src[index + 4]));
GiStoreFloat32(&output02[i], GiLoadFloat32(&src[index]));
GiStoreFloat32(&output1[i], GiLoadFloat32(&src[index + 4]));
i += 4;
index += 8;
}
vst1q_f32(&output02[i], vld1q_f32(&src[index]));
GiStoreFloat32(&output02[i], GiLoadFloat32(&src[index]));
float* output[3];
output[0] = output02;
output[1] = output1;
......@@ -120,19 +119,19 @@ void fuse_packb(
size_t index = 4 * (ic * IH * IW + (start_h * SH + fh) * IW +
(cur_remain_w * SW));
float32x4_t v0 = vld1q_f32(&src[index]);
GI_FLOAT32_t v0 = GiLoadFloat32(&src[index]);
for (int w = cur_remain_w; w < OW; w++) {
LOAD_AND_STOR_IM2COL_DST();
}
for (int h = start_h + 1; h < end_h; h++) {
size_t index = 4 * (ic * IH * IW + (h * SH + fh) * IW);
v0 = vld1q_f32(&src[index]);
v0 = GiLoadFloat32(&src[index]);
rep(ow, OW) { LOAD_AND_STOR_IM2COL_DST(); }
}
index = 4 * (ic * IH * IW + (end_h * SH + fh) * IW);
v0 = vld1q_f32(&src[index]);
v0 = GiLoadFloat32(&src[index]);
for (int w = 0; w < end_remain_w; w++) {
LOAD_AND_STOR_IM2COL_DST();
}
......@@ -190,6 +189,4 @@ template class StrategyFuseXx12x1Nchw44K3x3S2<
float, float, megdnn::PostprocessMode::FLOAT>;
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
......@@ -57,7 +57,7 @@ void kern_naive(const MatrixMulImpl::KernParam& kern_param) {
size_t pack_size = get_pack_size();
megdnn_assert(
(M % pack_size == 0 && K % pack_size == 0),
"M and N must time of pack_size M: %zu N: %zu pack_size: %zu", M, N,
"M and K must time of pack_size M: %zu K: %zu pack_size: %zu", M, N,
pack_size);
#define DISPATCH(TA, TB) \
......@@ -263,12 +263,15 @@ void gi_f32_mk4_4x8_kern(const MatrixMulImpl::KernParam& kern_param) {
} // anonymous namespace
bool MatrixMulImpl::AlgoF32GiMK4_4x8::usable(
const KernSizeParam& kern_size_param) const {
constexpr size_t MB = 4;
constexpr size_t KB = 4;
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA &&
!kern_size_param.trB;
!kern_size_param.trB && kern_size_param.M % MB == 0 &&
kern_size_param.K % KB == 0;
}
size_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_workspace(
......@@ -295,6 +298,71 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern(
return gi_f32_mk4_4x8_kern;
}
/* ===================== F32 algo gi mk4 pack K4x12 ===================== */
namespace {
void f32_gi_mk4_pack_4x12_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(
megdnn_fb_gi_matmul_kern, midout_iv("f32_gi_mk4_pack_4x12_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<float>(), Bptr = kern_param.B<float>();
auto Cptr = kern_param.C<float>();
matmul::fallback::gi_sgemm_mk4_pack_4x12 strategy(
M, N, K, A_type, B_type, C_type);
megdnn::matmul::GemmInterleaved<matmul::fallback::gi_sgemm_mk4_pack_4x12>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // anonymous namespace
bool MatrixMulImpl::AlgoF32GiMK4Pack4x12::usable(
const KernSizeParam& kern_size_param) const {
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA &&
!kern_size_param.trB && kern_size_param.M % 4 == 0 &&
kern_size_param.K % 4 == 0 && !kern_size_param.trA && !kern_size_param.trB;
}
size_t MatrixMulImpl::AlgoF32GiMK4Pack4x12::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(
megdnn_fb_gi_matmul_kern,
midout_iv("AlgoF32GiMK4Pack4x12::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
matmul::fallback::gi_sgemm_mk4_pack_4x12 strategy(
M, N, K, A_type, B_type, C_type);
return megdnn::matmul::GemmInterleaved<
matmul::fallback::gi_sgemm_mk4_pack_4x12>(
M, N, K, trA, trB, strategy)
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4Pack4x12::get_kern(
const KernSizeParam&) const {
return f32_gi_mk4_pack_4x12_kern;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
AlgoF32GiMK4Pack4x12, megdnn_fb_gi_matmul_kern, "AlgoF32GiMK4Pack4x12"_hash,
matmul::fallback::gi_sgemm_mk4_pack_4x12, float, float, AlgoDataType::FLOAT32,
MK4);
/* ===================== F32 algo ===================== */
namespace {
void f32_kern(const MatrixMulImpl::KernParam& kern_param) {
......
......@@ -97,6 +97,19 @@ public:
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_4x8)
};
class MatrixMulImpl::AlgoF32GiMK4Pack4x12 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "FB_GI_F32_MK4_PACK_4x12"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_PACK_4x12)
};
class MatrixMulImpl::AlgoF32Gi4x12 final : public AlgoBase {
public:
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
......
......@@ -9,6 +9,8 @@ MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12)
MEGDNN_REG_GEMM_STRATEGY_NOPACK(
float, float, float, 4, 8, 1, false, true, gi_sgemm_nopack_4x8);
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, gi_sgemm_4x12);
MEGDNN_REG_GEMM_STRATEGY(
float, float, float, 4, 12, 1, false, false, gi_sgemm_mk4_pack_4x12);
} // namespace fallback
} // namespace matmul
......
......@@ -214,6 +214,113 @@ static GI_FORCEINLINE void transpose_4x4_1_s(
outptr += stride;
}
template <typename T>
static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) {
static_assert(sizeof(T) == 4, "transpose_1x12_4_s only support sizeof(T) == 4");
GI_FLOAT32_t tmp_a, tmp_b;
#define LOAD() \
tmp_a = GiLoadFloat32(inptr0); \
inptr0 += 4; \
tmp_b = GiLoadFloat32(inptr0); \
inptr0 += 4;
LOAD();
GI_FLOAT32_V2_t d0d1d2d3 = GiZipqFloat32(tmp_a, tmp_b);
LOAD();
GI_FLOAT32_V2_t d4d5d6d7 = GiZipqFloat32(tmp_a, tmp_b);
LOAD();
GI_FLOAT32_V2_t d8d9d10d11 = GiZipqFloat32(tmp_a, tmp_b);
LOAD();
GI_FLOAT32_V2_t d12d13d14d15 = GiZipqFloat32(tmp_a, tmp_b);
LOAD();
GI_FLOAT32_V2_t d16d17d18d19 = GiZipqFloat32(tmp_a, tmp_b);
LOAD();
GI_FLOAT32_V2_t d20d21d22d23 = GiZipqFloat32(tmp_a, tmp_b);
#undef LOAD
GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 0)));
GiSt1Float32(outptr + 1 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 0)));
GiSt1Float32(
outptr + 2 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d8d9d10d11, 0)));
GiSt1Float32(
outptr + 3 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d12d13d14d15, 0)));
GiSt1Float32(
outptr + 4 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d16d17d18d19, 0)));
GiSt1Float32(
outptr + 5 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d20d21d22d23, 0)));
GiSt1Float32(
outptr + 6 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 0)));
GiSt1Float32(
outptr + 7 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 0)));
GiSt1Float32(
outptr + 8 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d8d9d10d11, 0)));
GiSt1Float32(
outptr + 9 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d12d13d14d15, 0)));
GiSt1Float32(
outptr + 10 * 2,
GiGetHighFloat32(GiGetSubVectorFloat32V2(d16d17d18d19, 0)));
GiSt1Float32(
outptr + 11 * 2,
GiGetHighFloat32(GiGetSubVectorFloat32V2(d20d21d22d23, 0)));
GiSt1Float32(
outptr + 12 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 1)));
GiSt1Float32(
outptr + 13 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 1)));
GiSt1Float32(
outptr + 14 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d8d9d10d11, 1)));
GiSt1Float32(
outptr + 15 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d12d13d14d15, 1)));
GiSt1Float32(
outptr + 16 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d16d17d18d19, 1)));
GiSt1Float32(
outptr + 17 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d20d21d22d23, 1)));
GiSt1Float32(
outptr + 18 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 1)));
GiSt1Float32(
outptr + 19 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 1)));
GiSt1Float32(
outptr + 20 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d8d9d10d11, 1)));
GiSt1Float32(
outptr + 21 * 2,
GiGetHighFloat32(GiGetSubVectorFloat32V2(d12d13d14d15, 1)));
GiSt1Float32(
outptr + 22 * 2,
GiGetHighFloat32(GiGetSubVectorFloat32V2(d16d17d18d19, 1)));
GiSt1Float32(
outptr + 23 * 2,
GiGetHighFloat32(GiGetSubVectorFloat32V2(d20d21d22d23, 1)));
outptr += 23 * 2;
}
template <typename T>
static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) {
static_assert(sizeof(T) == 4, "transpose_1x4_4_s only support sizeof(T) == 4");
GI_FLOAT32_t tmp_a, tmp_b;
#define LOAD() \
tmp_a = GiLoadFloat32(inptr0); \
inptr0 += 4; \
tmp_b = GiLoadFloat32(inptr0); \
inptr0 += 4;
LOAD();
GI_FLOAT32_V2_t d0d1d2d3 = GiZipqFloat32(tmp_a, tmp_b);
LOAD();
GI_FLOAT32_V2_t d4d5d6d7 = GiZipqFloat32(tmp_a, tmp_b);
#undef LOAD
GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 0)));
GiSt1Float32(outptr + 1 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 0)));
GiSt1Float32(
outptr + 2 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 0)));
GiSt1Float32(
outptr + 3 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 0)));
GiSt1Float32(outptr + 4 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 1)));
GiSt1Float32(outptr + 5 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 1)));
GiSt1Float32(
outptr + 6 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 1)));
GiSt1Float32(
outptr + 7 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 1)));
outptr += 7 * 2;
}
} // namespace fallback
} // namespace matmul
} // namespace megdnn
......
//! risc-v gcc will error report uninitialized var at if/else case when use RVV type
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wuninitialized"
#ifdef __GNUC__
#ifndef __has_warning
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#else
#if __has_warning("-Wmaybe-uninitialized")
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#endif
#endif
#endif
#include "src/fallback/matrix_mul/generic_strategy.h"
#include "src/fallback/matrix_mul/gi/fp32/common.h"
using namespace megdnn;
using namespace matmul::fallback;
namespace {
void kern_4x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA;
const float* b_ptr = packB;
float* output0 = output;
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;
float* r1 = output;
GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15, d16d17, d18d19,
d20d21, d22d23, d24d25, d26d27, d28d29, d30d31;
if (is_first_k) {
d8d9 = GiBroadcastFloat32(0.0f);
d10d11 = GiBroadcastFloat32(0.0f);
d12d13 = GiBroadcastFloat32(0.0f);
d14d15 = GiBroadcastFloat32(0.0f);
d0d1 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;
d16d17 = GiBroadcastFloat32(0.0f);
d18d19 = GiBroadcastFloat32(0.0f);
d20d21 = GiBroadcastFloat32(0.0f);
d22d23 = GiBroadcastFloat32(0.0f);
d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d24d25 = GiBroadcastFloat32(0.0f);
d26d27 = GiBroadcastFloat32(0.0f);
d28d29 = GiBroadcastFloat32(0.0f);
d30d31 = GiBroadcastFloat32(0.0f);
} else {
d8d9 = GiLoadFloat32(r1);
r1 = r1 + 4;
d10d11 = GiLoadFloat32(r1);
r1 = r1 + 4;
d12d13 = GiLoadFloat32(r1);
r1 = r1 + 4;
d14d15 = GiLoadFloat32(r1);
r1 = r1 + 4;
d16d17 = GiLoadFloat32(r1);
r1 = r1 + 4;
d18d19 = GiLoadFloat32(r1);
r1 = r1 + 4;
d20d21 = GiLoadFloat32(r1);
r1 = r1 + 4;
d22d23 = GiLoadFloat32(r1);
r1 = r1 + 4;
d24d25 = GiLoadFloat32(r1);
r1 = r1 + 4;
d26d27 = GiLoadFloat32(r1);
r1 = r1 + 4;
d28d29 = GiLoadFloat32(r1);
r1 = r1 + 4;
d30d31 = GiLoadFloat32(r1);
r1 = r1 + 4;
d0d1 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;
d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
}
for (; K > 0; K--) {
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0);
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1);
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2);
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3);
d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d16d17 = GiSimdFmaLane(d16d17, d0d1, d6d7, 0);
d18d19 = GiSimdFmaLane(d18d19, d0d1, d6d7, 1);
d20d21 = GiSimdFmaLane(d20d21, d0d1, d6d7, 2);
d2d3 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;
d22d23 = GiSimdFmaLane(d22d23, d0d1, d6d7, 3);
d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d24d25 = GiSimdFmaLane(d24d25, d0d1, d4d5, 0);
d26d27 = GiSimdFmaLane(d26d27, d0d1, d4d5, 1);
d28d29 = GiSimdFmaLane(d28d29, d0d1, d4d5, 2);
d30d31 = GiSimdFmaLane(d30d31, d0d1, d4d5, 3);
d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0);
d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1);
d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2);
d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3);
d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d16d17 = GiSimdFmaLane(d16d17, d2d3, d4d5, 0);
d18d19 = GiSimdFmaLane(d18d19, d2d3, d4d5, 1);
d0d1 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;
d20d21 = GiSimdFmaLane(d20d21, d2d3, d4d5, 2);
d22d23 = GiSimdFmaLane(d22d23, d2d3, d4d5, 3);
d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d24d25 = GiSimdFmaLane(d24d25, d2d3, d6d7, 0);
d26d27 = GiSimdFmaLane(d26d27, d2d3, d6d7, 1);
d28d29 = GiSimdFmaLane(d28d29, d2d3, d6d7, 2);
d30d31 = GiSimdFmaLane(d30d31, d2d3, d6d7, 3);
d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
}
if (1 == oddk) {
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0);
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1);
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2);
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3);
d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d16d17 = GiSimdFmaLane(d16d17, d0d1, d6d7, 0);
GiStoreFloat32(output0, d8d9);
output0 = output0 + 4;
GiStoreFloat32(output0, d10d11);
output0 = output0 + 4;
d18d19 = GiSimdFmaLane(d18d19, d0d1, d6d7, 1);
d20d21 = GiSimdFmaLane(d20d21, d0d1, d6d7, 2);
GiStoreFloat32(output0, d12d13);
output0 = output0 + 4;
GiStoreFloat32(output0, d14d15);
output0 = output0 + 4;
d22d23 = GiSimdFmaLane(d22d23, d0d1, d6d7, 3);
d24d25 = GiSimdFmaLane(d24d25, d0d1, d4d5, 0);
GiStoreFloat32(output0, d16d17);
output0 = output0 + 4;
GiStoreFloat32(output0, d18d19);
output0 = output0 + 4;
d26d27 = GiSimdFmaLane(d26d27, d0d1, d4d5, 1);
GiStoreFloat32(output0, d20d21);
output0 = output0 + 4;
GiStoreFloat32(output0, d22d23);
output0 = output0 + 4;
d28d29 = GiSimdFmaLane(d28d29, d0d1, d4d5, 2);
GiStoreFloat32(output0, d24d25);
output0 = output0 + 4;
GiStoreFloat32(output0, d26d27);
output0 = output0 + 4;
d30d31 = GiSimdFmaLane(d30d31, d0d1, d4d5, 3);
GiStoreFloat32(output0, d28d29);
output0 = output0 + 4;
GiStoreFloat32(output0, d30d31);
output0 = output0 + 4;
} else {
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0);
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1);
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2);
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3);
d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d16d17 = GiSimdFmaLane(d16d17, d0d1, d6d7, 0);
d18d19 = GiSimdFmaLane(d18d19, d0d1, d6d7, 1);
d20d21 = GiSimdFmaLane(d20d21, d0d1, d6d7, 2);
d2d3 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;
d22d23 = GiSimdFmaLane(d22d23, d0d1, d6d7, 3);
d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d24d25 = GiSimdFmaLane(d24d25, d0d1, d4d5, 0);
d26d27 = GiSimdFmaLane(d26d27, d0d1, d4d5, 1);
d28d29 = GiSimdFmaLane(d28d29, d0d1, d4d5, 2);
d30d31 = GiSimdFmaLane(d30d31, d0d1, d4d5, 3);
d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0);
d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1);
d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2);
d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3);
d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d16d17 = GiSimdFmaLane(d16d17, d2d3, d4d5, 0);
d18d19 = GiSimdFmaLane(d18d19, d2d3, d4d5, 1);
GiStoreFloat32(output0, d8d9);
output0 = output0 + 4;
GiStoreFloat32(output0, d10d11);
output0 = output0 + 4;
d20d21 = GiSimdFmaLane(d20d21, d2d3, d4d5, 2);
d22d23 = GiSimdFmaLane(d22d23, d2d3, d4d5, 3);
GiStoreFloat32(output0, d12d13);
output0 = output0 + 4;
GiStoreFloat32(output0, d14d15);
output0 = output0 + 4;
d24d25 = GiSimdFmaLane(d24d25, d2d3, d6d7, 0);
d26d27 = GiSimdFmaLane(d26d27, d2d3, d6d7, 1);
GiStoreFloat32(output0, d16d17);
output0 = output0 + 4;
GiStoreFloat32(output0, d18d19);
output0 = output0 + 4;
d28d29 = GiSimdFmaLane(d28d29, d2d3, d6d7, 2);
d30d31 = GiSimdFmaLane(d30d31, d2d3, d6d7, 3);
GiStoreFloat32(output0, d20d21);
output0 = output0 + 4;
GiStoreFloat32(output0, d22d23);
output0 = output0 + 4;
GiStoreFloat32(output0, d24d25);
output0 = output0 + 4;
GiStoreFloat32(output0, d26d27);
output0 = output0 + 4;
GiStoreFloat32(output0, d28d29);
output0 = output0 + 4;
GiStoreFloat32(output0, d30d31);
output0 = output0 + 4;
}
}
void kern_4x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;
float* r1 = output;
GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15;
if (is_first_k) {
d8d9 = GiBroadcastFloat32(0.0f);
d10d11 = GiBroadcastFloat32(0.0f);
d0d1 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;
d12d13 = GiBroadcastFloat32(0.0f);
d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d14d15 = GiBroadcastFloat32(0.0f);
} else {
if (n_remain == 4) {
d8d9 = GiLoadFloat32(r1);
r1 = r1 + 4;
d10d11 = GiLoadFloat32(r1);
r1 = r1 + 4;
d12d13 = GiLoadFloat32(r1);
r1 = r1 + 4;
d14d15 = GiLoadFloat32(r1);
r1 = r1 + 4;
} else if (n_remain == 3) {
d8d9 = GiLoadFloat32(r1);
r1 = r1 + 4;
d10d11 = GiLoadFloat32(r1);
r1 = r1 + 4;
d12d13 = GiLoadFloat32(r1);
r1 = r1 + 4;
} else if (n_remain == 2) {
d8d9 = GiLoadFloat32(r1);
r1 = r1 + 4;
d10d11 = GiLoadFloat32(r1);
r1 = r1 + 4;
} else if (n_remain == 1) {
d8d9 = GiLoadFloat32(r1);
r1 = r1 + 4;
}
}
for (; K > 0; K--) {
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0);
d2d3 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1);
d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2);
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3);
d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0);
d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1);
d0d1 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;
d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2);
d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3);
}
if (1 == oddk) {
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0);
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1);
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2);
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3);
} else {
d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0);
d2d3 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;
d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1);
d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2);
d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3);
d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0);
d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1);
d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2);
d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3);
}
if (n_remain == 4) {
GiStoreFloat32(output, d8d9);
output = output + 4;
GiStoreFloat32(output, d10d11);
output = output + 4;
GiStoreFloat32(output, d12d13);
output = output + 4;
GiStoreFloat32(output, d14d15);
output = output + 4;
} else if (n_remain == 3) {
GiStoreFloat32(output, d8d9);
output = output + 4;
GiStoreFloat32(output, d10d11);
output = output + 4;
GiStoreFloat32(output, d12d13);
output = output + 4;
} else if (n_remain == 2) {
GiStoreFloat32(output, d8d9);
output = output + 4;
GiStoreFloat32(output, d10d11);
output = output + 4;
} else if (n_remain == 1) {
GiStoreFloat32(output, d8d9);
output = output + 4;
}
}
} // namespace
MEGDNN_REG_GEMM_STRATEGY_IMPL(gi_sgemm_mk4_pack_4x12);
//! Now no matmul mode of only packB support in conv1x1 and im2col, so just copy
//! the weight
void gi_sgemm_mk4_pack_4x12::pack_A(
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax,
bool) const {
megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4");
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int PACK_C_SIZE = 4;
size_t cp_length = (kmax - k0) * PACK_C_SIZE;
for (int m = y0; m < ymax; m += 4) {
const float* src = in + (m / PACK_C_SIZE) * ldin + k0 * PACK_C_SIZE;
memcpy(out, src, cp_length * sizeof(float));
out += cp_length;
}
}
void gi_sgemm_mk4_pack_4x12::pack_B(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose_B) const {
megdnn_assert(!transpose_B);
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
float tmpbuff[16] = {0.0f};
constexpr int PACK_C_SIZE = 4;
int ksize = kmax - k0;
int ksize12 = ksize * 12;
int ksize4 = (ksize << 2);
float* outptr_base = out;
float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12;
int k = k0;
for (; k + 3 < kmax; k += 4) {
const float* inptr = in + k / PACK_C_SIZE * ldin + x0 * PACK_C_SIZE;
int x = x0;
auto outptr = outptr_base;
for (; x + 12 <= xmax; x += 12) {
auto outptr_interleave = outptr;
transpose_1x12_4_s(inptr, outptr_interleave);
outptr += ksize12;
}
outptr = outptr_base4;
for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr;
transpose_1x4_4_s(inptr, outptr_interleave);
outptr += ksize4;
}
if (x < xmax) {
memcpy(tmpbuff, inptr, sizeof(float) * (xmax - x) * PACK_C_SIZE);
auto outptr_interleave = outptr;
const float* tmp_ptr = &tmpbuff[0];
transpose_1x4_4_s<float>(tmp_ptr, outptr_interleave);
outptr += ksize4;
}
outptr_base += 12 * PACK_C_SIZE;
outptr_base4 += 4 * PACK_C_SIZE;
}
}
void gi_sgemm_mk4_pack_4x12::kern(
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k, const float*, float*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
constexpr int PACK_C_SIZE = 4;
constexpr size_t A_INTERLEAVE = 4;
constexpr size_t B_INTERLEAVE = 12;
const int K12 = K * 12;
const int K4 = K * 4;
size_t m = 0;
for (; m < M; m += A_INTERLEAVE) {
float* output = C + (m / 4 * LDC);
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
kern_4x12(packA, cur_packB, K, output, LDC, is_first_k);
output += PACK_C_SIZE * B_INTERLEAVE;
cur_packB += K12;
}
for (; n < N; n += 4) {
kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += PACK_C_SIZE * 4;
cur_packB += K4;
}
packA += K4;
}
}
// vim: syntax=cpp.doxygen
......@@ -28,6 +28,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoNaive naive;
AlgoF32GiGemvMK4 f32_gemv_mk4;
AlgoF32GiMK4_4x8 f32_mk4_4x8;
AlgoF32GiMK4Pack4x12 f32_mk4_gi_pack_4x12;
AlgoF32Gi4x12 f32_4x8;
SmallVector<AlgoBase*> m_all_algos;
AlgoBase::Mapper m_all_algos_map;
......@@ -36,6 +37,7 @@ public:
AlgoPack() {
m_all_algos.emplace_back(&f32_gemv_mk4);
m_all_algos.emplace_back(&f32_mk4_4x8);
m_all_algos.emplace_back(&f32_mk4_gi_pack_4x12);
m_all_algos.emplace_back(&f32_4x8);
m_all_algos.emplace_back(&gemv);
m_all_algos.emplace_back(&f32_k8x12x1);
......
......@@ -103,6 +103,7 @@ public:
FB_NAIVE,
FB_GI_F32_GEMV_MK4,
FB_GI_F32_MK4_4x8,
FB_GI_F32_MK4_PACK_4x12,
FB_GI_F32_4x12,
#if MEGDNN_X86
......@@ -233,6 +234,7 @@ private:
class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1
class AlgoF32GiGemvMK4; // fallback F32 gi Gemv NCHW44
class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44
class AlgoF32GiMK4Pack4x12; // fallback F32 gi Gemm pack NCHW44
class AlgoF32Gi4x12; // fallback F32 gi Gemm
class AlgoGemv;
class AlgoNaive;
......
......@@ -364,7 +364,9 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> {
DType dst_type, size_t N, size_t OC, size_t OH, size_t OW,
size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
megdnn_assert(pack_oc_size == 1, "PostProcess only support nchw in x86");
megdnn_assert(
pack_oc_size == 1 || pack_oc_size == 4,
"PostProcess only support nchw/44 in x86");
megdnn_assert(
nonlineMode == megdnn::param::ConvBiasV0::NonlineMode::IDENTITY,
"Add bias PostProcess only support IDENTITY");
......
......@@ -59,6 +59,11 @@ cb(dt_float32, float, "avx2", float, __m256, mm256, ps, ps, SIMDType::AVX2);
template <typename ctype, SIMDType simd_type = SIMDType::AVX2>
struct ParamElemVisitorHalfBoardCast;
//! some compiler do not define _mm256_set_m128
#define _mm256_set_m128ff(xmm1, xmm2) \
_mm256_permute2f128_ps( \
_mm256_castps128_ps256(xmm1), _mm256_castps128_ps256(xmm2), 2)
#define cb( \
_ctype, _simd_ptr_type, load_half_fuc, half_type, _simd_type, board_cast_func) \
template <> \
......@@ -78,9 +83,10 @@ cb(dt_int32, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i);
cb(dt_int16, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i);
cb(dt_int8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i);
cb(dt_uint8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i);
cb(dt_float32, float, _mm_load_ps, __m128, __m256, _mm256_set_m128);
cb(dt_float32, float, _mm_load_ps, __m128, __m256, _mm256_set_m128ff);
#undef cb
#undef _mm256_set_m128ff
/*!
* \brief broadcast type
* BCAST_x[0]x[1]...: x[i] == !stride[i]
......
......@@ -239,6 +239,74 @@ void checker_conv_bias(
}
}
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_1X1_S1_MK4_PACK_F32) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({1}, FULL_NLMODE, ALL_BIASMODE, 1, true);
check_conv_bias(args, handle(), "CONV1x1:FB_GI_F32_MK4_PACK_4x12:24");
}
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S1_MK4_PACK_F32_PREPROCESS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({2, 4, 7}, FULL_NLMODE, BR_AND_NO_BIASMODE, 1);
#define cb(name) \
check_conv_bias_preprocess( \
args, handle(), nullptr, 0.001, dtype::Float32(), dtype::Float32(), \
dtype::Float32(), dtype::Float32(), name);
cb("IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12");
#undef cb
}
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S2_MK4_PACK_F32_FUSE_PREPROCESS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({3}, FULL_NLMODE, BR_AND_BIAS_BIASMODE, 2);
#define cb(name) \
check_conv_bias_preprocess( \
args, handle(), nullptr, 0.001, dtype::Float32(), dtype::Float32(), \
dtype::Float32(), dtype::Float32(), name);
cb("IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12");
#undef cb
}
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_1X1_S1_MK4_PACK_F32_PREPROCESS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({1}, FULL_NLMODE, ALL_BIASMODE, 1, true);
#define cb(name) \
check_conv_bias_preprocess( \
args, handle(), nullptr, 0.001, dtype::Float32(), dtype::Float32(), \
dtype::Float32(), dtype::Float32(), name);
cb("CONV1x1:FB_GI_F32_MK4_PACK_4x12:24");
#undef cb
}
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S1_MK4_PACK_F32) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({2, 4, 7}, FULL_NLMODE, BR_AND_BIAS_BIASMODE, 1);
check_conv_bias(args, handle(), "IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12");
}
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S2_MK4_PACK_F32) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({3, 5, 6}, FULL_NLMODE, BR_AND_BIAS_BIASMODE, 2);
#define cb(name) check_conv_bias(args, handle(), name);
cb("IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12");
#undef cb
}
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S2_MK4_PACK_F32_FUSE) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({3}, FULL_NLMODE, ALL_BIASMODE, 2);
#define cb(name) check_conv_bias(args, handle(), name);
cb("IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12");
#undef cb
}
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_IM2COL_8X8X16) {
using namespace conv_bias;
param::ConvBias cur_param;
......
......@@ -42,12 +42,18 @@ TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) {
"FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1);
}
TEST_F(FALLBACK, MATRIX_MULF_GI_F32_4x12) {
TEST_F(FALLBACK, MATRIX_MUL_GI_F32_4x12) {
matrix_mul::check_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
"FB_GI_F32_4x12");
}
TEST_F(FALLBACK, MATRIX_MUL_GI_PACK_MK4) {
matrix_mul::check_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
"FB_GI_F32_MK4_PACK_4x12", param::MatrixMul::Format::MK4, 1);
}
TEST_F(FALLBACK, MATRIX_MUL_RECORD) {
TaskRecordChecker<MatrixMul> checker(1);
using Param = MatrixMul::Param;
......@@ -163,6 +169,13 @@ TEST_F(FALLBACK, BENCHMARK_MATRIX_MUL_FB_GI_F32_4x12) {
"FB_GI_F32_4x12", param::MatrixMul::Format::DEFAULT);
}
TEST_F(FALLBACK, BENCHMARK_MATRIX_MUL_GI_PACK_MK4) {
auto args = matrix_mul::get_benchmark_matmul_args();
matrix_mul::benchmark_single_algo(
handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
"FB_GI_F32_MK4_PACK_4x12", param::MatrixMul::Format::MK4);
}
#endif
} // namespace test
} // namespace megdnn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册