Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
df8931b6
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
df8931b6
编写于
5月 06, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
5月 12, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/arm): add padding support for nchw44 arm pooling and opt code
GitOrigin-RevId: f125004e1f271656f2c2646913aea6afdd112e15
上级
07dd6b6c
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
552 addition
and
682 deletion
+552
-682
dnn/src/arm_common/pooling/algo.cpp
dnn/src/arm_common/pooling/algo.cpp
+138
-133
dnn/src/arm_common/pooling/algo.h
dnn/src/arm_common/pooling/algo.h
+8
-10
dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.cpp
...src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.cpp
+0
-91
dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h
dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h
+0
-25
dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp
dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp
+20
-21
dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.h
dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.h
+4
-2
dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp
dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp
+195
-0
dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.h
dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.h
+9
-2
dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.cpp
dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.cpp
+36
-28
dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.h
dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.h
+1
-1
dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp
dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp
+61
-45
dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.h
dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.h
+1
-1
dnn/src/arm_common/pooling/opr_impl.cpp
dnn/src/arm_common/pooling/opr_impl.cpp
+8
-4
dnn/src/arm_common/pooling/opr_impl.h
dnn/src/arm_common/pooling/opr_impl.h
+1
-2
dnn/test/arm_common/pooling.cpp
dnn/test/arm_common/pooling.cpp
+0
-205
dnn/test/arm_common/pooling_multi_thread.cpp
dnn/test/arm_common/pooling_multi_thread.cpp
+70
-112
未找到文件。
dnn/src/arm_common/pooling/algo.cpp
浏览文件 @
df8931b6
...
@@ -11,14 +11,13 @@
...
@@ -11,14 +11,13 @@
*/
*/
#include "src/arm_common/pooling/algo.h"
#include "src/arm_common/pooling/algo.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/opr_param_defs.h"
#include "src/arm_common/pooling/do_max_pooling_2x2_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_4x4_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_5x5_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h"
#include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h"
#include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h"
#include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h"
#include "src/arm_common/pooling/do_pooling_2x2_nchw44.h"
#include "src/arm_common/pooling/do_pooling_3x3_nchw44.h"
#include "src/arm_common/pooling/do_pooling_4x4_nchw44.h"
#include "src/arm_common/pooling/do_pooling_5x5_nchw44.h"
#include "midout.h"
#include "midout.h"
...
@@ -57,6 +56,41 @@ WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param) {
...
@@ -57,6 +56,41 @@ WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param) {
return
ws
;
return
ws
;
}
}
WorkspaceBundle
get_bundle_nchw44
(
const
PoolingImpl
::
PoolingKernSizeParam
&
param
)
{
megdnn_assert
((
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)
&&
(
param
.
format
==
param
::
Pooling
::
Format
::
NCHW44
));
auto
IH
=
param
.
isz
[
0
];
auto
IW
=
param
.
isz
[
1
];
auto
PH
=
param
.
padding
[
0
];
auto
PW
=
param
.
padding
[
1
];
size_t
padding_size
=
0
;
if
((
PH
!=
0
)
||
(
PW
!=
0
))
{
padding_size
=
(
IW
+
2
*
PW
)
*
(
IH
+
2
*
PH
)
*
4
*
sizeof
(
int8_t
);
}
return
WorkspaceBundle
(
nullptr
,
{
padding_size
});
}
const
int8_t
*
handle_padding
(
const
int8_t
*
src
,
size_t
IH
,
size_t
IW
,
size_t
&
IH2
,
size_t
&
IW2
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
)
{
int8_t
*
sptr_base
=
nullptr
;
bool
need_pad
=
((
PH
!=
0
)
||
(
PW
!=
0
))
?
true
:
false
;
if
(
need_pad
)
{
IH2
=
IH
+
2
*
PH
;
IW2
=
IW
+
2
*
PW
;
sptr_base
=
static_cast
<
int8_t
*>
(
ws
.
get
(
0
));
memset
(
sptr_base
,
-
128
,
sizeof
(
int8_t
)
*
IH2
*
IW2
*
4
);
rep
(
ih
,
IH
)
{
std
::
memcpy
(
sptr_base
+
(
ih
+
PH
)
*
IW2
*
4
+
PW
*
4
,
src
+
ih
*
IW
*
4
,
sizeof
(
int8_t
)
*
IW
*
4
);
}
}
else
{
IH2
=
IH
;
IW2
=
IW
;
}
return
need_pad
?
sptr_base
:
src
;
}
bool
PoolingImpl
::
AlgoFilterxModexStride1
::
usable
(
bool
PoolingImpl
::
AlgoFilterxModexStride1
::
usable
(
const
PoolingKernSizeParam
&
param
)
const
{
const
PoolingKernSizeParam
&
param
)
const
{
auto
SH
=
param
.
stride
[
0
];
auto
SH
=
param
.
stride
[
0
];
...
@@ -563,47 +597,50 @@ void PoolingImpl::AlgoInt8Filter3MaxStride2::exec(
...
@@ -563,47 +597,50 @@ void PoolingImpl::AlgoInt8Filter3MaxStride2::exec(
MIDOUT_END
();
MIDOUT_END
();
}
}
bool
PoolingImpl
::
AlgoFilter3MaxStride
2
NCHW44
::
usable
(
bool
PoolingImpl
::
AlgoFilter3MaxStride
x
NCHW44
::
usable
(
const
PoolingKernSizeParam
&
param
)
const
{
const
PoolingKernSizeParam
&
param
)
const
{
auto
SH
=
param
.
stride
[
0
];
auto
SH
=
param
.
stride
[
0
];
auto
SW
=
param
.
stride
[
1
];
auto
SW
=
param
.
stride
[
1
];
auto
FH
=
param
.
filter
[
0
];
auto
FH
=
param
.
filter
[
0
];
auto
FW
=
param
.
filter
[
1
];
auto
FW
=
param
.
filter
[
1
];
auto
PH
=
param
.
padding
[
0
];
auto
PW
=
param
.
padding
[
1
];
bool
avaible
=
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
bool
avaible
=
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
param
.
format
==
Param
::
Format
::
NCHW44
&&
param
.
format
==
Param
::
Format
::
NCHW44
&&
param
.
mode
==
Mode
::
MAX
&&
FH
==
3
&&
FW
==
3
&&
S
H
==
2
&&
param
.
mode
==
Mode
::
MAX
&&
FH
==
3
&&
FW
==
3
&&
S
W
==
SH
&&
SW
==
2
&&
PH
==
0
&&
PW
==
0
;
(
SH
==
1
||
SW
==
2
)
;
return
avaible
;
return
avaible
;
}
}
void
PoolingImpl
::
AlgoFilter3MaxStride
2
NCHW44
::
exec
(
void
PoolingImpl
::
AlgoFilter3MaxStride
x
NCHW44
::
exec
(
const
PoolingKernParam
&
param
)
const
{
const
PoolingKernParam
&
param
)
const
{
auto
IH
=
param
.
isz
[
0
],
IW
=
param
.
isz
[
1
];
auto
IH
=
param
.
isz
[
0
],
IW
=
param
.
isz
[
1
];
auto
OH
=
param
.
osz
[
0
],
OW
=
param
.
osz
[
1
];
auto
OH
=
param
.
osz
[
0
],
OW
=
param
.
osz
[
1
];
auto
N
=
param
.
n
,
C
=
param
.
ic
;
auto
N
=
param
.
n
,
C
=
param
.
ic
;
auto
PH
=
param
.
padding
[
0
];
auto
PH
=
param
.
padding
[
0
];
auto
PW
=
param
.
padding
[
1
];
auto
PW
=
param
.
padding
[
1
];
auto
SW
=
param
.
stride
[
0
];
void
*
src_ptr
=
param
.
src_ptr
;
void
*
src_ptr
=
param
.
src_ptr
;
void
*
dst_ptr
=
param
.
dst_ptr
;
void
*
dst_ptr
=
param
.
dst_ptr
;
#define DISPATCH_FUNC(type, func,
midout_type_id)
\
#define DISPATCH_FUNC(type, func,
i)
\
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
midout_iv(midout_type_id)) { \
midout_iv(#type #i##_hash)) { \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \
WorkspaceBundle wbundle = get_bundle_nchw44(param); \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
wbundle = wbundle, \
workspace_ptr = param.workspace<dt_byte>()]( \
size_t index, size_t thread_id) { \
size_t index, size_t thread_id) { \
MEGDNN_MARK_USED_VAR(thread_id); \
auto ws = wbundle; \
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
size_t n = index / C; \
size_t n = index / C; \
size_t c = index % C; \
size_t c = index % C; \
do_max_pooling_3x3_s
2x2_##func##_nchw44_NEON(
\
do_max_pooling_3x3_s
tride##i##_##func##_nchw44_NEON(
\
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \
c * IH * IW * 4, \
c * IH * IW * 4, \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \
c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW
);
\
IH, IW, OH, OW, PH, PW
, ws);
\
}; \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
...
@@ -611,61 +648,23 @@ void PoolingImpl::AlgoFilter3MaxStride2NCHW44::exec(
...
@@ -611,61 +648,23 @@ void PoolingImpl::AlgoFilter3MaxStride2NCHW44::exec(
} \
} \
MIDOUT_END();
MIDOUT_END();
DISPATCH_FUNC
(
int8_t
,
int8
,
9
);
#define DISPATCH_STRIDE(type, func) \
switch (SW) { \
#undef DISPATCH_FUNC
case 1: { \
}
DISPATCH_FUNC(type, func, 1); \
break; \
bool
PoolingImpl
::
AlgoFilter3MaxStride1NCHW44
::
usable
(
} \
const
PoolingKernSizeParam
&
param
)
const
{
case 2: { \
auto
SH
=
param
.
stride
[
0
];
DISPATCH_FUNC(type, func, 2); \
auto
SW
=
param
.
stride
[
1
];
break; \
auto
FH
=
param
.
filter
[
0
];
} \
auto
FW
=
param
.
filter
[
1
];
default: \
auto
PH
=
param
.
padding
[
0
];
megdnn_assert(0, "unsupport stride size"); \
auto
PW
=
param
.
padding
[
1
];
}
bool
avaible
=
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
param
.
format
==
Param
::
Format
::
NCHW44
&&
param
.
mode
==
Mode
::
MAX
&&
FH
==
3
&&
FW
==
3
&&
SH
==
1
&&
SW
==
1
&&
PH
==
0
&&
PW
==
0
;
return
avaible
;
}
void
PoolingImpl
::
AlgoFilter3MaxStride1NCHW44
::
exec
(
const
PoolingKernParam
&
param
)
const
{
auto
IH
=
param
.
isz
[
0
],
IW
=
param
.
isz
[
1
];
auto
OH
=
param
.
osz
[
0
],
OW
=
param
.
osz
[
1
];
auto
N
=
param
.
n
,
C
=
param
.
ic
;
auto
PH
=
param
.
padding
[
0
];
auto
PW
=
param
.
padding
[
1
];
void
*
src_ptr
=
param
.
src_ptr
;
void
*
dst_ptr
=
param
.
dst_ptr
;
#define DISPATCH_FUNC(type, func, midout_type_id) \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
midout_iv(midout_type_id)) { \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \
size_t index, size_t thread_id) { \
MEGDNN_MARK_USED_VAR(thread_id); \
size_t n = index / C; \
size_t c = index % C; \
do_max_pooling_3x3_s1x1_##func##_nchw44_NEON( \
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \
c * IH * IW * 4, \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
run); \
} \
MIDOUT_END();
DISPATCH_
FUNC
(
int8_t
,
int8
,
10
);
DISPATCH_
STRIDE
(
int8_t
,
int8
);
#undef DISPATCH_STRIDE
#undef DISPATCH_FUNC
#undef DISPATCH_FUNC
}
}
...
@@ -675,13 +674,11 @@ bool PoolingImpl::AlgoFilter2MaxStridexNCHW44::usable(
...
@@ -675,13 +674,11 @@ bool PoolingImpl::AlgoFilter2MaxStridexNCHW44::usable(
auto
SW
=
param
.
stride
[
1
];
auto
SW
=
param
.
stride
[
1
];
auto
FH
=
param
.
filter
[
0
];
auto
FH
=
param
.
filter
[
0
];
auto
FW
=
param
.
filter
[
1
];
auto
FW
=
param
.
filter
[
1
];
auto
PH
=
param
.
padding
[
0
];
auto
PW
=
param
.
padding
[
1
];
bool
avaible
=
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
bool
avaible
=
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
param
.
format
==
Param
::
Format
::
NCHW44
&&
param
.
format
==
Param
::
Format
::
NCHW44
&&
param
.
mode
==
Mode
::
MAX
&&
FH
==
2
&&
FW
==
2
&&
SH
==
SW
&&
param
.
mode
==
Mode
::
MAX
&&
FH
==
2
&&
FW
==
2
&&
SH
==
SW
&&
(
SW
==
1
||
SW
==
2
)
&&
PH
==
0
&&
PW
==
0
;
(
SW
==
1
||
SW
==
2
);
return
avaible
;
return
avaible
;
}
}
...
@@ -697,12 +694,16 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec(
...
@@ -697,12 +694,16 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec(
void
*
src_ptr
=
param
.
src_ptr
;
void
*
src_ptr
=
param
.
src_ptr
;
void
*
dst_ptr
=
param
.
dst_ptr
;
void
*
dst_ptr
=
param
.
dst_ptr
;
#define DISPATCH_FUNC(type, func,
midout_type_id, i)
\
#define DISPATCH_FUNC(type, func,
i)
\
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
midout_iv(midout_type_id)) { \
midout_iv(#func #i##_hash)) { \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \
WorkspaceBundle wbundle = get_bundle_nchw44(param); \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
wbundle = wbundle, \
workspace_ptr = param.workspace<dt_byte>()]( \
size_t index, size_t thread_id) { \
size_t index, size_t thread_id) { \
MEGDNN_MARK_USED_VAR(thread_id); \
auto ws = wbundle; \
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
size_t n = index / C; \
size_t n = index / C; \
size_t c = index % C; \
size_t c = index % C; \
do_max_pooling_2x2_stride##i##_##func##_nchw44_NEON( \
do_max_pooling_2x2_stride##i##_##func##_nchw44_NEON( \
...
@@ -710,7 +711,7 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec(
...
@@ -710,7 +711,7 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec(
c * IH * IW * 4, \
c * IH * IW * 4, \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \
c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW
);
\
IH, IW, OH, OW, PH, PW
, ws);
\
}; \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
...
@@ -718,21 +719,21 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec(
...
@@ -718,21 +719,21 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec(
} \
} \
MIDOUT_END();
MIDOUT_END();
#define DISPATCH_STRIDE(type, func
, midout_type_id)
\
#define DISPATCH_STRIDE(type, func
)
\
switch (SW) {
\
switch (SW) { \
case 1: {
\
case 1: { \
DISPATCH_FUNC(type, func,
midout_type_id, 1);
\
DISPATCH_FUNC(type, func,
1);
\
break;
\
break; \
}
\
} \
case 2: {
\
case 2: { \
DISPATCH_FUNC(type, func,
midout_type_id, 2);
\
DISPATCH_FUNC(type, func,
2);
\
break;
\
break; \
}
\
} \
default:
\
default: \
megdnn_assert(0, "unsupport stride size");
\
megdnn_assert(0, "unsupport stride size"); \
}
}
DISPATCH_STRIDE
(
int8_t
,
int8
,
10
);
DISPATCH_STRIDE
(
int8_t
,
int8
);
#undef DISPATCH_STRIDE
#undef DISPATCH_STRIDE
#undef DISPATCH_FUNC
#undef DISPATCH_FUNC
...
@@ -744,13 +745,11 @@ bool PoolingImpl::AlgoFilter4MaxStridexNCHW44::usable(
...
@@ -744,13 +745,11 @@ bool PoolingImpl::AlgoFilter4MaxStridexNCHW44::usable(
auto
SW
=
param
.
stride
[
1
];
auto
SW
=
param
.
stride
[
1
];
auto
FH
=
param
.
filter
[
0
];
auto
FH
=
param
.
filter
[
0
];
auto
FW
=
param
.
filter
[
1
];
auto
FW
=
param
.
filter
[
1
];
auto
PH
=
param
.
padding
[
0
];
auto
PW
=
param
.
padding
[
1
];
bool
avaible
=
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
bool
avaible
=
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
param
.
format
==
Param
::
Format
::
NCHW44
&&
param
.
format
==
Param
::
Format
::
NCHW44
&&
param
.
mode
==
Mode
::
MAX
&&
FH
==
4
&&
FW
==
4
&&
SH
==
SW
&&
param
.
mode
==
Mode
::
MAX
&&
FH
==
4
&&
FW
==
4
&&
SH
==
SW
&&
(
SW
==
1
||
SW
==
2
)
&&
PH
==
0
&&
PW
==
0
;
(
SW
==
1
||
SW
==
2
);
return
avaible
;
return
avaible
;
}
}
...
@@ -766,12 +765,16 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec(
...
@@ -766,12 +765,16 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec(
void
*
src_ptr
=
param
.
src_ptr
;
void
*
src_ptr
=
param
.
src_ptr
;
void
*
dst_ptr
=
param
.
dst_ptr
;
void
*
dst_ptr
=
param
.
dst_ptr
;
#define DISPATCH_FUNC(type, func,
midout_type_id, i)
\
#define DISPATCH_FUNC(type, func,
i)
\
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
midout_iv(midout_type_id)) { \
midout_iv(#func #i##_hash)) { \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \
WorkspaceBundle wbundle = get_bundle_nchw44(param); \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
wbundle = wbundle, \
workspace_ptr = param.workspace<dt_byte>()]( \
size_t index, size_t thread_id) { \
size_t index, size_t thread_id) { \
MEGDNN_MARK_USED_VAR(thread_id); \
auto ws = wbundle; \
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
size_t n = index / C; \
size_t n = index / C; \
size_t c = index % C; \
size_t c = index % C; \
do_max_pooling_4x4_stride##i##_##func##_nchw44_NEON( \
do_max_pooling_4x4_stride##i##_##func##_nchw44_NEON( \
...
@@ -779,7 +782,7 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec(
...
@@ -779,7 +782,7 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec(
c * IH * IW * 4, \
c * IH * IW * 4, \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \
c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW
);
\
IH, IW, OH, OW, PH, PW
, ws);
\
}; \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
...
@@ -787,21 +790,21 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec(
...
@@ -787,21 +790,21 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec(
} \
} \
MIDOUT_END();
MIDOUT_END();
#define DISPATCH_STRIDE(type, func
, midout_type_id)
\
#define DISPATCH_STRIDE(type, func
)
\
switch (SW) {
\
switch (SW) { \
case 1: {
\
case 1: { \
DISPATCH_FUNC(type, func,
midout_type_id, 1);
\
DISPATCH_FUNC(type, func,
1);
\
break;
\
break; \
}
\
} \
case 2: {
\
case 2: { \
DISPATCH_FUNC(type, func,
midout_type_id, 2);
\
DISPATCH_FUNC(type, func,
2);
\
break;
\
break; \
}
\
} \
default:
\
default: \
megdnn_assert(0, "unsupport stride size");
\
megdnn_assert(0, "unsupport stride size"); \
}
}
DISPATCH_STRIDE
(
int8_t
,
int8
,
11
);
DISPATCH_STRIDE
(
int8_t
,
int8
);
#undef DISPATCH_STRIDE
#undef DISPATCH_STRIDE
#undef DISPATCH_FUNC
#undef DISPATCH_FUNC
...
@@ -813,13 +816,11 @@ bool PoolingImpl::AlgoFilter5MaxStridexNCHW44::usable(
...
@@ -813,13 +816,11 @@ bool PoolingImpl::AlgoFilter5MaxStridexNCHW44::usable(
auto
SW
=
param
.
stride
[
1
];
auto
SW
=
param
.
stride
[
1
];
auto
FH
=
param
.
filter
[
0
];
auto
FH
=
param
.
filter
[
0
];
auto
FW
=
param
.
filter
[
1
];
auto
FW
=
param
.
filter
[
1
];
auto
PH
=
param
.
padding
[
0
];
auto
PW
=
param
.
padding
[
1
];
bool
avaible
=
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
bool
avaible
=
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
param
.
format
==
Param
::
Format
::
NCHW44
&&
param
.
format
==
Param
::
Format
::
NCHW44
&&
param
.
mode
==
Mode
::
MAX
&&
FH
==
5
&&
FW
==
5
&&
SH
==
SW
&&
param
.
mode
==
Mode
::
MAX
&&
FH
==
5
&&
FW
==
5
&&
SH
==
SW
&&
(
SW
==
1
||
SW
==
2
)
&&
PH
==
0
&&
PW
==
0
;
(
SW
==
1
||
SW
==
2
);
return
avaible
;
return
avaible
;
}
}
...
@@ -835,12 +836,16 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec(
...
@@ -835,12 +836,16 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec(
void
*
src_ptr
=
param
.
src_ptr
;
void
*
src_ptr
=
param
.
src_ptr
;
void
*
dst_ptr
=
param
.
dst_ptr
;
void
*
dst_ptr
=
param
.
dst_ptr
;
#define DISPATCH_FUNC(type, func,
midout_type_id, i)
\
#define DISPATCH_FUNC(type, func,
i)
\
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
midout_iv(midout_type_id)) { \
midout_iv(#func #i##_hash)) { \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \
WorkspaceBundle wbundle = get_bundle_nchw44(param); \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
wbundle = wbundle, \
workspace_ptr = param.workspace<dt_byte>()]( \
size_t index, size_t thread_id) { \
size_t index, size_t thread_id) { \
MEGDNN_MARK_USED_VAR(thread_id); \
auto ws = wbundle; \
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
size_t n = index / C; \
size_t n = index / C; \
size_t c = index % C; \
size_t c = index % C; \
do_max_pooling_5x5_stride##i##_##func##_nchw44_NEON( \
do_max_pooling_5x5_stride##i##_##func##_nchw44_NEON( \
...
@@ -848,7 +853,7 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec(
...
@@ -848,7 +853,7 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec(
c * IH * IW * 4, \
c * IH * IW * 4, \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \
c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW
);
\
IH, IW, OH, OW, PH, PW
, ws);
\
}; \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
...
@@ -856,21 +861,21 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec(
...
@@ -856,21 +861,21 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec(
} \
} \
MIDOUT_END();
MIDOUT_END();
#define DISPATCH_STRIDE(type, func
, midout_type_id)
\
#define DISPATCH_STRIDE(type, func
)
\
switch (SW) {
\
switch (SW) { \
case 1: {
\
case 1: { \
DISPATCH_FUNC(type, func,
midout_type_id, 1);
\
DISPATCH_FUNC(type, func,
1);
\
break;
\
break; \
}
\
} \
case 2: {
\
case 2: { \
DISPATCH_FUNC(type, func,
midout_type_id, 2);
\
DISPATCH_FUNC(type, func,
2);
\
break;
\
break; \
}
\
} \
default:
\
default: \
megdnn_assert(0, "unsupport stride size");
\
megdnn_assert(0, "unsupport stride size"); \
}
}
DISPATCH_STRIDE
(
int8_t
,
int8
,
12
);
DISPATCH_STRIDE
(
int8_t
,
int8
);
#undef DISPATCH_STRIDE
#undef DISPATCH_STRIDE
#undef DISPATCH_FUNC
#undef DISPATCH_FUNC
...
...
dnn/src/arm_common/pooling/algo.h
浏览文件 @
df8931b6
...
@@ -83,18 +83,10 @@ public:
...
@@ -83,18 +83,10 @@ public:
void
exec
(
const
PoolingKernParam
&
param
)
const
override
;
void
exec
(
const
PoolingKernParam
&
param
)
const
override
;
};
};
class
PoolingImpl
::
AlgoFilter3MaxStride
2
NCHW44
final
:
public
AlgoBase
{
class
PoolingImpl
::
AlgoFilter3MaxStride
x
NCHW44
final
:
public
AlgoBase
{
public:
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"ARM_POOLING_FILTER3_MAX_STRIDE2_NCHW44"
;
}
const
char
*
name
()
const
override
{
return
"ARM_POOLING_FILTER3_MAX_STRIDEX_NCHW44"
;
}
bool
usable
(
const
PoolingKernSizeParam
&
param
)
const
override
;
void
exec
(
const
PoolingKernParam
&
param
)
const
override
;
};
class
PoolingImpl
::
AlgoFilter3MaxStride1NCHW44
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"ARM_POOLING_FILTER3_MAX_STRIDE1_NCHW44"
;
}
bool
usable
(
const
PoolingKernSizeParam
&
param
)
const
override
;
bool
usable
(
const
PoolingKernSizeParam
&
param
)
const
override
;
void
exec
(
const
PoolingKernParam
&
param
)
const
override
;
void
exec
(
const
PoolingKernParam
&
param
)
const
override
;
};
};
...
@@ -125,6 +117,12 @@ public:
...
@@ -125,6 +117,12 @@ public:
WorkspaceBundle
get_bundle
(
const
PoolingImpl
::
PoolingKernSizeParam
&
param
);
WorkspaceBundle
get_bundle
(
const
PoolingImpl
::
PoolingKernSizeParam
&
param
);
WorkspaceBundle
get_bundle_nchw44
(
const
PoolingImpl
::
PoolingKernSizeParam
&
param
);
const
int8_t
*
handle_padding
(
const
int8_t
*
src
,
size_t
IH
,
size_t
IW
,
size_t
&
IH2
,
size_t
&
IW2
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
);
}
// namespace arm_common
}
// namespace arm_common
}
// namespace megdnn
}
// namespace megdnn
...
...
dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.cpp
已删除
100644 → 0
浏览文件 @
07dd6b6c
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
namespace
megdnn
{
namespace
arm_common
{
void
do_max_pooling_3x3_s1x1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
)
{
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
;
const
int8_t
*
__restrict
sptr0
=
src
+
(
ih
+
0
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr1
=
src
+
(
ih
+
1
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr2
=
src
+
(
ih
+
2
)
*
IW
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int8x16_t
src0123
=
vld1q_s8
(
sptr0
);
int8x16_t
src1234
=
vld1q_s8
(
sptr0
+
4
);
int8x16_t
src2345
=
vld1q_s8
(
sptr0
+
8
);
int8x16_t
max0
=
vmaxq_s8
(
src0123
,
src1234
);
max0
=
vmaxq_s8
(
max0
,
src2345
);
src0123
=
vld1q_s8
(
sptr1
);
src1234
=
vld1q_s8
(
sptr1
+
4
);
src2345
=
vld1q_s8
(
sptr1
+
8
);
int8x16_t
max1
=
vmaxq_s8
(
src0123
,
src1234
);
max1
=
vmaxq_s8
(
max1
,
src2345
);
src0123
=
vld1q_s8
(
sptr2
);
src1234
=
vld1q_s8
(
sptr2
+
4
);
src2345
=
vld1q_s8
(
sptr2
+
8
);
int8x16_t
max2
=
vmaxq_s8
(
src0123
,
src1234
);
max2
=
vmaxq_s8
(
max2
,
src2345
);
int8x16_t
max_out
=
vmaxq_s8
(
max0
,
max1
);
max_out
=
vmaxq_s8
(
max_out
,
max2
);
vst1q_s8
(
dptr
,
max_out
);
sptr0
+=
16
;
sptr1
+=
16
;
sptr2
+=
16
;
dptr
+=
16
;
}
for
(;
ow
<
OW
;
++
ow
)
{
int8x8_t
src001
=
vld1_s8
(
sptr0
);
int8x8_t
src012
=
vld1_s8
(
sptr0
+
4
);
int8x8_t
src101
=
vld1_s8
(
sptr1
);
int8x8_t
src112
=
vld1_s8
(
sptr1
+
4
);
int8x8_t
src201
=
vld1_s8
(
sptr2
);
int8x8_t
src212
=
vld1_s8
(
sptr2
+
4
);
int8x8_t
max01_tmp
=
vmax_s8
(
src001
,
src101
);
max01_tmp
=
vmax_s8
(
max01_tmp
,
src201
);
int8x8_t
max12_tmp
=
vmax_s8
(
src012
,
src112
);
max12_tmp
=
vmax_s8
(
max12_tmp
,
src212
);
#define cb(i) \
int8_t dst##i = std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), \
max12_tmp[i + 4]);
#define store(i) *(dptr + i) = dst##i;
UNROLL_CALL_NOWRAPPER
(
4
,
cb
)
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
#undef store
#undef cb
sptr0
+=
4
;
sptr1
+=
4
;
sptr2
+=
4
;
dptr
+=
4
;
}
}
}
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h
已删除
100644 → 0
浏览文件 @
07dd6b6c
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/common/utils.h"
namespace
megdnn
{
namespace
arm_common
{
void
do_max_pooling_3x3_s1x1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
);
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/arm_common/pooling/do_
max_
pooling_2x2_nchw44.cpp
→
dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp
浏览文件 @
df8931b6
...
@@ -9,7 +9,8 @@
...
@@ -9,7 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
* implied.
*/
*/
#include "src/arm_common/pooling/do_max_pooling_2x2_nchw44.h"
#include "src/arm_common/pooling/do_pooling_2x2_nchw44.h"
#include "src/arm_common/pooling/algo.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/unroll_macro.h"
...
@@ -19,12 +20,16 @@ namespace arm_common {
...
@@ -19,12 +20,16 @@ namespace arm_common {
void
do_max_pooling_2x2_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
void
do_max_pooling_2x2_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
)
{
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
)
{
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
);
size_t
oh
=
0
;
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
;
size_t
ih
=
oh
;
const
int8_t
*
__restrict
sptr0
=
s
rc
+
(
ih
+
0
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr0
=
s
ptr
+
(
ih
+
0
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr1
=
s
rc
+
(
ih
+
1
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr1
=
s
ptr
+
(
ih
+
1
)
*
IW2
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
...
@@ -46,15 +51,10 @@ void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
...
@@ -46,15 +51,10 @@ void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
}
}
for
(;
ow
<
OW
;
++
ow
)
{
for
(;
ow
<
OW
;
++
ow
)
{
int8x8_t
src001
=
vld1_s8
(
sptr0
);
int8x8_t
src001
=
vld1_s8
(
sptr0
);
int8x8_t
src012
=
vld1_s8
(
sptr0
+
4
);
int8x8_t
src101
=
vld1_s8
(
sptr1
);
int8x8_t
src101
=
vld1_s8
(
sptr1
);
int8x8_t
src112
=
vld1_s8
(
sptr1
+
4
);
int8x8_t
max01_tmp
=
vmax_s8
(
src001
,
src101
);
int8x8_t
max_out
=
vmax_s8
(
src001
,
src101
);
int8x8_t
max12_tmp
=
vmax_s8
(
src012
,
src112
);
#define store(i) *(dptr + i) = std::max(max_out[i], max_out[i + 4]);
int8x8_t
mat_out
=
vmax_s8
(
max01_tmp
,
max12_tmp
);
#define store(i) *(dptr + i) = mat_out[i];
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
#undef store
#undef store
sptr0
+=
4
;
sptr0
+=
4
;
...
@@ -66,12 +66,16 @@ void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
...
@@ -66,12 +66,16 @@ void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
void
do_max_pooling_2x2_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
void
do_max_pooling_2x2_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
)
{
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
)
{
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
);
size_t
oh
=
0
;
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
<<
1
;
size_t
ih
=
oh
<<
1
;
const
int8_t
*
__restrict
sptr0
=
s
rc
+
(
ih
+
0
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr0
=
s
ptr
+
(
ih
+
0
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr1
=
s
rc
+
(
ih
+
1
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr1
=
s
ptr
+
(
ih
+
1
)
*
IW2
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
...
@@ -103,15 +107,10 @@ void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
...
@@ -103,15 +107,10 @@ void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
}
}
for
(;
ow
<
OW
;
++
ow
)
{
for
(;
ow
<
OW
;
++
ow
)
{
int8x8_t
src001
=
vld1_s8
(
sptr0
);
int8x8_t
src001
=
vld1_s8
(
sptr0
);
int8x8_t
src012
=
vld1_s8
(
sptr0
+
4
);
int8x8_t
src101
=
vld1_s8
(
sptr1
);
int8x8_t
src101
=
vld1_s8
(
sptr1
);
int8x8_t
src112
=
vld1_s8
(
sptr1
+
4
);
int8x8_t
max01_tmp
=
vmax_s8
(
src001
,
src101
);
int8x8_t
max_out
=
vmax_s8
(
src001
,
src101
);
int8x8_t
max12_tmp
=
vmax_s8
(
src012
,
src112
);
#define store(i) *(dptr + i) = std::max(max_out[i], max_out[i + 4]);
int8x8_t
mat_out
=
vmax_s8
(
max01_tmp
,
max12_tmp
);
#define store(i) *(dptr + i) = mat_out[i];
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
#undef store
#undef store
sptr0
+=
8
;
sptr0
+=
8
;
...
...
dnn/src/arm_common/pooling/do_
max_
pooling_2x2_nchw44.h
→
dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.h
浏览文件 @
df8931b6
...
@@ -18,11 +18,13 @@ namespace arm_common {
...
@@ -18,11 +18,13 @@ namespace arm_common {
void
do_max_pooling_2x2_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
void
do_max_pooling_2x2_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
);
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
);
void
do_max_pooling_2x2_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
void
do_max_pooling_2x2_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
);
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
);
}
// namespace arm_common
}
// namespace arm_common
}
// namespace megdnn
}
// namespace megdnn
...
...
dnn/src/arm_common/pooling/do_
max_pooling_3x3_s2x2
_nchw44.cpp
→
dnn/src/arm_common/pooling/do_
pooling_3x3
_nchw44.cpp
浏览文件 @
df8931b6
...
@@ -9,60 +9,143 @@
...
@@ -9,60 +9,143 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
* implied.
*/
*/
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h"
#include "src/arm_common/pooling/do_pooling_3x3_nchw44.h"
#include "src/arm_common/pooling/algo.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/unroll_macro.h"
namespace
megdnn
{
namespace
megdnn
{
namespace
arm_common
{
namespace
arm_common
{
void
do_max_pooling_3x3_s2x2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
void
do_max_pooling_3x3_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
IH
,
size_t
IW
,
size_t
OW
,
size_t
PH
,
size_t
PW
)
{
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
)
{
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
);
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
;
const
int8_t
*
__restrict
sptr0
=
sptr
+
(
ih
+
0
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr1
=
sptr
+
(
ih
+
1
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr2
=
sptr
+
(
ih
+
2
)
*
IW2
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int8x16_t
src0123
=
vld1q_s8
(
sptr0
);
int8x16_t
src1234
=
vld1q_s8
(
sptr0
+
4
);
int8x16_t
src2345
=
vld1q_s8
(
sptr0
+
8
);
int8x16_t
max0
=
vmaxq_s8
(
src0123
,
src1234
);
max0
=
vmaxq_s8
(
max0
,
src2345
);
src0123
=
vld1q_s8
(
sptr1
);
src1234
=
vld1q_s8
(
sptr1
+
4
);
src2345
=
vld1q_s8
(
sptr1
+
8
);
int8x16_t
max1
=
vmaxq_s8
(
src0123
,
src1234
);
max1
=
vmaxq_s8
(
max1
,
src2345
);
src0123
=
vld1q_s8
(
sptr2
);
src1234
=
vld1q_s8
(
sptr2
+
4
);
src2345
=
vld1q_s8
(
sptr2
+
8
);
int8x16_t
max2
=
vmaxq_s8
(
src0123
,
src1234
);
max2
=
vmaxq_s8
(
max2
,
src2345
);
int8x16_t
max_out
=
vmaxq_s8
(
max0
,
max1
);
max_out
=
vmaxq_s8
(
max_out
,
max2
);
vst1q_s8
(
dptr
,
max_out
);
sptr0
+=
16
;
sptr1
+=
16
;
sptr2
+=
16
;
dptr
+=
16
;
}
for
(;
ow
<
OW
;
++
ow
)
{
int8x8_t
src001
=
vld1_s8
(
sptr0
);
int8x8_t
src012
=
vld1_s8
(
sptr0
+
4
);
int8x8_t
src101
=
vld1_s8
(
sptr1
);
int8x8_t
src112
=
vld1_s8
(
sptr1
+
4
);
int8x8_t
src201
=
vld1_s8
(
sptr2
);
int8x8_t
src212
=
vld1_s8
(
sptr2
+
4
);
int8x8_t
max01_tmp
=
vmax_s8
(
src001
,
src101
);
max01_tmp
=
vmax_s8
(
max01_tmp
,
src201
);
int8x8_t
max12_tmp
=
vmax_s8
(
src012
,
src112
);
max12_tmp
=
vmax_s8
(
max12_tmp
,
src212
);
#define cb(i) \
int8_t dst##i = std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), \
max12_tmp[i + 4]);
#define store(i) *(dptr + i) = dst##i;
UNROLL_CALL_NOWRAPPER
(
4
,
cb
)
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
#undef store
#undef cb
sptr0
+=
4
;
sptr1
+=
4
;
sptr2
+=
4
;
dptr
+=
4
;
}
}
}
void
do_max_pooling_3x3_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
)
{
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
);
size_t
oh
=
0
;
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
<<
1
;
size_t
ih
=
oh
<<
1
;
const
int8_t
*
__restrict
sptr0
=
src
+
(
ih
+
0
)
*
IW
*
4
;
const
int8_t
*
sptr0
=
sptr
+
(
ih
+
0
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr1
=
src
+
(
ih
+
1
)
*
IW
*
4
;
const
int8_t
*
sptr1
=
sptr
+
(
ih
+
1
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr2
=
src
+
(
ih
+
2
)
*
IW
*
4
;
const
int8_t
*
sptr2
=
sptr
+
(
ih
+
2
)
*
IW2
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int8x16_t
src00
=
vld1q_s8
(
sptr0
);
int8x16_t
src00
=
vld1q_s8
(
sptr0
);
int8x16_t
src04
=
vld1q_s8
(
sptr0
+
4
*
4
);
int8x16_t
src04
=
vld1q_s8
(
sptr0
+
4
*
4
);
int8x16_t
src08
=
vld1q_s8
(
sptr0
+
4
*
8
);
int32x4_t
src08
=
vld1q_dup_s32
(
reinterpret_cast
<
const
int32_t
*>
(
sptr0
+
4
*
8
));
int32x4x2_t
src_tmp
=
vuzpq_s32
(
vreinterpretq_s32_s8
(
src00
),
int32x4x2_t
src_tmp
=
vuzpq_s32
(
vreinterpretq_s32_s8
(
src00
),
vreinterpretq_s32_s8
(
src04
));
vreinterpretq_s32_s8
(
src04
));
int32x4_t
src0246
=
src_tmp
.
val
[
0
];
int32x4_t
src0246
=
src_tmp
.
val
[
0
];
int32x4_t
src1357
=
src_tmp
.
val
[
1
];
int32x4_t
src1357
=
src_tmp
.
val
[
1
];
int32x4_t
src2468
=
int32x4_t
src2468
=
vextq_s32
(
src0246
,
src08
,
1
);
vextq_s32
(
src0246
,
vreinterpretq_s32_s8
(
src08
),
1
);
int8x16_t
max_tmp
=
vmaxq_s8
(
vreinterpretq_s8_s32
(
src0246
),
int8x16_t
max_tmp
=
vmaxq_s8
(
vreinterpretq_s8_s32
(
src0246
),
vreinterpretq_s8_s32
(
src1357
));
vreinterpretq_s8_s32
(
src1357
));
int8x16_t
max0
=
vmaxq_s8
(
max_tmp
,
vreinterpretq_s8_s32
(
src2468
));
int8x16_t
max0
=
vmaxq_s8
(
max_tmp
,
vreinterpretq_s8_s32
(
src2468
));
int8x16_t
src10
=
vld1q_s8
(
sptr1
);
int8x16_t
src10
=
vld1q_s8
(
sptr1
);
int8x16_t
src14
=
vld1q_s8
(
sptr1
+
4
*
4
);
int8x16_t
src14
=
vld1q_s8
(
sptr1
+
4
*
4
);
int8x16_t
src18
=
vld1q_s8
(
sptr1
+
4
*
8
);
int32x4_t
src18
=
vld1q_dup_s32
(
reinterpret_cast
<
const
int32_t
*>
(
sptr1
+
4
*
8
));
src_tmp
=
vuzpq_s32
(
vreinterpretq_s32_s8
(
src10
),
src_tmp
=
vuzpq_s32
(
vreinterpretq_s32_s8
(
src10
),
vreinterpretq_s32_s8
(
src14
));
vreinterpretq_s32_s8
(
src14
));
src0246
=
src_tmp
.
val
[
0
];
src0246
=
src_tmp
.
val
[
0
];
src1357
=
src_tmp
.
val
[
1
];
src1357
=
src_tmp
.
val
[
1
];
src2468
=
vextq_s32
(
src0246
,
vreinterpretq_s32_s8
(
src18
)
,
1
);
src2468
=
vextq_s32
(
src0246
,
src18
,
1
);
max_tmp
=
vmaxq_s8
(
vreinterpretq_s8_s32
(
src0246
),
max_tmp
=
vmaxq_s8
(
vreinterpretq_s8_s32
(
src0246
),
vreinterpretq_s8_s32
(
src1357
));
vreinterpretq_s8_s32
(
src1357
));
int8x16_t
max1
=
vmaxq_s8
(
max_tmp
,
vreinterpretq_s8_s32
(
src2468
));
int8x16_t
max1
=
vmaxq_s8
(
max_tmp
,
vreinterpretq_s8_s32
(
src2468
));
int8x16_t
src20
=
vld1q_s8
(
sptr2
);
int8x16_t
src20
=
vld1q_s8
(
sptr2
);
int8x16_t
src24
=
vld1q_s8
(
sptr2
+
4
*
4
);
int8x16_t
src24
=
vld1q_s8
(
sptr2
+
4
*
4
);
int8x16_t
src28
=
vld1q_s8
(
sptr2
+
4
*
8
);
int32x4_t
src28
=
vld1q_dup_s32
(
reinterpret_cast
<
const
int32_t
*>
(
sptr2
+
4
*
8
));
src_tmp
=
vuzpq_s32
(
vreinterpretq_s32_s8
(
src20
),
src_tmp
=
vuzpq_s32
(
vreinterpretq_s32_s8
(
src20
),
vreinterpretq_s32_s8
(
src24
));
vreinterpretq_s32_s8
(
src24
));
src0246
=
src_tmp
.
val
[
0
];
src0246
=
src_tmp
.
val
[
0
];
src1357
=
src_tmp
.
val
[
1
];
src1357
=
src_tmp
.
val
[
1
];
src2468
=
vextq_s32
(
src0246
,
vreinterpretq_s32_s8
(
src28
)
,
1
);
src2468
=
vextq_s32
(
src0246
,
src28
,
1
);
max_tmp
=
vmaxq_s8
(
vreinterpretq_s8_s32
(
src0246
),
max_tmp
=
vmaxq_s8
(
vreinterpretq_s8_s32
(
src0246
),
vreinterpretq_s8_s32
(
src1357
));
vreinterpretq_s8_s32
(
src1357
));
...
...
dnn/src/arm_common/pooling/do_
max_pooling_3x3_s2x2
_nchw44.h
→
dnn/src/arm_common/pooling/do_
pooling_3x3
_nchw44.h
浏览文件 @
df8931b6
...
@@ -15,9 +15,16 @@
...
@@ -15,9 +15,16 @@
namespace
megdnn
{
namespace
megdnn
{
namespace
arm_common
{
namespace
arm_common
{
void
do_max_pooling_3x3_s2x2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
void
do_max_pooling_3x3_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
);
void
do_max_pooling_3x3_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
);
size_t
OW
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
);
}
// namespace arm_common
}
// namespace arm_common
}
// namespace megdnn
}
// namespace megdnn
...
...
dnn/src/arm_common/pooling/do_
max_
pooling_4x4_nchw44.cpp
→
dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.cpp
浏览文件 @
df8931b6
...
@@ -9,7 +9,8 @@
...
@@ -9,7 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
* implied.
*/
*/
#include "src/arm_common/pooling/do_max_pooling_4x4_nchw44.h"
#include "src/arm_common/pooling/do_pooling_4x4_nchw44.h"
#include "src/arm_common/pooling/algo.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/unroll_macro.h"
...
@@ -19,14 +20,18 @@ namespace arm_common {
...
@@ -19,14 +20,18 @@ namespace arm_common {
void
do_max_pooling_4x4_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
void
do_max_pooling_4x4_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
)
{
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
)
{
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
);
size_t
oh
=
0
;
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
;
size_t
ih
=
oh
;
const
int8_t
*
__restrict
sptr0
=
s
rc
+
(
ih
+
0
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr0
=
s
ptr
+
(
ih
+
0
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr1
=
s
rc
+
(
ih
+
1
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr1
=
s
ptr
+
(
ih
+
1
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr2
=
s
rc
+
(
ih
+
2
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr2
=
s
ptr
+
(
ih
+
2
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr3
=
s
rc
+
(
ih
+
3
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr3
=
s
ptr
+
(
ih
+
3
)
*
IW2
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
...
@@ -90,35 +95,38 @@ void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
...
@@ -90,35 +95,38 @@ void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
void
do_max_pooling_4x4_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
void
do_max_pooling_4x4_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
)
{
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
)
{
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
);
size_t
oh
=
0
;
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
<<
1
;
size_t
ih
=
oh
<<
1
;
const
int8_t
*
__restrict
sptr0
=
s
rc
+
(
ih
+
0
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr0
=
s
ptr
+
(
ih
+
0
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr1
=
s
rc
+
(
ih
+
1
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr1
=
s
ptr
+
(
ih
+
1
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr2
=
s
rc
+
(
ih
+
2
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr2
=
s
ptr
+
(
ih
+
2
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr3
=
s
rc
+
(
ih
+
3
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr3
=
s
ptr
+
(
ih
+
3
)
*
IW2
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int8x16_t
src00
,
src04
,
src08
,
src09
,
max_tmp0
,
max_tmp1
,
max_tmp2
,
int8x16_t
src00
,
src04
,
max_tmp0
,
max_tmp1
,
max_tmp2
,
max_tmp3
;
max_tmp3
;
int32x4_t
src0246
,
src1357
,
src2468
,
src3579
,
src08
,
src09
;
int32x4_t
src0246
,
src1357
,
src2468
,
src3579
;
int32x4x2_t
src_tmp
;
int32x4x2_t
src_tmp
;
#define CACULATE_ROW(i) \
#define CACULATE_ROW(i)
\
src00 = vld1q_s8(sptr##i); \
src00 = vld1q_s8(sptr##i);
\
src04 = vld1q_s8(sptr##i + 4 * 4); \
src04 = vld1q_s8(sptr##i + 4 * 4);
\
src08 = vld1q_
s8(sptr##i + 4 * 8);
\
src08 = vld1q_
dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 8));
\
src09 = vld1q_
s8(sptr##i + 4 * 9);
\
src09 = vld1q_
dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 9));
\
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00),
\
vreinterpretq_s32_s8(src04)); \
vreinterpretq_s32_s8(src04));
\
src0246 = src_tmp.val[0]; \
src0246 = src_tmp.val[0];
\
src1357 = src_tmp.val[1]; \
src1357 = src_tmp.val[1];
\
src2468 = vextq_s32(src0246,
vreinterpretq_s32_s8(src08), 1);
\
src2468 = vextq_s32(src0246,
src08, 1);
\
src3579 = vextq_s32(src1357,
vreinterpretq_s32_s8(src09), 1);
\
src3579 = vextq_s32(src1357,
src09, 1);
\
max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \
max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246),
\
vreinterpretq_s8_s32(src1357)); \
vreinterpretq_s8_s32(src1357));
\
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468));
\
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579));
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579));
UNROLL_CALL_NOWRAPPER
(
4
,
CACULATE_ROW
)
UNROLL_CALL_NOWRAPPER
(
4
,
CACULATE_ROW
)
...
...
dnn/src/arm_common/pooling/do_
max_
pooling_4x4_nchw44.h
→
dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.h
浏览文件 @
df8931b6
...
@@ -18,7 +18,7 @@ namespace arm_common {
...
@@ -18,7 +18,7 @@ namespace arm_common {
#define KERN(strdie) \
#define KERN(strdie) \
void do_max_pooling_4x4_##strdie##_int8_nchw44_NEON( \
void do_max_pooling_4x4_##strdie##_int8_nchw44_NEON( \
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \
size_t OW, size_t PH, size_t PW);
size_t OW, size_t PH, size_t PW
, const WorkspaceBundle& ws
);
KERN
(
stride1
)
KERN
(
stride1
)
KERN
(
stride2
)
KERN
(
stride2
)
...
...
dnn/src/arm_common/pooling/do_
max_
pooling_5x5_nchw44.cpp
→
dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp
浏览文件 @
df8931b6
...
@@ -9,7 +9,8 @@
...
@@ -9,7 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
* implied.
*/
*/
#include "src/arm_common/pooling/do_max_pooling_5x5_nchw44.h"
#include "src/arm_common/pooling/do_pooling_5x5_nchw44.h"
#include "src/arm_common/pooling/algo.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/unroll_macro.h"
...
@@ -19,15 +20,19 @@ namespace arm_common {
...
@@ -19,15 +20,19 @@ namespace arm_common {
void
do_max_pooling_5x5_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
void
do_max_pooling_5x5_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
)
{
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
)
{
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
);
size_t
oh
=
0
;
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
;
size_t
ih
=
oh
;
const
int8_t
*
__restrict
sptr0
=
src
+
(
ih
+
0
)
*
IW
*
4
;
const
int8_t
*
sptr0
=
sptr
+
(
ih
+
0
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr1
=
src
+
(
ih
+
1
)
*
IW
*
4
;
const
int8_t
*
sptr1
=
sptr
+
(
ih
+
1
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr2
=
src
+
(
ih
+
2
)
*
IW
*
4
;
const
int8_t
*
sptr2
=
sptr
+
(
ih
+
2
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr3
=
src
+
(
ih
+
3
)
*
IW
*
4
;
const
int8_t
*
sptr3
=
sptr
+
(
ih
+
3
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr4
=
src
+
(
ih
+
4
)
*
IW
*
4
;
const
int8_t
*
sptr4
=
sptr
+
(
ih
+
4
)
*
IW2
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
...
@@ -80,13 +85,16 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
...
@@ -80,13 +85,16 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
max_out
=
vmax_s8
(
max_out
,
max_tmp3
);
max_out
=
vmax_s8
(
max_out
,
max_tmp3
);
max_out
=
vmax_s8
(
max_out
,
max_tmp4
);
max_out
=
vmax_s8
(
max_out
,
max_tmp4
);
#define COMPARE_SRC45(i) int8x8_t src##i##_45 = vld1_s8(sptr##i + 4 * 4);
#define COMPARE_SRC45(i) \
int32x2_t src##i##_45 = \
vld1_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 4));
UNROLL_CALL_NOWRAPPER
(
5
,
COMPARE_SRC45
)
UNROLL_CALL_NOWRAPPER
(
5
,
COMPARE_SRC45
)
int8x8_t
max_45
=
vmax_s8
(
src0_45
,
src1_45
);
int8x8_t
max_45
=
vmax_s8
(
vreinterpret_s8_s32
(
src0_45
),
max_45
=
vmax_s8
(
max_45
,
src1_45
);
vreinterpret_s8_s32
(
src1_45
));
max_45
=
vmax_s8
(
max_45
,
src2_45
);
max_45
=
vmax_s8
(
max_45
,
vreinterpret_s8_s32
(
src1_45
));
max_45
=
vmax_s8
(
max_45
,
src3_45
);
max_45
=
vmax_s8
(
max_45
,
vreinterpret_s8_s32
(
src2_45
));
max_45
=
vmax_s8
(
max_45
,
src4_45
);
max_45
=
vmax_s8
(
max_45
,
vreinterpret_s8_s32
(
src3_45
));
max_45
=
vmax_s8
(
max_45
,
vreinterpret_s8_s32
(
src4_45
));
#define store(i) \
#define store(i) \
*(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]);
*(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]);
...
@@ -106,39 +114,44 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
...
@@ -106,39 +114,44 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
void
do_max_pooling_5x5_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
void
do_max_pooling_5x5_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
)
{
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
)
{
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
);
size_t
oh
=
0
;
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
<<
1
;
size_t
ih
=
oh
<<
1
;
const
int8_t
*
__restrict
sptr0
=
src
+
(
ih
+
0
)
*
IW
*
4
;
const
int8_t
*
sptr0
=
sptr
+
(
ih
+
0
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr1
=
src
+
(
ih
+
1
)
*
IW
*
4
;
const
int8_t
*
sptr1
=
sptr
+
(
ih
+
1
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr2
=
src
+
(
ih
+
2
)
*
IW
*
4
;
const
int8_t
*
sptr2
=
sptr
+
(
ih
+
2
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr3
=
src
+
(
ih
+
3
)
*
IW
*
4
;
const
int8_t
*
sptr3
=
sptr
+
(
ih
+
3
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr4
=
src
+
(
ih
+
4
)
*
IW
*
4
;
const
int8_t
*
sptr4
=
sptr
+
(
ih
+
4
)
*
IW2
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int8x16_t
src00
,
src04
,
src08
,
src09
,
src10
,
max_tmp0
,
max_tmp1
,
int8x16_t
src00
,
src04
,
max_tmp0
,
max_tmp1
,
max_tmp2
,
max_tmp3
,
max_tmp2
,
max_tmp3
,
max_tmp4
;
max_tmp4
;
int32x4_t
src0246
,
src1357
,
src2468
,
src3579
,
src46810
;
int32x4_t
src0246
,
src1357
,
src2468
,
src3579
,
src46810
,
src10
,
src09
,
src08
;
int32x4x2_t
src_tmp
;
int32x4x2_t
src_tmp
;
#define CACULATE_ROW(i) \
#define CACULATE_ROW(i)
\
src00 = vld1q_s8(sptr##i); \
src00 = vld1q_s8(sptr##i);
\
src04 = vld1q_s8(sptr##i + 4 * 4); \
src04 = vld1q_s8(sptr##i + 4 * 4);
\
src08 = vld1q_
s8(sptr##i + 4 * 8);
\
src08 = vld1q_
dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 8));
\
src09 = vld1q_
s8(sptr##i + 4 * 9);
\
src09 = vld1q_
dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 9));
\
src10 = vld1q_
s8(sptr##i + 4 * 10);
\
src10 = vld1q_
dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 10));
\
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00),
\
vreinterpretq_s32_s8(src04)); \
vreinterpretq_s32_s8(src04));
\
src0246 = src_tmp.val[0]; \
src0246 = src_tmp.val[0];
\
src1357 = src_tmp.val[1]; \
src1357 = src_tmp.val[1];
\
src2468 = vextq_s32(src0246,
vreinterpretq_s32_s8(src08), 1);
\
src2468 = vextq_s32(src0246,
src08, 1);
\
src3579 = vextq_s32(src1357,
vreinterpretq_s32_s8(src09), 1);
\
src3579 = vextq_s32(src1357,
src09, 1);
\
src46810 = vextq_s32(src2468,
vreinterpretq_s32_s8(src10), 1);
\
src46810 = vextq_s32(src2468,
src10, 1);
\
max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \
max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246),
\
vreinterpretq_s8_s32(src1357)); \
vreinterpretq_s8_s32(src1357));
\
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468));
\
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579)); \
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579));
\
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src46810));
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src46810));
UNROLL_CALL_NOWRAPPER
(
5
,
CACULATE_ROW
)
UNROLL_CALL_NOWRAPPER
(
5
,
CACULATE_ROW
)
...
@@ -173,13 +186,16 @@ void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
...
@@ -173,13 +186,16 @@ void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
max_out
=
vmax_s8
(
max_out
,
max_tmp3
);
max_out
=
vmax_s8
(
max_out
,
max_tmp3
);
max_out
=
vmax_s8
(
max_out
,
max_tmp4
);
max_out
=
vmax_s8
(
max_out
,
max_tmp4
);
#define COMPARE_SRC45(i) int8x8_t src##i##_45 = vld1_s8(sptr##i + 4 * 4);
#define COMPARE_SRC45(i) \
int32x2_t src##i##_45 = \
vld1_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 4));
UNROLL_CALL_NOWRAPPER
(
5
,
COMPARE_SRC45
)
UNROLL_CALL_NOWRAPPER
(
5
,
COMPARE_SRC45
)
int8x8_t
max_45
=
vmax_s8
(
src0_45
,
src1_45
);
int8x8_t
max_45
=
vmax_s8
(
vreinterpret_s8_s32
(
src0_45
),
max_45
=
vmax_s8
(
max_45
,
src1_45
);
vreinterpret_s8_s32
(
src1_45
));
max_45
=
vmax_s8
(
max_45
,
src2_45
);
max_45
=
vmax_s8
(
max_45
,
vreinterpret_s8_s32
(
src1_45
));
max_45
=
vmax_s8
(
max_45
,
src3_45
);
max_45
=
vmax_s8
(
max_45
,
vreinterpret_s8_s32
(
src2_45
));
max_45
=
vmax_s8
(
max_45
,
src4_45
);
max_45
=
vmax_s8
(
max_45
,
vreinterpret_s8_s32
(
src3_45
));
max_45
=
vmax_s8
(
max_45
,
vreinterpret_s8_s32
(
src4_45
));
#define store(i) \
#define store(i) \
*(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]);
*(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]);
...
...
dnn/src/arm_common/pooling/do_
max_
pooling_5x5_nchw44.h
→
dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.h
浏览文件 @
df8931b6
...
@@ -18,7 +18,7 @@ namespace arm_common {
...
@@ -18,7 +18,7 @@ namespace arm_common {
#define KERN(strdie) \
#define KERN(strdie) \
void do_max_pooling_5x5_##strdie##_int8_nchw44_NEON( \
void do_max_pooling_5x5_##strdie##_int8_nchw44_NEON( \
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \
size_t OW, size_t PH, size_t PW);
size_t OW, size_t PH, size_t PW
, const WorkspaceBundle& ws
);
KERN
(
stride1
)
KERN
(
stride1
)
KERN
(
stride2
)
KERN
(
stride2
)
...
...
dnn/src/arm_common/pooling/opr_impl.cpp
浏览文件 @
df8931b6
...
@@ -26,8 +26,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj {
...
@@ -26,8 +26,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj {
AlgoInt8Filter2MaxStride2
algo_int8_filter2_max_stride2
;
AlgoInt8Filter2MaxStride2
algo_int8_filter2_max_stride2
;
AlgoInt8Filter3MaxStride2
algo_int8_filter3_max_stride2
;
AlgoInt8Filter3MaxStride2
algo_int8_filter3_max_stride2
;
AlgoFilter2MaxStridexNCHW44
algo_filter2_max_stridex_nchw4
;
AlgoFilter2MaxStridexNCHW44
algo_filter2_max_stridex_nchw4
;
AlgoFilter3MaxStride2NCHW44
algo_filter3_max_stride2_nchw4
;
AlgoFilter3MaxStridexNCHW44
algo_filter3_max_stridex_nchw4
;
AlgoFilter3MaxStride1NCHW44
algo_filter3_max_stride1_nchw4
;
AlgoFilter4MaxStridexNCHW44
algo_filter4_max_stridex_nchw4
;
AlgoFilter4MaxStridexNCHW44
algo_filter4_max_stridex_nchw4
;
AlgoFilter5MaxStridexNCHW44
algo_filter5_max_stridex_nchw4
;
AlgoFilter5MaxStridexNCHW44
algo_filter5_max_stridex_nchw4
;
...
@@ -41,8 +40,7 @@ public:
...
@@ -41,8 +40,7 @@ public:
all_algos
.
emplace_back
(
&
algo_filter5_max_stride2
);
all_algos
.
emplace_back
(
&
algo_filter5_max_stride2
);
all_algos
.
emplace_back
(
&
algo_int8_filter2_max_stride2
);
all_algos
.
emplace_back
(
&
algo_int8_filter2_max_stride2
);
all_algos
.
emplace_back
(
&
algo_int8_filter3_max_stride2
);
all_algos
.
emplace_back
(
&
algo_int8_filter3_max_stride2
);
all_algos
.
emplace_back
(
&
algo_filter3_max_stride2_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter3_max_stridex_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter3_max_stride1_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter2_max_stridex_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter2_max_stridex_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter4_max_stridex_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter4_max_stridex_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter5_max_stridex_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter5_max_stridex_nchw4
);
...
@@ -119,6 +117,12 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src,
...
@@ -119,6 +117,12 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src,
arm_common_workspace
=
ws
.
total_size_in_bytes
()
*
nr_threads
;
arm_common_workspace
=
ws
.
total_size_in_bytes
()
*
nr_threads
;
}
}
if
((
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)
&&
(
param
.
format
==
param
::
Pooling
::
Format
::
NCHW44
))
{
WorkspaceBundle
ws
=
get_bundle_nchw44
(
param
);
arm_common_workspace
=
ws
.
total_size_in_bytes
()
*
nr_threads
;
}
if
(
find_algo
)
{
if
(
find_algo
)
{
return
arm_common_workspace
;
return
arm_common_workspace
;
}
else
{
}
else
{
...
...
dnn/src/arm_common/pooling/opr_impl.h
浏览文件 @
df8931b6
...
@@ -84,8 +84,7 @@ private:
...
@@ -84,8 +84,7 @@ private:
class
AlgoInt8Filter2MaxStride2
;
class
AlgoInt8Filter2MaxStride2
;
class
AlgoInt8Filter3MaxStride2
;
class
AlgoInt8Filter3MaxStride2
;
class
AlgoFilter2MaxStridexNCHW44
;
class
AlgoFilter2MaxStridexNCHW44
;
class
AlgoFilter3MaxStride2NCHW44
;
class
AlgoFilter3MaxStridexNCHW44
;
class
AlgoFilter3MaxStride1NCHW44
;
class
AlgoFilter4MaxStridexNCHW44
;
class
AlgoFilter4MaxStridexNCHW44
;
class
AlgoFilter5MaxStridexNCHW44
;
class
AlgoFilter5MaxStridexNCHW44
;
class
AlgoPack
;
class
AlgoPack
;
...
...
dnn/test/arm_common/pooling.cpp
浏览文件 @
df8931b6
...
@@ -8,8 +8,6 @@
...
@@ -8,8 +8,6 @@
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
*/
#include "megdnn/dtype.h"
#include "megdnn/opr_param_defs.h"
#include "test/arm_common/fixture.h"
#include "test/arm_common/fixture.h"
#include "test/common/pooling.h"
#include "test/common/pooling.h"
...
@@ -102,209 +100,6 @@ TEST_F(ARM_COMMON, POOLING_INT8_W3x3_S2x2)
...
@@ -102,209 +100,6 @@ TEST_F(ARM_COMMON, POOLING_INT8_W3x3_S2x2)
// clang-format on
// clang-format on
}
}
TEST_F
(
ARM_COMMON
,
POOLING_MAX_W3x3_S2x2_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
3
,
5
,
10
})
for
(
size_t
iw
:
{
3
,
5
,
7
,
9
,
15
,
20
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
3
&&
iw
+
2
*
pw
>=
3
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
window_h
=
param
.
window_w
=
3
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON
,
POOLING_MAX_W3x3_S1x1_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
3
,
5
,
10
})
for
(
size_t
iw
:
{
3
,
5
,
7
,
9
,
15
,
20
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
3
&&
iw
+
2
*
pw
>=
3
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
window_h
=
param
.
window_w
=
3
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON
,
POOLING_MAX_W2x2_S1x1_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
2
,
5
,
10
,
17
})
for
(
size_t
iw
:
{
2
,
6
,
8
,
16
,
26
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
2
&&
iw
+
2
*
pw
>=
2
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
window_h
=
param
.
window_w
=
2
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON
,
POOLING_MAX_W2x2_S2x2_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
2
,
5
,
10
,
17
})
for
(
size_t
iw
:
{
2
,
6
,
8
,
16
,
26
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
2
&&
iw
+
2
*
pw
>=
2
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
window_h
=
param
.
window_w
=
2
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON
,
POOLING_MAX_W4x4_S1x1_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
4
,
7
,
10
,
17
,
20
})
for
(
size_t
iw
:
{
4
,
8
,
10
,
21
,
32
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
2
&&
iw
+
2
*
pw
>=
2
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
window_h
=
param
.
window_w
=
4
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON
,
POOLING_MAX_W4x4_S2x2_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
4
,
10
,
18
,
25
,
30
})
for
(
size_t
iw
:
{
4
,
12
,
17
,
20
,
25
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
2
&&
iw
+
2
*
pw
>=
2
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
window_h
=
param
.
window_w
=
4
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON
,
POOLING_MAX_W5x5_S1x1_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
5
,
9
,
19
,
20
,
39
})
for
(
size_t
iw
:
{
5
,
12
,
23
,
27
,
39
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
5
&&
iw
+
2
*
pw
>=
5
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
window_h
=
param
.
window_w
=
5
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON
,
POOLING_MAX_W5x5_S2x2_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
5
,
9
,
19
,
20
,
39
})
for
(
size_t
iw
:
{
5
,
12
,
23
,
27
,
39
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
5
&&
iw
+
2
*
pw
>=
5
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
window_h
=
param
.
window_w
=
5
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F
(
ARM_COMMON
,
POOLING_FP16
)
{
TEST_F
(
ARM_COMMON
,
POOLING_FP16
)
{
Checker
<
Pooling
>
checker
(
handle
());
Checker
<
Pooling
>
checker
(
handle
());
...
...
dnn/test/arm_common/pooling_multi_thread.cpp
浏览文件 @
df8931b6
...
@@ -8,6 +8,8 @@
...
@@ -8,6 +8,8 @@
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
*/
#include <vector>
#include "megdnn/dtype.h"
#include "test/arm_common/fixture.h"
#include "test/arm_common/fixture.h"
#include "test/common/pooling.h"
#include "test/common/pooling.h"
...
@@ -53,38 +55,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING) {
...
@@ -53,38 +55,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING) {
checker
.
set_param
(
param
).
exec
({{
2
,
3
,
ih
,
iw
},
{}});
checker
.
set_param
(
param
).
exec
({{
2
,
3
,
ih
,
iw
},
{}});
}
}
}
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_MAX_W3x3_S2x2_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
3
,
5
,
10
})
for
(
size_t
iw
:
{
3
,
5
,
7
,
9
,
15
,
20
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
3
&&
iw
+
2
*
pw
>=
3
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_MAX_W3x3_NCHW44
)
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
window_h
=
param
.
window_w
=
3
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_MAX_W3x3_S1x1_NCHW44
)
{
{
// clang-format off
// clang-format off
for
(
size_t
ih
:
{
3
,
5
,
10
})
for
(
size_t
ih
:
{
3
,
5
,
10
})
for
(
size_t
iw
:
{
3
,
5
,
7
,
9
,
15
,
20
})
for
(
size_t
iw
:
{
3
,
5
,
7
,
9
,
15
,
20
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
ph
:
{
0
,
1
,
2
})
for
(
size_t
pw
:
{
0
})
for
(
size_t
pw
:
{
0
,
1
,
2
})
if
(
ih
+
2
*
ph
>=
3
&&
iw
+
2
*
pw
>=
3
)
if
(
ih
+
2
*
ph
>=
3
&&
iw
+
2
*
pw
>=
3
)
{
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
...
@@ -100,18 +78,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_S1x1_NCHW44)
...
@@ -100,18 +78,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_S1x1_NCHW44)
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
window_h
=
param
.
window_w
=
3
;
param
.
window_h
=
param
.
window_w
=
3
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
param
.
stride_h
=
param
.
stride_w
=
2
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
}
// clang-format on
// clang-format on
}
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_MAX_W2x2_
S1x1_
NCHW44
)
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_MAX_W2x2_NCHW44
)
{
{
// clang-format off
// clang-format off
for
(
size_t
ih
:
{
2
,
5
,
10
,
17
})
for
(
size_t
ih
:
{
2
,
5
,
10
,
17
})
for
(
size_t
iw
:
{
2
,
6
,
8
,
16
,
26
})
for
(
size_t
iw
:
{
2
,
6
,
8
,
16
,
26
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
ph
:
{
0
,
1
})
for
(
size_t
pw
:
{
0
})
for
(
size_t
pw
:
{
0
,
1
})
if
(
ih
+
2
*
ph
>=
3
&&
iw
+
2
*
pw
>=
3
)
if
(
ih
+
2
*
ph
>=
2
&&
iw
+
2
*
pw
>=
2
)
{
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
Checker
<
Pooling
>
checker
(
handle
());
...
@@ -126,41 +108,20 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_S1x1_NCHW44)
...
@@ -126,41 +108,20 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_S1x1_NCHW44)
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
window_h
=
param
.
window_w
=
2
;
param
.
window_h
=
param
.
window_w
=
2
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_MAX_W2x2_S2x2_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
2
,
5
,
10
,
17
})
for
(
size_t
iw
:
{
2
,
6
,
8
,
16
,
26
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
3
&&
iw
+
2
*
pw
>=
3
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
window_h
=
param
.
window_w
=
2
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
}
// clang-format on
// clang-format on
}
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_MAX_W4x4_S1x1_NCHW44
)
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_MAX_W4x4_NCHW44
)
{
{
// clang-format off
// clang-format off
for
(
size_t
ih
:
{
4
,
7
,
10
,
17
,
2
0
})
for
(
size_t
ih
:
{
4
,
10
,
18
,
25
,
3
0
})
for
(
size_t
iw
:
{
4
,
8
,
10
,
21
,
32
})
for
(
size_t
iw
:
{
4
,
12
,
17
,
20
,
25
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
ph
:
{
0
,
1
,
2
})
for
(
size_t
pw
:
{
0
})
for
(
size_t
pw
:
{
0
,
1
,
2
})
if
(
ih
+
2
*
ph
>=
4
&&
iw
+
2
*
pw
>=
4
)
if
(
ih
+
2
*
ph
>=
4
&&
iw
+
2
*
pw
>=
4
)
{
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
...
@@ -176,41 +137,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_S1x1_NCHW44)
...
@@ -176,41 +137,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_S1x1_NCHW44)
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
window_h
=
param
.
window_w
=
4
;
param
.
window_h
=
param
.
window_w
=
4
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_MAX_W4x4_S2x2_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
4
,
10
,
18
,
25
,
30
})
for
(
size_t
iw
:
{
4
,
12
,
17
,
20
,
25
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
4
&&
iw
+
2
*
pw
>=
4
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
window_h
=
param
.
window_w
=
4
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
}
// clang-format on
// clang-format on
}
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_MAX_W5x5_
S1x1_
NCHW44
)
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_MAX_W5x5_NCHW44
)
{
{
// clang-format off
// clang-format off
for
(
size_t
ih
:
{
5
,
9
,
19
,
20
,
39
})
for
(
size_t
ih
:
{
5
,
9
,
19
,
20
,
39
})
for
(
size_t
iw
:
{
5
,
12
,
23
,
27
,
39
})
for
(
size_t
iw
:
{
5
,
12
,
23
,
27
,
39
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
ph
:
{
0
,
1
,
2
})
for
(
size_t
pw
:
{
0
})
for
(
size_t
pw
:
{
0
,
1
,
2
})
if
(
ih
+
2
*
ph
>=
5
&&
iw
+
2
*
pw
>=
5
)
if
(
ih
+
2
*
ph
>=
5
&&
iw
+
2
*
pw
>=
5
)
{
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
...
@@ -226,31 +165,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_S1x1_NCHW44)
...
@@ -226,31 +165,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_S1x1_NCHW44)
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
window_h
=
param
.
window_w
=
5
;
param
.
window_h
=
param
.
window_w
=
5
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_MAX_W5x5_S2x2_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
5
,
9
,
19
,
20
,
39
})
for
(
size_t
iw
:
{
5
,
12
,
23
,
27
,
39
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
5
&&
iw
+
2
*
pw
>=
5
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
window_h
=
param
.
window_w
=
5
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
}
// clang-format on
// clang-format on
}
}
...
@@ -473,13 +391,15 @@ template <typename Opr>
...
@@ -473,13 +391,15 @@ template <typename Opr>
void
benchmark_impl
(
const
typename
Opr
::
Param
&
param
,
void
benchmark_impl
(
const
typename
Opr
::
Param
&
param
,
std
::
vector
<
SmallVector
<
TensorShape
>>
shapes
,
size_t
RUNS
,
std
::
vector
<
SmallVector
<
TensorShape
>>
shapes
,
size_t
RUNS
,
TaskExecutorConfig
&&
multi_thread_config
,
TaskExecutorConfig
&&
multi_thread_config
,
TaskExecutorConfig
&&
single_thread_config
)
{
TaskExecutorConfig
&&
single_thread_config
,
DType
data_type
)
{
std
::
vector
<
float
>
multi_thread_times
,
single_thread_times
;
std
::
vector
<
float
>
multi_thread_times
,
single_thread_times
;
{
{
auto
multi_thread_hanle
=
auto
multi_thread_hanle
=
create_cpu_handle
(
0
,
true
,
&
multi_thread_config
);
create_cpu_handle
(
0
,
true
,
&
multi_thread_config
);
auto
benchmarker
=
Benchmarker
<
Opr
>
(
multi_thread_hanle
.
get
());
auto
benchmarker
=
Benchmarker
<
Opr
>
(
multi_thread_hanle
.
get
());
benchmarker
.
set_times
(
RUNS
).
set_display
(
false
).
set_param
(
param
);
benchmarker
.
set_times
(
RUNS
).
set_display
(
false
).
set_param
(
param
);
benchmarker
.
set_dtype
(
0
,
data_type
);
for
(
auto
shape
:
shapes
)
{
for
(
auto
shape
:
shapes
)
{
multi_thread_times
.
push_back
(
benchmarker
.
exec
(
shape
)
/
RUNS
);
multi_thread_times
.
push_back
(
benchmarker
.
exec
(
shape
)
/
RUNS
);
}
}
...
@@ -489,6 +409,7 @@ void benchmark_impl(const typename Opr::Param& param,
...
@@ -489,6 +409,7 @@ void benchmark_impl(const typename Opr::Param& param,
create_cpu_handle
(
0
,
true
,
&
single_thread_config
);
create_cpu_handle
(
0
,
true
,
&
single_thread_config
);
auto
benchmarker
=
Benchmarker
<
Opr
>
(
single_thread_handle
.
get
());
auto
benchmarker
=
Benchmarker
<
Opr
>
(
single_thread_handle
.
get
());
benchmarker
.
set_times
(
RUNS
).
set_display
(
false
).
set_param
(
param
);
benchmarker
.
set_times
(
RUNS
).
set_display
(
false
).
set_param
(
param
);
benchmarker
.
set_dtype
(
0
,
data_type
);
for
(
auto
shape
:
shapes
)
{
for
(
auto
shape
:
shapes
)
{
single_thread_times
.
push_back
(
benchmarker
.
exec
(
shape
)
/
RUNS
);
single_thread_times
.
push_back
(
benchmarker
.
exec
(
shape
)
/
RUNS
);
}
}
...
@@ -540,10 +461,47 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_POOLING) {
...
@@ -540,10 +461,47 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_POOLING) {
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
pad_h
=
param
.
pad_w
=
1
;
param
.
pad_h
=
param
.
pad_w
=
1
;
printf
(
"Benchmark POOLING kernel:%d*%d stride:%d,mode %d
\n
"
,
param
.
window_h
,
printf
(
"Benchmark POOLING kernel:%d*%d stride:%d,mode %d
\n
"
,
param
.
window_h
,
param
.
stride_h
,
param
.
pad_h
,
static_cast
<
int
>
(
param
.
mode
));
param
.
window_w
,
param
.
stride_h
,
static_cast
<
int
>
(
param
.
mode
));
benchmark_impl
<
Pooling
>
(
param
,
shapes
,
RUNS
,
{
4
,
{
0
,
1
,
2
,
3
}},
{
1
,
{
0
}});
benchmark_impl
<
Pooling
>
(
param
,
shapes
,
RUNS
,
{
4
,
{
0
,
1
,
2
,
3
}},
{
1
,
{
0
}},
dtype
::
Float32
());
benchmark_impl
<
Pooling
>
(
param
,
shapes
,
RUNS
,
{
4
,
{
4
,
5
,
6
,
7
}},
{
1
,
{
4
}});
benchmark_impl
<
Pooling
>
(
param
,
shapes
,
RUNS
,
{
4
,
{
4
,
5
,
6
,
7
}},
{
1
,
{
4
}},
dtype
::
Float32
());
benchmark_impl
<
Pooling
>
(
param
,
shapes
,
RUNS
,
{
2
,
{
0
,
1
}},
{
1
,
{
0
}});
benchmark_impl
<
Pooling
>
(
param
,
shapes
,
RUNS
,
{
2
,
{
0
,
1
}},
{
1
,
{
0
}},
dtype
::
Float32
());
}
TEST_F
(
ARM_COMMON_BENCHMARK_MULTI_THREADS
,
BENCHMARK_POOLING_NCHW44
)
{
constexpr
size_t
RUNS
=
50
;
using
Param
=
param
::
Pooling
;
Param
param
;
param
.
pad_h
=
param
.
pad_w
=
0
;
param
.
mode
=
Param
::
Mode
::
MAX
;
std
::
vector
<
SmallVector
<
TensorShape
>>
shapes
;
std
::
vector
<
std
::
vector
<
size_t
>>
filter_and_stride
=
{
{
2
,
1
},
{
2
,
2
},
{
3
,
1
},
{
3
,
2
},
{
4
,
1
},
{
4
,
2
},
{
5
,
1
},
{
5
,
2
}};
for
(
auto
filter
:
filter_and_stride
){
shapes
.
push_back
({{
1
,
32
*
4
,
215
,
215
},
{}});
shapes
.
push_back
({{
1
,
32
*
4
,
128
,
128
},
{}});
shapes
.
push_back
({{
1
,
16
*
4
,
56
,
56
},
{}});
param
.
window_h
=
param
.
window_w
=
filter
[
0
];
param
.
stride_h
=
param
.
stride_w
=
filter
[
1
];
param
.
format
=
Param
::
Format
::
NCHW
;
printf
(
"NCHW Benchmark POOLING kernel:%d*%d stride:%d,mode %d
\n
"
,
param
.
window_h
,
param
.
window_h
,
param
.
stride_h
,
static_cast
<
int
>
(
param
.
mode
));
benchmark_impl
<
Pooling
>
(
param
,
shapes
,
RUNS
,
{
4
,
{
4
,
5
,
6
,
7
}},
{
1
,
{
4
}},
dtype
::
QuantizedS8
(
1.1
f
));
shapes
.
clear
();
shapes
.
push_back
({{
1
,
32
,
215
,
215
,
4
},
{}});
shapes
.
push_back
({{
1
,
32
,
128
,
128
,
4
},
{}});
shapes
.
push_back
({{
1
,
16
,
56
,
56
,
4
},
{}});
param
.
format
=
Param
::
Format
::
NCHW44
;
printf
(
"NCHW44 Benchmark POOLING kernel:%d*%d stride:%d,mode %d
\n
"
,
param
.
window_h
,
param
.
window_w
,
param
.
stride_h
,
static_cast
<
int
>
(
param
.
mode
));
benchmark_impl
<
Pooling
>
(
param
,
shapes
,
RUNS
,
{
4
,
{
4
,
5
,
6
,
7
}},
{
1
,
{
4
}},
dtype
::
QuantizedS8
(
1.1
f
));
shapes
.
clear
();
}
}
}
#endif
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录