Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3de1fa5b
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3de1fa5b
编写于
5月 13, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(arm/dnn): add support for nchw_nchw44 filter 2
GitOrigin-RevId: 013242911ec2e3acdf2e67676e3f4a821a9a5eb0
上级
f3547242
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
74 addition
and
28 deletion
+74
-28
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp
...on/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp
+4
-1
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp
...on/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp
+57
-15
dnn/test/arm_common/conv_bias.cpp
dnn/test/arm_common/conv_bias.cpp
+10
-9
dnn/test/arm_common/conv_bias_multi_thread.cpp
dnn/test/arm_common/conv_bias_multi_thread.cpp
+3
-3
未找到文件。
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp
浏览文件 @
3de1fa5b
...
@@ -207,7 +207,7 @@ bool ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::usable(
...
@@ -207,7 +207,7 @@ bool ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::usable(
(
fm
.
format
==
param
::
Convolution
::
Format
::
NCHW44
);
(
fm
.
format
==
param
::
Convolution
::
Format
::
NCHW44
);
bool
ok_src_dst
=
fm
.
icpg
<
4
&&
(
oc
%
4
==
0
&&
oc
>=
4
)
&&
fm
.
group
==
1
;
bool
ok_src_dst
=
fm
.
icpg
<
4
&&
(
oc
%
4
==
0
&&
oc
>=
4
)
&&
fm
.
group
==
1
;
bool
ok_filter
=
fm
.
spatial_ndim
==
2
&&
fh
==
fm
.
spatial
[
1
]
&&
bool
ok_filter
=
fm
.
spatial_ndim
==
2
&&
fh
==
fm
.
spatial
[
1
]
&&
(
fh
==
3
||
fh
==
5
||
fh
==
7
);
(
fh
==
2
||
fh
==
3
||
fh
==
5
||
fh
==
7
);
bool
ok_slide
=
fm
.
dilation
[
0
]
==
1
&&
fm
.
dilation
[
1
]
==
1
&&
bool
ok_slide
=
fm
.
dilation
[
0
]
==
1
&&
fm
.
dilation
[
1
]
==
1
&&
fm
.
stride
[
0
]
==
2
&&
fm
.
stride
[
1
]
==
2
;
fm
.
stride
[
0
]
==
2
&&
fm
.
stride
[
1
]
==
2
;
bool
ok_conv
=
!
fm
.
should_flip
&&
param
.
bias_mode
!=
BiasMode
::
BIAS
;
bool
ok_conv
=
!
fm
.
should_flip
&&
param
.
bias_mode
!=
BiasMode
::
BIAS
;
...
@@ -267,6 +267,9 @@ ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::dispatch_kerns(
...
@@ -267,6 +267,9 @@ ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::dispatch_kerns(
#define DISPATCH_CONV_KERN() \
#define DISPATCH_CONV_KERN() \
switch (param.filter_meta.spatial[0]) { \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(2) \
break; \
case 3: \
case 3: \
GET_BIAS_MODE_PARAM(3) \
GET_BIAS_MODE_PARAM(3) \
break; \
break; \
...
...
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp
浏览文件 @
3de1fa5b
...
@@ -207,24 +207,27 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block> {
...
@@ -207,24 +207,27 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block> {
float32x4_t
src
[
src_reg_size
];
float32x4_t
src
[
src_reg_size
];
float32x4_t
weight
[
c_dim
][
filter_size
];
float32x4_t
weight
[
c_dim
][
filter_size
];
// row 0
// row 0
load_helper
<
5
,
0
,
simd_len
,
0
,
Vld1q_f32
>
(
src
,
src_ptr
,
0
);
load_helper
<
src_reg_size
,
0
,
simd_len
,
0
,
Vld1q_f32
>
(
src
,
src_ptr
,
load_helper
<
3
,
0
,
oc_step
,
c_dim
,
Vld1q_f32
>
(
weight
,
weight_ptr
,
0
);
ld_weight_oc
);
load_helper
<
filter_size
,
0
,
oc_step
,
c_dim
,
Vld1q_f32
>
(
weight
,
weight_ptr
,
ld_weight_oc
);
cal_helper
<
0
,
0
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
0
,
0
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
1
,
1
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
1
,
1
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
2
,
2
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
2
,
2
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
// row 1
// row 1
load_helper
<
5
,
0
,
simd_len
,
0
,
Vld1q_f32
>
(
src
,
src_ptr
+
iw
,
0
);
load_helper
<
src_reg_size
,
0
,
simd_len
,
0
,
Vld1q_f32
>
(
load_helper
<
3
,
0
,
oc_step
,
c_dim
,
Vld1q_f32
>
(
src
,
src_ptr
+
iw
,
0
);
load_helper
<
filter_size
,
0
,
oc_step
,
c_dim
,
Vld1q_f32
>
(
weight
,
weight_ptr
+
1
*
ld_weight_fw
,
ld_weight_oc
);
weight
,
weight_ptr
+
1
*
ld_weight_fw
,
ld_weight_oc
);
cal_helper
<
0
,
0
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
0
,
0
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
1
,
1
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
1
,
1
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
2
,
2
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
2
,
2
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
// row 2
// row 2
load_helper
<
5
,
0
,
simd_len
,
0
,
Vld1q_f32
>
(
src
,
src_ptr
+
2
*
iw
,
0
);
load_helper
<
src_reg_size
,
0
,
simd_len
,
0
,
Vld1q_f32
>
(
load_helper
<
3
,
0
,
oc_step
,
c_dim
,
Vld1q_f32
>
(
src
,
src_ptr
+
2
*
iw
,
0
);
load_helper
<
filter_size
,
0
,
oc_step
,
c_dim
,
Vld1q_f32
>
(
weight
,
weight_ptr
+
2
*
ld_weight_fw
,
ld_weight_oc
);
weight
,
weight_ptr
+
2
*
ld_weight_fw
,
ld_weight_oc
);
cal_helper
<
0
,
0
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
0
,
0
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
1
,
1
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
1
,
1
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
...
@@ -238,6 +241,52 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block> {
...
@@ -238,6 +241,52 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block> {
}
}
};
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
>
struct
KerNeonXXs2NchwNchw44FP32
<
bias_mode
,
Op
,
remain_w
,
2
,
oc_block
>
{
static
void
impl
(
const
float32_t
*
src_ptr
,
const
float32_t
*
weight_ptr
,
const
float32_t
*
bias_ptr
,
float32_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
loop_ic_step
=
1
;
constexpr
int
filter_size
=
2
;
constexpr
int
oc_step
=
4
;
constexpr
int
simd_len
=
4
;
constexpr
int
src_reg_size
=
4
;
constexpr
int
ld_weight_fw
=
oc_step
*
filter_size
;
const
int
ld_weight_oc
=
oc_step
*
filter_size
*
filter_size
*
ic
;
const
int
ld_weight_ic
=
oc_step
*
filter_size
*
filter_size
;
const
int
ld_src_ic
=
ih
*
iw
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
float32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
float32x4_t
src
[
src_reg_size
];
float32x4_t
weight
[
c_dim
][
filter_size
];
// row 0
load_helper
<
src_reg_size
,
0
,
simd_len
,
0
,
Vld1q_f32
>
(
src
,
src_ptr
,
0
);
load_helper
<
filter_size
,
0
,
oc_step
,
c_dim
,
Vld1q_f32
>
(
weight
,
weight_ptr
,
ld_weight_oc
);
cal_helper
<
0
,
0
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
1
,
1
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
// row 1
load_helper
<
src_reg_size
,
0
,
simd_len
,
0
,
Vld1q_f32
>
(
src
,
src_ptr
+
iw
,
0
);
load_helper
<
filter_size
,
0
,
oc_step
,
c_dim
,
Vld1q_f32
>
(
weight
,
weight_ptr
+
1
*
ld_weight_fw
,
ld_weight_oc
);
cal_helper
<
0
,
0
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
cal_helper
<
1
,
1
,
c_dim
,
Vfmaq_laneq_f32
>
(
c
,
src
,
weight
);
src_ptr
+=
ld_src_ic
;
weight_ptr
+=
ld_weight_ic
;
}
store_ocx_ow8_remain_static
<
c_dim
,
remain_w
,
Op
>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
}
// namespace
}
// namespace
void
conv_bias
::
pack_weight_fp32_nchw_nchw44
(
const
float32_t
*
in_ptr
,
void
conv_bias
::
pack_weight_fp32_nchw_nchw44
(
const
float32_t
*
in_ptr
,
...
@@ -383,19 +432,12 @@ static void conv_direct_stride2_fp32_nchw_nchw44(
...
@@ -383,19 +432,12 @@ static void conv_direct_stride2_fp32_nchw_nchw44(
ow, op, ph, pw); \
ow, op, ph, pw); \
}
}
CONSTRUCT_FUNC
(
2
);
CONSTRUCT_FUNC
(
3
);
CONSTRUCT_FUNC
(
3
);
CONSTRUCT_FUNC
(
5
);
CONSTRUCT_FUNC
(
5
);
CONSTRUCT_FUNC
(
7
);
CONSTRUCT_FUNC
(
7
);
#undef CONSTRUCT_FUNC
#undef CONSTRUCT_FUNC
template
<
BiasMode
bias_mode
,
typename
Op
>
void
conv_bias
::
conv_direct_stride2_2x2_fp32_nchw_nchw44
(
const
float32_t
*
,
const
float32_t
*
,
const
float32_t
*
,
float32_t
*
,
float32_t
*
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
Op
&
,
const
int
,
const
int
)
{
megdnn_assert
(
0
,
"not imple nchw_nchw44 2x2s2 conv"
);
}
#define INSTANTIATION(stride, i, bias, Op) \
#define INSTANTIATION(stride, i, bias, Op) \
template void conv_bias:: \
template void conv_bias:: \
conv_direct_##stride##_##i##x##i##_fp32_nchw_nchw44<bias, Op>( \
conv_direct_##stride##_##i##x##i##_fp32_nchw_nchw44<bias, Op>( \
...
...
dnn/test/arm_common/conv_bias.cpp
浏览文件 @
3de1fa5b
...
@@ -195,6 +195,7 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
...
@@ -195,6 +195,7 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
};
};
if
(
is_fp32
)
{
if
(
is_fp32
)
{
run
(
1
,
1
,
4
,
112
,
112
,
2
,
2
,
true
);
run
(
1
,
3
,
32
,
224
,
224
,
3
,
2
,
true
);
run
(
1
,
3
,
32
,
224
,
224
,
3
,
2
,
true
);
run
(
1
,
3
,
64
,
224
,
224
,
7
,
2
,
true
);
run
(
1
,
3
,
64
,
224
,
224
,
7
,
2
,
true
);
}
else
{
}
else
{
...
@@ -1806,12 +1807,12 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_WINOGRAD_VS_IM2COL_INT8) {
...
@@ -1806,12 +1807,12 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_WINOGRAD_VS_IM2COL_INT8) {
auto
opr
=
handle
()
->
create_operator
<
ConvBias
>
();
auto
opr
=
handle
()
->
create_operator
<
ConvBias
>
();
opr
->
param
()
=
arg
.
param
;
opr
->
param
()
=
arg
.
param
;
opr
->
deduce_layout
({
arg
.
src
,
dtype
::
Float32
()},
opr
->
deduce_layout
({
arg
.
src
,
dtype
::
Float32
()},
{
arg
.
filter
,
dtype
::
Float32
()},
{
arg
.
filter
,
dtype
::
Float32
()},
{
arg
.
bias
,
dtype
::
Float32
()},
{},
dst_layout
);
{
arg
.
bias
,
dtype
::
Float32
()},
{},
dst_layout
);
//! dst.nr_elems * IC * FH * FW * 2
//! dst.nr_elems * IC * FH * FW * 2
float
computations
=
dst_layout
.
total_nr_elems
()
*
arg
.
filter
[
1
]
*
float
computations
=
dst_layout
.
total_nr_elems
()
*
arg
.
filter
[
1
]
*
arg
.
filter
[
2
]
*
arg
.
filter
[
3
]
*
2.0
/
arg
.
filter
[
2
]
*
arg
.
filter
[
3
]
*
2.0
/
(
1024
*
1024
*
1024
)
*
1e3
;
(
1024
*
1024
*
1024
)
*
1e3
;
benchmark_im2col
.
set_param
(
arg
.
param
);
benchmark_im2col
.
set_param
(
arg
.
param
);
auto
im2col_used
=
auto
im2col_used
=
...
@@ -1828,11 +1829,11 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_WINOGRAD_VS_IM2COL_INT8) {
...
@@ -1828,11 +1829,11 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_WINOGRAD_VS_IM2COL_INT8) {
RUN
;
RUN
;
printf
(
"%s %s: im2col: %f ms %f Gflops winograd: %f ms %f GFlops "
printf
(
"%s %s: im2col: %f ms %f Gflops winograd: %f ms %f GFlops "
"speedup: "
"speedup: "
"%f
\n
"
,
"%f
\n
"
,
arg
.
src
.
to_string
().
c_str
(),
arg
.
filter
.
to_string
().
c_str
(),
arg
.
src
.
to_string
().
c_str
(),
arg
.
filter
.
to_string
().
c_str
(),
im2col_used
,
computations
/
im2col_used
,
winograd_used
,
im2col_used
,
computations
/
im2col_used
,
winograd_used
,
computations
/
winograd_used
,
im2col_used
/
winograd_used
);
computations
/
winograd_used
,
im2col_used
/
winograd_used
);
}
}
}
}
...
...
dnn/test/arm_common/conv_bias_multi_thread.cpp
浏览文件 @
3de1fa5b
...
@@ -342,9 +342,9 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) {
...
@@ -342,9 +342,9 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) {
handle
(),
"F32STRD2_SMALL_GROUP"
);
handle
(),
"F32STRD2_SMALL_GROUP"
);
}
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONVBIAS_NCHW_NCHW44_F32
)
{
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONVBIAS_NCHW_NCHW44_F32
)
{
check_conv_bias
(
check_conv_bias
(
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
2
,
false
,
false
,
get_nchw44_conv_bias_args
({
3
,
5
,
7
},
2
,
false
,
false
,
false
,
true
),
false
,
true
),
handle
(),
"F32_CONV_NCHW_NCHW44"
);
handle
(),
"F32_CONV_NCHW_NCHW44"
);
}
}
/**********************************F16 direct************************/
/**********************************F16 direct************************/
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录