Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3117bfb7
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3117bfb7
编写于
6月 08, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(dnn/arm): nchw44 direct int8 support 8832
GitOrigin-RevId: 696fa05d943b28fcec3a236bb8518fb255eae9db
上级
4e0c9ad3
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
1645 addition
and
2200 deletion
+1645
-2200
dnn/src/arm_common/conv_bias/int8/algos.h
dnn/src/arm_common/conv_bias/int8/algos.h
+6
-23
dnn/src/arm_common/conv_bias/int8/direct.h
dnn/src/arm_common/conv_bias/int8/direct.h
+0
-20
dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp
dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp
+149
-136
dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h
dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h
+1428
-0
dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_algo.cpp
.../arm_common/conv_bias/int8/direct_stride1_nchw44_algo.cpp
+0
-393
dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_kern.cpp
.../arm_common/conv_bias/int8/direct_stride1_nchw44_kern.cpp
+0
-791
dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern.cpp
.../arm_common/conv_bias/int8/direct_stride2_nchw44_kern.cpp
+0
-793
dnn/src/arm_common/conv_bias/opr_impl.cpp
dnn/src/arm_common/conv_bias/opr_impl.cpp
+2
-4
dnn/src/arm_common/conv_bias/opr_impl.h
dnn/src/arm_common/conv_bias/opr_impl.h
+1
-2
dnn/src/arm_common/elemwise_helper/kimpl/none.h
dnn/src/arm_common/elemwise_helper/kimpl/none.h
+2
-0
dnn/test/arm_common/conv_bias.cpp
dnn/test/arm_common/conv_bias.cpp
+9
-1
dnn/test/arm_common/conv_bias_multi_thread.cpp
dnn/test/arm_common/conv_bias_multi_thread.cpp
+48
-37
未找到文件。
dnn/src/arm_common/conv_bias/int8/algos.h
浏览文件 @
3117bfb7
...
...
@@ -38,23 +38,6 @@ public:
const
NCBKernSizeParam
&
param
)
const
override
;
};
class
ConvBiasImpl
::
AlgoS8DirectStride1NCHW44
final
:
public
AlgoBase
{
public:
AlgoS8DirectStride1NCHW44
()
{}
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"S8_NCHW44_DIRECT_STRD1"
;
}
bool
usable
(
fallback
::
ConvBiasImpl
*
opr
,
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
algo_selection_strategy
)
const
override
;
size_t
get_workspace
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
fallback
::
ConvBiasImpl
*
opr
,
const
NCBKernSizeParam
&
param
)
const
override
;
bool
is_preferred
(
megdnn
::
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
override
;
};
class
ConvBiasImpl
::
AlgoS8DirectStride2
final
:
public
AlgoBase
{
bool
m_large_group
;
...
...
@@ -74,11 +57,11 @@ public:
const
NCBKernSizeParam
&
param
)
const
override
;
};
class
ConvBiasImpl
::
AlgoS8Direct
Stride2
NCHW44
final
:
public
AlgoBase
{
class
ConvBiasImpl
::
AlgoS8DirectNCHW44
final
:
public
AlgoBase
{
public:
AlgoS8Direct
Stride2
NCHW44
()
{}
AlgoS8DirectNCHW44
()
{}
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"S8_NCHW44_DIRECT
_STRD2
"
;
}
const
char
*
name
()
const
override
{
return
"S8_NCHW44_DIRECT"
;
}
bool
usable
(
fallback
::
ConvBiasImpl
*
opr
,
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
algo_selection_strategy
)
const
override
;
size_t
get_workspace
(
fallback
::
ConvBiasImpl
*
,
...
...
@@ -245,8 +228,8 @@ private:
//=======================input int8 compute fp32 output int8============
class
ConvBiasImpl
::
AlgoS8CF32WinogradF23_4x4_NCHW44
final
:
public
AlgoBase
{
public:
AlgoS8CF32WinogradF23_4x4_NCHW44
(
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
uint32_t
tile_size
)
AlgoS8CF32WinogradF23_4x4_NCHW44
(
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
uint32_t
tile_size
)
:
m_matmul_algo
{
matmul_algo
},
m_tile_size
{
tile_size
}
{}
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
...
...
@@ -277,7 +260,7 @@ private:
class
ConvBiasImpl
::
AlgoS8WinogradF23_8x8_NCHW44
final
:
public
AlgoBase
{
public:
AlgoS8WinogradF23_8x8_NCHW44
(
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
uint32_t
tile_size
)
uint32_t
tile_size
)
:
m_matmul_algo
{
matmul_algo
},
m_tile_size
{
tile_size
}
{}
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
...
...
dnn/src/arm_common/conv_bias/int8/direct.h
浏览文件 @
3117bfb7
...
...
@@ -36,26 +36,6 @@ KERN(stride2, 7, nchw)
#undef KERN
#define KERN(stride, i, layout) \
template <BiasMode bias_mode, typename Op, int remain_w> \
void conv_direct_##stride##_##i##x##i##_int8_##layout( \
const int8_t* src, const int8_t* filter, const int32_t* bias, \
int32_t* temp, int8_t* dst, const size_t OC, const size_t IC, \
const size_t IH, const size_t IW, const size_t OH, \
const size_t OW, const Op& op);
KERN
(
stride1
,
2
,
nchw44
)
KERN
(
stride1
,
3
,
nchw44
)
KERN
(
stride1
,
5
,
nchw44
)
KERN
(
stride1
,
7
,
nchw44
)
KERN
(
stride2
,
2
,
nchw44
)
KERN
(
stride2
,
3
,
nchw44
)
KERN
(
stride2
,
5
,
nchw44
)
KERN
(
stride2
,
7
,
nchw44
)
#undef KERN
void
nchw44_pack_filter
(
const
int8_t
*
src
,
int8_t
*
dst
,
int
filter
);
void
nchw44_pack_src
(
const
int8_t
*
src
,
int8_t
*
dst
,
int
length
);
}
// namespace conv_bias
}
// namespace arm_common
}
// namespace megdnn
...
...
dnn/src/arm_common/conv_bias/int8/direct_
stride2_
nchw44_algo.cpp
→
dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp
浏览文件 @
3117bfb7
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_
stride2_
nchw44_algo.cpp
* \file dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...
...
@@ -13,6 +13,7 @@
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"
...
...
@@ -25,28 +26,19 @@ using conv_fun = std::function<void(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
CpuNDRange
&
workspace_ids
,
const
CpuNDRange
&
ncb_range
)
>
;
MIDOUT_DECL
(
megdnn_arm_common_conv_bias_int8_nchw44
_stride2
)
MIDOUT_DECL
(
megdnn_arm_common_conv_bias_int8_nchw44
)
static
void
get_rectified_size
(
const
megdnn
::
fallback
::
ConvBiasImpl
::
NCBKernSizeParam
&
param
,
size_t
&
IH2
,
size_t
&
IW2
,
size_t
&
OH2
,
size_t
&
OW
2
)
{
const
megdnn
::
fallback
::
ConvBiasImpl
::
NCBKernSizeParam
&
param
,
int
&
ih2
,
int
&
iw
2
)
{
auto
&&
fm
=
param
.
filter_meta
;
size_t
SW
=
fm
.
stride
[
1
];
size_t
IH
=
param
.
isz
[
0
];
size_t
IW
=
param
.
isz
[
1
];
size_t
OH
=
param
.
osz
[
0
];
size_t
OW
=
param
.
osz
[
1
];
size_t
FH
=
fm
.
spatial
[
0
];
size_t
FW
=
fm
.
spatial
[
1
];
int
ih
=
param
.
isz
[
0
];
int
iw
=
param
.
isz
[
1
];
int
ph
=
fm
.
padding
[
0
];
int
pw
=
fm
.
padding
[
1
];
OH2
=
OH
;
OW2
=
(
OW
+
7
)
&
~
7
;
IH2
=
SW
*
OH
+
FH
-
SW
;
IW2
=
SW
*
OW2
+
FW
-
SW
;
// Because stride is 2, sometimes IW == IW2+1. Do a max update to
// handle this case.
IH2
=
std
::
max
(
IH2
,
IH
);
IW2
=
std
::
max
(
IW2
,
IW
);
ih2
=
ih
+
ph
*
2
;
iw2
=
iw
+
pw
*
2
;
}
static
WorkspaceBundle
get_bundle
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
)
{
constexpr
size_t
src_expand
=
4
;
...
...
@@ -57,8 +49,8 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
size_t
OC
=
fm
.
ocpg
;
size_t
FH
=
fm
.
spatial
[
0
];
size_t
FW
=
fm
.
spatial
[
1
];
size_t
IH2
,
IW2
,
OH2
,
O
W2
;
get_rectified_size
(
param
,
IH2
,
IW2
,
OH2
,
OW2
);
int
IH2
,
I
W2
;
get_rectified_size
(
param
,
IH2
,
IW2
);
if
(
group
==
1
)
{
size_t
src_size
=
batch
*
group
*
IC
*
IH2
*
IW2
*
sizeof
(
int8_t
)
*
src_expand
;
...
...
@@ -76,16 +68,16 @@ static void copy_padding_kern(WorkspaceBundle bundle,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
CpuNDRange
&
workspace_ids
)
{
size_
t
IH
=
kern_param
.
isz
[
0
];
size_
t
IW
=
kern_param
.
isz
[
1
];
size_
t
IC
=
kern_param
.
filter_meta
.
icpg
;
size_
t
PH
=
kern_param
.
filter_meta
.
padding
[
0
];
size_
t
PW
=
kern_param
.
filter_meta
.
padding
[
1
];
size_
t
GROUP
=
kern_param
.
filter_meta
.
group
;
in
t
IH
=
kern_param
.
isz
[
0
];
in
t
IW
=
kern_param
.
isz
[
1
];
in
t
IC
=
kern_param
.
filter_meta
.
icpg
;
in
t
PH
=
kern_param
.
filter_meta
.
padding
[
0
];
in
t
PW
=
kern_param
.
filter_meta
.
padding
[
1
];
in
t
GROUP
=
kern_param
.
filter_meta
.
group
;
size_t
IH2
,
IW2
,
OH2
,
O
W2
;
get_rectified_size
(
kern_param
,
IH2
,
IW2
,
OH2
,
OW2
);
size_
t
padding_group_size
=
IH2
*
IW2
*
IC
;
int
IH2
,
I
W2
;
get_rectified_size
(
kern_param
,
IH2
,
IW2
);
in
t
padding_group_size
=
IH2
*
IW2
*
IC
;
bundle
.
set
(
kern_param
.
workspace_ptr
);
//! Used for get the workspace offset
constexpr
int
pack_ic
=
4
;
...
...
@@ -100,16 +92,10 @@ static void copy_padding_kern(WorkspaceBundle bundle,
size_t
group_id
=
ncb_index
.
ndrange_id
[
1
];
size_t
group_pack_size
=
1
;
int
nr_pad_h
=
PH
*
IW2
*
pack_ic
*
expend_element
;
int
nr_pad_w
=
PW
*
pack_ic
*
expend_element
;
int
over_pad
=
std
::
max
(
0
_z
,
IW2
-
IW
-
2
*
PW
)
*
pack_ic
*
expend_element
;
int
row_last_pad
=
((
int
)
IW2
-
(
int
)
IW
-
2
*
(
int
)
PW
)
>=
0
?
nr_pad_w
+
over_pad
:
(
IW2
-
IW
-
PW
)
*
pack_ic
*
expend_element
;
int
col_last_pad
=
((
int
)
IH2
-
(
int
)
IH
-
2
*
(
int
)
PH
)
>=
0
?
nr_pad_h
:
(
IH2
-
IH
-
PH
)
*
IW2
*
pack_ic
*
expend_element
;
int
nr_pad_h
=
PH
*
IW2
*
pack_ic
*
expend_element
;
int
row_last_pad
=
(
IW2
-
IW
-
PW
)
*
pack_ic
*
expend_element
;
int
col_last_pad
=
(
IH2
-
IH
-
PH
)
*
IW2
*
pack_ic
*
expend_element
;
const
int8_t
*
sptr
=
static_cast
<
const
int8_t
*>
(
kern_param
.
src
<
int8_t
>
(
batch_id
,
group_id
,
workspace_ic_id
,
group_pack_size
,
pack_ic
));
...
...
@@ -129,7 +115,7 @@ static void copy_padding_kern(WorkspaceBundle bundle,
rep
(
ih_idx
,
IH
)
{
std
::
memset
(
sptr_base
,
0
,
nr_pad_w
*
sizeof
(
int8_t
));
sptr_base
+=
nr_pad_w
;
conv_bias
::
nchw44_pack_src
(
sptr
,
sptr_base
,
IW
);
nchw44_pack_src
(
sptr
,
sptr_base
,
IW
);
sptr_base
+=
IW
*
pack_ic
*
expend_element
;
sptr
+=
IW
*
pack_ic
;
std
::
memset
(
sptr_base
,
0
,
row_last_pad
*
sizeof
(
int8_t
));
...
...
@@ -140,7 +126,8 @@ static void copy_padding_kern(WorkspaceBundle bundle,
}
}
template
<
size_t
filter
,
BiasMode
bias_mode
,
typename
Op
,
int
ow_remain
>
template
<
size_t
filter
,
BiasMode
bias_mode
,
typename
Op
,
int
ow_remain
,
typename
DstType
,
int
stride
>
static
void
do_conv_kern
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
...
...
@@ -153,12 +140,12 @@ static void do_conv_kern(WorkspaceBundle bundle,
size_t
IC
=
kern_param
.
filter_meta
.
icpg
;
size_t
OC
=
kern_param
.
filter_meta
.
ocpg
;
size_t
GROUP
=
kern_param
.
filter_meta
.
group
;
size_t
IH2
,
IW2
,
OH2
,
O
W2
;
get_rectified_size
(
kern_param
,
IH2
,
IW2
,
OH2
,
OW2
);
int
IH2
,
I
W2
;
get_rectified_size
(
kern_param
,
IH2
,
IW2
);
bool
need_post_process
=
kern_param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
;
//! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f)
Op
op
=
Op
(
1.0
f
,
4.0
f
);
Op
op
(
1.
f
,
4.
f
);
if
(
need_post_process
)
{
float
scale_bias
=
kern_param
.
bias_type
.
param
<
dtype
::
QuantizedS32
>
().
scale
;
...
...
@@ -191,49 +178,43 @@ static void do_conv_kern(WorkspaceBundle bundle,
const
int8_t
*
fptr
=
kern_param
.
filter
<
dt_int8
>
(
group_id
)
+
oc_idx
*
FH
*
FW
*
IC
;
void
*
dst
=
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
ptrdiff_t
>
(
kern_param
.
dst
<
void
>
(
batch_id
,
group_id
))
+
oc_idx
*
OH
*
OW
);
DstType
*
dst
=
reinterpret_cast
<
DstType
*>
(
kern_param
.
dst
<
void
>
(
batch_id
,
group_id
,
oc_idx
));
const
int32_t
*
bptr
=
kern_param
.
bias
<
dt_int32
>
(
batch_id
,
group_id
)
+
oc_idx
;
auto
packed_weight
=
reinterpret_cast
<
int8_t
*>
(
bundle
.
get
(
1
))
+
group_id
*
OC
*
IC
*
FH
*
FW
+
oc_idx
*
IC
*
FH
*
FW
;
conv_bias
::
nchw44_pack_filter
(
fptr
,
packed_weight
,
oc_block
/
4
*
IC
/
4
*
FH
*
FW
);
#define KERN1_NCHW44_CONV(filter) \
conv_bias::conv_direct_stride2_##filter##x##filter##_int8_nchw44< \
bias_mode, Op, ow_remain>(sptr, packed_weight, bptr, nullptr, \
static_cast<int8_t*>(dst), oc_block, IC, \
IH2, IW2, OH, OW, op)
DISPATCH_FILTER
(
filter
,
KERN1_NCHW44_CONV
)
#undef KERN1_NCHW44_CONV
nchw44_pack_filter
(
fptr
,
packed_weight
,
oc_block
/
4
*
IC
/
4
*
FH
*
FW
);
conv_direct_int8_nchw44
<
bias_mode
,
Op
,
ow_remain
,
filter
,
DstType
,
stride
>
(
sptr
,
packed_weight
,
bptr
,
nullptr
,
static_cast
<
DstType
*>
(
dst
),
oc_block
,
IC
,
IH2
,
IW2
,
OH
,
OW
,
op
);
}
/* ===================== stride2 algo ===================== */
bool
ConvBiasImpl
::
AlgoS8DirectStride2NCHW44
::
usable
(
bool
ConvBiasImpl
::
AlgoS8DirectNCHW44
::
usable
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
algo_selection_strategy
)
const
{
MEGDNN_MARK_USED_VAR
(
algo_selection_strategy
);
auto
&&
fm
=
param
.
filter_meta
;
auto
FH
=
fm
.
spatial
[
0
];
auto
OC
=
fm
.
ocpg
;
auto
IC
=
fm
.
icpg
;
bool
avaible
=
//! src and filter are qint8, dst is qint8 or qint32
const
int
fh
=
fm
.
spatial
[
0
];
const
int
fw
=
fm
.
spatial
[
1
];
const
int
oc
=
fm
.
ocpg
;
const
int
ic
=
fm
.
icpg
;
const
bool
avaible
=
//! src and filter are qint8, dst is qint8 or qint32
((
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
param
.
filter_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
(
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)))
&&
(
fm
.
format
==
param
::
Convolution
::
Format
::
NCHW44
)
&&
(
OC
%
4
==
0
&&
IC
%
4
==
0
&&
OC
>=
4
)
&&
!
fm
.
should_flip
&&
(
oc
%
4
==
0
&&
ic
%
4
==
0
&&
oc
>=
4
)
&&
!
fm
.
should_flip
&&
fm
.
spatial_ndim
==
2
&&
fm
.
dilation
[
0
]
==
1
&&
fm
.
dilation
[
1
]
==
1
&&
fm
.
stride
[
0
]
==
2
&&
fm
.
stride
[
1
]
==
2
&&
FH
==
fm
.
spatial
[
1
]
&&
(
FH
==
2
||
FH
==
3
||
FH
==
5
||
FH
==
7
)
&&
fm
.
dilation
[
1
]
==
1
&&
fm
.
stride
[
0
]
==
fm
.
stride
[
1
]
&&
(
fm
.
stride
[
0
]
==
2
||
fm
.
stride
[
0
]
==
1
)
&&
fh
==
fw
&&
(
fh
==
2
||
fh
==
3
||
fh
==
5
||
fh
==
7
)
&&
param
.
bias_mode
!=
BiasMode
::
BIAS
;
return
avaible
;
}
bool
ConvBiasImpl
::
AlgoS8Direct
Stride2
NCHW44
::
is_preferred
(
bool
ConvBiasImpl
::
AlgoS8DirectNCHW44
::
is_preferred
(
megdnn
::
fallback
::
ConvBiasImpl
*
conv_bias_impl_ptr
,
const
NCBKernSizeParam
&
param
)
const
{
// TODO: benchmark and fix
...
...
@@ -242,13 +223,13 @@ bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::is_preferred(
return
false
;
}
size_t
ConvBiasImpl
::
AlgoS8Direct
Stride2
NCHW44
::
get_workspace
(
size_t
ConvBiasImpl
::
AlgoS8DirectNCHW44
::
get_workspace
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
return
get_bundle
(
param
).
total_size_in_bytes
();
}
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ConvBiasImpl
::
AlgoS8Direct
Stride2
NCHW44
::
dispatch_kerns
(
ConvBiasImpl
::
AlgoS8DirectNCHW44
::
dispatch_kerns
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
auto
fm
=
param
.
filter_meta
;
size_t
N
=
param
.
n
;
...
...
@@ -261,97 +242,129 @@ ConvBiasImpl::AlgoS8DirectStride2NCHW44::dispatch_kerns(
WorkspaceBundle
wbundle
=
get_bundle
(
param
);
conv_fun
do_conv_fun
=
nullptr
;
int
ow_remain
=
OW
%
8
;
bool
need_post_process
=
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
;
// NOTE: remain_w is not used to gen hash of midout for compatible with changing
// shape runtime
#define DO_CONV_KERN_FUN(filter, bias_mode, remain_w, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride2, \
midout_iv(#filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w>; \
} \
#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, remain_w, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44, \
midout_iv(#stride #dst_type #filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w, dst_type, \
stride>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(filter, bias_mode, remain_w) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
default: \
megdnn_assert(0); \
break; \
#define GET_OP_PARAM(stride, filter, bias_mode, remain_w) \
if (need_post_process) { \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
remain_w, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
remain_w, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
remain_w, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
default: \
megdnn_assert(0, "no supported noline mode"); \
break; \
} \
} else { \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, \
remain_w, NoneOp<dt_int32>) \
break; \
default: \
megdnn_assert( \
0, \
"only support IDENTITY mode when dst is not qint8"); \
break; \
} \
}
#define GET_REMAIN_W_PARAM(filter, bias_mode) \
switch (ow_remain) { \
case 0: \
GET_OP_PARAM(filter, bias_mode, 0); \
break; \
case 1: \
GET_OP_PARAM(filter, bias_mode, 1); \
break; \
case 2: \
GET_OP_PARAM(filter, bias_mode, 2); \
break; \
case 3: \
GET_OP_PARAM(filter, bias_mode, 3); \
break; \
case 4: \
GET_OP_PARAM(filter, bias_mode, 4); \
break; \
case 5: \
GET_OP_PARAM(filter, bias_mode, 5); \
break; \
case 6: \
GET_OP_PARAM(filter, bias_mode, 6); \
break; \
case 7: \
GET_OP_PARAM(filter, bias_mode, 7); \
break; \
default: \
megdnn_assert(0); \
#define GET_REMAIN_W_PARAM(
stride,
filter, bias_mode) \
switch (ow_remain) {
\
case 0:
\
GET_OP_PARAM(
stride,
filter, bias_mode, 0); \
break;
\
case 1:
\
GET_OP_PARAM(
stride,
filter, bias_mode, 1); \
break;
\
case 2:
\
GET_OP_PARAM(
stride,
filter, bias_mode, 2); \
break;
\
case 3:
\
GET_OP_PARAM(
stride,
filter, bias_mode, 3); \
break;
\
case 4:
\
GET_OP_PARAM(
stride,
filter, bias_mode, 4); \
break;
\
case 5:
\
GET_OP_PARAM(
stride,
filter, bias_mode, 5); \
break;
\
case 6:
\
GET_OP_PARAM(
stride,
filter, bias_mode, 6); \
break;
\
case 7:
\
GET_OP_PARAM(
stride,
filter, bias_mode, 7); \
break;
\
default:
\
megdnn_assert(0);
\
}
#define GET_BIAS_MODE_PARAM(filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_REMAIN_W_PARAM(filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_REMAIN_W_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
#define GET_BIAS_MODE_PARAM(stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_REMAIN_W_PARAM(stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_REMAIN_W_PARAM(stride, filter, \
BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN(
)
\
#define DISPATCH_CONV_KERN(
stride)
\
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(
2)
\
GET_BIAS_MODE_PARAM(
stride, 2)
\
break; \
case 3: \
GET_BIAS_MODE_PARAM(
3)
\
GET_BIAS_MODE_PARAM(
stride, 3)
\
break; \
case 5: \
GET_BIAS_MODE_PARAM(
5)
\
GET_BIAS_MODE_PARAM(
stride, 5)
\
break; \
case 7: \
GET_BIAS_MODE_PARAM(
7)
\
GET_BIAS_MODE_PARAM(
stride, 7)
\
break; \
default: \
megdnn_assert(0); \
break; \
}
DISPATCH_CONV_KERN
();
switch
(
param
.
filter_meta
.
stride
[
0
])
{
case
1
:
DISPATCH_CONV_KERN
(
1
);
break
;
case
2
:
DISPATCH_CONV_KERN
(
2
);
break
;
default:
megdnn_throw
(
ssprintf
(
"Unsupport stride size %u for the first conv"
,
param
.
filter_meta
.
stride
[
0
])
.
c_str
());
break
;
}
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
...
...
dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h
0 → 100644
浏览文件 @
3117bfb7
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
namespace
megdnn
{
namespace
arm_common
{
namespace
{
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
,
int
c_dim
,
typename
DstType
>
static
void
ker_neon_dirctconv_2x2s1_oc8_ow8
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
DstType
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
ic_step
=
4
;
constexpr
int
oc_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc4
=
oc_step
*
fh
*
fw
*
ic
;
int32x4_t
c
[
2
][
8
];
int8x16_t
weight
[
2
][
2
];
int8x16_t
src
[
8
+
1
];
int16x8_t
temp_c
[
4
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
8
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
2
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
3
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
4
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
5
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
6
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
7
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
8
*
16
));
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
][
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
0
][
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
weight
[
1
][
0
]
=
vld1q_s8
(
read_weight_ptr
+
ld_weight_oc4
);
weight
[
1
][
1
]
=
vld1q_s8
(
read_weight_ptr
+
ld_weight_oc4
+
16
);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
0
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
1
][
0
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
0
],
c
[
1
][
0
],
temp_c
[
1
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
1
],
c
[
0
][
1
],
temp_c
[
2
]);
c
[
1
][
1
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
1
],
c
[
1
][
1
],
temp_c
[
3
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
1
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
1
][
0
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
1
],
c
[
1
][
0
],
temp_c
[
1
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
2
],
c
[
0
][
1
],
temp_c
[
2
]);
c
[
1
][
1
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
2
],
c
[
1
][
1
],
temp_c
[
3
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
2
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
1
][
2
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
2
],
c
[
1
][
2
],
temp_c
[
1
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
3
],
c
[
0
][
3
],
temp_c
[
2
]);
c
[
1
][
3
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
3
],
c
[
1
][
3
],
temp_c
[
3
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
3
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
1
][
2
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
3
],
c
[
1
][
2
],
temp_c
[
1
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
4
],
c
[
0
][
3
],
temp_c
[
2
]);
c
[
1
][
3
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
4
],
c
[
1
][
3
],
temp_c
[
3
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
4
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
1
][
4
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
4
],
c
[
1
][
4
],
temp_c
[
1
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
5
],
c
[
0
][
5
],
temp_c
[
2
]);
c
[
1
][
5
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
5
],
c
[
1
][
5
],
temp_c
[
3
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
5
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
1
][
4
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
5
],
c
[
1
][
4
],
temp_c
[
1
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
6
],
c
[
0
][
5
],
temp_c
[
2
]);
c
[
1
][
5
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
6
],
c
[
1
][
5
],
temp_c
[
3
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
6
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
1
][
6
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
6
],
c
[
1
][
6
],
temp_c
[
1
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
7
],
c
[
0
][
7
],
temp_c
[
2
]);
c
[
1
][
7
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
7
],
c
[
1
][
7
],
temp_c
[
3
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
7
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
1
][
6
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
7
],
c
[
1
][
6
],
temp_c
[
1
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
8
],
c
[
0
][
7
],
temp_c
[
2
]);
c
[
1
][
7
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
8
],
c
[
1
][
7
],
temp_c
[
3
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
DstType
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
,
int
c_dim
,
typename
DstType
>
static
void
ker_neon_dirctconv_2x2s1_oc4_ow8
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
DstType
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
int32x4_t
c
[
1
][
8
];
int8x16_t
weight
[
1
][
2
];
int8x16_t
src
[
8
+
1
];
int16x8_t
temp_c
[
2
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
8
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
2
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
3
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
4
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
5
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
6
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
7
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
8
*
16
));
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
][
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
0
][
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
0
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
1
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
1
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
2
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
2
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
3
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
3
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
4
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
4
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
5
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
5
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
6
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
6
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
7
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
7
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
8
],
c
[
0
][
7
],
temp_c
[
1
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
DstType
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
,
int
c_dim
,
typename
DstType
>
struct
KerNeonDirectStride1Int8
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
DstType
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
const
Op
&
op
,
int
ld_dst_oc
);
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
c_dim
,
typename
DstType
>
struct
KerNeonDirectStride1Int8
<
bias_mode
,
Op
,
remain_w
,
2
,
c_dim
,
DstType
>
{
static
void
impl
(
const
int8_t
*
,
const
int8_t
*
,
const
int32_t
*
,
DstType
*
,
int
,
int
,
int
,
const
Op
&
,
int
)
{
megdnn_throw
(
"no impl"
);
}
};
/**
dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc>
example: (format like weight<oc, ic>)
packed weight
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3>
---------------------------------------------------------------------
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0>
dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0>
**/
//! TODO: can try oh = 2 impl, oc = 8 impl
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
c_dim
,
typename
DstType
>
struct
KerNeonDirectStride1Int8
<
bias_mode
,
Op
,
remain_w
,
3
,
c_dim
,
DstType
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
DstType
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
const
Op
&
op
,
int
ld_dst_oc
)
{
constexpr
int
filter_size
=
3
;
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
int32x4_t
c
[
c_dim
][
8
];
int8x16_t
weight
[
3
];
int8x16_t
src
[
8
+
2
];
int16x8_t
temp_c
[
2
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
8
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
2
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
3
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
4
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
5
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
6
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
7
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
8
*
16
));
src
[
9
]
=
vld1q_s8
((
src_ic_0_3
+
9
*
16
));
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
weight
[
2
]
=
vld1q_s8
(
read_weight_ptr
+
2
*
16
);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
1
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
2
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
3
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
3
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
4
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
5
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
5
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
6
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
7
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
6
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
7
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
7
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
8
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
8
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
9
],
c
[
0
][
7
],
temp_c
[
1
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_ocx_ow8_remain_static_dt
<
1
,
remain_w
,
Op
,
DstType
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
c_dim
,
typename
DstType
>
struct
KerNeonDirectStride1Int8
<
bias_mode
,
Op
,
remain_w
,
5
,
c_dim
,
DstType
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
DstType
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
const
Op
&
op
,
int
ld_dst_oc
)
{
constexpr
int
filter_size
=
5
;
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
int32x4_t
c
[
c_dim
][
8
];
int8x16_t
weight
[
5
];
int8x16_t
src
[
8
+
2
];
int16x8_t
temp_c
[
2
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
8
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
2
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
3
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
4
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
5
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
6
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
7
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
8
*
16
));
src
[
9
]
=
vld1q_s8
((
src_ic_0_3
+
9
*
16
));
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
weight
[
2
]
=
vld1q_s8
(
read_weight_ptr
+
2
*
16
);
weight
[
3
]
=
vld1q_s8
(
read_weight_ptr
+
3
*
16
);
weight
[
4
]
=
vld1q_s8
(
read_weight_ptr
+
4
*
16
);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
1
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
2
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
3
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
3
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
4
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
4
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
5
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
3
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
4
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
5
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
5
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
6
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
6
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
7
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
5
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
6
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
7
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
7
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
8
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
8
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
9
],
c
[
0
][
5
],
temp_c
[
1
]);
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
+
10
*
16
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
11
*
16
));
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
6
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
7
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
7
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
8
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
8
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
9
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
9
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
0
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
0
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
1
],
c
[
0
][
7
],
temp_c
[
1
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_ocx_ow8_remain_static_dt
<
1
,
remain_w
,
Op
,
DstType
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
c_dim
,
typename
DstType
>
struct
KerNeonDirectStride1Int8
<
bias_mode
,
Op
,
remain_w
,
7
,
c_dim
,
DstType
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
DstType
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
const
Op
&
op
,
int
ld_dst_oc
)
{
constexpr
int
filter_size
=
7
;
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
int32x4_t
c
[
c_dim
][
8
];
int8x16_t
weight
[
7
];
int8x16_t
src
[
8
+
2
];
int16x8_t
temp_c
[
2
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
8
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
2
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
3
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
4
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
5
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
6
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
7
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
8
*
16
));
src
[
9
]
=
vld1q_s8
((
src_ic_0_3
+
9
*
16
));
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
weight
[
2
]
=
vld1q_s8
(
read_weight_ptr
+
2
*
16
);
weight
[
3
]
=
vld1q_s8
(
read_weight_ptr
+
3
*
16
);
weight
[
4
]
=
vld1q_s8
(
read_weight_ptr
+
4
*
16
);
weight
[
5
]
=
vld1q_s8
(
read_weight_ptr
+
5
*
16
);
weight
[
6
]
=
vld1q_s8
(
read_weight_ptr
+
6
*
16
);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
1
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
2
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
3
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
3
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
4
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
4
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
5
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
5
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
6
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
6
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
7
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
3
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
4
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
5
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
5
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
6
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
6
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
7
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
7
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
8
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
8
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
9
],
c
[
0
][
3
],
temp_c
[
1
]);
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
+
10
*
16
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
11
*
16
));
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
5
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
6
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
7
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
7
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
8
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
8
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
9
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
9
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
0
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
0
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
1
],
c
[
0
][
5
],
temp_c
[
1
]);
src
[
2
]
=
vld1q_s8
(
src_ic_0_3
+
12
*
16
);
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
13
*
16
));
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
6
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
7
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
7
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
8
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
8
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
9
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
9
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
0
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
0
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
1
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
1
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
2
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
2
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
3
],
c
[
0
][
7
],
temp_c
[
1
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_ocx_ow8_remain_static_dt
<
1
,
remain_w
,
Op
,
DstType
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
/**
origin weight shape <oc/4, ic/4, fh, fw, 4, 4>
packed weight shape <oc/4, ic/4, fh, fw, 16>
example: (format like weight<oc, ic>)
origin
<0, 0> <1, 0> <2, 0> <3, 0>
<0, 1> <1, 1> <2, 1> <3, 1>
<0, 2> <1, 2> <2, 2> <3, 2>
<0, 3> <1, 3> <2, 3> <3, 3>
packed
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3>
---------------------------------------------------------------------
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0>
**/
static
inline
void
nchw44_pack_filter
(
const
int8_t
*
src
,
int8_t
*
dst
,
int
length
)
{
static
const
uint8_t
weight_idx_buffer
[
16
]
=
{
0
,
4
,
9
,
13
,
2
,
6
,
11
,
15
,
12
,
8
,
5
,
1
,
14
,
10
,
7
,
3
};
constexpr
int
simd_len
=
16
;
uint8x16_t
weight_idx
=
vld1q_u8
(
weight_idx_buffer
);
for
(
int
i
=
0
;
i
<
length
;
i
++
)
{
int8x16_t
result
=
vldq_tbl_s8
(
src
+
i
*
simd_len
,
weight_idx
);
vst1q_s8
(
dst
+
i
*
simd_len
,
result
);
}
}
/**
origin src shape <n, ic/4, h, w, 4>
packed src shape <n, ic/4, h, w, 16>
example: (format like <ic>)
origin
<0> <0> <0> <0>
packed
low 64 bit <0> <1> <2> <3> | <0> <1> <2> <3>
---------------------------------------------------------------------
high 64 bit <3> <2> <1> <0> | <3> <2> <1> <0>
**/
static
inline
void
nchw44_pack_src
(
const
int8_t
*
src
,
int8_t
*
dst
,
int
length
)
{
static
const
uint8_t
src_idx_buffer
[
16
]
=
{
0
,
1
,
2
,
3
,
0
,
1
,
2
,
3
,
3
,
2
,
1
,
0
,
3
,
2
,
1
,
0
};
constexpr
int
pack_ic
=
4
;
constexpr
int
simd_len
=
16
;
uint8x16_t
src_idx
=
vld1q_u8
(
src_idx_buffer
);
for
(
int
i
=
0
;
i
<
length
;
i
++
)
{
int8x16_t
result
=
vld_dup_tbl_s32
(
src
+
i
*
pack_ic
,
src_idx
);
vst1q_s8
(
dst
+
i
*
simd_len
,
result
);
}
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
typename
DstType
>
void
conv_direct_stride1_2x2_int8_nchw44
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
DstType
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
MEGDNN_MARK_USED_VAR
(
temp
);
constexpr
size_t
filter_size
=
2
;
constexpr
size_t
fh
=
filter_size
;
constexpr
size_t
fw
=
filter_size
;
constexpr
size_t
ic_step
=
4
;
constexpr
size_t
oc_step
=
4
;
constexpr
size_t
big_oc_step
=
8
;
constexpr
size_t
oh_step
=
1
;
constexpr
size_t
ow_step
=
8
;
constexpr
int
pack_iw_len
=
4
;
const
size_t
img_stride
=
oh
*
ow
;
const
size_t
ow_end
=
ow
/
ow_step
*
ow_step
;
const
size_t
ow_remain
=
ow
-
ow_end
;
const
size_t
oc_end
=
oc
/
big_oc_step
*
big_oc_step
;
const
size_t
oc_remain
=
oc
-
oc_end
;
const
int
ld_oc
=
oh
*
ow
*
oc_step
;
for
(
size_t
oc_idx
=
0
;
oc_idx
<
oc_end
;
oc_idx
+=
big_oc_step
)
{
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
iw
+
ow_idx
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
ker_neon_dirctconv_2x2s1_oc8_ow8
<
bias_mode
,
Op
,
0
,
filter_size
,
2
,
DstType
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_oc
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
iw
+
ow_end
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
ker_neon_dirctconv_2x2s1_oc8_ow8
<
bias_mode
,
Op
,
remain_w
,
filter_size
,
2
,
DstType
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_oc
,
op
);
}
}
}
if
(
oc_remain
>
0
)
{
const
size_t
oc_idx
=
oc_end
;
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
iw
+
ow_idx
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
ker_neon_dirctconv_2x2s1_oc4_ow8
<
bias_mode
,
Op
,
0
,
filter_size
,
1
,
DstType
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_oc
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
iw
+
ow_end
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
ker_neon_dirctconv_2x2s1_oc4_ow8
<
bias_mode
,
Op
,
remain_w
,
filter_size
,
1
,
DstType
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_oc
,
op
);
}
}
}
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
,
typename
DstType
>
void
conv_direct_stride1_int8_nchw44_kern
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
DstType
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
MEGDNN_MARK_USED_VAR
(
temp
);
constexpr
size_t
fh
=
filter_size
;
constexpr
size_t
fw
=
filter_size
;
constexpr
size_t
ic_step
=
4
;
constexpr
size_t
oc_step
=
4
;
constexpr
size_t
oh_step
=
1
;
constexpr
size_t
ow_step
=
8
;
constexpr
int
pack_iw_len
=
4
;
const
size_t
img_stride
=
oh
*
ow
;
const
int
ld_dst_oc
=
oh
*
ow
*
oc_step
;
const
size_t
ow_end
=
ow
/
ow_step
*
ow_step
;
const
size_t
ow_remain
=
ow
-
ow_end
;
for
(
size_t
oc_idx
=
0
;
oc_idx
<
oc
;
oc_idx
+=
oc_step
)
{
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
iw
+
ow_idx
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
KerNeonDirectStride1Int8
<
bias_mode
,
Op
,
ow_step
,
filter_size
,
1
,
DstType
>::
impl
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
,
ld_dst_oc
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
iw
+
ow_end
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
KerNeonDirectStride1Int8
<
bias_mode
,
Op
,
remain_w
,
filter_size
,
1
,
DstType
>::
impl
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
,
ld_dst_oc
);
}
}
}
}
/////////////////////stride 2/////////////////
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
,
int
c_dim
,
typename
DstType
>
struct
KerNeonDirectStride2Int8
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
DstType
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
const
Op
&
op
,
int
ld_dst_oc
);
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
c_dim
,
typename
DstType
>
struct
KerNeonDirectStride2Int8
<
bias_mode
,
Op
,
remain_w
,
2
,
c_dim
,
DstType
>
{
static
void
impl
(
const
int8_t
*
,
const
int8_t
*
,
const
int32_t
*
,
DstType
*
,
int
,
int
,
int
,
const
Op
&
,
int
)
{
megdnn_throw
(
"no impl"
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
,
int
c_dim
,
typename
DstType
>
static
void
ker_neon_dirctconv_2x2s2_oc8_ow8
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
DstType
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
ic_step
=
4
;
constexpr
int
oc_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc4
=
oc_step
*
fh
*
fw
*
ic
;
int32x4_t
c
[
2
][
8
];
int8x16_t
weight
[
2
][
2
];
int8x16_t
src
[
8
+
1
];
int16x8_t
temp_c
[
4
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
8
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
(
src_ic_0_3
+
16
);
src
[
2
]
=
vld1q_s8
(
src_ic_0_3
+
2
*
16
);
src
[
3
]
=
vld1q_s8
(
src_ic_0_3
+
3
*
16
);
src
[
4
]
=
vld1q_s8
(
src_ic_0_3
+
4
*
16
);
src
[
5
]
=
vld1q_s8
(
src_ic_0_3
+
5
*
16
);
src
[
6
]
=
vld1q_s8
(
src_ic_0_3
+
6
*
16
);
src
[
7
]
=
vld1q_s8
(
src_ic_0_3
+
7
*
16
);
src
[
8
]
=
vld1q_s8
(
src_ic_0_3
+
8
*
16
);
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
][
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
0
][
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
weight
[
1
][
0
]
=
vld1q_s8
(
read_weight_ptr
+
ld_weight_oc4
);
weight
[
1
][
1
]
=
vld1q_s8
(
read_weight_ptr
+
ld_weight_oc4
+
16
);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
0
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
1
][
0
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
0
],
c
[
1
][
0
],
temp_c
[
1
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
2
],
c
[
0
][
1
],
temp_c
[
2
]);
c
[
1
][
1
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
2
],
c
[
1
][
1
],
temp_c
[
3
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
1
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
1
][
0
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
1
],
c
[
1
][
0
],
temp_c
[
1
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
3
],
c
[
0
][
1
],
temp_c
[
2
]);
c
[
1
][
1
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
3
],
c
[
1
][
1
],
temp_c
[
3
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
4
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
1
][
2
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
4
],
c
[
1
][
2
],
temp_c
[
1
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
6
],
c
[
0
][
3
],
temp_c
[
2
]);
c
[
1
][
3
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
6
],
c
[
1
][
3
],
temp_c
[
3
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
5
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
1
][
2
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
5
],
c
[
1
][
2
],
temp_c
[
1
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
7
],
c
[
0
][
3
],
temp_c
[
2
]);
c
[
1
][
3
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
7
],
c
[
1
][
3
],
temp_c
[
3
]);
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
+
9
*
16
);
src
[
1
]
=
vld1q_s8
(
src_ic_0_3
+
10
*
16
);
src
[
2
]
=
vld1q_s8
(
src_ic_0_3
+
11
*
16
);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
8
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
1
][
4
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
8
],
c
[
1
][
4
],
temp_c
[
1
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
1
],
c
[
0
][
5
],
temp_c
[
2
]);
c
[
1
][
5
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
1
],
c
[
1
][
5
],
temp_c
[
3
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
0
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
1
][
4
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
0
],
c
[
1
][
4
],
temp_c
[
1
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
2
],
c
[
0
][
5
],
temp_c
[
2
]);
c
[
1
][
5
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
2
],
c
[
1
][
5
],
temp_c
[
3
]);
src
[
3
]
=
vld1q_s8
(
src_ic_0_3
+
12
*
16
);
src
[
4
]
=
vld1q_s8
(
src_ic_0_3
+
13
*
16
);
src
[
5
]
=
vld1q_s8
(
src_ic_0_3
+
14
*
16
);
src
[
6
]
=
vld1q_s8
(
src_ic_0_3
+
15
*
16
);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
3
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
1
][
6
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
3
],
c
[
1
][
6
],
temp_c
[
1
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
5
],
c
[
0
][
7
],
temp_c
[
2
]);
c
[
1
][
7
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
5
],
c
[
1
][
7
],
temp_c
[
3
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
4
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
1
][
6
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
4
],
c
[
1
][
6
],
temp_c
[
1
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
6
],
c
[
0
][
7
],
temp_c
[
2
]);
c
[
1
][
7
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
6
],
c
[
1
][
7
],
temp_c
[
3
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
DstType
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
,
int
c_dim
,
typename
DstType
>
static
void
ker_neon_dirctconv_2x2s2_oc4_ow8
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
DstType
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
int32x4_t
c
[
c_dim
][
8
];
int8x16_t
weight
[
2
];
int8x16_t
src
[
8
+
1
];
int16x8_t
temp_c
[
2
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
8
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
2
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
3
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
4
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
5
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
6
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
7
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
8
*
16
));
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
6
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
7
],
c
[
0
][
3
],
temp_c
[
1
]);
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
+
9
*
16
);
src
[
1
]
=
vld1q_s8
(
src_ic_0_3
+
10
*
16
);
src
[
2
]
=
vld1q_s8
(
src_ic_0_3
+
11
*
16
);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
8
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
1
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
0
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
2
],
c
[
0
][
5
],
temp_c
[
1
]);
src
[
3
]
=
vld1q_s8
(
src_ic_0_3
+
12
*
16
);
src
[
4
]
=
vld1q_s8
(
src_ic_0_3
+
13
*
16
);
src
[
5
]
=
vld1q_s8
(
src_ic_0_3
+
14
*
16
);
src
[
6
]
=
vld1q_s8
(
src_ic_0_3
+
15
*
16
);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
3
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
5
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
4
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
6
],
c
[
0
][
7
],
temp_c
[
1
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
DstType
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
/**
dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc>
example: (format like weight<oc, ic>)
packed weight
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3>
---------------------------------------------------------------------
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0>
dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0>
**/
// TODO: can try oh = 2 impl, oc = 8 impl
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
c_dim
,
typename
DstType
>
struct
KerNeonDirectStride2Int8
<
bias_mode
,
Op
,
remain_w
,
3
,
c_dim
,
DstType
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
DstType
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
const
Op
&
op
,
int
ld_dst_oc
)
{
constexpr
int
filter_size
=
3
;
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
int32x4_t
c
[
c_dim
][
8
];
int8x16_t
weight
[
3
];
int8x16_t
src
[
8
+
2
];
int16x8_t
temp_c
[
4
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
8
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
2
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
3
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
4
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
5
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
6
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
7
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
8
*
16
));
src
[
9
]
=
vld1q_s8
((
src_ic_0_3
+
9
*
16
));
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
weight
[
2
]
=
vld1q_s8
(
read_weight_ptr
+
2
*
16
);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
][
0
],
temp_c
[
2
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
0
][
1
],
temp_c
[
3
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
0
][
2
],
temp_c
[
2
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
6
],
c
[
0
][
3
],
temp_c
[
3
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
7
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
0
][
2
],
temp_c
[
2
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
8
],
c
[
0
][
3
],
temp_c
[
3
]);
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
+
10
*
16
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
11
*
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
12
*
16
));
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
8
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
9
],
c
[
0
][
4
],
temp_c
[
2
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
][
5
],
temp_c
[
3
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
0
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
0
][
5
],
temp_c
[
1
]);
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
13
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
14
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
15
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
16
*
16
));
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
0
][
6
],
temp_c
[
2
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
0
][
7
],
temp_c
[
3
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
0
][
6
],
temp_c
[
2
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
0
][
7
],
temp_c
[
3
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
DstType
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
c_dim
,
typename
DstType
>
struct
KerNeonDirectStride2Int8
<
bias_mode
,
Op
,
remain_w
,
5
,
c_dim
,
DstType
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
DstType
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
const
Op
&
op
,
int
ld_dst_oc
)
{
constexpr
int
filter_size
=
5
;
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
int32x4_t
c
[
c_dim
][
8
];
int8x16_t
weight
[
5
];
int8x16_t
src
[
8
+
2
];
int16x8_t
temp_c
[
4
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
8
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
2
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
3
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
4
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
5
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
6
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
7
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
8
*
16
));
src
[
9
]
=
vld1q_s8
((
src_ic_0_3
+
9
*
16
));
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
weight
[
2
]
=
vld1q_s8
(
read_weight_ptr
+
2
*
16
);
weight
[
3
]
=
vld1q_s8
(
read_weight_ptr
+
3
*
16
);
weight
[
4
]
=
vld1q_s8
(
read_weight_ptr
+
4
*
16
);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
][
0
],
temp_c
[
2
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
0
][
1
],
temp_c
[
3
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
3
],
c
[
0
][
0
],
temp_c
[
2
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
5
],
c
[
0
][
1
],
temp_c
[
3
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
4
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
6
],
c
[
0
][
1
],
temp_c
[
1
]);
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
+
10
*
16
);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
0
][
2
],
temp_c
[
2
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
6
],
c
[
0
][
3
],
temp_c
[
3
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
7
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
0
][
2
],
temp_c
[
2
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
8
],
c
[
0
][
3
],
temp_c
[
3
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
7
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
9
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
8
],
c
[
0
][
2
],
temp_c
[
2
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
0
],
c
[
0
][
3
],
temp_c
[
3
]);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
11
*
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
12
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
13
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
14
*
16
));
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
8
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
9
],
c
[
0
][
4
],
temp_c
[
2
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
][
5
],
temp_c
[
3
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
0
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
1
],
c
[
0
][
4
],
temp_c
[
2
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
3
],
c
[
0
][
5
],
temp_c
[
3
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
2
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
4
],
c
[
0
][
5
],
temp_c
[
1
]);
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
15
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
16
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
17
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
18
*
16
));
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
0
][
6
],
temp_c
[
2
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
0
][
7
],
temp_c
[
3
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
0
][
6
],
temp_c
[
2
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
0
][
7
],
temp_c
[
3
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
5
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
7
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
6
],
c
[
0
][
6
],
temp_c
[
2
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
8
],
c
[
0
][
7
],
temp_c
[
3
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
DstType
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
c_dim
,
typename
DstType
>
struct
KerNeonDirectStride2Int8
<
bias_mode
,
Op
,
remain_w
,
7
,
c_dim
,
DstType
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
DstType
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
const
Op
&
op
,
int
ld_dst_oc
)
{
constexpr
int
filter_size
=
7
;
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
int32x4_t
c
[
c_dim
][
8
];
int8x16_t
weight
[
7
];
int8x16_t
src
[
8
+
2
];
int16x8_t
temp_c
[
4
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
8
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
(
src_ic_0_3
+
1
*
16
);
src
[
2
]
=
vld1q_s8
(
src_ic_0_3
+
2
*
16
);
src
[
3
]
=
vld1q_s8
(
src_ic_0_3
+
3
*
16
);
src
[
4
]
=
vld1q_s8
(
src_ic_0_3
+
4
*
16
);
src
[
5
]
=
vld1q_s8
(
src_ic_0_3
+
5
*
16
);
src
[
6
]
=
vld1q_s8
(
src_ic_0_3
+
6
*
16
);
src
[
7
]
=
vld1q_s8
(
src_ic_0_3
+
7
*
16
);
src
[
8
]
=
vld1q_s8
(
src_ic_0_3
+
8
*
16
);
src
[
9
]
=
vld1q_s8
(
src_ic_0_3
+
9
*
16
);
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
weight
[
2
]
=
vld1q_s8
(
read_weight_ptr
+
2
*
16
);
weight
[
3
]
=
vld1q_s8
(
read_weight_ptr
+
3
*
16
);
weight
[
4
]
=
vld1q_s8
(
read_weight_ptr
+
4
*
16
);
weight
[
5
]
=
vld1q_s8
(
read_weight_ptr
+
5
*
16
);
weight
[
6
]
=
vld1q_s8
(
read_weight_ptr
+
6
*
16
);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
][
0
],
temp_c
[
2
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
0
][
1
],
temp_c
[
3
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
3
],
c
[
0
][
0
],
temp_c
[
2
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
5
],
c
[
0
][
1
],
temp_c
[
3
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
4
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
6
],
c
[
0
][
1
],
temp_c
[
1
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
5
],
c
[
0
][
0
],
temp_c
[
2
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
7
],
c
[
0
][
1
],
temp_c
[
3
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
6
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
8
],
c
[
0
][
1
],
temp_c
[
1
]);
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
+
10
*
16
);
src
[
1
]
=
vld1q_s8
(
src_ic_0_3
+
11
*
16
);
src
[
2
]
=
vld1q_s8
(
src_ic_0_3
+
12
*
16
);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
0
][
2
],
temp_c
[
2
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
6
],
c
[
0
][
3
],
temp_c
[
3
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
7
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
0
][
2
],
temp_c
[
2
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
8
],
c
[
0
][
3
],
temp_c
[
3
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
7
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
9
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
8
],
c
[
0
][
2
],
temp_c
[
2
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
0
],
c
[
0
][
3
],
temp_c
[
3
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
9
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
1
],
c
[
0
][
3
],
temp_c
[
1
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
0
],
c
[
0
][
2
],
temp_c
[
2
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
2
],
c
[
0
][
3
],
temp_c
[
3
]);
src
[
3
]
=
vld1q_s8
(
src_ic_0_3
+
13
*
16
);
src
[
4
]
=
vld1q_s8
(
src_ic_0_3
+
14
*
16
);
src
[
5
]
=
vld1q_s8
(
src_ic_0_3
+
15
*
16
);
src
[
6
]
=
vld1q_s8
(
src_ic_0_3
+
16
*
16
);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
8
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
9
],
c
[
0
][
4
],
temp_c
[
2
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
][
5
],
temp_c
[
3
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
0
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
1
],
c
[
0
][
4
],
temp_c
[
2
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
3
],
c
[
0
][
5
],
temp_c
[
3
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
2
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
4
],
c
[
0
][
5
],
temp_c
[
1
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
3
],
c
[
0
][
4
],
temp_c
[
2
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
5
],
c
[
0
][
5
],
temp_c
[
3
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
4
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
6
],
c
[
0
][
5
],
temp_c
[
1
]);
src
[
7
]
=
vld1q_s8
(
src_ic_0_3
+
17
*
16
);
src
[
8
]
=
vld1q_s8
(
src_ic_0_3
+
18
*
16
);
src
[
9
]
=
vld1q_s8
(
src_ic_0_3
+
19
*
16
);
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
+
20
*
16
);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
0
][
6
],
temp_c
[
2
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
0
][
7
],
temp_c
[
3
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
0
][
6
],
temp_c
[
2
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
0
][
7
],
temp_c
[
3
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
5
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
7
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
6
],
c
[
0
][
6
],
temp_c
[
2
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
8
],
c
[
0
][
7
],
temp_c
[
3
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
7
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
9
],
c
[
0
][
7
],
temp_c
[
1
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
8
],
c
[
0
][
6
],
temp_c
[
2
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
0
],
c
[
0
][
7
],
temp_c
[
3
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
DstType
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
typename
DstType
>
void
conv_direct_stride2_2x2_int8_nchw44
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
,
DstType
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
constexpr
size_t
filter_size
=
2
;
constexpr
size_t
fh
=
filter_size
;
constexpr
size_t
fw
=
filter_size
;
constexpr
size_t
ic_step
=
4
;
constexpr
size_t
oc_step
=
4
;
constexpr
size_t
big_oc_step
=
8
;
constexpr
size_t
oh_step
=
1
;
constexpr
size_t
ow_step
=
8
;
constexpr
size_t
stride_h
=
2
;
constexpr
size_t
stride_w
=
2
;
constexpr
int
pack_iw_len
=
4
;
const
size_t
out_img_stride
=
oh
*
ow
;
const
size_t
ow_end
=
ow
/
ow_step
*
ow_step
;
const
size_t
ow_remain
=
ow
-
ow_end
;
const
size_t
oc_end
=
oc
/
big_oc_step
*
big_oc_step
;
const
size_t
oc_remain
=
oc
-
oc_end
;
const
int
ld_dst_oc
=
oh
*
ow
*
oc_step
;
for
(
size_t
oc_idx
=
0
;
oc_idx
<
oc_end
;
oc_idx
+=
big_oc_step
)
{
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
out_img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
ker_neon_dirctconv_2x2s2_oc8_ow8
<
bias_mode
,
Op
,
ow_step
,
filter_size
,
2
,
DstType
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
out_img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
ker_neon_dirctconv_2x2s2_oc8_ow8
<
bias_mode
,
Op
,
remain_w
,
filter_size
,
2
,
DstType
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
}
}
if
(
oc_remain
>
0
)
{
const
size_t
oc_idx
=
oc_end
;
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
out_img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
ker_neon_dirctconv_2x2s2_oc4_ow8
<
bias_mode
,
Op
,
ow_step
,
filter_size
,
1
,
DstType
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
out_img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
ker_neon_dirctconv_2x2s2_oc4_ow8
<
bias_mode
,
Op
,
remain_w
,
filter_size
,
1
,
DstType
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
}
}
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
,
typename
DstType
>
void
conv_direct_stride2_int8_nchw44_kern
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
,
DstType
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
constexpr
size_t
fh
=
filter_size
;
constexpr
size_t
fw
=
filter_size
;
constexpr
size_t
ic_step
=
4
;
constexpr
size_t
oc_step
=
4
;
constexpr
size_t
oh_step
=
1
;
constexpr
size_t
ow_step
=
8
;
constexpr
size_t
stride_h
=
2
;
constexpr
size_t
stride_w
=
2
;
constexpr
int
pack_iw_len
=
4
;
const
size_t
img_stride
=
oh
*
ow
;
const
size_t
ow_end
=
ow
/
ow_step
*
ow_step
;
const
size_t
ow_remain
=
ow
-
ow_end
;
const
int
ld_dst_oc
=
oh
*
ow
*
oc_step
;
for
(
size_t
oc_idx
=
0
;
oc_idx
<
oc
;
oc_idx
+=
oc_step
)
{
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
KerNeonDirectStride2Int8
<
bias_mode
,
Op
,
ow_step
,
filter_size
,
1
,
DstType
>::
impl
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
,
ld_dst_oc
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
KerNeonDirectStride2Int8
<
bias_mode
,
Op
,
remain_w
,
filter_size
,
1
,
DstType
>::
impl
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
,
ld_dst_oc
);
}
}
}
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
,
typename
DstType
,
int
stride
>
struct
ConvDirectInt8Nchw44Choose
{
static
void
impl
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
DstType
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
);
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
,
typename
DstType
>
struct
ConvDirectInt8Nchw44Choose
<
bias_mode
,
Op
,
remain_w
,
filter_size
,
DstType
,
1
>
{
static
void
impl
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
DstType
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
if
(
filter_size
==
2
)
{
conv_direct_stride1_2x2_int8_nchw44
<
bias_mode
,
Op
,
remain_w
,
DstType
>
(
src
,
filter
,
bias
,
temp
,
dst
,
oc
,
ic
,
ih
,
iw
,
oh
,
ow
,
op
);
}
else
{
conv_direct_stride1_int8_nchw44_kern
<
bias_mode
,
Op
,
remain_w
,
filter_size
,
DstType
>
(
src
,
filter
,
bias
,
temp
,
dst
,
oc
,
ic
,
ih
,
iw
,
oh
,
ow
,
op
);
}
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
,
typename
DstType
>
struct
ConvDirectInt8Nchw44Choose
<
bias_mode
,
Op
,
remain_w
,
filter_size
,
DstType
,
2
>
{
static
void
impl
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
DstType
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
if
(
filter_size
==
2
)
{
conv_direct_stride2_2x2_int8_nchw44
<
bias_mode
,
Op
,
remain_w
,
DstType
>
(
src
,
filter
,
bias
,
temp
,
dst
,
oc
,
ic
,
ih
,
iw
,
oh
,
ow
,
op
);
}
else
{
conv_direct_stride2_int8_nchw44_kern
<
bias_mode
,
Op
,
remain_w
,
filter_size
,
DstType
>
(
src
,
filter
,
bias
,
temp
,
dst
,
oc
,
ic
,
ih
,
iw
,
oh
,
ow
,
op
);
}
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
,
typename
DstType
,
int
stride
>
void
conv_direct_int8_nchw44
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
DstType
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
ConvDirectInt8Nchw44Choose
<
bias_mode
,
Op
,
remain_w
,
filter_size
,
DstType
,
stride
>::
impl
(
src
,
filter
,
bias
,
temp
,
dst
,
oc
,
ic
,
ih
,
iw
,
oh
,
ow
,
op
);
}
}
// namespace
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_algo.cpp
已删除
100644 → 0
浏览文件 @
4e0c9ad3
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"
#include "midout.h"
using
namespace
megdnn
;
using
namespace
arm_common
;
using
conv_fun
=
std
::
function
<
void
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
CpuNDRange
&
workspace_ids
,
const
CpuNDRange
&
ncb_range
)
>
;
MIDOUT_DECL
(
megdnn_arm_common_conv_bias_int8_nchw44_stride1
)
static
void
get_rectified_size
(
const
megdnn
::
fallback
::
ConvBiasImpl
::
NCBKernSizeParam
&
param
,
size_t
&
IH2
,
size_t
&
IW2
,
size_t
&
OH2
,
size_t
&
OW2
)
{
auto
&&
fm
=
param
.
filter_meta
;
auto
SW
=
fm
.
stride
[
1
];
auto
OH
=
param
.
osz
[
0
];
auto
OW
=
param
.
osz
[
1
];
auto
FH
=
fm
.
spatial
[
0
];
auto
FW
=
fm
.
spatial
[
1
];
OH2
=
OH
;
OW2
=
(
OW
+
7
)
&
~
7
;
IH2
=
SW
*
OH
+
FH
-
SW
;
IW2
=
SW
*
OW2
+
FW
-
SW
;
}
static
WorkspaceBundle
get_bundle
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
)
{
constexpr
size_t
src_expand
=
4
;
auto
&&
fm
=
param
.
filter_meta
;
size_t
group
=
fm
.
group
;
size_t
batch
=
param
.
n
;
size_t
IC
=
fm
.
icpg
;
size_t
OC
=
fm
.
ocpg
;
size_t
FH
=
fm
.
spatial
[
0
];
size_t
FW
=
fm
.
spatial
[
1
];
size_t
IH2
,
IW2
,
OH2
,
OW2
;
get_rectified_size
(
param
,
IH2
,
IW2
,
OH2
,
OW2
);
if
(
group
==
1
)
{
size_t
src_size
=
batch
*
group
*
IC
*
IH2
*
IW2
*
sizeof
(
int8_t
)
*
src_expand
;
size_t
weight_size
=
group
*
OC
*
IC
*
FH
*
FW
*
sizeof
(
int8_t
);
return
{
nullptr
,
{
src_size
,
weight_size
}};
}
else
{
size_t
src_size
=
param
.
nr_threads
*
IC
*
IH2
*
IW2
*
sizeof
(
int8_t
)
*
src_expand
;
size_t
weight_size
=
group
*
OC
*
IC
*
FH
*
FW
*
sizeof
(
int8_t
);
return
{
nullptr
,
{
src_size
,
weight_size
}};
}
};
static
void
copy_padding_kern
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
CpuNDRange
&
workspace_ids
)
{
size_t
IH
=
kern_param
.
isz
[
0
];
size_t
IW
=
kern_param
.
isz
[
1
];
size_t
IC
=
kern_param
.
filter_meta
.
icpg
;
size_t
PH
=
kern_param
.
filter_meta
.
padding
[
0
];
size_t
PW
=
kern_param
.
filter_meta
.
padding
[
1
];
size_t
GROUP
=
kern_param
.
filter_meta
.
group
;
size_t
IH2
,
IW2
,
OH2
,
OW2
;
get_rectified_size
(
kern_param
,
IH2
,
IW2
,
OH2
,
OW2
);
size_t
padding_group_size
=
IH2
*
IW2
*
IC
;
bundle
.
set
(
kern_param
.
workspace_ptr
);
//! Used for get the workspace offset
constexpr
int
pack_ic
=
4
;
constexpr
int
expend_element
=
4
;
// TODO: block dim is better to get from arg
size_t
workspace_ic_block
=
4
;
size_t
workspace_batch_id
=
workspace_ids
[
0
];
size_t
workspace_group_id
=
workspace_ids
[
1
];
size_t
workspace_ic_id
=
workspace_ids
[
2
];
size_t
workspace_ic
=
workspace_ic_id
*
workspace_ic_block
;
size_t
batch_id
=
ncb_index
.
ndrange_id
[
0
];
size_t
group_id
=
ncb_index
.
ndrange_id
[
1
];
size_t
group_pack_size
=
1
;
int
nr_pad_h
=
PH
*
IW2
*
pack_ic
*
expend_element
;
int
nr_pad_w
=
PW
*
pack_ic
*
expend_element
;
int
over_pad
=
std
::
max
(
0
_z
,
IW2
-
IW
-
2
*
PW
)
*
pack_ic
*
expend_element
;
//! copy to sptr_base to eliminate padding effect
const
int8_t
*
sptr
=
static_cast
<
const
int8_t
*>
(
kern_param
.
src
<
int8_t
>
(
batch_id
,
group_id
,
workspace_ic_id
,
group_pack_size
,
pack_ic
));
int8_t
*
sptr_base
=
static_cast
<
int8_t
*>
(
bundle
.
get
(
0
))
+
(
workspace_batch_id
*
GROUP
*
padding_group_size
+
workspace_group_id
*
padding_group_size
+
workspace_ic
*
IH2
*
IW2
)
*
expend_element
;
size_t
nr_ic
=
workspace_ic_block
;
if
(
GROUP
>
1
)
{
nr_ic
=
IC
;
}
rep_step
(
ic_idx
,
nr_ic
,
pack_ic
)
{
std
::
memset
(
sptr_base
,
0
,
nr_pad_h
*
sizeof
(
int8_t
));
sptr_base
+=
nr_pad_h
;
rep
(
ih_idx
,
IH
)
{
std
::
memset
(
sptr_base
,
0
,
nr_pad_w
*
sizeof
(
int8_t
));
sptr_base
+=
nr_pad_w
;
conv_bias
::
nchw44_pack_src
(
sptr
,
sptr_base
,
IW
);
sptr_base
+=
IW
*
pack_ic
*
expend_element
;
sptr
+=
IW
*
pack_ic
;
std
::
memset
(
sptr_base
,
0
,
(
nr_pad_w
+
over_pad
)
*
sizeof
(
int8_t
));
sptr_base
+=
nr_pad_w
+
over_pad
;
}
std
::
memset
(
sptr_base
,
0
,
nr_pad_h
*
sizeof
(
int8_t
));
sptr_base
+=
nr_pad_h
;
}
}
template
<
size_t
filter
,
BiasMode
bias_mode
,
typename
Op
,
int
ow_remain
>
static
void
do_conv_kern
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
CpuNDRange
&
workspace_ids
,
const
CpuNDRange
&
ncb_range
)
{
size_t
OH
=
kern_param
.
osz
[
0
];
size_t
OW
=
kern_param
.
osz
[
1
];
size_t
FH
=
kern_param
.
filter_meta
.
spatial
[
0
];
size_t
FW
=
kern_param
.
filter_meta
.
spatial
[
1
];
size_t
IC
=
kern_param
.
filter_meta
.
icpg
;
size_t
OC
=
kern_param
.
filter_meta
.
ocpg
;
size_t
GROUP
=
kern_param
.
filter_meta
.
group
;
size_t
IH2
,
IW2
,
OH2
,
OW2
;
get_rectified_size
(
kern_param
,
IH2
,
IW2
,
OH2
,
OW2
);
bool
need_post_process
=
kern_param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
;
//! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f)
Op
op
=
Op
(
1.0
f
,
4.0
f
);
if
(
need_post_process
)
{
float
scale_bias
=
kern_param
.
bias_type
.
param
<
dtype
::
QuantizedS32
>
().
scale
;
float
scale_dst
=
kern_param
.
dst_type
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
op
=
Op
(
scale_bias
,
scale_dst
);
}
size_t
padding_group_size
=
IH2
*
IW2
*
IC
;
bundle
.
set
(
kern_param
.
workspace_ptr
);
constexpr
size_t
pack_c
=
4
;
constexpr
size_t
src_expand_size
=
4
;
const
size_t
workspace_batch_id
=
workspace_ids
[
0
];
const
size_t
workspace_group_id
=
workspace_ids
[
1
];
const
size_t
batch_id
=
ncb_index
.
ndrange_id
[
0
];
const
size_t
group_id
=
ncb_index
.
ndrange_id
[
1
];
const
size_t
oc_id
=
ncb_index
.
ndrange_id
[
2
];
const
size_t
oc_block_num
=
ncb_range
[
2
];
size_t
nr_pack_per_step
=
div_ceil
(
div_ceil
(
OC
,
pack_c
),
oc_block_num
);
size_t
oc_block
=
nr_pack_per_step
*
pack_c
;
const
size_t
oc_idx
=
oc_id
*
oc_block
;
if
(
oc_id
==
(
oc_block_num
-
1
))
{
oc_block
=
OC
-
oc_id
*
nr_pack_per_step
*
pack_c
;
}
megdnn_assert
(
oc_block
%
pack_c
==
0
,
"oc must be devisible by 4, but oc = %zu"
,
oc_block
);
const
int8_t
*
sptr
=
static_cast
<
int8_t
*>
(
bundle
.
get
(
0
))
+
workspace_batch_id
*
GROUP
*
padding_group_size
*
src_expand_size
+
workspace_group_id
*
padding_group_size
*
src_expand_size
;
const
int8_t
*
fptr
=
kern_param
.
filter
<
dt_int8
>
(
group_id
)
+
oc_idx
*
FH
*
FW
*
IC
;
void
*
dst
=
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
ptrdiff_t
>
(
kern_param
.
dst
<
void
>
(
batch_id
,
group_id
))
+
oc_idx
*
OH
*
OW
);
const
int32_t
*
bptr
=
kern_param
.
bias
<
dt_int32
>
(
batch_id
,
group_id
)
+
oc_idx
;
auto
packed_weight
=
reinterpret_cast
<
int8_t
*>
(
bundle
.
get
(
1
))
+
group_id
*
OC
*
IC
*
FH
*
FW
+
oc_idx
*
IC
*
FH
*
FW
;
conv_bias
::
nchw44_pack_filter
(
fptr
,
packed_weight
,
oc_block
/
4
*
IC
/
4
*
FH
*
FW
);
#define KERN1_NCHW44_CONV(filter) \
conv_bias::conv_direct_stride1_##filter##x##filter##_int8_nchw44< \
bias_mode, Op, ow_remain>(sptr, packed_weight, bptr, nullptr, \
static_cast<int8_t*>(dst), oc_block, IC, \
IH2, IW2, OH, OW, op)
DISPATCH_FILTER
(
filter
,
KERN1_NCHW44_CONV
)
#undef KERN1_NCHW44_CONV
}
/* ===================== stride1 algo ===================== */
bool
ConvBiasImpl
::
AlgoS8DirectStride1NCHW44
::
usable
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
algo_selection_strategy
)
const
{
MEGDNN_MARK_USED_VAR
(
algo_selection_strategy
);
auto
&&
fm
=
param
.
filter_meta
;
auto
FH
=
fm
.
spatial
[
0
];
auto
OC
=
fm
.
ocpg
;
auto
IC
=
fm
.
icpg
;
bool
avaible
=
//! src and filter are qint8, dst is qint8 or qint32
((
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
param
.
filter_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
(
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)))
&&
(
fm
.
format
==
param
::
Convolution
::
Format
::
NCHW44
)
&&
(
OC
%
4
==
0
&&
IC
%
4
==
0
&&
OC
>=
4
)
&&
!
fm
.
should_flip
&&
fm
.
spatial_ndim
==
2
&&
fm
.
dilation
[
0
]
==
1
&&
fm
.
dilation
[
1
]
==
1
&&
fm
.
stride
[
0
]
==
1
&&
fm
.
stride
[
1
]
==
1
&&
FH
==
fm
.
spatial
[
1
]
&&
(
FH
==
2
||
FH
==
3
||
FH
==
5
||
FH
==
7
)
&&
param
.
bias_mode
!=
BiasMode
::
BIAS
;
return
avaible
;
}
bool
ConvBiasImpl
::
AlgoS8DirectStride1NCHW44
::
is_preferred
(
megdnn
::
fallback
::
ConvBiasImpl
*
conv_bias_impl_ptr
,
const
NCBKernSizeParam
&
param
)
const
{
// TODO: benchmark and fix
MEGDNN_MARK_USED_VAR
(
conv_bias_impl_ptr
);
MEGDNN_MARK_USED_VAR
(
param
);
return
false
;
}
size_t
ConvBiasImpl
::
AlgoS8DirectStride1NCHW44
::
get_workspace
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
return
get_bundle
(
param
).
total_size_in_bytes
();
}
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ConvBiasImpl
::
AlgoS8DirectStride1NCHW44
::
dispatch_kerns
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
auto
fm
=
param
.
filter_meta
;
size_t
N
=
param
.
n
;
size_t
IC
=
fm
.
icpg
;
size_t
OC
=
fm
.
ocpg
;
size_t
OW
=
param
.
osz
[
1
];
size_t
group
=
fm
.
group
;
size_t
fh
=
fm
.
spatial
[
0
];
size_t
fw
=
fm
.
spatial
[
1
];
WorkspaceBundle
wbundle
=
get_bundle
(
param
);
conv_fun
do_conv_fun
=
nullptr
;
int
ow_remain
=
OW
%
8
;
// NOTE: remain_w is not used to gen hash of midout for compatible with changing
// shape runtime
#define DO_CONV_KERN_FUN(filter, bias_mode, remain_w, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride1, \
midout_iv(#filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(filter, bias_mode, remain_w) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_REMAIN_W_PARAM(filter, bias_mode) \
switch (ow_remain) { \
case 0: \
GET_OP_PARAM(filter, bias_mode, 0); \
break; \
case 1: \
GET_OP_PARAM(filter, bias_mode, 1); \
break; \
case 2: \
GET_OP_PARAM(filter, bias_mode, 2); \
break; \
case 3: \
GET_OP_PARAM(filter, bias_mode, 3); \
break; \
case 4: \
GET_OP_PARAM(filter, bias_mode, 4); \
break; \
case 5: \
GET_OP_PARAM(filter, bias_mode, 5); \
break; \
case 6: \
GET_OP_PARAM(filter, bias_mode, 6); \
break; \
case 7: \
GET_OP_PARAM(filter, bias_mode, 7); \
break; \
default: \
megdnn_assert(0); \
}
#define GET_BIAS_MODE_PARAM(filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_REMAIN_W_PARAM(filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_REMAIN_W_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN() \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(7) \
break; \
default: \
megdnn_assert(0); \
break; \
}
DISPATCH_CONV_KERN
();
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
#undef GET_OP_PARAM
#undef GET_BIAS_MODE_PARAM
#undef DISPATCH_CONV_KERN
megdnn_assert
(
do_conv_fun
);
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ret_kerns
;
WorkspaceBundle
bundle
=
wbundle
;
constexpr
size_t
pack_oc
=
4
;
size_t
oc_step
=
pack_oc
;
if
(
fh
==
2
&&
fw
==
2
&&
OC
>=
8
)
{
oc_step
=
8
;
}
if
(
group
==
1
)
{
CpuNDRange
ncb_range
=
{
N
,
group
,
div_ceil
(
OC
,
oc_step
)};
auto
copy_padding
=
[
bundle
](
const
NCBKernParam
&
kern_param
,
const
NCBKernIndex
&
ncb_index
)
{
copy_padding_kern
(
bundle
,
kern_param
,
ncb_index
,
ncb_index
.
ndrange_id
);
};
constexpr
size_t
pack_ic
=
4
;
ret_kerns
.
push_back
({
copy_padding
,
{
N
,
group
,
div_ceil
(
IC
,
pack_ic
)}});
auto
do_conv
=
[
bundle
,
do_conv_fun
,
ncb_range
](
const
NCBKernParam
&
kern_param
,
const
NCBKernIndex
&
ncb_index
)
{
do_conv_fun
(
bundle
,
kern_param
,
ncb_index
,
ncb_index
.
ndrange_id
,
ncb_range
);
};
ret_kerns
.
push_back
({
do_conv
,
ncb_range
});
}
else
{
CpuNDRange
ncb_range
=
{
N
,
group
,
1
};
auto
do_conv
=
[
bundle
,
do_conv_fun
,
ncb_range
](
const
NCBKernParam
&
kern_param
,
const
NCBKernIndex
&
ncb_index
)
{
copy_padding_kern
(
bundle
,
kern_param
,
ncb_index
,
{
0
,
ncb_index
.
thread_id
,
0
});
do_conv_fun
(
bundle
,
kern_param
,
ncb_index
,
{
0
,
ncb_index
.
thread_id
,
0
},
ncb_range
);
};
ret_kerns
.
push_back
({
do_conv
,
ncb_range
});
}
return
ret_kerns
;
}
// vim: syntax=cpp.doxygen
dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_kern.cpp
已删除
100644 → 0
浏览文件 @
4e0c9ad3
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_kern.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
using
namespace
megdnn
;
using
namespace
arm_common
;
namespace
{
/**
dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc>
example: (format like weight<oc, ic>)
packed weight
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3>
---------------------------------------------------------------------
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0>
dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0>
**/
// TODO: can try oh = 2 impl, oc = 8 impl
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
>
static
void
ker_neon_dirctconv_3x3s1_oc4_ow8
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
const
Op
&
op
)
{
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
ic_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
int32x4_t
c
[
2
*
4
];
int8x16_t
weight
[
3
];
int8x16_t
src
[
8
+
2
];
int16x8_t
temp_c
[
2
];
init_oc4_ow8
<
bias_mode
>
(
c
,
bias_ptr
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
2
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
3
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
4
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
5
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
6
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
7
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
8
*
16
));
src
[
9
]
=
vld1q_s8
((
src_ic_0_3
+
9
*
16
));
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
weight
[
2
]
=
vld1q_s8
(
read_weight_ptr
+
2
*
16
);
c
[
0
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
1
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
2
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
3
],
c
[
1
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
3
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
4
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
5
],
c
[
3
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
5
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
6
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
7
],
c
[
5
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
6
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
7
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
7
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
8
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
8
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
9
],
c
[
7
],
temp_c
[
1
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_oc4_ow8_remain_static
<
remain_w
,
Op
>
(
c
,
op
,
dst_ptr
);
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
>
static
void
ker_neon_dirctconv_2x2s1_oc8_ow8
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
ic_step
=
4
;
constexpr
int
oc_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc4
=
oc_step
*
fh
*
fw
*
ic
;
int32x4_t
c
[
2
][
8
];
int8x16_t
weight
[
2
][
2
];
int8x16_t
src
[
8
+
1
];
int16x8_t
temp_c
[
4
];
init_oc8_ow8
<
bias_mode
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
2
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
3
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
4
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
5
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
6
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
7
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
8
*
16
));
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
][
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
0
][
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
weight
[
1
][
0
]
=
vld1q_s8
(
read_weight_ptr
+
ld_weight_oc4
);
weight
[
1
][
1
]
=
vld1q_s8
(
read_weight_ptr
+
ld_weight_oc4
+
16
);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
0
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
1
][
0
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
0
],
c
[
1
][
0
],
temp_c
[
1
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
1
],
c
[
0
][
1
],
temp_c
[
2
]);
c
[
1
][
1
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
1
],
c
[
1
][
1
],
temp_c
[
3
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
1
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
1
][
0
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
1
],
c
[
1
][
0
],
temp_c
[
1
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
2
],
c
[
0
][
1
],
temp_c
[
2
]);
c
[
1
][
1
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
2
],
c
[
1
][
1
],
temp_c
[
3
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
2
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
1
][
2
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
2
],
c
[
1
][
2
],
temp_c
[
1
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
3
],
c
[
0
][
3
],
temp_c
[
2
]);
c
[
1
][
3
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
3
],
c
[
1
][
3
],
temp_c
[
3
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
3
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
1
][
2
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
3
],
c
[
1
][
2
],
temp_c
[
1
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
4
],
c
[
0
][
3
],
temp_c
[
2
]);
c
[
1
][
3
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
4
],
c
[
1
][
3
],
temp_c
[
3
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
4
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
1
][
4
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
4
],
c
[
1
][
4
],
temp_c
[
1
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
5
],
c
[
0
][
5
],
temp_c
[
2
]);
c
[
1
][
5
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
5
],
c
[
1
][
5
],
temp_c
[
3
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
5
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
1
][
4
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
5
],
c
[
1
][
4
],
temp_c
[
1
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
6
],
c
[
0
][
5
],
temp_c
[
2
]);
c
[
1
][
5
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
6
],
c
[
1
][
5
],
temp_c
[
3
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
6
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
1
][
6
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
6
],
c
[
1
][
6
],
temp_c
[
1
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
7
],
c
[
0
][
7
],
temp_c
[
2
]);
c
[
1
][
7
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
7
],
c
[
1
][
7
],
temp_c
[
3
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
7
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
1
][
6
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
7
],
c
[
1
][
6
],
temp_c
[
1
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
8
],
c
[
0
][
7
],
temp_c
[
2
]);
c
[
1
][
7
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
8
],
c
[
1
][
7
],
temp_c
[
3
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_oc8_ow8_remain_static
<
remain_w
>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
>
static
void
ker_neon_dirctconv_2x2s1_oc4_ow8
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
const
Op
&
op
)
{
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
ic_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
int32x4_t
c
[
2
*
4
];
int8x16_t
weight
[
2
];
int8x16_t
src
[
8
+
1
];
int16x8_t
temp_c
[
2
];
init_oc4_ow8
<
bias_mode
>
(
c
,
bias_ptr
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
2
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
3
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
4
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
5
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
6
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
7
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
8
*
16
));
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
c
[
0
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
1
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
2
],
c
[
1
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
3
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
4
],
c
[
3
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
5
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
6
],
c
[
5
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
6
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
7
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
7
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
8
],
c
[
7
],
temp_c
[
1
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_oc4_ow8_remain_static
<
remain_w
,
Op
>
(
c
,
op
,
dst_ptr
);
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
>
static
void
ker_neon_dirctconv_5x5s1_oc4_ow8
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
const
Op
&
op
)
{
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
ic_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
int32x4_t
c
[
2
*
4
];
int8x16_t
weight
[
5
];
int8x16_t
src
[
8
+
2
];
int16x8_t
temp_c
[
2
];
init_oc4_ow8
<
bias_mode
>
(
c
,
bias_ptr
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
2
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
3
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
4
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
5
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
6
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
7
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
8
*
16
));
src
[
9
]
=
vld1q_s8
((
src_ic_0_3
+
9
*
16
));
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
weight
[
2
]
=
vld1q_s8
(
read_weight_ptr
+
2
*
16
);
weight
[
3
]
=
vld1q_s8
(
read_weight_ptr
+
3
*
16
);
weight
[
4
]
=
vld1q_s8
(
read_weight_ptr
+
4
*
16
);
c
[
0
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
1
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
2
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
3
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
3
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
4
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
4
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
5
],
c
[
1
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
3
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
4
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
5
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
5
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
6
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
6
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
7
],
c
[
3
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
5
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
6
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
7
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
7
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
8
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
8
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
9
],
c
[
5
],
temp_c
[
1
]);
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
+
10
*
16
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
11
*
16
));
c
[
6
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
6
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
7
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
7
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
8
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
8
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
9
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
9
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
0
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
0
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
1
],
c
[
7
],
temp_c
[
1
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_oc4_ow8_remain_static
<
remain_w
,
Op
>
(
c
,
op
,
dst_ptr
);
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
>
static
void
ker_neon_dirctconv_7x7s1_oc4_ow8
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
const
Op
&
op
)
{
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
ic_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
int32x4_t
c
[
2
*
4
];
int8x16_t
weight
[
7
];
int8x16_t
src
[
8
+
2
];
int16x8_t
temp_c
[
2
];
init_oc4_ow8
<
bias_mode
>
(
c
,
bias_ptr
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
2
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
3
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
4
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
5
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
6
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
7
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
8
*
16
));
src
[
9
]
=
vld1q_s8
((
src_ic_0_3
+
9
*
16
));
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
weight
[
2
]
=
vld1q_s8
(
read_weight_ptr
+
2
*
16
);
weight
[
3
]
=
vld1q_s8
(
read_weight_ptr
+
3
*
16
);
weight
[
4
]
=
vld1q_s8
(
read_weight_ptr
+
4
*
16
);
weight
[
5
]
=
vld1q_s8
(
read_weight_ptr
+
5
*
16
);
weight
[
6
]
=
vld1q_s8
(
read_weight_ptr
+
6
*
16
);
c
[
0
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
1
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
2
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
3
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
3
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
4
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
4
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
5
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
5
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
6
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
6
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
7
],
c
[
1
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
3
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
4
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
5
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
5
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
6
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
6
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
7
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
7
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
8
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
8
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
9
],
c
[
3
],
temp_c
[
1
]);
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
+
10
*
16
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
11
*
16
));
c
[
4
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
5
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
6
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
7
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
7
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
8
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
8
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
9
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
9
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
0
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
0
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
1
],
c
[
5
],
temp_c
[
1
]);
src
[
2
]
=
vld1q_s8
(
src_ic_0_3
+
12
*
16
);
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
13
*
16
));
c
[
6
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
6
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
7
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
7
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
8
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
8
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
9
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
9
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
0
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
0
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
1
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
1
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
2
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
2
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
3
],
c
[
7
],
temp_c
[
1
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_oc4_ow8_remain_static
<
remain_w
,
Op
>
(
c
,
op
,
dst_ptr
);
}
}
// namespace
/**
origin weight shape <oc/4, ic/4, fh, fw, 4, 4>
packed weight shape <oc/4, ic/4, fh, fw, 16>
example: (format like weight<oc, ic>)
origin
<0, 0> <1, 0> <2, 0> <3, 0>
<0, 1> <1, 1> <2, 1> <3, 1>
<0, 2> <1, 2> <2, 2> <3, 2>
<0, 3> <1, 3> <2, 3> <3, 3>
packed
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3>
---------------------------------------------------------------------
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0>
**/
void
conv_bias
::
nchw44_pack_filter
(
const
int8_t
*
src
,
int8_t
*
dst
,
int
length
)
{
static
const
uint8_t
weight_idx_buffer
[
16
]
=
{
0
,
4
,
9
,
13
,
2
,
6
,
11
,
15
,
12
,
8
,
5
,
1
,
14
,
10
,
7
,
3
};
constexpr
int
simd_len
=
16
;
uint8x16_t
weight_idx
=
vld1q_u8
(
weight_idx_buffer
);
for
(
int
i
=
0
;
i
<
length
;
i
++
)
{
int8x16_t
result
=
vldq_tbl_s8
(
src
+
i
*
simd_len
,
weight_idx
);
vst1q_s8
(
dst
+
i
*
simd_len
,
result
);
}
}
/**
origin src shape <n, ic/4, h, w, 4>
packed src shape <n, ic/4, h, w, 16>
example: (format like <ic>)
origin
<0> <0> <0> <0>
packed
low 64 bit <0> <1> <2> <3> | <0> <1> <2> <3>
---------------------------------------------------------------------
high 64 bit <3> <2> <1> <0> | <3> <2> <1> <0>
**/
void
conv_bias
::
nchw44_pack_src
(
const
int8_t
*
src
,
int8_t
*
dst
,
int
length
)
{
static
const
uint8_t
src_idx_buffer
[
16
]
=
{
0
,
1
,
2
,
3
,
0
,
1
,
2
,
3
,
3
,
2
,
1
,
0
,
3
,
2
,
1
,
0
};
constexpr
int
pack_ic
=
4
;
constexpr
int
simd_len
=
16
;
uint8x16_t
src_idx
=
vld1q_u8
(
src_idx_buffer
);
for
(
int
i
=
0
;
i
<
length
;
i
++
)
{
int8x16_t
result
=
vld_dup_tbl_s32
(
src
+
i
*
pack_ic
,
src_idx
);
vst1q_s8
(
dst
+
i
*
simd_len
,
result
);
}
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
>
void
conv_bias
::
conv_direct_stride1_2x2_int8_nchw44
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
MEGDNN_MARK_USED_VAR
(
temp
);
constexpr
size_t
filter_size
=
2
;
constexpr
size_t
fh
=
filter_size
;
constexpr
size_t
fw
=
filter_size
;
constexpr
size_t
ic_step
=
4
;
constexpr
size_t
oc_step
=
4
;
constexpr
size_t
big_oc_step
=
8
;
constexpr
size_t
oh_step
=
1
;
constexpr
size_t
ow_step
=
8
;
constexpr
int
pack_iw_len
=
4
;
const
size_t
img_stride
=
oh
*
ow
;
const
size_t
ow_end
=
ow
/
ow_step
*
ow_step
;
const
size_t
ow_remain
=
ow
-
ow_end
;
const
size_t
oc_end
=
oc
/
big_oc_step
*
big_oc_step
;
const
size_t
oc_remain
=
oc
-
oc_end
;
const
int
ld_oc
=
oh
*
ow
*
ic_step
;
for
(
size_t
oc_idx
=
0
;
oc_idx
<
oc_end
;
oc_idx
+=
big_oc_step
)
{
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
iw
+
ow_idx
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
ker_neon_dirctconv_2x2s1_oc8_ow8
<
bias_mode
,
Op
,
0
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_oc
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
iw
+
ow_end
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
ker_neon_dirctconv_2x2s1_oc8_ow8
<
bias_mode
,
Op
,
remain_w
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_oc
,
op
);
}
}
}
if
(
oc_remain
>
0
)
{
const
size_t
oc_idx
=
oc_end
;
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
iw
+
ow_idx
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
ker_neon_dirctconv_2x2s1_oc4_ow8
<
bias_mode
,
Op
,
0
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
iw
+
ow_end
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
ker_neon_dirctconv_2x2s1_oc4_ow8
<
bias_mode
,
Op
,
remain_w
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
);
}
}
}
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
>
void
conv_bias
::
conv_direct_stride1_3x3_int8_nchw44
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
MEGDNN_MARK_USED_VAR
(
temp
);
constexpr
size_t
filter_size
=
3
;
constexpr
size_t
fh
=
filter_size
;
constexpr
size_t
fw
=
filter_size
;
constexpr
size_t
ic_step
=
4
;
constexpr
size_t
oc_step
=
4
;
constexpr
size_t
oh_step
=
1
;
constexpr
size_t
ow_step
=
8
;
constexpr
int
pack_iw_len
=
4
;
const
size_t
img_stride
=
oh
*
ow
;
const
size_t
ow_end
=
ow
/
ow_step
*
ow_step
;
const
size_t
ow_remain
=
ow
-
ow_end
;
for
(
size_t
oc_idx
=
0
;
oc_idx
<
oc
;
oc_idx
+=
oc_step
)
{
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
iw
+
ow_idx
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
ker_neon_dirctconv_3x3s1_oc4_ow8
<
bias_mode
,
Op
,
0
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
iw
+
ow_end
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
ker_neon_dirctconv_3x3s1_oc4_ow8
<
bias_mode
,
Op
,
remain_w
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
);
}
}
}
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
>
void
conv_bias
::
conv_direct_stride1_5x5_int8_nchw44
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
MEGDNN_MARK_USED_VAR
(
temp
);
constexpr
size_t
filter_size
=
5
;
constexpr
size_t
fh
=
filter_size
;
constexpr
size_t
fw
=
filter_size
;
constexpr
size_t
ic_step
=
4
;
constexpr
size_t
oc_step
=
4
;
constexpr
size_t
oh_step
=
1
;
constexpr
size_t
ow_step
=
8
;
constexpr
int
pack_iw_len
=
4
;
const
size_t
img_stride
=
oh
*
ow
;
const
size_t
ow_end
=
ow
/
ow_step
*
ow_step
;
const
size_t
ow_remain
=
ow
-
ow_end
;
for
(
size_t
oc_idx
=
0
;
oc_idx
<
oc
;
oc_idx
+=
oc_step
)
{
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
iw
+
ow_idx
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
ker_neon_dirctconv_5x5s1_oc4_ow8
<
bias_mode
,
Op
,
0
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
iw
+
ow_end
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
ker_neon_dirctconv_5x5s1_oc4_ow8
<
bias_mode
,
Op
,
remain_w
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
);
}
}
}
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
>
void
conv_bias
::
conv_direct_stride1_7x7_int8_nchw44
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
MEGDNN_MARK_USED_VAR
(
temp
);
constexpr
size_t
filter_size
=
7
;
constexpr
size_t
fh
=
filter_size
;
constexpr
size_t
fw
=
filter_size
;
constexpr
size_t
ic_step
=
4
;
constexpr
size_t
oc_step
=
4
;
constexpr
size_t
oh_step
=
1
;
constexpr
size_t
ow_step
=
8
;
constexpr
int
pack_iw_len
=
4
;
const
size_t
img_stride
=
oh
*
ow
;
const
size_t
ow_end
=
ow
/
ow_step
*
ow_step
;
const
size_t
ow_remain
=
ow
-
ow_end
;
for
(
size_t
oc_idx
=
0
;
oc_idx
<
oc
;
oc_idx
+=
oc_step
)
{
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
iw
+
ow_idx
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
ker_neon_dirctconv_7x7s1_oc4_ow8
<
bias_mode
,
Op
,
0
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
iw
+
ow_end
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
ker_neon_dirctconv_7x7s1_oc4_ow8
<
bias_mode
,
Op
,
remain_w
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
);
}
}
}
}
#define INSTANTIATION(stride, i, bias, remain_w, Op) \
template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_nchw44< \
bias, Op, remain_w>(const int8_t*, const int8_t*, const int32_t*, \
int32_t*, int8_t*, const size_t, const size_t, \
const size_t, const size_t, const size_t, \
const size_t, const Op&);
#define FOR_OP(stride, i, bias, remain_w) \
INSTANTIATION(stride, i, bias, remain_w, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(stride, i, bias, remain_w, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(stride, i, bias, remain_w, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)
#define FOR_REMAIN(stride, i, bias) \
FOR_OP(stride, i, bias, 0) \
FOR_OP(stride, i, bias, 1) \
FOR_OP(stride, i, bias, 2) \
FOR_OP(stride, i, bias, 3) \
FOR_OP(stride, i, bias, 4) \
FOR_OP(stride, i, bias, 5) \
FOR_OP(stride, i, bias, 6) \
FOR_OP(stride, i, bias, 7)
#define FOR_BIAS(stride, i) \
FOR_REMAIN(stride, i, BiasMode::NO_BIAS) \
FOR_REMAIN(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)
#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)
FOR_FILTER
(
stride1
)
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern.cpp
已删除
100644 → 0
浏览文件 @
4e0c9ad3
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
using
namespace
megdnn
;
using
namespace
arm_common
;
namespace
{
/**
dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc>
example: (format like weight<oc, ic>)
packed weight
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3>
---------------------------------------------------------------------
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0>
dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0>
**/
// TODO: can try oh = 2 impl, oc = 8 impl
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
>
static
void
ker_neon_dirctconv_3x3s2_oc4_ow8
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
const
Op
&
op
)
{
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
ic_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
int32x4_t
c
[
2
*
4
];
int8x16_t
weight
[
3
];
int8x16_t
src
[
8
+
2
];
int16x8_t
temp_c
[
2
];
init_oc4_ow8
<
bias_mode
>
(
c
,
bias_ptr
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
2
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
3
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
4
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
5
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
6
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
7
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
8
*
16
));
src
[
9
]
=
vld1q_s8
((
src_ic_0_3
+
9
*
16
));
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
weight
[
2
]
=
vld1q_s8
(
read_weight_ptr
+
2
*
16
);
c
[
0
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
1
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
6
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
7
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
8
],
c
[
3
],
temp_c
[
1
]);
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
+
10
*
16
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
11
*
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
12
*
16
));
c
[
4
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
8
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
9
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
0
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
5
],
temp_c
[
1
]);
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
13
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
14
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
15
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
16
*
16
));
c
[
6
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
7
],
temp_c
[
1
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_oc4_ow8_remain_static
<
remain_w
,
Op
>
(
c
,
op
,
dst_ptr
);
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
>
static
void
ker_neon_dirctconv_2x2s2_oc8_ow8
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
ic_step
=
4
;
constexpr
int
oc_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc4
=
oc_step
*
fh
*
fw
*
ic
;
int32x4_t
c
[
2
][
8
];
int8x16_t
weight
[
2
][
2
];
int8x16_t
src
[
8
+
1
];
int16x8_t
temp_c
[
4
];
init_oc8_ow8
<
bias_mode
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
(
src_ic_0_3
+
16
);
src
[
2
]
=
vld1q_s8
(
src_ic_0_3
+
2
*
16
);
src
[
3
]
=
vld1q_s8
(
src_ic_0_3
+
3
*
16
);
src
[
4
]
=
vld1q_s8
(
src_ic_0_3
+
4
*
16
);
src
[
5
]
=
vld1q_s8
(
src_ic_0_3
+
5
*
16
);
src
[
6
]
=
vld1q_s8
(
src_ic_0_3
+
6
*
16
);
src
[
7
]
=
vld1q_s8
(
src_ic_0_3
+
7
*
16
);
src
[
8
]
=
vld1q_s8
(
src_ic_0_3
+
8
*
16
);
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
][
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
0
][
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
weight
[
1
][
0
]
=
vld1q_s8
(
read_weight_ptr
+
ld_weight_oc4
);
weight
[
1
][
1
]
=
vld1q_s8
(
read_weight_ptr
+
ld_weight_oc4
+
16
);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
0
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
1
][
0
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
0
],
c
[
1
][
0
],
temp_c
[
1
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
2
],
c
[
0
][
1
],
temp_c
[
2
]);
c
[
1
][
1
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
2
],
c
[
1
][
1
],
temp_c
[
3
]);
c
[
0
][
0
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
1
],
c
[
0
][
0
],
temp_c
[
0
]);
c
[
1
][
0
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
1
],
c
[
1
][
0
],
temp_c
[
1
]);
c
[
0
][
1
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
3
],
c
[
0
][
1
],
temp_c
[
2
]);
c
[
1
][
1
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
3
],
c
[
1
][
1
],
temp_c
[
3
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
4
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
1
][
2
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
4
],
c
[
1
][
2
],
temp_c
[
1
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
6
],
c
[
0
][
3
],
temp_c
[
2
]);
c
[
1
][
3
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
6
],
c
[
1
][
3
],
temp_c
[
3
]);
c
[
0
][
2
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
5
],
c
[
0
][
2
],
temp_c
[
0
]);
c
[
1
][
2
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
5
],
c
[
1
][
2
],
temp_c
[
1
]);
c
[
0
][
3
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
7
],
c
[
0
][
3
],
temp_c
[
2
]);
c
[
1
][
3
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
7
],
c
[
1
][
3
],
temp_c
[
3
]);
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
+
9
*
16
);
src
[
1
]
=
vld1q_s8
(
src_ic_0_3
+
10
*
16
);
src
[
2
]
=
vld1q_s8
(
src_ic_0_3
+
11
*
16
);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
8
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
1
][
4
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
8
],
c
[
1
][
4
],
temp_c
[
1
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
1
],
c
[
0
][
5
],
temp_c
[
2
]);
c
[
1
][
5
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
1
],
c
[
1
][
5
],
temp_c
[
3
]);
c
[
0
][
4
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
0
],
c
[
0
][
4
],
temp_c
[
0
]);
c
[
1
][
4
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
0
],
c
[
1
][
4
],
temp_c
[
1
]);
c
[
0
][
5
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
2
],
c
[
0
][
5
],
temp_c
[
2
]);
c
[
1
][
5
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
2
],
c
[
1
][
5
],
temp_c
[
3
]);
src
[
3
]
=
vld1q_s8
(
src_ic_0_3
+
12
*
16
);
src
[
4
]
=
vld1q_s8
(
src_ic_0_3
+
13
*
16
);
src
[
5
]
=
vld1q_s8
(
src_ic_0_3
+
14
*
16
);
src
[
6
]
=
vld1q_s8
(
src_ic_0_3
+
15
*
16
);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
3
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
1
][
6
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
3
],
c
[
1
][
6
],
temp_c
[
1
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
0
][
0
],
src
[
5
],
c
[
0
][
7
],
temp_c
[
2
]);
c
[
1
][
7
]
=
vdotq_s32_h
(
weight
[
1
][
0
],
src
[
5
],
c
[
1
][
7
],
temp_c
[
3
]);
c
[
0
][
6
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
4
],
c
[
0
][
6
],
temp_c
[
0
]);
c
[
1
][
6
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
4
],
c
[
1
][
6
],
temp_c
[
1
]);
c
[
0
][
7
]
=
vdotq_s32_h
(
weight
[
0
][
1
],
src
[
6
],
c
[
0
][
7
],
temp_c
[
2
]);
c
[
1
][
7
]
=
vdotq_s32_h
(
weight
[
1
][
1
],
src
[
6
],
c
[
1
][
7
],
temp_c
[
3
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_oc8_ow8_remain_static
<
remain_w
>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
>
static
void
ker_neon_dirctconv_2x2s2_oc4_ow8
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
const
Op
&
op
)
{
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
ic_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
int32x4_t
c
[
2
*
4
];
int8x16_t
weight
[
2
];
int8x16_t
src
[
8
+
1
];
int16x8_t
temp_c
[
2
];
init_oc4_ow8
<
bias_mode
>
(
c
,
bias_ptr
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
2
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
3
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
4
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
5
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
6
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
7
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
8
*
16
));
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
c
[
0
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
1
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
6
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
7
],
c
[
3
],
temp_c
[
1
]);
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
+
9
*
16
);
src
[
1
]
=
vld1q_s8
(
src_ic_0_3
+
10
*
16
);
src
[
2
]
=
vld1q_s8
(
src_ic_0_3
+
11
*
16
);
c
[
4
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
8
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
1
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
0
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
2
],
c
[
5
],
temp_c
[
1
]);
src
[
3
]
=
vld1q_s8
(
src_ic_0_3
+
12
*
16
);
src
[
4
]
=
vld1q_s8
(
src_ic_0_3
+
13
*
16
);
src
[
5
]
=
vld1q_s8
(
src_ic_0_3
+
14
*
16
);
src
[
6
]
=
vld1q_s8
(
src_ic_0_3
+
15
*
16
);
c
[
6
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
3
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
5
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
4
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
6
],
c
[
7
],
temp_c
[
1
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_oc4_ow8_remain_static
<
remain_w
,
Op
>
(
c
,
op
,
dst_ptr
);
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
>
static
void
ker_neon_dirctconv_5x5s2_oc4_ow8
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
const
Op
&
op
)
{
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
ic_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
int32x4_t
c
[
2
*
4
];
int8x16_t
weight
[
5
];
int8x16_t
src
[
8
+
2
];
int16x8_t
temp_c
[
2
];
init_oc4_ow8
<
bias_mode
>
(
c
,
bias_ptr
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
2
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
3
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
4
*
16
));
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
5
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
6
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
7
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
8
*
16
));
src
[
9
]
=
vld1q_s8
((
src_ic_0_3
+
9
*
16
));
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
weight
[
2
]
=
vld1q_s8
(
read_weight_ptr
+
2
*
16
);
weight
[
3
]
=
vld1q_s8
(
read_weight_ptr
+
3
*
16
);
weight
[
4
]
=
vld1q_s8
(
read_weight_ptr
+
4
*
16
);
c
[
0
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
3
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
5
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
4
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
6
],
c
[
1
],
temp_c
[
1
]);
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
+
10
*
16
);
c
[
2
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
6
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
7
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
8
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
7
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
9
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
8
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
0
],
c
[
3
],
temp_c
[
1
]);
src
[
1
]
=
vld1q_s8
((
src_ic_0_3
+
11
*
16
));
src
[
2
]
=
vld1q_s8
((
src_ic_0_3
+
12
*
16
));
src
[
3
]
=
vld1q_s8
((
src_ic_0_3
+
13
*
16
));
src
[
4
]
=
vld1q_s8
((
src_ic_0_3
+
14
*
16
));
c
[
4
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
8
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
9
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
0
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
1
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
3
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
2
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
4
],
c
[
5
],
temp_c
[
1
]);
src
[
5
]
=
vld1q_s8
((
src_ic_0_3
+
15
*
16
));
src
[
6
]
=
vld1q_s8
((
src_ic_0_3
+
16
*
16
));
src
[
7
]
=
vld1q_s8
((
src_ic_0_3
+
17
*
16
));
src
[
8
]
=
vld1q_s8
((
src_ic_0_3
+
18
*
16
));
c
[
6
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
5
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
7
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
6
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
8
],
c
[
7
],
temp_c
[
1
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_oc4_ow8_remain_static
<
remain_w
,
Op
>
(
c
,
op
,
dst_ptr
);
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
>
static
void
ker_neon_dirctconv_7x7s2_oc4_ow8
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
const
Op
&
op
)
{
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
filter_size
;
constexpr
int
ic_step
=
4
;
constexpr
int
loop_ic_step
=
4
;
constexpr
int
ld_weight_ic4
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
int32x4_t
c
[
2
*
4
];
int8x16_t
weight
[
7
];
int8x16_t
src
[
8
+
2
];
int16x8_t
temp_c
[
2
];
init_oc4_ow8
<
bias_mode
>
(
c
,
bias_ptr
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
for
(
int
fh_idx
=
0
;
fh_idx
<
fh
;
++
fh_idx
)
{
const
int8_t
*
src_ic_0_3
=
src_ptr
+
ic_idx
*
ic_stride
+
fh_idx
*
iw
*
ic_step
*
pack_iw_len
;
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
);
src
[
1
]
=
vld1q_s8
(
src_ic_0_3
+
1
*
16
);
src
[
2
]
=
vld1q_s8
(
src_ic_0_3
+
2
*
16
);
src
[
3
]
=
vld1q_s8
(
src_ic_0_3
+
3
*
16
);
src
[
4
]
=
vld1q_s8
(
src_ic_0_3
+
4
*
16
);
src
[
5
]
=
vld1q_s8
(
src_ic_0_3
+
5
*
16
);
src
[
6
]
=
vld1q_s8
(
src_ic_0_3
+
6
*
16
);
src
[
7
]
=
vld1q_s8
(
src_ic_0_3
+
7
*
16
);
src
[
8
]
=
vld1q_s8
(
src_ic_0_3
+
8
*
16
);
src
[
9
]
=
vld1q_s8
(
src_ic_0_3
+
9
*
16
);
// oc == 0
const
int8_t
*
read_weight_ptr
=
weight_ptr
+
fh_idx
*
fw
*
ld_weight_ic4
;
weight
[
0
]
=
vld1q_s8
(
read_weight_ptr
);
weight
[
1
]
=
vld1q_s8
(
read_weight_ptr
+
16
);
weight
[
2
]
=
vld1q_s8
(
read_weight_ptr
+
2
*
16
);
weight
[
3
]
=
vld1q_s8
(
read_weight_ptr
+
3
*
16
);
weight
[
4
]
=
vld1q_s8
(
read_weight_ptr
+
4
*
16
);
weight
[
5
]
=
vld1q_s8
(
read_weight_ptr
+
5
*
16
);
weight
[
6
]
=
vld1q_s8
(
read_weight_ptr
+
6
*
16
);
c
[
0
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
3
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
5
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
4
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
6
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
5
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
7
],
c
[
1
],
temp_c
[
1
]);
c
[
0
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
6
],
c
[
0
],
temp_c
[
0
]);
c
[
1
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
8
],
c
[
1
],
temp_c
[
1
]);
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
+
10
*
16
);
src
[
1
]
=
vld1q_s8
(
src_ic_0_3
+
11
*
16
);
src
[
2
]
=
vld1q_s8
(
src_ic_0_3
+
12
*
16
);
c
[
2
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
6
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
7
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
8
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
7
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
9
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
8
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
0
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
9
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
1
],
c
[
3
],
temp_c
[
1
]);
c
[
2
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
0
],
c
[
2
],
temp_c
[
0
]);
c
[
3
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
2
],
c
[
3
],
temp_c
[
1
]);
src
[
3
]
=
vld1q_s8
(
src_ic_0_3
+
13
*
16
);
src
[
4
]
=
vld1q_s8
(
src_ic_0_3
+
14
*
16
);
src
[
5
]
=
vld1q_s8
(
src_ic_0_3
+
15
*
16
);
src
[
6
]
=
vld1q_s8
(
src_ic_0_3
+
16
*
16
);
c
[
4
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
8
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
0
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
9
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
1
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
0
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
2
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
1
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
3
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
2
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
4
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
3
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
5
],
c
[
5
],
temp_c
[
1
]);
c
[
4
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
4
],
c
[
4
],
temp_c
[
0
]);
c
[
5
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
6
],
c
[
5
],
temp_c
[
1
]);
src
[
7
]
=
vld1q_s8
(
src_ic_0_3
+
17
*
16
);
src
[
8
]
=
vld1q_s8
(
src_ic_0_3
+
18
*
16
);
src
[
9
]
=
vld1q_s8
(
src_ic_0_3
+
19
*
16
);
src
[
0
]
=
vld1q_s8
(
src_ic_0_3
+
20
*
16
);
c
[
6
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
2
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
0
],
src
[
4
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
3
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
1
],
src
[
5
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
4
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
2
],
src
[
6
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
5
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
3
],
src
[
7
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
6
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
4
],
src
[
8
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
7
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
5
],
src
[
9
],
c
[
7
],
temp_c
[
1
]);
c
[
6
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
8
],
c
[
6
],
temp_c
[
0
]);
c
[
7
]
=
vdotq_s32_h
(
weight
[
6
],
src
[
0
],
c
[
7
],
temp_c
[
1
]);
}
weight_ptr
+=
fh
*
fw
*
ld_weight_ic4
;
}
store_oc4_ow8_remain_static
<
remain_w
,
Op
>
(
c
,
op
,
dst_ptr
);
}
}
// namespace
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
>
void
conv_bias
::
conv_direct_stride2_2x2_int8_nchw44
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
MEGDNN_MARK_USED_VAR
(
temp
);
constexpr
size_t
filter_size
=
2
;
constexpr
size_t
fh
=
filter_size
;
constexpr
size_t
fw
=
filter_size
;
constexpr
size_t
ic_step
=
4
;
constexpr
size_t
oc_step
=
4
;
constexpr
size_t
big_oc_step
=
8
;
constexpr
size_t
oh_step
=
1
;
constexpr
size_t
ow_step
=
8
;
constexpr
size_t
stride_h
=
2
;
constexpr
size_t
stride_w
=
2
;
constexpr
int
pack_iw_len
=
4
;
const
size_t
out_img_stride
=
oh
*
ow
;
const
size_t
ow_end
=
ow
/
ow_step
*
ow_step
;
const
size_t
ow_remain
=
ow
-
ow_end
;
const
size_t
oc_end
=
oc
/
big_oc_step
*
big_oc_step
;
const
size_t
oc_remain
=
oc
-
oc_end
;
const
int
ld_oc
=
oh
*
ow
*
ic_step
;
for
(
size_t
oc_idx
=
0
;
oc_idx
<
oc_end
;
oc_idx
+=
big_oc_step
)
{
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
out_img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
ker_neon_dirctconv_2x2s2_oc8_ow8
<
bias_mode
,
Op
,
0
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_oc
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
out_img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
ker_neon_dirctconv_2x2s2_oc8_ow8
<
bias_mode
,
Op
,
remain_w
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_oc
,
op
);
}
}
}
if
(
oc_remain
>
0
)
{
const
size_t
oc_idx
=
oc_end
;
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
out_img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
ker_neon_dirctconv_2x2s2_oc4_ow8
<
bias_mode
,
Op
,
0
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
out_img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
ker_neon_dirctconv_2x2s2_oc4_ow8
<
bias_mode
,
Op
,
remain_w
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
);
}
}
}
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
>
void
conv_bias
::
conv_direct_stride2_3x3_int8_nchw44
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
MEGDNN_MARK_USED_VAR
(
temp
);
constexpr
size_t
filter_size
=
3
;
constexpr
size_t
fh
=
filter_size
;
constexpr
size_t
fw
=
filter_size
;
constexpr
size_t
ic_step
=
4
;
constexpr
size_t
oc_step
=
4
;
constexpr
size_t
oh_step
=
1
;
constexpr
size_t
ow_step
=
8
;
constexpr
size_t
stride_h
=
2
;
constexpr
size_t
stride_w
=
2
;
constexpr
int
pack_iw_len
=
4
;
const
size_t
img_stride
=
oh
*
ow
;
const
size_t
ow_end
=
ow
/
ow_step
*
ow_step
;
const
size_t
ow_remain
=
ow
-
ow_end
;
for
(
size_t
oc_idx
=
0
;
oc_idx
<
oc
;
oc_idx
+=
oc_step
)
{
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
ker_neon_dirctconv_3x3s2_oc4_ow8
<
bias_mode
,
Op
,
0
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
ker_neon_dirctconv_3x3s2_oc4_ow8
<
bias_mode
,
Op
,
remain_w
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
);
}
}
}
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
>
void
conv_bias
::
conv_direct_stride2_5x5_int8_nchw44
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
MEGDNN_MARK_USED_VAR
(
temp
);
constexpr
size_t
filter_size
=
5
;
constexpr
size_t
fh
=
filter_size
;
constexpr
size_t
fw
=
filter_size
;
constexpr
size_t
ic_step
=
4
;
constexpr
size_t
oc_step
=
4
;
constexpr
size_t
oh_step
=
1
;
constexpr
size_t
ow_step
=
8
;
constexpr
size_t
stride_h
=
2
;
constexpr
size_t
stride_w
=
2
;
constexpr
int
pack_iw_len
=
4
;
const
size_t
img_stride
=
oh
*
ow
;
const
size_t
ow_end
=
ow
/
ow_step
*
ow_step
;
const
size_t
ow_remain
=
ow
-
ow_end
;
for
(
size_t
oc_idx
=
0
;
oc_idx
<
oc
;
oc_idx
+=
oc_step
)
{
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
ker_neon_dirctconv_5x5s2_oc4_ow8
<
bias_mode
,
Op
,
0
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
ker_neon_dirctconv_5x5s2_oc4_ow8
<
bias_mode
,
Op
,
remain_w
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
);
}
}
}
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
>
void
conv_bias
::
conv_direct_stride2_7x7_int8_nchw44
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
MEGDNN_MARK_USED_VAR
(
temp
);
constexpr
size_t
filter_size
=
7
;
constexpr
size_t
fh
=
filter_size
;
constexpr
size_t
fw
=
filter_size
;
constexpr
size_t
ic_step
=
4
;
constexpr
size_t
oc_step
=
4
;
constexpr
size_t
oh_step
=
1
;
constexpr
size_t
ow_step
=
8
;
constexpr
size_t
stride_h
=
2
;
constexpr
size_t
stride_w
=
2
;
constexpr
int
pack_iw_len
=
4
;
const
size_t
img_stride
=
oh
*
ow
;
const
size_t
ow_end
=
ow
/
ow_step
*
ow_step
;
const
size_t
ow_remain
=
ow
-
ow_end
;
for
(
size_t
oc_idx
=
0
;
oc_idx
<
oc
;
oc_idx
+=
oc_step
)
{
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
ker_neon_dirctconv_7x7s2_oc4_ow8
<
bias_mode
,
Op
,
0
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
ker_neon_dirctconv_7x7s2_oc4_ow8
<
bias_mode
,
Op
,
remain_w
,
filter_size
>
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
op
);
}
}
}
}
#define INSTANTIATION(stride, i, bias, remain_w, Op) \
template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_nchw44< \
bias, Op, remain_w>(const int8_t*, const int8_t*, const int32_t*, \
int32_t*, int8_t*, const size_t, const size_t, \
const size_t, const size_t, const size_t, \
const size_t, const Op&);
#define FOR_OP(stride, i, bias, remain_w) \
INSTANTIATION(stride, i, bias, remain_w, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(stride, i, bias, remain_w, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(stride, i, bias, remain_w, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)
#define FOR_REMAIN(stride, i, bias) \
FOR_OP(stride, i, bias, 0) \
FOR_OP(stride, i, bias, 1) \
FOR_OP(stride, i, bias, 2) \
FOR_OP(stride, i, bias, 3) \
FOR_OP(stride, i, bias, 4) \
FOR_OP(stride, i, bias, 5) \
FOR_OP(stride, i, bias, 6) \
FOR_OP(stride, i, bias, 7)
#define FOR_BIAS(stride, i) \
FOR_REMAIN(stride, i, BiasMode::NO_BIAS) \
FOR_REMAIN(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)
#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)
FOR_FILTER
(
stride2
)
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
dnn/src/arm_common/conv_bias/opr_impl.cpp
浏览文件 @
3117bfb7
...
...
@@ -46,11 +46,10 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoQU8DirectStride1
qu8_direct_stride1_small_group
{
false
};
AlgoS8DirectStride2
s8_direct_stride2_large_group
{
true
};
AlgoS8DirectStride2
s8_direct_stride2_small_group
{
false
};
AlgoS8Direct
Stride2NCHW44
s8_direct_stride2
_nchw44
;
AlgoS8Direct
NCHW44
s8_direct
_nchw44
;
AlgoS8DirectNCHWNCHW44
s8_direct_nchw_nchw44
;
AlgoS8DirectStride1
s8_direct_stride1_large_group
{
true
};
AlgoS8DirectStride1
s8_direct_stride1_small_group
{
false
};
AlgoS8DirectStride1NCHW44
s8_direct_stride1_nchw44
;
AlgoS8ChanWiseStride1NCHW44
s8_channel_wise_stride1_nchw44
;
AlgoS8ChanWiseStride2NCHW44
s8_channel_wise_stride2_nchw44
;
...
...
@@ -114,11 +113,10 @@ public:
direct_algos
.
emplace_back
(
&
qu8_direct_stride1_small_group
);
direct_algos
.
emplace_back
(
&
s8_direct_stride2_large_group
);
direct_algos
.
emplace_back
(
&
s8_direct_stride2_small_group
);
direct_algos
.
emplace_back
(
&
s8_direct_
stride2_
nchw44
);
direct_algos
.
emplace_back
(
&
s8_direct_nchw44
);
direct_algos
.
emplace_back
(
&
s8_direct_nchw_nchw44
);
direct_algos
.
emplace_back
(
&
s8_direct_stride1_large_group
);
direct_algos
.
emplace_back
(
&
s8_direct_stride1_small_group
);
direct_algos
.
emplace_back
(
&
s8_direct_stride1_nchw44
);
direct_algos
.
emplace_back
(
&
s8_channel_wise_stride1_nchw44
);
direct_algos
.
emplace_back
(
&
s8_channel_wise_stride2_nchw44
);
...
...
dnn/src/arm_common/conv_bias/opr_impl.h
浏览文件 @
3117bfb7
...
...
@@ -37,9 +37,8 @@ protected:
private:
class
AlgoS8DirectStride1
;
class
AlgoS8DirectStride1NCHW44
;
class
AlgoS8DirectStride2
;
class
AlgoS8Direct
Stride2
NCHW44
;
class
AlgoS8DirectNCHW44
;
class
AlgoS8DirectNCHWNCHW44
;
class
AlgoQU8DirectStride1
;
class
AlgoQU8DirectStride2
;
...
...
dnn/src/arm_common/elemwise_helper/kimpl/none.h
浏览文件 @
3117bfb7
...
...
@@ -27,6 +27,8 @@ struct NoneOp;
#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \
template <> \
struct NoneOp<_ctype> : NoneOpBase<_ctype> { \
NoneOp(){}; \
NoneOp(float, float){}; \
using NoneOpBase::NoneOpBase; \
using NoneOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
...
...
dnn/test/arm_common/conv_bias.cpp
浏览文件 @
3117bfb7
...
...
@@ -226,7 +226,15 @@ static void benchmark_convbias(Handle* handle, std::string int_name,
run
(
1
,
3
,
32
,
224
,
224
,
5
,
1
,
true
);
run
(
1
,
3
,
64
,
224
,
224
,
7
,
1
,
true
);
for
(
size_t
stride
:
{
1
,
2
})
{
run
(
1
,
64
,
128
,
56
,
56
,
3
,
2
,
false
);
run
(
1
,
128
,
256
,
28
,
28
,
3
,
2
,
false
);
run
(
1
,
256
,
512
,
14
,
14
,
3
,
2
,
false
);
run
(
1
,
128
,
128
,
28
,
28
,
3
,
1
,
false
);
run
(
1
,
256
,
256
,
14
,
14
,
3
,
1
,
false
);
run
(
1
,
512
,
512
,
7
,
7
,
3
,
1
,
false
);
for
(
size_t
stride
:
{
1
})
{
printf
(
"stride %zu
\n
"
,
stride
);
for
(
size_t
filter_size
:
{
2
,
3
,
5
,
7
})
{
for
(
size_t
img_size
:
{
32
})
{
...
...
dnn/test/arm_common/conv_bias_multi_thread.cpp
浏览文件 @
3117bfb7
...
...
@@ -527,12 +527,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_SMALL_GROUP) {
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_INT8_STRIDE1_NCHW44
)
{
checker_conv_bias_qint8x8x8
(
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
1
,
false
,
false
,
false
),
handle
(),
"S8_NCHW44_DIRECT_STRD1"
);
handle
(),
"S8_NCHW44_DIRECT"
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_INT8_STRIDE1_NCHW44_8832
)
{
checker_conv_bias_qint8x8x32
(
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
1
,
false
,
false
,
true
),
handle
(),
"S8_NCHW44_DIRECT"
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_INT8_STRIDE2_NCHW44_8832
)
{
checker_conv_bias_qint8x8x32
(
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
2
,
false
,
false
,
true
),
handle
(),
"S8_NCHW44_DIRECT"
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_INT8_STRIDE2_NCHW44
)
{
checker_conv_bias_qint8x8x8
(
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
2
,
false
,
false
,
false
),
handle
(),
"S8_NCHW44_DIRECT
_STRD2
"
);
handle
(),
"S8_NCHW44_DIRECT"
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44
)
{
checker_conv_bias_qint8x8x8
(
...
...
@@ -1085,7 +1095,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) {
dtype
::
QuantizedS8
(
60.25
f
),
param
::
MatrixMul
::
Format
::
MK8
,
1e-3
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8
)
{
using
namespace
conv_bias
;
...
...
@@ -1096,17 +1105,17 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
param
::
MatrixMul
::
Format
format
,
float
eps
)
{
for
(
auto
&&
arg
:
args
)
{
for
(
uint32_t
m
:
out_size
)
{
checker
.
set_extra_opr_impl
(
std
::
bind
(
winograd_algo_extra_impl
,
std
::
placeholders
::
_1
,
m
,
arg
.
param
,
handle
,
format
));
checker
.
set_dtype
(
0
,
A_dtype
)
.
set_dtype
(
1
,
B_dtype
)
.
set_dtype
(
2
,
C_dtype
)
.
set_dtype
(
4
,
D_dtype
)
.
set_epsilon
(
eps
)
.
set_param
(
arg
.
param
)
.
execs
({
arg
.
src
,
arg
.
filter
,
arg
.
bias
,
{},
{}});
}
checker
.
set_extra_opr_impl
(
std
::
bind
(
winograd_algo_extra_impl
,
std
::
placeholders
::
_1
,
m
,
arg
.
param
,
handle
,
format
));
checker
.
set_dtype
(
0
,
A_dtype
)
.
set_dtype
(
1
,
B_dtype
)
.
set_dtype
(
2
,
C_dtype
)
.
set_dtype
(
4
,
D_dtype
)
.
set_epsilon
(
eps
)
.
set_param
(
arg
.
param
)
.
execs
({
arg
.
src
,
arg
.
filter
,
arg
.
bias
,
{},
{}});
}
}
};
...
...
@@ -1118,7 +1127,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
checker
.
set_before_exec_callback
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBias
>
(
ssprintf
(
"WINOGRAD_NCHW44:%s:8:2:32"
,
matmul_name
).
c_str
()));
std
::
vector
<
TestArg
>
quantized_args
=
get_int8_nchw44_args
(
3
,
4
);
std
::
vector
<
TestArg
>
quantized_args
=
get_int8_nchw44_args
(
3
,
4
);
UniformIntRNG
int_rng
{
-
50
,
50
};
checker
.
set_rng
(
0
,
&
int_rng
).
set_rng
(
1
,
&
int_rng
).
set_rng
(
2
,
&
int_rng
);
run
(
handle
(),
quantized_args
,
{
2
},
dtype
::
QuantizedS8
(
2.5
f
),
...
...
@@ -1126,8 +1135,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
dtype
::
QuantizedS8
(
60.25
f
),
param
::
MatrixMul
::
Format
::
MK8
,
1e-3
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE
)
{
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE
)
{
using
namespace
conv_bias
;
Checker
<
ConvBiasForward
>
checker
(
handle
());
...
...
@@ -1137,17 +1146,17 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPM
param
::
MatrixMul
::
Format
format
,
float
eps
)
{
for
(
auto
&&
arg
:
args
)
{
for
(
uint32_t
m
:
out_size
)
{
checker
.
set_extra_opr_impl
(
std
::
bind
(
winograd_algo_extra_impl
,
std
::
placeholders
::
_1
,
m
,
arg
.
param
,
handle
,
format
));
checker
.
set_dtype
(
0
,
A_dtype
)
.
set_dtype
(
1
,
B_dtype
)
.
set_dtype
(
2
,
C_dtype
)
.
set_dtype
(
4
,
D_dtype
)
.
set_epsilon
(
eps
)
.
set_param
(
arg
.
param
)
.
execs
({
arg
.
src
,
arg
.
filter
,
arg
.
bias
,
{},
{}});
}
checker
.
set_extra_opr_impl
(
std
::
bind
(
winograd_algo_extra_impl
,
std
::
placeholders
::
_1
,
m
,
arg
.
param
,
handle
,
format
));
checker
.
set_dtype
(
0
,
A_dtype
)
.
set_dtype
(
1
,
B_dtype
)
.
set_dtype
(
2
,
C_dtype
)
.
set_dtype
(
4
,
D_dtype
)
.
set_epsilon
(
eps
)
.
set_param
(
arg
.
param
)
.
execs
({
arg
.
src
,
arg
.
filter
,
arg
.
bias
,
{},
{}});
}
}
};
...
...
@@ -1168,7 +1177,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPM
dtype
::
QuantizedS8
(
60.25
f
),
param
::
MatrixMul
::
Format
::
MK8
,
1e-3
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32
)
{
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32
)
{
using
namespace
conv_bias
;
Checker
<
ConvBiasForward
>
checker
(
handle
());
...
...
@@ -1196,21 +1206,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F
#if MEGDNN_AARCH64
const
char
*
matmul_name
=
"AARCH64_F32_MK4_4x16"
;
#else
const
char
*
matmul_name
=
"ARMV7_F32_MK4_4x8"
;
const
char
*
matmul_name
=
"ARMV7_F32_MK4_4x8"
;
#endif
checker
.
set_before_exec_callback
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBias
>
(
ssprintf
(
"WINOGRAD_NCHW44:%s:4:2:32"
,
matmul_name
).
c_str
()));
std
::
vector
<
TestArg
>
quantized_args
=
get_int8_nchw44_args
(
3
,
4
,
true
);
std
::
vector
<
TestArg
>
quantized_args
=
get_int8_nchw44_args
(
3
,
4
,
true
);
UniformIntRNG
int_rng
{
-
50
,
50
};
checker
.
set_rng
(
0
,
&
int_rng
).
set_rng
(
1
,
&
int_rng
).
set_rng
(
2
,
&
int_rng
);
run
(
handle
(),
quantized_args
,
{
2
},
dtype
::
QuantizedS8
(
0.41113496
f
),
dtype
::
QuantizedS8
(
0.01887994
f
),
dtype
::
QuantizedS32
(
0.41113496
f
*
0.01887994
f
),
dtype
::
QuantizedS8
(
0.49550694
f
),
param
::
MatrixMul
::
Format
::
MK4
,
epsilon
);
dtype
::
QuantizedS8
(
0.49550694
f
),
param
::
MatrixMul
::
Format
::
MK4
,
epsilon
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE
)
{
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE
)
{
using
namespace
conv_bias
;
Checker
<
ConvBiasForward
>
checker
(
handle
());
...
...
@@ -1238,7 +1249,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F
#if MEGDNN_AARCH64
const
char
*
matmul_name
=
"AARCH64_F32_MK4_4x16"
;
#else
const
char
*
matmul_name
=
"ARMV7_F32_MK4_4x8"
;
const
char
*
matmul_name
=
"ARMV7_F32_MK4_4x8"
;
#endif
checker
.
set_before_exec_callback
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBias
>
(
ssprintf
(
"WINOGRAD_NCHW44:%s:4:2:32"
,
matmul_name
).
c_str
()));
...
...
@@ -1249,10 +1260,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F
run
(
handle
(),
quantized_args
,
{
2
},
dtype
::
QuantizedS8
(
0.41113496
f
),
dtype
::
QuantizedS8
(
0.01887994
f
),
dtype
::
QuantizedS32
(
0.41113496
f
*
0.01887994
f
),
dtype
::
QuantizedS8
(
0.49550694
f
),
param
::
MatrixMul
::
Format
::
MK4
,
epsilon
);
dtype
::
QuantizedS8
(
0.49550694
f
),
param
::
MatrixMul
::
Format
::
MK4
,
epsilon
);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_WINOGRAD_F16_F23
)
{
using
namespace
conv_bias
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录