Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
09ceaaae
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
09ceaaae
编写于
6月 03, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(dnn/arm): stride1 support for nchw_nchw44 fp32 conv
GitOrigin-RevId: 744c5db3dc3a867d1577f1c870e47472945234f5
上级
50db9b84
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
196 addition
and
250 deletion
+196
-250
dnn/src/arm_common/conv_bias/fp32/algos.h
dnn/src/arm_common/conv_bias/fp32/algos.h
+2
-2
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp
...arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp
+62
-56
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h
...c/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h
+118
-149
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h
...mmon/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h
+0
-38
dnn/src/arm_common/conv_bias/opr_impl.cpp
dnn/src/arm_common/conv_bias/opr_impl.cpp
+1
-1
dnn/src/arm_common/conv_bias/opr_impl.h
dnn/src/arm_common/conv_bias/opr_impl.h
+1
-1
dnn/test/arm_common/conv_bias.cpp
dnn/test/arm_common/conv_bias.cpp
+5
-0
dnn/test/arm_common/conv_bias_multi_thread.cpp
dnn/test/arm_common/conv_bias_multi_thread.cpp
+7
-3
未找到文件。
dnn/src/arm_common/conv_bias/fp32/algos.h
浏览文件 @
09ceaaae
...
...
@@ -293,11 +293,11 @@ public:
const
NCBKernSizeParam
&
param
)
const
override
;
};
class
ConvBiasImpl
::
AlgoF32Direct
Stride2
NCHWNCHW44
final
:
public
AlgoBase
{
class
ConvBiasImpl
::
AlgoF32DirectNCHWNCHW44
final
:
public
AlgoBase
{
SmallVector
<
NCBKern
>
get_kimpls
(
const
NCBKernSizeParam
&
param
)
const
;
public:
AlgoF32Direct
Stride2
NCHWNCHW44
()
{}
AlgoF32DirectNCHWNCHW44
()
{}
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"F32_CONV_NCHW_NCHW44"
;
}
bool
usable
(
fallback
::
ConvBiasImpl
*
opr
,
const
NCBKernSizeParam
&
param
,
...
...
dnn/src/arm_common/conv_bias/fp32/f32_direct_
stride2_
nchw_nchw44_algo.cpp
→
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp
浏览文件 @
09ceaaae
/**
* \file
dnn/src/arm_common/conv_bias/fp32/f32_direct_
stride2_
nchw_nchw44_algo.cpp
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...
...
@@ -13,7 +13,7 @@
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_
stride2_
nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"
...
...
@@ -26,7 +26,7 @@ using conv_fun = std::function<void(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
CpuNDRange
&
workspace_ids
,
const
CpuNDRange
&
ncb_range
)
>
;
MIDOUT_DECL
(
megdnn_arm_common_conv_bias_fp32_nchw_nchw44
_stride2
)
MIDOUT_DECL
(
megdnn_arm_common_conv_bias_fp32_nchw_nchw44
)
namespace
{
static
inline
int
block_helper
(
const
int
nthread
,
const
int
amount
,
const
int
per_unit_bytes
)
{
...
...
@@ -120,11 +120,10 @@ static void pack_weight(WorkspaceBundle bundle,
kern_param
.
filter
<
dt_float32
>
(
group_id
)
+
oc_idx
*
fh
*
fw
*
ic
;
auto
packed_weight
=
reinterpret_cast
<
float
*>
(
bundle
.
get
(
1
))
+
group_id
*
oc
*
ic
*
fh
*
fw
+
oc_idx
*
ic
*
fh
*
fw
;
conv_bias
::
pack_weight_fp32_nchw_nchw44
(
fptr
,
packed_weight
,
oc_block
,
fh
,
fw
,
ic
);
pack_weight_fp32_nchw_nchw44
(
fptr
,
packed_weight
,
oc_block
,
fh
,
fw
,
ic
);
}
template
<
size_t
filter
,
BiasMode
bias_mode
,
typename
Op
>
template
<
size_t
filter
_size
,
BiasMode
bias_mode
,
typename
Op
,
size_t
stride
>
static
void
do_conv_kern
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
...
...
@@ -137,7 +136,7 @@ static void do_conv_kern(WorkspaceBundle bundle,
const
int
oc
=
kern_param
.
filter_meta
.
ocpg
;
const
int
ih
=
kern_param
.
isz
[
0
];
const
int
iw
=
kern_param
.
isz
[
1
];
const
int
stride_h
=
kern_param
.
filter_meta
.
stride
[
0
]
;
const
int
stride_h
=
stride
;
const
int
ph
=
kern_param
.
filter_meta
.
padding
[
0
];
const
int
pw
=
kern_param
.
filter_meta
.
padding
[
1
];
int
ih2
=
0
;
...
...
@@ -181,21 +180,15 @@ static void do_conv_kern(WorkspaceBundle bundle,
const
float
*
bptr
=
kern_param
.
bias
<
dt_float32
>
(
batch_id
,
group_id
)
+
oc_idx
;
Op
op
;
#define KERN1_NCHW44_CONV(filter) \
conv_bias::conv_direct_stride2_##filter##x##filter##_fp32_nchw_nchw44< \
\
bias_mode, Op>(sptr, packed_weight, bptr, nullptr, dst, oc_block, \
ic, ih_real, iw2, oh, oh_block_real, ow, op, ph, \
pw)
DISPATCH_FILTER
(
filter
,
KERN1_NCHW44_CONV
);
#undef KERN1_NCHW44_CONV
conv_direct_fp32_nchw_nchw44
<
bias_mode
,
Op
,
filter_size
,
stride
>
(
sptr
,
packed_weight
,
bptr
,
nullptr
,
dst
,
oc_block
,
ic
,
ih_real
,
iw2
,
oh
,
oh_block_real
,
ow
,
op
,
ph
,
pw
);
}
}
// namespace
/* ===================== stride2 algo ===================== */
bool
ConvBiasImpl
::
AlgoF32DirectStride2NCHWNCHW44
::
usable
(
bool
ConvBiasImpl
::
AlgoF32DirectNCHWNCHW44
::
usable
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
)
const
{
auto
&&
fm
=
param
.
filter_meta
;
...
...
@@ -209,19 +202,20 @@ bool ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::usable(
bool
ok_filter
=
fm
.
spatial_ndim
==
2
&&
fh
==
fm
.
spatial
[
1
]
&&
(
fh
==
2
||
fh
==
3
||
fh
==
5
||
fh
==
7
);
bool
ok_slide
=
fm
.
dilation
[
0
]
==
1
&&
fm
.
dilation
[
1
]
==
1
&&
fm
.
stride
[
0
]
==
2
&&
fm
.
stride
[
1
]
==
2
;
fm
.
stride
[
0
]
==
fm
.
stride
[
1
]
&&
(
fm
.
stride
[
0
]
==
1
||
fm
.
stride
[
0
]
==
2
);
bool
ok_conv
=
!
fm
.
should_flip
&&
param
.
bias_mode
!=
BiasMode
::
BIAS
;
bool
avaible
=
ok_type
&&
ok_src_dst
&&
ok_filter
&&
ok_slide
&&
ok_conv
;
return
avaible
;
}
size_t
ConvBiasImpl
::
AlgoF32Direct
Stride2
NCHWNCHW44
::
get_workspace
(
size_t
ConvBiasImpl
::
AlgoF32DirectNCHWNCHW44
::
get_workspace
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
return
get_bundle
(
param
).
total_size_in_bytes
();
}
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ConvBiasImpl
::
AlgoF32Direct
Stride2
NCHWNCHW44
::
dispatch_kerns
(
ConvBiasImpl
::
AlgoF32DirectNCHWNCHW44
::
dispatch_kerns
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
auto
fm
=
param
.
filter_meta
;
const
int
batch
=
param
.
n
;
...
...
@@ -230,61 +224,73 @@ ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::dispatch_kerns(
conv_fun
do_conv_fun
=
nullptr
;
// NOTE: remain_w is not used to gen hash of midout for compatible with
// shape runtime
#define DO_CONV_KERN_FUN(
filter, bias_mode, op)
\
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw_nchw44
_stride2,
\
midout_iv(#
filter #bias_mode #op##_hash)) {
\
do_conv_fun = do_conv_kern<filter, bias_mode, op
>;
\
}
\
#define DO_CONV_KERN_FUN(
stride, filter, bias_mode, op)
\
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw_nchw44
,
\
midout_iv(#
stride #filter #bias_mode #op##_hash)) {
\
do_conv_fun = do_conv_kern<filter, bias_mode, op
, stride>;
\
} \
MIDOUT_END();
#define GET_OP_PARAM(filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(filter, bias_mode, NoneOp<dt_float32>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(filter, bias_mode, ReluOp<dt_float32>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(filter, bias_mode, HSwishOp<dt_float32>) \
break; \
default: \
megdnn_assert(0); \
break; \
#define GET_OP_PARAM(
stride,
filter, bias_mode) \
switch (param.nonlineMode) {
\
case param::ConvBias::NonlineMode::IDENTITY:
\
DO_CONV_KERN_FUN(
stride,
filter, bias_mode, NoneOp<dt_float32>) \
break;
\
case param::ConvBias::NonlineMode::RELU:
\
DO_CONV_KERN_FUN(
stride,
filter, bias_mode, ReluOp<dt_float32>) \
break;
\
case param::ConvBias::NonlineMode::H_SWISH:
\
DO_CONV_KERN_FUN(
stride,
filter, bias_mode, HSwishOp<dt_float32>) \
break;
\
default:
\
megdnn_assert(0);
\
break;
\
}
#define GET_BIAS_MODE_PARAM(filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
#define GET_BIAS_MODE_PARAM(
stride,
filter) \
switch (param.bias_mode) {
\
case BiasMode::NO_BIAS:
\
GET_OP_PARAM(
stride,
filter, BiasMode::NO_BIAS) \
break;
\
case BiasMode::BROADCAST_CHANNEL_BIAS:
\
GET_OP_PARAM(
stride,
filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break;
\
default:
\
megdnn_assert(0);
\
break;
\
}
#define DISPATCH_CONV_KERN(
)
\
#define DISPATCH_CONV_KERN(
stride)
\
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(
2)
\
GET_BIAS_MODE_PARAM(
stride, 2)
\
break; \
case 3: \
GET_BIAS_MODE_PARAM(
3)
\
GET_BIAS_MODE_PARAM(
stride, 3)
\
break; \
case 5: \
GET_BIAS_MODE_PARAM(
5)
\
GET_BIAS_MODE_PARAM(
stride, 5)
\
break; \
case 7: \
GET_BIAS_MODE_PARAM(
7)
\
GET_BIAS_MODE_PARAM(
stride, 7)
\
break; \
default: \
megdnn_assert(0); \
break; \
}
DISPATCH_CONV_KERN
();
switch
(
param
.
filter_meta
.
stride
[
0
])
{
case
1
:
DISPATCH_CONV_KERN
(
1
);
break
;
case
2
:
DISPATCH_CONV_KERN
(
2
);
break
;
default:
megdnn_throw
(
ssprintf
(
"Unsupport stride size %u for the first conv"
,
param
.
filter_meta
.
stride
[
0
])
.
c_str
());
break
;
}
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
...
...
dnn/src/arm_common/conv_bias/fp32/f32_direct_
stride2_nchw_nchw44_kern.cpp
→
dnn/src/arm_common/conv_bias/fp32/f32_direct_
nchw_nchw44_kern.h
浏览文件 @
09ceaaae
此差异已折叠。
点击以展开。
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h
已删除
100644 → 0
浏览文件 @
50db9b84
/**
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace
megdnn
{
namespace
arm_common
{
namespace
conv_bias
{
#define KERN(stride, i, layout) \
template <BiasMode bias_mode, typename Op> \
void conv_direct_##stride##_##i##x##i##_fp32_nchw_##layout( \
const float* src, const float* filter, const float* bias, \
float* temp, float* dst, const int oc, const int ic, const int ih, \
const int iw, const int oh, const int oh_block, const int ow, \
const Op& op, const int ph, const int pw);
KERN
(
stride2
,
2
,
nchw44
)
KERN
(
stride2
,
3
,
nchw44
)
KERN
(
stride2
,
5
,
nchw44
)
KERN
(
stride2
,
7
,
nchw44
)
#undef KERN
void
pack_weight_fp32_nchw_nchw44
(
const
float_t
*
in_ptr
,
float_t
*
dst_ptr
,
const
int
oc
,
const
int
kh
,
const
int
kw
,
const
int
ic
);
}
// namespace conv_bias
}
// namespace arm_common
}
// namespace megdnn
\ No newline at end of file
dnn/src/arm_common/conv_bias/opr_impl.cpp
浏览文件 @
09ceaaae
...
...
@@ -66,7 +66,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoDotU8DirectStride2
du8_direct_stride2_small_group
{
false
};
#endif
AlgoF32Direct
Stride2
NCHWNCHW44
f32_direct_stride2_nchw_nchw44
;
AlgoF32DirectNCHWNCHW44
f32_direct_stride2_nchw_nchw44
;
AlgoF32ChannelWiseNCHW44
f32_chanel_wise_nchw44
;
AlgoF32DirectNCHW44
f32_direct_nchw44
;
...
...
dnn/src/arm_common/conv_bias/opr_impl.h
浏览文件 @
09ceaaae
...
...
@@ -71,7 +71,7 @@ private:
class
AlgoF32Direct
;
class
AlgoF32DirectStride1
;
class
AlgoF32DirectStride2
;
class
AlgoF32Direct
Stride2
NCHWNCHW44
;
class
AlgoF32DirectNCHWNCHW44
;
class
AlgoF32ChannelWiseNCHW44
;
class
AlgoF32DirectNCHW44
;
...
...
dnn/test/arm_common/conv_bias.cpp
浏览文件 @
09ceaaae
...
...
@@ -204,6 +204,11 @@ static void benchmark_convbias(Handle* handle, std::string int_name,
run
(
1
,
3
,
32
,
224
,
224
,
3
,
2
,
true
);
run
(
1
,
3
,
64
,
224
,
224
,
7
,
2
,
true
);
run
(
1
,
1
,
4
,
112
,
112
,
2
,
1
,
true
);
run
(
1
,
3
,
32
,
224
,
224
,
3
,
1
,
true
);
run
(
1
,
3
,
64
,
224
,
224
,
3
,
1
,
true
);
run
(
1
,
3
,
64
,
224
,
224
,
7
,
1
,
true
);
run
(
1
,
64
,
128
,
56
,
56
,
3
,
2
,
false
);
run
(
1
,
128
,
256
,
28
,
28
,
3
,
2
,
false
);
run
(
1
,
256
,
512
,
14
,
14
,
3
,
2
,
false
);
...
...
dnn/test/arm_common/conv_bias_multi_thread.cpp
浏览文件 @
09ceaaae
...
...
@@ -392,6 +392,9 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) {
check_conv_bias
(
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
2
,
false
,
false
,
false
,
true
),
handle
(),
"F32_CONV_NCHW_NCHW44"
);
check_conv_bias
(
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
1
,
false
,
false
,
false
,
true
),
handle
(),
"F32_CONV_NCHW_NCHW44"
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1
)
{
check_conv_bias
(
...
...
@@ -824,13 +827,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_PREPROCESS_NCHW44) {
auto
conv_bias_opr
=
handle
->
create_operator
<
ConvBias
>
();
conv_bias_opr
->
param
()
=
param
;
conv_bias_opr
->
param
().
format
=
param
::
ConvBias
::
Format
::
NCHW44_WINOGRAD
;
conv_bias_opr
->
param
().
format
=
param
::
ConvBias
::
Format
::
NCHW44_WINOGRAD
;
conv_bias_opr
->
param
().
output_block_size
=
m
;
size_t
conv_bias_workspace_in_bytes
=
conv_bias_opr
->
get_workspace_in_bytes
(
tensors
[
0
].
layout
,
filter_transform_layout
,
tensors
[
2
].
layout
,
tensors
[
3
].
layout
,
tensors
[
4
].
layout
,
nullptr
);
tensors
[
2
].
layout
,
tensors
[
3
].
layout
,
tensors
[
4
].
layout
,
nullptr
);
WorkspaceBundle
wb
(
nullptr
,
{
filter_transform_layout
.
span
().
dist_byte
(),
conv_bias_workspace_in_bytes
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录