Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
15cca8f9
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
15cca8f9
编写于
4月 25, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
5月 12, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/arm): add arm nchw44 filter4x4 strdie1 and stride2 max pooling
GitOrigin-RevId: 3a4da20c597bac89a667e1c9d085591184260b1a
上级
ab401aba
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
381 addition
and
2 deletion
+381
-2
dnn/src/arm_common/pooling/algo.cpp
dnn/src/arm_common/pooling/algo.cpp
+70
-0
dnn/src/arm_common/pooling/algo.h
dnn/src/arm_common/pooling/algo.h
+8
-0
dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.cpp
dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.cpp
+168
-0
dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.h
dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.h
+30
-0
dnn/src/arm_common/pooling/opr_impl.cpp
dnn/src/arm_common/pooling/opr_impl.cpp
+3
-1
dnn/src/arm_common/pooling/opr_impl.h
dnn/src/arm_common/pooling/opr_impl.h
+2
-1
dnn/test/arm_common/pooling.cpp
dnn/test/arm_common/pooling.cpp
+50
-0
dnn/test/arm_common/pooling_multi_thread.cpp
dnn/test/arm_common/pooling_multi_thread.cpp
+50
-0
未找到文件。
dnn/src/arm_common/pooling/algo.cpp
浏览文件 @
15cca8f9
...
...
@@ -12,6 +12,7 @@
#include "src/arm_common/pooling/algo.h"
#include "megdnn/opr_param_defs.h"
#include "src/arm_common/pooling/do_max_pooling_2x2_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_4x4_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h"
...
...
@@ -736,6 +737,75 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec(
#undef DISPATCH_FUNC
}
bool
PoolingImpl
::
AlgoFilter4MaxStridexNCHW44
::
usable
(
const
PoolingKernSizeParam
&
param
)
const
{
auto
SH
=
param
.
stride
[
0
];
auto
SW
=
param
.
stride
[
1
];
auto
FH
=
param
.
filter
[
0
];
auto
FW
=
param
.
filter
[
1
];
auto
PH
=
param
.
padding
[
0
];
auto
PW
=
param
.
padding
[
1
];
bool
avaible
=
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
param
.
format
==
Param
::
Format
::
NCHW44
&&
param
.
mode
==
Mode
::
MAX
&&
FH
==
4
&&
FW
==
4
&&
SH
==
SW
&&
(
SW
==
1
||
SW
==
2
)
&&
PH
==
0
&&
PW
==
0
;
return
avaible
;
}
void
PoolingImpl
::
AlgoFilter4MaxStridexNCHW44
::
exec
(
const
PoolingKernParam
&
param
)
const
{
auto
IH
=
param
.
isz
[
0
],
IW
=
param
.
isz
[
1
];
auto
OH
=
param
.
osz
[
0
],
OW
=
param
.
osz
[
1
];
auto
N
=
param
.
n
,
C
=
param
.
ic
;
auto
PH
=
param
.
padding
[
0
];
auto
PW
=
param
.
padding
[
1
];
auto
SW
=
param
.
stride
[
0
];
void
*
src_ptr
=
param
.
src_ptr
;
void
*
dst_ptr
=
param
.
dst_ptr
;
#define DISPATCH_FUNC(type, func, midout_type_id, i) \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
midout_iv(midout_type_id)) { \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \
size_t index, size_t thread_id) { \
MEGDNN_MARK_USED_VAR(thread_id); \
size_t n = index / C; \
size_t c = index % C; \
do_max_pooling_4x4_stride##i##_##func##_nchw44_NEON( \
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \
c * IH * IW * 4, \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
run); \
} \
MIDOUT_END();
#define DISPATCH_STRIDE(type, func, midout_type_id) \
switch (SW) { \
case 1: { \
DISPATCH_FUNC(type, func, midout_type_id, 1); \
break; \
} \
case 2: { \
DISPATCH_FUNC(type, func, midout_type_id, 2); \
break; \
} \
default: \
megdnn_assert(0, "unsupport stride size"); \
}
DISPATCH_STRIDE
(
int8_t
,
int8
,
11
);
#undef DISPATCH_STRIDE
#undef DISPATCH_FUNC
}
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
...
...
dnn/src/arm_common/pooling/algo.h
浏览文件 @
15cca8f9
...
...
@@ -107,6 +107,14 @@ public:
void
exec
(
const
PoolingKernParam
&
param
)
const
override
;
};
class
PoolingImpl
::
AlgoFilter4MaxStridexNCHW44
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"ARM_POOLING_FILTER4_MAX_STRIDEX_NCHW44"
;
}
bool
usable
(
const
PoolingKernSizeParam
&
param
)
const
override
;
void
exec
(
const
PoolingKernParam
&
param
)
const
override
;
};
WorkspaceBundle
get_bundle
(
const
PoolingImpl
::
PoolingKernSizeParam
&
param
);
}
// namespace arm_common
...
...
dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.cpp
0 → 100644
浏览文件 @
15cca8f9
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/pooling/do_max_pooling_4x4_nchw44.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
namespace
megdnn
{
namespace
arm_common
{
void
do_max_pooling_4x4_stride1_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
)
{
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
;
const
int8_t
*
__restrict
sptr0
=
src
+
(
ih
+
0
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr1
=
src
+
(
ih
+
1
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr2
=
src
+
(
ih
+
2
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr3
=
src
+
(
ih
+
3
)
*
IW
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int8x16_t
src00
,
src04
,
max_out
,
max_tmp0
,
max_tmp1
,
max_tmp2
,
max_tmp3
;
int32x4_t
src1234
,
src2345
,
src3456
;
#define CACULATE_ROW(i) \
src00 = vld1q_s8(sptr##i); \
src04 = vld1q_s8(sptr##i + 4 * 4); \
src1234 = vextq_s32(vreinterpretq_s32_s8(src00), \
vreinterpretq_s32_s8(src04), 1); \
src2345 = vextq_s32(vreinterpretq_s32_s8(src00), \
vreinterpretq_s32_s8(src04), 2); \
src3456 = vextq_s32(vreinterpretq_s32_s8(src00), \
vreinterpretq_s32_s8(src04), 3); \
max_tmp##i = vmaxq_s8(src00, vreinterpretq_s8_s32(src1234)); \
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2345)); \
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3456));
UNROLL_CALL_NOWRAPPER
(
4
,
CACULATE_ROW
)
max_out
=
vmaxq_s8
(
max_tmp0
,
max_tmp1
);
max_out
=
vmaxq_s8
(
max_out
,
max_tmp2
);
max_out
=
vmaxq_s8
(
max_out
,
max_tmp3
);
vst1q_s8
(
dptr
,
max_out
);
sptr0
+=
16
;
sptr1
+=
16
;
sptr2
+=
16
;
sptr3
+=
16
;
dptr
+=
16
;
#undef CACULATE_ROW
}
for
(;
ow
<
OW
;
++
ow
)
{
int8x8_t
src01
,
src23
,
max_out
;
#define CACULATE_ROW(i) \
src01 = vld1_s8(sptr##i); \
src23 = vld1_s8(sptr##i + 8); \
int8x8_t max_tmp##i = vmax_s8(src01, src23);
UNROLL_CALL_NOWRAPPER
(
4
,
CACULATE_ROW
)
max_out
=
vmax_s8
(
max_tmp0
,
max_tmp1
);
max_out
=
vmax_s8
(
max_out
,
max_tmp2
);
max_out
=
vmax_s8
(
max_out
,
max_tmp3
);
#define store(i) *(dptr + i) = std::max(max_out[i], max_out[i + 4]);
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
#undef store
#undef CACULATE_ROW
sptr0
+=
4
;
sptr1
+=
4
;
sptr2
+=
4
;
sptr3
+=
4
;
dptr
+=
4
;
}
}
}
void
do_max_pooling_4x4_stride2_int8_nchw44_NEON
(
const
int8_t
*
src
,
int8_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
PH
,
size_t
PW
)
{
size_t
oh
=
0
;
for
(;
oh
<
OH
;
++
oh
)
{
size_t
ih
=
oh
<<
1
;
const
int8_t
*
__restrict
sptr0
=
src
+
(
ih
+
0
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr1
=
src
+
(
ih
+
1
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr2
=
src
+
(
ih
+
2
)
*
IW
*
4
;
const
int8_t
*
__restrict
sptr3
=
src
+
(
ih
+
3
)
*
IW
*
4
;
int8_t
*
__restrict
dptr
=
dst
+
oh
*
OW
*
4
;
size_t
ow
=
0
;
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
int8x16_t
src00
,
src04
,
src08
,
src09
,
max_tmp0
,
max_tmp1
,
max_tmp2
,
max_tmp3
;
int32x4_t
src0246
,
src1357
,
src2468
,
src3579
;
int32x4x2_t
src_tmp
;
#define CACULATE_ROW(i) \
src00 = vld1q_s8(sptr##i); \
src04 = vld1q_s8(sptr##i + 4 * 4); \
src08 = vld1q_s8(sptr##i + 4 * 8); \
src09 = vld1q_s8(sptr##i + 4 * 9); \
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \
vreinterpretq_s32_s8(src04)); \
src0246 = src_tmp.val[0]; \
src1357 = src_tmp.val[1]; \
src2468 = vextq_s32(src0246, vreinterpretq_s32_s8(src08), 1); \
src3579 = vextq_s32(src1357, vreinterpretq_s32_s8(src09), 1); \
max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \
vreinterpretq_s8_s32(src1357)); \
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579));
UNROLL_CALL_NOWRAPPER
(
4
,
CACULATE_ROW
)
int8x16_t
max_out
=
vmaxq_s8
(
max_tmp0
,
max_tmp1
);
max_out
=
vmaxq_s8
(
max_out
,
max_tmp2
);
max_out
=
vmaxq_s8
(
max_out
,
max_tmp3
);
vst1q_s8
(
dptr
,
max_out
);
sptr0
+=
32
;
sptr1
+=
32
;
sptr2
+=
32
;
sptr3
+=
32
;
dptr
+=
16
;
#undef CACULATE_ROW
}
for
(;
ow
<
OW
;
++
ow
)
{
int8x8_t
src01
,
src23
,
max_out
;
#define CACULATE_ROW(i) \
src01 = vld1_s8(sptr##i); \
src23 = vld1_s8(sptr##i + 8); \
int8x8_t max_tmp##i = vmax_s8(src01, src23);
UNROLL_CALL_NOWRAPPER
(
4
,
CACULATE_ROW
)
max_out
=
vmax_s8
(
max_tmp0
,
max_tmp1
);
max_out
=
vmax_s8
(
max_out
,
max_tmp2
);
max_out
=
vmax_s8
(
max_out
,
max_tmp3
);
#define store(i) *(dptr + i) = std::max(max_out[i], max_out[i + 4]);
UNROLL_CALL_NOWRAPPER
(
4
,
store
)
#undef store
#undef CACULATE_ROW
sptr0
+=
8
;
sptr1
+=
8
;
sptr2
+=
8
;
sptr3
+=
8
;
dptr
+=
4
;
}
}
}
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.h
0 → 100644
浏览文件 @
15cca8f9
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/common/utils.h"
namespace
megdnn
{
namespace
arm_common
{
#define KERN(strdie) \
void do_max_pooling_4x4_##strdie##_int8_nchw44_NEON( \
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \
size_t OW, size_t PH, size_t PW);
KERN
(
stride1
)
KERN
(
stride2
)
#undef KERN
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/arm_common/pooling/opr_impl.cpp
浏览文件 @
15cca8f9
...
...
@@ -25,9 +25,10 @@ class PoolingImpl::AlgoPack : NonCopyableObj {
AlgoFilter5MaxStride2
algo_filter5_max_stride2
;
AlgoInt8Filter2MaxStride2
algo_int8_filter2_max_stride2
;
AlgoInt8Filter3MaxStride2
algo_int8_filter3_max_stride2
;
AlgoFilter2MaxStridexNCHW44
algo_filter2_max_stridex_nchw4
;
AlgoFilter3MaxStride2NCHW44
algo_filter3_max_stride2_nchw4
;
AlgoFilter3MaxStride1NCHW44
algo_filter3_max_stride1_nchw4
;
AlgoFilter
2MaxStridexNCHW44
algo_filter2
_max_stridex_nchw4
;
AlgoFilter
4MaxStridexNCHW44
algo_filter4
_max_stridex_nchw4
;
public:
AlgoPack
()
{
...
...
@@ -42,6 +43,7 @@ public:
all_algos
.
emplace_back
(
&
algo_filter3_max_stride2_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter3_max_stride1_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter2_max_stridex_nchw4
);
all_algos
.
emplace_back
(
&
algo_filter4_max_stridex_nchw4
);
}
SmallVector
<
AlgoBase
*>
all_algos
;
};
...
...
dnn/src/arm_common/pooling/opr_impl.h
浏览文件 @
15cca8f9
...
...
@@ -83,9 +83,10 @@ private:
class
AlgoFilter5MaxStride2
;
class
AlgoInt8Filter2MaxStride2
;
class
AlgoInt8Filter3MaxStride2
;
class
AlgoFilter2MaxStridexNCHW44
;
class
AlgoFilter3MaxStride2NCHW44
;
class
AlgoFilter3MaxStride1NCHW44
;
class
AlgoFilter
2
MaxStridexNCHW44
;
class
AlgoFilter
4
MaxStridexNCHW44
;
class
AlgoPack
;
};
}
// namespace arm_common
...
...
dnn/test/arm_common/pooling.cpp
浏览文件 @
15cca8f9
...
...
@@ -204,6 +204,56 @@ TEST_F(ARM_COMMON, POOLING_MAX_W2x2_S2x2_NCHW44)
}
// clang-format on
}
TEST_F
(
ARM_COMMON
,
POOLING_MAX_W4x4_S1x1_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
4
,
7
,
10
,
17
,
20
})
for
(
size_t
iw
:
{
4
,
8
,
10
,
21
,
32
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
2
&&
iw
+
2
*
pw
>=
2
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
window_h
=
param
.
window_w
=
4
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON
,
POOLING_MAX_W4x4_S2x2_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
4
,
10
,
18
,
25
,
30
})
for
(
size_t
iw
:
{
4
,
12
,
17
,
20
,
25
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
2
&&
iw
+
2
*
pw
>=
2
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
window_h
=
param
.
window_w
=
4
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F
(
ARM_COMMON
,
POOLING_FP16
)
{
...
...
dnn/test/arm_common/pooling_multi_thread.cpp
浏览文件 @
15cca8f9
...
...
@@ -154,6 +154,56 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_S2x2_NCHW44)
}
// clang-format on
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_MAX_W4x4_S1x1_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
4
,
7
,
10
,
17
,
20
})
for
(
size_t
iw
:
{
4
,
8
,
10
,
21
,
32
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
4
&&
iw
+
2
*
pw
>=
4
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
window_h
=
param
.
window_w
=
4
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_MAX_W4x4_S2x2_NCHW44
)
{
// clang-format off
for
(
size_t
ih
:
{
4
,
10
,
18
,
25
,
30
})
for
(
size_t
iw
:
{
4
,
12
,
17
,
20
,
25
})
for
(
size_t
ph
:
{
0
})
for
(
size_t
pw
:
{
0
})
if
(
ih
+
2
*
ph
>=
4
&&
iw
+
2
*
pw
>=
4
)
{
UniformIntRNG
rng
{
INT8_MIN
>>
1
,
INT8_MAX
>>
1
};
Checker
<
Pooling
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
));
checker
.
set_rng
(
0
,
&
rng
);
param
::
Pooling
param
;
param
.
mode
=
param
::
Pooling
::
Mode
::
MAX
;
param
.
format
=
param
::
Pooling
::
Format
::
NCHW44
;
param
.
pad_h
=
ph
;
param
.
pad_w
=
pw
;
param
.
stride_h
=
param
.
stride_w
=
2
;
param
.
window_h
=
param
.
window_w
=
4
;
checker
.
set_param
(
param
).
exec
(
TensorShapeArray
{{
2
,
2
,
ih
,
iw
,
4
},
{}});
}
// clang-format on
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
POOLING_INT8_W3x3_S2x2
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录