Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
6b2760dd
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
6b2760dd
编写于
6月 11, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/fallback): add float32 nchw44 fuse packb 3x3 s2
GitOrigin-RevId: 3b664bb4f578f5e3f2c36fc963217e37676c9b78
上级
7aeb4f6c
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
868 addition
and
1 deletion
+868
-1
dnn/src/fallback/conv_bias/im2col/factory.h
dnn/src/fallback/conv_bias/im2col/factory.h
+71
-0
dnn/src/fallback/conv_bias/im2col/strategy_base.h
dnn/src/fallback/conv_bias/im2col/strategy_base.h
+69
-0
dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp
...src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp
+15
-0
dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44.cpp
dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44.cpp
+230
-0
dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp
...rc/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp
+204
-0
dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp
...allback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp
+209
-0
dnn/test/arm_common/conv_bias_multi_thread.cpp
dnn/test/arm_common/conv_bias_multi_thread.cpp
+10
-1
dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp
dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp
+60
-0
未找到文件。
dnn/src/fallback/conv_bias/im2col/factory.h
浏览文件 @
6b2760dd
...
...
@@ -226,6 +226,31 @@ public:
PostprocessMode
::
FLOAT
,
"DefaultStrategyType::FLOAT"
_hash
);
}
else
if
(
format
==
param
::
ConvBias
::
Format
::
NCHW44
)
{
#if MEGDNN_AARCH64
auto
matmul_block
=
matmul_algo
->
get_inner_block_size
();
//! Optimize NCHW44 3x3s2 8X12X1 im2col+pack fuse
if
(
matmul_block
.
m
==
8
&&
matmul_block
.
n
==
12
&&
matmul_block
.
k
==
1
&&
param
.
filter_meta
.
spatial
[
0
]
==
3
&&
param
.
filter_meta
.
spatial
[
1
]
==
3
&&
param
.
filter_meta
.
stride
[
0
]
==
2
&&
param
.
filter_meta
.
stride
[
1
]
==
2
&&
!
param
.
filter_meta
.
should_flip
)
{
MIDOUT_BEGIN
(
megdnn_fallback_im2col_factory_make_strategy
,
midout_iv
(
"DefaultStrategyType::8x12x1_fuse_packb_s2_nchw44"
_hash
))
{
return
std
::
make_unique
<
StrategyFuse8x12x1Nchw44K3x3S2
<
float
,
float
,
PostprocessMode
::
FLOAT
>>
();
}
MIDOUT_END
();
return
{};
}
#endif
cb1
(
NCHW44
,
DEFAULT
,
dt_float32
,
dt_float32
,
PostprocessMode
::
FLOAT
,
"DefaultStrategyTypeNCHW44::FLOAT"
_hash
);
...
...
@@ -320,6 +345,52 @@ public:
"DefaultStrategyType::QINT8x8x32x8"
_hash
);
}
else
if
(
format
==
param
::
ConvBias
::
Format
::
NCHW44
||
format
==
param
::
ConvBias
::
Format
::
NCHW44_DOT
)
{
#if MEGDNN_AARCH64
auto
matmul_block
=
matmul_algo
->
get_inner_block_size
();
if
(
format
==
param
::
ConvBias
::
Format
::
NCHW44
)
{
//! Optimize NCHW44 3x3s1 4X4X16 im2col+pack fuse
if
(
matmul_block
.
m
==
4
&&
matmul_block
.
n
==
4
&&
matmul_block
.
k
==
16
&&
param
.
filter_meta
.
spatial
[
0
]
==
3
&&
param
.
filter_meta
.
spatial
[
1
]
==
3
&&
param
.
filter_meta
.
stride
[
0
]
==
1
&&
param
.
filter_meta
.
stride
[
1
]
==
1
&&
!
param
.
filter_meta
.
should_flip
)
{
MIDOUT_BEGIN
(
megdnn_fallback_im2col_factory_make_strategy
,
midout_iv
(
"DefaultStrategyType::INT8x8x32_4x4x16"
_hash
))
{
return
std
::
make_unique
<
StrategyFuse4x4x16Nchw44
<
dt_qint32
,
dt_qint8
,
PostprocessMode
::
QUANTIZED
>>
();
}
MIDOUT_END
();
return
{};
}
}
else
{
//! Optimize NCHW44_DOT 3x3s1 8X12X4 im2col+pack fuse
if
(
matmul_block
.
m
==
8
&&
matmul_block
.
n
==
12
&&
matmul_block
.
k
==
4
&&
param
.
filter_meta
.
spatial
[
0
]
==
3
&&
param
.
filter_meta
.
spatial
[
1
]
==
3
&&
param
.
filter_meta
.
stride
[
0
]
==
1
&&
param
.
filter_meta
.
stride
[
1
]
==
1
&&
!
param
.
filter_meta
.
should_flip
)
{
MIDOUT_BEGIN
(
megdnn_fallback_im2col_factory_make_strategy
,
midout_iv
(
"DefaultStrategyType::INT8x8x32_8x12x4"
_hash
))
{
return
std
::
make_unique
<
StrategyFuse8x12x4Nchw44Dot
<
dt_qint32
,
dt_qint8
,
PostprocessMode
::
QUANTIZED
>>
();
}
MIDOUT_END
();
return
{};
}
}
#endif
cb2
(
NCHW44
,
DEFAULT
,
dtype
::
QuantizedS8
,
dtype
::
QuantizedS32
,
dtype
::
QuantizedS8
,
dt_int8
,
dt_int32
,
dt_int8
,
PostprocessMode
::
QUANTIZED
,
...
...
dnn/src/fallback/conv_bias/im2col/strategy_base.h
浏览文件 @
6b2760dd
...
...
@@ -445,6 +445,75 @@ public:
THREAD_BUNDLE_BIAS_INDEX
);
}
};
#if MEGDNN_AARCH64
template
<
typename
op_ctype
,
typename
op_dtype
,
megdnn
::
PostprocessMode
postprocess_mode
>
class
StrategyFuse4x4x16Nchw44
:
public
Strategy
<
dt_int8
,
dt_int32
,
dt_int8
,
op_ctype
,
op_dtype
,
postprocess_mode
,
PackMode
::
DEFAULT
,
FormatMode
::
NCHW44
>
{
public:
StrategyFuse4x4x16Nchw44
()
=
default
;
constexpr
static
size_t
BUNDLE_PADDING_INDEX
=
0
;
constexpr
static
size_t
BUNDLE_PACKA_INDEX
=
1
;
constexpr
static
size_t
THREAD_BUNDLE_PACKB_INDEX
=
0
;
constexpr
static
size_t
THREAD_BUNDLE_IM2COL_INDEX
=
1
;
constexpr
static
size_t
THREAD_BUNDLE_BIAS_INDEX
=
2
;
void
exec_im2col
(
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
override
;
};
template
<
typename
op_ctype
,
typename
op_dtype
,
megdnn
::
PostprocessMode
postprocess_mode
>
class
StrategyFuse8x12x1Nchw44K3x3S2
:
public
Strategy
<
float
,
float
,
float
,
op_ctype
,
op_dtype
,
postprocess_mode
,
PackMode
::
DEFAULT
,
FormatMode
::
NCHW44
>
{
public:
StrategyFuse8x12x1Nchw44K3x3S2
()
=
default
;
constexpr
static
size_t
BUNDLE_PADDING_INDEX
=
0
;
constexpr
static
size_t
BUNDLE_PACKA_INDEX
=
1
;
constexpr
static
size_t
THREAD_BUNDLE_PACKB_INDEX
=
0
;
constexpr
static
size_t
THREAD_BUNDLE_IM2COL_INDEX
=
1
;
constexpr
static
size_t
THREAD_BUNDLE_BIAS_INDEX
=
2
;
void
exec_im2col
(
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
override
;
};
template
<
typename
op_ctype
,
typename
op_dtype
,
megdnn
::
PostprocessMode
postprocess_mode
>
class
StrategyFuse8x12x4Nchw44Dot
:
public
Strategy
<
dt_int8
,
dt_int32
,
dt_int8
,
op_ctype
,
op_dtype
,
postprocess_mode
,
PackMode
::
DEFAULT
,
FormatMode
::
NCHW44
>
{
public:
StrategyFuse8x12x4Nchw44Dot
()
=
default
;
constexpr
static
size_t
BUNDLE_PADDING_INDEX
=
0
;
constexpr
static
size_t
BUNDLE_PACKA_INDEX
=
1
;
constexpr
static
size_t
THREAD_BUNDLE_PACKB_INDEX
=
0
;
constexpr
static
size_t
THREAD_BUNDLE_IM2COL_INDEX
=
1
;
constexpr
static
size_t
THREAD_BUNDLE_BIAS_INDEX
=
2
;
void
exec_im2col
(
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
override
;
};
#endif
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp
浏览文件 @
6b2760dd
...
...
@@ -14,6 +14,9 @@
#include "src/x86/conv_bias/postprocess_helper.h"
#endif
#if (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#endif
using
namespace
megdnn
;
#if MEGDNN_X86
...
...
@@ -101,11 +104,23 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
INSTANTIAL_CLASS
(
dt_float32
,
dt_float32
,
dt_float32
,
dt_float32
,
dt_float32
,
megdnn
::
PostprocessMode
::
FLOAT
)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
INSTANTIAL_CLASS
(
dt_float16
,
dt_float16
,
dt_float16
,
__fp16
,
__fp16
,
megdnn
::
PostprocessMode
::
FLOAT
)
#else
#if !MEGDNN_DISABLE_FLOAT16
INSTANTIAL_CLASS
(
dt_float16
,
dt_float16
,
dt_float16
,
dt_float16
,
dt_float16
,
megdnn
::
PostprocessMode
::
NO_PROCESS
)
#endif
#endif
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8
INSTANTIAL_CLASS
(
dt_uint8
,
dt_int32
,
dt_uint8
,
dt_qint32
,
dt_quint8
,
megdnn
::
PostprocessMode
::
QUANTIZED
)
INSTANTIAL_CLASS
(
dt_uint8
,
dt_int32
,
dt_int32
,
dt_qint32
,
dt_qint32
,
megdnn
::
PostprocessMode
::
NO_PROCESS
)
#endif
INSTANTIAL_CLASS
(
dt_int8
,
dt_int32
,
dt_int8
,
dt_qint32
,
dt_qint8
,
megdnn
::
PostprocessMode
::
QUANTIZED
)
...
...
dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44.cpp
浏览文件 @
6b2760dd
...
...
@@ -11,5 +11,235 @@
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#if MEGDNN_AARCH64
#include <arm_neon.h>
using
namespace
megdnn
;
namespace
{
#define TRANS_AND_STORE(input0, input1, input2, input3) \
{ \
auto tmp01 = vzipq_s32(input0, input1); \
auto tmp23 = vzipq_s32(input2, input3); \
auto dst0 = vzip1q_s64(vreinterpretq_s64_s32(tmp01.val[0]), \
vreinterpretq_s64_s32(tmp23.val[0])); \
auto dst1 = vzip2q_s64(vreinterpretq_s64_s32(tmp01.val[0]), \
vreinterpretq_s64_s32(tmp23.val[0])); \
auto dst2 = vzip1q_s64(vreinterpretq_s64_s32(tmp01.val[1]), \
vreinterpretq_s64_s32(tmp23.val[1])); \
auto dst3 = vzip2q_s64(vreinterpretq_s64_s32(tmp01.val[1]), \
vreinterpretq_s64_s32(tmp23.val[1])); \
vst1q_s32(dst, vreinterpretq_s32_s64(dst0)); \
vst1q_s32(dst + 4, vreinterpretq_s32_s64(dst1)); \
vst1q_s32(dst + 8, vreinterpretq_s32_s64(dst2)); \
vst1q_s32(dst + 12, vreinterpretq_s32_s64(dst3)); \
dst += 16; \
}
#define TRANS_AND_STORE_REMAIN(input0, input1, input2, input3, remain) \
{ \
auto tmp01 = vzipq_s32(input0, input1); \
auto tmp23 = vzipq_s32(input2, input3); \
vdst[0] = vzip1q_s64(vreinterpretq_s64_s32(tmp01.val[0]), \
vreinterpretq_s64_s32(tmp23.val[0])); \
vdst[1] = vzip2q_s64(vreinterpretq_s64_s32(tmp01.val[0]), \
vreinterpretq_s64_s32(tmp23.val[0])); \
vdst[2] = vzip1q_s64(vreinterpretq_s64_s32(tmp01.val[1]), \
vreinterpretq_s64_s32(tmp23.val[1])); \
vdst[3] = vzip2q_s64(vreinterpretq_s64_s32(tmp01.val[1]), \
vreinterpretq_s64_s32(tmp23.val[1])); \
for (size_t i = 0; i < remain; i++) { \
vst1q_s32(dst + i * 4, vreinterpretq_s32_s64(vdst[i])); \
} \
dst += 16; \
}
void
optimize_fuse_im2col_packB
(
dt_int8
*
src
,
size_t
ic
,
size_t
iw
,
size_t
ih
,
size_t
curr_iw
,
size_t
curr_ih
,
dt_int8
*
dst_ptr
)
{
int
*
src_line0
=
reinterpret_cast
<
int
*>
(
src
+
curr_ih
*
iw
*
4
+
curr_iw
*
4
);
int
*
src_line1
=
reinterpret_cast
<
int
*>
(
src
+
(
curr_ih
+
1
)
*
iw
*
4
+
curr_iw
*
4
);
int
*
src_line2
=
reinterpret_cast
<
int
*>
(
src
+
(
curr_ih
+
2
)
*
iw
*
4
+
curr_iw
*
4
);
int
*
dst
=
reinterpret_cast
<
int
*>
(
dst_ptr
);
int32x4_t
input
[
12
];
int
remain
=
0
;
for
(
size_t
c
=
0
;
c
<
ic
;
c
++
)
{
input
[
remain
]
=
vld1q_s32
(
src_line0
);
input
[
remain
+
1
]
=
vld1q_s32
(
src_line0
+
1
);
input
[
remain
+
2
]
=
vld1q_s32
(
src_line0
+
2
);
input
[
remain
+
3
]
=
vld1q_s32
(
src_line1
);
input
[
remain
+
4
]
=
vld1q_s32
(
src_line1
+
1
);
input
[
remain
+
5
]
=
vld1q_s32
(
src_line1
+
2
);
input
[
remain
+
6
]
=
vld1q_s32
(
src_line2
);
input
[
remain
+
7
]
=
vld1q_s32
(
src_line2
+
1
);
input
[
remain
+
8
]
=
vld1q_s32
(
src_line2
+
2
);
TRANS_AND_STORE
(
input
[
0
],
input
[
1
],
input
[
2
],
input
[
3
]);
TRANS_AND_STORE
(
input
[
4
],
input
[
5
],
input
[
6
],
input
[
7
]);
if
(
remain
==
3
)
{
TRANS_AND_STORE
(
input
[
8
],
input
[
9
],
input
[
10
],
input
[
11
]);
remain
=
0
;
}
else
{
for
(
int
i
=
0
;
i
<=
remain
;
i
++
)
{
input
[
i
]
=
input
[
8
+
i
];
}
remain
++
;
}
src_line0
+=
ih
*
iw
;
src_line1
+=
ih
*
iw
;
src_line2
+=
ih
*
iw
;
}
//! pad remain to 4
if
(
remain
>
0
)
{
TRANS_AND_STORE
(
input
[
0
],
input
[
1
],
input
[
2
],
input
[
3
]);
}
}
void
naive_fuse_im2col_packB
(
dt_int8
*
src
,
size_t
ic
,
size_t
iw
,
size_t
ih
,
size_t
curr_iw
,
size_t
curr_ih
,
size_t
num_point
,
size_t
ow
,
dt_int8
*
dst_ptr
)
{
megdnn_assert
(
num_point
<=
4
_z
,
"fuse im2col and packB of 4x4x16 num_point must less than 4"
);
int
*
src_line0
=
reinterpret_cast
<
int
*>
(
src
+
curr_ih
*
iw
*
4
);
int
*
src_line1
=
reinterpret_cast
<
int
*>
(
src
+
(
curr_ih
+
1
)
*
iw
*
4
);
int
*
src_line2
=
reinterpret_cast
<
int
*>
(
src
+
(
curr_ih
+
2
)
*
iw
*
4
);
int
remain
=
0
;
int
out
[
9
][
4
]
=
{{
0
}};
int32x4_t
input
[
12
];
int
*
dst
=
reinterpret_cast
<
int
*>
(
dst_ptr
);
for
(
size_t
c
=
0
;
c
<
ic
;
c
++
)
{
//! Read int buffer out
size_t
index
=
0
,
w
=
curr_iw
,
dalta_h
=
0
;
while
(
index
<
num_point
)
{
int
*
src_next_line0
=
src_line0
+
dalta_h
*
iw
;
int
*
src_next_line1
=
src_next_line0
+
iw
;
int
*
src_next_line2
=
src_next_line1
+
iw
;
for
(;
index
<
num_point
&&
w
<
ow
;
index
++
,
w
++
)
{
out
[
0
][
index
]
=
src_next_line0
[
w
];
out
[
1
][
index
]
=
src_next_line0
[
w
+
1
];
out
[
2
][
index
]
=
src_next_line0
[
w
+
2
];
out
[
3
][
index
]
=
src_next_line1
[
w
];
out
[
4
][
index
]
=
src_next_line1
[
w
+
1
];
out
[
5
][
index
]
=
src_next_line1
[
w
+
2
];
out
[
6
][
index
]
=
src_next_line2
[
w
];
out
[
7
][
index
]
=
src_next_line2
[
w
+
1
];
out
[
8
][
index
]
=
src_next_line2
[
w
+
2
];
}
//! next line
w
=
0
;
dalta_h
+=
1
;
}
//! load int vector
input
[
remain
]
=
vld1q_s32
(
out
[
0
]);
input
[
remain
+
1
]
=
vld1q_s32
(
out
[
1
]);
input
[
remain
+
2
]
=
vld1q_s32
(
out
[
2
]);
input
[
remain
+
3
]
=
vld1q_s32
(
out
[
3
]);
input
[
remain
+
4
]
=
vld1q_s32
(
out
[
4
]);
input
[
remain
+
5
]
=
vld1q_s32
(
out
[
5
]);
input
[
remain
+
6
]
=
vld1q_s32
(
out
[
6
]);
input
[
remain
+
7
]
=
vld1q_s32
(
out
[
7
]);
input
[
remain
+
8
]
=
vld1q_s32
(
out
[
8
]);
int64x2_t
vdst
[
4
];
TRANS_AND_STORE_REMAIN
(
input
[
0
],
input
[
1
],
input
[
2
],
input
[
3
],
num_point
);
TRANS_AND_STORE_REMAIN
(
input
[
4
],
input
[
5
],
input
[
6
],
input
[
7
],
num_point
);
if
(
remain
==
3
)
{
TRANS_AND_STORE_REMAIN
(
input
[
8
],
input
[
9
],
input
[
10
],
input
[
11
],
num_point
);
remain
=
0
;
}
else
{
for
(
int
i
=
0
;
i
<=
remain
;
i
++
)
{
input
[
i
]
=
input
[
8
+
i
];
}
remain
++
;
}
src_line0
+=
ih
*
iw
;
src_line1
+=
ih
*
iw
;
src_line2
+=
ih
*
iw
;
}
//! pad remain to 4
if
(
remain
>
0
)
{
int64x2_t
vdst
[
4
];
TRANS_AND_STORE_REMAIN
(
input
[
0
],
input
[
1
],
input
[
2
],
input
[
3
],
num_point
);
}
}
}
// namespace
template
<
typename
op_ctype
,
typename
op_dtype
,
megdnn
::
PostprocessMode
postprocess_mode
>
void
StrategyFuse4x4x16Nchw44
<
op_ctype
,
op_dtype
,
postprocess_mode
>::
exec_im2col
(
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
)
{
size_t
ow
=
param
.
osz
[
1
];
size_t
ic
=
param
.
filter_meta
.
icpg
;
size_t
ih
=
param
.
isz
[
0
]
+
param
.
filter_meta
.
padding
[
0
]
*
2
;
size_t
iw
=
param
.
isz
[
1
]
+
param
.
filter_meta
.
padding
[
1
]
*
2
;
constexpr
static
size_t
pack_size
=
4
;
size_t
input_offset
=
ih
*
iw
*
ic
*
(
sparam
.
group_id
+
param
.
filter_meta
.
group
*
sparam
.
batch_id
)
*
sizeof
(
dt_int8
);
dt_int8
*
src2
=
reinterpret_cast
<
dt_int8
*>
(
reinterpret_cast
<
uintptr_t
>
(
bundle
.
get
(
BUNDLE_PADDING_INDEX
))
+
input_offset
);
bool
is_phpwzero
=
param
.
filter_meta
.
padding
[
0
]
==
0
&&
param
.
filter_meta
.
padding
[
1
]
==
0
;
if
(
is_phpwzero
)
{
src2
=
const_cast
<
dt_int8
*>
(
param
.
src
<
dt_int8
>
(
sparam
.
batch_id
,
sparam
.
group_id
));
}
dt_int8
*
b_panel
=
reinterpret_cast
<
dt_int8
*>
(
reinterpret_cast
<
uintptr_t
>
(
bundle_thread
.
get
(
THREAD_BUNDLE_PACKB_INDEX
)));
megdnn_assert
(
ic
%
4
==
0
,
"nchw44 with ic is not of time 4"
);
const
int
packed_k
=
(
ic
*
3
*
3
)
/
pack_size
;
const
int
ksize4
=
round_up
<
int
>
(
packed_k
,
4
)
*
16
*
sizeof
(
dt_int8
);
size_t
out_size
=
sparam
.
output_block_size
;
size_t
curr_index
=
sparam
.
ohw_cur_index
;
size_t
curr_ih
=
curr_index
/
ow
;
size_t
curr_iw
=
curr_index
%
ow
;
size_t
out_index
=
0
;
while
(
out_index
<
out_size
)
{
for
(;
curr_iw
+
3
<
ow
&&
out_index
+
3
<
out_size
;
curr_iw
+=
4
,
out_index
+=
4
)
{
dt_int8
*
dst
=
b_panel
+
(
out_index
/
4
)
*
ksize4
;
optimize_fuse_im2col_packB
(
src2
,
ic
/
4
,
iw
,
ih
,
curr_iw
,
curr_ih
,
dst
);
}
if
(
curr_iw
<
ow
&&
out_index
<
out_size
)
{
size_t
out_remain
=
std
::
min
(
out_size
-
out_index
,
4
_z
);
size_t
remain_point_this_line
=
std
::
min
(
ow
-
curr_iw
,
out_remain
);
size_t
start_point_next_line
=
(
out_remain
-
remain_point_this_line
)
%
ow
;
size_t
pass_lines
=
(
out_remain
-
remain_point_this_line
)
/
ow
;
dt_int8
*
dst
=
b_panel
+
(
out_index
/
4
)
*
ksize4
;
naive_fuse_im2col_packB
(
src2
,
ic
/
4
,
iw
,
ih
,
curr_iw
,
curr_ih
,
out_remain
,
ow
,
dst
);
out_index
+=
out_remain
;
curr_iw
=
start_point_next_line
;
curr_ih
+=
(
pass_lines
+
1
);
}
else
{
curr_iw
=
0
;
curr_ih
++
;
}
}
}
#undef TRANS_AND_STORE_REMAIN
#undef TRANS_AND_STORE
namespace
megdnn
{
template
class
StrategyFuse4x4x16Nchw44
<
dt_qint32
,
dt_qint8
,
megdnn
::
PostprocessMode
::
QUANTIZED
>;
}
// namespace megdnn
#endif
// vim: syntax=cpp.doxygen
dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp
浏览文件 @
6b2760dd
...
...
@@ -11,5 +11,209 @@
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#if MEGDNN_AARCH64
#include <arm_neon.h>
using
namespace
megdnn
;
namespace
{
#define PACKB_ONELINE() \
int out_index = 0; \
outptr = output_base; \
for (; out_index + 11 < block_size; out_index += 12) { \
std::memcpy(outptr, tmp_output, 48); \
outptr += ksize12; \
tmp_output += 12; \
} \
\
outptr = output_base4; \
for (; out_index + 3 < block_size; out_index += 4) { \
std::memcpy(outptr, tmp_output, 16); \
outptr += ksize4; \
tmp_output += 4; \
} \
\
if (out_index < block_size) { \
uint32_t zerobuffer[4] = {0}; \
size_t out_remain = std::min(block_size - out_index, 4); \
std::memcpy(outptr, tmp_output, out_remain * sizeof(uint32_t)); \
outptr += out_remain; \
std::memcpy(outptr, zerobuffer, (4 - out_remain) * sizeof(uint32_t)); \
} \
output_base += 12; \
output_base4 += 4;
#define STOR_IM2COL_DST() \
output0[count] = uint32_src[index + 0]; \
output1[count] = uint32_src[index + 1]; \
output2[count] = uint32_src[index + 2];
#define LOAD_AND_STOR_IM2COL_DST() \
uint32x4_t v_tmp = vld1q_u32(&uint32_src[index + 4]); \
uint32x4_t v_o1 = vextq_u32(v_o0, v_tmp, 1); \
uint32x4_t v_o2 = vextq_u32(v_o0, v_tmp, 2); \
vst1q_u32(&output0[count], v_o0); \
vst1q_u32(&output1[count], v_o1); \
vst1q_u32(&output2[count], v_o2); \
v_o0 = v_tmp;
void
fuse_packb
(
const
dt_int8
*
__restrict
src
,
dt_int8
*
__restrict
dst
,
dt_int8
*
__restrict
b_panel
,
const
int
OW
,
const
int
IC
,
const
int
IH
,
const
int
IW
,
const
int
cur_index
,
const
int
block_size
)
{
int
start_h
=
cur_index
/
OW
;
int
cur_remain_w
=
cur_index
%
OW
;
int
end_h
=
(
cur_index
+
block_size
)
/
OW
;
int
end_remain_w
=
(
cur_index
+
block_size
)
%
OW
;
bool
same_line
=
start_h
==
end_h
?
true
:
false
;
size_t
newIC
=
IC
/
4
;
const
uint32_t
*
uint32_src
=
static_cast
<
const
uint32_t
*>
(
static_cast
<
const
void
*>
(
src
));
uint32_t
*
output
=
static_cast
<
uint32_t
*>
(
static_cast
<
void
*>
(
dst
));
uint32_t
*
b_output
=
static_cast
<
uint32_t
*>
(
static_cast
<
void
*>
(
b_panel
));
const
int
packed_k
=
newIC
*
3
*
3
;
const
int
ksize12
=
packed_k
*
12
*
sizeof
(
dt_int8
);
const
int
ksize4
=
packed_k
*
4
*
sizeof
(
dt_int8
);
uint32_t
*
outptr
=
b_output
;
uint32_t
*
output_base
=
b_output
;
uint32_t
*
output_base4
=
b_output
+
block_size
/
12
*
ksize12
;
constexpr
int
FH
=
3
;
if
(
same_line
)
{
rep
(
ic
,
newIC
)
{
rep
(
fh
,
FH
)
{
size_t
count
=
0
;
size_t
index
=
0
;
int
w
=
cur_remain_w
;
index
=
(
ic
*
IH
+
(
start_h
+
fh
))
*
IW
+
w
;
for
(;
w
+
3
<
end_remain_w
;
w
+=
4
)
{
vst1q_u32
(
&
output
[
count
],
vld1q_u32
(
&
uint32_src
[
index
]));
count
+=
4
;
index
+=
4
;
}
for
(;
w
<
end_remain_w
;
w
++
)
{
output
[
count
++
]
=
uint32_src
[
index
++
];
}
output
[
count
++
]
=
uint32_src
[
index
];
output
[
count
++
]
=
uint32_src
[
index
+
1
];
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
const
uint32_t
*
tmp_output
=
output
+
i
;
PACKB_ONELINE
();
}
}
}
}
else
{
rep
(
ic
,
newIC
)
{
rep
(
fh
,
FH
)
{
size_t
count
=
0
;
size_t
index
=
0
;
uint32_t
*
output0
=
output
;
uint32_t
*
output1
=
output
+
block_size
;
uint32_t
*
output2
=
output1
+
block_size
;
int
w
=
cur_remain_w
;
index
=
(
ic
*
IH
+
(
start_h
+
fh
))
*
IW
+
w
;
uint32x4_t
v_o0
=
vld1q_u32
(
&
uint32_src
[
index
]);
for
(
;
w
+
3
<
OW
;
w
+=
4
)
{
LOAD_AND_STOR_IM2COL_DST
();
count
+=
4
;
index
+=
4
;
}
for
(;
w
<
OW
;
w
++
)
{
STOR_IM2COL_DST
();
count
++
;
index
++
;
}
for
(
int
h
=
start_h
+
1
;
h
<
end_h
;
h
++
)
{
int
ow
=
0
;
index
=
(
ic
*
IH
+
(
h
+
fh
))
*
IW
+
ow
;
v_o0
=
vld1q_u32
(
&
uint32_src
[
index
]);
for
(;
ow
+
3
<
OW
;
ow
+=
4
)
{
LOAD_AND_STOR_IM2COL_DST
();
count
+=
4
;
index
+=
4
;
}
for
(;
ow
<
OW
;
ow
++
)
{
STOR_IM2COL_DST
();
count
++
;
index
++
;
}
}
index
=
(
ic
*
IH
+
(
end_h
+
fh
))
*
IW
;
w
=
0
;
v_o0
=
vld1q_u32
(
&
uint32_src
[
index
]);
for
(
;
w
+
3
<
end_remain_w
;
w
+=
4
)
{
LOAD_AND_STOR_IM2COL_DST
();
count
+=
4
;
index
+=
4
;
}
for
(
;
w
<
end_remain_w
;
w
++
)
{
STOR_IM2COL_DST
();
count
++
;
index
++
;
}
for
(
int
k
=
0
;
k
<
3
;
k
++
)
{
const
uint32_t
*
tmp_output
=
output
+
k
*
block_size
;
PACKB_ONELINE
();
}
}
}
}
}
#undef PACKB_ONELINE
#undef STOR_IM2COL_DST
#undef LOAD_AND_STOR_IM2COL_DST
}
// namespace
template
<
typename
op_ctype
,
typename
op_dtype
,
megdnn
::
PostprocessMode
postprocess_mode
>
void
StrategyFuse8x12x4Nchw44Dot
<
op_ctype
,
op_dtype
,
postprocess_mode
>::
exec_im2col
(
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
/*matmul_param*/
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
/*matmul_algo*/
)
{
size_t
ow
=
param
.
osz
[
1
];
size_t
ic
=
param
.
filter_meta
.
icpg
;
size_t
ih
=
param
.
isz
[
0
]
+
param
.
filter_meta
.
padding
[
0
]
*
2
;
size_t
iw
=
param
.
isz
[
1
]
+
param
.
filter_meta
.
padding
[
1
]
*
2
;
size_t
input_offset
=
ih
*
iw
*
ic
*
(
sparam
.
group_id
+
param
.
filter_meta
.
group
*
sparam
.
batch_id
)
*
sizeof
(
dt_int8
);
dt_int8
*
src2
=
reinterpret_cast
<
dt_int8
*>
(
reinterpret_cast
<
uintptr_t
>
(
bundle
.
get
(
BUNDLE_PADDING_INDEX
))
+
input_offset
);
bool
is_phpwzero
=
param
.
filter_meta
.
padding
[
0
]
==
0
&&
param
.
filter_meta
.
padding
[
1
]
==
0
;
if
(
is_phpwzero
)
{
src2
=
const_cast
<
dt_int8
*>
(
param
.
src
<
dt_int8
>
(
sparam
.
batch_id
,
sparam
.
group_id
));
}
dt_int8
*
b_panel
=
reinterpret_cast
<
dt_int8
*>
(
reinterpret_cast
<
uintptr_t
>
(
bundle_thread
.
get
(
THREAD_BUNDLE_PACKB_INDEX
)));
megdnn_assert
(
ic
%
4
==
0
,
"nchw44_dot with ic is not of time 4"
);
int8_t
*
im2col_dst
=
static_cast
<
int8_t
*>
(
bundle_thread
.
get
(
THREAD_BUNDLE_IM2COL_INDEX
));
fuse_packb
(
src2
,
im2col_dst
,
b_panel
,
ow
,
ic
,
ih
,
iw
,
sparam
.
ohw_cur_index
,
sparam
.
output_block_size
);
}
namespace
megdnn
{
template
class
StrategyFuse8x12x4Nchw44Dot
<
dt_qint32
,
dt_qint8
,
megdnn
::
PostprocessMode
::
QUANTIZED
>;
}
// namespace megdnn
#endif
// vim: syntax=cpp.doxygen
dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp
0 → 100644
浏览文件 @
6b2760dd
/**
* \file dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_AARCH64
#include <arm_neon.h>
using
namespace
megdnn
;
namespace
{
#define PACKB_ONELINE() \
int out_index = 0; \
outptr = output_base; \
for (; out_index + 11 < block_size; out_index += 12) { \
float32x4x4_t v0 = vld4q_f32(tmp_output); \
float32x4x4_t v1 = vld4q_f32(tmp_output + 16); \
float32x4x4_t v2 = vld4q_f32(tmp_output + 32); \
vst1q_f32(outptr, v0.val[0]); \
vst1q_f32(outptr + 4, v1.val[0]); \
vst1q_f32(outptr + 8, v2.val[0]); \
vst1q_f32(outptr + 12, v0.val[1]); \
vst1q_f32(outptr + 16, v1.val[1]); \
vst1q_f32(outptr + 20, v2.val[1]); \
vst1q_f32(outptr + 24, v0.val[2]); \
vst1q_f32(outptr + 28, v1.val[2]); \
vst1q_f32(outptr + 32, v2.val[2]); \
vst1q_f32(outptr + 36, v0.val[3]); \
vst1q_f32(outptr + 40, v1.val[3]); \
vst1q_f32(outptr + 44, v2.val[3]); \
outptr += ksize12; \
tmp_output += 48; \
} \
\
outptr = output_base4; \
for (; out_index + 3 < block_size; out_index += 4) { \
float32x4x4_t v0 = vld4q_f32(tmp_output); \
vst1q_f32(outptr, v0.val[0]); \
vst1q_f32(outptr + 4, v0.val[1]); \
vst1q_f32(outptr + 8, v0.val[2]); \
vst1q_f32(outptr + 12, v0.val[3]); \
outptr += ksize4; \
tmp_output += 16; \
} \
\
if (out_index < block_size) { \
float zerobuffer[16] = {0}; \
size_t out_remain = std::min(block_size - out_index, 4); \
std::memcpy(zerobuffer, tmp_output, out_remain * sizeof(float) * 4); \
float32x4x4_t v0 = vld4q_f32(zerobuffer); \
vst1q_f32(outptr, v0.val[0]); \
vst1q_f32(outptr + 4, v0.val[1]); \
vst1q_f32(outptr + 8, v0.val[2]); \
vst1q_f32(outptr + 12, v0.val[3]); \
} \
output_base += 48; \
output_base4 += 16;
#define LOAD_AND_STOR_IM2COL_DST() \
float32x4_t v1 = vld1q_f32(&src[index + 4]); \
float32x4_t v2 = vld1q_f32(&src[index + 8]); \
vst1q_f32(&output0[i], v0); \
vst1q_f32(&output1[i], v1); \
vst1q_f32(&output2[i], v2); \
i += 4; \
index += 8; \
v0 = v2;
void
fuse_packb
(
const
float
*
__restrict
src
,
float
*
__restrict
dst
,
float
*
__restrict
b_panel
,
const
int
OW
,
const
int
IC
,
const
int
IH
,
const
int
IW
,
const
int
cur_index
,
const
int
block_size
)
{
int
start_h
=
cur_index
/
OW
;
int
cur_remain_w
=
cur_index
%
OW
;
int
end_h
=
(
cur_index
+
block_size
)
/
OW
;
int
end_remain_w
=
(
cur_index
+
block_size
)
%
OW
;
bool
same_line
=
start_h
==
end_h
?
true
:
false
;
size_t
newIC
=
IC
/
4
;
float
*
b_output
=
b_panel
;
const
int
packed_k
=
IC
*
3
*
3
;
const
int
ksize12
=
packed_k
*
12
;
const
int
ksize4
=
packed_k
*
4
;
float
*
outptr
=
b_output
;
float
*
output_base
=
b_output
;
float
*
output_base4
=
b_output
+
block_size
/
12
*
ksize12
;
constexpr
int
FH
=
3
;
constexpr
int
SH
=
2
;
constexpr
int
SW
=
2
;
if
(
same_line
)
{
rep
(
ic
,
newIC
)
{
rep
(
fh
,
FH
)
{
float
*
output02
=
dst
;
float
*
output1
=
dst
+
block_size
*
4
+
4
;
size_t
i
=
0
;
size_t
index
=
4
*
(
ic
*
IH
*
IW
+
(
start_h
*
SH
+
fh
)
*
IW
+
cur_remain_w
*
SW
);
for
(
int
w
=
cur_remain_w
;
w
<
end_remain_w
;
w
++
)
{
vst1q_f32
(
&
output02
[
i
],
vld1q_f32
(
&
src
[
index
]));
vst1q_f32
(
&
output1
[
i
],
vld1q_f32
(
&
src
[
index
+
4
]));
i
+=
4
;
index
+=
8
;
}
vst1q_f32
(
&
output02
[
i
],
vld1q_f32
(
&
src
[
index
]));
float
*
output
[
3
];
output
[
0
]
=
output02
;
output
[
1
]
=
output1
;
output
[
2
]
=
output02
+
4
;
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
const
float
*
tmp_output
=
output
[
i
];
PACKB_ONELINE
();
}
}
}
}
else
{
rep
(
ic
,
newIC
)
{
rep
(
fh
,
FH
)
{
float
*
output0
=
dst
;
float
*
output1
=
dst
+
block_size
*
4
;
float
*
output2
=
output1
+
block_size
*
4
;
size_t
i
=
0
;
size_t
index
=
4
*
(
ic
*
IH
*
IW
+
(
start_h
*
SH
+
fh
)
*
IW
+
(
cur_remain_w
*
SW
));
float32x4_t
v0
=
vld1q_f32
(
&
src
[
index
]);
for
(
int
w
=
cur_remain_w
;
w
<
OW
;
w
++
)
{
LOAD_AND_STOR_IM2COL_DST
();
}
for
(
int
h
=
start_h
+
1
;
h
<
end_h
;
h
++
)
{
size_t
index
=
4
*
(
ic
*
IH
*
IW
+
(
h
*
SH
+
fh
)
*
IW
);
v0
=
vld1q_f32
(
&
src
[
index
]);
rep
(
ow
,
OW
)
{
LOAD_AND_STOR_IM2COL_DST
();
}
}
index
=
4
*
(
ic
*
IH
*
IW
+
(
end_h
*
SH
+
fh
)
*
IW
);
v0
=
vld1q_f32
(
&
src
[
index
]);
for
(
int
w
=
0
;
w
<
end_remain_w
;
w
++
)
{
LOAD_AND_STOR_IM2COL_DST
();
}
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
const
float
*
tmp_output
=
output0
+
i
*
block_size
*
4
;
PACKB_ONELINE
();
}
}
}
}
}
#undef PACKB_ONELINE
#undef LOAD_AND_STOR_IM2COL_DST
}
// namespace
template
<
typename
op_ctype
,
typename
op_dtype
,
megdnn
::
PostprocessMode
postprocess_mode
>
void
StrategyFuse8x12x1Nchw44K3x3S2
<
op_ctype
,
op_dtype
,
postprocess_mode
>::
exec_im2col
(
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
/*matmul_param*/
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
/*matmul_algo*/
)
{
size_t
ow
=
param
.
osz
[
1
];
size_t
ic
=
param
.
filter_meta
.
icpg
;
size_t
ih
=
param
.
isz
[
0
]
+
param
.
filter_meta
.
padding
[
0
]
*
2
;
size_t
iw
=
param
.
isz
[
1
]
+
param
.
filter_meta
.
padding
[
1
]
*
2
;
size_t
input_offset
=
ih
*
iw
*
ic
*
(
sparam
.
group_id
+
param
.
filter_meta
.
group
*
sparam
.
batch_id
)
*
sizeof
(
float
);
float
*
src2
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uintptr_t
>
(
bundle
.
get
(
BUNDLE_PADDING_INDEX
))
+
input_offset
);
bool
is_phpwzero
=
param
.
filter_meta
.
padding
[
0
]
==
0
&&
param
.
filter_meta
.
padding
[
1
]
==
0
;
if
(
is_phpwzero
)
{
src2
=
const_cast
<
float
*>
(
param
.
src
<
float
>
(
sparam
.
batch_id
,
sparam
.
group_id
));
}
float
*
b_panel
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uintptr_t
>
(
bundle_thread
.
get
(
THREAD_BUNDLE_PACKB_INDEX
)));
megdnn_assert
(
ic
%
4
==
0
,
"nchw44_dot with ic is not of time 4"
);
float
*
im2col_dst
=
static_cast
<
float
*>
(
bundle_thread
.
get
(
THREAD_BUNDLE_IM2COL_INDEX
));
fuse_packb
(
src2
,
im2col_dst
,
b_panel
,
ow
,
ic
,
ih
,
iw
,
sparam
.
ohw_cur_index
,
sparam
.
output_block_size
);
}
namespace
megdnn
{
template
class
StrategyFuse8x12x1Nchw44K3x3S2
<
float
,
float
,
megdnn
::
PostprocessMode
::
FLOAT
>;
}
// namespace megdnn
#endif
// vim: syntax=cpp.doxygen
dnn/test/arm_common/conv_bias_multi_thread.cpp
浏览文件 @
6b2760dd
...
...
@@ -1838,7 +1838,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) {
check_conv_bias
(
args
,
handle
(),
"IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12"
);
#endif
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_IM2COL_S2_MK4_PACK_F32_FUSE
)
{
using
namespace
conv_bias
;
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_nchw44_conv_bias_args
(
{
3
},
2
,
false
,
false
,
false
,
false
,
false
,
true
,
true
,
false
);
#if MEGDNN_AARCH64
check_conv_bias
(
args
,
handle
(),
"IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"
);
#elif MEGDNN_ARMV7
check_conv_bias
(
args
,
handle
(),
"IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12"
);
#endif
}
/***************************** Conv1x1 Algo Test ***********************/
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_1X1_S1_F32
)
{
using
namespace
conv_bias
;
...
...
dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp
浏览文件 @
6b2760dd
...
...
@@ -708,6 +708,66 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) {
}
TEST_F
(
ARM_COMMON_BENCHMARK_MULTI_THREADS
,
BENCHMARK_CONVBIAS_FLOAT_NCHW44
)
{
constexpr
size_t
RUNS
=
40
;
std
::
vector
<
DType
>
data_type
=
{
dtype
::
Float32
(),
dtype
::
Float32
(),
dtype
::
Float32
(),
dtype
::
Float32
()};
auto
bench_case
=
[
&
](
size_t
N
,
size_t
IC
,
size_t
OC
,
size_t
H
,
size_t
W
,
size_t
FS
,
size_t
group
,
size_t
P
,
size_t
S
,
bool
is_nchw
=
false
)
{
param
::
ConvBias
param
;
param
.
nonlineMode
=
param
::
ConvBias
::
NonlineMode
::
RELU
;
param
.
pad_h
=
P
;
param
.
pad_w
=
P
;
param
.
stride_h
=
S
;
param
.
stride_w
=
S
;
param
.
sparse
=
param
::
ConvBias
::
Sparse
::
DENSE
;
param
.
format
=
param
::
ConvBias
::
Format
::
NCHW44
;
auto
OH
=
(
H
+
2
*
P
-
FS
)
/
static_cast
<
size_t
>
(
S
)
+
1
;
auto
OW
=
(
W
+
2
*
P
-
FS
)
/
static_cast
<
size_t
>
(
S
)
+
1
;
TensorShape
src
=
{
N
,
IC
/
4
,
H
,
W
,
4
};
TensorShape
filter
=
{
OC
/
4
,
IC
/
4
,
FS
,
FS
,
4
,
4
};
if
(
group
>
1
)
{
filter
=
{
group
,
OC
/
group
/
4
,
IC
/
group
/
4
,
FS
,
FS
,
4
,
4
};
param
.
sparse
=
param
::
ConvBias
::
Sparse
::
GROUP
;
}
if
(
is_nchw
)
{
src
=
{
N
,
IC
,
H
,
W
};
filter
=
{
OC
/
4
,
FS
,
FS
,
IC
,
4
};
}
TensorShape
bias
=
{
1
,
OC
/
4
,
1
,
1
,
4
};
TensorShape
dst
=
{
N
,
OC
/
4
,
OH
,
OW
,
4
};
SmallVector
<
TensorShape
>
shapes
{
src
,
filter
,
bias
,
{},
dst
};
float
computations
=
(((
IC
/
group
)
*
FS
*
FS
+
1
)
*
dst
.
total_nr_elems
()
*
2
+
dst
.
total_nr_elems
())
*
1e-6
;
std
::
vector
<
std
::
pair
<
SmallVector
<
TensorShape
>
,
float
>>
shape_arg
=
{
std
::
make_pair
(
shapes
,
computations
)};
benchmark_impl
(
param
,
shape_arg
,
".+"
,
RUNS
,
{
4
,
{
4
,
5
,
6
,
7
}},
{
1
,
{
7
}},
data_type
);
};
bench_case
(
1
,
64
,
64
,
56
,
56
,
3
,
1
,
1
,
2
);
bench_case
(
1
,
128
,
128
,
28
,
28
,
3
,
1
,
1
,
2
);
bench_case
(
1
,
256
,
256
,
14
,
14
,
3
,
1
,
1
,
2
);
bench_case
(
1
,
512
,
512
,
7
,
7
,
3
,
1
,
1
,
2
);
bench_case
(
1
,
64
,
64
,
56
,
56
,
3
,
4
,
1
,
2
);
bench_case
(
1
,
128
,
128
,
28
,
28
,
3
,
4
,
1
,
2
);
bench_case
(
1
,
256
,
256
,
14
,
14
,
3
,
4
,
1
,
2
);
bench_case
(
1
,
512
,
512
,
7
,
7
,
3
,
4
,
1
,
2
);
bench_case
(
1
,
64
,
64
,
56
*
2
,
56
*
2
,
3
,
4
,
1
,
2
);
bench_case
(
1
,
128
,
128
,
28
*
2
,
28
*
2
,
3
,
4
,
1
,
2
);
bench_case
(
1
,
256
,
256
,
14
*
2
,
14
*
2
,
3
,
4
,
1
,
2
);
bench_case
(
1
,
512
,
512
,
7
*
2
,
7
*
2
,
3
,
4
,
1
,
2
);
}
TEST_F
(
ARM_COMMON_BENCHMARK_MULTI_THREADS
,
BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE2
)
{
constexpr
size_t
RUNS
=
50
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录