Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3597a6db
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3597a6db
编写于
6月 23, 2021
作者:
M
Megvii Engine Team
提交者:
huangxinda
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/arm): nchw_nchw44 conv support 1x1s1
GitOrigin-RevId: 8c8f7d7c763b603961ca27b1fd17425bafd019cd
上级
c64b1c94
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
146 addition
and
7 deletion
+146
-7
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp
...nv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp
+47
-0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp
...nv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp
+12
-0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp
...v_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp
+42
-0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp
...v_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp
+10
-0
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp
...src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp
+3
-0
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
...arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
+3
-0
dnn/src/common/nchw_nchwxx_valid.h
dnn/src/common/nchw_nchwxx_valid.h
+10
-6
dnn/test/arm_common/conv_bias_multi_thread.cpp
dnn/test/arm_common/conv_bias_multi_thread.cpp
+16
-0
dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp
dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp
+1
-0
src/plugin/impl/opr_footprint.cpp
src/plugin/impl/opr_footprint.cpp
+2
-1
未找到文件。
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp
浏览文件 @
3597a6db
...
...
@@ -47,6 +47,52 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> {
}
};
////////////////////stride 1///////////////////
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
,
int
ow_block
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
1
,
oc_block
,
ow_block
,
1
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
stride
=
1
;
constexpr
int
filter_hight
=
1
;
constexpr
int
filter_width
=
4
;
constexpr
int
weight_reg
=
2
;
constexpr
int
src_reg
=
2
;
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
1
;
constexpr
int
pack_iw_len
=
4
;
constexpr
int
simd_len
=
16
;
const
int
ld_bias
=
oc_step
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc
=
oc_step
*
filter_hight
*
filter_width
*
ic
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
int32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
remain_w
>
(
c
,
bias_ptr
,
ld_bias
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
ic_step
)
{
int8x16_t
src
[
src_reg
];
int8x16_t
weight
[
c_dim
][
weight_reg
];
// row 0
load_helper
<
src_reg
,
0
,
simd_len
,
0
,
Vld1q_s8
>
(
src
,
src_ptr
+
0
*
iw
*
pack_iw_len
,
0
);
load_helper
<
weight_reg
,
0
,
simd_len
,
c_dim
,
Vld1q_s8
>
(
weight
,
weight_ptr
,
ld_weight_oc
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_laneq_s32
,
ow_block
,
stride
>
(
c
,
src
,
weight
);
src_ptr
+=
ic_stride
;
weight_ptr
+=
filter_hight
*
filter_width
*
oc_step
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
dt_qint8
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
,
int
ow_block
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
2
,
oc_block
,
ow_block
,
...
...
@@ -441,6 +487,7 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter,
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)
#define DISPATCH_CONV_KERN(stride) \
GET_BIAS_MODE_PARAM(stride, 1) \
GET_BIAS_MODE_PARAM(stride, 2) \
GET_BIAS_MODE_PARAM(stride, 3) \
GET_BIAS_MODE_PARAM(stride, 5) \
...
...
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp
浏览文件 @
3597a6db
...
...
@@ -58,6 +58,17 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 2, T, T2, T3, T4> {
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
,
int
ow_block
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
1
,
oc_block
,
ow_block
,
2
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
impl
(
const
int8_t
*
,
const
int8_t
*
,
const
int32_t
*
,
int8_t
*
,
int
,
int
,
int
,
int
,
const
Op
&
)
{
megdnn_assert
(
0
,
"not impl"
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
,
int
ow_block
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
2
,
oc_block
,
ow_block
,
...
...
@@ -429,6 +440,7 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter,
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)
#define DISPATCH_CONV_KERN(stride) \
GET_BIAS_MODE_PARAM(stride, 1) \
GET_BIAS_MODE_PARAM(stride, 2) \
GET_BIAS_MODE_PARAM(stride, 3) \
GET_BIAS_MODE_PARAM(stride, 5) \
...
...
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp
浏览文件 @
3597a6db
...
...
@@ -112,6 +112,47 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, 1, T, T2, T3, T4> {
static
MEGDNN_ALWAYS_INLINE
void
impl
(
T
&
,
T2
&
,
T3
&
);
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
>
struct
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
remain_w
,
1
,
oc_block
,
1
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
stride
=
1
;
constexpr
int
filter_height
=
1
;
constexpr
int
filter_width
=
4
;
constexpr
int
oc_step
=
4
;
constexpr
int
loop_ic_step
=
1
;
constexpr
int
simd_len
=
16
;
constexpr
int
pack_iw_len
=
16
;
constexpr
int
src_reg
=
8
;
constexpr
int
weight_reg
=
1
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc
=
oc_step
*
filter_height
*
filter_width
*
ic
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
int32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
remain_w
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
const
int8_t
*
nchw_src_ptr
=
src_ptr
+
ic_idx
*
ic_stride
;
int8x16_t
src
[
src_reg
];
int8x16_t
dot4_weight
[
c_dim
][
weight_reg
];
int16x8_t
temp_c
[
4
];
load_helper
<
weight_reg
,
0
,
simd_len
,
c_dim
,
Vld1q_s8
>
(
dot4_weight
,
weight_ptr
,
ld_weight_oc
);
load_helper
<
src_reg
,
0
,
simd_len
,
0
,
Vld1q_s8
>
(
src
,
nchw_src_ptr
+
0
*
iw
*
pack_iw_len
,
0
);
cal_helper
<
0
,
0
,
c_dim
,
stride
>
(
c
,
src
,
dot4_weight
,
temp_c
);
weight_ptr
+=
oc_step
*
filter_height
*
filter_width
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
dt_qint8
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
>
struct
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
remain_w
,
2
,
oc_block
,
1
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
...
...
@@ -547,6 +588,7 @@ struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 1> {
INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)
#define INSTANCE_CONV_KERN(stride) \
INSTANCE_BIAS_MODE_PARAM(stride, 1) \
INSTANCE_BIAS_MODE_PARAM(stride, 2) \
INSTANCE_BIAS_MODE_PARAM(stride, 3) \
INSTANCE_BIAS_MODE_PARAM(stride, 5) \
...
...
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp
浏览文件 @
3597a6db
...
...
@@ -1033,6 +1033,15 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, stride> {
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
,
int
stride
>
struct
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
remain_w
,
1
,
oc_block
,
stride
>
{
static
void
impl
(
const
int8_t
*
,
const
int8_t
*
,
const
int32_t
*
,
int8_t
*
,
int
,
int
,
int
,
int
,
const
Op
&
)
{
megdnn_assert
(
0
,
"not impl nchw_nchw44 1x1 s2"
);
}
};
enum
PACK_MODE
{
NO_PAD
=
0
,
FIRST_PAD
=
1
,
LAST_PAD
=
2
};
template
<
PACK_MODE
mode
>
MEGDNN_ALWAYS_INLINE
void
pack_src_one_line
(
const
int8_t
*
inptr
,
int8_t
*
outptr
,
...
...
@@ -1398,6 +1407,7 @@ struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 2> {
INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)
#define INSTANCE_CONV_KERN(stride) \
INSTANCE_BIAS_MODE_PARAM(stride, 1) \
INSTANCE_BIAS_MODE_PARAM(stride, 2) \
INSTANCE_BIAS_MODE_PARAM(stride, 3) \
INSTANCE_BIAS_MODE_PARAM(stride, 5) \
...
...
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp
浏览文件 @
3597a6db
...
...
@@ -291,6 +291,9 @@ ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns(
#define DISPATCH_CONV_KERN(stride) \
switch (param.filter_meta.spatial[0]) { \
case 1: \
GET_BIAS_MODE_PARAM(stride, 1) \
break; \
case 2: \
GET_BIAS_MODE_PARAM(stride, 2) \
break; \
...
...
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
浏览文件 @
3597a6db
...
...
@@ -245,6 +245,9 @@ ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::dispatch_kerns(
#define DISPATCH_CONV_KERN(stride) \
switch (param.filter_meta.spatial[0]) { \
case 1: \
GET_BIAS_MODE_PARAM(stride, 1) \
break; \
case 2: \
GET_BIAS_MODE_PARAM(stride, 2) \
break; \
...
...
dnn/src/common/nchw_nchwxx_valid.h
浏览文件 @
3597a6db
...
...
@@ -74,9 +74,11 @@ inline bool nchw_nchwxx_valid<NCHW44_INT8>(
nonline_mode
==
param
::
ConvBias
::
NonlineMode
::
H_SWISH
;
bool
ok_src_dst
=
fm
.
icpg
<
4
&&
(
fm
.
ocpg
%
4
==
0
&&
fm
.
ocpg
>=
4
)
&&
fm
.
group
==
1
;
bool
ok_filter
=
fm
.
spatial_ndim
==
2
&&
fm
.
spatial
[
0
]
==
fm
.
spatial
[
1
]
&&
(
fm
.
spatial
[
0
]
==
2
||
fm
.
spatial
[
0
]
==
3
||
fm
.
spatial
[
0
]
==
5
||
fm
.
spatial
[
0
]
==
7
);
bool
ok_filter
=
fm
.
spatial_ndim
==
2
&&
fm
.
spatial
[
0
]
==
fm
.
spatial
[
1
]
&&
(
fm
.
spatial
[
0
]
==
2
||
fm
.
spatial
[
0
]
==
3
||
fm
.
spatial
[
0
]
==
5
||
fm
.
spatial
[
0
]
==
7
||
(
fm
.
spatial
[
0
]
==
1
&&
fm
.
stride
[
0
]
==
1
&&
fm
.
stride
[
1
]
==
1
));
bool
ok_slide
=
fm
.
dilation
[
0
]
==
1
&&
fm
.
dilation
[
1
]
==
1
&&
fm
.
stride
[
0
]
==
fm
.
stride
[
1
]
&&
(
fm
.
stride
[
0
]
==
1
||
fm
.
stride
[
1
]
==
2
);
...
...
@@ -126,9 +128,11 @@ inline bool nchw_nchwxx_valid<NCHW44_INT8_DOT>(
nonline_mode
==
param
::
ConvBias
::
NonlineMode
::
H_SWISH
;
bool
ok_src_dst
=
fm
.
icpg
<
4
&&
(
fm
.
ocpg
%
4
==
0
&&
fm
.
ocpg
>=
4
)
&&
fm
.
group
==
1
;
bool
ok_filter
=
fm
.
spatial_ndim
==
2
&&
fm
.
spatial
[
0
]
==
fm
.
spatial
[
1
]
&&
(
fm
.
spatial
[
0
]
==
2
||
fm
.
spatial
[
0
]
==
3
||
fm
.
spatial
[
0
]
==
5
||
fm
.
spatial
[
0
]
==
7
);
bool
ok_filter
=
fm
.
spatial_ndim
==
2
&&
fm
.
spatial
[
0
]
==
fm
.
spatial
[
1
]
&&
(
fm
.
spatial
[
0
]
==
2
||
fm
.
spatial
[
0
]
==
3
||
fm
.
spatial
[
0
]
==
5
||
fm
.
spatial
[
0
]
==
7
||
(
fm
.
spatial
[
0
]
==
1
&&
fm
.
stride
[
0
]
==
1
&&
fm
.
stride
[
1
]
==
1
));
bool
ok_slide
=
fm
.
dilation
[
0
]
==
1
&&
fm
.
dilation
[
1
]
==
1
&&
fm
.
stride
[
0
]
==
fm
.
stride
[
1
]
&&
(
fm
.
stride
[
0
]
==
1
||
fm
.
stride
[
1
]
==
2
);
...
...
dnn/test/arm_common/conv_bias_multi_thread.cpp
浏览文件 @
3597a6db
...
...
@@ -487,6 +487,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) {
handle
(),
"S8_CONV_NCHW_NCHW44"
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_INT8_NCHW_NCHW44_S1_F1
)
{
checker_conv_bias_qint8x8x8
(
get_nchw44_conv_bias_args
({
1
},
QUAN_NLMODE
,
BR_AND_NO_BIASMODE
,
1
,
false
,
true
),
handle
(),
"S8_CONV_NCHW_NCHW44"
);
}
/*****************************quint8 direct****************************/
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_QUINT8_STRIDE1
)
{
checker_conv_bias_quint8x8x8
(
get_int8_quint8_conv_bias_args
(
...
...
@@ -517,6 +524,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
checker_conv_bias_qint8x8x8
(
args
,
handle
(),
"ARMDOTS8_NCHW_NCHW44"
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_DOT_NCHW_NCHW44_S1_F1
)
{
auto
args
=
get_nchw44_conv_bias_args
({
1
},
QUAN_NLMODE
,
BR_AND_NO_BIASMODE
,
1
,
false
,
true
);
for
(
auto
&&
arg
:
args
)
{
arg
.
param
.
format
=
param
::
ConvBias
::
Format
::
NCHW44_DOT
;
}
checker_conv_bias_qint8x8x8
(
args
,
handle
(),
"ARMDOTS8_NCHW_NCHW44"
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_INT8_STRIDE1_WITHDOTPROD
)
{
checker_conv_bias_qint8x8x8
(
get_int8_quint8_conv_bias_args
(
{
2
,
3
,
5
,
7
},
1
,
false
,
false
,
false
),
...
...
dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp
浏览文件 @
3597a6db
...
...
@@ -635,6 +635,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44) {
benchmark_impl
(
param
,
shape_arg
,
".+"
,
RUNS
,
{
4
,
{
4
,
5
,
6
,
7
}},
{
1
,
{
7
}},
data_type
);
};
bench_case
(
1
,
2
,
64
,
160
,
160
,
1
,
1
,
0
,
1
,
true
);
bench_case
(
1
,
3
,
64
,
224
,
224
,
7
,
1
,
3
,
2
,
true
);
bench_case
(
1
,
64
,
64
,
56
,
56
,
3
,
1
,
1
,
1
);
bench_case
(
1
,
128
,
128
,
28
,
28
,
3
,
1
,
1
,
1
);
...
...
src/plugin/impl/opr_footprint.cpp
浏览文件 @
3597a6db
...
...
@@ -131,7 +131,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
if
(
param
.
format
==
Param
::
Format
::
NCHW44
||
param
.
format
==
Param
::
Format
::
NCHW44_DOT
)
{
//! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4}
if
(
filter_shape
[
1
]
==
1
&&
filter_shape
[
2
]
==
1
)
{
if
(
filter_shape
[
1
]
==
1
&&
filter_shape
[
2
]
==
1
&&
filter_shape
.
ndim
==
6
)
{
group
*=
4
;
}
size_t
computation
=
dst_shape
.
total_nr_elems
()
*
fh
*
fw
*
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录