Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
5dbf218d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
5dbf218d
编写于
6月 12, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/x86): add sse 8816 matmul
GitOrigin-RevId: ed8d9ee5db3342ec534dc2e8a38b8613b8733b2b
上级
25b6a131
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
294 addition
and
113 deletion
+294
-113
dnn/src/x86/matrix_mul/algos.cpp
dnn/src/x86/matrix_mul/algos.cpp
+89
-6
dnn/src/x86/matrix_mul/algos.h
dnn/src/x86/matrix_mul/algos.h
+16
-1
dnn/src/x86/matrix_mul/int8/kernel_sse_4x8x2.h
dnn/src/x86/matrix_mul/int8/kernel_sse_4x8x2.h
+82
-81
dnn/src/x86/matrix_mul/int8/sse_strategy_4x8x2.cpp
dnn/src/x86/matrix_mul/int8/sse_strategy_4x8x2.cpp
+70
-24
dnn/src/x86/matrix_mul/int8/strategy.h
dnn/src/x86/matrix_mul/int8/strategy.h
+4
-0
dnn/src/x86/matrix_mul/opr_impl.cpp
dnn/src/x86/matrix_mul/opr_impl.cpp
+2
-0
dnn/src/x86/matrix_mul/opr_impl.h
dnn/src/x86/matrix_mul/opr_impl.h
+1
-0
dnn/test/x86/conv_bias.cpp
dnn/test/x86/conv_bias.cpp
+8
-1
dnn/test/x86/convolution.cpp
dnn/test/x86/convolution.cpp
+1
-0
dnn/test/x86/matrix_mul.cpp
dnn/test/x86/matrix_mul.cpp
+21
-0
未找到文件。
dnn/src/x86/matrix_mul/algos.cpp
浏览文件 @
5dbf218d
...
...
@@ -184,7 +184,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Vnni::get_kern(
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL
(
AlgoInt8x8x32Vnni
,
megdnn_x86_matmul_kern
,
5
,
megdnn_x86_matmul_kern
,
"AlgoInt8x8x32Vnni"
_hash
,
x86
::
matmul
::
gemm_int8_vnni_12x32x4
,
dt_int8
,
dt_int32
,
dt_uint8
);
#endif
...
...
@@ -318,6 +319,8 @@ void gemm_s8s8s32_sse_4x8x2(const MatrixMulImpl::KernParam& kern_param) {
}
}
// namespace
/*************************AlgoInt8x8x16AVX2********************/
void
MatrixMulImpl
::
AlgoInt8x8x16AVX2
::
gemm_s8s8s16_avx2_4x16x2
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
MEGDNN_MARK_USED_VAR
(
kern_param
);
...
...
@@ -389,9 +392,86 @@ size_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_workspace(
.
get_workspace_size
();
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL
(
AlgoInt8x8x16AVX2
,
megdnn_x86_matmul_kern
,
8
,
AlgoInt8x8x16AVX2
,
megdnn_x86_matmul_kern
,
"AlgoInt8x8x16AVX2"
_hash
,
x86
::
matmul
::
gemm_avx2_s8s8s16_4x16x2
,
dt_int8
,
dt_int16
,
dt_int16
);
/*************************AlgoInt8x8x16SSE********************/
void
MatrixMulImpl
::
AlgoInt8x8x16SSE
::
gemm_s8s8s16_sse_4x8x2
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
MEGDNN_MARK_USED_VAR
(
kern_param
);
MIDOUT_BEGIN
(
megdnn_x86_matmul_kern_sse_4x8x2
,
midout_iv
(
2
))
{
constexpr
int
cacheline
=
64
;
const
size_t
m
=
kern_param
.
M
;
const
size_t
n
=
kern_param
.
N
;
const
size_t
k
=
kern_param
.
K
;
const
bool
trans_a
=
kern_param
.
trA
;
const
bool
trans_b
=
kern_param
.
trB
;
const
size_t
lda
=
kern_param
.
LDA
;
const
size_t
ldb
=
kern_param
.
LDB
;
const
size_t
ldc
=
kern_param
.
LDC
;
auto
a_type
=
kern_param
.
A_type
;
auto
b_type
=
kern_param
.
B_type
;
auto
c_type
=
kern_param
.
C_type
;
const
auto
a_ptr
=
kern_param
.
A
<
dt_int8
>
();
const
auto
b_ptr
=
kern_param
.
B
<
dt_int8
>
();
auto
c_ptr
=
kern_param
.
C
<
dt_int16
>
();
x86
::
matmul
::
gemm_sse_s8s8s16_4x8x2
strategy
(
m
,
n
,
k
,
a_type
,
b_type
,
c_type
);
megdnn
::
matmul
::
GemmInterleaved
<
x86
::
matmul
::
gemm_sse_s8s8s16_4x8x2
>
(
m
,
n
,
k
,
trans_a
,
trans_b
,
strategy
,
cacheline
)
.
execute
(
a_ptr
,
lda
,
b_ptr
,
ldb
,
c_ptr
,
ldc
,
kern_param
.
workspace_ptr
);
}
MIDOUT_END
();
}
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoInt8x8x16SSE
::
get_kern
(
const
KernSizeParam
&
)
const
{
return
gemm_s8s8s16_sse_4x8x2
;
}
bool
MatrixMulImpl
::
AlgoInt8x8x16SSE
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
bool
is_ab_same
=
kern_size_param
.
A_type
.
enumv
()
==
kern_size_param
.
B_type
.
enumv
();
bool
is_type_ok
=
((
kern_size_param
.
A_type
.
enumv
()
==
DTypeEnum
::
Int8
&&
kern_size_param
.
C_type
.
enumv
()
==
DTypeEnum
::
Int16
)
||
(
kern_size_param
.
A_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
kern_size_param
.
C_type
.
enumv
()
==
DTypeEnum
::
QuantizedS16
));
bool
is_mode_ok
=
kern_size_param
.
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
&&
is_supported
(
SIMDType
::
SSE4_1
);
bool
is_param_ok
=
is_ab_same
&&
is_type_ok
&&
is_mode_ok
;
return
is_param_ok
;
}
bool
MatrixMulImpl
::
AlgoInt8x8x16SSE
::
preferred
(
const
KernSizeParam
&
)
const
{
return
true
;
}
size_t
MatrixMulImpl
::
AlgoInt8x8x16SSE
::
get_workspace
(
const
KernSizeParam
&
kern_param
)
const
{
constexpr
int
cacheline
=
64
;
const
size_t
m
=
kern_param
.
M
;
const
size_t
n
=
kern_param
.
N
;
const
size_t
k
=
kern_param
.
K
;
const
bool
trans_a
=
kern_param
.
trA
;
const
bool
trans_b
=
kern_param
.
trB
;
auto
a_type
=
kern_param
.
A_type
;
auto
b_type
=
kern_param
.
B_type
;
auto
c_type
=
kern_param
.
C_type
;
x86
::
matmul
::
gemm_sse_s8s8s16_4x8x2
strategy
(
m
,
n
,
k
,
a_type
,
b_type
,
c_type
);
return
megdnn
::
matmul
::
GemmInterleaved
<
x86
::
matmul
::
gemm_sse_s8s8s16_4x8x2
>
(
m
,
n
,
k
,
trans_a
,
trans_b
,
strategy
,
cacheline
)
.
get_workspace_size
();
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL
(
AlgoInt8x8x16SSE
,
megdnn_x86_matmul_kern
,
"AlgoInt8x8x16SSE"
_hash
,
x86
::
matmul
::
gemm_sse_s8s8s16_4x8x2
,
dt_int8
,
dt_int16
,
dt_int16
);
/*************************AlgoInt8x8x32AVX2M4N16K2********************/
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoInt8x8x32AVX2M4N16K2
::
get_kern
(
const
KernSizeParam
&
)
const
{
return
gemm_s8s8s32_avx2_4x16x2
;
...
...
@@ -426,8 +506,9 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace(
.
get_workspace_size
();
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL
(
AlgoInt8x8x32AVX2M4N16K2
,
megdnn_x86_matmul_kern
,
8
,
x86
::
matmul
::
gemm_avx2_s8s8s32_4x16x2
,
dt_int8
,
dt_int32
,
dt_int16
);
AlgoInt8x8x32AVX2M4N16K2
,
megdnn_x86_matmul_kern
,
"AlgoInt8x8x32AVX2M4N16K2"
_hash
,
x86
::
matmul
::
gemm_avx2_s8s8s32_4x16x2
,
dt_int8
,
dt_int32
,
dt_int16
);
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoInt8x8x32AVX2M2N4K16
::
get_kern
(
const
KernSizeParam
&
)
const
{
...
...
@@ -463,7 +544,8 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_workspace(
.
get_workspace_size
();
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL
(
AlgoInt8x8x32AVX2M2N4K16
,
megdnn_x86_matmul_kern
,
8
,
megdnn_x86_matmul_kern
,
"AlgoInt8x8x32AVX2M2N4K16"
_hash
,
x86
::
matmul
::
gemm_avx2_s8s8s32_2x4x16
,
dt_int8
,
dt_int32
);
...
...
@@ -501,7 +583,8 @@ size_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_workspace(
.
get_workspace_size
();
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL
(
AlgoInt8x8x32SSEM4N8K2
,
megdnn_x86_matmul_kern
,
9
,
megdnn_x86_matmul_kern
,
"AlgoInt8x8x32SSEM4N8K2"
_hash
,
x86
::
matmul
::
gemm_sse_s8s8s32_4x8x2
,
dt_int8
,
dt_int32
,
dt_int16
);
...
...
dnn/src/x86/matrix_mul/algos.h
浏览文件 @
5dbf218d
...
...
@@ -76,7 +76,6 @@ class MatrixMulImpl::AlgoInt8x8x16AVX2 : public AlgoBase {
private:
static
void
gemm_s8s8s16_avx2_4x16x2
(
const
MatrixMulImpl
::
KernParam
&
kern_param
);
static
MatrixMulImpl
::
AlgoInt8x8x32AVX2M4N16K2
m_algo
;
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
...
...
@@ -89,6 +88,22 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
class
MatrixMulImpl
::
AlgoInt8x8x16SSE
:
public
AlgoBase
{
private:
static
void
gemm_s8s8s16_sse_4x8x2
(
const
MatrixMulImpl
::
KernParam
&
kern_param
);
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"X86_INT8X8X16_SSE"
;
}
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
class
MatrixMulImpl
::
AlgoInt8x8x32SSEM4N8K2
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
...
...
dnn/src/x86/matrix_mul/int8/kernel_sse_4x8x2.h
浏览文件 @
5dbf218d
...
...
@@ -6,10 +6,17 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include <immintrin.h>
#ifdef WIN32
#include <avx2intrin.h>
#include <avxintrin.h>
#include <fmaintrin.h>
#include <smmintrin.h>
#endif
#include <cmath>
#include <cstdint>
#include <type_traits>
...
...
@@ -21,10 +28,44 @@ namespace x86 {
namespace
matmul_sse_4x8x2
{
template
<
typename
CType
>
MEGDNN_ATTRIBUTE_TARGET
(
"sse4.1"
)
void
store_overflow
(
void
*
ptr
,
__m128i
a
);
template
<
>
void
store_overflow
<
int16_t
>
(
void
*
ptr
,
__m128i
a
)
{
a
=
_mm_shufflelo_epi16
(
a
,
0x08
);
a
=
_mm_shufflehi_epi16
(
a
,
0x08
);
a
=
_mm_shuffle_epi32
(
a
,
0x08
);
_mm_storel_epi64
((
__m128i
*
)
ptr
,
a
);
}
template
<
>
void
store_overflow
<
int32_t
>
(
void
*
ptr
,
__m128i
a
)
{
_mm_storeu_si128
((
__m128i
*
)(
ptr
),
a
);
}
template
<
typename
CType
>
MEGDNN_ATTRIBUTE_TARGET
(
"sse4.1"
)
void
store_overflow
(
void
*
ptr
,
__m128i
a
,
int
remain
);
template
<
>
void
store_overflow
<
int16_t
>
(
void
*
ptr
,
__m128i
a
,
int
remain
)
{
__m128i
mask
=
_mm_continue_mask
(
remain
*
sizeof
(
int16_t
));
a
=
_mm_shufflelo_epi16
(
a
,
0x08
);
a
=
_mm_shufflehi_epi16
(
a
,
0x08
);
a
=
_mm_shuffle_epi32
(
a
,
0x08
);
_mm_maskmoveu_si128
(
a
,
mask
,
reinterpret_cast
<
char
*>
(
ptr
));
}
template
<
>
void
store_overflow
<
int32_t
>
(
void
*
ptr
,
__m128i
a
,
int
remain
)
{
__m128i
mask
=
_mm_continue_mask
(
remain
*
sizeof
(
int32_t
));
_mm_maskmoveu_si128
(
a
,
mask
,
reinterpret_cast
<
char
*>
(
ptr
));
}
template
<
typename
CType
>
MEGDNN_ATTRIBUTE_TARGET
(
"sse4.1"
)
static
inline
void
kern_gemm_s8s8s32_sse_4x8x2
(
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
int32_t
*
c_ptr
,
const
int
ldc
,
CType
*
c_ptr
,
const
int
ldc
,
const
int
k
)
{
constexpr
int
k_step
=
2
;
...
...
@@ -102,20 +143,20 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2(const int16_t* pack_a_ptr,
pack_a_ptr
+=
8
;
pack_b_ptr
+=
16
;
}
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
),
c_vec
[
0
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
4
),
c_vec
[
1
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
ldc
),
c_vec
[
2
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
ldc
+
4
),
c_vec
[
3
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
2
*
ldc
),
c_vec
[
4
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
2
*
ldc
+
4
),
c_vec
[
5
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
3
*
ldc
),
c_vec
[
6
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
3
*
ldc
+
4
),
c_vec
[
7
]);
store_overflow
<
CType
>
(
c_ptr
,
c_vec
[
0
]);
store_overflow
<
CType
>
(
c_ptr
+
4
,
c_vec
[
1
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
+
4
,
c_vec
[
3
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
,
c_vec
[
4
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
+
4
,
c_vec
[
5
]);
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
,
c_vec
[
6
]);
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
+
4
,
c_vec
[
7
]);
}
template
<
typename
CType
>
MEGDNN_ATTRIBUTE_TARGET
(
"sse4.1"
)
static
inline
void
kern_gemm_s8s8s32_sse_4x8x2_remain_m
(
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
int32_t
*
c_ptr
,
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
CType
*
c_ptr
,
const
int
ldc
,
const
int
k
,
const
int
remain_m
)
{
constexpr
int
k_step
=
2
;
...
...
@@ -194,34 +235,35 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m(
pack_b_ptr
+=
16
;
}
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
)
,
c_vec
[
0
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
4
)
,
c_vec
[
1
]);
store_overflow
<
CType
>
(
c_ptr
,
c_vec
[
0
]);
store_overflow
<
CType
>
(
c_ptr
+
4
,
c_vec
[
1
]);
switch
(
remain_m
)
{
case
2
:
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
ldc
)
,
c_vec
[
2
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
ldc
+
4
)
,
c_vec
[
3
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
+
4
,
c_vec
[
3
]);
break
;
case
3
:
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
ldc
)
,
c_vec
[
2
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
ldc
+
4
)
,
c_vec
[
3
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
2
*
ldc
)
,
c_vec
[
4
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
2
*
ldc
+
4
)
,
c_vec
[
5
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
+
4
,
c_vec
[
3
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
,
c_vec
[
4
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
+
4
,
c_vec
[
5
]);
break
;
case
4
:
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
ldc
)
,
c_vec
[
2
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
ldc
+
4
)
,
c_vec
[
3
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
2
*
ldc
)
,
c_vec
[
4
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
2
*
ldc
+
4
)
,
c_vec
[
5
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
3
*
ldc
)
,
c_vec
[
6
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
3
*
ldc
+
4
)
,
c_vec
[
7
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
+
4
,
c_vec
[
3
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
,
c_vec
[
4
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
+
4
,
c_vec
[
5
]);
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
,
c_vec
[
6
]);
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
+
4
,
c_vec
[
7
]);
default:
break
;
}
}
template
<
typename
CType
>
MEGDNN_ATTRIBUTE_TARGET
(
"sse4.1"
)
static
inline
void
kern_gemm_s8s8s32_sse_4x8x2_remain_n
(
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
int32_t
*
c_ptr
,
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
CType
*
c_ptr
,
const
int
ldc
,
const
int
k
,
int
remain_n
)
{
constexpr
int
k_step
=
2
;
...
...
@@ -301,10 +343,10 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_n(
}
if
(
remain_n
>=
4
)
{
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
)
,
c_vec
[
0
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
ldc
)
,
c_vec
[
2
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
2
*
ldc
)
,
c_vec
[
4
]);
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
3
*
ldc
)
,
c_vec
[
6
]);
store_overflow
<
CType
>
(
c_ptr
,
c_vec
[
0
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
,
c_vec
[
4
]);
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
,
c_vec
[
6
]);
c_ptr
+=
4
;
remain_n
-=
4
;
c_vec
[
0
]
=
c_vec
[
1
];
...
...
@@ -312,35 +354,16 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_n(
c_vec
[
4
]
=
c_vec
[
5
];
c_vec
[
6
]
=
c_vec
[
7
];
}
switch
(
remain_n
)
{
case
0
:
break
;
case
1
:
*
(
c_ptr
)
=
_mm_extract_epi32
(
c_vec
[
0
],
0
);
*
(
c_ptr
+
ldc
)
=
_mm_extract_epi32
(
c_vec
[
2
],
0
);
*
(
c_ptr
+
2
*
ldc
)
=
_mm_extract_epi32
(
c_vec
[
4
],
0
);
*
(
c_ptr
+
3
*
ldc
)
=
_mm_extract_epi32
(
c_vec
[
6
],
0
);
break
;
case
2
:
case
3
:
_mm_storel_epi64
((
__m128i
*
)(
c_ptr
),
c_vec
[
0
]);
_mm_storel_epi64
((
__m128i
*
)(
c_ptr
+
ldc
),
c_vec
[
2
]);
_mm_storel_epi64
((
__m128i
*
)(
c_ptr
+
2
*
ldc
),
c_vec
[
4
]);
_mm_storel_epi64
((
__m128i
*
)(
c_ptr
+
3
*
ldc
),
c_vec
[
6
]);
break
;
}
if
(
remain_n
==
3
)
{
*
(
c_ptr
+
2
)
=
_mm_extract_epi32
(
c_vec
[
0
],
2
);
*
(
c_ptr
+
ldc
+
2
)
=
_mm_extract_epi32
(
c_vec
[
2
],
2
);
*
(
c_ptr
+
2
*
ldc
+
2
)
=
_mm_extract_epi32
(
c_vec
[
4
],
2
);
*
(
c_ptr
+
3
*
ldc
+
2
)
=
_mm_extract_epi32
(
c_vec
[
6
],
2
);
}
store_overflow
<
CType
>
(
c_ptr
,
c_vec
[
0
],
remain_n
);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
],
remain_n
);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
,
c_vec
[
4
],
remain_n
);
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
,
c_vec
[
6
],
remain_n
);
}
template
<
typename
CType
>
MEGDNN_ATTRIBUTE_TARGET
(
"sse4.1"
)
static
inline
void
kern_gemm_s8s8s32_sse_4x8x2_remain_m_n
(
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
int32_t
*
c_ptr
,
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
CType
*
c_ptr
,
const
int
ldc
,
const
int
k
,
int
remain_m
,
int
remain_n
)
{
constexpr
int
k_step
=
2
;
...
...
@@ -421,8 +444,7 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m_n(
int
index_array
[
4
]{
0
,
2
,
4
,
6
};
if
(
remain_n
>=
4
)
{
for
(
int
m
=
0
;
m
<
remain_m
;
++
m
)
{
_mm_storeu_si128
((
__m128i
*
)(
c_ptr
+
m
*
ldc
),
c_vec
[
index_array
[
m
]]);
store_overflow
<
CType
>
(
c_ptr
+
m
*
ldc
,
c_vec
[
index_array
[
m
]]);
}
c_ptr
+=
4
;
remain_n
-=
4
;
...
...
@@ -431,29 +453,8 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m_n(
c_vec
[
4
]
=
c_vec
[
5
];
c_vec
[
6
]
=
c_vec
[
7
];
}
switch
(
remain_n
)
{
case
0
:
break
;
case
1
:
for
(
int
m
=
0
;
m
<
remain_m
;
++
m
)
{
*
(
c_ptr
+
m
*
ldc
)
=
_mm_extract_epi32
(
c_vec
[
index_array
[
m
]],
0
);
}
break
;
case
2
:
case
3
:
for
(
int
m
=
0
;
m
<
remain_m
;
++
m
)
{
_mm_storel_epi64
((
__m128i
*
)(
c_ptr
+
m
*
ldc
),
c_vec
[
index_array
[
m
]]);
}
break
;
}
if
(
remain_n
==
3
)
{
for
(
int
m
=
0
;
m
<
remain_m
;
++
m
)
{
*
(
c_ptr
+
m
*
ldc
+
2
)
=
_mm_extract_epi32
(
c_vec
[
index_array
[
m
]],
2
);
}
for
(
int
m
=
0
;
m
<
remain_m
;
++
m
)
{
store_overflow
<
CType
>
(
c_ptr
+
m
*
ldc
,
c_vec
[
index_array
[
m
]],
remain_n
);
}
}
...
...
dnn/src/x86/matrix_mul/int8/sse_strategy_4x8x2.cpp
浏览文件 @
5dbf218d
...
...
@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/common/utils.h"
...
...
@@ -18,11 +19,9 @@ using namespace megdnn;
using
namespace
x86
;
using
namespace
x86
::
matmul
;
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_sse_s8s8s32_4x8x2
);
void
gemm_sse_s8s8s32_4x8x2
::
pack_A
(
dt_int16
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
static
inline
void
gemm_packa
(
dt_int16
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
transpose
)
{
if
(
transpose
)
{
matmul_sse_4x8x2
::
gemm_s8s8s32_sse_4x8x2_pack_at
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
kmax
);
...
...
@@ -31,10 +30,8 @@ void gemm_sse_s8s8s32_4x8x2::pack_A(dt_int16* out, const dt_int8* in, int ldin,
ymax
,
k0
,
kmax
);
}
}
void
gemm_sse_s8s8s32_4x8x2
::
pack_B
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
static
inline
void
gemm_packb
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose
)
{
if
(
transpose
)
{
matmul_sse_4x8x2
::
gemm_s8s8s32_sse_4x8x2_pack_bt
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
kmax
);
...
...
@@ -43,20 +40,11 @@ void gemm_sse_s8s8s32_4x8x2::pack_B(dt_int8* out, const dt_int8* in, int ldin,
xmax
,
k0
,
kmax
);
}
}
void
gemm_sse_s8s8s32_4x8x2
::
kern
(
const
dt_int16
*
pack_a_ptr
,
const
dt_int8
*
pack_b_ptr
,
size_t
m
,
size_t
n
,
size_t
k
,
dt_int32
*
c_ptr
,
size_t
ldc
,
bool
is_first_k
,
const
dt_int32
*
,
dt_int32
*
)
const
{
megdnn_assert
(
A_dtype
.
enumv
()
==
B_dtype
.
enumv
()
&&
((
A_dtype
.
enumv
()
==
DTypeEnum
::
Int8
&&
C_dtype
.
enumv
()
==
DTypeEnum
::
Int32
)
||
(
A_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
C_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)),
"A: %s B: %s C: %s"
,
A_dtype
.
name
(),
B_dtype
.
name
(),
C_dtype
.
name
());
megdnn_assert
(
is_first_k
==
true
);
template
<
typename
CType
>
static
inline
void
gemm_kern
(
const
dt_int16
*
pack_a_ptr
,
const
dt_int8
*
pack_b_ptr
,
size_t
m
,
size_t
n
,
size_t
k
,
CType
*
c_ptr
,
size_t
ldc
,
bool
is_first_k
)
{
constexpr
int
m_tile
=
4
;
constexpr
int
n_tile
=
8
;
constexpr
int
k_tile
=
2
;
...
...
@@ -99,4 +87,62 @@ void gemm_sse_s8s8s32_4x8x2::kern(const dt_int16* pack_a_ptr,
}
}
}
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_sse_s8s8s32_4x8x2
);
void
gemm_sse_s8s8s32_4x8x2
::
pack_A
(
dt_int16
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
gemm_packa
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
kmax
,
transpose
);
}
void
gemm_sse_s8s8s32_4x8x2
::
pack_B
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
gemm_packb
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
kmax
,
transpose
);
}
void
gemm_sse_s8s8s32_4x8x2
::
kern
(
const
dt_int16
*
pack_a_ptr
,
const
dt_int8
*
pack_b_ptr
,
size_t
m
,
size_t
n
,
size_t
k
,
dt_int32
*
c_ptr
,
size_t
ldc
,
bool
is_first_k
,
const
dt_int32
*
,
dt_int32
*
)
const
{
megdnn_assert
(
A_dtype
.
enumv
()
==
B_dtype
.
enumv
()
&&
((
A_dtype
.
enumv
()
==
DTypeEnum
::
Int8
&&
C_dtype
.
enumv
()
==
DTypeEnum
::
Int32
)
||
(
A_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
C_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)),
"A: %s B: %s C: %s"
,
A_dtype
.
name
(),
B_dtype
.
name
(),
C_dtype
.
name
());
megdnn_assert
(
is_first_k
==
true
);
gemm_kern
(
pack_a_ptr
,
pack_b_ptr
,
m
,
n
,
k
,
c_ptr
,
ldc
,
is_first_k
);
}
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_sse_s8s8s16_4x8x2
);
void
gemm_sse_s8s8s16_4x8x2
::
pack_A
(
dt_int16
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
gemm_packa
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
kmax
,
transpose
);
}
void
gemm_sse_s8s8s16_4x8x2
::
pack_B
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
gemm_packb
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
kmax
,
transpose
);
}
void
gemm_sse_s8s8s16_4x8x2
::
kern
(
const
dt_int16
*
pack_a_ptr
,
const
dt_int8
*
pack_b_ptr
,
size_t
m
,
size_t
n
,
size_t
k
,
dt_int16
*
c_ptr
,
size_t
ldc
,
bool
is_first_k
,
const
dt_int32
*
,
dt_int32
*
)
const
{
megdnn_assert
(
A_dtype
.
enumv
()
==
B_dtype
.
enumv
()
&&
((
A_dtype
.
enumv
()
==
DTypeEnum
::
Int8
&&
C_dtype
.
enumv
()
==
DTypeEnum
::
Int16
)
||
(
A_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
C_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS16
)),
"A: %s B: %s C: %s"
,
A_dtype
.
name
(),
B_dtype
.
name
(),
C_dtype
.
name
());
megdnn_assert
(
is_first_k
==
true
);
gemm_kern
(
pack_a_ptr
,
pack_b_ptr
,
m
,
n
,
k
,
c_ptr
,
ldc
,
is_first_k
);
}
// vim: syntax=cpp.doxygen
dnn/src/x86/matrix_mul/int8/strategy.h
浏览文件 @
5dbf218d
...
...
@@ -38,6 +38,10 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int32, dt_int32,
4
,
8
,
2
,
false
,
false
,
gemm_sse_s8s8s32_4x8x2
);
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE
(
dt_int8
,
dt_int16
,
dt_int16
,
dt_int32
,
4
,
8
,
2
,
false
,
false
,
gemm_sse_s8s8s16_4x8x2
);
}
// namespace matmul
}
// namespace x86
}
// namespace megdnn
...
...
dnn/src/x86/matrix_mul/opr_impl.cpp
浏览文件 @
5dbf218d
...
...
@@ -38,6 +38,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt8x8x32AVX2M2N4K16
algoint8x8x32avx2_m2n4k16
;
AlgoInt8x8x32SSEM4N8K2
algoint8x8x32sse_m4n8k2
;
AlgoInt8x8x16AVX2
algoint8x8x16avx2_m4n16k2
;
AlgoInt8x8x16SSE
algoint8x8x16sse_m4n8k2
;
AlgoF32MK8_8x8
algof32mk8_8x8
;
public:
...
...
@@ -51,6 +52,7 @@ public:
all_algos
.
emplace_back
(
&
algoint8x8x16avx2_m4n16k2
);
all_algos
.
emplace_back
(
&
algoint8x8x32avx2_m2n4k16
);
all_algos
.
emplace_back
(
&
algoint8x8x32sse_m4n8k2
);
all_algos
.
emplace_back
(
&
algoint8x8x16sse_m4n8k2
);
all_algos
.
emplace_back
(
&
algof32mk8_8x8
);
#if MEGDNN_X86_WITH_MKL_DNN
all_algos
.
emplace_back
(
&
algoint8x8x32mkldnn
);
...
...
dnn/src/x86/matrix_mul/opr_impl.h
浏览文件 @
5dbf218d
...
...
@@ -56,6 +56,7 @@ protected:
class
AlgoInt8x8x32AVX2M4N16K2
;
class
AlgoInt8x8x32SSEM4N8K2
;
class
AlgoInt8x8x16AVX2
;
class
AlgoInt8x8x16SSE
;
class
AlgoPack
;
class
AlgoF32MK8_8x8
;
};
...
...
dnn/test/x86/conv_bias.cpp
浏览文件 @
5dbf218d
...
...
@@ -835,6 +835,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X) {
}
if
(
::
megdnn
::
x86
::
is_supported
(
::
megdnn
::
x86
::
SIMDType
::
SSE4_2
))
{
cb
(
"IM2COLMATMUL:X86_INT8X8X32_SSE_4X8X2"
);
cb2
(
"IM2COLMATMUL:X86_INT8X8X16_SSE"
);
}
#undef cb
...
...
@@ -1002,7 +1003,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_BLAS) {
}
#endif
TEST_F
(
X86_MULTI_THREADS
,
CONV_BIAS_CONV1X1_S1_INT8X8X
32
)
{
TEST_F
(
X86_MULTI_THREADS
,
CONV_BIAS_CONV1X1_S1_INT8X8X
)
{
using
namespace
conv_bias
;
UniformIntRNG
rng
{
-
50
,
50
};
float
epsilon
=
0.001
;
...
...
@@ -1028,10 +1029,16 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) {
checker_conv_bias
(
args
,
handle
(),
&
rng
,
epsilon
,
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int32
{},
dtype
::
Int32
{},
"CONV1x1:X86_INT8X8X32_AVX2_2X4X16:24"
);
checker_conv_bias
(
args
,
handle
(),
&
rng
,
epsilon
,
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int16
{},
dtype
::
Int16
{},
"CONV1x1:X86_INT8X8X16_AVX2"
);
}
checker_conv_bias
(
args
,
handle
(),
&
rng
,
epsilon
,
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int32
{},
dtype
::
Int32
{},
"CONV1x1:X86_INT8X8X32_SSE_4X8X2:48"
);
checker_conv_bias
(
args
,
handle
(),
&
rng
,
epsilon
,
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int16
{},
dtype
::
Int16
{},
"CONV1x1:X86_INT8X8X16_SSE"
);
}
/************************* End Conv1x1 PackA ************************/
...
...
dnn/test/x86/convolution.cpp
浏览文件 @
5dbf218d
...
...
@@ -403,6 +403,7 @@ TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x16) {
benchmark
.
set_dtype
(
0
,
dtype
::
Int8
())
.
set_dtype
(
1
,
dtype
::
Int8
())
.
set_dtype
(
2
,
dtype
::
Int16
());
benchmark
.
set_before_exec_callback
(
AlgoChecker
<
Convolution
>
(
".*"
));
benchmark
.
set_display
(
false
);
benchmark
.
set_times
(
RUN
);
...
...
dnn/test/x86/matrix_mul.cpp
浏览文件 @
5dbf218d
...
...
@@ -52,6 +52,10 @@ TEST_F(X86, MATRIX_MUL_AVX2_8X8X16) {
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int16
{},
handle
(),
"X86_INT8X8X16_AVX2"
);
}
TEST_F
(
X86
,
MATRIX_MUL_SSE_8X8X16
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int16
{},
handle
(),
"X86_INT8X8X16_SSE"
);
}
TEST_F
(
X86
,
MATRIX_MUL_SSE_8X8X32
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int32
{},
handle
(),
"X86_INT8X8X32_SSE_4X8X2"
);
...
...
@@ -132,6 +136,17 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) {
benchmarker_avx2_4x16x2_8816
.
set_before_exec_callback
(
AlgoChecker
<
MatrixMul
>
(
"X86_INT8X8X16_AVX2"
));
Benchmarker
<
MatrixMul
>
benchmarker_sse_4x8x2_8816
(
handle
());
benchmarker_sse_4x8x2_8816
.
set_display
(
false
)
.
set_times
(
RUNS
)
.
set_dtype
(
0
,
dtype
::
Int8
{})
.
set_dtype
(
1
,
dtype
::
Int8
{})
.
set_dtype
(
2
,
dtype
::
Int16
{})
.
set_rng
(
0
,
rng
.
get
())
.
set_rng
(
1
,
rng
.
get
());
benchmarker_sse_4x8x2_8816
.
set_before_exec_callback
(
AlgoChecker
<
MatrixMul
>
(
"X86_INT8X8X16_SSE"
));
Benchmarker
<
MatrixMul
>
benchmarker_avx2_2x4x16
(
handle
());
benchmarker_avx2_2x4x16
.
set_display
(
false
)
.
set_times
(
RUNS
)
...
...
@@ -212,9 +227,15 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) {
std
::
cout
<<
"sse: "
<<
sse_used
<<
" ms, "
<<
computations
/
sse_used
<<
" Gflops, "
<<
"speed_up "
<<
float_used
/
sse_used
<<
", "
;
auto
sse_used_8816
=
benchmarker_sse_4x8x2_8816
.
exec
({{
M
,
K
},
{
K
,
N
},
{}})
/
RUNS
;
std
::
cout
<<
"sse_8816: "
<<
sse_used_8816
<<
" ms, "
<<
computations
/
sse_used_8816
<<
" Gflops, "
;
}
std
::
cout
<<
std
::
endl
;
};
run
(
256
,
256
,
256
);
for
(
size_t
M
:
{
8
,
64
,
112
,
256
,
512
})
{
for
(
size_t
K
:
{
8
,
16
,
32
,
64
,
112
,
256
,
512
})
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录