Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
25b6a131
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
25b6a131
编写于
6月 12, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/x86): add x86 avx2 8x8x16 matmul
GitOrigin-RevId: d2172c50b244a0683b00710a88bc507d02a9734f
上级
273f891b
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
464 addition
and
125 deletion
+464
-125
dnn/src/x86/matrix_mul/algos.cpp
dnn/src/x86/matrix_mul/algos.cpp
+73
-0
dnn/src/x86/matrix_mul/algos.h
dnn/src/x86/matrix_mul/algos.h
+20
-2
dnn/src/x86/matrix_mul/common/common.h
dnn/src/x86/matrix_mul/common/common.h
+38
-7
dnn/src/x86/matrix_mul/int8/avx2_strategy_4x16x2.cpp
dnn/src/x86/matrix_mul/int8/avx2_strategy_4x16x2.cpp
+70
-23
dnn/src/x86/matrix_mul/int8/kernel_avx2_4x16x2.h
dnn/src/x86/matrix_mul/int8/kernel_avx2_4x16x2.h
+123
-86
dnn/src/x86/matrix_mul/int8/strategy.h
dnn/src/x86/matrix_mul/int8/strategy.h
+6
-1
dnn/src/x86/matrix_mul/opr_impl.cpp
dnn/src/x86/matrix_mul/opr_impl.cpp
+2
-0
dnn/src/x86/matrix_mul/opr_impl.h
dnn/src/x86/matrix_mul/opr_impl.h
+3
-1
dnn/test/x86/conv_bias.cpp
dnn/test/x86/conv_bias.cpp
+47
-2
dnn/test/x86/convolution.cpp
dnn/test/x86/convolution.cpp
+59
-2
dnn/test/x86/matrix_mul.cpp
dnn/test/x86/matrix_mul.cpp
+23
-1
未找到文件。
dnn/src/x86/matrix_mul/algos.cpp
浏览文件 @
25b6a131
...
@@ -318,6 +318,79 @@ void gemm_s8s8s32_sse_4x8x2(const MatrixMulImpl::KernParam& kern_param) {
...
@@ -318,6 +318,79 @@ void gemm_s8s8s32_sse_4x8x2(const MatrixMulImpl::KernParam& kern_param) {
}
}
}
// namespace
}
// namespace
void
MatrixMulImpl
::
AlgoInt8x8x16AVX2
::
gemm_s8s8s16_avx2_4x16x2
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
MEGDNN_MARK_USED_VAR
(
kern_param
);
MIDOUT_BEGIN
(
megdnn_x86_matmul_kern_avx2_4x16x2
,
midout_iv
(
1
))
{
constexpr
int
cacheline
=
64
;
const
size_t
m
=
kern_param
.
M
;
const
size_t
n
=
kern_param
.
N
;
const
size_t
k
=
kern_param
.
K
;
const
bool
trans_a
=
kern_param
.
trA
;
const
bool
trans_b
=
kern_param
.
trB
;
const
size_t
lda
=
kern_param
.
LDA
;
const
size_t
ldb
=
kern_param
.
LDB
;
const
size_t
ldc
=
kern_param
.
LDC
;
auto
a_type
=
kern_param
.
A_type
;
auto
b_type
=
kern_param
.
B_type
;
auto
c_type
=
kern_param
.
C_type
;
const
auto
a_ptr
=
kern_param
.
A
<
dt_int8
>
();
const
auto
b_ptr
=
kern_param
.
B
<
dt_int8
>
();
auto
c_ptr
=
kern_param
.
C
<
dt_int16
>
();
x86
::
matmul
::
gemm_avx2_s8s8s16_4x16x2
strategy
(
m
,
n
,
k
,
a_type
,
b_type
,
c_type
);
megdnn
::
matmul
::
GemmInterleaved
<
x86
::
matmul
::
gemm_avx2_s8s8s16_4x16x2
>
(
m
,
n
,
k
,
trans_a
,
trans_b
,
strategy
,
cacheline
)
.
execute
(
a_ptr
,
lda
,
b_ptr
,
ldb
,
c_ptr
,
ldc
,
kern_param
.
workspace_ptr
);
}
MIDOUT_END
();
}
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoInt8x8x16AVX2
::
get_kern
(
const
KernSizeParam
&
)
const
{
return
gemm_s8s8s16_avx2_4x16x2
;
}
bool
MatrixMulImpl
::
AlgoInt8x8x16AVX2
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
bool
is_ab_same
=
kern_size_param
.
A_type
.
enumv
()
==
kern_size_param
.
B_type
.
enumv
();
bool
is_type_ok
=
((
kern_size_param
.
A_type
.
enumv
()
==
DTypeEnum
::
Int8
&&
kern_size_param
.
C_type
.
enumv
()
==
DTypeEnum
::
Int16
)
||
(
kern_size_param
.
A_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
kern_size_param
.
C_type
.
enumv
()
==
DTypeEnum
::
QuantizedS16
));
bool
is_mode_ok
=
kern_size_param
.
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
&&
is_supported
(
SIMDType
::
AVX2
);
bool
is_param_ok
=
is_ab_same
&&
is_type_ok
&&
is_mode_ok
;
return
is_param_ok
;
}
bool
MatrixMulImpl
::
AlgoInt8x8x16AVX2
::
preferred
(
const
KernSizeParam
&
)
const
{
return
true
;
}
size_t
MatrixMulImpl
::
AlgoInt8x8x16AVX2
::
get_workspace
(
const
KernSizeParam
&
kern_param
)
const
{
constexpr
int
cacheline
=
64
;
const
size_t
m
=
kern_param
.
M
;
const
size_t
n
=
kern_param
.
N
;
const
size_t
k
=
kern_param
.
K
;
const
bool
trans_a
=
kern_param
.
trA
;
const
bool
trans_b
=
kern_param
.
trB
;
auto
a_type
=
kern_param
.
A_type
;
auto
b_type
=
kern_param
.
B_type
;
auto
c_type
=
kern_param
.
C_type
;
x86
::
matmul
::
gemm_avx2_s8s8s16_4x16x2
strategy
(
m
,
n
,
k
,
a_type
,
b_type
,
c_type
);
return
megdnn
::
matmul
::
GemmInterleaved
<
x86
::
matmul
::
gemm_avx2_s8s8s16_4x16x2
>
(
m
,
n
,
k
,
trans_a
,
trans_b
,
strategy
,
cacheline
)
.
get_workspace_size
();
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL
(
AlgoInt8x8x16AVX2
,
megdnn_x86_matmul_kern
,
8
,
x86
::
matmul
::
gemm_avx2_s8s8s16_4x16x2
,
dt_int8
,
dt_int16
,
dt_int16
);
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoInt8x8x32AVX2M4N16K2
::
get_kern
(
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoInt8x8x32AVX2M4N16K2
::
get_kern
(
const
KernSizeParam
&
)
const
{
const
KernSizeParam
&
)
const
{
...
...
dnn/src/x86/matrix_mul/algos.h
浏览文件 @
25b6a131
...
@@ -6,13 +6,14 @@
...
@@ -6,13 +6,14 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#pragma once
#pragma once
#include "src/x86/matrix_mul/opr_impl.h"
#include "src/fallback/matrix_mul/gemm_common.h"
#include "src/fallback/matrix_mul/gemm_common.h"
#include "src/x86/matrix_mul/opr_impl.h"
namespace
megdnn
{
namespace
megdnn
{
namespace
x86
{
namespace
x86
{
...
@@ -71,6 +72,23 @@ public:
...
@@ -71,6 +72,23 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
};
class
MatrixMulImpl
::
AlgoInt8x8x16AVX2
:
public
AlgoBase
{
private:
static
void
gemm_s8s8s16_avx2_4x16x2
(
const
MatrixMulImpl
::
KernParam
&
kern_param
);
static
MatrixMulImpl
::
AlgoInt8x8x32AVX2M4N16K2
m_algo
;
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"X86_INT8X8X16_AVX2"
;
}
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
class
MatrixMulImpl
::
AlgoInt8x8x32SSEM4N8K2
:
public
AlgoBase
{
class
MatrixMulImpl
::
AlgoInt8x8x32SSEM4N8K2
:
public
AlgoBase
{
public:
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
bool
is_reproducible
()
const
override
{
return
true
;
}
...
...
dnn/src/x86/matrix_mul/common/common.h
浏览文件 @
25b6a131
...
@@ -6,16 +6,17 @@
...
@@ -6,16 +6,17 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#pragma once
#pragma once
#include <x86intrin.h>
#include <x86intrin.h>
#ifdef WIN32
#ifdef WIN32
#include <avxintrin.h>
#include <smmintrin.h>
#include <avx2intrin.h>
#include <avx2intrin.h>
#include <avxintrin.h>
#include <fmaintrin.h>
#include <fmaintrin.h>
#include <smmintrin.h>
#endif
#endif
#include <cmath>
#include <cmath>
#include <cstdint>
#include <cstdint>
...
@@ -787,19 +788,49 @@ static inline void transpose_4x8_k2_int8_to_int16(const int8_t* inptr0,
...
@@ -787,19 +788,49 @@ static inline void transpose_4x8_k2_int8_to_int16(const int8_t* inptr0,
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
static
inline
__v8si
_m256_continue_mask_v8si
(
const
int
&
x
)
{
static
inline
__v8si
_m256_continue_mask_v8si
(
const
int
&
x
)
{
// clang-format off
static
__v8si
map
[
9
]
=
{
static
__v8si
map
[
9
]
=
{
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
{
-
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
{
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
},
{
-
1
,
-
1
,
0
,
0
,
0
,
0
,
0
,
0
},
{
-
1
,
-
1
,
-
1
,
0
,
0
,
0
,
0
,
0
},
{
-
1
,
00
,
00
,
00
,
00
,
00
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
0
,
0
,
0
,
0
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
0
,
0
,
0
},
{
-
1
,
-
1
,
00
,
00
,
00
,
00
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
0
,
0
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
0
},
{
-
1
,
-
1
,
-
1
,
00
,
00
,
00
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
00
,
00
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
00
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
}};
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
}};
return
map
[
x
];
return
map
[
x
];
// clang-format on
}
}
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
static
inline
__m256i
_m256_continue_mask
(
const
int
&
x
)
{
static
inline
__m256i
_m256_continue_mask
(
const
int
&
x
)
{
return
(
__m256i
)
_m256_continue_mask_v8si
(
x
);
return
(
__m256i
)
_m256_continue_mask_v8si
(
x
);
}
}
MEGDNN_ATTRIBUTE_TARGET
(
"sse2"
)
static
inline
__m128i
_mm_continue_mask
(
const
int
&
x
)
{
static
__v16qi
map
[
17
]
=
{
{
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
},
{
-
1
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
},
{
-
1
,
-
1
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
00
,
00
,
00
,
00
,
00
,
00
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
00
,
00
,
00
,
00
,
00
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
00
,
00
,
00
,
00
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
00
,
00
,
00
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
00
,
00
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
00
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
00
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
00
},
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
},
};
return
(
__m128i
)
map
[
x
];
}
MEGDNN_ATTRIBUTE_TARGET
(
"sse2"
)
MEGDNN_ATTRIBUTE_TARGET
(
"sse2"
)
static
inline
void
transpose_4xk_int8_to_int16_pad
(
const
int8_t
*
inptr0
,
static
inline
void
transpose_4xk_int8_to_int16_pad
(
const
int8_t
*
inptr0
,
const
int8_t
*
inptr1
,
const
int8_t
*
inptr1
,
...
...
dnn/src/x86/matrix_mul/int8/avx2_strategy_4x16x2.cpp
浏览文件 @
25b6a131
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "src/common/utils.h"
#include "src/common/utils.h"
...
@@ -18,10 +19,9 @@ using namespace megdnn;
...
@@ -18,10 +19,9 @@ using namespace megdnn;
using
namespace
x86
;
using
namespace
x86
;
using
namespace
x86
::
matmul
;
using
namespace
x86
::
matmul
;
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_avx2_s8s8s32_4x16x2
);
static
inline
void
gemm_packa
(
dt_int16
*
out
,
const
dt_int8
*
in
,
int
ldin
,
void
gemm_avx2_s8s8s32_4x16x2
::
pack_A
(
dt_int16
*
out
,
const
dt_int8
*
in
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
bool
transpose
)
{
int
kmax
,
bool
transpose
)
const
{
if
(
transpose
)
{
if
(
transpose
)
{
matmul_avx2_4x16x2
::
gemm_s8s8s32_avx2_4x16x2_pack_at
(
out
,
in
,
ldin
,
y0
,
matmul_avx2_4x16x2
::
gemm_s8s8s32_avx2_4x16x2_pack_at
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
kmax
);
ymax
,
k0
,
kmax
);
...
@@ -30,10 +30,8 @@ void gemm_avx2_s8s8s32_4x16x2::pack_A(dt_int16* out, const dt_int8* in,
...
@@ -30,10 +30,8 @@ void gemm_avx2_s8s8s32_4x16x2::pack_A(dt_int16* out, const dt_int8* in,
ymax
,
k0
,
kmax
);
ymax
,
k0
,
kmax
);
}
}
}
}
static
inline
void
gemm_packb
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
x0
,
void
gemm_avx2_s8s8s32_4x16x2
::
pack_B
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose
)
{
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
if
(
transpose
)
{
if
(
transpose
)
{
matmul_avx2_4x16x2
::
gemm_s8s8s32_avx2_4x16x2_pack_bt
(
out
,
in
,
ldin
,
x0
,
matmul_avx2_4x16x2
::
gemm_s8s8s32_avx2_4x16x2_pack_bt
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
kmax
);
xmax
,
k0
,
kmax
);
...
@@ -42,20 +40,11 @@ void gemm_avx2_s8s8s32_4x16x2::pack_B(dt_int8* out, const dt_int8* in, int ldin,
...
@@ -42,20 +40,11 @@ void gemm_avx2_s8s8s32_4x16x2::pack_B(dt_int8* out, const dt_int8* in, int ldin,
xmax
,
k0
,
kmax
);
xmax
,
k0
,
kmax
);
}
}
}
}
template
<
typename
CType
>
void
gemm_avx2_s8s8s32_4x16x2
::
kern
(
const
dt_int16
*
pack_a_ptr
,
static
inline
void
gemm_kern
(
const
dt_int16
*
pack_a_ptr
,
const
dt_int8
*
pack_b_ptr
,
size_t
m
,
const
dt_int8
*
pack_b_ptr
,
size_t
m
,
size_t
n
,
size_t
n
,
size_t
k
,
dt_int32
*
c_ptr
,
size_t
k
,
CType
*
c_ptr
,
size_t
ldc
,
size_t
ldc
,
bool
is_first_k
,
bool
is_first_k
)
{
const
dt_int32
*
,
dt_int32
*
)
const
{
megdnn_assert
(
A_dtype
.
enumv
()
==
B_dtype
.
enumv
()
&&
((
A_dtype
.
enumv
()
==
DTypeEnum
::
Int8
&&
C_dtype
.
enumv
()
==
DTypeEnum
::
Int32
)
||
(
A_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
C_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)),
"A: %s B: %s C: %s"
,
A_dtype
.
name
(),
B_dtype
.
name
(),
C_dtype
.
name
());
megdnn_assert
(
is_first_k
==
true
);
constexpr
size_t
m_tile
=
4
;
constexpr
size_t
m_tile
=
4
;
constexpr
size_t
n_tile
=
16
;
constexpr
size_t
n_tile
=
16
;
constexpr
size_t
k_tile
=
2
;
constexpr
size_t
k_tile
=
2
;
...
@@ -109,4 +98,62 @@ void gemm_avx2_s8s8s32_4x16x2::kern(const dt_int16* pack_a_ptr,
...
@@ -109,4 +98,62 @@ void gemm_avx2_s8s8s32_4x16x2::kern(const dt_int16* pack_a_ptr,
}
}
}
}
}
}
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_avx2_s8s8s32_4x16x2
);
void
gemm_avx2_s8s8s32_4x16x2
::
pack_A
(
dt_int16
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
gemm_packa
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
kmax
,
transpose
);
}
void
gemm_avx2_s8s8s32_4x16x2
::
pack_B
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
gemm_packb
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
kmax
,
transpose
);
}
void
gemm_avx2_s8s8s32_4x16x2
::
kern
(
const
dt_int16
*
pack_a_ptr
,
const
dt_int8
*
pack_b_ptr
,
size_t
m
,
size_t
n
,
size_t
k
,
dt_int32
*
c_ptr
,
size_t
ldc
,
bool
is_first_k
,
const
dt_int32
*
,
dt_int32
*
)
const
{
megdnn_assert
(
A_dtype
.
enumv
()
==
B_dtype
.
enumv
()
&&
((
A_dtype
.
enumv
()
==
DTypeEnum
::
Int8
&&
C_dtype
.
enumv
()
==
DTypeEnum
::
Int32
)
||
(
A_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
C_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)),
"A: %s B: %s C: %s"
,
A_dtype
.
name
(),
B_dtype
.
name
(),
C_dtype
.
name
());
megdnn_assert
(
is_first_k
==
true
);
gemm_kern
(
pack_a_ptr
,
pack_b_ptr
,
m
,
n
,
k
,
c_ptr
,
ldc
,
is_first_k
);
}
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_avx2_s8s8s16_4x16x2
);
void
gemm_avx2_s8s8s16_4x16x2
::
pack_A
(
dt_int16
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
gemm_packa
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
kmax
,
transpose
);
}
void
gemm_avx2_s8s8s16_4x16x2
::
pack_B
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
gemm_packb
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
kmax
,
transpose
);
}
void
gemm_avx2_s8s8s16_4x16x2
::
kern
(
const
dt_int16
*
pack_a_ptr
,
const
dt_int8
*
pack_b_ptr
,
size_t
m
,
size_t
n
,
size_t
k
,
dt_int16
*
c_ptr
,
size_t
ldc
,
bool
is_first_k
,
const
dt_int32
*
,
dt_int32
*
)
const
{
megdnn_assert
(
A_dtype
.
enumv
()
==
B_dtype
.
enumv
()
&&
((
A_dtype
.
enumv
()
==
DTypeEnum
::
Int8
&&
C_dtype
.
enumv
()
==
DTypeEnum
::
Int16
)
||
(
A_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
C_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS16
)),
"A: %s B: %s C: %s"
,
A_dtype
.
name
(),
B_dtype
.
name
(),
C_dtype
.
name
());
megdnn_assert
(
is_first_k
==
true
);
gemm_kern
(
pack_a_ptr
,
pack_b_ptr
,
m
,
n
,
k
,
c_ptr
,
ldc
,
is_first_k
);
}
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/src/x86/matrix_mul/int8/kernel_avx2_4x16x2.h
浏览文件 @
25b6a131
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include <immintrin.h>
#include <immintrin.h>
...
@@ -20,11 +21,47 @@ namespace megdnn {
...
@@ -20,11 +21,47 @@ namespace megdnn {
namespace
x86
{
namespace
x86
{
namespace
matmul_avx2_4x16x2
{
namespace
matmul_avx2_4x16x2
{
template
<
typename
CType
>
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
void
store_overflow
(
void
*
ptr
,
__m256i
a
);
template
<
>
void
store_overflow
<
int16_t
>
(
void
*
ptr
,
__m256i
a
)
{
static
__m256i
idx
=
_mm256_setr_epi32
(
0
,
2
,
4
,
6
,
0
,
0
,
0
,
0
);
a
=
_mm256_shufflelo_epi16
(
a
,
0x08
);
a
=
_mm256_shufflehi_epi16
(
a
,
0x08
);
a
=
_mm256_permutevar8x32_epi32
(
a
,
idx
);
_mm_storeu_si128
((
__m128i
*
)
ptr
,
_mm256_extractf128_si256
(
a
,
0
));
}
template
<
>
void
store_overflow
<
int32_t
>
(
void
*
ptr
,
__m256i
a
)
{
_mm256_storeu_si256
((
__m256i
*
)(
ptr
),
a
);
}
template
<
typename
CType
>
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
void
store_overflow
(
void
*
ptr
,
__m256i
a
,
int
remain
);
template
<
>
void
store_overflow
<
int16_t
>
(
void
*
ptr
,
__m256i
a
,
int
remain
)
{
__m128i
mask
=
_mm_continue_mask
(
remain
*
sizeof
(
int16_t
));
static
__m256i
idx
=
_mm256_setr_epi32
(
0
,
2
,
4
,
6
,
0
,
0
,
0
,
0
);
a
=
_mm256_shufflelo_epi16
(
a
,
0x08
);
a
=
_mm256_shufflehi_epi16
(
a
,
0x08
);
a
=
_mm256_permutevar8x32_epi32
(
a
,
idx
);
_mm_maskmoveu_si128
(
_mm256_extractf128_si256
(
a
,
0
),
mask
,
reinterpret_cast
<
char
*>
(
ptr
));
}
template
<
>
void
store_overflow
<
int32_t
>
(
void
*
ptr
,
__m256i
a
,
int
remain
)
{
__m256i
mask
=
_m256_continue_mask
(
remain
);
_mm256_maskstore_epi32
(
reinterpret_cast
<
int32_t
*>
(
ptr
),
mask
,
a
);
}
template
<
typename
CType
>
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
static
inline
void
kern_gemm_s8s8s32_avx2_4x16x2
(
const
int16_t
*
pack_a_ptr
,
static
inline
void
kern_gemm_s8s8s32_avx2_4x16x2
(
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
const
int8_t
*
pack_b_ptr
,
int32_t
*
c_ptr
,
CType
*
c_ptr
,
const
uint32_t
ldc
,
const
uint32_t
ldc
,
const
uint32_t
k
)
{
const
uint32_t
k
)
{
constexpr
uint32_t
k_step
=
2
;
constexpr
uint32_t
k_step
=
2
;
...
@@ -104,19 +141,19 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2(const int16_t* pack_a_ptr,
...
@@ -104,19 +141,19 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2(const int16_t* pack_a_ptr,
pack_b_ptr
+=
32
;
pack_b_ptr
+=
32
;
}
}
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
)
,
c_vec
[
0
]);
store_overflow
<
CType
>
(
c_ptr
,
c_vec
[
0
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
8
)
,
c_vec
[
1
]);
store_overflow
<
CType
>
(
c_ptr
+
8
,
c_vec
[
1
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
ldc
)
,
c_vec
[
2
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
ldc
+
8
)
,
c_vec
[
3
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
+
8
,
c_vec
[
3
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
2
*
ldc
)
,
c_vec
[
4
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
,
c_vec
[
4
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
2
*
ldc
+
8
)
,
c_vec
[
5
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
+
8
,
c_vec
[
5
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
3
*
ldc
)
,
c_vec
[
6
]);
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
,
c_vec
[
6
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
3
*
ldc
+
8
)
,
c_vec
[
7
]);
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
+
8
,
c_vec
[
7
]);
}
}
template
<
typename
CType
>
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
static
inline
void
kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n
(
static
inline
void
kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n
(
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
int32_t
*
c_ptr
,
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
CType
*
c_ptr
,
const
uint32_t
ldc
,
const
uint32_t
k
,
const
uint32_t
remain_n
)
{
const
uint32_t
ldc
,
const
uint32_t
k
,
const
uint32_t
remain_n
)
{
constexpr
uint32_t
k_step
=
2
;
constexpr
uint32_t
k_step
=
2
;
...
@@ -173,15 +210,15 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n(
...
@@ -173,15 +210,15 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n(
pack_b_ptr
+=
32
;
pack_b_ptr
+=
32
;
}
}
__m256i
mask
=
_m256_continue_mask
(
remain_n
);
store_overflow
<
CType
>
(
c_ptr
,
c_vec
[
0
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
),
mask
,
c_vec
[
0
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
ldc
),
mask
,
c_vec
[
2
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
,
c_vec
[
4
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
2
*
ldc
),
mask
,
c_vec
[
4
]);
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
,
c_vec
[
6
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
3
*
ldc
),
mask
,
c_vec
[
6
]);
}
}
template
<
typename
CType
>
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
static
inline
void
kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n
(
static
inline
void
kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n
(
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
int32_t
*
c_ptr
,
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
CType
*
c_ptr
,
const
uint32_t
ldc
,
const
uint32_t
k
,
const
uint32_t
remain_m
,
const
uint32_t
ldc
,
const
uint32_t
k
,
const
uint32_t
remain_m
,
uint32_t
remain_n
)
{
uint32_t
remain_n
)
{
constexpr
uint32_t
k_step
=
2
;
constexpr
uint32_t
k_step
=
2
;
...
@@ -239,29 +276,29 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n(
...
@@ -239,29 +276,29 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n(
pack_b_ptr
+=
32
;
pack_b_ptr
+=
32
;
}
}
__m256i
mask
=
_m256_continue_mask
(
remain_n
);
store_overflow
<
CType
>
(
c_ptr
,
c_vec
[
0
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
),
mask
,
c_vec
[
0
]);
switch
(
remain_m
)
{
switch
(
remain_m
)
{
case
2
:
case
2
:
_mm256_maskstore_epi32
((
c_ptr
+
ldc
),
mask
,
c_vec
[
2
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
],
remain_n
);
break
;
break
;
case
3
:
case
3
:
_mm256_maskstore_epi32
((
c_ptr
+
ldc
),
mask
,
c_vec
[
2
]
);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
2
*
ldc
),
mask
,
c_vec
[
4
]
);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
,
c_vec
[
4
],
remain_n
);
break
;
break
;
case
4
:
case
4
:
_mm256_maskstore_epi32
((
c_ptr
+
ldc
),
mask
,
c_vec
[
2
]
);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
2
*
ldc
),
mask
,
c_vec
[
4
]
);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
,
c_vec
[
4
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
3
*
ldc
),
mask
,
c_vec
[
6
]
);
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
,
c_vec
[
6
],
remain_n
);
break
;
break
;
default:
default:
break
;
break
;
}
}
}
}
template
<
typename
CType
>
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
static
inline
void
kern_gemm_s8s8s32_avx2_4x16x2_remain_m
(
static
inline
void
kern_gemm_s8s8s32_avx2_4x16x2_remain_m
(
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
int32_t
*
c_ptr
,
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
CType
*
c_ptr
,
const
uint32_t
ldc
,
const
uint32_t
k
,
const
uint32_t
remain_m
)
{
const
uint32_t
ldc
,
const
uint32_t
k
,
const
uint32_t
remain_m
)
{
constexpr
uint32_t
k_step
=
2
;
constexpr
uint32_t
k_step
=
2
;
...
@@ -339,34 +376,36 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m(
...
@@ -339,34 +376,36 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m(
pack_a_ptr
+=
8
;
pack_a_ptr
+=
8
;
pack_b_ptr
+=
32
;
pack_b_ptr
+=
32
;
}
}
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
),
c_vec
[
0
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
8
),
c_vec
[
1
]);
store_overflow
<
CType
>
(
c_ptr
,
c_vec
[
0
]);
store_overflow
<
CType
>
(
c_ptr
+
8
,
c_vec
[
1
]);
switch
(
remain_m
)
{
switch
(
remain_m
)
{
case
2
:
case
2
:
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
ldc
)
,
c_vec
[
2
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
ldc
+
8
)
,
c_vec
[
3
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
+
8
,
c_vec
[
3
]);
break
;
break
;
case
3
:
case
3
:
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
ldc
)
,
c_vec
[
2
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
ldc
+
8
)
,
c_vec
[
3
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
+
8
,
c_vec
[
3
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
2
*
ldc
)
,
c_vec
[
4
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
,
c_vec
[
4
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
2
*
ldc
+
8
)
,
c_vec
[
5
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
+
8
,
c_vec
[
5
]);
break
;
break
;
case
4
:
case
4
:
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
ldc
)
,
c_vec
[
2
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
ldc
+
8
)
,
c_vec
[
3
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
+
8
,
c_vec
[
3
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
2
*
ldc
)
,
c_vec
[
4
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
,
c_vec
[
4
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
2
*
ldc
+
8
)
,
c_vec
[
5
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
+
8
,
c_vec
[
5
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
3
*
ldc
)
,
c_vec
[
6
]);
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
,
c_vec
[
6
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
3
*
ldc
+
8
)
,
c_vec
[
7
]);
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
+
8
,
c_vec
[
7
]);
default:
default:
break
;
break
;
}
}
}
}
template
<
typename
CType
>
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
static
inline
void
kern_gemm_s8s8s32_avx2_4x16x2_remain_n
(
static
inline
void
kern_gemm_s8s8s32_avx2_4x16x2_remain_n
(
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
int32_t
*
c_ptr
,
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
CType
*
c_ptr
,
const
uint32_t
ldc
,
const
uint32_t
k
,
uint32_t
remain_n
)
{
const
uint32_t
ldc
,
const
uint32_t
k
,
uint32_t
remain_n
)
{
constexpr
uint32_t
k_step
=
2
;
constexpr
uint32_t
k_step
=
2
;
...
@@ -446,29 +485,28 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_n(
...
@@ -446,29 +485,28 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_n(
}
}
if
(
remain_n
>=
8
)
{
if
(
remain_n
>=
8
)
{
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
)
,
c_vec
[
0
]);
store_overflow
<
CType
>
(
c_ptr
,
c_vec
[
0
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
ldc
)
,
c_vec
[
2
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
2
*
ldc
)
,
c_vec
[
4
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
,
c_vec
[
4
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
3
*
ldc
)
,
c_vec
[
6
]);
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
,
c_vec
[
6
]);
remain_n
-=
8
;
remain_n
-=
8
;
if
(
remain_n
>
0
)
{
if
(
remain_n
>
0
)
{
__m256i
mask
=
_m256_continue_mask
(
remain_n
);
store_overflow
<
CType
>
(
c_ptr
+
8
,
c_vec
[
1
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
8
),
mask
,
c_vec
[
1
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
+
8
,
c_vec
[
3
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
ldc
+
8
),
mask
,
c_vec
[
3
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
+
8
,
c_vec
[
5
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
2
*
ldc
+
8
),
mask
,
c_vec
[
5
]);
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
+
8
,
c_vec
[
7
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
3
*
ldc
+
8
),
mask
,
c_vec
[
7
]);
}
}
}
else
{
}
else
{
__m256i
mask
=
_m256_continue_mask
(
remain_n
);
store_overflow
<
CType
>
(
c_ptr
,
c_vec
[
0
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
),
mask
,
c_vec
[
0
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
ldc
),
mask
,
c_vec
[
2
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
,
c_vec
[
4
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
2
*
ldc
),
mask
,
c_vec
[
4
]);
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
,
c_vec
[
6
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
3
*
ldc
),
mask
,
c_vec
[
6
]);
}
}
}
}
template
<
typename
CType
>
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
MEGDNN_ATTRIBUTE_TARGET
(
"avx2"
)
static
inline
void
kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n
(
static
inline
void
kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n
(
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
int32_t
*
c_ptr
,
const
int16_t
*
pack_a_ptr
,
const
int8_t
*
pack_b_ptr
,
CType
*
c_ptr
,
const
uint32_t
ldc
,
const
uint32_t
k
,
const
uint32_t
remain_m
,
const
uint32_t
ldc
,
const
uint32_t
k
,
const
uint32_t
remain_m
,
uint32_t
remain_n
)
{
uint32_t
remain_n
)
{
constexpr
uint32_t
k_step
=
2
;
constexpr
uint32_t
k_step
=
2
;
...
@@ -549,19 +587,19 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n(
...
@@ -549,19 +587,19 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n(
}
}
if
(
remain_n
>=
8
)
{
if
(
remain_n
>=
8
)
{
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
)
,
c_vec
[
0
]);
store_overflow
<
CType
>
(
c_ptr
,
c_vec
[
0
]);
switch
(
remain_m
)
{
switch
(
remain_m
)
{
case
2
:
case
2
:
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
ldc
)
,
c_vec
[
2
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
]);
break
;
break
;
case
3
:
case
3
:
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
ldc
)
,
c_vec
[
2
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
2
*
ldc
)
,
c_vec
[
4
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
,
c_vec
[
4
]);
break
;
break
;
case
4
:
case
4
:
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
ldc
)
,
c_vec
[
2
]);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
2
*
ldc
)
,
c_vec
[
4
]);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
,
c_vec
[
4
]);
_mm256_storeu_si256
((
__m256i
*
)(
c_ptr
+
3
*
ldc
)
,
c_vec
[
6
]);
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
,
c_vec
[
6
]);
break
;
break
;
default:
default:
break
;
break
;
...
@@ -569,43 +607,41 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n(
...
@@ -569,43 +607,41 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n(
remain_n
-=
8
;
remain_n
-=
8
;
if
(
remain_n
>
0
)
{
if
(
remain_n
>
0
)
{
__m256i
mask
=
_m256_continue_mask
(
remain_n
);
store_overflow
<
CType
>
(
c_ptr
+
8
,
c_vec
[
1
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
8
),
mask
,
c_vec
[
1
]);
switch
(
remain_m
)
{
switch
(
remain_m
)
{
case
2
:
case
2
:
_mm256_maskstore_epi32
((
c_ptr
+
ldc
+
8
),
mask
,
c_vec
[
3
]
);
store_overflow
<
CType
>
(
c_ptr
+
ldc
+
8
,
c_vec
[
3
],
remain_n
);
break
;
break
;
case
3
:
case
3
:
_mm256_maskstore_epi32
((
c_ptr
+
ldc
+
8
),
mask
,
c_vec
[
3
]
);
store_overflow
<
CType
>
(
c_ptr
+
ldc
+
8
,
c_vec
[
3
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
2
*
ldc
+
8
),
mask
,
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
+
8
,
c_vec
[
5
]
,
c_vec
[
5
]
);
remain_n
);
break
;
break
;
case
4
:
case
4
:
_mm256_maskstore_epi32
((
c_ptr
+
ldc
+
8
),
mask
,
c_vec
[
3
]
);
store_overflow
<
CType
>
(
c_ptr
+
ldc
+
8
,
c_vec
[
3
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
2
*
ldc
+
8
),
mask
,
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
+
8
,
c_vec
[
5
]
,
c_vec
[
5
]
);
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
3
*
ldc
+
8
),
mask
,
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
+
8
,
c_vec
[
7
]
,
c_vec
[
7
]
);
remain_n
);
break
;
break
;
default:
default:
break
;
break
;
}
}
}
}
}
else
{
}
else
{
__m256i
mask
=
_m256_continue_mask
(
remain_n
);
store_overflow
<
CType
>
(
c_ptr
,
c_vec
[
0
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
),
mask
,
c_vec
[
0
]);
switch
(
remain_m
)
{
switch
(
remain_m
)
{
case
2
:
case
2
:
_mm256_maskstore_epi32
((
c_ptr
+
ldc
),
mask
,
c_vec
[
2
]
);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
],
remain_n
);
break
;
break
;
case
3
:
case
3
:
_mm256_maskstore_epi32
((
c_ptr
+
ldc
),
mask
,
c_vec
[
2
]
);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
2
*
ldc
),
mask
,
c_vec
[
4
]
);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
,
c_vec
[
4
],
remain_n
);
break
;
break
;
case
4
:
case
4
:
_mm256_maskstore_epi32
((
c_ptr
+
ldc
),
mask
,
c_vec
[
2
]
);
store_overflow
<
CType
>
(
c_ptr
+
ldc
,
c_vec
[
2
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
2
*
ldc
),
mask
,
c_vec
[
4
]
);
store_overflow
<
CType
>
(
c_ptr
+
2
*
ldc
,
c_vec
[
4
],
remain_n
);
_mm256_maskstore_epi32
((
c_ptr
+
3
*
ldc
),
mask
,
c_vec
[
6
]
);
store_overflow
<
CType
>
(
c_ptr
+
3
*
ldc
,
c_vec
[
6
],
remain_n
);
break
;
break
;
default:
default:
break
;
break
;
...
@@ -833,4 +869,5 @@ static inline void gemm_s8s8s32_avx2_4x16x2_pack_at(dt_int16* out,
...
@@ -833,4 +869,5 @@ static inline void gemm_s8s8s32_avx2_4x16x2_pack_at(dt_int16* out,
}
// namespace x86
}
// namespace x86
}
// namespace megdnn
}
// namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
// vim: syntax=cpp.doxygen
\ No newline at end of file
dnn/src/x86/matrix_mul/int8/strategy.h
浏览文件 @
25b6a131
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#pragma once
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"
#include "src/fallback/matrix_mul/gemm_common.h"
...
@@ -29,6 +30,10 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int32, dt_int32,
...
@@ -29,6 +30,10 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int32, dt_int32,
4
,
16
,
2
,
false
,
false
,
4
,
16
,
2
,
false
,
false
,
gemm_avx2_s8s8s32_4x16x2
);
gemm_avx2_s8s8s32_4x16x2
);
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE
(
dt_int8
,
dt_int16
,
dt_int16
,
dt_int32
,
4
,
16
,
2
,
false
,
false
,
gemm_avx2_s8s8s16_4x16x2
);
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE
(
dt_int8
,
dt_int16
,
dt_int32
,
dt_int32
,
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE
(
dt_int8
,
dt_int16
,
dt_int32
,
dt_int32
,
4
,
8
,
2
,
false
,
false
,
4
,
8
,
2
,
false
,
false
,
gemm_sse_s8s8s32_4x8x2
);
gemm_sse_s8s8s32_4x8x2
);
...
...
dnn/src/x86/matrix_mul/opr_impl.cpp
浏览文件 @
25b6a131
...
@@ -37,6 +37,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
...
@@ -37,6 +37,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt8x8x32AVX2M4N16K2
algoint8x8x32avx2_m4n16k2
;
AlgoInt8x8x32AVX2M4N16K2
algoint8x8x32avx2_m4n16k2
;
AlgoInt8x8x32AVX2M2N4K16
algoint8x8x32avx2_m2n4k16
;
AlgoInt8x8x32AVX2M2N4K16
algoint8x8x32avx2_m2n4k16
;
AlgoInt8x8x32SSEM4N8K2
algoint8x8x32sse_m4n8k2
;
AlgoInt8x8x32SSEM4N8K2
algoint8x8x32sse_m4n8k2
;
AlgoInt8x8x16AVX2
algoint8x8x16avx2_m4n16k2
;
AlgoF32MK8_8x8
algof32mk8_8x8
;
AlgoF32MK8_8x8
algof32mk8_8x8
;
public:
public:
...
@@ -47,6 +48,7 @@ public:
...
@@ -47,6 +48,7 @@ public:
#endif
#endif
}
}
all_algos
.
emplace_back
(
&
algoint8x8x32avx2_m4n16k2
);
all_algos
.
emplace_back
(
&
algoint8x8x32avx2_m4n16k2
);
all_algos
.
emplace_back
(
&
algoint8x8x16avx2_m4n16k2
);
all_algos
.
emplace_back
(
&
algoint8x8x32avx2_m2n4k16
);
all_algos
.
emplace_back
(
&
algoint8x8x32avx2_m2n4k16
);
all_algos
.
emplace_back
(
&
algoint8x8x32sse_m4n8k2
);
all_algos
.
emplace_back
(
&
algoint8x8x32sse_m4n8k2
);
all_algos
.
emplace_back
(
&
algof32mk8_8x8
);
all_algos
.
emplace_back
(
&
algof32mk8_8x8
);
...
...
dnn/src/x86/matrix_mul/opr_impl.h
浏览文件 @
25b6a131
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#pragma once
#pragma once
...
@@ -54,6 +55,7 @@ protected:
...
@@ -54,6 +55,7 @@ protected:
class
AlgoInt8x8x32AVX2M2N4K16
;
class
AlgoInt8x8x32AVX2M2N4K16
;
class
AlgoInt8x8x32AVX2M4N16K2
;
class
AlgoInt8x8x32AVX2M4N16K2
;
class
AlgoInt8x8x32SSEM4N8K2
;
class
AlgoInt8x8x32SSEM4N8K2
;
class
AlgoInt8x8x16AVX2
;
class
AlgoPack
;
class
AlgoPack
;
class
AlgoF32MK8_8x8
;
class
AlgoF32MK8_8x8
;
};
};
...
...
dnn/test/x86/conv_bias.cpp
浏览文件 @
25b6a131
...
@@ -752,7 +752,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2) {
...
@@ -752,7 +752,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2) {
}
}
}
}
TEST_F
(
X86_MULTI_THREADS
,
CONV_BIAS_IM2COLMATMUL_INT8
x8x32
)
{
TEST_F
(
X86_MULTI_THREADS
,
CONV_BIAS_IM2COLMATMUL_INT8
X8X
)
{
using
namespace
conv_bias
;
using
namespace
conv_bias
;
std
::
vector
<
TestArg
>
args
;
std
::
vector
<
TestArg
>
args
;
...
@@ -807,6 +807,16 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
...
@@ -807,6 +807,16 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
.set_param(arg.param) \
.set_param(arg.param) \
.execs({arg.src, arg.filter, {}, {}, {}}); \
.execs({arg.src, arg.filter, {}, {}, {}}); \
}
}
#define cb2(algo_name) \
checker.set_before_exec_callback( \
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name)); \
checker.set_dtype(0, dtype::Int8()); \
checker.set_dtype(1, dtype::Int8()); \
checker.set_dtype(2, dtype::Int16()); \
checker.set_dtype(4, dtype::Int16()); \
for (auto&& arg : args) { \
checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}}); \
}
#if MEGDNN_X86_WITH_MKL_DNN
#if MEGDNN_X86_WITH_MKL_DNN
if
(
megdnn
::
x86
::
is_supported
(
x86
::
SIMDType
::
VNNI
))
{
if
(
megdnn
::
x86
::
is_supported
(
x86
::
SIMDType
::
VNNI
))
{
...
@@ -821,12 +831,14 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
...
@@ -821,12 +831,14 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
if
(
megdnn
::
x86
::
is_supported
(
x86
::
SIMDType
::
AVX2
))
{
if
(
megdnn
::
x86
::
is_supported
(
x86
::
SIMDType
::
AVX2
))
{
cb
(
"IM2COLMATMUL:X86_INT8X8X32_AVX2_2X4X16"
);
cb
(
"IM2COLMATMUL:X86_INT8X8X32_AVX2_2X4X16"
);
cb
(
"IM2COLMATMUL:X86_INT8X8X32_AVX2_4X16X2"
);
cb
(
"IM2COLMATMUL:X86_INT8X8X32_AVX2_4X16X2"
);
cb2
(
"IM2COLMATMUL:X86_INT8X8X16_AVX2"
);
}
}
if
(
::
megdnn
::
x86
::
is_supported
(
::
megdnn
::
x86
::
SIMDType
::
SSE4_2
))
{
if
(
::
megdnn
::
x86
::
is_supported
(
::
megdnn
::
x86
::
SIMDType
::
SSE4_2
))
{
cb
(
"IM2COLMATMUL:X86_INT8X8X32_SSE_4X8X2"
);
cb
(
"IM2COLMATMUL:X86_INT8X8X32_SSE_4X8X2"
);
}
}
#undef cb
#undef cb
#undef cb2
}
}
TEST_F
(
X86_MULTI_THREADS
,
CONV_BIAS_IM2COLMATMUL_FP32
)
{
TEST_F
(
X86_MULTI_THREADS
,
CONV_BIAS_IM2COLMATMUL_FP32
)
{
...
@@ -1964,6 +1976,39 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECT_AVX2_INT8) {
...
@@ -1964,6 +1976,39 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECT_AVX2_INT8) {
shapes_and_computation
.
clear
();
shapes_and_computation
.
clear
();
}
}
TEST_F
(
X86_BENCHMARK_MULTI_THREADS
,
BENCHMARK_CONVBIAS_8816
)
{
constexpr
size_t
RUNS
=
30
;
param
::
ConvBias
param
;
param
.
stride_h
=
1
;
param
.
stride_w
=
1
;
param
.
sparse
=
param
::
ConvBias
::
Sparse
::
DENSE
;
std
::
vector
<
DType
>
data_type
=
{
dtype
::
Int8
(),
dtype
::
Int8
(),
dtype
::
Int16
(),
dtype
::
Int16
()};
std
::
vector
<
std
::
pair
<
SmallVector
<
TensorShape
>
,
float
>>
shapes_and_computation
;
auto
bench_case
=
[
&
](
size_t
N
,
size_t
IC
,
size_t
OC
,
size_t
H
,
size_t
W
,
size_t
FS
)
{
param
.
pad_h
=
FS
/
2
;
param
.
pad_w
=
FS
/
2
;
SmallVector
<
TensorShape
>
shapes
{
{
N
,
IC
,
H
,
W
},
{
OC
,
IC
,
FS
,
FS
},
{},
{},
{}};
TensorShape
dst
{
N
,
OC
,
(
H
+
2
*
param
.
pad_h
-
FS
)
/
param
.
stride_h
+
1
,
(
W
+
2
*
param
.
pad_w
-
FS
)
/
param
.
stride_w
+
1
};
float
computations
=
(
IC
*
FS
*
FS
*
dst
.
total_nr_elems
()
*
2
)
*
1e-6
;
shapes_and_computation
.
push_back
(
std
::
make_pair
(
shapes
,
computations
));
};
bench_case
(
1
,
48
,
192
,
15
,
15
,
1
);
std
::
string
algo_name
=
"IM2COLMATMUL:X86_INT8X8X16_AVX2"
;
benchmark_impl
(
param
,
shapes_and_computation
,
algo_name
,
RUNS
,
{
4
,
{
4
,
5
,
6
,
7
}},
{
1
,
{
4
}},
data_type
);
shapes_and_computation
.
clear
();
}
TEST_F
(
X86_BENCHMARK_MULTI_THREADS
,
TEST_F
(
X86_BENCHMARK_MULTI_THREADS
,
BENCHMARK_CONVBIAS_DIRECT_AVX2_INT8_STRIDE2
)
{
BENCHMARK_CONVBIAS_DIRECT_AVX2_INT8_STRIDE2
)
{
constexpr
size_t
RUNS
=
50
;
constexpr
size_t
RUNS
=
50
;
...
@@ -1985,7 +2030,7 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS,
...
@@ -1985,7 +2030,7 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS,
SmallVector
<
TensorShape
>
shapes
{
SmallVector
<
TensorShape
>
shapes
{
{
N
,
IC
,
H
,
W
},
{
OC
,
IC
,
FS
,
FS
},
{},
{},
{}};
{
N
,
IC
,
H
,
W
},
{
OC
,
IC
,
FS
,
FS
},
{},
{},
{}};
TensorShape
dst
{
N
,
OC
,
(
H
+
2
*
param
.
pad_h
-
FS
)
/
param
.
stride_h
+
1
,
TensorShape
dst
{
N
,
OC
,
(
H
+
2
*
param
.
pad_h
-
FS
)
/
param
.
stride_h
+
1
,
(
W
+
2
*
param
.
pad_w
-
FS
)
/
param
.
pad
_w
+
1
};
(
W
+
2
*
param
.
pad_w
-
FS
)
/
param
.
stride
_w
+
1
};
float
computations
=
(
IC
*
FS
*
FS
*
dst
.
total_nr_elems
()
*
2
)
*
1e-6
;
float
computations
=
(
IC
*
FS
*
FS
*
dst
.
total_nr_elems
()
*
2
)
*
1e-6
;
shapes_and_computation
.
push_back
(
std
::
make_pair
(
shapes
,
computations
));
shapes_and_computation
.
push_back
(
std
::
make_pair
(
shapes
,
computations
));
};
};
...
...
dnn/test/x86/convolution.cpp
浏览文件 @
25b6a131
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "test/x86/fixture.h"
#include "test/x86/fixture.h"
...
@@ -369,6 +370,63 @@ TEST_F(X86, CONVOLUTION_DIRECT_MKLDNN_C8) {
...
@@ -369,6 +370,63 @@ TEST_F(X86, CONVOLUTION_DIRECT_MKLDNN_C8) {
#endif
#endif
#if MEGDNN_WITH_BENCHMARK
#if MEGDNN_WITH_BENCHMARK
TEST_F
(
X86
,
BENCHMARK_CONVOLUTION_I8x8x16
)
{
using
namespace
convolution
;
using
Param
=
param
::
Convolution
;
std
::
vector
<
TestArg
>
args
;
auto
run
=
[
&
](
size_t
oc
,
size_t
ic
,
size_t
w
,
size_t
h
,
size_t
kernel
,
size_t
stride
,
size_t
group
=
1
)
{
Param
param
;
param
.
stride_h
=
stride
;
param
.
stride_w
=
stride
;
param
.
pad_h
=
kernel
/
2
;
param
.
pad_w
=
kernel
/
2
;
if
(
group
>
1
)
{
param
.
sparse
=
param
::
Convolution
::
Sparse
::
GROUP
;
args
.
emplace_back
(
param
,
TensorShape
{
1
,
ic
,
h
,
w
},
TensorShape
{
group
,
oc
/
group
,
ic
/
group
,
kernel
,
kernel
});
}
else
{
param
.
sparse
=
param
::
Convolution
::
Sparse
::
DENSE
;
args
.
emplace_back
(
param
,
TensorShape
{
1
,
ic
,
h
,
w
},
TensorShape
{
oc
,
ic
,
kernel
,
kernel
});
}
};
run
(
48
,
96
,
15
,
15
,
1
,
1
);
run
(
64
,
64
,
60
,
60
,
3
,
1
);
run
(
64
,
64
,
60
,
60
,
3
,
1
,
64
);
constexpr
size_t
RUN
=
30
;
Benchmarker
<
Convolution
>
benchmark
(
handle
());
benchmark
.
set_dtype
(
0
,
dtype
::
Int8
())
.
set_dtype
(
1
,
dtype
::
Int8
())
.
set_dtype
(
2
,
dtype
::
Int16
());
benchmark
.
set_display
(
false
);
benchmark
.
set_times
(
RUN
);
for
(
auto
&&
arg
:
args
)
{
TensorLayout
dst_layout
;
auto
opr
=
handle
()
->
create_operator
<
Convolution
>
();
opr
->
param
()
=
arg
.
param
;
opr
->
deduce_layout
({
arg
.
src
,
dtype
::
Float32
()},
{
arg
.
filter
,
dtype
::
Float32
()},
dst_layout
);
//! dst.nr_elems * IC * FH * FW * 2
float
icpg
=
arg
.
filter
.
ndim
==
4
?
arg
.
filter
[
1
]
:
arg
.
filter
[
2
];
float
filter
=
arg
.
filter
.
ndim
==
4
?
arg
.
filter
[
2
]
:
arg
.
filter
[
3
];
float
computations
=
dst_layout
.
total_nr_elems
()
*
icpg
*
filter
*
filter
*
2.0
/
(
1024
*
1024
*
1024
)
*
1e3
;
auto
used_int
=
benchmark
.
set_param
(
arg
.
param
).
exec
({
arg
.
src
,
arg
.
filter
,
{}})
/
RUN
;
printf
(
"%s %s: int: %f ms %f Gflops
\n
"
,
arg
.
src
.
to_string
().
c_str
(),
arg
.
filter
.
to_string
().
c_str
(),
used_int
,
computations
/
used_int
);
}
}
#if MEGDNN_X86_WITH_MKL_DNN
#if MEGDNN_X86_WITH_MKL_DNN
TEST_F
(
X86
,
BENCHMARK_CONVOLUTION_I8x8x32_MKLDNN
)
{
TEST_F
(
X86
,
BENCHMARK_CONVOLUTION_I8x8x32_MKLDNN
)
{
using
namespace
convolution
;
using
namespace
convolution
;
...
@@ -419,7 +477,6 @@ TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x32_MKLDNN) {
...
@@ -419,7 +477,6 @@ TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x32_MKLDNN) {
float
computations
=
dst_layout
.
total_nr_elems
()
*
arg
.
filter
[
1
]
*
float
computations
=
dst_layout
.
total_nr_elems
()
*
arg
.
filter
[
1
]
*
arg
.
filter
[
2
]
*
arg
.
filter
[
3
]
*
2.0
/
arg
.
filter
[
2
]
*
arg
.
filter
[
3
]
*
2.0
/
(
1024
*
1024
*
1024
)
*
1e3
;
(
1024
*
1024
*
1024
)
*
1e3
;
auto
used_int
=
auto
used_int
=
benchmark
.
set_param
(
arg
.
param
).
exec
({
arg
.
src
,
arg
.
filter
,
{}})
/
benchmark
.
set_param
(
arg
.
param
).
exec
({
arg
.
src
,
arg
.
filter
,
{}})
/
RUN
;
RUN
;
...
...
dnn/test/x86/matrix_mul.cpp
浏览文件 @
25b6a131
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "test/x86/fixture.h"
#include "test/x86/fixture.h"
...
@@ -47,6 +48,10 @@ TEST_F(X86, MATRIX_MUL_AVX2_8X8X32) {
...
@@ -47,6 +48,10 @@ TEST_F(X86, MATRIX_MUL_AVX2_8X8X32) {
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int32
{},
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int32
{},
handle
(),
"X86_INT8X8X32_AVX2_4X16X2"
);
handle
(),
"X86_INT8X8X32_AVX2_4X16X2"
);
}
}
TEST_F
(
X86
,
MATRIX_MUL_AVX2_8X8X16
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int16
{},
handle
(),
"X86_INT8X8X16_AVX2"
);
}
TEST_F
(
X86
,
MATRIX_MUL_SSE_8X8X32
)
{
TEST_F
(
X86
,
MATRIX_MUL_SSE_8X8X32
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int32
{},
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int32
{},
handle
(),
"X86_INT8X8X32_SSE_4X8X2"
);
handle
(),
"X86_INT8X8X32_SSE_4X8X2"
);
...
@@ -116,6 +121,17 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) {
...
@@ -116,6 +121,17 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) {
benchmarker_avx2_4x16x2
.
set_before_exec_callback
(
benchmarker_avx2_4x16x2
.
set_before_exec_callback
(
AlgoChecker
<
MatrixMul
>
(
"X86_INT8X8X32_AVX2_4X16X2"
));
AlgoChecker
<
MatrixMul
>
(
"X86_INT8X8X32_AVX2_4X16X2"
));
Benchmarker
<
MatrixMul
>
benchmarker_avx2_4x16x2_8816
(
handle
());
benchmarker_avx2_4x16x2_8816
.
set_display
(
false
)
.
set_times
(
RUNS
)
.
set_dtype
(
0
,
dtype
::
Int8
{})
.
set_dtype
(
1
,
dtype
::
Int8
{})
.
set_dtype
(
2
,
dtype
::
Int16
{})
.
set_rng
(
0
,
rng
.
get
())
.
set_rng
(
1
,
rng
.
get
());
benchmarker_avx2_4x16x2_8816
.
set_before_exec_callback
(
AlgoChecker
<
MatrixMul
>
(
"X86_INT8X8X16_AVX2"
));
Benchmarker
<
MatrixMul
>
benchmarker_avx2_2x4x16
(
handle
());
Benchmarker
<
MatrixMul
>
benchmarker_avx2_2x4x16
(
handle
());
benchmarker_avx2_2x4x16
.
set_display
(
false
)
benchmarker_avx2_2x4x16
.
set_display
(
false
)
.
set_times
(
RUNS
)
.
set_times
(
RUNS
)
...
@@ -183,6 +199,12 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) {
...
@@ -183,6 +199,12 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) {
<<
"k2_speed_up "
<<
float_used
/
avx2_used_4x16x2
<<
"k2_speed_up "
<<
float_used
/
avx2_used_4x16x2
<<
", k16_speed_up "
<<
float_used
/
avx2_used_2x4x16
<<
", k16_speed_up "
<<
float_used
/
avx2_used_2x4x16
<<
","
;
<<
","
;
auto
avx2_used_4x16x2_8816
=
benchmarker_avx2_4x16x2_8816
.
exec
({{
M
,
K
},
{
K
,
N
},
{}})
/
RUNS
;
std
::
cout
<<
"avx2_8816: "
<<
avx2_used_4x16x2_8816
<<
" ms, 8816 throughput "
<<
computations
/
avx2_used_4x16x2_8816
<<
" Gflops,"
;
}
}
if
(
is_supported
(
SIMDType
::
SSE4_1
))
{
if
(
is_supported
(
SIMDType
::
SSE4_1
))
{
auto
sse_used
=
auto
sse_used
=
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录