提交 d2184af3 编写于 作者: 泥点无名哥's avatar 泥点无名哥

feat(dnn/src/x86/matmul): add matmul_6x16 for x86

上级 dbb3232c
......@@ -10,6 +10,7 @@ project(MegEngine LANGUAGES C CXX VERSION ${MGB_VER_STRING})
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)
set(CMAKE_POLICY_DEFAULT_CMP0048 NEW)
......
......@@ -70,16 +70,16 @@ static void choice_ohw_oc_block(
fallback::MatrixMulImpl::AlgoBase::PackMode pack_mode) {
//! calculate m_oc_tile_size in choice_ohw_oc_block() fucntion,
//! when ohw_tile_size < this value ohw_tile_size = ohw
static constexpr size_t DEFAULT_OHW_MIN_TILE_SIZE = 32;
size_t DEFAULT_OHW_MIN_TILE_SIZE = round_up(32UL, block_n);
//! when nr_threads > 1 and round(ohw,nr_threads)>nr_threads,
//! oc_tile_size = DEFAULT_OC_TILE_SIZE
static constexpr size_t DEFAULT_OC_TILE_SIZE = 512;
size_t DEFAULT_OC_TILE_SIZE = round_up(512UL, block_m);
//! when oc_tile_size > this value m_oc_tile_size =
//! DEFAULT_OC_MAX_TILE_SIZE
static constexpr size_t DEFAULT_OC_MAX_TILE_SIZE = 1024;
size_t DEFAULT_OC_MAX_TILE_SIZE = round_up(1024UL, block_m);
//! when oc_tile_size < this value oc_tile_size =
//! DEFAULT_OC_MIN_TILE_SIZE the purpose is aligning the calculation
static constexpr size_t DEFAULT_OC_MIN_TILE_SIZE = 128;
size_t DEFAULT_OC_MIN_TILE_SIZE = round_up(128UL, block_m);;
size_t nr_threads = param.nr_threads;
size_t OC = param.filter_meta.ocpg;
size_t ohw = param.osz[0] * param.osz[1];
......
......@@ -122,6 +122,7 @@ public:
X86_INT8X8X16_SSE,
X86_INT8X8X32_SSE_4X8X2,
X86_F32_MK8_8X8,
X86_F32_6x16,
X86_INT8X8X32_VNNI,
X86_INT8X8X32_MKLDNN,
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7
......
......@@ -31,6 +31,15 @@ static inline __m256 _mm256_loadu2_m128_emulate(
_mm_loadu_ps(hiaddr), 1);
}
MEGDNN_ATTRIBUTE_TARGET("avx")
static inline void _mm256_storeu2_m128_emulate(float *hiaddr, float *loaddr,
__m256 reg) {
auto xmm0 = _mm256_extractf128_ps(reg, 0);
auto xmm1 = _mm256_extractf128_ps(reg, 1);
_mm_storeu_ps(loaddr, xmm0);
_mm_storeu_ps(hiaddr, xmm1);
}
template <typename ctype, size_t len>
struct Vector;
......
......@@ -320,6 +320,35 @@ void gemm_s8s8s32_sse_4x8x2(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_END();
}
void gemm_f32_avx2_6x16(const MatrixMulImpl::KernParam& kern_param) {
MEGDNN_MARK_USED_VAR(kern_param);
MIDOUT_BEGIN(megdnn_x86_matmul_kern_avx2_6x16x2, midout_iv(0)) {
constexpr int cacheline = 64;
const size_t m = kern_param.M;
const size_t n = kern_param.N;
const size_t k = kern_param.K;
const bool trans_a = kern_param.trA;
const bool trans_b = kern_param.trB;
const size_t lda = kern_param.LDA;
const size_t ldb = kern_param.LDB;
const size_t ldc = kern_param.LDC;
auto a_type = kern_param.A_type;
auto b_type = kern_param.B_type;
auto c_type = kern_param.C_type;
const auto a_ptr = kern_param.A<float>();
const auto b_ptr = kern_param.B<float>();
auto c_ptr = kern_param.C<float>();
x86::matmul::sgemm_pack_6x16_avx2 strategy(m, n, k, a_type, b_type,
c_type);
megdnn::matmul::GemmInterleaved<x86::matmul::sgemm_pack_6x16_avx2>(
m, n, k, trans_a, trans_b, strategy, cacheline)
.execute(a_ptr, lda, b_ptr, ldb, c_ptr, ldc,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // namespace
/*************************AlgoInt8x8x16AVX2********************/
......@@ -662,4 +691,45 @@ size_t MatrixMulImpl::AlgoF32MK8_8x8::get_workspace(
MIDOUT_END();
}
/*************************AlgoFloatAVX2M6N16********************/
MatrixMulImpl::kern_t MatrixMulImpl::AlgoFloatAVX2M6N16::get_kern(
const KernSizeParam&) const {
return gemm_f32_avx2_6x16;
}
bool MatrixMulImpl::AlgoFloatAVX2M6N16::usable(
const KernSizeParam& kern_size_param) const {
bool is_param_ok =
kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
((kern_size_param.A_type.enumv() == DTypeEnum::Float32 &&
kern_size_param.C_type.enumv() == DTypeEnum::Float32)) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::AVX2);
return is_param_ok;
}
size_t MatrixMulImpl::AlgoFloatAVX2M6N16::get_workspace(
const KernSizeParam& kern_param) const {
constexpr int cacheline = 64;
const size_t m = kern_param.M;
const size_t n = kern_param.N;
const size_t k = kern_param.K;
const bool trans_a = kern_param.trA;
const bool trans_b = kern_param.trB;
auto a_type = kern_param.A_type;
auto b_type = kern_param.B_type;
auto c_type = kern_param.C_type;
x86::matmul::sgemm_pack_6x16_avx2 strategy(m, n, k, a_type, b_type,
c_type);
return megdnn::matmul::GemmInterleaved<
x86::matmul::sgemm_pack_6x16_avx2>(
m, n, k, trans_a, trans_b, strategy, cacheline)
.get_workspace_size();
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
AlgoFloatAVX2M6N16, megdnn_x86_matmul_kern,
"AlgoFloatAVX2M6N16"_hash, x86::matmul::sgemm_pack_6x16_avx2,
float, float, float, AlgoDataType::FLOAT32, DEFAULT);
// vim: syntax=cpp.doxygen
......@@ -149,6 +149,19 @@ public:
MEGDNN_DECL_ALGO_TYPE(X86_F32_MK8_8X8)
};
class MatrixMulImpl::AlgoFloatAVX2M6N16 : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char *name() const override { return "X86_F32_6x16"; }
bool usable(const KernSizeParam &) const override;
size_t get_workspace(const KernSizeParam &) const override;
kern_t get_kern(const KernSizeParam &) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(X86_F32_6x16)
};
#if MEGDNN_X86_WITH_VNNI
class MatrixMulImpl::AlgoInt8x8x32Vnni : public AlgoBase {
public:
......
......@@ -19,6 +19,8 @@ namespace matmul {
MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 8, 8, 8, false, true,
sgemm_nopack_8x8_avx2);
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(float, float, float, float,
6, 16, 1, false, false, sgemm_pack_6x16_avx2);
} // namespace matmul
} // namespace x86
} // namespace megdnn
\ No newline at end of file
/**
* \file dnn/src/x86/matrix_mul/f32/strategy_6x16.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
/**
* \file dnn/src/x86/matrix_mul/f32/strategy_6x16.cpp
*
* This file is part of MegDNN, a deep neural network run-time library
* developed by Megvii.
*
* \copyright copyright (c) 2014-2019 megvii inc. all rights reserved.
*/
#include <immintrin.h>
#include "src/common/utils.h"
#include "src/x86/avx_helper.h"
#include "src/x86/matrix_mul/common/common.h"
#include "src/x86/matrix_mul/f32/strategy.h"
#include "src/common/unroll_macro.h"
using namespace megdnn;
using namespace x86;
#define DNN_AVX2_TARGET
#if !defined(__clang__)
//! bypass gcc bug https://bugs.launchpad.net/ubuntu/+source/gcc-5/+bug/1642109
#pragma GCC target("avx2")
#else
#undef DNN_AVX2_TARGET
#define DNN_AVX2_TARGET MEGDNN_ATTRIBUTE_TARGET("avx2")
#endif
#define UNROLL_CODE(cb, i, a...) UNROLL_CALL1(i,cb,##a)
namespace {
DNN_AVX2_TARGET
void transpose_16x8_1_s(const float *inptr0, const float *inptr1,
const float *inptr2, const float *inptr3,
const float *inptr4, const float *inptr5,
const float *inptr6, const float *inptr7,
const float *inptr8, const float *inptr9,
const float *inptr10, const float *inptr11,
const float *inptr12, const float *inptr13,
const float *inptr14, const float *inptr15,
float *outptr) {
auto ymm0 = _mm256_loadu_ps(inptr0); // A0A1A2A3A4A5A6A7
auto ymm1 = _mm256_loadu_ps(inptr1); // B0B1B2B3B4B5B6B7
auto ymm2 = _mm256_loadu_ps(inptr2); // C0C1C2C3C4C5C6C7
auto ymm3 = _mm256_loadu_ps(inptr3); // D0D1D2D3D4D5D6D7
auto ymm4 = _mm256_loadu_ps(inptr4); // E0E1E2E3E4E5E6E7
auto ymm5 = _mm256_loadu_ps(inptr5); // F0F1F2F3F4F5F6F7
auto ymm6 = _mm256_loadu_ps(inptr6); // G0G1G2G3G4G5G6G7
auto ymm7 = _mm256_loadu_ps(inptr7); // H0H1H2H3H4H5H6H7
auto ymm8 = _mm256_unpacklo_ps(ymm0, ymm2); // A0C0A1C1A4C4A5C5
auto ymm9 = _mm256_unpackhi_ps(ymm0, ymm2); // A2C2A3C3A6C6A7C7
auto ymm10 = _mm256_unpacklo_ps(ymm1, ymm3); // B0D0B1D1B4D4B5D5
auto ymm11 = _mm256_unpackhi_ps(ymm1, ymm3); // B2D2B3D3B6D6B7D7
auto ymm12 = _mm256_unpacklo_ps(ymm4, ymm6); // E0G0E1G1E4G4E5G5
auto ymm13 = _mm256_unpackhi_ps(ymm4, ymm6); // E2G2E3G3E6G6E7G7
auto ymm14 = _mm256_unpacklo_ps(ymm5, ymm7); // F0H0F1H1F4H4F5H5
auto ymm15 = _mm256_unpackhi_ps(ymm5, ymm7); // F2H2F3H3F6H6F7H7
ymm0 = _mm256_unpacklo_ps(ymm8, ymm10); // A0B0C0D0A4B4C4D4
ymm1 = _mm256_unpackhi_ps(ymm8, ymm10); // A1B1C1D1A5B5C5D5
ymm2 = _mm256_unpacklo_ps(ymm9, ymm11); // A2B2C2D2A6B6C6D6
ymm3 = _mm256_unpackhi_ps(ymm9, ymm11); // A3B3C3D3A7B7C7D7
ymm4 = _mm256_unpacklo_ps(ymm12, ymm14); // E0F0G0H0E4F4G4H4
ymm5 = _mm256_unpackhi_ps(ymm12, ymm14); // E1F1G1H1E5F5G5H5
ymm6 = _mm256_unpacklo_ps(ymm13, ymm15); // E2F2G2H2E6F6G6H6
ymm7 = _mm256_unpackhi_ps(ymm13, ymm15); // E3F3G3H3E7F7G7H7
ymm8 = _mm256_permute2f128_ps(ymm0, ymm4, 0x20); // A0B0C0D0E0F0G0H0
ymm9 = _mm256_permute2f128_ps(ymm1, ymm5, 0x20); // A1B1C1D1E1F1G1H1
ymm10 = _mm256_permute2f128_ps(ymm2, ymm6, 0x20); // A2B2C2D2E2F2G2H2
ymm11 = _mm256_permute2f128_ps(ymm3, ymm7, 0x20); // A3B3C3D3E3F3G3H3
ymm12 = _mm256_permute2f128_ps(ymm0, ymm4, 0x31); // A4B4C4D4E4F4G4H4
ymm13 = _mm256_permute2f128_ps(ymm1, ymm5, 0x31); // A5B5C5D5E5F5G5H5
ymm14 = _mm256_permute2f128_ps(ymm2, ymm6, 0x31); // A6B6C6D6E6F6G6H6
ymm15 = _mm256_permute2f128_ps(ymm3, ymm7, 0x31); // A7B7C7D7E7F7G7H7
_mm256_storeu_ps(outptr + 16 * 0, ymm8);
_mm256_storeu_ps(outptr + 16 * 1, ymm9);
_mm256_storeu_ps(outptr + 16 * 2, ymm10);
_mm256_storeu_ps(outptr + 16 * 3, ymm11);
_mm256_storeu_ps(outptr + 16 * 4, ymm12);
_mm256_storeu_ps(outptr + 16 * 5, ymm13);
_mm256_storeu_ps(outptr + 16 * 6, ymm14);
_mm256_storeu_ps(outptr + 16 * 7, ymm15);
ymm0 = _mm256_loadu_ps(inptr8); // A0A1A2A3A4A5A6A7
ymm1 = _mm256_loadu_ps(inptr9); // B0B1B2B3B4B5B6B7
ymm2 = _mm256_loadu_ps(inptr10); // C0C1C2C3C4C5C6C7
ymm3 = _mm256_loadu_ps(inptr11); // D0D1D2D3D4D5D6D7
ymm4 = _mm256_loadu_ps(inptr12); // E0E1E2E3E4E5E6E7
ymm5 = _mm256_loadu_ps(inptr13); // F0F1F2F3F4F5F6F7
ymm6 = _mm256_loadu_ps(inptr14); // G0G1G2G3G4G5G6G7
ymm7 = _mm256_loadu_ps(inptr15); // H0H1H2H3H4H5H6H7
ymm8 = _mm256_unpacklo_ps(ymm0, ymm2); // A0C0A1C1A4C4A5C5
ymm9 = _mm256_unpackhi_ps(ymm0, ymm2); // A2C2A3C3A6C6A7C7
ymm10 = _mm256_unpacklo_ps(ymm1, ymm3); // B0D0B1D1B4D4B5D5
ymm11 = _mm256_unpackhi_ps(ymm1, ymm3); // B2D2B3D3B6D6B7D7
ymm12 = _mm256_unpacklo_ps(ymm4, ymm6); // E0G0E1G1E4G4E5G5
ymm13 = _mm256_unpackhi_ps(ymm4, ymm6); // E2G2E3G3E6G6E7G7
ymm14 = _mm256_unpacklo_ps(ymm5, ymm7); // F0H0F1H1F4H4F5H5
ymm15 = _mm256_unpackhi_ps(ymm5, ymm7); // F2H2F3H3F6H6F7H7
ymm0 = _mm256_unpacklo_ps(ymm8, ymm10); // A0B0C0D0A4B4C4D4
ymm1 = _mm256_unpackhi_ps(ymm8, ymm10); // A1B1C1D1A5B5C5D5
ymm2 = _mm256_unpacklo_ps(ymm9, ymm11); // A2B2C2D2A6B6C6D6
ymm3 = _mm256_unpackhi_ps(ymm9, ymm11); // A3B3C3D3A7B7C7D7
ymm4 = _mm256_unpacklo_ps(ymm12, ymm14); // E0F0G0H0E4F4G4H4
ymm5 = _mm256_unpackhi_ps(ymm12, ymm14); // E1F1G1H1E5F5G5H5
ymm6 = _mm256_unpacklo_ps(ymm13, ymm15); // E2F2G2H2E6F6G6H6
ymm7 = _mm256_unpackhi_ps(ymm13, ymm15); // E3F3G3H3E7F7G7H7
ymm8 = _mm256_permute2f128_ps(ymm0, ymm4, 0x20); // A0B0C0D0E0F0G0H0
ymm9 = _mm256_permute2f128_ps(ymm1, ymm5, 0x20); // A1B1C1D1E1F1G1H1
ymm10 = _mm256_permute2f128_ps(ymm2, ymm6, 0x20); // A2B2C2D2E2F2G2H2
ymm11 = _mm256_permute2f128_ps(ymm3, ymm7, 0x20); // A3B3C3D3E3F3G3H3
ymm12 = _mm256_permute2f128_ps(ymm0, ymm4, 0x31); // A4B4C4D4E4F4G4H4
ymm13 = _mm256_permute2f128_ps(ymm1, ymm5, 0x31); // A5B5C5D5E5F5G5H5
ymm14 = _mm256_permute2f128_ps(ymm2, ymm6, 0x31); // A6B6C6D6E6F6G6H6
ymm15 = _mm256_permute2f128_ps(ymm3, ymm7, 0x31); // A7B7C7D7E7F7G7H7
_mm256_storeu_ps(outptr + 16 * 0 + 8, ymm8);
_mm256_storeu_ps(outptr + 16 * 1 + 8, ymm9);
_mm256_storeu_ps(outptr + 16 * 2 + 8, ymm10);
_mm256_storeu_ps(outptr + 16 * 3 + 8, ymm11);
_mm256_storeu_ps(outptr + 16 * 4 + 8, ymm12);
_mm256_storeu_ps(outptr + 16 * 5 + 8, ymm13);
_mm256_storeu_ps(outptr + 16 * 6 + 8, ymm14);
_mm256_storeu_ps(outptr + 16 * 7 + 8, ymm15);
}
DNN_AVX2_TARGET
void transpose_16x4_1_s(const float *inptr0, const float *inptr1,
const float *inptr2, const float *inptr3,
const float *inptr4, const float *inptr5,
const float *inptr6, const float *inptr7,
const float *inptr8, const float *inptr9,
const float *inptr10, const float *inptr11,
const float *inptr12, const float *inptr13,
const float *inptr14, const float *inptr15,
float *outptr) {
const std::uint32_t arr[8] = {0, 1, 4, 5, 2, 3, 6, 7};
__m256i order = _mm256_loadu_si256((const __m256i *)arr);
auto ymm0 = _mm256_loadu2_m128_emulate(inptr2, inptr0); // A0A1A2A3C0C1C2C3
auto ymm1 = _mm256_loadu2_m128_emulate(inptr3, inptr1); // B0B1B2B3D0D1D2D3
auto ymm2 = _mm256_loadu2_m128_emulate(inptr6, inptr4); // E0E1E2E3G0G1G2G3
auto ymm3 = _mm256_loadu2_m128_emulate(inptr7, inptr5); // F0F1F2F3H0H1H2H3
auto ymm4 = _mm256_unpacklo_ps(ymm0, ymm1); // A0B0A1B1C0D0C1D1
auto ymm5 = _mm256_unpackhi_ps(ymm0, ymm1); // A2B2A3B3C2D2C3D3
auto ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); // E0F0E1F1G0H0G1H1
auto ymm7 = _mm256_unpackhi_ps(ymm2, ymm3); // E2F2E3F3G2H2G3H3
auto ymm8 = _mm256_permutevar8x32_ps(ymm4, order); // A0B0C0D0A1B1C1D1
auto ymm9 = _mm256_permutevar8x32_ps(ymm5, order); // A2B2C2D2A3B3C3D3
auto ymm10 = _mm256_permutevar8x32_ps(ymm6, order); // E0F0G0H0E1F1G1H1
auto ymm11 = _mm256_permutevar8x32_ps(ymm7, order); // E2F2G2H2E3F3G3H3
ymm0 = _mm256_permute2f128_ps(ymm8, ymm10, 0x20); // A0B0C0D0E0F0G0H0
ymm1 = _mm256_permute2f128_ps(ymm8, ymm10, 0x31); // A1B1C1D1E1F1G1H1
ymm2 = _mm256_permute2f128_ps(ymm9, ymm11, 0x20); // A2B2C2D2E2F2G2H2
ymm3 = _mm256_permute2f128_ps(ymm9, ymm11, 0x31); // A3B3C3D3E3F3G3H3
_mm256_storeu_ps(outptr + 16 * 0, ymm0);
_mm256_storeu_ps(outptr + 16 * 1, ymm1);
_mm256_storeu_ps(outptr + 16 * 2, ymm2);
_mm256_storeu_ps(outptr + 16 * 3, ymm3);
ymm0 = _mm256_loadu2_m128_emulate(inptr10, inptr8); // A0A1A2A3C0C1C2C3
ymm1 = _mm256_loadu2_m128_emulate(inptr11, inptr9); // B0B1B2B3D0D1D2D3
ymm2 = _mm256_loadu2_m128_emulate(inptr14, inptr12); // E0E1E2E3G0G1G2G3
ymm3 = _mm256_loadu2_m128_emulate(inptr15, inptr13); // F0F1F2F3H0H1H2H3
ymm4 = _mm256_unpacklo_ps(ymm0, ymm1); // A0B0A1B1C0D0C1D1
ymm5 = _mm256_unpackhi_ps(ymm0, ymm1); // A2B2A3B3C2D2C3D3
ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); // E0F0E1F1G0H0G1H1
ymm7 = _mm256_unpackhi_ps(ymm2, ymm3); // E2F2E3F3G2H2G3H3
ymm8 = _mm256_permutevar8x32_ps(ymm4, order); // A0B0C0D0A1B1C1D1
ymm9 = _mm256_permutevar8x32_ps(ymm5, order); // A2B2C2D2A3B3C3D3
ymm10 = _mm256_permutevar8x32_ps(ymm6, order); // E0F0G0H0E1F1G1H1
ymm11 = _mm256_permutevar8x32_ps(ymm7, order); // E2F2G2H2E3F3G3H3
ymm0 = _mm256_permute2f128_ps(ymm8, ymm10, 0x20); // A0B0C0D0E0F0G0H0
ymm1 = _mm256_permute2f128_ps(ymm8, ymm10, 0x31); // A1B1C1D1E1F1G1H1
ymm2 = _mm256_permute2f128_ps(ymm9, ymm11, 0x20); // A2B2C2D2E2F2G2H2
ymm3 = _mm256_permute2f128_ps(ymm9, ymm11, 0x31); // A3B3C3D3E3F3G3H3
_mm256_storeu_ps(outptr + 16 * 0 + 8, ymm0);
_mm256_storeu_ps(outptr + 16 * 1 + 8, ymm1);
_mm256_storeu_ps(outptr + 16 * 2 + 8, ymm2);
_mm256_storeu_ps(outptr + 16 * 3 + 8, ymm3);
}
static size_t min(size_t a, size_t b) { return a > b ? b : a; }
DNN_AVX2_TARGET
void transpose_6x16_1_s(const float *inptr0, const float *inptr1,
const float *inptr2, const float *inptr3,
const float *inptr4, const float *inptr5,
float *outptr) {
auto ymm0 = _mm256_loadu_ps(inptr0 + 0); // A0A1A2A3A4A5A6A7
auto ymm1 = _mm256_loadu_ps(inptr0 + 8); // a0a1a2a3a4a5a6a7
auto ymm2 = _mm256_loadu_ps(inptr1 + 0); // B0B1B2B3B4B5B6B7
auto ymm3 = _mm256_loadu_ps(inptr1 + 8); // b0b1b2b3b4b5b6b7
auto ymm4 = _mm256_loadu_ps(inptr2 + 0); // C0C1C2C3C4C5C6C7
auto ymm5 = _mm256_loadu_ps(inptr2 + 8); // c0c1c2c3c4c5c6c7
auto ymm6 = _mm256_loadu_ps(inptr3 + 0); // D0D1D2D3D4D5D6D7
auto ymm7 = _mm256_loadu_ps(inptr3 + 8); // d0d1d2d3d4d5d6d7
auto ymm8 = _mm256_unpacklo_ps(ymm0, ymm4); // A0C0A1C1A4C4A5C5
auto ymm9 = _mm256_unpackhi_ps(ymm0, ymm4); // A2C2A3C3A6C6A7C7
auto ymm10 = _mm256_unpacklo_ps(ymm2, ymm6); // B0D0B1D1B4D4B5D5
auto ymm11 = _mm256_unpackhi_ps(ymm2, ymm6); // B2D2B3D3B6D6B7D7
auto ymm12 = _mm256_unpacklo_ps(ymm1, ymm5); // a0c0a1c1a4c4a5c5
auto ymm13 = _mm256_unpackhi_ps(ymm1, ymm5); // a2c2a3c3a6c6a7c7
auto ymm14 = _mm256_unpacklo_ps(ymm3, ymm7); // b0d0b1d1b4d4b5d5
auto ymm15 = _mm256_unpackhi_ps(ymm3, ymm7); // b2d2b3d3b6d6b7d7
ymm0 = _mm256_unpacklo_ps(ymm8, ymm10); // A0B0C0D0A4B4C4D4
ymm1 = _mm256_unpackhi_ps(ymm8, ymm10); // A1B1C1D1A5B5C5D5
ymm2 = _mm256_unpacklo_ps(ymm9, ymm11); // A2B2C2D2A6B6C6D6
ymm3 = _mm256_unpackhi_ps(ymm9, ymm11); // A3B3C3D3A7B7C7D7
ymm4 = _mm256_unpacklo_ps(ymm12, ymm14); // a0b0c0d0a4b4c4d4
ymm5 = _mm256_unpackhi_ps(ymm12, ymm14); // a1b1c1d1a5b5c5d5
ymm6 = _mm256_unpacklo_ps(ymm13, ymm15); // a2b2c2d2a6b6c6d6
ymm7 = _mm256_unpackhi_ps(ymm13, ymm15); // a3b3c3d3a7b7c7d7
_mm256_storeu2_m128_emulate(outptr + 6 * 4, outptr + 6 * 0, ymm0);
_mm256_storeu2_m128_emulate(outptr + 6 * 5, outptr + 6 * 1, ymm1);
_mm256_storeu2_m128_emulate(outptr + 6 * 6, outptr + 6 * 2, ymm2);
_mm256_storeu2_m128_emulate(outptr + 6 * 7, outptr + 6 * 3, ymm3);
_mm256_storeu2_m128_emulate(outptr + 6 * 12, outptr + 6 * 8, ymm4);
_mm256_storeu2_m128_emulate(outptr + 6 * 13, outptr + 6 * 9, ymm5);
_mm256_storeu2_m128_emulate(outptr + 6 * 14, outptr + 6 * 10, ymm6);
_mm256_storeu2_m128_emulate(outptr + 6 * 15, outptr + 6 * 11, ymm7);
float other[4 * 8];
ymm8 = _mm256_loadu_ps(inptr4 + 0); // E0E1E2E3E4E5E6E7
ymm9 = _mm256_loadu_ps(inptr4 + 8); // e0e1e2e3e4e5e6e7
ymm10 = _mm256_loadu_ps(inptr5 + 0); // F0F1F2F3F4F5F6F7
ymm11 = _mm256_loadu_ps(inptr5 + 8); // f0f1f2f3f4f5f6f7
_mm256_storeu_ps(other, ymm8);
_mm256_storeu_ps(other + 8, ymm9);
_mm256_storeu_ps(other + 16, ymm10);
_mm256_storeu_ps(other + 24, ymm11);
for (size_t i = 0; i < 16; i++) {
outptr[6 * i + 4] = other[i];
outptr[6 * i + 5] = other[i + 16];
}
}
DNN_AVX2_TARGET
void transpose_6x8_1_s(const float *inptr0, const float *inptr1,
const float *inptr2, const float *inptr3,
const float *inptr4, const float *inptr5,
float *outptr) {
auto ymm0 = _mm256_loadu_ps(inptr0); // A0A1A2A3A4A5A6A7
auto ymm1 = _mm256_loadu_ps(inptr1); // B0B1B2B3B4B5B6B7
auto ymm2 = _mm256_loadu_ps(inptr2); // C0C1C2C3C4C5C6C7
auto ymm3 = _mm256_loadu_ps(inptr3); // D0D1D2D3D4D5D6D7
auto ymm4 = _mm256_unpacklo_ps(ymm0, ymm2); // A0C0A1C1A4C4A5C5
auto ymm5 = _mm256_unpackhi_ps(ymm0, ymm2); // A2C2A3C3A6C6A7C7
auto ymm6 = _mm256_unpacklo_ps(ymm1, ymm3); // B0D0B1D1B4D4B5D5
auto ymm7 = _mm256_unpackhi_ps(ymm1, ymm3); // B2D2B3D3B6D6B7D7
auto ymm8 = _mm256_unpacklo_ps(ymm4, ymm6); // A0B0C0D0A4B4C4D4
auto ymm9 = _mm256_unpackhi_ps(ymm4, ymm6); // A1B1C1D1A5B5C5D5
auto ymm10 = _mm256_unpacklo_ps(ymm5, ymm7); // A2B2C2D2A6B6C6D6
auto ymm11 = _mm256_unpackhi_ps(ymm5, ymm7); // A3B3C3D3A7B7C7D7
_mm256_storeu2_m128_emulate(outptr + 6 * 4, outptr + 6 * 0, ymm8);
_mm256_storeu2_m128_emulate(outptr + 6 * 5, outptr + 6 * 1, ymm9);
_mm256_storeu2_m128_emulate(outptr + 6 * 6, outptr + 6 * 2, ymm10);
_mm256_storeu2_m128_emulate(outptr + 6 * 7, outptr + 6 * 3, ymm11);
float other[16];
auto ymm12 = _mm256_loadu_ps(inptr4); // E0E1E2E3E4E5E6E7
auto ymm13 = _mm256_loadu_ps(inptr5); // F0F1F2F3F4F5F6F7
_mm256_storeu_ps(other, ymm12);
_mm256_storeu_ps(other + 8, ymm13);
for (size_t i = 0; i < 8; i++) {
outptr[6 * i + 4] = other[i];
outptr[6 * i + 5] = other[8 + i];
}
}
DNN_AVX2_TARGET
void transpose_6x4_1_s(const float *inptr0, const float *inptr1,
const float *inptr2, const float *inptr3,
const float *inptr4, const float *inptr5,
float *outptr) {
const std::uint32_t arr[8] = {0, 1, 4, 5, 2, 3, 6, 7};
__m256i order = _mm256_loadu_si256((const __m256i *)arr);
auto ymm0 = _mm256_loadu2_m128_emulate(inptr2, inptr0); // A0A1A2A3C0C1C2C3
auto ymm1 = _mm256_loadu2_m128_emulate(inptr3, inptr1); // B0B1B2B3D0D1D2D3
auto ymm2 = _mm256_unpacklo_ps(ymm0, ymm1); // A0B0A1B1C0D0C1D1
auto ymm3 = _mm256_unpackhi_ps(ymm0, ymm1); // A2B2A3B3C2D2C3D3
auto ymm4 = _mm256_permutevar8x32_ps(ymm2, order); // A0B0C0D0A1B1C1D1
auto ymm5 = _mm256_permutevar8x32_ps(ymm3, order); // A2B2C2D2A3B3C3D3
_mm256_storeu2_m128_emulate(outptr + 6 * 1, outptr + 6 * 0, ymm4);
_mm256_storeu2_m128_emulate(outptr + 6 * 3, outptr + 6 * 2, ymm5);
float other[8];
auto ymm6 = _mm256_loadu2_m128_emulate(inptr5, inptr4); // E0E1E2E3E4E5E6E7
_mm256_storeu_ps(other, ymm6);
for (size_t i = 0; i < 4; i++) {
outptr[6 * i + 4] = other[i];
outptr[6 * i + 5] = other[4 + i];
}
}
DNN_AVX2_TARGET
void transpose_4x8_1_s(const float *inptr0, const float *inptr1,
const float *inptr2, const float *inptr3,
float *outptr) {
auto ymm0 = _mm256_loadu_ps(inptr0); // A0A1A2A3A4A5A6A7
auto ymm1 = _mm256_loadu_ps(inptr1); // B0B1B2B3B4B5B6B7
auto ymm2 = _mm256_loadu_ps(inptr2); // C0C1C2C3C4C5C6C7
auto ymm3 = _mm256_loadu_ps(inptr3); // D0D1D2D3D4D5D6D7
auto ymm4 = _mm256_unpacklo_ps(ymm0, ymm2); // A0C0A1C1A4C4A5C5
auto ymm5 = _mm256_unpackhi_ps(ymm0, ymm2); // A2C2A3C3A6C6A7C7
auto ymm6 = _mm256_unpacklo_ps(ymm1, ymm3); // B0D0B1D1B4D4B5D5
auto ymm7 = _mm256_unpackhi_ps(ymm1, ymm3); // B2D2B3D3B6D6B7D7
auto ymm8 = _mm256_unpacklo_ps(ymm4, ymm6); // A0B0C0D0A4B4C4D4
auto ymm9 = _mm256_unpackhi_ps(ymm4, ymm6); // A1B1C1D1A5B5C5D5
auto ymm10 = _mm256_unpacklo_ps(ymm5, ymm7); // A2B2C2D2A6B6C6D6
auto ymm11 = _mm256_unpackhi_ps(ymm5, ymm7); // A3B3C3D3A7B7C7D7
ymm0 = _mm256_permute2f128_ps(ymm8, ymm9, 0x20); // A0B0C0D0A1B1C1D1
ymm1 = _mm256_permute2f128_ps(ymm10, ymm11, 0x20); // A2B2C2D2A3B3C3D3
ymm2 = _mm256_permute2f128_ps(ymm8, ymm9, 0x31); // A4B4C4D4A5B5C5D5
ymm3 = _mm256_permute2f128_ps(ymm10, ymm11, 0x31); // A6B6C6D6A7B7C7D7
_mm256_storeu_ps(outptr + 8 * 0, ymm0);
_mm256_storeu_ps(outptr + 8 * 1, ymm1);
_mm256_storeu_ps(outptr + 8 * 2, ymm2);
_mm256_storeu_ps(outptr + 8 * 3, ymm3);
}
DNN_AVX2_TARGET
void transpose_4x4_1_s(const float *inptr0, const float *inptr1,
const float *inptr2, const float *inptr3,
float *outptr) {
const std::uint32_t arr[8] = {0, 1, 4, 5, 2, 3, 6, 7};
__m256i order = _mm256_loadu_si256((const __m256i *)arr);
auto ymm0 = _mm256_loadu2_m128_emulate(inptr2, inptr0); // A0A1A2A3C0C1C2C3
auto ymm1 = _mm256_loadu2_m128_emulate(inptr3, inptr1); // B0B1B2B3D0D1D2D3
auto ymm2 = _mm256_unpacklo_ps(ymm0, ymm1); // A0B0A1B1C0D0C1D1
auto ymm3 = _mm256_unpackhi_ps(ymm0, ymm1); // A2B2A3B3C2D2C3D3
auto ymm4 = _mm256_permutevar8x32_ps(ymm2, order); // A0B0C0D0A1B1C1D1
auto ymm5 = _mm256_permutevar8x32_ps(ymm3, order); // A2B2C2D2A3B3C3D3
_mm256_storeu_ps(outptr, ymm4);
_mm256_storeu_ps(outptr + 8, ymm5);
}
void transpose_2x16_1_s(const float *inptr0, const float *inptr1,
float *outptr) {
for (size_t i = 0; i < 16; i++) {
*outptr++ = inptr0[i];
*outptr++ = inptr1[i];
}
}
void transpose_2x8_1_s(const float *inptr0, const float *inptr1,
float *outptr) {
for (size_t i = 0; i < 8; i++) {
*outptr++ = inptr0[i];
*outptr++ = inptr1[i];
}
}
void transpose_2x4_1_s(const float *inptr0, const float *inptr1,
float *outptr) {
for (size_t i = 0; i < 4; i++) {
*outptr++ = inptr0[i];
*outptr++ = inptr1[i];
}
}
DNN_AVX2_TARGET
void interleave_1x16_1_s(const float *inptr0, float *outptr) {
auto ymm0 = _mm256_loadu_ps(inptr0);
auto ymm1 = _mm256_loadu_ps(inptr0 + 8);
_mm256_storeu_ps(outptr, ymm0);
_mm256_storeu_ps(outptr + 8, ymm1);
}
DNN_AVX2_TARGET
void interleave_8x16_1_s(const float *inptr0, const float *inptr1,
const float *inptr2, const float *inptr3,
const float *inptr4, const float *inptr5,
const float *inptr6, const float *inptr7,
float *outptr) {
auto ymm0 = _mm256_loadu_ps(inptr0);
auto ymm1 = _mm256_loadu_ps(inptr0 + 8);
auto ymm2 = _mm256_loadu_ps(inptr1);
auto ymm3 = _mm256_loadu_ps(inptr1 + 8);
auto ymm4 = _mm256_loadu_ps(inptr2);
auto ymm5 = _mm256_loadu_ps(inptr2 + 8);
auto ymm6 = _mm256_loadu_ps(inptr3);
auto ymm7 = _mm256_loadu_ps(inptr3 + 8);
auto ymm8 = _mm256_loadu_ps(inptr4);
auto ymm9 = _mm256_loadu_ps(inptr4 + 8);
auto ymm10 = _mm256_loadu_ps(inptr5);
auto ymm11 = _mm256_loadu_ps(inptr5 + 8);
auto ymm12 = _mm256_loadu_ps(inptr6);
auto ymm13 = _mm256_loadu_ps(inptr6 + 8);
auto ymm14 = _mm256_loadu_ps(inptr7);
auto ymm15 = _mm256_loadu_ps(inptr7 + 8);
_mm256_storeu_ps(outptr + 8 * 0, ymm0);
_mm256_storeu_ps(outptr + 8 * 1, ymm1);
_mm256_storeu_ps(outptr + 8 * 2, ymm2);
_mm256_storeu_ps(outptr + 8 * 3, ymm3);
_mm256_storeu_ps(outptr + 8 * 4, ymm4);
_mm256_storeu_ps(outptr + 8 * 5, ymm5);
_mm256_storeu_ps(outptr + 8 * 6, ymm6);
_mm256_storeu_ps(outptr + 8 * 7, ymm7);
_mm256_storeu_ps(outptr + 8 * 8, ymm8);
_mm256_storeu_ps(outptr + 8 * 9, ymm9);
_mm256_storeu_ps(outptr + 8 * 10, ymm10);
_mm256_storeu_ps(outptr + 8 * 11, ymm11);
_mm256_storeu_ps(outptr + 8 * 12, ymm12);
_mm256_storeu_ps(outptr + 8 * 13, ymm13);
_mm256_storeu_ps(outptr + 8 * 14, ymm14);
_mm256_storeu_ps(outptr + 8 * 15, ymm15);
}
DNN_AVX2_TARGET
void interleave_8x4_1_s(const float *inptr0, const float *inptr1,
const float *inptr2, const float *inptr3,
const float *inptr4, const float *inptr5,
const float *inptr6, const float *inptr7,
float *outptr) {
auto ymm0 = _mm256_loadu2_m128_emulate(inptr1, inptr0); // A0A1A2A3B0B1B2B3
auto ymm1 = _mm256_loadu2_m128_emulate(inptr3, inptr2); // C0C1C2C3D0D1D2D3
auto ymm2 = _mm256_loadu2_m128_emulate(inptr5, inptr4); // E0E1E2E3F0F1F2F3
auto ymm3 = _mm256_loadu2_m128_emulate(inptr7, inptr6); // G0G1G2G3H0H1H2H3
_mm256_storeu_ps(outptr + 8 * 0, ymm0);
_mm256_storeu_ps(outptr + 8 * 1, ymm1);
_mm256_storeu_ps(outptr + 8 * 2, ymm2);
_mm256_storeu_ps(outptr + 8 * 3, ymm3);
}
void interleave_8x2_1_s(const float *inptr0, const float *inptr1,
const float *inptr2, const float *inptr3,
const float *inptr4, const float *inptr5,
const float *inptr6, const float *inptr7,
float *outptr) {
#define cb(i) \
*outptr++ = inptr##i[0]; \
*outptr++ = inptr##i[1];
UNROLL_CODE(cb, 8)
#undef cb
}
void interleave_1x4_1_s(const float *inptr0, float *outptr) {
outptr[0] = inptr0[0];
outptr[1] = inptr0[1];
outptr[2] = inptr0[2];
outptr[3] = inptr0[3];
}
void interleave_8x6_1_s(const float *inptr0, const float *inptr1,
const float *inptr2, const float *inptr3,
const float *inptr4, const float *inptr5,
const float *inptr6, const float *inptr7,
float *outptr) {
#define cb(i) auto xmm##i = _mm_loadu_ps(inptr##i);
UNROLL_CODE(cb, 8)
#undef cb
#define cb(i) _mm_storeu_ps(outptr + 6 * i, xmm##i);
UNROLL_CODE(cb, 8)
#undef cb
#define cb(i) \
outptr[6 * i + 4] = inptr##i[4]; \
outptr[6 * i + 5] = inptr##i[5];
UNROLL_CODE(cb, 8)
#undef cb
}
void interleave_1x6_1_s(const float *inptr0, float *outptr) {
outptr[0] = inptr0[0];
outptr[1] = inptr0[1];
outptr[2] = inptr0[2];
outptr[3] = inptr0[3];
outptr[4] = inptr0[4];
outptr[5] = inptr0[5];
}
void interleave_1x2_1_s(const float *inptr0, float *outptr) {
outptr[0] = inptr0[0];
outptr[1] = inptr0[1];
}
static inline void interleave_helper(const float *inptr, float *outptr,
int unroll_k, int ksize, float val) {
int k = 0;
for (; k < ksize; k++) {
*outptr++ = *inptr++;
}
for (; k < unroll_k; k++) {
*outptr++ = val;
}
}
void interleave_1(const float *inptr0, float *outptr, int unroll_k, int ksize,
float val) {
for (int k = 0; k < ksize; k += unroll_k) {
int size = min(unroll_k, ksize - k);
interleave_helper(inptr0, outptr, unroll_k, size, val);
inptr0 += size;
outptr += unroll_k;
}
}
void interleave_8(const float *inptr0, const float *inptr1, const float *inptr2,
const float *inptr3, const float *inptr4, const float *inptr5,
const float *inptr6, const float *inptr7, float *outptr,
int unroll_k, int ksize, float val) {
for (int k = 0; k < ksize; k += unroll_k) {
int size = min(unroll_k, ksize - k);
interleave_helper(inptr0, outptr, unroll_k, size, val);
inptr0 += size;
outptr += unroll_k;
interleave_helper(inptr1, outptr, unroll_k, size, val);
inptr1 += size;
outptr += unroll_k;
interleave_helper(inptr2, outptr, unroll_k, size, val);
inptr2 += size;
outptr += unroll_k;
interleave_helper(inptr3, outptr, unroll_k, size, val);
inptr3 += size;
outptr += unroll_k;
interleave_helper(inptr4, outptr, unroll_k, size, val);
inptr4 += size;
outptr += unroll_k;
interleave_helper(inptr5, outptr, unroll_k, size, val);
inptr5 += size;
outptr += unroll_k;
interleave_helper(inptr6, outptr, unroll_k, size, val);
inptr6 += size;
outptr += unroll_k;
interleave_helper(inptr7, outptr, unroll_k, size, val);
inptr7 += size;
outptr += unroll_k;
}
}
DNN_AVX2_TARGET
MEGDNN_ATTRIBUTE_TARGET("fma")
void gemm_6x16_kern2x16(const float *packA, const float *packB, int K,
float *output, int LDC, bool is_first_k, int m_remain) {
const float *cur_b = packB;
const float *cur_a = packA;
__m256 ymm0, ymm1, ymm2, ymm3;
__m256 b_tmp0, b_tmp1;
__m256 tmp;
if (is_first_k) {
#define cb(i) ymm##i = _mm256_set1_ps(0.0f);
UNROLL_CODE(cb, 4)
#undef cb
} else {
ymm0 = _mm256_loadu_ps(output + LDC * 0 + 0);
ymm1 = _mm256_loadu_ps(output + LDC * 0 + 8);
ymm2 = _mm256_loadu_ps(output + LDC * 1 + 0);
ymm3 = _mm256_loadu_ps(output + LDC * 1 + 8);
}
b_tmp0 = _mm256_loadu_ps(cur_b);
b_tmp1 = _mm256_loadu_ps(cur_b + 8);
size_t i = 0;
for (; i + 2 <= K; i += 2) {
cur_b += 16;
#define CAL_OUPUT(i, first, second) \
tmp = _mm256_broadcast_ss(cur_a + i); \
ymm##first = _mm256_fmadd_ps(b_tmp0, tmp, ymm##first); \
ymm##second = _mm256_fmadd_ps(b_tmp1, tmp, ymm##second);
CAL_OUPUT(0, 0, 1)
CAL_OUPUT(1, 2, 3)
b_tmp0 = _mm256_loadu_ps(cur_b);
b_tmp1 = _mm256_loadu_ps(cur_b + 8);
cur_b += 16;
CAL_OUPUT(2, 0, 1)
CAL_OUPUT(3, 2, 3)
cur_a += 4;
b_tmp0 = _mm256_loadu_ps(cur_b);
b_tmp1 = _mm256_loadu_ps(cur_b + 8);
}
if (i < K) {
CAL_OUPUT(0, 0, 1)
CAL_OUPUT(1, 2, 3)
}
#undef CAL_OUPUT
switch (m_remain) {
case 2:
_mm256_storeu_ps(output + LDC * 1 + 0, ymm2);
_mm256_storeu_ps(output + LDC * 1 + 8, ymm3);
case 1:
_mm256_storeu_ps(output + LDC * 0 + 0, ymm0);
_mm256_storeu_ps(output + LDC * 0 + 8, ymm1);
default:
break;
}
}
DNN_AVX2_TARGET
MEGDNN_ATTRIBUTE_TARGET("fma")
void gemm_6x16_kern6x4(const float *packA, const float *packB, int K,
float *output, int LDC, bool is_first_k, int n_remain) {
const float *cur_b = packB;
const float *cur_a = packA;
__m128 xmm0, xmm1, xmm2, xmm3, xmm4, xmm5;
__m128 tmp_a, tmp_b;
if (is_first_k) {
xmm0 = _mm_set1_ps(0.0f);
xmm1 = _mm_set1_ps(0.0f);
xmm2 = _mm_set1_ps(0.0f);
xmm3 = _mm_set1_ps(0.0f);
xmm4 = _mm_set1_ps(0.0f);
xmm5 = _mm_set1_ps(0.0f);
} else {
xmm0 = _mm_loadu_ps(output + LDC * 0);
xmm1 = _mm_loadu_ps(output + LDC * 1);
xmm2 = _mm_loadu_ps(output + LDC * 2);
xmm3 = _mm_loadu_ps(output + LDC * 3);
xmm4 = _mm_loadu_ps(output + LDC * 4);
xmm5 = _mm_loadu_ps(output + LDC * 5);
}
for (size_t i = 0; i < K; i++) {
tmp_b = _mm_loadu_ps(cur_b);
cur_b += 4;
tmp_a = _mm_broadcast_ss(cur_a);
xmm0 = _mm_fmadd_ps(tmp_a, tmp_b, xmm0);
tmp_a = _mm_broadcast_ss(cur_a + 1);
xmm1 = _mm_fmadd_ps(tmp_a, tmp_b, xmm1);
tmp_a = _mm_broadcast_ss(cur_a + 2);
xmm2 = _mm_fmadd_ps(tmp_a, tmp_b, xmm2);
tmp_a = _mm_broadcast_ss(cur_a + 3);
xmm3 = _mm_fmadd_ps(tmp_a, tmp_b, xmm3);
tmp_a = _mm_broadcast_ss(cur_a + 4);
xmm4 = _mm_fmadd_ps(tmp_a, tmp_b, xmm4);
tmp_a = _mm_broadcast_ss(cur_a + 5);
xmm5 = _mm_fmadd_ps(tmp_a, tmp_b, xmm5);
cur_a += 6;
}
if (n_remain == 4) {
_mm_storeu_ps(output + LDC * 0, xmm0);
_mm_storeu_ps(output + LDC * 1, xmm1);
_mm_storeu_ps(output + LDC * 2, xmm2);
_mm_storeu_ps(output + LDC * 3, xmm3);
_mm_storeu_ps(output + LDC * 4, xmm4);
_mm_storeu_ps(output + LDC * 5, xmm5);
} else {
float dst[6 * 4];
_mm_storeu_ps(dst + 4 * 0, xmm0);
_mm_storeu_ps(dst + 4 * 1, xmm1);
_mm_storeu_ps(dst + 4 * 2, xmm2);
_mm_storeu_ps(dst + 4 * 3, xmm3);
_mm_storeu_ps(dst + 4 * 4, xmm4);
_mm_storeu_ps(dst + 4 * 5, xmm5);
for (size_t i = 0; i < n_remain; i++) {
for (size_t j = 0; j < 6; j++) {
output[LDC * j + i] = dst[4 * j + i];
}
}
}
}
DNN_AVX2_TARGET
MEGDNN_ATTRIBUTE_TARGET("fma")
void gemm_6x16_kern2x4(const float *packA, const float *packB, int K,
float *output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
const float *cur_b = packB;
const float *cur_a = packA;
__m128 xmm0, xmm1;
__m128 tmp_a, tmp_b;
if (is_first_k) {
xmm0 = _mm_set1_ps(0.0f);
xmm1 = _mm_set1_ps(0.0f);
} else {
xmm0 = _mm_loadu_ps(output + LDC * 0);
xmm1 = _mm_loadu_ps(output + LDC * 1);
}
for (size_t i = 0; i < K; i++) {
tmp_b = _mm_loadu_ps(cur_b);
cur_b += 4;
tmp_a = _mm_broadcast_ss(cur_a);
xmm0 = _mm_fmadd_ps(tmp_a, tmp_b, xmm0);
tmp_a = _mm_broadcast_ss(cur_a + 1);
xmm1 = _mm_fmadd_ps(tmp_a, tmp_b, xmm1);
cur_a += 2;
}
float dst[2 * 4];
_mm_storeu_ps(dst + 4 * 0, xmm0);
_mm_storeu_ps(dst + 4 * 1, xmm1);
for (size_t i = 0; i < n_remain; i++) {
for (size_t j = 0; j < m_remain; j++) {
output[LDC * j + i] = dst[4 * j + i];
}
}
}
DNN_AVX2_TARGET
MEGDNN_ATTRIBUTE_TARGET("fma")
void gemm_6x16_kern6x16(const float *packA, const float *packB, int K,
float *output, int LDC, bool is_first_k) {
const float *cur_b = packB;
const float *cur_a = packA;
__m256 ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8, ymm9, ymm10,
ymm11;
__m256 b_tmp0, b_tmp1;
__m256 tmp;
if (is_first_k) {
#define cb(i) ymm##i = _mm256_set1_ps(0.0f);
UNROLL_CODE(cb, 12)
#undef cb
} else {
ymm0 = _mm256_loadu_ps(output + LDC * 0 + 0);
ymm1 = _mm256_loadu_ps(output + LDC * 0 + 8);
ymm2 = _mm256_loadu_ps(output + LDC * 1 + 0);
ymm3 = _mm256_loadu_ps(output + LDC * 1 + 8);
ymm4 = _mm256_loadu_ps(output + LDC * 2 + 0);
ymm5 = _mm256_loadu_ps(output + LDC * 2 + 8);
ymm6 = _mm256_loadu_ps(output + LDC * 3 + 0);
ymm7 = _mm256_loadu_ps(output + LDC * 3 + 8);
ymm8 = _mm256_loadu_ps(output + LDC * 4 + 0);
ymm9 = _mm256_loadu_ps(output + LDC * 4 + 8);
ymm10 = _mm256_loadu_ps(output + LDC * 5 + 0);
ymm11 = _mm256_loadu_ps(output + LDC * 5 + 8);
}
b_tmp0 = _mm256_loadu_ps(cur_b);
b_tmp1 = _mm256_loadu_ps(cur_b + 8);
size_t i = 0;
for (; i + 2 <= K; i += 2) {
cur_b += 16;
#define CAL_OUPUT(i, first, second) \
tmp = _mm256_broadcast_ss(cur_a + i); \
ymm##first = _mm256_fmadd_ps(b_tmp0, tmp, ymm##first); \
ymm##second = _mm256_fmadd_ps(b_tmp1, tmp, ymm##second);
CAL_OUPUT(0, 0, 1)
CAL_OUPUT(1, 2, 3)
CAL_OUPUT(2, 4, 5)
CAL_OUPUT(3, 6, 7)
CAL_OUPUT(4, 8, 9)
CAL_OUPUT(5, 10, 11)
b_tmp0 = _mm256_loadu_ps(cur_b);
b_tmp1 = _mm256_loadu_ps(cur_b + 8);
cur_b += 16;
CAL_OUPUT(6, 0, 1)
CAL_OUPUT(7, 2, 3)
CAL_OUPUT(8, 4, 5)
CAL_OUPUT(9, 6, 7)
CAL_OUPUT(10, 8, 9)
CAL_OUPUT(11, 10, 11)
cur_a += 12;
b_tmp0 = _mm256_loadu_ps(cur_b);
b_tmp1 = _mm256_loadu_ps(cur_b + 8);
}
if (i < K) {
CAL_OUPUT(0, 0, 1)
CAL_OUPUT(1, 2, 3)
CAL_OUPUT(2, 4, 5)
CAL_OUPUT(3, 6, 7)
CAL_OUPUT(4, 8, 9)
CAL_OUPUT(5, 10, 11)
}
#undef CAL_OUPUT
_mm256_storeu_ps(output + LDC * 0 + 0, ymm0);
_mm256_storeu_ps(output + LDC * 0 + 8, ymm1);
_mm256_storeu_ps(output + LDC * 1 + 0, ymm2);
_mm256_storeu_ps(output + LDC * 1 + 8, ymm3);
_mm256_storeu_ps(output + LDC * 2 + 0, ymm4);
_mm256_storeu_ps(output + LDC * 2 + 8, ymm5);
_mm256_storeu_ps(output + LDC * 3 + 0, ymm6);
_mm256_storeu_ps(output + LDC * 3 + 8, ymm7);
_mm256_storeu_ps(output + LDC * 4 + 0, ymm8);
_mm256_storeu_ps(output + LDC * 4 + 8, ymm9);
_mm256_storeu_ps(output + LDC * 5 + 0, ymm10);
_mm256_storeu_ps(output + LDC * 5 + 8, ymm11);
}
void gemm_6x16_kern(const float *packA, const float *packB, size_t M, size_t N,
size_t K, float *C, size_t LDC, int is_first_k) {
size_t n = 0;
const int K2 = K * 2;
const int K4 = K * 4;
const int K6 = K * 6;
const int K16 = K * 16;
const int A_INTERLEAVE6 = 6;
const int A_INTERLEAVE2 = 2;
const int B_INTERLEAVE16 = 16;
const int B_INTERLEAVE4 = 4;
auto *cur_packB = packB;
for (; n + B_INTERLEAVE16 <= N; n += B_INTERLEAVE16) {
size_t m = 0;
auto output = C + n;
auto *cur_packA = packA;
for (; m + A_INTERLEAVE6 <= M; m += A_INTERLEAVE6) {
gemm_6x16_kern6x16(cur_packA, cur_packB, K, output, LDC, is_first_k);
output += A_INTERLEAVE6 * LDC;
cur_packA += K6;
}
for (; m < M; m += A_INTERLEAVE2) {
gemm_6x16_kern2x16(cur_packA, cur_packB, K, output, LDC, is_first_k,
min(M - m, 2));
output += A_INTERLEAVE2 * LDC;
cur_packA += K2;
}
cur_packB += K16;
}
for (; n < N; n += B_INTERLEAVE4) {
size_t m = 0;
auto output = C + n;
auto *cur_packA = packA;
for (; m + A_INTERLEAVE6 <= M; m += A_INTERLEAVE6) {
gemm_6x16_kern6x4(cur_packA, cur_packB, K, output, LDC, is_first_k,
min(N - n, 4));
output += A_INTERLEAVE6 * LDC;
cur_packA += K6;
}
for (; m < M; m += A_INTERLEAVE2) {
gemm_6x16_kern2x4(cur_packA, cur_packB, K, output, LDC, is_first_k,
min(M - m, 2), min(N - n, 4));
output += A_INTERLEAVE2 * LDC;
cur_packA += K2;
}
cur_packB += K4;
}
}
void gemm_6x16_pack_A_t(float *outptr, const float *inptr, int ldin, int x0,
int xmax, int k0, int kmax) {
size_t ksize = kmax - k0;
size_t ksize6 = ksize * 6;
size_t ksize2 = ksize * 2;
float *outptr_base6 = outptr;
float *outptr_base2 = outptr_base6 + (xmax - x0) / 6 * ksize6;
size_t k = k0;
for (; k + 7 < kmax; k += 8) {
const float *cur_inptr = inptr + k * ldin + k0;
#define cb(i) const float *inptr##i = cur_inptr + ldin * i;
UNROLL_CODE(cb, 8)
#undef cb
#define cb(i) __builtin_prefetch(inptr##i, 0, 3);
UNROLL_CODE(cb, 8)
#undef cb
int x = x0;
float *outptr = outptr_base6;
for (; x + 6 <= xmax; x += 6) {
interleave_8x6_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr);
#define cb(i) inptr##i += 6;
UNROLL_CODE(cb, 8)
#undef cb
outptr += ksize6;
}
outptr = outptr_base2;
for (; x + 2 <= xmax; x += 2) {
interleave_8x2_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr);
#define cb(i) inptr##i += 2;
UNROLL_CODE(cb, 8)
#undef cb
outptr += ksize2;
}
if (x < xmax) {
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 2, xmax - x, 0);
inptr0 += xmax - x;
inptr1 += xmax - x;
inptr2 += xmax - x;
inptr3 += xmax - x;
inptr4 += xmax - x;
inptr5 += xmax - x;
inptr6 += xmax - x;
inptr7 += xmax - x;
}
outptr_base6 += 8 * 6;
outptr_base2 += 8 * 2;
}
for (; k < kmax; k++) {
const float *inptr0 = inptr + k * ldin + k0;
__builtin_prefetch(inptr0, 0, 3);
int x = x0;
float *outptr = outptr_base6;
for (; x + 6 <= xmax; x += 6) {
interleave_1x6_1_s(inptr0, outptr);
inptr0 += 6;
outptr += ksize6;
}
outptr = outptr_base2;
for (; x + 2 <= xmax; x += 2) {
interleave_1x2_1_s(inptr0, outptr);
inptr0 += 2;
outptr += ksize2;
}
if (x < xmax) {
interleave_1(inptr0, outptr, 2, xmax - x, 0);
inptr0 += xmax - x;
outptr += 2;
}
outptr_base6 += 6;
outptr_base2 += 2;
}
}
void gemm_6x16_pack_A_n(float *outptr, const float *inptr, int ldin, int y0,
int ymax, int k0, int kmax) {
float zerobuff[16];
memset(zerobuff, 0, sizeof(float) * 16);
size_t y = y0;
const size_t PACK_SIZE_96 = 6 * 16;
const size_t PACK_SIZE_48 = 6 * 8;
const size_t PACK_SIZE_24 = 6 * 4;
const size_t PACK_SIZE_32 = 4 * 8;
const size_t PACK_SIZE_16 = 4 * 4;
const size_t PACK_SIZE_8 = 4 * 2;
for (; y + 5 < ymax; y += 6) {
const float *cur_inptr = inptr + y * ldin + k0;
#define cb(i) const float *inptr##i = cur_inptr + ldin * i;
UNROLL_CODE(cb, 6)
#undef cb
#define cb(i) __builtin_prefetch(inptr##i, 0, 3);
UNROLL_CODE(cb, 6)
#undef cb
int x = (kmax - k0);
for (; x > 15; x -= 16) {
transpose_6x16_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
outptr);
#define cb(i) inptr##i += 16;
UNROLL_CODE(cb, 6)
#undef cb
outptr += PACK_SIZE_96;
}
for (; x > 7; x -= 8) {
transpose_6x8_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, outptr);
#define cb(i) inptr##i += 8;
UNROLL_CODE(cb, 6)
#undef cb
outptr += PACK_SIZE_48;
}
for (; x > 3; x -= 4) {
transpose_6x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, outptr);
#define cb(i) inptr##i += 4;
UNROLL_CODE(cb, 6)
#undef cb
outptr += PACK_SIZE_24;
}
for (; x > 0; x--) {
#define cb(i) *outptr++ = *inptr##i++;
UNROLL_CODE(cb, 6)
#undef cb
}
}
for (; y < ymax; y += 2) {
const float *cur_inptr = inptr + y * ldin + k0;
#define cb(i) const float *inptr##i = cur_inptr + ldin * i;
UNROLL_CODE(cb, 2)
#undef cb
#define cb(i) __builtin_prefetch(inptr##i, 0, 3);
UNROLL_CODE(cb, 2)
#undef cb
int x = kmax - k0;
for (; x > 15; x -= 16) {
if ((y + 1) >= ymax) {
inptr1 = zerobuff;
}
transpose_2x16_1_s(inptr0, inptr1, outptr);
#define cb(i) inptr##i += 16;
UNROLL_CODE(cb, 2)
#undef cb
outptr += PACK_SIZE_32;
}
for (; x > 7; x -= 8) {
if ((y + 1) >= ymax) {
inptr1 = zerobuff;
}
transpose_2x8_1_s(inptr0, inptr1, outptr);
#define cb(i) inptr##i += 8;
UNROLL_CODE(cb, 2)
#undef cb
outptr += PACK_SIZE_16;
}
for (; x > 3; x -= 4) {
if ((y + 1) >= ymax) {
inptr1 = zerobuff;
}
transpose_2x4_1_s(inptr0, inptr1, outptr);
#define cb(i) inptr##i += 4;
UNROLL_CODE(cb, 2)
#undef cb
outptr += PACK_SIZE_8;
}
if (x > 0) {
if ((y + 1) >= ymax) {
inptr1 = zerobuff;
}
for (size_t i = 0; i < x; i++) {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
}
}
}
}
void gemm_6x16_pack_B_t(float *outptr, const float *inptr, int ldin, int y0,
int ymax, int k0, int kmax) {
float zerobuff[16];
memset(zerobuff, 0, sizeof(float) * 16);
const size_t PACK_SIZE_128 = 8 * 16;
const size_t PACK_SIZE_64 = 4 * 16;
const size_t PACK_SiZE_32 = 4 * 8;
const size_t PACK_SIZE_16 = 4 * 4;
size_t y = y0;
for (; y + 15 < ymax; y += 16) {
const float *cur_inptr = inptr + y * ldin + k0;
#define cb(i) const float *inptr##i = cur_inptr + ldin * i;
UNROLL_CODE(cb, 16)
#undef cb
#define cb(i) __builtin_prefetch(inptr##i, 0, 3);
UNROLL_CODE(cb, 16)
#undef cb
int x = (kmax - k0);
for (; x > 7; x -= 8) {
transpose_16x8_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, inptr8, inptr9, inptr10, inptr11, inptr12,
inptr13, inptr14, inptr15, outptr);
#define cb(i) inptr##i += 8;
UNROLL_CODE(cb, 16)
#undef cb
outptr += PACK_SIZE_128;
}
for (; x > 3; x -= 4) {
transpose_16x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, inptr8, inptr9, inptr10, inptr11, inptr12,
inptr13, inptr14, inptr15, outptr);
#define cb(i) inptr##i += 4;
UNROLL_CODE(cb, 16)
#undef cb
outptr += PACK_SIZE_64;
}
for (; x > 0; x--) {
#define cb(i) *outptr++ = *inptr##i++;
UNROLL_CODE(cb, 16)
#undef cb
}
}
for (; y < ymax; y += 4) {
const float *cur_inptr = inptr + y * ldin + k0;
#define cb(i) const float *inptr##i = cur_inptr + ldin * i;
UNROLL_CODE(cb, 4)
#undef cb
#define cb(i) __builtin_prefetch(inptr##i, 0, 3);
UNROLL_CODE(cb, 4)
#undef cb
int x = kmax - k0;
for (; x > 7; x -= 8) {
if ((y + 3) >= ymax) {
switch ((y + 3) - ymax) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
default:
break;
}
}
transpose_4x8_1_s(inptr0, inptr1, inptr2, inptr3, outptr);
#define cb(i) inptr##i += 8;
UNROLL_CODE(cb, 4)
#undef cb
outptr += PACK_SiZE_32;
}
for (; x > 3; x -= 4) {
if ((y + 3) >= ymax) {
switch ((y + 3) - ymax) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
default:
break;
}
}
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr);
#define cb(i) inptr##i += 4;
UNROLL_CODE(cb, 4)
#undef cb
outptr += PACK_SIZE_16;
}
if (x > 0) {
if ((y + 3) >= ymax) {
switch ((y + 3) - ymax) {
case 2:
inptr1 = zerobuff;
case 1:
inptr2 = zerobuff;
case 0:
inptr3 = zerobuff;
break;
}
}
for (size_t i = 0; i < x; i++) {
*outptr++ = *inptr0++;
*outptr++ = *inptr1++;
*outptr++ = *inptr2++;
*outptr++ = *inptr3++;
}
}
}
}
void gemm_6x16_pack_B_n(float *outptr, const float *inptr, int ldin, int x0,
int xmax, int k0, int kmax) {
size_t ksize = kmax - k0;
size_t ksize16 = ksize * 16;
size_t ksize4 = ksize * 4;
float *outptr_base16 = outptr;
float *outptr_base4 = outptr_base16 + (xmax - x0) / 16 * ksize16;
size_t k = k0;
for (; k + 7 < kmax; k += 8) {
const float *cur_inptr = inptr + k * ldin + k0;
#define cb(i) const float *inptr##i = cur_inptr + ldin * i;
UNROLL_CODE(cb, 8)
#undef cb
#define cb(i) __builtin_prefetch(inptr##i, 0, 3);
UNROLL_CODE(cb, 8)
#undef cb
int x = x0;
float *outptr = outptr_base16;
for (; x + 16 <= xmax; x += 16) {
interleave_8x16_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
#define cb(i) inptr##i += 16;
UNROLL_CODE(cb, 8)
#undef cb
outptr += ksize16;
}
outptr = outptr_base4;
for (; x + 4 <= xmax; x += 4) {
interleave_8x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr);
#define cb(i) inptr##i += 4;
UNROLL_CODE(cb, 8)
#undef cb
outptr += ksize4;
}
if (x < xmax) {
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 4, xmax - x, 0);
inptr0 += xmax - x;
inptr1 += xmax - x;
inptr2 += xmax - x;
inptr3 += xmax - x;
inptr4 += xmax - x;
inptr5 += xmax - x;
inptr6 += xmax - x;
inptr7 += xmax - x;
}
outptr_base16 += 8 * 16;
outptr_base4 += 8 * 4;
}
for (; k < kmax; k++) {
const float *inptr0 = inptr + k * ldin + k0;
__builtin_prefetch(inptr0, 0, 3);
int x = x0;
float *outptr = outptr_base16;
for (; x + 16 <= xmax; x += 16) {
interleave_1x16_1_s(inptr0, outptr);
inptr0 += 16;
outptr += ksize16;
}
outptr = outptr_base4;
for (; x + 4 <= xmax; x += 4) {
interleave_1x4_1_s(inptr0, outptr);
inptr0 += 4;
outptr += ksize4;
}
if (x < xmax) {
interleave_1(inptr0, outptr, 4, xmax - x, 0);
inptr0 += xmax - x;
outptr += 4;
}
outptr_base16 += 16;
outptr_base4 += 4;
}
}
} // namespace
#undef UNROLL_CODE
namespace megdnn {
namespace x86 {
namespace matmul {
void sgemm_pack_6x16_avx2::pack_A(float *out, const float *in, int ldin, int y0,
int ymax, int k0, int kmax,
bool transpose_A) const {
if (!transpose_A)
gemm_6x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax);
else
gemm_6x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax);
}
void sgemm_pack_6x16_avx2::pack_B(float *out, const float *in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose_B) const {
if (!transpose_B)
gemm_6x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
else
gemm_6x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax);
}
void sgemm_pack_6x16_avx2::kern(const float *packA, const float *packB,
size_t M, size_t N, size_t K, float *C,
size_t LDC, bool is_first_k, const float *bias,
float *workspace) const {
MEGDNN_MARK_USED_VAR(bias);
MEGDNN_MARK_USED_VAR(workspace);
gemm_6x16_kern(packA, packB, M, N, K, C, LDC, is_first_k);
};
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_pack_6x16_avx2);
} // namespace matmul
} // namespace x86
} // namespace megdnn
......@@ -34,6 +34,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt8x8x16AVX2 algoint8x8x16avx2_m4n16k2;
AlgoInt8x8x16SSE algoint8x8x16sse_m4n8k2;
AlgoF32MK8_8x8 algof32mk8_8x8;
AlgoFloatAVX2M6N16 algof32_6x16;
SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos;
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map;
......@@ -51,6 +52,7 @@ public:
m_all_algos.emplace_back(&algoint8x8x32sse_m4n8k2);
m_all_algos.emplace_back(&algoint8x8x16sse_m4n8k2);
m_all_algos.emplace_back(&algof32mk8_8x8);
m_all_algos.emplace_back(&algof32_6x16);
#if MEGDNN_X86_WITH_MKL_DNN
m_all_algos.emplace_back(&algoint8x8x32mkldnn);
#endif
......
......@@ -68,7 +68,7 @@ private:
class AlgoInt8x8x16SSE;
class AlgoPack;
class AlgoF32MK8_8x8;
class AlgoFloatAVX2M6N16;
public:
static const AlgoPack& algo_pack();
};
......
......@@ -98,6 +98,15 @@ TEST_F(X86, SHAKE_MATRIX_MUL_FORWARD) {
.exec({{20, 100}, {100, 60}, {}});
}
TEST_F(X86, SHAKE_MATRIX_MUL_6x16_FORWARD) {
AccuracyShakeChecker<MatrixMul> checker(handle());
checker.set_before_exec_callback(AlgoGenerator<MatrixMul>("X86_F32_6x16"));
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.exec({{20, 100}, {100, 60}, {}});
}
} // namespace test
} // namespace megdnn
......
......@@ -1171,6 +1171,110 @@ TEST_F(X86, CONV_BIAS_IM2COLMATMUL_FP32_NOPACK_PREPROCESS) {
#endif
TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32_6x16) {
using namespace conv_bias;
std::vector<TestArg> args;
auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel,
size_t p, NonlineMode nonline_mode) {
if (w + 2 * p < kernel || h + 2 * p < kernel)
return;
param::ConvBias param;
param.stride_h = 1;
param.stride_w = 1;
param.pad_h = p;
param.pad_w = p;
param.nonlineMode = nonline_mode;
//! no bias
args.emplace_back(param, TensorShape{1, ic, h, w},
TensorShape{oc, ic, kernel, kernel}, TensorShape{});
args.emplace_back(param, TensorShape{1, ic, h, w},
TensorShape{oc, ic, kernel, kernel},
TensorShape{1, oc, 1, 1});
args.emplace_back(
param, TensorShape{1, ic, h, w},
TensorShape{oc, ic, kernel, kernel},
TensorShape{1, oc, (h + 2 * p - kernel) / param.stride_h + 1,
(w + 2 * p - kernel) / param.stride_w + 1});
};
for (size_t kernel : {2, 3, 4, 5, 6, 7})
for (size_t ic : {1, 4, 8, 16})
for (size_t oc : {1, 4, 8, 16, 300})
for (size_t p : {0, 2})
for (size_t size : {8,24})
for (NonlineMode nonline_mode :
{NonlineMode::IDENTITY, NonlineMode::RELU}) {
run(oc, ic, size, size, kernel, p, nonline_mode);
}
run(2046, 8, 20, 20, 3, 1, NonlineMode::IDENTITY);
Checker<ConvBias> checker(handle());
#define cb(algo_name) \
checker.set_before_exec_callback( \
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name)); \
for (auto&& arg : args) { \
checker.set_param(arg.param).execs( \
{arg.src, arg.filter, arg.bias, {}, {}}); \
}
cb("IM2COLMATMUL:X86_F32_6x16:192");
}
TEST_F(X86, CONV_BIAS_IM2COLMATMUL_FP32_6x16) {
using namespace conv_bias;
std::vector<TestArg> args;
auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel,
size_t p, NonlineMode nonline_mode) {
if (w + 2 * p < kernel || h + 2 * p < kernel)
return;
param::ConvBias param;
param.stride_h = 1;
param.stride_w = 1;
param.pad_h = p;
param.pad_w = p;
param.nonlineMode = nonline_mode;
//! no bias
args.emplace_back(param, TensorShape{1, ic, h, w},
TensorShape{oc, ic, kernel, kernel}, TensorShape{});
args.emplace_back(param, TensorShape{1, ic, h, w},
TensorShape{oc, ic, kernel, kernel},
TensorShape{1, oc, 1, 1});
args.emplace_back(
param, TensorShape{1, ic, h, w},
TensorShape{oc, ic, kernel, kernel},
TensorShape{1, oc, (h + 2 * p - kernel) / param.stride_h + 1,
(w + 2 * p - kernel) / param.stride_w + 1});
};
for (size_t kernel : {2, 3, 4, 5, 6, 7})
for (size_t ic : {1, 4, 8, 16})
for (size_t oc : {1, 4, 8, 16, 300})
for (size_t p : {0, 2})
for (size_t size : {8, 24})
for (NonlineMode nonline_mode :
{NonlineMode::IDENTITY, NonlineMode::RELU}) {
run(oc, ic, size, size, kernel, p, nonline_mode);
}
run(2046, 8, 20, 20, 3, 1, NonlineMode::IDENTITY);
Checker<ConvBias> checker(handle());
#define cb(algo_name) \
checker.set_before_exec_callback( \
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name)); \
for (auto&& arg : args) { \
checker.set_param(arg.param).execs( \
{arg.src, arg.filter, arg.bias, {}, {}}); \
}
cb("IM2COLMATMUL:X86_F32_6x16:192");
#undef cb
}
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32_PACKA) {
using namespace conv_bias;
......@@ -1377,6 +1481,12 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_BLAS) {
check_conv_bias(args, handle(), "CONV1x1:X86_F32_BLAS:48");
}
TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_6x16) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
check_conv_bias(args, handle(), "CONV1x1:X86_F32_6x16:48");
}
TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_BLAS_NOPACK_REPROCESS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
......@@ -2627,6 +2737,76 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_F32) {
shapes_and_computation.clear();
}
TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_F32_6x16) {
constexpr size_t RUNS = 50;
param::ConvBias param;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
param.pad_h = 1;
param.pad_w = 1;
param.stride_h = 1;
param.stride_w = 1;
std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(),
dtype::Float32(), dtype::Float32()};
std::vector<std::pair<SmallVector<TensorShape>, float>>
shapes_and_computation;
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W,
size_t FS, size_t group) {
SmallVector<TensorShape> shapes{{N, IC, H, W},
{OC / group, IC / group, FS, FS},
{1, OC, 1, 1},
{},
{N, OC, H, W}};
TensorShape dst{N, OC, H, W};
float computations =
((IC / group) * FS * FS * dst.total_nr_elems() * 2 +
dst.total_nr_elems()) *
1e-6;
shapes_and_computation.push_back(std::make_pair(shapes, computations));
};
bench_case(1, 32, 32, 200, 200, 3, 1);
bench_case(1, 32, 32, 200, 200, 3, 1);
bench_case(1, 32, 32, 128, 128, 3, 1);
bench_case(1, 32, 32, 128, 128, 3, 1);
bench_case(1, 32, 32, 100, 100, 3, 1);
bench_case(1, 32, 32, 100, 100, 3, 1);
bench_case(1, 32, 32, 80, 80, 3, 1);
bench_case(1, 32, 32, 80, 80, 3, 1);
bench_case(1, 64, 32, 7, 7, 3, 1);
bench_case(1, 64, 64, 7, 7, 3, 1);
bench_case(1, 64, 128, 7, 7, 3, 1);
bench_case(1, 64, 256, 7, 7, 3, 1);
bench_case(1, 64, 512, 7, 7, 3, 1);
bench_case(1, 64, 1024, 7, 7, 3, 1);
bench_case(1, 64, 32, 14, 14, 3, 1);
bench_case(1, 64, 64, 14, 14, 3, 1);
bench_case(1, 64, 128, 14, 14, 3, 1);
bench_case(1, 64, 256, 14, 14, 3, 1);
bench_case(1, 64, 512, 14, 14, 3, 1);
bench_case(1, 64, 1024, 14, 14, 3, 1);
bench_case(1, 128, 128, 14, 14, 3, 1);
bench_case(1, 128, 256, 14, 14, 3, 1);
bench_case(1, 512, 512, 14, 14, 3, 1);
bench_case(1, 256, 512, 14, 14, 3, 1);
bench_case(1, 512, 1024, 14, 14, 3, 1);
bench_case(1, 1024, 1024, 14, 14, 3, 1);
std::string algo_name = "IM2COLMATMUL:X86_F32_6x16:192";
printf("Benchmark IM2COLMATMUL:X86_F32_6x16 algo\n");
benchmark_impl(param, shapes_and_computation, algo_name, RUNS,
{4, {4, 5, 6, 7}}, {1, {4}}, data_type);
benchmark_impl(param, shapes_and_computation, algo_name, RUNS,
{4, {4, 5, 6, 7}}, {1, {7}}, data_type);
benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}},
{1, {4}}, data_type);
shapes_and_computation.clear();
}
TEST_F(X86_BENCHMARK_MULTI_THREADS,
BENCHMARK_CONVBIAS_IM2COL_F32_single_thread) {
constexpr size_t RUNS = 50;
......@@ -2697,6 +2877,76 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS,
shapes_and_computation.clear();
}
TEST_F(X86_BENCHMARK_MULTI_THREADS,
BENCHMARK_CONVBIAS_IM2COL_F32_6X16_single_thread) {
constexpr size_t RUNS = 50;
param::ConvBias param;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
param.pad_h = 1;
param.pad_w = 1;
param.stride_h = 1;
param.stride_w = 1;
std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(),
dtype::Float32(), dtype::Float32()};
std::vector<std::pair<SmallVector<TensorShape>, float>>
shapes_and_computation;
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W,
size_t FS, size_t group) {
SmallVector<TensorShape> shapes{{N, IC, H, W},
{OC / group, IC / group, FS, FS},
{1, OC, 1, 1},
{},
{N, OC, H, W}};
TensorShape dst{N, OC, H, W};
float computations =
((IC / group) * FS * FS * dst.total_nr_elems() * 2 +
dst.total_nr_elems()) *
1e-6;
shapes_and_computation.push_back(std::make_pair(shapes, computations));
};
bench_case(1, 32, 32, 200, 200, 3, 1);
bench_case(1, 32, 32, 200, 200, 3, 1);
bench_case(1, 32, 32, 128, 128, 3, 1);
bench_case(1, 32, 32, 128, 128, 3, 1);
bench_case(1, 32, 32, 100, 100, 3, 1);
bench_case(1, 32, 32, 100, 100, 3, 1);
bench_case(1, 32, 32, 80, 80, 3, 1);
bench_case(1, 32, 32, 80, 80, 3, 1);
bench_case(1, 64, 32, 7, 7, 3, 1);
bench_case(1, 64, 64, 7, 7, 3, 1);
bench_case(1, 64, 128, 7, 7, 3, 1);
bench_case(1, 64, 256, 7, 7, 3, 1);
bench_case(1, 64, 512, 7, 7, 3, 1);
bench_case(1, 64, 1024, 7, 7, 3, 1);
bench_case(1, 64, 32, 14, 14, 3, 1);
bench_case(1, 64, 64, 14, 14, 3, 1);
bench_case(1, 64, 128, 14, 14, 3, 1);
bench_case(1, 64, 256, 14, 14, 3, 1);
bench_case(1, 64, 512, 14, 14, 3, 1);
bench_case(1, 64, 1024, 14, 14, 3, 1);
bench_case(1, 128, 128, 14, 14, 3, 1);
bench_case(1, 128, 256, 14, 14, 3, 1);
bench_case(1, 512, 512, 14, 14, 3, 1);
bench_case(1, 256, 512, 14, 14, 3, 1);
bench_case(1, 512, 1024, 14, 14, 3, 1);
bench_case(1, 1024, 1024, 14, 14, 3, 1);
std::string algo_name = "IM2COLMATMUL:X86_F32_MKL_PACKA:192";
std::string algo_name1 = "IM2COLMATMUL:X86_F32_6x16:192";
printf("Benchmark IM2COLMATMUL:X86_F32_6x16 algo\n");
benchmark_impl_comp(param, shapes_and_computation, algo_name, algo_name1,
RUNS, {1, {4}}, {1, {4}}, data_type);
benchmark_impl_comp(param, shapes_and_computation, algo_name, algo_name1,
RUNS, {1, {7}}, {1, {7}}, data_type);
shapes_and_computation.clear();
}
TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_INT8X8X32) {
constexpr size_t RUNS = 50;
......
......@@ -85,6 +85,13 @@ TEST_F(X86, MATRIX_MUL_AVX2_MK8_8X8) {
param::MatrixMul::Format::MK8, 1, 1e-3, false);
}
TEST_F(X86, MATRIX_MUL_AVX2_6x16) {
matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{},
dtype::Float32{}, handle(), "X86_F32_6x16",
param::MatrixMul::Format::DEFAULT, 1, 1e-3, false);
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(X86, BENCHMARK_MATRIX_MUL_AVX2_MK8_8X8) {
......@@ -96,6 +103,14 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_AVX2_MK8_8X8) {
"X86_F32_BLAS");
}
TEST_F(X86, BENCHMARK_MATRIX_MUL_AVX2_6x16) {
auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8);
matrix_mul::benchmark_with_contrast(
handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
"X86_F32_6x16", param::MatrixMul::Format::DEFAULT, dtype::Float32{},
dtype::Float32{}, dtype::Float32{},"X86_F32_BLAS");
}
TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) {
constexpr size_t RUNS = 50;
auto rng = std::make_unique<UniformIntRNG>(-127, 127);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册