Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ca2828dd
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ca2828dd
编写于
6月 28, 2021
作者:
M
Megvii Engine Team
提交者:
huangxinda
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(dnn/x86): fix x86 int8 matmul ldc bug
GitOrigin-RevId: 2502f99000d5e90fdc410b1d0bf731668cd1077c
上级
aa4e8476
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
63 addition
and
37 deletion
+63
-37
dnn/src/x86/matrix_mul/int8/avx2_strategy_2x4x16.cpp
dnn/src/x86/matrix_mul/int8/avx2_strategy_2x4x16.cpp
+4
-4
dnn/src/x86/matrix_mul/int8/avx2_strategy_4x16x2.cpp
dnn/src/x86/matrix_mul/int8/avx2_strategy_4x16x2.cpp
+4
-4
dnn/src/x86/matrix_mul/int8/sse_strategy_4x8x2.cpp
dnn/src/x86/matrix_mul/int8/sse_strategy_4x8x2.cpp
+4
-4
dnn/test/common/checker.h
dnn/test/common/checker.h
+11
-1
dnn/test/common/matrix_mul.cpp
dnn/test/common/matrix_mul.cpp
+19
-15
dnn/test/common/matrix_mul.h
dnn/test/common/matrix_mul.h
+5
-3
dnn/test/x86/matrix_mul.cpp
dnn/test/x86/matrix_mul.cpp
+16
-6
未找到文件。
dnn/src/x86/matrix_mul/int8/avx2_strategy_2x4x16.cpp
浏览文件 @
ca2828dd
...
...
@@ -71,13 +71,13 @@ void gemm_avx2_s8s8s32_2x4x16::kern(const dt_int8* pack_a_ptr,
auto
iter_a_ptr
=
pack_a_ptr
+
m_offset
*
roundup_k
;
for
(
size_t
n_offset
=
0
;
n_offset
<
n_end
;
n_offset
+=
n_tile
)
{
auto
iter_b_ptr
=
pack_b_ptr
+
n_offset
*
roundup_k
;
auto
iter_c_ptr
=
c_ptr
+
m_offset
*
n
+
n_offset
;
auto
iter_c_ptr
=
c_ptr
+
m_offset
*
ldc
+
n_offset
;
matmul_avx2_2x4x16
::
kern_gemm_s8s8s32_2x4x16
(
iter_a_ptr
,
iter_b_ptr
,
iter_c_ptr
,
ldc
,
k
);
}
if
(
n_end
<
n
)
{
auto
iter_b_ptr
=
pack_b_ptr
+
n_end
*
roundup_k
;
auto
iter_c_ptr
=
c_ptr
+
m_offset
*
n
+
n_end
;
auto
iter_c_ptr
=
c_ptr
+
m_offset
*
ldc
+
n_end
;
matmul_avx2_2x4x16
::
kern_gemm_s8s8s32_2x4x16_remain
(
iter_a_ptr
,
iter_b_ptr
,
iter_c_ptr
,
ldc
,
k
,
m_tile
,
n_remain
);
...
...
@@ -87,14 +87,14 @@ void gemm_avx2_s8s8s32_2x4x16::kern(const dt_int8* pack_a_ptr,
auto
iter_a_ptr
=
pack_a_ptr
+
m_end
*
roundup_k
;
for
(
size_t
n_offset
=
0
;
n_offset
<
n_end
;
n_offset
+=
n_tile
)
{
auto
iter_b_ptr
=
pack_b_ptr
+
n_offset
*
roundup_k
;
auto
iter_c_ptr
=
c_ptr
+
m_end
*
n
+
n_offset
;
auto
iter_c_ptr
=
c_ptr
+
m_end
*
ldc
+
n_offset
;
matmul_avx2_2x4x16
::
kern_gemm_s8s8s32_2x4x16_remain
(
iter_a_ptr
,
iter_b_ptr
,
iter_c_ptr
,
ldc
,
k
,
m_remain
,
n_tile
);
}
if
(
n_end
<
n
)
{
auto
iter_b_ptr
=
pack_b_ptr
+
n_end
*
roundup_k
;
auto
iter_c_ptr
=
c_ptr
+
m_end
*
n
+
n_end
;
auto
iter_c_ptr
=
c_ptr
+
m_end
*
ldc
+
n_end
;
matmul_avx2_2x4x16
::
kern_gemm_s8s8s32_2x4x16_remain
(
iter_a_ptr
,
iter_b_ptr
,
iter_c_ptr
,
ldc
,
k
,
m_remain
,
n_remain
);
...
...
dnn/src/x86/matrix_mul/int8/avx2_strategy_4x16x2.cpp
浏览文件 @
ca2828dd
...
...
@@ -59,13 +59,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr,
auto
iter_a_ptr
=
pack_a_ptr
+
m_offset
*
roundup_k
;
for
(
size_t
n_offset
=
0
;
n_offset
<
n_end
;
n_offset
+=
n_tile
)
{
auto
iter_b_ptr
=
pack_b_ptr
+
n_offset
*
roundup_k
;
auto
iter_c_ptr
=
c_ptr
+
m_offset
*
n
+
n_offset
;
auto
iter_c_ptr
=
c_ptr
+
m_offset
*
ldc
+
n_offset
;
matmul_avx2_4x16x2
::
kern_gemm_s8s8s32_avx2_4x16x2
(
iter_a_ptr
,
iter_b_ptr
,
iter_c_ptr
,
ldc
,
k
);
}
if
(
n_remain
>
0
)
{
auto
iter_b_ptr
=
pack_b_ptr
+
n_end
*
roundup_k
;
auto
iter_c_ptr
=
c_ptr
+
m_offset
*
n
+
n_end
;
auto
iter_c_ptr
=
c_ptr
+
m_offset
*
ldc
+
n_end
;
if
(
n_remain
<=
8
)
{
matmul_avx2_4x16x2
::
kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n
(
iter_a_ptr
,
iter_b_ptr
,
iter_c_ptr
,
ldc
,
k
,
n_remain
);
...
...
@@ -79,13 +79,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr,
auto
iter_a_ptr
=
pack_a_ptr
+
m_end
*
roundup_k
;
for
(
size_t
n_offset
=
0
;
n_offset
<
n_end
;
n_offset
+=
n_tile
)
{
auto
iter_b_ptr
=
pack_b_ptr
+
n_offset
*
roundup_k
;
auto
iter_c_ptr
=
c_ptr
+
m_end
*
n
+
n_offset
;
auto
iter_c_ptr
=
c_ptr
+
m_end
*
ldc
+
n_offset
;
matmul_avx2_4x16x2
::
kern_gemm_s8s8s32_avx2_4x16x2_remain_m
(
iter_a_ptr
,
iter_b_ptr
,
iter_c_ptr
,
ldc
,
k
,
m_remain
);
}
if
(
n_remain
>
0
)
{
auto
iter_b_ptr
=
pack_b_ptr
+
n_end
*
roundup_k
;
auto
iter_c_ptr
=
c_ptr
+
m_end
*
n
+
n_end
;
auto
iter_c_ptr
=
c_ptr
+
m_end
*
ldc
+
n_end
;
if
(
n_remain
<=
8
)
{
matmul_avx2_4x16x2
::
kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n
(
iter_a_ptr
,
iter_b_ptr
,
iter_c_ptr
,
ldc
,
k
,
m_remain
,
...
...
dnn/src/x86/matrix_mul/int8/sse_strategy_4x8x2.cpp
浏览文件 @
ca2828dd
...
...
@@ -59,13 +59,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr,
auto
iter_a_ptr
=
pack_a_ptr
+
m_offset
*
roundup_k
;
for
(
int
n_offset
=
0
;
n_offset
<
n_end
;
n_offset
+=
n_tile
)
{
auto
iter_b_ptr
=
pack_b_ptr
+
n_offset
*
roundup_k
;
auto
iter_c_ptr
=
c_ptr
+
m_offset
*
n
+
n_offset
;
auto
iter_c_ptr
=
c_ptr
+
m_offset
*
ldc
+
n_offset
;
matmul_sse_4x8x2
::
kern_gemm_s8s8s32_sse_4x8x2
(
iter_a_ptr
,
iter_b_ptr
,
iter_c_ptr
,
ldc
,
k
);
}
if
(
n_remain
>
0
)
{
auto
iter_b_ptr
=
pack_b_ptr
+
n_end
*
roundup_k
;
auto
iter_c_ptr
=
c_ptr
+
m_offset
*
n
+
n_end
;
auto
iter_c_ptr
=
c_ptr
+
m_offset
*
ldc
+
n_end
;
matmul_sse_4x8x2
::
kern_gemm_s8s8s32_sse_4x8x2_remain_n
(
iter_a_ptr
,
iter_b_ptr
,
iter_c_ptr
,
ldc
,
k
,
n_remain
);
}
...
...
@@ -74,13 +74,13 @@ static inline void gemm_kern(const dt_int16* pack_a_ptr,
auto
iter_a_ptr
=
pack_a_ptr
+
m_end
*
roundup_k
;
for
(
int
n_offset
=
0
;
n_offset
<
n_end
;
n_offset
+=
n_tile
)
{
auto
iter_b_ptr
=
pack_b_ptr
+
n_offset
*
roundup_k
;
auto
iter_c_ptr
=
c_ptr
+
m_end
*
n
+
n_offset
;
auto
iter_c_ptr
=
c_ptr
+
m_end
*
ldc
+
n_offset
;
matmul_sse_4x8x2
::
kern_gemm_s8s8s32_sse_4x8x2_remain_m
(
iter_a_ptr
,
iter_b_ptr
,
iter_c_ptr
,
ldc
,
k
,
m_remain
);
}
if
(
n_remain
>
0
)
{
auto
iter_b_ptr
=
pack_b_ptr
+
n_end
*
roundup_k
;
auto
iter_c_ptr
=
c_ptr
+
m_end
*
n
+
n_end
;
auto
iter_c_ptr
=
c_ptr
+
m_end
*
ldc
+
n_end
;
matmul_sse_4x8x2
::
kern_gemm_s8s8s32_sse_4x8x2_remain_m_n
(
iter_a_ptr
,
iter_b_ptr
,
iter_c_ptr
,
ldc
,
k
,
m_remain
,
n_remain
);
...
...
dnn/test/common/checker.h
浏览文件 @
ca2828dd
...
...
@@ -78,6 +78,7 @@ protected:
TensorsConstriant
m_tensor_constraint
;
bool
m_no_naive_and_check
=
false
;
bool
m_stable_check
=
false
;
bool
m_force_deduce_dst
=
true
;
/**
* the offset from the start of malloc memory
*
...
...
@@ -236,6 +237,12 @@ public:
return
*
this
;
}
//! froce deduce dst
Checker
&
set_force_deduce_dst
(
bool
force_deduce_dst
)
{
m_force_deduce_dst
=
force_deduce_dst
;
return
*
this
;
}
Checker
&
set_no_naive_check
(
bool
no_naive_and_check
)
{
m_no_naive_and_check
=
no_naive_and_check
;
return
*
this
;
...
...
@@ -343,7 +350,10 @@ void Checker<Opr, Proxy>::exec(TensorLayoutArray layouts) {
auto
opr_cur
=
this
->
opr
();
opr_naive
->
param
()
=
m_param
;
opr_cur
->
param
()
=
m_param
;
bool
deduce_layout
=
layouts
.
back
().
ndim
==
0
;
if
(
deduce_layout
||
m_force_deduce_dst
)
{
m_naive_proxy
.
deduce_layout
(
opr_naive
.
get
(),
layouts
);
}
auto
exec_naive
=
[
this
,
&
opr_naive
,
&
layouts
,
&
opr_relayout
](
const
TensorValueArray
&
values
)
{
TensorValueArray
contig_values
=
values
;
...
...
dnn/test/common/matrix_mul.cpp
浏览文件 @
ca2828dd
...
...
@@ -101,7 +101,7 @@ std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args_mask(
size_t
Astride
=
mask
&
1
?
m
+
2
:
k
+
2
;
// B: (k, n)
size_t
Bstride
=
mask
&
2
?
k
+
2
:
n
+
2
;
size_t
Cstride
=
n
+
2
;
size_t
Cstride
=
n
*
2
+
2
;
args
.
emplace_back
(
m
,
n
,
k
,
mask
,
Astride
,
Bstride
,
Cstride
);
}
return
args
;
...
...
@@ -183,9 +183,11 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
Handle
*
handle
,
const
ExecutionPolicyAlgoName
&
algo
,
param
::
MatrixMul
::
Format
format
,
size_t
nbase
,
float
eps
,
std
::
vector
<
TestArg
>&&
user_args
)
{
float
eps
,
std
::
vector
<
TestArg
>&&
user_args
,
bool
force_deduce_dst
)
{
megdnn_assert
(
A_dtype
.
enumv
()
==
B_dtype
.
enumv
());
Checker
<
Opr
>
checker
(
handle
);
checker
.
set_force_deduce_dst
(
force_deduce_dst
);
if
(
!
algo
.
name
.
empty
())
{
checker
.
set_before_exec_callback
(
AlgoChecker
<
Opr
>
(
algo
));
}
...
...
@@ -245,8 +247,8 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
for
(
auto
&
arg
:
args
)
{
size_t
m
=
arg
.
m
,
n
=
arg
.
n
,
k
=
arg
.
k
;
#if MEGDNN_WITH_CUDA
//[NOTE]
: cublas can only process 4B aligned 8-bit input matrix;
if
(
handle
->
type
()
==
Handle
::
HandleType
::
CUDA
)
{
//! NOTE
: cublas can only process 4B aligned 8-bit input matrix;
bool
is_dt_8bit
=
A_dtype
.
enumv
()
==
DTypeEnum
::
Int8
||
A_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
A_dtype
.
enumv
()
==
DTypeEnum
::
Uint8
||
...
...
@@ -254,7 +256,7 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
if
(
is_dt_8bit
&&
((
m
%
4
!=
0
)
||
(
n
%
4
!=
0
)))
{
continue
;
}
#endif
}
Param
param
;
param
.
transposeA
=
arg
.
mask
&
0x1
;
...
...
@@ -312,20 +314,22 @@ void matrix_mul::check_batched_matrix_mul(DType A_dtype, DType B_dtype,
DType
C_dtype
,
Handle
*
handle
,
const
ExecutionPolicyAlgoName
&
algo
,
float
eps
,
std
::
vector
<
TestArg
>&&
args
)
{
std
::
vector
<
TestArg
>&&
args
,
bool
force_deduce_dst
)
{
check_matrix_mul
<
megdnn
::
BatchedMatrixMul
>
(
A_dtype
,
B_dtype
,
C_dtype
,
handle
,
algo
,
param
::
MatrixMul
::
Format
::
DEFAULT
,
8
,
eps
,
std
::
forward
<
decltype
(
args
)
>
(
args
));
std
::
forward
<
decltype
(
args
)
>
(
args
)
,
force_deduce_dst
);
}
void
matrix_mul
::
check_matrix_mul
(
DType
A_dtype
,
DType
B_dtype
,
DType
C_dtype
,
Handle
*
handle
,
const
ExecutionPolicyAlgoName
&
algo
,
param
::
MatrixMul
::
Format
format
,
size_t
nbase
,
float
eps
)
{
float
eps
,
bool
force_deduce_dst
)
{
check_matrix_mul
<
megdnn
::
MatrixMul
>
(
A_dtype
,
B_dtype
,
C_dtype
,
handle
,
algo
,
format
,
nbase
,
eps
);
format
,
nbase
,
eps
,
{},
force_deduce_dst
);
}
#if MEGDNN_WITH_BENCHMARK
...
...
dnn/test/common/matrix_mul.h
浏览文件 @
ca2828dd
...
...
@@ -68,19 +68,21 @@ void check_matrix_mul(
DType
A_dtype
,
DType
B_dtype
,
DType
C_dtype
,
Handle
*
handle
,
const
ExecutionPolicyAlgoName
&
algo
=
{
""
,
{}},
param
::
MatrixMul
::
Format
format
=
param
::
MatrixMul
::
Format
::
DEFAULT
,
size_t
nbase
=
8
,
float
eps
=
1e-3
,
std
::
vector
<
TestArg
>&&
args
=
{});
size_t
nbase
=
8
,
float
eps
=
1e-3
,
std
::
vector
<
TestArg
>&&
args
=
{},
bool
force_deduce_dst
=
true
);
void
check_matrix_mul
(
DType
A_dtype
,
DType
B_dtype
,
DType
C_dtype
,
Handle
*
handle
,
const
ExecutionPolicyAlgoName
&
algo
=
{
""
,
{}},
param
::
MatrixMul
::
Format
format
=
param
::
MatrixMul
::
Format
::
DEFAULT
,
size_t
nbase
=
8
,
float
eps
=
1e-3
);
size_t
nbase
=
8
,
float
eps
=
1e-3
,
bool
force_deduce_dst
=
true
);
void
check_batched_matrix_mul
(
DType
A_dtype
,
DType
B_dtype
,
DType
C_dtype
,
Handle
*
handle
,
const
ExecutionPolicyAlgoName
&
algo
=
{
""
,
{}},
float
eps
=
1e-3
,
std
::
vector
<
TestArg
>&&
args
=
{});
std
::
vector
<
TestArg
>&&
args
=
{},
bool
force_deduce_dst
=
true
);
#if MEGDNN_WITH_BENCHMARK
std
::
vector
<
TestArg
>
get_benchmark_matmul_args
();
...
...
dnn/test/x86/matrix_mul.cpp
浏览文件 @
ca2828dd
...
...
@@ -44,21 +44,31 @@ TEST_F(X86, MATRIX_MUL_MKLDNN_8X8X32) {
//! FIXME: need to add tests of GEMV and QUINT8
TEST_F
(
X86
,
MATRIX_MUL_AVX2_8X8X32
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int32
{},
handle
(),
"X86_INT8X8X32_AVX2_2X4X16"
);
handle
(),
"X86_INT8X8X32_AVX2_2X4X16"
,
param
::
MatrixMul
::
Format
::
DEFAULT
,
8
,
1e-3
,
false
);
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int32
{},
handle
(),
"X86_INT8X8X32_AVX2_4X16X2"
);
handle
(),
"X86_INT8X8X32_AVX2_4X16X2"
,
param
::
MatrixMul
::
Format
::
DEFAULT
,
8
,
1e-3
,
false
);
}
TEST_F
(
X86
,
MATRIX_MUL_AVX2_8X8X16
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int16
{},
handle
(),
"X86_INT8X8X16_AVX2"
);
handle
(),
"X86_INT8X8X16_AVX2"
,
param
::
MatrixMul
::
Format
::
DEFAULT
,
8
,
1e-3
,
false
);
}
TEST_F
(
X86
,
MATRIX_MUL_SSE_8X8X16
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int16
{},
handle
(),
"X86_INT8X8X16_SSE"
);
handle
(),
"X86_INT8X8X16_SSE"
,
param
::
MatrixMul
::
Format
::
DEFAULT
,
8
,
1e-3
,
false
);
}
TEST_F
(
X86
,
MATRIX_MUL_SSE_8X8X32
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int32
{},
handle
(),
"X86_INT8X8X32_SSE_4X8X2"
);
handle
(),
"X86_INT8X8X32_SSE_4X8X2"
,
param
::
MatrixMul
::
Format
::
DEFAULT
,
8
,
1e-3
,
false
);
}
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
...
...
@@ -72,7 +82,7 @@ TEST_F(X86, MATRIX_MUL_MKL_PACKA) {
TEST_F
(
X86
,
MATRIX_MUL_AVX2_MK8_8X8
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Float32
{},
dtype
::
Float32
{},
dtype
::
Float32
{},
handle
(),
"X86_F32MK8_8X8"
,
param
::
MatrixMul
::
Format
::
MK8
,
1
);
param
::
MatrixMul
::
Format
::
MK8
,
1
,
1e-3
,
false
);
}
#if MEGDNN_WITH_BENCHMARK
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录