Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d2184af3
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d2184af3
编写于
9月 05, 2021
作者:
泥点无名哥
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/src/x86/matmul): add matmul_6x16 for x86
上级
dbb3232c
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
1632 addition
and
5 deletion
+1632
-5
CMakeLists.txt
CMakeLists.txt
+1
-0
dnn/src/fallback/conv_bias/im2col/algos.cpp
dnn/src/fallback/conv_bias/im2col/algos.cpp
+4
-4
dnn/src/fallback/matrix_mul/opr_impl.h
dnn/src/fallback/matrix_mul/opr_impl.h
+1
-0
dnn/src/x86/avx_helper.h
dnn/src/x86/avx_helper.h
+9
-0
dnn/src/x86/matrix_mul/algos.cpp
dnn/src/x86/matrix_mul/algos.cpp
+70
-0
dnn/src/x86/matrix_mul/algos.h
dnn/src/x86/matrix_mul/algos.h
+13
-0
dnn/src/x86/matrix_mul/f32/strategy.h
dnn/src/x86/matrix_mul/f32/strategy.h
+2
-0
dnn/src/x86/matrix_mul/f32/strategy_6x16.cpp
dnn/src/x86/matrix_mul/f32/strategy_6x16.cpp
+1255
-0
dnn/src/x86/matrix_mul/opr_impl.cpp
dnn/src/x86/matrix_mul/opr_impl.cpp
+2
-0
dnn/src/x86/matrix_mul/opr_impl.h
dnn/src/x86/matrix_mul/opr_impl.h
+1
-1
dnn/test/x86/accuracy_shake.cpp
dnn/test/x86/accuracy_shake.cpp
+9
-0
dnn/test/x86/conv_bias.cpp
dnn/test/x86/conv_bias.cpp
+250
-0
dnn/test/x86/matrix_mul.cpp
dnn/test/x86/matrix_mul.cpp
+15
-0
未找到文件。
CMakeLists.txt
浏览文件 @
d2184af3
...
...
@@ -10,6 +10,7 @@ project(MegEngine LANGUAGES C CXX VERSION ${MGB_VER_STRING})
set
(
CMAKE_CXX_STANDARD 14
)
set
(
CMAKE_CXX_STANDARD_REQUIRED ON
)
set
(
CMAKE_CXX_EXTENSIONS OFF
)
set
(
CMAKE_EXPORT_COMPILE_COMMANDS ON
)
set
(
CMAKE_POSITION_INDEPENDENT_CODE ON
)
set
(
CMAKE_MODULE_PATH
${
PROJECT_SOURCE_DIR
}
/cmake/Modules
)
set
(
CMAKE_POLICY_DEFAULT_CMP0048 NEW
)
...
...
dnn/src/fallback/conv_bias/im2col/algos.cpp
浏览文件 @
d2184af3
...
...
@@ -70,16 +70,16 @@ static void choice_ohw_oc_block(
fallback
::
MatrixMulImpl
::
AlgoBase
::
PackMode
pack_mode
)
{
//! calculate m_oc_tile_size in choice_ohw_oc_block() fucntion,
//! when ohw_tile_size < this value ohw_tile_size = ohw
s
tatic
constexpr
size_t
DEFAULT_OHW_MIN_TILE_SIZE
=
32
;
s
ize_t
DEFAULT_OHW_MIN_TILE_SIZE
=
round_up
(
32UL
,
block_n
)
;
//! when nr_threads > 1 and round(ohw,nr_threads)>nr_threads,
//! oc_tile_size = DEFAULT_OC_TILE_SIZE
s
tatic
constexpr
size_t
DEFAULT_OC_TILE_SIZE
=
512
;
s
ize_t
DEFAULT_OC_TILE_SIZE
=
round_up
(
512UL
,
block_m
)
;
//! when oc_tile_size > this value m_oc_tile_size =
//! DEFAULT_OC_MAX_TILE_SIZE
s
tatic
constexpr
size_t
DEFAULT_OC_MAX_TILE_SIZE
=
1024
;
s
ize_t
DEFAULT_OC_MAX_TILE_SIZE
=
round_up
(
1024UL
,
block_m
)
;
//! when oc_tile_size < this value oc_tile_size =
//! DEFAULT_OC_MIN_TILE_SIZE the purpose is aligning the calculation
s
tatic
constexpr
size_t
DEFAULT_OC_MIN_TILE_SIZE
=
128
;
s
ize_t
DEFAULT_OC_MIN_TILE_SIZE
=
round_up
(
128UL
,
block_m
);
;
size_t
nr_threads
=
param
.
nr_threads
;
size_t
OC
=
param
.
filter_meta
.
ocpg
;
size_t
ohw
=
param
.
osz
[
0
]
*
param
.
osz
[
1
];
...
...
dnn/src/fallback/matrix_mul/opr_impl.h
浏览文件 @
d2184af3
...
...
@@ -122,6 +122,7 @@ public:
X86_INT8X8X16_SSE
,
X86_INT8X8X32_SSE_4X8X2
,
X86_F32_MK8_8X8
,
X86_F32_6x16
,
X86_INT8X8X32_VNNI
,
X86_INT8X8X32_MKLDNN
,
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7
...
...
dnn/src/x86/avx_helper.h
浏览文件 @
d2184af3
...
...
@@ -31,6 +31,15 @@ static inline __m256 _mm256_loadu2_m128_emulate(
_mm_loadu_ps
(
hiaddr
),
1
);
}
MEGDNN_ATTRIBUTE_TARGET
(
"avx"
)
static
inline
void
_mm256_storeu2_m128_emulate
(
float
*
hiaddr
,
float
*
loaddr
,
__m256
reg
)
{
auto
xmm0
=
_mm256_extractf128_ps
(
reg
,
0
);
auto
xmm1
=
_mm256_extractf128_ps
(
reg
,
1
);
_mm_storeu_ps
(
loaddr
,
xmm0
);
_mm_storeu_ps
(
hiaddr
,
xmm1
);
}
template
<
typename
ctype
,
size_t
len
>
struct
Vector
;
...
...
dnn/src/x86/matrix_mul/algos.cpp
浏览文件 @
d2184af3
...
...
@@ -320,6 +320,35 @@ void gemm_s8s8s32_sse_4x8x2(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_END
();
}
void
gemm_f32_avx2_6x16
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
MEGDNN_MARK_USED_VAR
(
kern_param
);
MIDOUT_BEGIN
(
megdnn_x86_matmul_kern_avx2_6x16x2
,
midout_iv
(
0
))
{
constexpr
int
cacheline
=
64
;
const
size_t
m
=
kern_param
.
M
;
const
size_t
n
=
kern_param
.
N
;
const
size_t
k
=
kern_param
.
K
;
const
bool
trans_a
=
kern_param
.
trA
;
const
bool
trans_b
=
kern_param
.
trB
;
const
size_t
lda
=
kern_param
.
LDA
;
const
size_t
ldb
=
kern_param
.
LDB
;
const
size_t
ldc
=
kern_param
.
LDC
;
auto
a_type
=
kern_param
.
A_type
;
auto
b_type
=
kern_param
.
B_type
;
auto
c_type
=
kern_param
.
C_type
;
const
auto
a_ptr
=
kern_param
.
A
<
float
>
();
const
auto
b_ptr
=
kern_param
.
B
<
float
>
();
auto
c_ptr
=
kern_param
.
C
<
float
>
();
x86
::
matmul
::
sgemm_pack_6x16_avx2
strategy
(
m
,
n
,
k
,
a_type
,
b_type
,
c_type
);
megdnn
::
matmul
::
GemmInterleaved
<
x86
::
matmul
::
sgemm_pack_6x16_avx2
>
(
m
,
n
,
k
,
trans_a
,
trans_b
,
strategy
,
cacheline
)
.
execute
(
a_ptr
,
lda
,
b_ptr
,
ldb
,
c_ptr
,
ldc
,
kern_param
.
workspace_ptr
);
}
MIDOUT_END
();
}
}
// namespace
/*************************AlgoInt8x8x16AVX2********************/
...
...
@@ -662,4 +691,45 @@ size_t MatrixMulImpl::AlgoF32MK8_8x8::get_workspace(
MIDOUT_END
();
}
/*************************AlgoFloatAVX2M6N16********************/
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoFloatAVX2M6N16
::
get_kern
(
const
KernSizeParam
&
)
const
{
return
gemm_f32_avx2_6x16
;
}
bool
MatrixMulImpl
::
AlgoFloatAVX2M6N16
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
bool
is_param_ok
=
kern_size_param
.
A_type
.
enumv
()
==
kern_size_param
.
B_type
.
enumv
()
&&
((
kern_size_param
.
A_type
.
enumv
()
==
DTypeEnum
::
Float32
&&
kern_size_param
.
C_type
.
enumv
()
==
DTypeEnum
::
Float32
))
&&
kern_size_param
.
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
&&
kern_size_param
.
format
==
Param
::
Format
::
DEFAULT
&&
is_supported
(
SIMDType
::
AVX2
);
return
is_param_ok
;
}
size_t
MatrixMulImpl
::
AlgoFloatAVX2M6N16
::
get_workspace
(
const
KernSizeParam
&
kern_param
)
const
{
constexpr
int
cacheline
=
64
;
const
size_t
m
=
kern_param
.
M
;
const
size_t
n
=
kern_param
.
N
;
const
size_t
k
=
kern_param
.
K
;
const
bool
trans_a
=
kern_param
.
trA
;
const
bool
trans_b
=
kern_param
.
trB
;
auto
a_type
=
kern_param
.
A_type
;
auto
b_type
=
kern_param
.
B_type
;
auto
c_type
=
kern_param
.
C_type
;
x86
::
matmul
::
sgemm_pack_6x16_avx2
strategy
(
m
,
n
,
k
,
a_type
,
b_type
,
c_type
);
return
megdnn
::
matmul
::
GemmInterleaved
<
x86
::
matmul
::
sgemm_pack_6x16_avx2
>
(
m
,
n
,
k
,
trans_a
,
trans_b
,
strategy
,
cacheline
)
.
get_workspace_size
();
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL
(
AlgoFloatAVX2M6N16
,
megdnn_x86_matmul_kern
,
"AlgoFloatAVX2M6N16"
_hash
,
x86
::
matmul
::
sgemm_pack_6x16_avx2
,
float
,
float
,
float
,
AlgoDataType
::
FLOAT32
,
DEFAULT
);
// vim: syntax=cpp.doxygen
dnn/src/x86/matrix_mul/algos.h
浏览文件 @
d2184af3
...
...
@@ -149,6 +149,19 @@ public:
MEGDNN_DECL_ALGO_TYPE
(
X86_F32_MK8_8X8
)
};
class
MatrixMulImpl
::
AlgoFloatAVX2M6N16
:
public
AlgoBase
{
public:
AlgoAttribute
attribute
()
const
override
{
return
AlgoAttribute
::
REPRODUCIBLE
;
}
const
char
*
name
()
const
override
{
return
"X86_F32_6x16"
;
}
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
MEGDNN_DECL_ALGO_TYPE
(
X86_F32_6x16
)
};
#if MEGDNN_X86_WITH_VNNI
class
MatrixMulImpl
::
AlgoInt8x8x32Vnni
:
public
AlgoBase
{
public:
...
...
dnn/src/x86/matrix_mul/f32/strategy.h
浏览文件 @
d2184af3
...
...
@@ -19,6 +19,8 @@ namespace matmul {
MEGDNN_REG_GEMM_STRATEGY_NOPACK
(
float
,
float
,
float
,
8
,
8
,
8
,
false
,
true
,
sgemm_nopack_8x8_avx2
);
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE
(
float
,
float
,
float
,
float
,
6
,
16
,
1
,
false
,
false
,
sgemm_pack_6x16_avx2
);
}
// namespace matmul
}
// namespace x86
}
// namespace megdnn
\ No newline at end of file
dnn/src/x86/matrix_mul/f32/strategy_6x16.cpp
0 → 100644
浏览文件 @
d2184af3
/**
* \file dnn/src/x86/matrix_mul/f32/strategy_6x16.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
/**
* \file dnn/src/x86/matrix_mul/f32/strategy_6x16.cpp
*
* This file is part of MegDNN, a deep neural network run-time library
* developed by Megvii.
*
* \copyright copyright (c) 2014-2019 megvii inc. all rights reserved.
*/
#include <immintrin.h>
#include "src/common/utils.h"
#include "src/x86/avx_helper.h"
#include "src/x86/matrix_mul/common/common.h"
#include "src/x86/matrix_mul/f32/strategy.h"
#include "src/common/unroll_macro.h"
using
namespace
megdnn
;
using
namespace
x86
;
#define DNN_AVX2_TARGET
#if !defined(__clang__)
//! bypass gcc bug https://bugs.launchpad.net/ubuntu/+source/gcc-5/+bug/1642109
#pragma GCC target("avx2")
#else
#undef DNN_AVX2_TARGET
#define DNN_AVX2_TARGET MEGDNN_ATTRIBUTE_TARGET("avx2")
#endif
#define UNROLL_CODE(cb, i, a...) UNROLL_CALL1(i,cb,##a)
namespace
{
DNN_AVX2_TARGET
void
transpose_16x8_1_s
(
const
float
*
inptr0
,
const
float
*
inptr1
,
const
float
*
inptr2
,
const
float
*
inptr3
,
const
float
*
inptr4
,
const
float
*
inptr5
,
const
float
*
inptr6
,
const
float
*
inptr7
,
const
float
*
inptr8
,
const
float
*
inptr9
,
const
float
*
inptr10
,
const
float
*
inptr11
,
const
float
*
inptr12
,
const
float
*
inptr13
,
const
float
*
inptr14
,
const
float
*
inptr15
,
float
*
outptr
)
{
auto
ymm0
=
_mm256_loadu_ps
(
inptr0
);
// A0A1A2A3A4A5A6A7
auto
ymm1
=
_mm256_loadu_ps
(
inptr1
);
// B0B1B2B3B4B5B6B7
auto
ymm2
=
_mm256_loadu_ps
(
inptr2
);
// C0C1C2C3C4C5C6C7
auto
ymm3
=
_mm256_loadu_ps
(
inptr3
);
// D0D1D2D3D4D5D6D7
auto
ymm4
=
_mm256_loadu_ps
(
inptr4
);
// E0E1E2E3E4E5E6E7
auto
ymm5
=
_mm256_loadu_ps
(
inptr5
);
// F0F1F2F3F4F5F6F7
auto
ymm6
=
_mm256_loadu_ps
(
inptr6
);
// G0G1G2G3G4G5G6G7
auto
ymm7
=
_mm256_loadu_ps
(
inptr7
);
// H0H1H2H3H4H5H6H7
auto
ymm8
=
_mm256_unpacklo_ps
(
ymm0
,
ymm2
);
// A0C0A1C1A4C4A5C5
auto
ymm9
=
_mm256_unpackhi_ps
(
ymm0
,
ymm2
);
// A2C2A3C3A6C6A7C7
auto
ymm10
=
_mm256_unpacklo_ps
(
ymm1
,
ymm3
);
// B0D0B1D1B4D4B5D5
auto
ymm11
=
_mm256_unpackhi_ps
(
ymm1
,
ymm3
);
// B2D2B3D3B6D6B7D7
auto
ymm12
=
_mm256_unpacklo_ps
(
ymm4
,
ymm6
);
// E0G0E1G1E4G4E5G5
auto
ymm13
=
_mm256_unpackhi_ps
(
ymm4
,
ymm6
);
// E2G2E3G3E6G6E7G7
auto
ymm14
=
_mm256_unpacklo_ps
(
ymm5
,
ymm7
);
// F0H0F1H1F4H4F5H5
auto
ymm15
=
_mm256_unpackhi_ps
(
ymm5
,
ymm7
);
// F2H2F3H3F6H6F7H7
ymm0
=
_mm256_unpacklo_ps
(
ymm8
,
ymm10
);
// A0B0C0D0A4B4C4D4
ymm1
=
_mm256_unpackhi_ps
(
ymm8
,
ymm10
);
// A1B1C1D1A5B5C5D5
ymm2
=
_mm256_unpacklo_ps
(
ymm9
,
ymm11
);
// A2B2C2D2A6B6C6D6
ymm3
=
_mm256_unpackhi_ps
(
ymm9
,
ymm11
);
// A3B3C3D3A7B7C7D7
ymm4
=
_mm256_unpacklo_ps
(
ymm12
,
ymm14
);
// E0F0G0H0E4F4G4H4
ymm5
=
_mm256_unpackhi_ps
(
ymm12
,
ymm14
);
// E1F1G1H1E5F5G5H5
ymm6
=
_mm256_unpacklo_ps
(
ymm13
,
ymm15
);
// E2F2G2H2E6F6G6H6
ymm7
=
_mm256_unpackhi_ps
(
ymm13
,
ymm15
);
// E3F3G3H3E7F7G7H7
ymm8
=
_mm256_permute2f128_ps
(
ymm0
,
ymm4
,
0x20
);
// A0B0C0D0E0F0G0H0
ymm9
=
_mm256_permute2f128_ps
(
ymm1
,
ymm5
,
0x20
);
// A1B1C1D1E1F1G1H1
ymm10
=
_mm256_permute2f128_ps
(
ymm2
,
ymm6
,
0x20
);
// A2B2C2D2E2F2G2H2
ymm11
=
_mm256_permute2f128_ps
(
ymm3
,
ymm7
,
0x20
);
// A3B3C3D3E3F3G3H3
ymm12
=
_mm256_permute2f128_ps
(
ymm0
,
ymm4
,
0x31
);
// A4B4C4D4E4F4G4H4
ymm13
=
_mm256_permute2f128_ps
(
ymm1
,
ymm5
,
0x31
);
// A5B5C5D5E5F5G5H5
ymm14
=
_mm256_permute2f128_ps
(
ymm2
,
ymm6
,
0x31
);
// A6B6C6D6E6F6G6H6
ymm15
=
_mm256_permute2f128_ps
(
ymm3
,
ymm7
,
0x31
);
// A7B7C7D7E7F7G7H7
_mm256_storeu_ps
(
outptr
+
16
*
0
,
ymm8
);
_mm256_storeu_ps
(
outptr
+
16
*
1
,
ymm9
);
_mm256_storeu_ps
(
outptr
+
16
*
2
,
ymm10
);
_mm256_storeu_ps
(
outptr
+
16
*
3
,
ymm11
);
_mm256_storeu_ps
(
outptr
+
16
*
4
,
ymm12
);
_mm256_storeu_ps
(
outptr
+
16
*
5
,
ymm13
);
_mm256_storeu_ps
(
outptr
+
16
*
6
,
ymm14
);
_mm256_storeu_ps
(
outptr
+
16
*
7
,
ymm15
);
ymm0
=
_mm256_loadu_ps
(
inptr8
);
// A0A1A2A3A4A5A6A7
ymm1
=
_mm256_loadu_ps
(
inptr9
);
// B0B1B2B3B4B5B6B7
ymm2
=
_mm256_loadu_ps
(
inptr10
);
// C0C1C2C3C4C5C6C7
ymm3
=
_mm256_loadu_ps
(
inptr11
);
// D0D1D2D3D4D5D6D7
ymm4
=
_mm256_loadu_ps
(
inptr12
);
// E0E1E2E3E4E5E6E7
ymm5
=
_mm256_loadu_ps
(
inptr13
);
// F0F1F2F3F4F5F6F7
ymm6
=
_mm256_loadu_ps
(
inptr14
);
// G0G1G2G3G4G5G6G7
ymm7
=
_mm256_loadu_ps
(
inptr15
);
// H0H1H2H3H4H5H6H7
ymm8
=
_mm256_unpacklo_ps
(
ymm0
,
ymm2
);
// A0C0A1C1A4C4A5C5
ymm9
=
_mm256_unpackhi_ps
(
ymm0
,
ymm2
);
// A2C2A3C3A6C6A7C7
ymm10
=
_mm256_unpacklo_ps
(
ymm1
,
ymm3
);
// B0D0B1D1B4D4B5D5
ymm11
=
_mm256_unpackhi_ps
(
ymm1
,
ymm3
);
// B2D2B3D3B6D6B7D7
ymm12
=
_mm256_unpacklo_ps
(
ymm4
,
ymm6
);
// E0G0E1G1E4G4E5G5
ymm13
=
_mm256_unpackhi_ps
(
ymm4
,
ymm6
);
// E2G2E3G3E6G6E7G7
ymm14
=
_mm256_unpacklo_ps
(
ymm5
,
ymm7
);
// F0H0F1H1F4H4F5H5
ymm15
=
_mm256_unpackhi_ps
(
ymm5
,
ymm7
);
// F2H2F3H3F6H6F7H7
ymm0
=
_mm256_unpacklo_ps
(
ymm8
,
ymm10
);
// A0B0C0D0A4B4C4D4
ymm1
=
_mm256_unpackhi_ps
(
ymm8
,
ymm10
);
// A1B1C1D1A5B5C5D5
ymm2
=
_mm256_unpacklo_ps
(
ymm9
,
ymm11
);
// A2B2C2D2A6B6C6D6
ymm3
=
_mm256_unpackhi_ps
(
ymm9
,
ymm11
);
// A3B3C3D3A7B7C7D7
ymm4
=
_mm256_unpacklo_ps
(
ymm12
,
ymm14
);
// E0F0G0H0E4F4G4H4
ymm5
=
_mm256_unpackhi_ps
(
ymm12
,
ymm14
);
// E1F1G1H1E5F5G5H5
ymm6
=
_mm256_unpacklo_ps
(
ymm13
,
ymm15
);
// E2F2G2H2E6F6G6H6
ymm7
=
_mm256_unpackhi_ps
(
ymm13
,
ymm15
);
// E3F3G3H3E7F7G7H7
ymm8
=
_mm256_permute2f128_ps
(
ymm0
,
ymm4
,
0x20
);
// A0B0C0D0E0F0G0H0
ymm9
=
_mm256_permute2f128_ps
(
ymm1
,
ymm5
,
0x20
);
// A1B1C1D1E1F1G1H1
ymm10
=
_mm256_permute2f128_ps
(
ymm2
,
ymm6
,
0x20
);
// A2B2C2D2E2F2G2H2
ymm11
=
_mm256_permute2f128_ps
(
ymm3
,
ymm7
,
0x20
);
// A3B3C3D3E3F3G3H3
ymm12
=
_mm256_permute2f128_ps
(
ymm0
,
ymm4
,
0x31
);
// A4B4C4D4E4F4G4H4
ymm13
=
_mm256_permute2f128_ps
(
ymm1
,
ymm5
,
0x31
);
// A5B5C5D5E5F5G5H5
ymm14
=
_mm256_permute2f128_ps
(
ymm2
,
ymm6
,
0x31
);
// A6B6C6D6E6F6G6H6
ymm15
=
_mm256_permute2f128_ps
(
ymm3
,
ymm7
,
0x31
);
// A7B7C7D7E7F7G7H7
_mm256_storeu_ps
(
outptr
+
16
*
0
+
8
,
ymm8
);
_mm256_storeu_ps
(
outptr
+
16
*
1
+
8
,
ymm9
);
_mm256_storeu_ps
(
outptr
+
16
*
2
+
8
,
ymm10
);
_mm256_storeu_ps
(
outptr
+
16
*
3
+
8
,
ymm11
);
_mm256_storeu_ps
(
outptr
+
16
*
4
+
8
,
ymm12
);
_mm256_storeu_ps
(
outptr
+
16
*
5
+
8
,
ymm13
);
_mm256_storeu_ps
(
outptr
+
16
*
6
+
8
,
ymm14
);
_mm256_storeu_ps
(
outptr
+
16
*
7
+
8
,
ymm15
);
}
DNN_AVX2_TARGET
void
transpose_16x4_1_s
(
const
float
*
inptr0
,
const
float
*
inptr1
,
const
float
*
inptr2
,
const
float
*
inptr3
,
const
float
*
inptr4
,
const
float
*
inptr5
,
const
float
*
inptr6
,
const
float
*
inptr7
,
const
float
*
inptr8
,
const
float
*
inptr9
,
const
float
*
inptr10
,
const
float
*
inptr11
,
const
float
*
inptr12
,
const
float
*
inptr13
,
const
float
*
inptr14
,
const
float
*
inptr15
,
float
*
outptr
)
{
const
std
::
uint32_t
arr
[
8
]
=
{
0
,
1
,
4
,
5
,
2
,
3
,
6
,
7
};
__m256i
order
=
_mm256_loadu_si256
((
const
__m256i
*
)
arr
);
auto
ymm0
=
_mm256_loadu2_m128_emulate
(
inptr2
,
inptr0
);
// A0A1A2A3C0C1C2C3
auto
ymm1
=
_mm256_loadu2_m128_emulate
(
inptr3
,
inptr1
);
// B0B1B2B3D0D1D2D3
auto
ymm2
=
_mm256_loadu2_m128_emulate
(
inptr6
,
inptr4
);
// E0E1E2E3G0G1G2G3
auto
ymm3
=
_mm256_loadu2_m128_emulate
(
inptr7
,
inptr5
);
// F0F1F2F3H0H1H2H3
auto
ymm4
=
_mm256_unpacklo_ps
(
ymm0
,
ymm1
);
// A0B0A1B1C0D0C1D1
auto
ymm5
=
_mm256_unpackhi_ps
(
ymm0
,
ymm1
);
// A2B2A3B3C2D2C3D3
auto
ymm6
=
_mm256_unpacklo_ps
(
ymm2
,
ymm3
);
// E0F0E1F1G0H0G1H1
auto
ymm7
=
_mm256_unpackhi_ps
(
ymm2
,
ymm3
);
// E2F2E3F3G2H2G3H3
auto
ymm8
=
_mm256_permutevar8x32_ps
(
ymm4
,
order
);
// A0B0C0D0A1B1C1D1
auto
ymm9
=
_mm256_permutevar8x32_ps
(
ymm5
,
order
);
// A2B2C2D2A3B3C3D3
auto
ymm10
=
_mm256_permutevar8x32_ps
(
ymm6
,
order
);
// E0F0G0H0E1F1G1H1
auto
ymm11
=
_mm256_permutevar8x32_ps
(
ymm7
,
order
);
// E2F2G2H2E3F3G3H3
ymm0
=
_mm256_permute2f128_ps
(
ymm8
,
ymm10
,
0x20
);
// A0B0C0D0E0F0G0H0
ymm1
=
_mm256_permute2f128_ps
(
ymm8
,
ymm10
,
0x31
);
// A1B1C1D1E1F1G1H1
ymm2
=
_mm256_permute2f128_ps
(
ymm9
,
ymm11
,
0x20
);
// A2B2C2D2E2F2G2H2
ymm3
=
_mm256_permute2f128_ps
(
ymm9
,
ymm11
,
0x31
);
// A3B3C3D3E3F3G3H3
_mm256_storeu_ps
(
outptr
+
16
*
0
,
ymm0
);
_mm256_storeu_ps
(
outptr
+
16
*
1
,
ymm1
);
_mm256_storeu_ps
(
outptr
+
16
*
2
,
ymm2
);
_mm256_storeu_ps
(
outptr
+
16
*
3
,
ymm3
);
ymm0
=
_mm256_loadu2_m128_emulate
(
inptr10
,
inptr8
);
// A0A1A2A3C0C1C2C3
ymm1
=
_mm256_loadu2_m128_emulate
(
inptr11
,
inptr9
);
// B0B1B2B3D0D1D2D3
ymm2
=
_mm256_loadu2_m128_emulate
(
inptr14
,
inptr12
);
// E0E1E2E3G0G1G2G3
ymm3
=
_mm256_loadu2_m128_emulate
(
inptr15
,
inptr13
);
// F0F1F2F3H0H1H2H3
ymm4
=
_mm256_unpacklo_ps
(
ymm0
,
ymm1
);
// A0B0A1B1C0D0C1D1
ymm5
=
_mm256_unpackhi_ps
(
ymm0
,
ymm1
);
// A2B2A3B3C2D2C3D3
ymm6
=
_mm256_unpacklo_ps
(
ymm2
,
ymm3
);
// E0F0E1F1G0H0G1H1
ymm7
=
_mm256_unpackhi_ps
(
ymm2
,
ymm3
);
// E2F2E3F3G2H2G3H3
ymm8
=
_mm256_permutevar8x32_ps
(
ymm4
,
order
);
// A0B0C0D0A1B1C1D1
ymm9
=
_mm256_permutevar8x32_ps
(
ymm5
,
order
);
// A2B2C2D2A3B3C3D3
ymm10
=
_mm256_permutevar8x32_ps
(
ymm6
,
order
);
// E0F0G0H0E1F1G1H1
ymm11
=
_mm256_permutevar8x32_ps
(
ymm7
,
order
);
// E2F2G2H2E3F3G3H3
ymm0
=
_mm256_permute2f128_ps
(
ymm8
,
ymm10
,
0x20
);
// A0B0C0D0E0F0G0H0
ymm1
=
_mm256_permute2f128_ps
(
ymm8
,
ymm10
,
0x31
);
// A1B1C1D1E1F1G1H1
ymm2
=
_mm256_permute2f128_ps
(
ymm9
,
ymm11
,
0x20
);
// A2B2C2D2E2F2G2H2
ymm3
=
_mm256_permute2f128_ps
(
ymm9
,
ymm11
,
0x31
);
// A3B3C3D3E3F3G3H3
_mm256_storeu_ps
(
outptr
+
16
*
0
+
8
,
ymm0
);
_mm256_storeu_ps
(
outptr
+
16
*
1
+
8
,
ymm1
);
_mm256_storeu_ps
(
outptr
+
16
*
2
+
8
,
ymm2
);
_mm256_storeu_ps
(
outptr
+
16
*
3
+
8
,
ymm3
);
}
static
size_t
min
(
size_t
a
,
size_t
b
)
{
return
a
>
b
?
b
:
a
;
}
DNN_AVX2_TARGET
void
transpose_6x16_1_s
(
const
float
*
inptr0
,
const
float
*
inptr1
,
const
float
*
inptr2
,
const
float
*
inptr3
,
const
float
*
inptr4
,
const
float
*
inptr5
,
float
*
outptr
)
{
auto
ymm0
=
_mm256_loadu_ps
(
inptr0
+
0
);
// A0A1A2A3A4A5A6A7
auto
ymm1
=
_mm256_loadu_ps
(
inptr0
+
8
);
// a0a1a2a3a4a5a6a7
auto
ymm2
=
_mm256_loadu_ps
(
inptr1
+
0
);
// B0B1B2B3B4B5B6B7
auto
ymm3
=
_mm256_loadu_ps
(
inptr1
+
8
);
// b0b1b2b3b4b5b6b7
auto
ymm4
=
_mm256_loadu_ps
(
inptr2
+
0
);
// C0C1C2C3C4C5C6C7
auto
ymm5
=
_mm256_loadu_ps
(
inptr2
+
8
);
// c0c1c2c3c4c5c6c7
auto
ymm6
=
_mm256_loadu_ps
(
inptr3
+
0
);
// D0D1D2D3D4D5D6D7
auto
ymm7
=
_mm256_loadu_ps
(
inptr3
+
8
);
// d0d1d2d3d4d5d6d7
auto
ymm8
=
_mm256_unpacklo_ps
(
ymm0
,
ymm4
);
// A0C0A1C1A4C4A5C5
auto
ymm9
=
_mm256_unpackhi_ps
(
ymm0
,
ymm4
);
// A2C2A3C3A6C6A7C7
auto
ymm10
=
_mm256_unpacklo_ps
(
ymm2
,
ymm6
);
// B0D0B1D1B4D4B5D5
auto
ymm11
=
_mm256_unpackhi_ps
(
ymm2
,
ymm6
);
// B2D2B3D3B6D6B7D7
auto
ymm12
=
_mm256_unpacklo_ps
(
ymm1
,
ymm5
);
// a0c0a1c1a4c4a5c5
auto
ymm13
=
_mm256_unpackhi_ps
(
ymm1
,
ymm5
);
// a2c2a3c3a6c6a7c7
auto
ymm14
=
_mm256_unpacklo_ps
(
ymm3
,
ymm7
);
// b0d0b1d1b4d4b5d5
auto
ymm15
=
_mm256_unpackhi_ps
(
ymm3
,
ymm7
);
// b2d2b3d3b6d6b7d7
ymm0
=
_mm256_unpacklo_ps
(
ymm8
,
ymm10
);
// A0B0C0D0A4B4C4D4
ymm1
=
_mm256_unpackhi_ps
(
ymm8
,
ymm10
);
// A1B1C1D1A5B5C5D5
ymm2
=
_mm256_unpacklo_ps
(
ymm9
,
ymm11
);
// A2B2C2D2A6B6C6D6
ymm3
=
_mm256_unpackhi_ps
(
ymm9
,
ymm11
);
// A3B3C3D3A7B7C7D7
ymm4
=
_mm256_unpacklo_ps
(
ymm12
,
ymm14
);
// a0b0c0d0a4b4c4d4
ymm5
=
_mm256_unpackhi_ps
(
ymm12
,
ymm14
);
// a1b1c1d1a5b5c5d5
ymm6
=
_mm256_unpacklo_ps
(
ymm13
,
ymm15
);
// a2b2c2d2a6b6c6d6
ymm7
=
_mm256_unpackhi_ps
(
ymm13
,
ymm15
);
// a3b3c3d3a7b7c7d7
_mm256_storeu2_m128_emulate
(
outptr
+
6
*
4
,
outptr
+
6
*
0
,
ymm0
);
_mm256_storeu2_m128_emulate
(
outptr
+
6
*
5
,
outptr
+
6
*
1
,
ymm1
);
_mm256_storeu2_m128_emulate
(
outptr
+
6
*
6
,
outptr
+
6
*
2
,
ymm2
);
_mm256_storeu2_m128_emulate
(
outptr
+
6
*
7
,
outptr
+
6
*
3
,
ymm3
);
_mm256_storeu2_m128_emulate
(
outptr
+
6
*
12
,
outptr
+
6
*
8
,
ymm4
);
_mm256_storeu2_m128_emulate
(
outptr
+
6
*
13
,
outptr
+
6
*
9
,
ymm5
);
_mm256_storeu2_m128_emulate
(
outptr
+
6
*
14
,
outptr
+
6
*
10
,
ymm6
);
_mm256_storeu2_m128_emulate
(
outptr
+
6
*
15
,
outptr
+
6
*
11
,
ymm7
);
float
other
[
4
*
8
];
ymm8
=
_mm256_loadu_ps
(
inptr4
+
0
);
// E0E1E2E3E4E5E6E7
ymm9
=
_mm256_loadu_ps
(
inptr4
+
8
);
// e0e1e2e3e4e5e6e7
ymm10
=
_mm256_loadu_ps
(
inptr5
+
0
);
// F0F1F2F3F4F5F6F7
ymm11
=
_mm256_loadu_ps
(
inptr5
+
8
);
// f0f1f2f3f4f5f6f7
_mm256_storeu_ps
(
other
,
ymm8
);
_mm256_storeu_ps
(
other
+
8
,
ymm9
);
_mm256_storeu_ps
(
other
+
16
,
ymm10
);
_mm256_storeu_ps
(
other
+
24
,
ymm11
);
for
(
size_t
i
=
0
;
i
<
16
;
i
++
)
{
outptr
[
6
*
i
+
4
]
=
other
[
i
];
outptr
[
6
*
i
+
5
]
=
other
[
i
+
16
];
}
}
DNN_AVX2_TARGET
void
transpose_6x8_1_s
(
const
float
*
inptr0
,
const
float
*
inptr1
,
const
float
*
inptr2
,
const
float
*
inptr3
,
const
float
*
inptr4
,
const
float
*
inptr5
,
float
*
outptr
)
{
auto
ymm0
=
_mm256_loadu_ps
(
inptr0
);
// A0A1A2A3A4A5A6A7
auto
ymm1
=
_mm256_loadu_ps
(
inptr1
);
// B0B1B2B3B4B5B6B7
auto
ymm2
=
_mm256_loadu_ps
(
inptr2
);
// C0C1C2C3C4C5C6C7
auto
ymm3
=
_mm256_loadu_ps
(
inptr3
);
// D0D1D2D3D4D5D6D7
auto
ymm4
=
_mm256_unpacklo_ps
(
ymm0
,
ymm2
);
// A0C0A1C1A4C4A5C5
auto
ymm5
=
_mm256_unpackhi_ps
(
ymm0
,
ymm2
);
// A2C2A3C3A6C6A7C7
auto
ymm6
=
_mm256_unpacklo_ps
(
ymm1
,
ymm3
);
// B0D0B1D1B4D4B5D5
auto
ymm7
=
_mm256_unpackhi_ps
(
ymm1
,
ymm3
);
// B2D2B3D3B6D6B7D7
auto
ymm8
=
_mm256_unpacklo_ps
(
ymm4
,
ymm6
);
// A0B0C0D0A4B4C4D4
auto
ymm9
=
_mm256_unpackhi_ps
(
ymm4
,
ymm6
);
// A1B1C1D1A5B5C5D5
auto
ymm10
=
_mm256_unpacklo_ps
(
ymm5
,
ymm7
);
// A2B2C2D2A6B6C6D6
auto
ymm11
=
_mm256_unpackhi_ps
(
ymm5
,
ymm7
);
// A3B3C3D3A7B7C7D7
_mm256_storeu2_m128_emulate
(
outptr
+
6
*
4
,
outptr
+
6
*
0
,
ymm8
);
_mm256_storeu2_m128_emulate
(
outptr
+
6
*
5
,
outptr
+
6
*
1
,
ymm9
);
_mm256_storeu2_m128_emulate
(
outptr
+
6
*
6
,
outptr
+
6
*
2
,
ymm10
);
_mm256_storeu2_m128_emulate
(
outptr
+
6
*
7
,
outptr
+
6
*
3
,
ymm11
);
float
other
[
16
];
auto
ymm12
=
_mm256_loadu_ps
(
inptr4
);
// E0E1E2E3E4E5E6E7
auto
ymm13
=
_mm256_loadu_ps
(
inptr5
);
// F0F1F2F3F4F5F6F7
_mm256_storeu_ps
(
other
,
ymm12
);
_mm256_storeu_ps
(
other
+
8
,
ymm13
);
for
(
size_t
i
=
0
;
i
<
8
;
i
++
)
{
outptr
[
6
*
i
+
4
]
=
other
[
i
];
outptr
[
6
*
i
+
5
]
=
other
[
8
+
i
];
}
}
DNN_AVX2_TARGET
void
transpose_6x4_1_s
(
const
float
*
inptr0
,
const
float
*
inptr1
,
const
float
*
inptr2
,
const
float
*
inptr3
,
const
float
*
inptr4
,
const
float
*
inptr5
,
float
*
outptr
)
{
const
std
::
uint32_t
arr
[
8
]
=
{
0
,
1
,
4
,
5
,
2
,
3
,
6
,
7
};
__m256i
order
=
_mm256_loadu_si256
((
const
__m256i
*
)
arr
);
auto
ymm0
=
_mm256_loadu2_m128_emulate
(
inptr2
,
inptr0
);
// A0A1A2A3C0C1C2C3
auto
ymm1
=
_mm256_loadu2_m128_emulate
(
inptr3
,
inptr1
);
// B0B1B2B3D0D1D2D3
auto
ymm2
=
_mm256_unpacklo_ps
(
ymm0
,
ymm1
);
// A0B0A1B1C0D0C1D1
auto
ymm3
=
_mm256_unpackhi_ps
(
ymm0
,
ymm1
);
// A2B2A3B3C2D2C3D3
auto
ymm4
=
_mm256_permutevar8x32_ps
(
ymm2
,
order
);
// A0B0C0D0A1B1C1D1
auto
ymm5
=
_mm256_permutevar8x32_ps
(
ymm3
,
order
);
// A2B2C2D2A3B3C3D3
_mm256_storeu2_m128_emulate
(
outptr
+
6
*
1
,
outptr
+
6
*
0
,
ymm4
);
_mm256_storeu2_m128_emulate
(
outptr
+
6
*
3
,
outptr
+
6
*
2
,
ymm5
);
float
other
[
8
];
auto
ymm6
=
_mm256_loadu2_m128_emulate
(
inptr5
,
inptr4
);
// E0E1E2E3E4E5E6E7
_mm256_storeu_ps
(
other
,
ymm6
);
for
(
size_t
i
=
0
;
i
<
4
;
i
++
)
{
outptr
[
6
*
i
+
4
]
=
other
[
i
];
outptr
[
6
*
i
+
5
]
=
other
[
4
+
i
];
}
}
DNN_AVX2_TARGET
void
transpose_4x8_1_s
(
const
float
*
inptr0
,
const
float
*
inptr1
,
const
float
*
inptr2
,
const
float
*
inptr3
,
float
*
outptr
)
{
auto
ymm0
=
_mm256_loadu_ps
(
inptr0
);
// A0A1A2A3A4A5A6A7
auto
ymm1
=
_mm256_loadu_ps
(
inptr1
);
// B0B1B2B3B4B5B6B7
auto
ymm2
=
_mm256_loadu_ps
(
inptr2
);
// C0C1C2C3C4C5C6C7
auto
ymm3
=
_mm256_loadu_ps
(
inptr3
);
// D0D1D2D3D4D5D6D7
auto
ymm4
=
_mm256_unpacklo_ps
(
ymm0
,
ymm2
);
// A0C0A1C1A4C4A5C5
auto
ymm5
=
_mm256_unpackhi_ps
(
ymm0
,
ymm2
);
// A2C2A3C3A6C6A7C7
auto
ymm6
=
_mm256_unpacklo_ps
(
ymm1
,
ymm3
);
// B0D0B1D1B4D4B5D5
auto
ymm7
=
_mm256_unpackhi_ps
(
ymm1
,
ymm3
);
// B2D2B3D3B6D6B7D7
auto
ymm8
=
_mm256_unpacklo_ps
(
ymm4
,
ymm6
);
// A0B0C0D0A4B4C4D4
auto
ymm9
=
_mm256_unpackhi_ps
(
ymm4
,
ymm6
);
// A1B1C1D1A5B5C5D5
auto
ymm10
=
_mm256_unpacklo_ps
(
ymm5
,
ymm7
);
// A2B2C2D2A6B6C6D6
auto
ymm11
=
_mm256_unpackhi_ps
(
ymm5
,
ymm7
);
// A3B3C3D3A7B7C7D7
ymm0
=
_mm256_permute2f128_ps
(
ymm8
,
ymm9
,
0x20
);
// A0B0C0D0A1B1C1D1
ymm1
=
_mm256_permute2f128_ps
(
ymm10
,
ymm11
,
0x20
);
// A2B2C2D2A3B3C3D3
ymm2
=
_mm256_permute2f128_ps
(
ymm8
,
ymm9
,
0x31
);
// A4B4C4D4A5B5C5D5
ymm3
=
_mm256_permute2f128_ps
(
ymm10
,
ymm11
,
0x31
);
// A6B6C6D6A7B7C7D7
_mm256_storeu_ps
(
outptr
+
8
*
0
,
ymm0
);
_mm256_storeu_ps
(
outptr
+
8
*
1
,
ymm1
);
_mm256_storeu_ps
(
outptr
+
8
*
2
,
ymm2
);
_mm256_storeu_ps
(
outptr
+
8
*
3
,
ymm3
);
}
DNN_AVX2_TARGET
void
transpose_4x4_1_s
(
const
float
*
inptr0
,
const
float
*
inptr1
,
const
float
*
inptr2
,
const
float
*
inptr3
,
float
*
outptr
)
{
const
std
::
uint32_t
arr
[
8
]
=
{
0
,
1
,
4
,
5
,
2
,
3
,
6
,
7
};
__m256i
order
=
_mm256_loadu_si256
((
const
__m256i
*
)
arr
);
auto
ymm0
=
_mm256_loadu2_m128_emulate
(
inptr2
,
inptr0
);
// A0A1A2A3C0C1C2C3
auto
ymm1
=
_mm256_loadu2_m128_emulate
(
inptr3
,
inptr1
);
// B0B1B2B3D0D1D2D3
auto
ymm2
=
_mm256_unpacklo_ps
(
ymm0
,
ymm1
);
// A0B0A1B1C0D0C1D1
auto
ymm3
=
_mm256_unpackhi_ps
(
ymm0
,
ymm1
);
// A2B2A3B3C2D2C3D3
auto
ymm4
=
_mm256_permutevar8x32_ps
(
ymm2
,
order
);
// A0B0C0D0A1B1C1D1
auto
ymm5
=
_mm256_permutevar8x32_ps
(
ymm3
,
order
);
// A2B2C2D2A3B3C3D3
_mm256_storeu_ps
(
outptr
,
ymm4
);
_mm256_storeu_ps
(
outptr
+
8
,
ymm5
);
}
void
transpose_2x16_1_s
(
const
float
*
inptr0
,
const
float
*
inptr1
,
float
*
outptr
)
{
for
(
size_t
i
=
0
;
i
<
16
;
i
++
)
{
*
outptr
++
=
inptr0
[
i
];
*
outptr
++
=
inptr1
[
i
];
}
}
void
transpose_2x8_1_s
(
const
float
*
inptr0
,
const
float
*
inptr1
,
float
*
outptr
)
{
for
(
size_t
i
=
0
;
i
<
8
;
i
++
)
{
*
outptr
++
=
inptr0
[
i
];
*
outptr
++
=
inptr1
[
i
];
}
}
void
transpose_2x4_1_s
(
const
float
*
inptr0
,
const
float
*
inptr1
,
float
*
outptr
)
{
for
(
size_t
i
=
0
;
i
<
4
;
i
++
)
{
*
outptr
++
=
inptr0
[
i
];
*
outptr
++
=
inptr1
[
i
];
}
}
DNN_AVX2_TARGET
void
interleave_1x16_1_s
(
const
float
*
inptr0
,
float
*
outptr
)
{
auto
ymm0
=
_mm256_loadu_ps
(
inptr0
);
auto
ymm1
=
_mm256_loadu_ps
(
inptr0
+
8
);
_mm256_storeu_ps
(
outptr
,
ymm0
);
_mm256_storeu_ps
(
outptr
+
8
,
ymm1
);
}
DNN_AVX2_TARGET
void
interleave_8x16_1_s
(
const
float
*
inptr0
,
const
float
*
inptr1
,
const
float
*
inptr2
,
const
float
*
inptr3
,
const
float
*
inptr4
,
const
float
*
inptr5
,
const
float
*
inptr6
,
const
float
*
inptr7
,
float
*
outptr
)
{
auto
ymm0
=
_mm256_loadu_ps
(
inptr0
);
auto
ymm1
=
_mm256_loadu_ps
(
inptr0
+
8
);
auto
ymm2
=
_mm256_loadu_ps
(
inptr1
);
auto
ymm3
=
_mm256_loadu_ps
(
inptr1
+
8
);
auto
ymm4
=
_mm256_loadu_ps
(
inptr2
);
auto
ymm5
=
_mm256_loadu_ps
(
inptr2
+
8
);
auto
ymm6
=
_mm256_loadu_ps
(
inptr3
);
auto
ymm7
=
_mm256_loadu_ps
(
inptr3
+
8
);
auto
ymm8
=
_mm256_loadu_ps
(
inptr4
);
auto
ymm9
=
_mm256_loadu_ps
(
inptr4
+
8
);
auto
ymm10
=
_mm256_loadu_ps
(
inptr5
);
auto
ymm11
=
_mm256_loadu_ps
(
inptr5
+
8
);
auto
ymm12
=
_mm256_loadu_ps
(
inptr6
);
auto
ymm13
=
_mm256_loadu_ps
(
inptr6
+
8
);
auto
ymm14
=
_mm256_loadu_ps
(
inptr7
);
auto
ymm15
=
_mm256_loadu_ps
(
inptr7
+
8
);
_mm256_storeu_ps
(
outptr
+
8
*
0
,
ymm0
);
_mm256_storeu_ps
(
outptr
+
8
*
1
,
ymm1
);
_mm256_storeu_ps
(
outptr
+
8
*
2
,
ymm2
);
_mm256_storeu_ps
(
outptr
+
8
*
3
,
ymm3
);
_mm256_storeu_ps
(
outptr
+
8
*
4
,
ymm4
);
_mm256_storeu_ps
(
outptr
+
8
*
5
,
ymm5
);
_mm256_storeu_ps
(
outptr
+
8
*
6
,
ymm6
);
_mm256_storeu_ps
(
outptr
+
8
*
7
,
ymm7
);
_mm256_storeu_ps
(
outptr
+
8
*
8
,
ymm8
);
_mm256_storeu_ps
(
outptr
+
8
*
9
,
ymm9
);
_mm256_storeu_ps
(
outptr
+
8
*
10
,
ymm10
);
_mm256_storeu_ps
(
outptr
+
8
*
11
,
ymm11
);
_mm256_storeu_ps
(
outptr
+
8
*
12
,
ymm12
);
_mm256_storeu_ps
(
outptr
+
8
*
13
,
ymm13
);
_mm256_storeu_ps
(
outptr
+
8
*
14
,
ymm14
);
_mm256_storeu_ps
(
outptr
+
8
*
15
,
ymm15
);
}
DNN_AVX2_TARGET
void
interleave_8x4_1_s
(
const
float
*
inptr0
,
const
float
*
inptr1
,
const
float
*
inptr2
,
const
float
*
inptr3
,
const
float
*
inptr4
,
const
float
*
inptr5
,
const
float
*
inptr6
,
const
float
*
inptr7
,
float
*
outptr
)
{
auto
ymm0
=
_mm256_loadu2_m128_emulate
(
inptr1
,
inptr0
);
// A0A1A2A3B0B1B2B3
auto
ymm1
=
_mm256_loadu2_m128_emulate
(
inptr3
,
inptr2
);
// C0C1C2C3D0D1D2D3
auto
ymm2
=
_mm256_loadu2_m128_emulate
(
inptr5
,
inptr4
);
// E0E1E2E3F0F1F2F3
auto
ymm3
=
_mm256_loadu2_m128_emulate
(
inptr7
,
inptr6
);
// G0G1G2G3H0H1H2H3
_mm256_storeu_ps
(
outptr
+
8
*
0
,
ymm0
);
_mm256_storeu_ps
(
outptr
+
8
*
1
,
ymm1
);
_mm256_storeu_ps
(
outptr
+
8
*
2
,
ymm2
);
_mm256_storeu_ps
(
outptr
+
8
*
3
,
ymm3
);
}
void
interleave_8x2_1_s
(
const
float
*
inptr0
,
const
float
*
inptr1
,
const
float
*
inptr2
,
const
float
*
inptr3
,
const
float
*
inptr4
,
const
float
*
inptr5
,
const
float
*
inptr6
,
const
float
*
inptr7
,
float
*
outptr
)
{
#define cb(i) \
*outptr++ = inptr##i[0]; \
*outptr++ = inptr##i[1];
UNROLL_CODE
(
cb
,
8
)
#undef cb
}
void
interleave_1x4_1_s
(
const
float
*
inptr0
,
float
*
outptr
)
{
outptr
[
0
]
=
inptr0
[
0
];
outptr
[
1
]
=
inptr0
[
1
];
outptr
[
2
]
=
inptr0
[
2
];
outptr
[
3
]
=
inptr0
[
3
];
}
void
interleave_8x6_1_s
(
const
float
*
inptr0
,
const
float
*
inptr1
,
const
float
*
inptr2
,
const
float
*
inptr3
,
const
float
*
inptr4
,
const
float
*
inptr5
,
const
float
*
inptr6
,
const
float
*
inptr7
,
float
*
outptr
)
{
#define cb(i) auto xmm##i = _mm_loadu_ps(inptr##i);
UNROLL_CODE
(
cb
,
8
)
#undef cb
#define cb(i) _mm_storeu_ps(outptr + 6 * i, xmm##i);
UNROLL_CODE
(
cb
,
8
)
#undef cb
#define cb(i) \
outptr[6 * i + 4] = inptr##i[4]; \
outptr[6 * i + 5] = inptr##i[5];
UNROLL_CODE
(
cb
,
8
)
#undef cb
}
void
interleave_1x6_1_s
(
const
float
*
inptr0
,
float
*
outptr
)
{
outptr
[
0
]
=
inptr0
[
0
];
outptr
[
1
]
=
inptr0
[
1
];
outptr
[
2
]
=
inptr0
[
2
];
outptr
[
3
]
=
inptr0
[
3
];
outptr
[
4
]
=
inptr0
[
4
];
outptr
[
5
]
=
inptr0
[
5
];
}
void
interleave_1x2_1_s
(
const
float
*
inptr0
,
float
*
outptr
)
{
outptr
[
0
]
=
inptr0
[
0
];
outptr
[
1
]
=
inptr0
[
1
];
}
static
inline
void
interleave_helper
(
const
float
*
inptr
,
float
*
outptr
,
int
unroll_k
,
int
ksize
,
float
val
)
{
int
k
=
0
;
for
(;
k
<
ksize
;
k
++
)
{
*
outptr
++
=
*
inptr
++
;
}
for
(;
k
<
unroll_k
;
k
++
)
{
*
outptr
++
=
val
;
}
}
void
interleave_1
(
const
float
*
inptr0
,
float
*
outptr
,
int
unroll_k
,
int
ksize
,
float
val
)
{
for
(
int
k
=
0
;
k
<
ksize
;
k
+=
unroll_k
)
{
int
size
=
min
(
unroll_k
,
ksize
-
k
);
interleave_helper
(
inptr0
,
outptr
,
unroll_k
,
size
,
val
);
inptr0
+=
size
;
outptr
+=
unroll_k
;
}
}
void
interleave_8
(
const
float
*
inptr0
,
const
float
*
inptr1
,
const
float
*
inptr2
,
const
float
*
inptr3
,
const
float
*
inptr4
,
const
float
*
inptr5
,
const
float
*
inptr6
,
const
float
*
inptr7
,
float
*
outptr
,
int
unroll_k
,
int
ksize
,
float
val
)
{
for
(
int
k
=
0
;
k
<
ksize
;
k
+=
unroll_k
)
{
int
size
=
min
(
unroll_k
,
ksize
-
k
);
interleave_helper
(
inptr0
,
outptr
,
unroll_k
,
size
,
val
);
inptr0
+=
size
;
outptr
+=
unroll_k
;
interleave_helper
(
inptr1
,
outptr
,
unroll_k
,
size
,
val
);
inptr1
+=
size
;
outptr
+=
unroll_k
;
interleave_helper
(
inptr2
,
outptr
,
unroll_k
,
size
,
val
);
inptr2
+=
size
;
outptr
+=
unroll_k
;
interleave_helper
(
inptr3
,
outptr
,
unroll_k
,
size
,
val
);
inptr3
+=
size
;
outptr
+=
unroll_k
;
interleave_helper
(
inptr4
,
outptr
,
unroll_k
,
size
,
val
);
inptr4
+=
size
;
outptr
+=
unroll_k
;
interleave_helper
(
inptr5
,
outptr
,
unroll_k
,
size
,
val
);
inptr5
+=
size
;
outptr
+=
unroll_k
;
interleave_helper
(
inptr6
,
outptr
,
unroll_k
,
size
,
val
);
inptr6
+=
size
;
outptr
+=
unroll_k
;
interleave_helper
(
inptr7
,
outptr
,
unroll_k
,
size
,
val
);
inptr7
+=
size
;
outptr
+=
unroll_k
;
}
}
DNN_AVX2_TARGET
MEGDNN_ATTRIBUTE_TARGET
(
"fma"
)
void
gemm_6x16_kern2x16
(
const
float
*
packA
,
const
float
*
packB
,
int
K
,
float
*
output
,
int
LDC
,
bool
is_first_k
,
int
m_remain
)
{
const
float
*
cur_b
=
packB
;
const
float
*
cur_a
=
packA
;
__m256
ymm0
,
ymm1
,
ymm2
,
ymm3
;
__m256
b_tmp0
,
b_tmp1
;
__m256
tmp
;
if
(
is_first_k
)
{
#define cb(i) ymm##i = _mm256_set1_ps(0.0f);
UNROLL_CODE
(
cb
,
4
)
#undef cb
}
else
{
ymm0
=
_mm256_loadu_ps
(
output
+
LDC
*
0
+
0
);
ymm1
=
_mm256_loadu_ps
(
output
+
LDC
*
0
+
8
);
ymm2
=
_mm256_loadu_ps
(
output
+
LDC
*
1
+
0
);
ymm3
=
_mm256_loadu_ps
(
output
+
LDC
*
1
+
8
);
}
b_tmp0
=
_mm256_loadu_ps
(
cur_b
);
b_tmp1
=
_mm256_loadu_ps
(
cur_b
+
8
);
size_t
i
=
0
;
for
(;
i
+
2
<=
K
;
i
+=
2
)
{
cur_b
+=
16
;
#define CAL_OUPUT(i, first, second) \
tmp = _mm256_broadcast_ss(cur_a + i); \
ymm##first = _mm256_fmadd_ps(b_tmp0, tmp, ymm##first); \
ymm##second = _mm256_fmadd_ps(b_tmp1, tmp, ymm##second);
CAL_OUPUT
(
0
,
0
,
1
)
CAL_OUPUT
(
1
,
2
,
3
)
b_tmp0
=
_mm256_loadu_ps
(
cur_b
);
b_tmp1
=
_mm256_loadu_ps
(
cur_b
+
8
);
cur_b
+=
16
;
CAL_OUPUT
(
2
,
0
,
1
)
CAL_OUPUT
(
3
,
2
,
3
)
cur_a
+=
4
;
b_tmp0
=
_mm256_loadu_ps
(
cur_b
);
b_tmp1
=
_mm256_loadu_ps
(
cur_b
+
8
);
}
if
(
i
<
K
)
{
CAL_OUPUT
(
0
,
0
,
1
)
CAL_OUPUT
(
1
,
2
,
3
)
}
#undef CAL_OUPUT
switch
(
m_remain
)
{
case
2
:
_mm256_storeu_ps
(
output
+
LDC
*
1
+
0
,
ymm2
);
_mm256_storeu_ps
(
output
+
LDC
*
1
+
8
,
ymm3
);
case
1
:
_mm256_storeu_ps
(
output
+
LDC
*
0
+
0
,
ymm0
);
_mm256_storeu_ps
(
output
+
LDC
*
0
+
8
,
ymm1
);
default:
break
;
}
}
DNN_AVX2_TARGET
MEGDNN_ATTRIBUTE_TARGET
(
"fma"
)
void
gemm_6x16_kern6x4
(
const
float
*
packA
,
const
float
*
packB
,
int
K
,
float
*
output
,
int
LDC
,
bool
is_first_k
,
int
n_remain
)
{
const
float
*
cur_b
=
packB
;
const
float
*
cur_a
=
packA
;
__m128
xmm0
,
xmm1
,
xmm2
,
xmm3
,
xmm4
,
xmm5
;
__m128
tmp_a
,
tmp_b
;
if
(
is_first_k
)
{
xmm0
=
_mm_set1_ps
(
0.0
f
);
xmm1
=
_mm_set1_ps
(
0.0
f
);
xmm2
=
_mm_set1_ps
(
0.0
f
);
xmm3
=
_mm_set1_ps
(
0.0
f
);
xmm4
=
_mm_set1_ps
(
0.0
f
);
xmm5
=
_mm_set1_ps
(
0.0
f
);
}
else
{
xmm0
=
_mm_loadu_ps
(
output
+
LDC
*
0
);
xmm1
=
_mm_loadu_ps
(
output
+
LDC
*
1
);
xmm2
=
_mm_loadu_ps
(
output
+
LDC
*
2
);
xmm3
=
_mm_loadu_ps
(
output
+
LDC
*
3
);
xmm4
=
_mm_loadu_ps
(
output
+
LDC
*
4
);
xmm5
=
_mm_loadu_ps
(
output
+
LDC
*
5
);
}
for
(
size_t
i
=
0
;
i
<
K
;
i
++
)
{
tmp_b
=
_mm_loadu_ps
(
cur_b
);
cur_b
+=
4
;
tmp_a
=
_mm_broadcast_ss
(
cur_a
);
xmm0
=
_mm_fmadd_ps
(
tmp_a
,
tmp_b
,
xmm0
);
tmp_a
=
_mm_broadcast_ss
(
cur_a
+
1
);
xmm1
=
_mm_fmadd_ps
(
tmp_a
,
tmp_b
,
xmm1
);
tmp_a
=
_mm_broadcast_ss
(
cur_a
+
2
);
xmm2
=
_mm_fmadd_ps
(
tmp_a
,
tmp_b
,
xmm2
);
tmp_a
=
_mm_broadcast_ss
(
cur_a
+
3
);
xmm3
=
_mm_fmadd_ps
(
tmp_a
,
tmp_b
,
xmm3
);
tmp_a
=
_mm_broadcast_ss
(
cur_a
+
4
);
xmm4
=
_mm_fmadd_ps
(
tmp_a
,
tmp_b
,
xmm4
);
tmp_a
=
_mm_broadcast_ss
(
cur_a
+
5
);
xmm5
=
_mm_fmadd_ps
(
tmp_a
,
tmp_b
,
xmm5
);
cur_a
+=
6
;
}
if
(
n_remain
==
4
)
{
_mm_storeu_ps
(
output
+
LDC
*
0
,
xmm0
);
_mm_storeu_ps
(
output
+
LDC
*
1
,
xmm1
);
_mm_storeu_ps
(
output
+
LDC
*
2
,
xmm2
);
_mm_storeu_ps
(
output
+
LDC
*
3
,
xmm3
);
_mm_storeu_ps
(
output
+
LDC
*
4
,
xmm4
);
_mm_storeu_ps
(
output
+
LDC
*
5
,
xmm5
);
}
else
{
float
dst
[
6
*
4
];
_mm_storeu_ps
(
dst
+
4
*
0
,
xmm0
);
_mm_storeu_ps
(
dst
+
4
*
1
,
xmm1
);
_mm_storeu_ps
(
dst
+
4
*
2
,
xmm2
);
_mm_storeu_ps
(
dst
+
4
*
3
,
xmm3
);
_mm_storeu_ps
(
dst
+
4
*
4
,
xmm4
);
_mm_storeu_ps
(
dst
+
4
*
5
,
xmm5
);
for
(
size_t
i
=
0
;
i
<
n_remain
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
6
;
j
++
)
{
output
[
LDC
*
j
+
i
]
=
dst
[
4
*
j
+
i
];
}
}
}
}
DNN_AVX2_TARGET
MEGDNN_ATTRIBUTE_TARGET
(
"fma"
)
void
gemm_6x16_kern2x4
(
const
float
*
packA
,
const
float
*
packB
,
int
K
,
float
*
output
,
int
LDC
,
bool
is_first_k
,
int
m_remain
,
int
n_remain
)
{
const
float
*
cur_b
=
packB
;
const
float
*
cur_a
=
packA
;
__m128
xmm0
,
xmm1
;
__m128
tmp_a
,
tmp_b
;
if
(
is_first_k
)
{
xmm0
=
_mm_set1_ps
(
0.0
f
);
xmm1
=
_mm_set1_ps
(
0.0
f
);
}
else
{
xmm0
=
_mm_loadu_ps
(
output
+
LDC
*
0
);
xmm1
=
_mm_loadu_ps
(
output
+
LDC
*
1
);
}
for
(
size_t
i
=
0
;
i
<
K
;
i
++
)
{
tmp_b
=
_mm_loadu_ps
(
cur_b
);
cur_b
+=
4
;
tmp_a
=
_mm_broadcast_ss
(
cur_a
);
xmm0
=
_mm_fmadd_ps
(
tmp_a
,
tmp_b
,
xmm0
);
tmp_a
=
_mm_broadcast_ss
(
cur_a
+
1
);
xmm1
=
_mm_fmadd_ps
(
tmp_a
,
tmp_b
,
xmm1
);
cur_a
+=
2
;
}
float
dst
[
2
*
4
];
_mm_storeu_ps
(
dst
+
4
*
0
,
xmm0
);
_mm_storeu_ps
(
dst
+
4
*
1
,
xmm1
);
for
(
size_t
i
=
0
;
i
<
n_remain
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
m_remain
;
j
++
)
{
output
[
LDC
*
j
+
i
]
=
dst
[
4
*
j
+
i
];
}
}
}
DNN_AVX2_TARGET
MEGDNN_ATTRIBUTE_TARGET
(
"fma"
)
void
gemm_6x16_kern6x16
(
const
float
*
packA
,
const
float
*
packB
,
int
K
,
float
*
output
,
int
LDC
,
bool
is_first_k
)
{
const
float
*
cur_b
=
packB
;
const
float
*
cur_a
=
packA
;
__m256
ymm0
,
ymm1
,
ymm2
,
ymm3
,
ymm4
,
ymm5
,
ymm6
,
ymm7
,
ymm8
,
ymm9
,
ymm10
,
ymm11
;
__m256
b_tmp0
,
b_tmp1
;
__m256
tmp
;
if
(
is_first_k
)
{
#define cb(i) ymm##i = _mm256_set1_ps(0.0f);
UNROLL_CODE
(
cb
,
12
)
#undef cb
}
else
{
ymm0
=
_mm256_loadu_ps
(
output
+
LDC
*
0
+
0
);
ymm1
=
_mm256_loadu_ps
(
output
+
LDC
*
0
+
8
);
ymm2
=
_mm256_loadu_ps
(
output
+
LDC
*
1
+
0
);
ymm3
=
_mm256_loadu_ps
(
output
+
LDC
*
1
+
8
);
ymm4
=
_mm256_loadu_ps
(
output
+
LDC
*
2
+
0
);
ymm5
=
_mm256_loadu_ps
(
output
+
LDC
*
2
+
8
);
ymm6
=
_mm256_loadu_ps
(
output
+
LDC
*
3
+
0
);
ymm7
=
_mm256_loadu_ps
(
output
+
LDC
*
3
+
8
);
ymm8
=
_mm256_loadu_ps
(
output
+
LDC
*
4
+
0
);
ymm9
=
_mm256_loadu_ps
(
output
+
LDC
*
4
+
8
);
ymm10
=
_mm256_loadu_ps
(
output
+
LDC
*
5
+
0
);
ymm11
=
_mm256_loadu_ps
(
output
+
LDC
*
5
+
8
);
}
b_tmp0
=
_mm256_loadu_ps
(
cur_b
);
b_tmp1
=
_mm256_loadu_ps
(
cur_b
+
8
);
size_t
i
=
0
;
for
(;
i
+
2
<=
K
;
i
+=
2
)
{
cur_b
+=
16
;
#define CAL_OUPUT(i, first, second) \
tmp = _mm256_broadcast_ss(cur_a + i); \
ymm##first = _mm256_fmadd_ps(b_tmp0, tmp, ymm##first); \
ymm##second = _mm256_fmadd_ps(b_tmp1, tmp, ymm##second);
CAL_OUPUT
(
0
,
0
,
1
)
CAL_OUPUT
(
1
,
2
,
3
)
CAL_OUPUT
(
2
,
4
,
5
)
CAL_OUPUT
(
3
,
6
,
7
)
CAL_OUPUT
(
4
,
8
,
9
)
CAL_OUPUT
(
5
,
10
,
11
)
b_tmp0
=
_mm256_loadu_ps
(
cur_b
);
b_tmp1
=
_mm256_loadu_ps
(
cur_b
+
8
);
cur_b
+=
16
;
CAL_OUPUT
(
6
,
0
,
1
)
CAL_OUPUT
(
7
,
2
,
3
)
CAL_OUPUT
(
8
,
4
,
5
)
CAL_OUPUT
(
9
,
6
,
7
)
CAL_OUPUT
(
10
,
8
,
9
)
CAL_OUPUT
(
11
,
10
,
11
)
cur_a
+=
12
;
b_tmp0
=
_mm256_loadu_ps
(
cur_b
);
b_tmp1
=
_mm256_loadu_ps
(
cur_b
+
8
);
}
if
(
i
<
K
)
{
CAL_OUPUT
(
0
,
0
,
1
)
CAL_OUPUT
(
1
,
2
,
3
)
CAL_OUPUT
(
2
,
4
,
5
)
CAL_OUPUT
(
3
,
6
,
7
)
CAL_OUPUT
(
4
,
8
,
9
)
CAL_OUPUT
(
5
,
10
,
11
)
}
#undef CAL_OUPUT
_mm256_storeu_ps
(
output
+
LDC
*
0
+
0
,
ymm0
);
_mm256_storeu_ps
(
output
+
LDC
*
0
+
8
,
ymm1
);
_mm256_storeu_ps
(
output
+
LDC
*
1
+
0
,
ymm2
);
_mm256_storeu_ps
(
output
+
LDC
*
1
+
8
,
ymm3
);
_mm256_storeu_ps
(
output
+
LDC
*
2
+
0
,
ymm4
);
_mm256_storeu_ps
(
output
+
LDC
*
2
+
8
,
ymm5
);
_mm256_storeu_ps
(
output
+
LDC
*
3
+
0
,
ymm6
);
_mm256_storeu_ps
(
output
+
LDC
*
3
+
8
,
ymm7
);
_mm256_storeu_ps
(
output
+
LDC
*
4
+
0
,
ymm8
);
_mm256_storeu_ps
(
output
+
LDC
*
4
+
8
,
ymm9
);
_mm256_storeu_ps
(
output
+
LDC
*
5
+
0
,
ymm10
);
_mm256_storeu_ps
(
output
+
LDC
*
5
+
8
,
ymm11
);
}
void
gemm_6x16_kern
(
const
float
*
packA
,
const
float
*
packB
,
size_t
M
,
size_t
N
,
size_t
K
,
float
*
C
,
size_t
LDC
,
int
is_first_k
)
{
size_t
n
=
0
;
const
int
K2
=
K
*
2
;
const
int
K4
=
K
*
4
;
const
int
K6
=
K
*
6
;
const
int
K16
=
K
*
16
;
const
int
A_INTERLEAVE6
=
6
;
const
int
A_INTERLEAVE2
=
2
;
const
int
B_INTERLEAVE16
=
16
;
const
int
B_INTERLEAVE4
=
4
;
auto
*
cur_packB
=
packB
;
for
(;
n
+
B_INTERLEAVE16
<=
N
;
n
+=
B_INTERLEAVE16
)
{
size_t
m
=
0
;
auto
output
=
C
+
n
;
auto
*
cur_packA
=
packA
;
for
(;
m
+
A_INTERLEAVE6
<=
M
;
m
+=
A_INTERLEAVE6
)
{
gemm_6x16_kern6x16
(
cur_packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
);
output
+=
A_INTERLEAVE6
*
LDC
;
cur_packA
+=
K6
;
}
for
(;
m
<
M
;
m
+=
A_INTERLEAVE2
)
{
gemm_6x16_kern2x16
(
cur_packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
min
(
M
-
m
,
2
));
output
+=
A_INTERLEAVE2
*
LDC
;
cur_packA
+=
K2
;
}
cur_packB
+=
K16
;
}
for
(;
n
<
N
;
n
+=
B_INTERLEAVE4
)
{
size_t
m
=
0
;
auto
output
=
C
+
n
;
auto
*
cur_packA
=
packA
;
for
(;
m
+
A_INTERLEAVE6
<=
M
;
m
+=
A_INTERLEAVE6
)
{
gemm_6x16_kern6x4
(
cur_packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
min
(
N
-
n
,
4
));
output
+=
A_INTERLEAVE6
*
LDC
;
cur_packA
+=
K6
;
}
for
(;
m
<
M
;
m
+=
A_INTERLEAVE2
)
{
gemm_6x16_kern2x4
(
cur_packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
min
(
M
-
m
,
2
),
min
(
N
-
n
,
4
));
output
+=
A_INTERLEAVE2
*
LDC
;
cur_packA
+=
K2
;
}
cur_packB
+=
K4
;
}
}
void
gemm_6x16_pack_A_t
(
float
*
outptr
,
const
float
*
inptr
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
)
{
size_t
ksize
=
kmax
-
k0
;
size_t
ksize6
=
ksize
*
6
;
size_t
ksize2
=
ksize
*
2
;
float
*
outptr_base6
=
outptr
;
float
*
outptr_base2
=
outptr_base6
+
(
xmax
-
x0
)
/
6
*
ksize6
;
size_t
k
=
k0
;
for
(;
k
+
7
<
kmax
;
k
+=
8
)
{
const
float
*
cur_inptr
=
inptr
+
k
*
ldin
+
k0
;
#define cb(i) const float *inptr##i = cur_inptr + ldin * i;
UNROLL_CODE
(
cb
,
8
)
#undef cb
#define cb(i) __builtin_prefetch(inptr##i, 0, 3);
UNROLL_CODE
(
cb
,
8
)
#undef cb
int
x
=
x0
;
float
*
outptr
=
outptr_base6
;
for
(;
x
+
6
<=
xmax
;
x
+=
6
)
{
interleave_8x6_1_s
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
inptr6
,
inptr7
,
outptr
);
#define cb(i) inptr##i += 6;
UNROLL_CODE
(
cb
,
8
)
#undef cb
outptr
+=
ksize6
;
}
outptr
=
outptr_base2
;
for
(;
x
+
2
<=
xmax
;
x
+=
2
)
{
interleave_8x2_1_s
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
inptr6
,
inptr7
,
outptr
);
#define cb(i) inptr##i += 2;
UNROLL_CODE
(
cb
,
8
)
#undef cb
outptr
+=
ksize2
;
}
if
(
x
<
xmax
)
{
interleave_8
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
inptr6
,
inptr7
,
outptr
,
2
,
xmax
-
x
,
0
);
inptr0
+=
xmax
-
x
;
inptr1
+=
xmax
-
x
;
inptr2
+=
xmax
-
x
;
inptr3
+=
xmax
-
x
;
inptr4
+=
xmax
-
x
;
inptr5
+=
xmax
-
x
;
inptr6
+=
xmax
-
x
;
inptr7
+=
xmax
-
x
;
}
outptr_base6
+=
8
*
6
;
outptr_base2
+=
8
*
2
;
}
for
(;
k
<
kmax
;
k
++
)
{
const
float
*
inptr0
=
inptr
+
k
*
ldin
+
k0
;
__builtin_prefetch
(
inptr0
,
0
,
3
);
int
x
=
x0
;
float
*
outptr
=
outptr_base6
;
for
(;
x
+
6
<=
xmax
;
x
+=
6
)
{
interleave_1x6_1_s
(
inptr0
,
outptr
);
inptr0
+=
6
;
outptr
+=
ksize6
;
}
outptr
=
outptr_base2
;
for
(;
x
+
2
<=
xmax
;
x
+=
2
)
{
interleave_1x2_1_s
(
inptr0
,
outptr
);
inptr0
+=
2
;
outptr
+=
ksize2
;
}
if
(
x
<
xmax
)
{
interleave_1
(
inptr0
,
outptr
,
2
,
xmax
-
x
,
0
);
inptr0
+=
xmax
-
x
;
outptr
+=
2
;
}
outptr_base6
+=
6
;
outptr_base2
+=
2
;
}
}
void
gemm_6x16_pack_A_n
(
float
*
outptr
,
const
float
*
inptr
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
)
{
float
zerobuff
[
16
];
memset
(
zerobuff
,
0
,
sizeof
(
float
)
*
16
);
size_t
y
=
y0
;
const
size_t
PACK_SIZE_96
=
6
*
16
;
const
size_t
PACK_SIZE_48
=
6
*
8
;
const
size_t
PACK_SIZE_24
=
6
*
4
;
const
size_t
PACK_SIZE_32
=
4
*
8
;
const
size_t
PACK_SIZE_16
=
4
*
4
;
const
size_t
PACK_SIZE_8
=
4
*
2
;
for
(;
y
+
5
<
ymax
;
y
+=
6
)
{
const
float
*
cur_inptr
=
inptr
+
y
*
ldin
+
k0
;
#define cb(i) const float *inptr##i = cur_inptr + ldin * i;
UNROLL_CODE
(
cb
,
6
)
#undef cb
#define cb(i) __builtin_prefetch(inptr##i, 0, 3);
UNROLL_CODE
(
cb
,
6
)
#undef cb
int
x
=
(
kmax
-
k0
);
for
(;
x
>
15
;
x
-=
16
)
{
transpose_6x16_1_s
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
outptr
);
#define cb(i) inptr##i += 16;
UNROLL_CODE
(
cb
,
6
)
#undef cb
outptr
+=
PACK_SIZE_96
;
}
for
(;
x
>
7
;
x
-=
8
)
{
transpose_6x8_1_s
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
outptr
);
#define cb(i) inptr##i += 8;
UNROLL_CODE
(
cb
,
6
)
#undef cb
outptr
+=
PACK_SIZE_48
;
}
for
(;
x
>
3
;
x
-=
4
)
{
transpose_6x4_1_s
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
outptr
);
#define cb(i) inptr##i += 4;
UNROLL_CODE
(
cb
,
6
)
#undef cb
outptr
+=
PACK_SIZE_24
;
}
for
(;
x
>
0
;
x
--
)
{
#define cb(i) *outptr++ = *inptr##i++;
UNROLL_CODE
(
cb
,
6
)
#undef cb
}
}
for
(;
y
<
ymax
;
y
+=
2
)
{
const
float
*
cur_inptr
=
inptr
+
y
*
ldin
+
k0
;
#define cb(i) const float *inptr##i = cur_inptr + ldin * i;
UNROLL_CODE
(
cb
,
2
)
#undef cb
#define cb(i) __builtin_prefetch(inptr##i, 0, 3);
UNROLL_CODE
(
cb
,
2
)
#undef cb
int
x
=
kmax
-
k0
;
for
(;
x
>
15
;
x
-=
16
)
{
if
((
y
+
1
)
>=
ymax
)
{
inptr1
=
zerobuff
;
}
transpose_2x16_1_s
(
inptr0
,
inptr1
,
outptr
);
#define cb(i) inptr##i += 16;
UNROLL_CODE
(
cb
,
2
)
#undef cb
outptr
+=
PACK_SIZE_32
;
}
for
(;
x
>
7
;
x
-=
8
)
{
if
((
y
+
1
)
>=
ymax
)
{
inptr1
=
zerobuff
;
}
transpose_2x8_1_s
(
inptr0
,
inptr1
,
outptr
);
#define cb(i) inptr##i += 8;
UNROLL_CODE
(
cb
,
2
)
#undef cb
outptr
+=
PACK_SIZE_16
;
}
for
(;
x
>
3
;
x
-=
4
)
{
if
((
y
+
1
)
>=
ymax
)
{
inptr1
=
zerobuff
;
}
transpose_2x4_1_s
(
inptr0
,
inptr1
,
outptr
);
#define cb(i) inptr##i += 4;
UNROLL_CODE
(
cb
,
2
)
#undef cb
outptr
+=
PACK_SIZE_8
;
}
if
(
x
>
0
)
{
if
((
y
+
1
)
>=
ymax
)
{
inptr1
=
zerobuff
;
}
for
(
size_t
i
=
0
;
i
<
x
;
i
++
)
{
*
outptr
++
=
*
inptr0
++
;
*
outptr
++
=
*
inptr1
++
;
}
}
}
}
void
gemm_6x16_pack_B_t
(
float
*
outptr
,
const
float
*
inptr
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
)
{
float
zerobuff
[
16
];
memset
(
zerobuff
,
0
,
sizeof
(
float
)
*
16
);
const
size_t
PACK_SIZE_128
=
8
*
16
;
const
size_t
PACK_SIZE_64
=
4
*
16
;
const
size_t
PACK_SiZE_32
=
4
*
8
;
const
size_t
PACK_SIZE_16
=
4
*
4
;
size_t
y
=
y0
;
for
(;
y
+
15
<
ymax
;
y
+=
16
)
{
const
float
*
cur_inptr
=
inptr
+
y
*
ldin
+
k0
;
#define cb(i) const float *inptr##i = cur_inptr + ldin * i;
UNROLL_CODE
(
cb
,
16
)
#undef cb
#define cb(i) __builtin_prefetch(inptr##i, 0, 3);
UNROLL_CODE
(
cb
,
16
)
#undef cb
int
x
=
(
kmax
-
k0
);
for
(;
x
>
7
;
x
-=
8
)
{
transpose_16x8_1_s
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
inptr6
,
inptr7
,
inptr8
,
inptr9
,
inptr10
,
inptr11
,
inptr12
,
inptr13
,
inptr14
,
inptr15
,
outptr
);
#define cb(i) inptr##i += 8;
UNROLL_CODE
(
cb
,
16
)
#undef cb
outptr
+=
PACK_SIZE_128
;
}
for
(;
x
>
3
;
x
-=
4
)
{
transpose_16x4_1_s
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
inptr6
,
inptr7
,
inptr8
,
inptr9
,
inptr10
,
inptr11
,
inptr12
,
inptr13
,
inptr14
,
inptr15
,
outptr
);
#define cb(i) inptr##i += 4;
UNROLL_CODE
(
cb
,
16
)
#undef cb
outptr
+=
PACK_SIZE_64
;
}
for
(;
x
>
0
;
x
--
)
{
#define cb(i) *outptr++ = *inptr##i++;
UNROLL_CODE
(
cb
,
16
)
#undef cb
}
}
for
(;
y
<
ymax
;
y
+=
4
)
{
const
float
*
cur_inptr
=
inptr
+
y
*
ldin
+
k0
;
#define cb(i) const float *inptr##i = cur_inptr + ldin * i;
UNROLL_CODE
(
cb
,
4
)
#undef cb
#define cb(i) __builtin_prefetch(inptr##i, 0, 3);
UNROLL_CODE
(
cb
,
4
)
#undef cb
int
x
=
kmax
-
k0
;
for
(;
x
>
7
;
x
-=
8
)
{
if
((
y
+
3
)
>=
ymax
)
{
switch
((
y
+
3
)
-
ymax
)
{
case
2
:
inptr1
=
zerobuff
;
case
1
:
inptr2
=
zerobuff
;
case
0
:
inptr3
=
zerobuff
;
default:
break
;
}
}
transpose_4x8_1_s
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
outptr
);
#define cb(i) inptr##i += 8;
UNROLL_CODE
(
cb
,
4
)
#undef cb
outptr
+=
PACK_SiZE_32
;
}
for
(;
x
>
3
;
x
-=
4
)
{
if
((
y
+
3
)
>=
ymax
)
{
switch
((
y
+
3
)
-
ymax
)
{
case
2
:
inptr1
=
zerobuff
;
case
1
:
inptr2
=
zerobuff
;
case
0
:
inptr3
=
zerobuff
;
default:
break
;
}
}
transpose_4x4_1_s
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
outptr
);
#define cb(i) inptr##i += 4;
UNROLL_CODE
(
cb
,
4
)
#undef cb
outptr
+=
PACK_SIZE_16
;
}
if
(
x
>
0
)
{
if
((
y
+
3
)
>=
ymax
)
{
switch
((
y
+
3
)
-
ymax
)
{
case
2
:
inptr1
=
zerobuff
;
case
1
:
inptr2
=
zerobuff
;
case
0
:
inptr3
=
zerobuff
;
break
;
}
}
for
(
size_t
i
=
0
;
i
<
x
;
i
++
)
{
*
outptr
++
=
*
inptr0
++
;
*
outptr
++
=
*
inptr1
++
;
*
outptr
++
=
*
inptr2
++
;
*
outptr
++
=
*
inptr3
++
;
}
}
}
}
void
gemm_6x16_pack_B_n
(
float
*
outptr
,
const
float
*
inptr
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
)
{
size_t
ksize
=
kmax
-
k0
;
size_t
ksize16
=
ksize
*
16
;
size_t
ksize4
=
ksize
*
4
;
float
*
outptr_base16
=
outptr
;
float
*
outptr_base4
=
outptr_base16
+
(
xmax
-
x0
)
/
16
*
ksize16
;
size_t
k
=
k0
;
for
(;
k
+
7
<
kmax
;
k
+=
8
)
{
const
float
*
cur_inptr
=
inptr
+
k
*
ldin
+
k0
;
#define cb(i) const float *inptr##i = cur_inptr + ldin * i;
UNROLL_CODE
(
cb
,
8
)
#undef cb
#define cb(i) __builtin_prefetch(inptr##i, 0, 3);
UNROLL_CODE
(
cb
,
8
)
#undef cb
int
x
=
x0
;
float
*
outptr
=
outptr_base16
;
for
(;
x
+
16
<=
xmax
;
x
+=
16
)
{
interleave_8x16_1_s
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
inptr6
,
inptr7
,
outptr
);
#define cb(i) inptr##i += 16;
UNROLL_CODE
(
cb
,
8
)
#undef cb
outptr
+=
ksize16
;
}
outptr
=
outptr_base4
;
for
(;
x
+
4
<=
xmax
;
x
+=
4
)
{
interleave_8x4_1_s
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
inptr6
,
inptr7
,
outptr
);
#define cb(i) inptr##i += 4;
UNROLL_CODE
(
cb
,
8
)
#undef cb
outptr
+=
ksize4
;
}
if
(
x
<
xmax
)
{
interleave_8
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
inptr6
,
inptr7
,
outptr
,
4
,
xmax
-
x
,
0
);
inptr0
+=
xmax
-
x
;
inptr1
+=
xmax
-
x
;
inptr2
+=
xmax
-
x
;
inptr3
+=
xmax
-
x
;
inptr4
+=
xmax
-
x
;
inptr5
+=
xmax
-
x
;
inptr6
+=
xmax
-
x
;
inptr7
+=
xmax
-
x
;
}
outptr_base16
+=
8
*
16
;
outptr_base4
+=
8
*
4
;
}
for
(;
k
<
kmax
;
k
++
)
{
const
float
*
inptr0
=
inptr
+
k
*
ldin
+
k0
;
__builtin_prefetch
(
inptr0
,
0
,
3
);
int
x
=
x0
;
float
*
outptr
=
outptr_base16
;
for
(;
x
+
16
<=
xmax
;
x
+=
16
)
{
interleave_1x16_1_s
(
inptr0
,
outptr
);
inptr0
+=
16
;
outptr
+=
ksize16
;
}
outptr
=
outptr_base4
;
for
(;
x
+
4
<=
xmax
;
x
+=
4
)
{
interleave_1x4_1_s
(
inptr0
,
outptr
);
inptr0
+=
4
;
outptr
+=
ksize4
;
}
if
(
x
<
xmax
)
{
interleave_1
(
inptr0
,
outptr
,
4
,
xmax
-
x
,
0
);
inptr0
+=
xmax
-
x
;
outptr
+=
4
;
}
outptr_base16
+=
16
;
outptr_base4
+=
4
;
}
}
}
// namespace
#undef UNROLL_CODE
namespace
megdnn
{
namespace
x86
{
namespace
matmul
{
void
sgemm_pack_6x16_avx2
::
pack_A
(
float
*
out
,
const
float
*
in
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
transpose_A
)
const
{
if
(
!
transpose_A
)
gemm_6x16_pack_A_n
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
kmax
);
else
gemm_6x16_pack_A_t
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
kmax
);
}
void
sgemm_pack_6x16_avx2
::
pack_B
(
float
*
out
,
const
float
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose_B
)
const
{
if
(
!
transpose_B
)
gemm_6x16_pack_B_n
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
kmax
);
else
gemm_6x16_pack_B_t
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
kmax
);
}
void
sgemm_pack_6x16_avx2
::
kern
(
const
float
*
packA
,
const
float
*
packB
,
size_t
M
,
size_t
N
,
size_t
K
,
float
*
C
,
size_t
LDC
,
bool
is_first_k
,
const
float
*
bias
,
float
*
workspace
)
const
{
MEGDNN_MARK_USED_VAR
(
bias
);
MEGDNN_MARK_USED_VAR
(
workspace
);
gemm_6x16_kern
(
packA
,
packB
,
M
,
N
,
K
,
C
,
LDC
,
is_first_k
);
};
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
sgemm_pack_6x16_avx2
);
}
// namespace matmul
}
// namespace x86
}
// namespace megdnn
dnn/src/x86/matrix_mul/opr_impl.cpp
浏览文件 @
d2184af3
...
...
@@ -34,6 +34,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt8x8x16AVX2
algoint8x8x16avx2_m4n16k2
;
AlgoInt8x8x16SSE
algoint8x8x16sse_m4n8k2
;
AlgoF32MK8_8x8
algof32mk8_8x8
;
AlgoFloatAVX2M6N16
algof32_6x16
;
SmallVector
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
m_all_algos
;
fallback
::
MatrixMulImpl
::
AlgoBase
::
Mapper
m_all_algos_map
;
...
...
@@ -51,6 +52,7 @@ public:
m_all_algos
.
emplace_back
(
&
algoint8x8x32sse_m4n8k2
);
m_all_algos
.
emplace_back
(
&
algoint8x8x16sse_m4n8k2
);
m_all_algos
.
emplace_back
(
&
algof32mk8_8x8
);
m_all_algos
.
emplace_back
(
&
algof32_6x16
);
#if MEGDNN_X86_WITH_MKL_DNN
m_all_algos
.
emplace_back
(
&
algoint8x8x32mkldnn
);
#endif
...
...
dnn/src/x86/matrix_mul/opr_impl.h
浏览文件 @
d2184af3
...
...
@@ -68,7 +68,7 @@ private:
class
AlgoInt8x8x16SSE
;
class
AlgoPack
;
class
AlgoF32MK8_8x8
;
class
AlgoFloatAVX2M6N16
;
public:
static
const
AlgoPack
&
algo_pack
();
};
...
...
dnn/test/x86/accuracy_shake.cpp
浏览文件 @
d2184af3
...
...
@@ -98,6 +98,15 @@ TEST_F(X86, SHAKE_MATRIX_MUL_FORWARD) {
.
exec
({{
20
,
100
},
{
100
,
60
},
{}});
}
TEST_F
(
X86
,
SHAKE_MATRIX_MUL_6x16_FORWARD
)
{
AccuracyShakeChecker
<
MatrixMul
>
checker
(
handle
());
checker
.
set_before_exec_callback
(
AlgoGenerator
<
MatrixMul
>
(
"X86_F32_6x16"
));
checker
.
set_dtype
(
0
,
dtype
::
Float32
())
.
set_dtype
(
1
,
dtype
::
Float32
())
.
set_dtype
(
2
,
dtype
::
Float32
())
.
exec
({{
20
,
100
},
{
100
,
60
},
{}});
}
}
// namespace test
}
// namespace megdnn
...
...
dnn/test/x86/conv_bias.cpp
浏览文件 @
d2184af3
...
...
@@ -1171,6 +1171,110 @@ TEST_F(X86, CONV_BIAS_IM2COLMATMUL_FP32_NOPACK_PREPROCESS) {
#endif
TEST_F
(
X86_MULTI_THREADS
,
CONV_BIAS_IM2COLMATMUL_FP32_6x16
)
{
using
namespace
conv_bias
;
std
::
vector
<
TestArg
>
args
;
auto
run
=
[
&
](
size_t
oc
,
size_t
ic
,
size_t
w
,
size_t
h
,
size_t
kernel
,
size_t
p
,
NonlineMode
nonline_mode
)
{
if
(
w
+
2
*
p
<
kernel
||
h
+
2
*
p
<
kernel
)
return
;
param
::
ConvBias
param
;
param
.
stride_h
=
1
;
param
.
stride_w
=
1
;
param
.
pad_h
=
p
;
param
.
pad_w
=
p
;
param
.
nonlineMode
=
nonline_mode
;
//! no bias
args
.
emplace_back
(
param
,
TensorShape
{
1
,
ic
,
h
,
w
},
TensorShape
{
oc
,
ic
,
kernel
,
kernel
},
TensorShape
{});
args
.
emplace_back
(
param
,
TensorShape
{
1
,
ic
,
h
,
w
},
TensorShape
{
oc
,
ic
,
kernel
,
kernel
},
TensorShape
{
1
,
oc
,
1
,
1
});
args
.
emplace_back
(
param
,
TensorShape
{
1
,
ic
,
h
,
w
},
TensorShape
{
oc
,
ic
,
kernel
,
kernel
},
TensorShape
{
1
,
oc
,
(
h
+
2
*
p
-
kernel
)
/
param
.
stride_h
+
1
,
(
w
+
2
*
p
-
kernel
)
/
param
.
stride_w
+
1
});
};
for
(
size_t
kernel
:
{
2
,
3
,
4
,
5
,
6
,
7
})
for
(
size_t
ic
:
{
1
,
4
,
8
,
16
})
for
(
size_t
oc
:
{
1
,
4
,
8
,
16
,
300
})
for
(
size_t
p
:
{
0
,
2
})
for
(
size_t
size
:
{
8
,
24
})
for
(
NonlineMode
nonline_mode
:
{
NonlineMode
::
IDENTITY
,
NonlineMode
::
RELU
})
{
run
(
oc
,
ic
,
size
,
size
,
kernel
,
p
,
nonline_mode
);
}
run
(
2046
,
8
,
20
,
20
,
3
,
1
,
NonlineMode
::
IDENTITY
);
Checker
<
ConvBias
>
checker
(
handle
());
#define cb(algo_name) \
checker.set_before_exec_callback( \
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name)); \
for (auto&& arg : args) { \
checker.set_param(arg.param).execs( \
{arg.src, arg.filter, arg.bias, {}, {}}); \
}
cb
(
"IM2COLMATMUL:X86_F32_6x16:192"
);
}
TEST_F
(
X86
,
CONV_BIAS_IM2COLMATMUL_FP32_6x16
)
{
using
namespace
conv_bias
;
std
::
vector
<
TestArg
>
args
;
auto
run
=
[
&
](
size_t
oc
,
size_t
ic
,
size_t
w
,
size_t
h
,
size_t
kernel
,
size_t
p
,
NonlineMode
nonline_mode
)
{
if
(
w
+
2
*
p
<
kernel
||
h
+
2
*
p
<
kernel
)
return
;
param
::
ConvBias
param
;
param
.
stride_h
=
1
;
param
.
stride_w
=
1
;
param
.
pad_h
=
p
;
param
.
pad_w
=
p
;
param
.
nonlineMode
=
nonline_mode
;
//! no bias
args
.
emplace_back
(
param
,
TensorShape
{
1
,
ic
,
h
,
w
},
TensorShape
{
oc
,
ic
,
kernel
,
kernel
},
TensorShape
{});
args
.
emplace_back
(
param
,
TensorShape
{
1
,
ic
,
h
,
w
},
TensorShape
{
oc
,
ic
,
kernel
,
kernel
},
TensorShape
{
1
,
oc
,
1
,
1
});
args
.
emplace_back
(
param
,
TensorShape
{
1
,
ic
,
h
,
w
},
TensorShape
{
oc
,
ic
,
kernel
,
kernel
},
TensorShape
{
1
,
oc
,
(
h
+
2
*
p
-
kernel
)
/
param
.
stride_h
+
1
,
(
w
+
2
*
p
-
kernel
)
/
param
.
stride_w
+
1
});
};
for
(
size_t
kernel
:
{
2
,
3
,
4
,
5
,
6
,
7
})
for
(
size_t
ic
:
{
1
,
4
,
8
,
16
})
for
(
size_t
oc
:
{
1
,
4
,
8
,
16
,
300
})
for
(
size_t
p
:
{
0
,
2
})
for
(
size_t
size
:
{
8
,
24
})
for
(
NonlineMode
nonline_mode
:
{
NonlineMode
::
IDENTITY
,
NonlineMode
::
RELU
})
{
run
(
oc
,
ic
,
size
,
size
,
kernel
,
p
,
nonline_mode
);
}
run
(
2046
,
8
,
20
,
20
,
3
,
1
,
NonlineMode
::
IDENTITY
);
Checker
<
ConvBias
>
checker
(
handle
());
#define cb(algo_name) \
checker.set_before_exec_callback( \
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name)); \
for (auto&& arg : args) { \
checker.set_param(arg.param).execs( \
{arg.src, arg.filter, arg.bias, {}, {}}); \
}
cb
(
"IM2COLMATMUL:X86_F32_6x16:192"
);
#undef cb
}
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
TEST_F
(
X86_MULTI_THREADS
,
CONV_BIAS_IM2COLMATMUL_FP32_PACKA
)
{
using
namespace
conv_bias
;
...
...
@@ -1377,6 +1481,12 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_BLAS) {
check_conv_bias
(
args
,
handle
(),
"CONV1x1:X86_F32_BLAS:48"
);
}
TEST_F
(
X86_MULTI_THREADS
,
CONV_BIAS_CONV1X1_S1_FP32_6x16
)
{
using
namespace
conv_bias
;
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_conv_bias_1x1_args
(
false
,
false
);
check_conv_bias
(
args
,
handle
(),
"CONV1x1:X86_F32_6x16:48"
);
}
TEST_F
(
X86_MULTI_THREADS
,
CONV_BIAS_CONV1X1_S1_FP32_BLAS_NOPACK_REPROCESS
)
{
using
namespace
conv_bias
;
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_conv_bias_1x1_args
(
false
,
false
);
...
...
@@ -2627,6 +2737,76 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_F32) {
shapes_and_computation
.
clear
();
}
TEST_F
(
X86_BENCHMARK_MULTI_THREADS
,
BENCHMARK_CONVBIAS_IM2COL_F32_6x16
)
{
constexpr
size_t
RUNS
=
50
;
param
::
ConvBias
param
;
param
.
nonlineMode
=
param
::
ConvBias
::
NonlineMode
::
RELU
;
param
.
pad_h
=
1
;
param
.
pad_w
=
1
;
param
.
stride_h
=
1
;
param
.
stride_w
=
1
;
std
::
vector
<
DType
>
data_type
=
{
dtype
::
Float32
(),
dtype
::
Float32
(),
dtype
::
Float32
(),
dtype
::
Float32
()};
std
::
vector
<
std
::
pair
<
SmallVector
<
TensorShape
>
,
float
>>
shapes_and_computation
;
auto
bench_case
=
[
&
](
size_t
N
,
size_t
IC
,
size_t
OC
,
size_t
H
,
size_t
W
,
size_t
FS
,
size_t
group
)
{
SmallVector
<
TensorShape
>
shapes
{{
N
,
IC
,
H
,
W
},
{
OC
/
group
,
IC
/
group
,
FS
,
FS
},
{
1
,
OC
,
1
,
1
},
{},
{
N
,
OC
,
H
,
W
}};
TensorShape
dst
{
N
,
OC
,
H
,
W
};
float
computations
=
((
IC
/
group
)
*
FS
*
FS
*
dst
.
total_nr_elems
()
*
2
+
dst
.
total_nr_elems
())
*
1e-6
;
shapes_and_computation
.
push_back
(
std
::
make_pair
(
shapes
,
computations
));
};
bench_case
(
1
,
32
,
32
,
200
,
200
,
3
,
1
);
bench_case
(
1
,
32
,
32
,
200
,
200
,
3
,
1
);
bench_case
(
1
,
32
,
32
,
128
,
128
,
3
,
1
);
bench_case
(
1
,
32
,
32
,
128
,
128
,
3
,
1
);
bench_case
(
1
,
32
,
32
,
100
,
100
,
3
,
1
);
bench_case
(
1
,
32
,
32
,
100
,
100
,
3
,
1
);
bench_case
(
1
,
32
,
32
,
80
,
80
,
3
,
1
);
bench_case
(
1
,
32
,
32
,
80
,
80
,
3
,
1
);
bench_case
(
1
,
64
,
32
,
7
,
7
,
3
,
1
);
bench_case
(
1
,
64
,
64
,
7
,
7
,
3
,
1
);
bench_case
(
1
,
64
,
128
,
7
,
7
,
3
,
1
);
bench_case
(
1
,
64
,
256
,
7
,
7
,
3
,
1
);
bench_case
(
1
,
64
,
512
,
7
,
7
,
3
,
1
);
bench_case
(
1
,
64
,
1024
,
7
,
7
,
3
,
1
);
bench_case
(
1
,
64
,
32
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
64
,
64
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
64
,
128
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
64
,
256
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
64
,
512
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
64
,
1024
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
128
,
128
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
128
,
256
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
512
,
512
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
256
,
512
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
512
,
1024
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
1024
,
1024
,
14
,
14
,
3
,
1
);
std
::
string
algo_name
=
"IM2COLMATMUL:X86_F32_6x16:192"
;
printf
(
"Benchmark IM2COLMATMUL:X86_F32_6x16 algo
\n
"
);
benchmark_impl
(
param
,
shapes_and_computation
,
algo_name
,
RUNS
,
{
4
,
{
4
,
5
,
6
,
7
}},
{
1
,
{
4
}},
data_type
);
benchmark_impl
(
param
,
shapes_and_computation
,
algo_name
,
RUNS
,
{
4
,
{
4
,
5
,
6
,
7
}},
{
1
,
{
7
}},
data_type
);
benchmark_impl
(
param
,
shapes_and_computation
,
algo_name
,
RUNS
,
{
2
,
{
4
,
5
}},
{
1
,
{
4
}},
data_type
);
shapes_and_computation
.
clear
();
}
TEST_F
(
X86_BENCHMARK_MULTI_THREADS
,
BENCHMARK_CONVBIAS_IM2COL_F32_single_thread
)
{
constexpr
size_t
RUNS
=
50
;
...
...
@@ -2697,6 +2877,76 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS,
shapes_and_computation
.
clear
();
}
TEST_F
(
X86_BENCHMARK_MULTI_THREADS
,
BENCHMARK_CONVBIAS_IM2COL_F32_6X16_single_thread
)
{
constexpr
size_t
RUNS
=
50
;
param
::
ConvBias
param
;
param
.
nonlineMode
=
param
::
ConvBias
::
NonlineMode
::
RELU
;
param
.
pad_h
=
1
;
param
.
pad_w
=
1
;
param
.
stride_h
=
1
;
param
.
stride_w
=
1
;
std
::
vector
<
DType
>
data_type
=
{
dtype
::
Float32
(),
dtype
::
Float32
(),
dtype
::
Float32
(),
dtype
::
Float32
()};
std
::
vector
<
std
::
pair
<
SmallVector
<
TensorShape
>
,
float
>>
shapes_and_computation
;
auto
bench_case
=
[
&
](
size_t
N
,
size_t
IC
,
size_t
OC
,
size_t
H
,
size_t
W
,
size_t
FS
,
size_t
group
)
{
SmallVector
<
TensorShape
>
shapes
{{
N
,
IC
,
H
,
W
},
{
OC
/
group
,
IC
/
group
,
FS
,
FS
},
{
1
,
OC
,
1
,
1
},
{},
{
N
,
OC
,
H
,
W
}};
TensorShape
dst
{
N
,
OC
,
H
,
W
};
float
computations
=
((
IC
/
group
)
*
FS
*
FS
*
dst
.
total_nr_elems
()
*
2
+
dst
.
total_nr_elems
())
*
1e-6
;
shapes_and_computation
.
push_back
(
std
::
make_pair
(
shapes
,
computations
));
};
bench_case
(
1
,
32
,
32
,
200
,
200
,
3
,
1
);
bench_case
(
1
,
32
,
32
,
200
,
200
,
3
,
1
);
bench_case
(
1
,
32
,
32
,
128
,
128
,
3
,
1
);
bench_case
(
1
,
32
,
32
,
128
,
128
,
3
,
1
);
bench_case
(
1
,
32
,
32
,
100
,
100
,
3
,
1
);
bench_case
(
1
,
32
,
32
,
100
,
100
,
3
,
1
);
bench_case
(
1
,
32
,
32
,
80
,
80
,
3
,
1
);
bench_case
(
1
,
32
,
32
,
80
,
80
,
3
,
1
);
bench_case
(
1
,
64
,
32
,
7
,
7
,
3
,
1
);
bench_case
(
1
,
64
,
64
,
7
,
7
,
3
,
1
);
bench_case
(
1
,
64
,
128
,
7
,
7
,
3
,
1
);
bench_case
(
1
,
64
,
256
,
7
,
7
,
3
,
1
);
bench_case
(
1
,
64
,
512
,
7
,
7
,
3
,
1
);
bench_case
(
1
,
64
,
1024
,
7
,
7
,
3
,
1
);
bench_case
(
1
,
64
,
32
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
64
,
64
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
64
,
128
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
64
,
256
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
64
,
512
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
64
,
1024
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
128
,
128
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
128
,
256
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
512
,
512
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
256
,
512
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
512
,
1024
,
14
,
14
,
3
,
1
);
bench_case
(
1
,
1024
,
1024
,
14
,
14
,
3
,
1
);
std
::
string
algo_name
=
"IM2COLMATMUL:X86_F32_MKL_PACKA:192"
;
std
::
string
algo_name1
=
"IM2COLMATMUL:X86_F32_6x16:192"
;
printf
(
"Benchmark IM2COLMATMUL:X86_F32_6x16 algo
\n
"
);
benchmark_impl_comp
(
param
,
shapes_and_computation
,
algo_name
,
algo_name1
,
RUNS
,
{
1
,
{
4
}},
{
1
,
{
4
}},
data_type
);
benchmark_impl_comp
(
param
,
shapes_and_computation
,
algo_name
,
algo_name1
,
RUNS
,
{
1
,
{
7
}},
{
1
,
{
7
}},
data_type
);
shapes_and_computation
.
clear
();
}
TEST_F
(
X86_BENCHMARK_MULTI_THREADS
,
BENCHMARK_CONVBIAS_IM2COL_INT8X8X32
)
{
constexpr
size_t
RUNS
=
50
;
...
...
dnn/test/x86/matrix_mul.cpp
浏览文件 @
d2184af3
...
...
@@ -85,6 +85,13 @@ TEST_F(X86, MATRIX_MUL_AVX2_MK8_8X8) {
param
::
MatrixMul
::
Format
::
MK8
,
1
,
1e-3
,
false
);
}
TEST_F
(
X86
,
MATRIX_MUL_AVX2_6x16
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Float32
{},
dtype
::
Float32
{},
dtype
::
Float32
{},
handle
(),
"X86_F32_6x16"
,
param
::
MatrixMul
::
Format
::
DEFAULT
,
1
,
1e-3
,
false
);
}
#if MEGDNN_WITH_BENCHMARK
TEST_F
(
X86
,
BENCHMARK_MATRIX_MUL_AVX2_MK8_8X8
)
{
...
...
@@ -96,6 +103,14 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_AVX2_MK8_8X8) {
"X86_F32_BLAS"
);
}
TEST_F
(
X86
,
BENCHMARK_MATRIX_MUL_AVX2_6x16
)
{
auto
args
=
matrix_mul
::
get_benchmark_matmul_mk_packed_args
(
8
);
matrix_mul
::
benchmark_with_contrast
(
handle
(),
args
,
dtype
::
Float32
{},
dtype
::
Float32
{},
dtype
::
Float32
{},
"X86_F32_6x16"
,
param
::
MatrixMul
::
Format
::
DEFAULT
,
dtype
::
Float32
{},
dtype
::
Float32
{},
dtype
::
Float32
{},
"X86_F32_BLAS"
);
}
TEST_F
(
X86
,
BENCHMARK_MATRIX_MUL_8X8X32
)
{
constexpr
size_t
RUNS
=
50
;
auto
rng
=
std
::
make_unique
<
UniformIntRNG
>
(
-
127
,
127
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录