Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ab401aba
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
ab401aba
编写于
4月 25, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
5月 12, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/arm): add arm nchw44 filter2x2 strdie1 and stride2 max pooling
GitOrigin-RevId: 42d144a8139de203d87f1d5753487e1020b14dca
上级
b336db65
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
340 addition
and
1 deletion
+340
-1
dnn/src/arm_common/pooling/algo.cpp
dnn/src/arm_common/pooling/algo.cpp
+71
-1
dnn/src/arm_common/pooling/algo.h
dnn/src/arm_common/pooling/algo.h
+8
-0
dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.cpp
dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.cpp
+126
-0
dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.h
dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.h
+30
-0
dnn/src/arm_common/pooling/opr_impl.cpp
dnn/src/arm_common/pooling/opr_impl.cpp
+2
-0
dnn/src/arm_common/pooling/opr_impl.h
dnn/src/arm_common/pooling/opr_impl.h
+1
-0
dnn/test/arm_common/pooling.cpp
dnn/test/arm_common/pooling.cpp
+51
-0
dnn/test/arm_common/pooling_multi_thread.cpp
dnn/test/arm_common/pooling_multi_thread.cpp
+51
-0
未找到文件。
dnn/src/arm_common/pooling/algo.cpp
浏览文件 @
ab401aba
...
...
@@ -11,9 +11,10 @@
*/
#include "src/arm_common/pooling/algo.h"
#include "megdnn/opr_param_defs.h"
#include "src/arm_common/pooling/do_max_pooling_2x2_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h"
#include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h"
...
...
@@ -666,6 +667,75 @@ void PoolingImpl::AlgoFilter3MaxStride1NCHW44::exec(
#undef DISPATCH_FUNC
}
bool
PoolingImpl
::
AlgoFilter2MaxStridexNCHW44
::
usable
(
const
PoolingKernSizeParam
&
param
)
const
{
auto
SH
=
param
.
stride
[
0
];
auto
SW
=
param
.
stride
[
1
];
auto
FH
=
param
.
filter
[
0
];
auto
FW
=
param
.
filter
[
1
];
auto
PH
=
param
.
padding
[
0
];
auto
PW
=
param
.
padding
[
1
];
bool
avaible
=
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
param
.
format
==
Param
::
Format
::
NCHW44
&&
param
.
mode
==
Mode
::
MAX
&&
FH
==
2
&&
FW
==
2
&&
SH
==
SW
&&
(
SW
==
1
||
SW
==
2
)
&&
PH
==
0
&&
PW
==
0
;
return
avaible
;
}
void
PoolingImpl
::
AlgoFilter2MaxStridexNCHW44
::
exec
(
const
PoolingKernParam
&
param
)
const
{
auto
IH
=
param
.
isz
[
0
],
IW
=
param
.
isz
[
1
];
auto
OH
=
param
.
osz
[
0
],
OW
=
param
.
osz
[
1
];
auto
N
=
param
.
n
,
C
=
param
.
ic
;
auto
PH
=
param
.
padding
[
0
];
auto
PW
=
param
.
padding
[
1
];
auto
SW
=
param
.
stride
[
0
];
void
*
src_ptr
=
param
.
src_ptr
;
void
*
dst_ptr
=
param
.
dst_ptr
;
#define DISPATCH_FUNC(type, func, midout_type_id, i) \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
midout_iv(midout_type_id)) { \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \
size_t index, size_t thread_id) { \
MEGDNN_MARK_USED_VAR(thread_id); \
size_t n = index / C; \
size_t c = index % C; \
do_max_pooling_2x2_stride##i##_##func##_nchw44_NEON( \
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \
c * IH * IW * 4, \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
run); \
} \
MIDOUT_END();
#define DISPATCH_STRIDE(type, func, midout_type_id) \
switch (SW) { \
case 1: { \
DISPATCH_FUNC(type, func, midout_type_id, 1); \
break; \
} \
case 2: { \
DISPATCH_FUNC(type, func, midout_type_id, 2); \
break; \
} \
default: \
megdnn_assert(0, "unsupport stride size"); \
}
DISPATCH_STRIDE
(
int8_t
,
int8
,
10
);
#undef DISPATCH_STRIDE
#undef DISPATCH_FUNC
}
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
...
...
dnn/src/arm_common/pooling/algo.h
浏览文件 @
ab401aba
...
...
@@ -99,6 +99,14 @@ public:
void
exec
(
const
PoolingKernParam
&
param
)
const
override
;
};
class
PoolingImpl
::
AlgoFilter2MaxStridexNCHW44
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"ARM_POOLING_FILTER2_MAX_STRIDEX_NCHW44"
;
}
bool
usable
(
const
PoolingKernSizeParam
&
param
)
const
override
;
void
exec
(
const
PoolingKernParam
&
param
)
const
override
;
};
WorkspaceBundle
get_bundle
(
const
PoolingImpl
::
PoolingKernSizeParam
&
param
);
}
// namespace arm_common
...
...
dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.cpp
0 → 100644
浏览文件 @
ab401aba
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/pooling/do_max_pooling_2x2_nchw44.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
namespace
megdnn
{
namespace
arm_common
{
void
do_max_pooling_2x2_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
)
{
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
;
const
int8_t
*
__restrict
sptr0
=
src
+
(
ih
+
0
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr1
=
src
+
(
ih
+
1
)
*
IW
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int8x16_t
src0123
=
vld1q_s8
(
sptr0
);
int8x16_t
src1234
=
vld1q_s8
(
sptr0
+
4
);
int8x16_t
max0
=
vmaxq_s8
(
src0123
,
src1234
);
src0123
=
vld1q_s8
(
sptr1
);
src1234
=
vld1q_s8
(
sptr1
+
4
);
int8x16_t
max1
=
vmaxq_s8
(
src0123
,
src1234
);
int8x16_t
max_out
=
vmaxq_s8
(
max0
,
max1
);
vst1q_s8
(
dptr
,
max_out
);
sptr0
+=
16
;
sptr1
+=
16
;
dptr
+=
16
;
}
for
(;
ow
<
OW
;
++
ow
)
{
int8x8_t
src001
=
vld1_s8
(
sptr0
);
int8x8_t
src012
=
vld1_s8
(
sptr0
+
4
);
int8x8_t
src101
=
vld1_s8
(
sptr1
);
int8x8_t
src112
=
vld1_s8
(
sptr1
+
4
);
int8x8_t
max01_tmp
=
vmax_s8
(
src001
,
src101
);
int8x8_t
max12_tmp
=
vmax_s8
(
src012
,
src112
);
int8x8_t
mat_out
=
vmax_s8
(
max01_tmp
,
max12_tmp
);
#define store(i) *(dptr + i) = mat_out[i];
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
#undef store
sptr0
+=
4
;
sptr1
+=
4
;
dptr
+=
4
;
}
}
}
void
do_max_pooling_2x2_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
)
{
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
<<
1
;
const
int8_t
*
__restrict
sptr0
=
src
+
(
ih
+
0
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr1
=
src
+
(
ih
+
1
)
*
IW
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int8x16_t
src00
=
vld1q_s8
(
sptr0
);
int8x16_t
src04
=
vld1q_s8
(
sptr0
+
4
*
4
);
int32x4x2_t
src_tmp
=
vuzpq_s32
(
vreinterpretq_s32_s8
(
src00
),
vreinterpretq_s32_s8
(
src04
));
int32x4_t
src0246
=
src_tmp
.
val
[
0
];
int32x4_t
src1357
=
src_tmp
.
val
[
1
];
int8x16_t
max0
=
vmaxq_s8
(
vreinterpretq_s8_s32
(
src0246
),
vreinterpretq_s8_s32
(
src1357
));
src00
=
vld1q_s8
(
sptr1
);
src04
=
vld1q_s8
(
sptr1
+
4
*
4
);
src_tmp
=
vuzpq_s32
(
vreinterpretq_s32_s8
(
src00
),
vreinterpretq_s32_s8
(
src04
));
src0246
=
src_tmp
.
val
[
0
];
src1357
=
src_tmp
.
val
[
1
];
int8x16_t
max1
=
vmaxq_s8
(
vreinterpretq_s8_s32
(
src0246
),
vreinterpretq_s8_s32
(
src1357
));
int8x16_t
max_out
=
vmaxq_s8
(
max0
,
max1
);
vst1q_s8
(
dptr
,
max_out
);
sptr0
+=
32
;
sptr1
+=
32
;
dptr
+=
16
;
}
for
(;
ow
<
OW
;
++
ow
)
{
int8x8_t
src001
=
vld1_s8
(
sptr0
);
int8x8_t
src012
=
vld1_s8
(
sptr0
+
4
);
int8x8_t
src101
=
vld1_s8
(
sptr1
);
int8x8_t
src112
=
vld1_s8
(
sptr1
+
4
);
int8x8_t
max01_tmp
=
vmax_s8
(
src001
,
src101
);
int8x8_t
max12_tmp
=
vmax_s8
(
src012
,
src112
);
int8x8_t
mat_out
=
vmax_s8
(
max01_tmp
,
max12_tmp
);
#define store(i) *(dptr + i) = mat_out[i];
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
#undef store
sptr0
+=
8
;
sptr1
+=
8
;
dptr
+=
4
;
}
}
}
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.h
0 → 100644
浏览文件 @
ab401aba
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/common/utils.h"
namespace
megdnn
{
namespace
arm_common
{
void
do_max_pooling_2x2_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
);
void
do_max_pooling_2x2_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
);
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/arm_common/pooling/opr_impl.cpp
浏览文件 @
ab401aba
...
...
@@ -27,6 +27,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj {
AlgoInt8Filter3MaxStride2
algo_int8_filter3_max_stride2
;
AlgoFilter3MaxStride2NCHW44
algo_filter3_max_stride2_nchw4
;
AlgoFilter3MaxStride1NCHW44
algo_filter3_max_stride1_nchw4
;
AlgoFilter2MaxStridexNCHW44
algo_filter2_max_stridex_nchw4
;
public:
AlgoPack
()
{
...
...
@@ -40,6 +41,7 @@ public:
all_algos
.
emplace_back
(
&
algo_int8_filter3_max_stride2
);
all_algos
.
emplace_back
(
&
algo_filter3_max_stride2_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter3_max_stride1_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter2_max_stridex_nchw4
);
}
SmallVector
<
AlgoBase
*>
all_algos
;
};
...
...
dnn/src/arm_common/pooling/opr_impl.h
浏览文件 @
ab401aba
...
...
@@ -85,6 +85,7 @@ private:
class
AlgoInt8Filter3MaxStride2
;
class
AlgoFilter3MaxStride2NCHW44
;
class
AlgoFilter3MaxStride1NCHW44
;
class
AlgoFilter2MaxStridexNCHW44
;
class
AlgoPack
;
};
}
// namespace arm_common
...
...
dnn/test/arm_common/pooling.cpp
浏览文件 @
ab401aba
...
...
@@ -154,6 +154,57 @@ TEST_F(ARM_COMMON, POOLING_MAX_W3x3_S1x1_NCHW44)
// clang-format on
}
TEST_F
(
ARM_COMMON
,
POOLING_MAX_W2x2_S1x1_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
2
,
5
,
10
,
17
})
for
(
size_t
iw
:
{
2
,
6
,
8
,
16
,
26
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
2
&&
iw
+
2
*
pw
>=
2
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
window_h
=
param
.
window_w
=
2
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON
,
POOLING_MAX_W2x2_S2x2_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
2
,
5
,
10
,
17
})
for
(
size_t
iw
:
{
2
,
6
,
8
,
16
,
26
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
2
&&
iw
+
2
*
pw
>=
2
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
window_h
=
param
.
window_w
=
2
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F
(
ARM_COMMON
,
POOLING_FP16
)
{
Checker
<
Pooling
>
checker
(
handle
());
...
...
dnn/test/arm_common/pooling_multi_thread.cpp
浏览文件 @
ab401aba
...
...
@@ -104,6 +104,57 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_S1x1_NCHW44)
// clang-format on
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_MAX_W2x2_S1x1_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
2
,
5
,
10
,
17
})
for
(
size_t
iw
:
{
2
,
6
,
8
,
16
,
26
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
3
&&
iw
+
2
*
pw
>=
3
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
window_h
=
param
.
window_w
=
2
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_MAX_W2x2_S2x2_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
2
,
5
,
10
,
17
})
for
(
size_t
iw
:
{
2
,
6
,
8
,
16
,
26
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
3
&&
iw
+
2
*
pw
>=
3
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
window_h
=
param
.
window_w
=
2
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_INT8_W3x3_S2x2
)
{
for
(
size_t
ih
:
{
2
,
3
,
7
,
13
,
52
,
53
,
54
,
55
})
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录