Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2ae9fdef
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2ae9fdef
编写于
5月 09, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
5月 12, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/arm): add arm common nchw44 avg pooling
GitOrigin-RevId: 25eab33e14ef4000480acea17a49e4360d42bdd7
上级
0293d58a
变更
13
展开全部
隐藏空白更改
内联
并排
Showing
13 changed file
with
1214 addition
and
179 deletion
+1214
-179
dnn/src/arm_common/pooling/algo.cpp
dnn/src/arm_common/pooling/algo.cpp
+143
-78
dnn/src/arm_common/pooling/algo.h
dnn/src/arm_common/pooling/algo.h
+9
-9
dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp
dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp
+203
-3
dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.h
dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.h
+10
-11
dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp
dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp
+238
-3
dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.h
dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.h
+10
-11
dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.cpp
dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.cpp
+248
-2
dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.h
dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.h
+8
-7
dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp
dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp
+281
-3
dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.h
dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.h
+8
-7
dnn/src/arm_common/pooling/opr_impl.cpp
dnn/src/arm_common/pooling/opr_impl.cpp
+8
-8
dnn/src/arm_common/pooling/opr_impl.h
dnn/src/arm_common/pooling/opr_impl.h
+4
-4
dnn/test/arm_common/pooling_multi_thread.cpp
dnn/test/arm_common/pooling_multi_thread.cpp
+44
-33
未找到文件。
dnn/src/arm_common/pooling/algo.cpp
浏览文件 @
2ae9fdef
此差异已折叠。
点击以展开。
dnn/src/arm_common/pooling/algo.h
浏览文件 @
2ae9fdef
...
...
@@ -83,34 +83,34 @@ public:
void
exec
(
const
PoolingKernParam
&
param
)
const
override
;
};
class
PoolingImpl
::
AlgoFilter3M
a
xStridexNCHW44
final
:
public
AlgoBase
{
class
PoolingImpl
::
AlgoFilter3M
ode
xStridexNCHW44
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"ARM_POOLING_FILTER3_M
A
X_STRIDEX_NCHW44"
;
}
const
char
*
name
()
const
override
{
return
"ARM_POOLING_FILTER3_M
ODE
X_STRIDEX_NCHW44"
;
}
bool
usable
(
const
PoolingKernSizeParam
&
param
)
const
override
;
void
exec
(
const
PoolingKernParam
&
param
)
const
override
;
};
class
PoolingImpl
::
AlgoFilter2M
a
xStridexNCHW44
final
:
public
AlgoBase
{
class
PoolingImpl
::
AlgoFilter2M
ode
xStridexNCHW44
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"ARM_POOLING_FILTER2_M
A
X_STRIDEX_NCHW44"
;
}
const
char
*
name
()
const
override
{
return
"ARM_POOLING_FILTER2_M
ODE
X_STRIDEX_NCHW44"
;
}
bool
usable
(
const
PoolingKernSizeParam
&
param
)
const
override
;
void
exec
(
const
PoolingKernParam
&
param
)
const
override
;
};
class
PoolingImpl
::
AlgoFilter4M
a
xStridexNCHW44
final
:
public
AlgoBase
{
class
PoolingImpl
::
AlgoFilter4M
ode
xStridexNCHW44
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"ARM_POOLING_FILTER4_M
A
X_STRIDEX_NCHW44"
;
}
const
char
*
name
()
const
override
{
return
"ARM_POOLING_FILTER4_M
ODE
X_STRIDEX_NCHW44"
;
}
bool
usable
(
const
PoolingKernSizeParam
&
param
)
const
override
;
void
exec
(
const
PoolingKernParam
&
param
)
const
override
;
};
class
PoolingImpl
::
AlgoFilter5M
a
xStridexNCHW44
final
:
public
AlgoBase
{
class
PoolingImpl
::
AlgoFilter5M
ode
xStridexNCHW44
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"ARM_POOLING_FILTER5_M
A
X_STRIDEX_NCHW44"
;
}
const
char
*
name
()
const
override
{
return
"ARM_POOLING_FILTER5_M
ODE
X_STRIDEX_NCHW44"
;
}
bool
usable
(
const
PoolingKernSizeParam
&
param
)
const
override
;
void
exec
(
const
PoolingKernParam
&
param
)
const
override
;
};
...
...
@@ -122,7 +122,7 @@ WorkspaceBundle get_bundle_nchw44(
const
int8_t
*
handle_padding
(
const
int8_t
*
src
,
size_t
IH
,
size_t
IW
,
size_t
&
IH2
,
size_t
&
IW2
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
);
const
WorkspaceBundle
&
ws
,
bool
is_max_mode
);
}
// namespace arm_common
}
// namespace megdnn
...
...
dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp
浏览文件 @
2ae9fdef
/**
* \file dnn/src/arm_common/pooling/do_
max_
pooling_2x2_nchw44.cpp
* \file dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...
...
@@ -24,7 +24,7 @@ void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
const
WorkspaceBundle
&
ws
)
{
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
);
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
,
true
);
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
;
...
...
@@ -70,7 +70,7 @@ void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
const
WorkspaceBundle
&
ws
)
{
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
);
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
,
true
);
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
<<
1
;
...
...
@@ -120,6 +120,206 @@ void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
}
}
void
do_avg_pooling_2x2_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
)
{
int16_t
filter_size
=
4
;
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
,
false
);
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
;
const
int8_t
*
__restrict
sptr0
=
sptr
+
(
ih
+
0
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr1
=
sptr
+
(
ih
+
1
)
*
IW2
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int8x16_t
src0123
,
src1234
;
int16x8_t
src01
,
src23
,
src12
,
src34
;
int16x8_t
sum01
=
vdupq_n_s16
(
0
);
int16x8_t
sum23
=
vdupq_n_s16
(
0
);
#define CACULATE_ROW(i) \
src0123 = vld1q_s8(sptr##i); \
src1234 = vld1q_s8(sptr##i + 4); \
src01 = vmovl_s8(vget_low_s8(src0123)); \
src23 = vmovl_s8(vget_high_s8(src0123)); \
src12 = vmovl_s8(vget_low_s8(src1234)); \
src34 = vmovl_s8(vget_high_s8(src1234)); \
sum01 = vaddq_s16(sum01, src01); \
sum01 = vaddq_s16(sum01, src12); \
sum23 = vaddq_s16(sum23, src23); \
sum23 = vaddq_s16(sum23, src34);
UNROLL_CALL_NOWRAPPER
(
2
,
CACULATE_ROW
)
#define sum_define(i) int16_t sum##i;
UNROLL_CALL_NOWRAPPER
(
8
,
sum_define
)
#define sum01_avg(i) \
sum##i = vgetq_lane_s16(sum01, i) > 0 \
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \
filter_size;
#define sum23_avg(i) \
sum##i = vgetq_lane_s16(sum23, i) > 0 \
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \
filter_size;
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i);
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER
(
8
,
sum01_avg
)
UNROLL_CALL_NOWRAPPER
(
8
,
store_sum01
)
UNROLL_CALL_NOWRAPPER
(
8
,
sum23_avg
)
UNROLL_CALL_NOWRAPPER
(
8
,
store_sum23
)
sptr0
+=
16
;
sptr1
+=
16
;
dptr
+=
16
;
#undef store_sum01
#undef store_sum23
#undef sum01_avg
#undef sum23_avg
#undef sum_define
#undef CACULATE_ROW
}
for
(;
ow
<
OW
;
++
ow
)
{
int8x8_t
src001
=
vld1_s8
(
sptr0
);
int8x8_t
src101
=
vld1_s8
(
sptr1
);
int16x8_t
src00
=
vmovl_s8
(
src001
);
int16x8_t
src10
=
vmovl_s8
(
src101
);
int16x8_t
max_tmp
=
vaddq_s16
(
src00
,
src10
);
#define do_acc(i) \
int16_t sum##i = \
vgetq_lane_s16(max_tmp, i) + vgetq_lane_s16(max_tmp, i + 4);
#define do_avg(i) \
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \
: (sum##i - filter_size / 2) / filter_size;
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER
(
4
,
do_acc
)
UNROLL_CALL_NOWRAPPER
(
4
,
do_avg
)
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
#undef store
#undef do_avg
#undef do_acc
sptr0
+=
4
;
sptr1
+=
4
;
dptr
+=
4
;
}
}
}
void
do_avg_pooling_2x2_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
)
{
int16_t
filter_size
=
4
;
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
,
false
);
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
<<
1
;
const
int8_t
*
__restrict
sptr0
=
sptr
+
(
ih
+
0
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr1
=
sptr
+
(
ih
+
1
)
*
IW2
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int32x4x2_t
src_tmp
;
int8x16_t
src00
,
src04
;
int32x4_t
src0246
,
src1357
;
int16x8_t
src02
,
src13
,
src46
,
src57
;
int16x8_t
sum01
=
vdupq_n_s16
(
0
);
int16x8_t
sum23
=
vdupq_n_s16
(
0
);
#define CACULATE_ROW(i) \
src00 = vld1q_s8(sptr##i); \
src04 = vld1q_s8(sptr##i + 4 * 4); \
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \
vreinterpretq_s32_s8(src04)); \
src0246 = src_tmp.val[0]; \
src1357 = src_tmp.val[1]; \
src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \
src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \
src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \
src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \
sum01 = vaddq_s16(sum01, src02); \
sum01 = vaddq_s16(sum01, src13); \
sum23 = vaddq_s16(sum23, src46); \
sum23 = vaddq_s16(sum23, src57);
UNROLL_CALL_NOWRAPPER
(
2
,
CACULATE_ROW
)
#define sum_define(i) int16_t sum##i;
UNROLL_CALL_NOWRAPPER
(
8
,
sum_define
)
#define sum01_avg(i) \
sum##i = vgetq_lane_s16(sum01, i) > 0 \
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \
filter_size;
#define sum23_avg(i) \
sum##i = vgetq_lane_s16(sum23, i) > 0 \
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \
filter_size;
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i);
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER
(
8
,
sum01_avg
)
UNROLL_CALL_NOWRAPPER
(
8
,
store_sum01
)
UNROLL_CALL_NOWRAPPER
(
8
,
sum23_avg
)
UNROLL_CALL_NOWRAPPER
(
8
,
store_sum23
)
sptr0
+=
32
;
sptr1
+=
32
;
dptr
+=
16
;
#undef store_sum01
#undef store_sum23
#undef sum01_avg
#undef sum23_avg
#undef sum_define
#undef CACULATE_ROW
}
for
(;
ow
<
OW
;
++
ow
)
{
int8x8_t
src001
=
vld1_s8
(
sptr0
);
int8x8_t
src101
=
vld1_s8
(
sptr1
);
int16x8_t
src00
=
vmovl_s8
(
src001
);
int16x8_t
src10
=
vmovl_s8
(
src101
);
int16x8_t
max_tmp
=
vaddq_s16
(
src00
,
src10
);
#define do_acc(i) \
int16_t sum##i = \
vgetq_lane_s16(max_tmp, i) + vgetq_lane_s16(max_tmp, i + 4);
#define do_avg(i) \
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \
: (sum##i - filter_size / 2) / filter_size;
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER
(
4
,
do_acc
)
UNROLL_CALL_NOWRAPPER
(
4
,
do_avg
)
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
#undef do_avg
#undef do_acc
#undef store
sptr0
+=
8
;
sptr1
+=
8
;
dptr
+=
4
;
}
}
}
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.h
浏览文件 @
2ae9fdef
/**
* \file dnn/src/arm_common/pooling/do_
max_
pooling_2x2_nchw44.h
* \file dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...
...
@@ -15,16 +15,15 @@
namespace
megdnn
{
namespace
arm_common
{
void
do_max_pooling_2x2_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
);
void
do_max_pooling_2x2_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
);
#define KERN(mode, stride, ctype) \
void do_##mode##_pooling_2x2_##stride##_##ctype##_nchw44_NEON( \
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \
size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws);
KERN
(
max
,
stride1
,
int8
)
KERN
(
max
,
stride2
,
int8
)
KERN
(
avg
,
stride1
,
int8
)
KERN
(
avg
,
stride2
,
int8
)
#undef KERN
}
// namespace arm_common
}
// namespace megdnn
...
...
dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp
浏览文件 @
2ae9fdef
/**
* \file dnn/src/arm_common/pooling/do_
max_pooling_3x3_s2x2
_nchw44.cpp
* \file dnn/src/arm_common/pooling/do_
pooling_3x3
_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...
...
@@ -24,7 +24,7 @@ void do_max_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
const
WorkspaceBundle
&
ws
)
{
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
);
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
,
true
);
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
;
...
...
@@ -99,7 +99,7 @@ void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
const
WorkspaceBundle
&
ws
)
{
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
);
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
,
true
);
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
<<
1
;
...
...
@@ -190,6 +190,241 @@ void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
}
}
void
do_avg_pooling_3x3_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
)
{
int16_t
filter_size
=
9
;
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
,
false
);
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
;
const
int8_t
*
__restrict
sptr0
=
sptr
+
(
ih
+
0
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr1
=
sptr
+
(
ih
+
1
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr2
=
sptr
+
(
ih
+
2
)
*
IW2
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int8x16_t
src0123
,
src1234
,
src2345
;
int16x8_t
src01
,
src23
,
src12
,
src34
,
src45
;
int16x8_t
sum01
=
vdupq_n_s16
(
0
);
int16x8_t
sum23
=
vdupq_n_s16
(
0
);
#define CACULATE_ROW(i) \
src0123 = vld1q_s8(sptr##i); \
src1234 = vld1q_s8(sptr##i + 4); \
src2345 = vld1q_s8(sptr##i + 8); \
src01 = vmovl_s8(vget_low_s8(src0123)); \
src23 = vmovl_s8(vget_high_s8(src0123)); \
src12 = vmovl_s8(vget_low_s8(src1234)); \
src34 = vmovl_s8(vget_high_s8(src1234)); \
src45 = vmovl_s8(vget_high_s8(src2345)); \
sum01 = vaddq_s16(sum01, src01); \
sum01 = vaddq_s16(sum01, src12); \
sum01 = vaddq_s16(sum01, src23); \
sum23 = vaddq_s16(sum23, src23); \
sum23 = vaddq_s16(sum23, src34); \
sum23 = vaddq_s16(sum23, src45);
UNROLL_CALL_NOWRAPPER
(
3
,
CACULATE_ROW
)
#define sum_define(i) int16_t sum##i;
UNROLL_CALL_NOWRAPPER
(
8
,
sum_define
)
#define sum01_avg(i) \
sum##i = vgetq_lane_s16(sum01, i) > 0 \
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \
filter_size;
#define sum23_avg(i) \
sum##i = vgetq_lane_s16(sum23, i) > 0 \
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \
filter_size;
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i);
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER
(
8
,
sum01_avg
)
UNROLL_CALL_NOWRAPPER
(
8
,
store_sum01
)
UNROLL_CALL_NOWRAPPER
(
8
,
sum23_avg
)
UNROLL_CALL_NOWRAPPER
(
8
,
store_sum23
)
sptr0
+=
16
;
sptr1
+=
16
;
sptr2
+=
16
;
dptr
+=
16
;
#undef store_sum01
#undef store_sum23
#undef sum01_avg
#undef sum23_avg
#undef sum_define
#undef CACULATE_ROW
}
for
(;
ow
<
OW
;
++
ow
)
{
int8x8_t
src001
,
src012
;
int16x8_t
src01
,
src12
,
sum01
,
sum02
;
sum01
=
vdupq_n_s16
(
0
);
sum02
=
vdupq_n_s16
(
0
);
#define CACULATE_ROW(i) \
src001 = vld1_s8(sptr##i); \
src012 = vld1_s8(sptr##i + 4); \
src01 = vmovl_s8(src001); \
src12 = vmovl_s8(src012); \
sum01 = vaddq_s16(sum01, src01); \
sum02 = vaddq_s16(sum02, src12);
UNROLL_CALL_NOWRAPPER
(
3
,
CACULATE_ROW
)
#define do_acc(i) \
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4) + \
vgetq_lane_s16(sum02, i + 4);
#define do_avg(i) \
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \
: (sum##i - filter_size / 2) / filter_size;
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER
(
4
,
do_acc
)
UNROLL_CALL_NOWRAPPER
(
4
,
do_avg
)
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
#undef store
#undef do_avg
#undef do_acc
#undef CACULATE_ROW
sptr0
+=
4
;
sptr1
+=
4
;
sptr2
+=
4
;
dptr
+=
4
;
}
}
}
void
do_avg_pooling_3x3_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
)
{
int16_t
filter_size
=
9
;
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
,
false
);
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
<<
1
;
const
int8_t
*
sptr0
=
sptr
+
(
ih
+
0
)
*
IW2
*
4
;
const
int8_t
*
sptr1
=
sptr
+
(
ih
+
1
)
*
IW2
*
4
;
const
int8_t
*
sptr2
=
sptr
+
(
ih
+
2
)
*
IW2
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int32x4x2_t
src_tmp
;
int8x16_t
src00
,
src04
;
int32x4_t
src0246
,
src1357
,
src2468
,
src08
;
int16x8_t
src02
,
src46
,
src13
,
src57
,
src24
,
src68
;
int16x8_t
sum01
=
vdupq_n_s16
(
0
);
int16x8_t
sum23
=
vdupq_n_s16
(
0
);
#define CACULATE_ROW(i) \
src00 = vld1q_s8(sptr##i); \
src04 = vld1q_s8(sptr##i + 4 * 4); \
src08 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 8)); \
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \
vreinterpretq_s32_s8(src04)); \
src0246 = src_tmp.val[0]; \
src1357 = src_tmp.val[1]; \
src2468 = vextq_s32(src0246, src08, 1); \
src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \
src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \
src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \
src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \
src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \
src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \
sum01 = vaddq_s16(sum01, src02); \
sum01 = vaddq_s16(sum01, src13); \
sum01 = vaddq_s16(sum01, src24); \
sum23 = vaddq_s16(sum23, src46); \
sum23 = vaddq_s16(sum23, src57); \
sum23 = vaddq_s16(sum23, src68);
UNROLL_CALL_NOWRAPPER
(
3
,
CACULATE_ROW
)
#define sum_define(i) int16_t sum##i;
UNROLL_CALL_NOWRAPPER
(
8
,
sum_define
)
#define sum01_avg(i) \
sum##i = vgetq_lane_s16(sum01, i) > 0 \
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \
filter_size;
#define sum23_avg(i) \
sum##i = vgetq_lane_s16(sum23, i) > 0 \
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \
filter_size;
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i);
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER
(
8
,
sum01_avg
)
UNROLL_CALL_NOWRAPPER
(
8
,
store_sum01
)
UNROLL_CALL_NOWRAPPER
(
8
,
sum23_avg
)
UNROLL_CALL_NOWRAPPER
(
8
,
store_sum23
)
sptr0
+=
32
;
sptr1
+=
32
;
sptr2
+=
32
;
dptr
+=
16
;
#undef store_sum01
#undef store_sum23
#undef sum01_avg
#undef sum23_avg
#undef sum_define
#undef CACULATE_ROW
}
for
(;
ow
<
OW
;
++
ow
)
{
int8x8_t
src001
,
src012
;
int16x8_t
src01
,
src12
,
sum01
,
sum02
;
sum01
=
vdupq_n_s16
(
0
);
sum02
=
vdupq_n_s16
(
0
);
#define CACULATE_ROW(i) \
src001 = vld1_s8(sptr##i); \
src012 = vld1_s8(sptr##i + 4); \
src01 = vmovl_s8(src001); \
src12 = vmovl_s8(src012); \
sum01 = vaddq_s16(sum01, src01); \
sum02 = vaddq_s16(sum02, src12);
UNROLL_CALL_NOWRAPPER
(
3
,
CACULATE_ROW
)
#define do_acc(i) \
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4) + \
vgetq_lane_s16(sum02, i + 4);
#define do_avg(i) \
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \
: (sum##i - filter_size / 2) / filter_size;
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER
(
4
,
do_acc
)
UNROLL_CALL_NOWRAPPER
(
4
,
do_avg
)
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
#undef store
#undef do_avg
#undef do_acc
sptr0
+=
8
;
sptr1
+=
8
;
sptr2
+=
8
;
dptr
+=
4
;
}
}
}
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.h
浏览文件 @
2ae9fdef
/**
* \file dnn/src/arm_common/pooling/do_
max_pooling_3x3_s2x2
_nchw44.h
* \file dnn/src/arm_common/pooling/do_
pooling_3x3
_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...
...
@@ -15,16 +15,15 @@
namespace
megdnn
{
namespace
arm_common
{
void
do_max_pooling_3x3_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
);
void
do_max_pooling_3x3_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
);
#define KERN(mode, stride, ctype) \
void do_##mode##_pooling_3x3_##stride##_##ctype##_nchw44_NEON( \
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \
size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws);
KERN
(
max
,
stride1
,
int8
)
KERN
(
max
,
stride2
,
int8
)
KERN
(
avg
,
stride1
,
int8
)
KERN
(
avg
,
stride2
,
int8
)
#undef KERN
}
// namespace arm_common
}
// namespace megdnn
...
...
dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.cpp
浏览文件 @
2ae9fdef
...
...
@@ -24,7 +24,7 @@ void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
const
WorkspaceBundle
&
ws
)
{
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
);
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
,
true
);
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
;
...
...
@@ -99,7 +99,7 @@ void do_max_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
const
WorkspaceBundle
&
ws
)
{
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
);
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
,
true
);
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
<<
1
;
...
...
@@ -171,6 +171,252 @@ void do_max_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
}
}
void
do_avg_pooling_4x4_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
)
{
int16_t
filter_size
=
16
;
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
,
false
);
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
;
const
int8_t
*
__restrict
sptr0
=
sptr
+
(
ih
+
0
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr1
=
sptr
+
(
ih
+
1
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr2
=
sptr
+
(
ih
+
2
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr3
=
sptr
+
(
ih
+
3
)
*
IW2
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int16x8_t
src01
,
src23
,
src12
,
src34
,
src45
,
src56
;
int16x8_t
sum01
=
vdupq_n_s16
(
0
);
int16x8_t
sum23
=
vdupq_n_s16
(
0
);
#define CACULATE_ROW(i) \
src01 = vmovl_s8(vld1_s8(sptr##i)); \
src23 = vmovl_s8(vld1_s8(sptr##i + 8)); \
src12 = vmovl_s8(vld1_s8(sptr##i + 4)); \
src34 = vmovl_s8(vld1_s8(sptr##i + 12)); \
src45 = vmovl_s8(vld1_s8(sptr##i + 16)); \
src56 = vmovl_s8(vld1_s8(sptr##i + 20)); \
sum01 = vaddq_s16(sum01, src01); \
sum01 = vaddq_s16(sum01, src12); \
sum01 = vaddq_s16(sum01, src23); \
sum01 = vaddq_s16(sum01, src34); \
sum23 = vaddq_s16(sum23, src23); \
sum23 = vaddq_s16(sum23, src34); \
sum23 = vaddq_s16(sum23, src45); \
sum23 = vaddq_s16(sum23, src56);
UNROLL_CALL_NOWRAPPER
(
4
,
CACULATE_ROW
)
#define sum_define(i) int16_t sum##i;
UNROLL_CALL_NOWRAPPER
(
8
,
sum_define
)
#define sum01_avg(i) \
sum##i = vgetq_lane_s16(sum01, i) > 0 \
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \
filter_size;
#define sum23_avg(i) \
sum##i = vgetq_lane_s16(sum23, i) > 0 \
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \
filter_size;
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i);
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER
(
8
,
sum01_avg
)
UNROLL_CALL_NOWRAPPER
(
8
,
store_sum01
)
UNROLL_CALL_NOWRAPPER
(
8
,
sum23_avg
)
UNROLL_CALL_NOWRAPPER
(
8
,
store_sum23
)
sptr0
+=
16
;
sptr1
+=
16
;
sptr2
+=
16
;
sptr3
+=
16
;
dptr
+=
16
;
#undef store_sum01
#undef store_sum23
#undef sum01_avg
#undef sum23_avg
#undef sum_define
#undef CACULATE_ROW
}
for
(;
ow
<
OW
;
++
ow
)
{
int16x8_t
src01
,
src23
,
sum01
;
sum01
=
vdupq_n_s16
(
0
);
#define CACULATE_ROW(i) \
src01 = vmovl_s8(vld1_s8(sptr##i)); \
src23 = vmovl_s8(vld1_s8(sptr##i + 8)); \
sum01 = vaddq_s16(sum01, src01); \
sum01 = vaddq_s16(sum01, src23);
UNROLL_CALL_NOWRAPPER
(
4
,
CACULATE_ROW
)
#define do_acc(i) \
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4);
#define do_avg(i) \
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \
: (sum##i - filter_size / 2) / filter_size;
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER
(
4
,
do_acc
)
UNROLL_CALL_NOWRAPPER
(
4
,
do_avg
)
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
#undef store
#undef do_avg
#undef do_acc
#undef CACULATE_ROW
sptr0
+=
4
;
sptr1
+=
4
;
sptr2
+=
4
;
sptr3
+=
4
;
dptr
+=
4
;
}
}
}
void
do_avg_pooling_4x4_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
)
{
int16_t
filter_size
=
16
;
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
,
false
);
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
<<
1
;
const
int8_t
*
sptr0
=
sptr
+
(
ih
+
0
)
*
IW2
*
4
;
const
int8_t
*
sptr1
=
sptr
+
(
ih
+
1
)
*
IW2
*
4
;
const
int8_t
*
sptr2
=
sptr
+
(
ih
+
2
)
*
IW2
*
4
;
const
int8_t
*
sptr3
=
sptr
+
(
ih
+
3
)
*
IW2
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int32x4x2_t
src_tmp
;
int8x16_t
src00
,
src04
;
int16x8_t
src02
,
src13
,
src57
,
src24
,
src68
,
src35
,
src79
,
src46
;
int32x4_t
src08
,
src09
,
src0246
,
src1357
,
src2468
,
src3579
;
int16x8_t
sum01
=
vdupq_n_s16
(
0
);
int16x8_t
sum23
=
vdupq_n_s16
(
0
);
#define CACULATE_ROW(i) \
src00 = vld1q_s8(sptr##i); \
src04 = vld1q_s8(sptr##i + 4 * 4); \
src08 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 8)); \
src09 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 9)); \
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \
vreinterpretq_s32_s8(src04)); \
src0246 = src_tmp.val[0]; \
src1357 = src_tmp.val[1]; \
src2468 = vextq_s32(src0246, src08, 1); \
src3579 = vextq_s32(src1357, src09, 1); \
src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \
src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \
src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \
src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \
src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \
src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \
src35 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src3579))); \
src79 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src3579))); \
sum01 = vaddq_s16(sum01, src02); \
sum01 = vaddq_s16(sum01, src13); \
sum01 = vaddq_s16(sum01, src24); \
sum01 = vaddq_s16(sum01, src35); \
sum23 = vaddq_s16(sum23, src46); \
sum23 = vaddq_s16(sum23, src57); \
sum23 = vaddq_s16(sum23, src68); \
sum23 = vaddq_s16(sum23, src79);
UNROLL_CALL_NOWRAPPER
(
4
,
CACULATE_ROW
)
#define sum_define(i) int16_t sum##i;
UNROLL_CALL_NOWRAPPER
(
8
,
sum_define
)
#define sum01_avg(i) \
sum##i = vgetq_lane_s16(sum01, i) > 0 \
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \
filter_size;
#define sum23_avg(i) \
sum##i = vgetq_lane_s16(sum23, i) > 0 \
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \
filter_size;
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i);
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER
(
8
,
sum01_avg
)
UNROLL_CALL_NOWRAPPER
(
8
,
store_sum01
)
UNROLL_CALL_NOWRAPPER
(
8
,
sum23_avg
)
UNROLL_CALL_NOWRAPPER
(
8
,
store_sum23
)
sptr0
+=
32
;
sptr1
+=
32
;
sptr2
+=
32
;
sptr3
+=
32
;
dptr
+=
16
;
#undef store_sum01
#undef store_sum23
#undef sum01_avg
#undef sum23_avg
#undef sum_define
#undef CACULATE_ROW
}
for
(;
ow
<
OW
;
++
ow
)
{
int8x8_t
src001
,
src023
;
int16x8_t
src01
,
src23
,
sum01
;
sum01
=
vdupq_n_s16
(
0
);
#define CACULATE_ROW(i) \
src001 = vld1_s8(sptr##i); \
src023 = vld1_s8(sptr##i + 8); \
src01 = vmovl_s8(src001); \
src23 = vmovl_s8(src023); \
sum01 = vaddq_s16(sum01, src01); \
sum01 = vaddq_s16(sum01, src23);
UNROLL_CALL_NOWRAPPER
(
4
,
CACULATE_ROW
)
#define do_acc(i) \
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4);
#define do_avg(i) \
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \
: (sum##i - filter_size / 2) / filter_size;
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER
(
4
,
do_acc
)
UNROLL_CALL_NOWRAPPER
(
4
,
do_avg
)
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
#undef store
#undef do_avg
#undef do_acc
#undef CACULATE_ROW
sptr0
+=
8
;
sptr1
+=
8
;
sptr2
+=
8
;
sptr3
+=
8
;
dptr
+=
4
;
}
}
}
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.h
浏览文件 @
2ae9fdef
/**
* \file dnn/src/arm_common/pooling/do_
max_
pooling_4x4_nchw44.h
* \file dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...
...
@@ -15,15 +15,16 @@
namespace
megdnn
{
namespace
arm_common
{
#define KERN(
strdie)
\
void do_
max_pooling_4x4_##strdie##_int8_nchw44_NEON(
\
#define KERN(
mode, stride, ctype)
\
void do_
##mode##_pooling_4x4_##stride##_##ctype##_nchw44_NEON(
\
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \
size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws);
KERN
(
stride1
)
KERN
(
stride2
)
KERN
(
max
,
stride1
,
int8
)
KERN
(
max
,
stride2
,
int8
)
KERN
(
avg
,
stride1
,
int8
)
KERN
(
avg
,
stride2
,
int8
)
#undef KERN
}
// namespace arm_common
}
// namespace megdnn
...
...
dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp
浏览文件 @
2ae9fdef
/**
* \file dnn/src/arm_common/pooling/do_
max_
pooling_5x5_nchw44.cpp
* \file dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...
...
@@ -24,7 +24,7 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
const
WorkspaceBundle
&
ws
)
{
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
);
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
,
true
);
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
;
...
...
@@ -118,7 +118,7 @@ void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
const
WorkspaceBundle
&
ws
)
{
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
);
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
,
true
);
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
<<
1
;
...
...
@@ -213,6 +213,284 @@ void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
}
}
void
do_avg_pooling_5x5_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
)
{
int16_t
filter_size
=
25
;
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
,
false
);
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
;
const
int8_t
*
__restrict
sptr0
=
sptr
+
(
ih
+
0
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr1
=
sptr
+
(
ih
+
1
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr2
=
sptr
+
(
ih
+
2
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr3
=
sptr
+
(
ih
+
3
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr4
=
sptr
+
(
ih
+
4
)
*
IW2
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int16x8_t
src01
,
src23
,
src12
,
src34
,
src45
,
src56
,
src67
;
int16x8_t
sum01
=
vdupq_n_s16
(
0
);
int16x8_t
sum23
=
vdupq_n_s16
(
0
);
#define CACULATE_ROW(i) \
src01 = vmovl_s8(vld1_s8(sptr##i)); \
src23 = vmovl_s8(vld1_s8(sptr##i + 8)); \
src12 = vmovl_s8(vld1_s8(sptr##i + 4)); \
src34 = vmovl_s8(vld1_s8(sptr##i + 12)); \
src45 = vmovl_s8(vld1_s8(sptr##i + 16)); \
src56 = vmovl_s8(vld1_s8(sptr##i + 20)); \
src67 = vmovl_s8(vld1_s8(sptr##i + 24)); \
sum01 = vaddq_s16(sum01, src01); \
sum01 = vaddq_s16(sum01, src12); \
sum01 = vaddq_s16(sum01, src23); \
sum01 = vaddq_s16(sum01, src34); \
sum01 = vaddq_s16(sum01, src45); \
sum23 = vaddq_s16(sum23, src23); \
sum23 = vaddq_s16(sum23, src34); \
sum23 = vaddq_s16(sum23, src45); \
sum23 = vaddq_s16(sum23, src56); \
sum23 = vaddq_s16(sum23, src67);
UNROLL_CALL_NOWRAPPER
(
5
,
CACULATE_ROW
)
#define sum_define(i) int16_t sum##i;
UNROLL_CALL_NOWRAPPER
(
8
,
sum_define
)
#define sum01_avg(i) \
sum##i = vgetq_lane_s16(sum01, i) > 0 \
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \
filter_size;
#define sum23_avg(i) \
sum##i = vgetq_lane_s16(sum23, i) > 0 \
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \
filter_size;
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i);
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER
(
8
,
sum01_avg
)
UNROLL_CALL_NOWRAPPER
(
8
,
store_sum01
)
UNROLL_CALL_NOWRAPPER
(
8
,
sum23_avg
)
UNROLL_CALL_NOWRAPPER
(
8
,
store_sum23
)
sptr0
+=
16
;
sptr1
+=
16
;
sptr2
+=
16
;
sptr3
+=
16
;
sptr4
+=
16
;
dptr
+=
16
;
#undef store_sum01
#undef store_sum23
#undef sum01_avg
#undef sum23_avg
#undef sum_define
#undef CACULATE_ROW
}
for
(;
ow
<
OW
;
++
ow
)
{
int32x2_t
src004
;
int8x8_t
src001
,
src023
;
int16x8_t
src01
,
src23
,
src04
,
sum01
,
sum02
;
sum01
=
vdupq_n_s16
(
0
);
sum02
=
vdupq_n_s16
(
0
);
#define CACULATE_ROW(i) \
src001 = vld1_s8(sptr##i); \
src023 = vld1_s8(sptr##i + 8); \
src004 = vld1_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 4)); \
src01 = vmovl_s8(src001); \
src23 = vmovl_s8(src023); \
src04 = vmovl_s8(vreinterpret_s8_s32(src004)); \
sum01 = vaddq_s16(sum01, src01); \
sum01 = vaddq_s16(sum01, src23); \
sum02 = vaddq_s16(sum02, src04);
UNROLL_CALL_NOWRAPPER
(
5
,
CACULATE_ROW
)
#define do_acc(i) \
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4) + \
vgetq_lane_s16(sum02, i + 4);
#define do_avg(i) \
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \
: (sum##i - filter_size / 2) / filter_size;
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER
(
4
,
do_acc
)
UNROLL_CALL_NOWRAPPER
(
4
,
do_avg
)
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
#undef store
#undef do_avg
#undef do_acc
#undef CACULATE_ROW
sptr0
+=
4
;
sptr1
+=
4
;
sptr2
+=
4
;
sptr3
+=
4
;
sptr4
+=
4
;
dptr
+=
4
;
}
}
}
void
do_avg_pooling_5x5_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
,
const
WorkspaceBundle
&
ws
)
{
int16_t
filter_size
=
25
;
const
int8_t
*
sptr
=
nullptr
;
size_t
IH2
,
IW2
;
sptr
=
handle_padding
(
src
,
IH
,
IW
,
IH2
,
IW2
,
PH
,
PW
,
ws
,
false
);
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
<<
1
;
const
int8_t
*
__restrict
sptr0
=
sptr
+
(
ih
+
0
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr1
=
sptr
+
(
ih
+
1
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr2
=
sptr
+
(
ih
+
2
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr3
=
sptr
+
(
ih
+
3
)
*
IW2
*
4
;
const
int8_t
*
__restrict
sptr4
=
sptr
+
(
ih
+
4
)
*
IW2
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int32x4x2_t
src_tmp
;
int8x16_t
src00
,
src04
;
int16x8_t
src02
,
src13
,
src57
,
src24
,
src68
,
src35
,
src79
,
src46
,
src810
;
int32x4_t
src08
,
src09
,
src10
,
src0246
,
src1357
,
src2468
,
src3579
,
src46810
;
int16x8_t
sum01
=
vdupq_n_s16
(
0
);
int16x8_t
sum23
=
vdupq_n_s16
(
0
);
#define CACULATE_ROW(i) \
src00 = vld1q_s8(sptr##i); \
src04 = vld1q_s8(sptr##i + 4 * 4); \
src08 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 8)); \
src09 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 9)); \
src10 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 10)); \
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \
vreinterpretq_s32_s8(src04)); \
src0246 = src_tmp.val[0]; \
src1357 = src_tmp.val[1]; \
src2468 = vextq_s32(src0246, src08, 1); \
src3579 = vextq_s32(src1357, src09, 1); \
src46810 = vextq_s32(src2468, src10, 1); \
src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \
src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \
src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \
src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \
src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \
src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \
src35 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src3579))); \
src79 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src3579))); \
src46 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src46810))); \
src810 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src46810))); \
sum01 = vaddq_s16(sum01, src02); \
sum01 = vaddq_s16(sum01, src13); \
sum01 = vaddq_s16(sum01, src24); \
sum01 = vaddq_s16(sum01, src35); \
sum01 = vaddq_s16(sum01, src46); \
sum23 = vaddq_s16(sum23, src46); \
sum23 = vaddq_s16(sum23, src57); \
sum23 = vaddq_s16(sum23, src68); \
sum23 = vaddq_s16(sum23, src79); \
sum23 = vaddq_s16(sum23, src810);
UNROLL_CALL_NOWRAPPER
(
5
,
CACULATE_ROW
)
#define sum_define(i) int16_t sum##i;
UNROLL_CALL_NOWRAPPER
(
8
,
sum_define
)
#define sum01_avg(i) \
sum##i = vgetq_lane_s16(sum01, i) > 0 \
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \
filter_size;
#define sum23_avg(i) \
sum##i = vgetq_lane_s16(sum23, i) > 0 \
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \
filter_size;
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i);
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER
(
8
,
sum01_avg
)
UNROLL_CALL_NOWRAPPER
(
8
,
store_sum01
)
UNROLL_CALL_NOWRAPPER
(
8
,
sum23_avg
)
UNROLL_CALL_NOWRAPPER
(
8
,
store_sum23
)
sptr0
+=
32
;
sptr1
+=
32
;
sptr2
+=
32
;
sptr3
+=
32
;
sptr4
+=
32
;
dptr
+=
16
;
#undef store_sum01
#undef store_sum23
#undef sum01_avg
#undef sum23_avg
#undef sum_define
#undef CACULATE_ROW
}
for
(;
ow
<
OW
;
++
ow
)
{
int32x2_t
src004
;
int8x8_t
src001
,
src023
;
int16x8_t
src01
,
src23
,
src04
,
sum01
,
sum02
;
sum01
=
vdupq_n_s16
(
0
);
sum02
=
vdupq_n_s16
(
0
);
#define CACULATE_ROW(i) \
src001 = vld1_s8(sptr##i); \
src023 = vld1_s8(sptr##i + 8); \
src004 = vld1_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 4)); \
src01 = vmovl_s8(src001); \
src23 = vmovl_s8(src023); \
src04 = vmovl_s8(vreinterpret_s8_s32(src004)); \
sum01 = vaddq_s16(sum01, src01); \
sum01 = vaddq_s16(sum01, src23); \
sum02 = vaddq_s16(sum02, src04);
UNROLL_CALL_NOWRAPPER
(
5
,
CACULATE_ROW
)
#define do_acc(i) \
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4) + \
vgetq_lane_s16(sum02, i + 4);
#define do_avg(i) \
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \
: (sum##i - filter_size / 2) / filter_size;
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER
(
4
,
do_acc
)
UNROLL_CALL_NOWRAPPER
(
4
,
do_avg
)
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
#undef store
#undef do_avg
#undef do_acc
#undef CACULATE_ROW
sptr0
+=
8
;
sptr1
+=
8
;
sptr2
+=
8
;
sptr3
+=
8
;
sptr4
+=
8
;
dptr
+=
4
;
}
}
}
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.h
浏览文件 @
2ae9fdef
/**
* \file dnn/src/arm_common/pooling/do_
max_pooling_4x4
_nchw44.h
* \file dnn/src/arm_common/pooling/do_
_pooling_5x5
_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...
...
@@ -15,15 +15,16 @@
namespace
megdnn
{
namespace
arm_common
{
#define KERN(
strdie)
\
void do_
max_pooling_5x5_##strdie##_int8_nchw44_NEON(
\
#define KERN(
mode, stride, ctype)
\
void do_
##mode##_pooling_5x5_##stride##_##ctype##_nchw44_NEON(
\
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \
size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws);
KERN
(
stride1
)
KERN
(
stride2
)
KERN
(
max
,
stride1
,
int8
)
KERN
(
max
,
stride2
,
int8
)
KERN
(
avg
,
stride1
,
int8
)
KERN
(
avg
,
stride2
,
int8
)
#undef KERN
}
// namespace arm_common
}
// namespace megdnn
...
...
dnn/src/arm_common/pooling/opr_impl.cpp
浏览文件 @
2ae9fdef
...
...
@@ -25,10 +25,10 @@ class PoolingImpl::AlgoPack : NonCopyableObj {
AlgoFilter5MaxStride2
algo_filter5_max_stride2
;
AlgoInt8Filter2MaxStride2
algo_int8_filter2_max_stride2
;
AlgoInt8Filter3MaxStride2
algo_int8_filter3_max_stride2
;
AlgoFilter2M
axStridexNCHW44
algo_filter2_ma
x_stridex_nchw4
;
AlgoFilter3M
axStridexNCHW44
algo_filter3_ma
x_stridex_nchw4
;
AlgoFilter4M
axStridexNCHW44
algo_filter4_ma
x_stridex_nchw4
;
AlgoFilter5M
axStridexNCHW44
algo_filter5_ma
x_stridex_nchw4
;
AlgoFilter2M
odexStridexNCHW44
algo_filter2_mode
x_stridex_nchw4
;
AlgoFilter3M
odexStridexNCHW44
algo_filter3_mode
x_stridex_nchw4
;
AlgoFilter4M
odexStridexNCHW44
algo_filter4_mode
x_stridex_nchw4
;
AlgoFilter5M
odexStridexNCHW44
algo_filter5_mode
x_stridex_nchw4
;
public:
AlgoPack
()
{
...
...
@@ -40,10 +40,10 @@ public:
all_algos
.
emplace_back
(
&
algo_filter5_max_stride2
);
all_algos
.
emplace_back
(
&
algo_int8_filter2_max_stride2
);
all_algos
.
emplace_back
(
&
algo_int8_filter3_max_stride2
);
all_algos
.
emplace_back
(
&
algo_filter3_m
a
x_stridex_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter2_m
a
x_stridex_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter4_m
a
x_stridex_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter5_m
a
x_stridex_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter3_m
ode
x_stridex_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter2_m
ode
x_stridex_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter4_m
ode
x_stridex_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter5_m
ode
x_stridex_nchw4
);
}
SmallVector
<
AlgoBase
*>
all_algos
;
};
...
...
dnn/src/arm_common/pooling/opr_impl.h
浏览文件 @
2ae9fdef
...
...
@@ -83,10 +83,10 @@ private:
class
AlgoFilter5MaxStride2
;
class
AlgoInt8Filter2MaxStride2
;
class
AlgoInt8Filter3MaxStride2
;
class
AlgoFilter2M
a
xStridexNCHW44
;
class
AlgoFilter3M
a
xStridexNCHW44
;
class
AlgoFilter4M
a
xStridexNCHW44
;
class
AlgoFilter5M
a
xStridexNCHW44
;
class
AlgoFilter2M
ode
xStridexNCHW44
;
class
AlgoFilter3M
ode
xStridexNCHW44
;
class
AlgoFilter4M
ode
xStridexNCHW44
;
class
AlgoFilter5M
ode
xStridexNCHW44
;
class
AlgoPack
;
};
}
// namespace arm_common
...
...
dnn/test/arm_common/pooling_multi_thread.cpp
浏览文件 @
2ae9fdef
...
...
@@ -10,6 +10,7 @@
*/
#include <vector>
#include "megdnn/dtype.h"
#include "megdnn/opr_param_defs.h"
#include "test/arm_common/fixture.h"
#include "test/common/pooling.h"
...
...
@@ -56,13 +57,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING) {
}
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_
MAX_
W3x3_NCHW44
)
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_W3x3_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
3
,
5
,
10
})
for
(
size_t
iw
:
{
3
,
5
,
7
,
9
,
15
,
20
})
for
(
size_t
ph
:
{
0
,
1
,
2
})
for
(
size_t
pw
:
{
0
,
1
,
2
})
for
(
auto
mode
:
{
param
::
Pooling
::
Mode
::
MAX
,
param
::
Pooling
::
Mode
::
AVERAGE
})
if
(
ih
+
2
*
ph
>=
3
&&
iw
+
2
*
pw
>=
3
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
...
...
@@ -71,7 +73,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_NCHW44)
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
mode
=
mode
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
...
...
@@ -86,13 +88,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_NCHW44)
// clang-format on
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_
MAX_
W2x2_NCHW44
)
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_W2x2_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
2
,
5
,
10
,
17
})
for
(
size_t
iw
:
{
2
,
6
,
8
,
16
,
26
})
for
(
size_t
ph
:
{
0
,
1
})
for
(
size_t
pw
:
{
0
,
1
})
for
(
auto
mode
:
{
param
::
Pooling
::
Mode
::
MAX
,
param
::
Pooling
::
Mode
::
AVERAGE
})
if
(
ih
+
2
*
ph
>=
2
&&
iw
+
2
*
pw
>=
2
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
...
...
@@ -101,7 +104,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_NCHW44)
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
mode
=
mode
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
...
...
@@ -115,13 +118,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_NCHW44)
// clang-format on
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_
MAX_
W4x4_NCHW44
)
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_W4x4_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
4
,
10
,
18
,
25
,
30
})
for
(
size_t
iw
:
{
4
,
12
,
17
,
20
,
25
})
for
(
size_t
ph
:
{
0
,
1
,
2
})
for
(
size_t
pw
:
{
0
,
1
,
2
})
for
(
auto
mode
:
{
param
::
Pooling
::
Mode
::
MAX
,
param
::
Pooling
::
Mode
::
AVERAGE
})
if
(
ih
+
2
*
ph
>=
4
&&
iw
+
2
*
pw
>=
4
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
...
...
@@ -130,7 +134,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_NCHW44)
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
mode
=
mode
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
...
...
@@ -143,13 +147,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_NCHW44)
}
// clang-format on
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_
MAX_
W5x5_NCHW44
)
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_W5x5_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
5
,
9
,
19
,
20
,
39
})
for
(
size_t
iw
:
{
5
,
12
,
23
,
27
,
39
})
for
(
size_t
ph
:
{
0
,
1
,
2
})
for
(
size_t
pw
:
{
0
,
1
,
2
})
for
(
auto
mode
:
{
param
::
Pooling
::
Mode
::
MAX
,
param
::
Pooling
::
Mode
::
AVERAGE
})
if
(
ih
+
2
*
ph
>=
5
&&
iw
+
2
*
pw
>=
5
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
...
...
@@ -158,7 +163,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_NCHW44)
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
mode
=
mode
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
...
...
@@ -477,31 +482,37 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_POOLING_NCHW44) {
std
::
vector
<
SmallVector
<
TensorShape
>>
shapes
;
std
::
vector
<
std
::
vector
<
size_t
>>
filter_and_stride
=
{
{
2
,
1
},
{
2
,
2
},
{
3
,
1
},
{
3
,
2
},
{
4
,
1
},
{
4
,
2
},
{
5
,
1
},
{
5
,
2
}};
for
(
auto
filter
:
filter_and_stride
){
shapes
.
push_back
({{
1
,
32
*
4
,
215
,
215
},
{}});
shapes
.
push_back
({{
1
,
32
*
4
,
128
,
128
},
{}});
shapes
.
push_back
({{
1
,
16
*
4
,
56
,
56
},
{}});
param
.
window_h
=
param
.
window_w
=
filter
[
0
];
param
.
stride_h
=
param
.
stride_w
=
filter
[
1
];
param
.
format
=
Param
::
Format
::
NCHW
;
printf
(
"NCHW Benchmark POOLING kernel:%d*%d stride:%d,mode %d
\n
"
,
param
.
window_h
,
param
.
window_h
,
param
.
stride_h
,
static_cast
<
int
>
(
param
.
mode
));
benchmark_impl
<
Pooling
>
(
param
,
shapes
,
RUNS
,
{
4
,
{
4
,
5
,
6
,
7
}},
{
1
,
{
4
}},
dtype
::
QuantizedS8
(
1.1
f
));
shapes
.
clear
();
shapes
.
push_back
({{
1
,
32
,
215
,
215
,
4
},
{}});
shapes
.
push_back
({{
1
,
32
,
128
,
128
,
4
},
{}});
shapes
.
push_back
({{
1
,
16
,
56
,
56
,
4
},
{}});
param
.
format
=
Param
::
Format
::
NCHW44
;
printf
(
"NCHW44 Benchmark POOLING kernel:%d*%d stride:%d,mode %d
\n
"
,
param
.
window_h
,
param
.
window_w
,
param
.
stride_h
,
static_cast
<
int
>
(
param
.
mode
));
benchmark_impl
<
Pooling
>
(
param
,
shapes
,
RUNS
,
{
4
,
{
4
,
5
,
6
,
7
}},
{
1
,
{
4
}},
dtype
::
QuantizedS8
(
1.1
f
));
shapes
.
clear
();
}
for
(
auto
mode
:
{
param
::
Pooling
::
Mode
::
MAX
,
param
::
Pooling
::
Mode
::
AVERAGE
})
{
for
(
auto
filter
:
filter_and_stride
)
{
shapes
.
push_back
({{
1
,
32
*
4
,
215
,
215
},
{}});
shapes
.
push_back
({{
1
,
32
*
4
,
128
,
128
},
{}});
shapes
.
push_back
({{
1
,
16
*
4
,
56
,
56
},
{}});
param
.
mode
=
mode
;
param
.
window_h
=
param
.
window_w
=
filter
[
0
];
param
.
stride_h
=
param
.
stride_w
=
filter
[
1
];
param
.
format
=
Param
::
Format
::
NCHW
;
printf
(
"NCHW Benchmark POOLING kernel:%d*%d stride:%d,mode %d
\n
"
,
param
.
window_h
,
param
.
window_h
,
param
.
stride_h
,
static_cast
<
int
>
(
param
.
mode
));
benchmark_impl
<
Pooling
>
(
param
,
shapes
,
RUNS
,
{
4
,
{
4
,
5
,
6
,
7
}},
{
1
,
{
4
}},
dtype
::
QuantizedS8
(
1.1
f
));
shapes
.
clear
();
shapes
.
push_back
({{
1
,
32
,
215
,
215
,
4
},
{}});
shapes
.
push_back
({{
1
,
32
,
128
,
128
,
4
},
{}});
shapes
.
push_back
({{
1
,
16
,
56
,
56
,
4
},
{}});
param
.
format
=
Param
::
Format
::
NCHW44
;
printf
(
"NCHW44 Benchmark POOLING kernel:%d*%d stride:%d,mode %d
\n
"
,
param
.
window_h
,
param
.
window_w
,
param
.
stride_h
,
static_cast
<
int
>
(
param
.
mode
));
benchmark_impl
<
Pooling
>
(
param
,
shapes
,
RUNS
,
{
4
,
{
4
,
5
,
6
,
7
}},
{
1
,
{
4
}},
dtype
::
QuantizedS8
(
1.1
f
));
shapes
.
clear
();
}
}
}
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录