Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b336db65
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b336db65
编写于
4月 25, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
5月 12, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/arm): add arm nchw44 filter3x3 strdie1x1 max pooling
GitOrigin-RevId: b0d54d38ad9e9782a49853767f1192290df2c854
上级
9837bc00
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
233 addition
and
0 deletion
+233
-0
dnn/src/arm_common/pooling/algo.cpp
dnn/src/arm_common/pooling/algo.cpp
+54
-0
dnn/src/arm_common/pooling/algo.h
dnn/src/arm_common/pooling/algo.h
+8
-0
dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.cpp
...src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.cpp
+91
-0
dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h
dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h
+25
-0
dnn/src/arm_common/pooling/opr_impl.cpp
dnn/src/arm_common/pooling/opr_impl.cpp
+2
-0
dnn/src/arm_common/pooling/opr_impl.h
dnn/src/arm_common/pooling/opr_impl.h
+1
-0
dnn/test/arm_common/pooling.cpp
dnn/test/arm_common/pooling.cpp
+26
-0
dnn/test/arm_common/pooling_multi_thread.cpp
dnn/test/arm_common/pooling_multi_thread.cpp
+26
-0
未找到文件。
dnn/src/arm_common/pooling/algo.cpp
浏览文件 @
b336db65
...
...
@@ -13,6 +13,7 @@
#include "megdnn/opr_param_defs.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h"
#include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h"
...
...
@@ -612,6 +613,59 @@ void PoolingImpl::AlgoFilter3MaxStride2NCHW44::exec(
#undef DISPATCH_FUNC
}
bool
PoolingImpl
::
AlgoFilter3MaxStride1NCHW44
::
usable
(
const
PoolingKernSizeParam
&
param
)
const
{
auto
SH
=
param
.
stride
[
0
];
auto
SW
=
param
.
stride
[
1
];
auto
FH
=
param
.
filter
[
0
];
auto
FW
=
param
.
filter
[
1
];
auto
PH
=
param
.
padding
[
0
];
auto
PW
=
param
.
padding
[
1
];
bool
avaible
=
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
param
.
format
==
Param
::
Format
::
NCHW44
&&
param
.
mode
==
Mode
::
MAX
&&
FH
==
3
&&
FW
==
3
&&
SH
==
1
&&
SW
==
1
&&
PH
==
0
&&
PW
==
0
;
return
avaible
;
}
void
PoolingImpl
::
AlgoFilter3MaxStride1NCHW44
::
exec
(
const
PoolingKernParam
&
param
)
const
{
auto
IH
=
param
.
isz
[
0
],
IW
=
param
.
isz
[
1
];
auto
OH
=
param
.
osz
[
0
],
OW
=
param
.
osz
[
1
];
auto
N
=
param
.
n
,
C
=
param
.
ic
;
auto
PH
=
param
.
padding
[
0
];
auto
PW
=
param
.
padding
[
1
];
void
*
src_ptr
=
param
.
src_ptr
;
void
*
dst_ptr
=
param
.
dst_ptr
;
#define DISPATCH_FUNC(type, func, midout_type_id) \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
midout_iv(midout_type_id)) { \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \
size_t index, size_t thread_id) { \
MEGDNN_MARK_USED_VAR(thread_id); \
size_t n = index / C; \
size_t c = index % C; \
do_max_pooling_3x3_s1x1_##func##_nchw44_NEON( \
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \
c * IH * IW * 4, \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
run); \
} \
MIDOUT_END();
DISPATCH_FUNC
(
int8_t
,
int8
,
10
);
#undef DISPATCH_FUNC
}
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
...
...
dnn/src/arm_common/pooling/algo.h
浏览文件 @
b336db65
...
...
@@ -91,6 +91,14 @@ public:
void
exec
(
const
PoolingKernParam
&
param
)
const
override
;
};
class
PoolingImpl
::
AlgoFilter3MaxStride1NCHW44
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"ARM_POOLING_FILTER3_MAX_STRIDE1_NCHW44"
;
}
bool
usable
(
const
PoolingKernSizeParam
&
param
)
const
override
;
void
exec
(
const
PoolingKernParam
&
param
)
const
override
;
};
WorkspaceBundle
get_bundle
(
const
PoolingImpl
::
PoolingKernSizeParam
&
param
);
}
// namespace arm_common
...
...
dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.cpp
0 → 100644
浏览文件 @
b336db65
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
namespace
megdnn
{
namespace
arm_common
{
void
do_max_pooling_3x3_s1x1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
)
{
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
;
const
int8_t
*
__restrict
sptr0
=
src
+
(
ih
+
0
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr1
=
src
+
(
ih
+
1
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr2
=
src
+
(
ih
+
2
)
*
IW
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int8x16_t
src0123
=
vld1q_s8
(
sptr0
);
int8x16_t
src1234
=
vld1q_s8
(
sptr0
+
4
);
int8x16_t
src2345
=
vld1q_s8
(
sptr0
+
8
);
int8x16_t
max0
=
vmaxq_s8
(
src0123
,
src1234
);
max0
=
vmaxq_s8
(
max0
,
src2345
);
src0123
=
vld1q_s8
(
sptr1
);
src1234
=
vld1q_s8
(
sptr1
+
4
);
src2345
=
vld1q_s8
(
sptr1
+
8
);
int8x16_t
max1
=
vmaxq_s8
(
src0123
,
src1234
);
max1
=
vmaxq_s8
(
max1
,
src2345
);
src0123
=
vld1q_s8
(
sptr2
);
src1234
=
vld1q_s8
(
sptr2
+
4
);
src2345
=
vld1q_s8
(
sptr2
+
8
);
int8x16_t
max2
=
vmaxq_s8
(
src0123
,
src1234
);
max2
=
vmaxq_s8
(
max2
,
src2345
);
int8x16_t
max_out
=
vmaxq_s8
(
max0
,
max1
);
max_out
=
vmaxq_s8
(
max_out
,
max2
);
vst1q_s8
(
dptr
,
max_out
);
sptr0
+=
16
;
sptr1
+=
16
;
sptr2
+=
16
;
dptr
+=
16
;
}
for
(;
ow
<
OW
;
++
ow
)
{
int8x8_t
src001
=
vld1_s8
(
sptr0
);
int8x8_t
src012
=
vld1_s8
(
sptr0
+
4
);
int8x8_t
src101
=
vld1_s8
(
sptr1
);
int8x8_t
src112
=
vld1_s8
(
sptr1
+
4
);
int8x8_t
src201
=
vld1_s8
(
sptr2
);
int8x8_t
src212
=
vld1_s8
(
sptr2
+
4
);
int8x8_t
max01_tmp
=
vmax_s8
(
src001
,
src101
);
max01_tmp
=
vmax_s8
(
max01_tmp
,
src201
);
int8x8_t
max12_tmp
=
vmax_s8
(
src012
,
src112
);
max12_tmp
=
vmax_s8
(
max12_tmp
,
src212
);
#define cb(i) \
int8_t dst##i = std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), \
max12_tmp[i + 4]);
#define store(i) *(dptr + i) = dst##i;
UNROLL_CALL_NOWRAPPER
(
4
,
cb
)
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
#undef store
#undef cb
sptr0
+=
4
;
sptr1
+=
4
;
sptr2
+=
4
;
dptr
+=
4
;
}
}
}
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h
0 → 100644
浏览文件 @
b336db65
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/common/utils.h"
namespace
megdnn
{
namespace
arm_common
{
void
do_max_pooling_3x3_s1x1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
);
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/arm_common/pooling/opr_impl.cpp
浏览文件 @
b336db65
...
...
@@ -26,6 +26,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj {
AlgoInt8Filter2MaxStride2
algo_int8_filter2_max_stride2
;
AlgoInt8Filter3MaxStride2
algo_int8_filter3_max_stride2
;
AlgoFilter3MaxStride2NCHW44
algo_filter3_max_stride2_nchw4
;
AlgoFilter3MaxStride1NCHW44
algo_filter3_max_stride1_nchw4
;
public:
AlgoPack
()
{
...
...
@@ -38,6 +39,7 @@ public:
all_algos
.
emplace_back
(
&
algo_int8_filter2_max_stride2
);
all_algos
.
emplace_back
(
&
algo_int8_filter3_max_stride2
);
all_algos
.
emplace_back
(
&
algo_filter3_max_stride2_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter3_max_stride1_nchw4
);
}
SmallVector
<
AlgoBase
*>
all_algos
;
};
...
...
dnn/src/arm_common/pooling/opr_impl.h
浏览文件 @
b336db65
...
...
@@ -84,6 +84,7 @@ private:
class
AlgoInt8Filter2MaxStride2
;
class
AlgoInt8Filter3MaxStride2
;
class
AlgoFilter3MaxStride2NCHW44
;
class
AlgoFilter3MaxStride1NCHW44
;
class
AlgoPack
;
};
}
// namespace arm_common
...
...
dnn/test/arm_common/pooling.cpp
浏览文件 @
b336db65
...
...
@@ -128,6 +128,32 @@ TEST_F(ARM_COMMON, POOLING_MAX_W3x3_S2x2_NCHW44)
// clang-format on
}
TEST_F
(
ARM_COMMON
,
POOLING_MAX_W3x3_S1x1_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
3
,
5
,
10
})
for
(
size_t
iw
:
{
3
,
5
,
7
,
9
,
15
,
20
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
3
&&
iw
+
2
*
pw
>=
3
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
window_h
=
param
.
window_w
=
3
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F
(
ARM_COMMON
,
POOLING_FP16
)
{
Checker
<
Pooling
>
checker
(
handle
());
...
...
dnn/test/arm_common/pooling_multi_thread.cpp
浏览文件 @
b336db65
...
...
@@ -78,6 +78,32 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_S2x2_NCHW44)
}
// clang-format on
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_MAX_W3x3_S1x1_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
3
,
5
,
10
})
for
(
size_t
iw
:
{
3
,
5
,
7
,
9
,
15
,
20
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
3
&&
iw
+
2
*
pw
>=
3
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
window_h
=
param
.
window_w
=
3
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_INT8_W3x3_S2x2
)
{
for
(
size_t
ih
:
{
2
,
3
,
7
,
13
,
52
,
53
,
54
,
55
})
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录