Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7b0dbe6a
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
7b0dbe6a
编写于
6月 03, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(dnn/arm): fix stride 1 support for int8 nchw_nchw44
GitOrigin-RevId: 9d718eb7a4dae3c2724ea07ba2b639fbfb319f78
上级
198f3eb5
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
1682 addition
and
52 deletion
+1682
-52
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp
+3
-2
dnn/src/arm_common/conv_bias/int8/algos.h
dnn/src/arm_common/conv_bias/int8/algos.h
+2
-2
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp
...src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp
+373
-0
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h
+1287
-0
dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h
...m_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h
+0
-44
dnn/src/arm_common/conv_bias/opr_impl.cpp
dnn/src/arm_common/conv_bias/opr_impl.cpp
+2
-2
dnn/src/arm_common/conv_bias/opr_impl.h
dnn/src/arm_common/conv_bias/opr_impl.h
+1
-1
dnn/test/arm_common/conv_bias.cpp
dnn/test/arm_common/conv_bias.cpp
+8
-0
dnn/test/arm_common/conv_bias_multi_thread.cpp
dnn/test/arm_common/conv_bias_multi_thread.cpp
+6
-1
未找到文件。
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp
浏览文件 @
7b0dbe6a
...
...
@@ -37,7 +37,7 @@ static inline size_t get_perthread_cache_bytes(const int ic, const int ih2,
static
void
get_rectified_size
(
const
megdnn
::
fallback
::
ConvBiasImpl
::
NCBKernSizeParam
&
param
,
int
&
ih2
,
int
&
iw2
,
int
&
oh2
,
int
&
ow2
)
{
constexpr
int
cacheline
=
64
/
sizeof
(
float
);
constexpr
int
nr_elements_in_
cacheline
=
64
/
sizeof
(
float
);
int
ic
=
param
.
filter_meta
.
icpg
;
int
iw
=
param
.
isz
[
1
];
int
oh
=
param
.
osz
[
0
];
...
...
@@ -52,7 +52,8 @@ static void get_rectified_size(
int
block_oh
=
l2_block_helper
(
param
.
nr_threads
,
oh
,
ic
*
iw
*
sizeof
(
float
)
*
stride_h
);
ih2
=
block_oh
*
stride_h
+
filter_h
-
stride_h
;
iw2
=
round_up
(
iw
+
2
*
static_cast
<
int
>
(
fm
.
padding
[
1
]),
cacheline
);
iw2
=
round_up
(
iw
+
2
*
static_cast
<
int
>
(
fm
.
padding
[
1
]),
nr_elements_in_cacheline
);
}
static
WorkspaceBundle
get_bundle
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
)
{
...
...
dnn/src/arm_common/conv_bias/int8/algos.h
浏览文件 @
7b0dbe6a
...
...
@@ -90,9 +90,9 @@ public:
const
NCBKernSizeParam
&
param
)
const
override
;
};
class
ConvBiasImpl
::
AlgoS8Direct
Stride2
NCHWNCHW44
final
:
public
AlgoBase
{
class
ConvBiasImpl
::
AlgoS8DirectNCHWNCHW44
final
:
public
AlgoBase
{
public:
AlgoS8Direct
Stride2
NCHWNCHW44
()
{}
AlgoS8DirectNCHWNCHW44
()
{}
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"S8_CONV_NCHW_NCHW44"
;
}
bool
usable
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
,
...
...
dnn/src/arm_common/conv_bias/int8/direct_
stride2_
nchw_nchw44_algo.cpp
→
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp
浏览文件 @
7b0dbe6a
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_
stride2_
nchw_nchw44_algo.cpp
* \file dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...
...
@@ -12,7 +12,7 @@
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct_
stride2_
nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"
...
...
@@ -25,93 +25,147 @@ using conv_fun = std::function<void(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
CpuNDRange
&
workspace_ids
,
const
CpuNDRange
&
ncb_range
)
>
;
MIDOUT_DECL
(
megdnn_arm_common_conv_bias_int8_nchw_nchw44
_stride2
)
MIDOUT_DECL
(
megdnn_arm_common_conv_bias_int8_nchw_nchw44
)
static
void
get_rectified_size
(
const
megdnn
::
fallback
::
ConvBiasImpl
::
NCBKernSizeParam
&
param
,
size_t
&
IH2
,
size_t
&
IW2
,
size_t
&
OH2
,
size_t
&
OW
2
)
{
const
megdnn
::
fallback
::
ConvBiasImpl
::
NCBKernSizeParam
&
param
,
int
&
ih2
,
int
&
iw2
,
int
&
oh2
,
int
&
ow
2
)
{
auto
&&
fm
=
param
.
filter_meta
;
size_t
IH
=
param
.
isz
[
0
];
size_t
IW
=
param
.
isz
[
1
];
size_t
OH
=
param
.
osz
[
0
];
size_t
OW
=
param
.
osz
[
1
];
int
ih
=
param
.
isz
[
0
];
int
iw
=
param
.
isz
[
1
];
int
oh
=
param
.
osz
[
0
];
int
ow
=
param
.
osz
[
1
];
int
ph
=
fm
.
padding
[
0
];
int
pw
=
fm
.
padding
[
1
];
int
stride_h
=
fm
.
stride
[
0
];
OH2
=
OH
;
OW2
=
OW
;
IH2
=
round_up
(
IH
+
2
*
fm
.
padding
[
0
],
static_cast
<
size_t
>
(
2
));
IW2
=
IW
+
2
*
fm
.
padding
[
1
];
oh2
=
oh
;
ow2
=
ow
;
ih2
=
stride_h
==
2
?
round_up
(
ih
+
2
*
ph
,
2
)
:
ih
+
2
*
ph
;
iw2
=
iw
+
2
*
pw
;
}
static
inline
size_t
get_temp_bytes
(
const
int
iw
,
const
int
pw
)
{
//! border_size is used to avoid read illegal memory
constexpr
int
cacheline_size
=
64
;
constexpr
int
border_size
=
1
*
cacheline_size
;
return
round_up
(
iw
+
pw
*
2
,
cacheline_size
)
+
border_size
;
}
static
WorkspaceBundle
get_bundle
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
)
{
constexpr
size_t
src_expand
=
4
;
auto
&&
fm
=
param
.
filter_meta
;
size_t
group
=
fm
.
group
;
size_t
batch
=
param
.
n
;
size_t
IC
=
fm
.
icpg
;
size_t
OC
=
fm
.
ocpg
;
size_t
FH
=
fm
.
spatial
[
0
];
size_t
FW
=
fm
.
spatial
[
1
];
size_t
IH2
,
IW2
,
OH2
,
OW2
;
get_rectified_size
(
param
,
IH2
,
IW2
,
OH2
,
OW2
);
int
group
=
fm
.
group
;
int
batch
=
param
.
n
;
int
ic
=
fm
.
icpg
;
int
oc
=
fm
.
ocpg
;
int
fh
=
fm
.
spatial
[
0
];
int
fw
=
fm
.
spatial
[
1
];
int
stride_h
=
fm
.
stride
[
0
];
int
iw
=
param
.
isz
[
1
];
int
pw
=
fm
.
padding
[
1
];
int
ih2
,
iw2
,
oh2
,
ow2
;
const
size_t
src_expand
=
stride_h
==
2
?
4
:
16
;
get_rectified_size
(
param
,
ih2
,
iw2
,
oh2
,
ow2
);
megdnn_assert
(
group
==
1
,
"only support group == 1 now"
);
size_t
src_size
=
batch
*
group
*
IC
*
IH2
*
IW2
*
sizeof
(
int8_t
)
*
src_expand
;
size_t
weight_size
=
group
*
OC
*
IC
*
FH
*
FW
*
sizeof
(
int8_t
);
return
{
nullptr
,
{
src_size
,
weight_size
}};
batch
*
group
*
ic
*
ih2
*
iw2
*
sizeof
(
int8_t
)
*
src_expand
;
size_t
weight_size
=
group
*
oc
*
ic
*
fh
*
fw
*
sizeof
(
int8_t
);
size_t
tmp_size
=
0
;
if
(
stride_h
==
1
)
{
weight_size
=
group
*
oc
*
ic
*
fh
*
round_up
(
fw
,
4
)
*
sizeof
(
int8_t
);
tmp_size
=
get_temp_bytes
(
iw
,
pw
);
}
return
{
nullptr
,
{
src_size
,
weight_size
,
tmp_size
*
param
.
nr_threads
}};
};
static
void
copy_padding_kern
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
CpuNDRange
&
workspace_ids
)
{
size_t
IH
=
kern_param
.
isz
[
0
];
size_t
IW
=
kern_param
.
isz
[
1
];
size_t
IC
=
kern_param
.
filter_meta
.
icpg
;
size_t
PH
=
kern_param
.
filter_meta
.
padding
[
0
];
size_t
PW
=
kern_param
.
filter_meta
.
padding
[
1
];
size_t
GROUP
=
kern_param
.
filter_meta
.
group
;
int
ih
=
kern_param
.
isz
[
0
];
int
iw
=
kern_param
.
isz
[
1
];
int
ic
=
kern_param
.
filter_meta
.
icpg
;
int
ph
=
kern_param
.
filter_meta
.
padding
[
0
];
int
pw
=
kern_param
.
filter_meta
.
padding
[
1
];
int
group
=
kern_param
.
filter_meta
.
group
;
int
stride_h
=
kern_param
.
filter_meta
.
stride
[
0
];
size_t
IH2
,
IW2
,
OH2
,
OW
2
;
get_rectified_size
(
kern_param
,
IH2
,
IW2
,
OH2
,
OW
2
);
size_t
padding_group_size
=
IH2
*
IW2
*
IC
;
int
ih2
,
iw2
,
oh2
,
ow
2
;
get_rectified_size
(
kern_param
,
ih2
,
iw2
,
oh2
,
ow
2
);
int
padding_group_size
=
ih2
*
iw2
*
ic
;
bundle
.
set
(
kern_param
.
workspace_ptr
);
//! Used for get the workspace offset
constexpr
int
expend_element
=
4
;
// TODO: block dim is better to get from arg
size_t
workspace_ic_block
=
1
;
size_t
workspace_batch_id
=
workspace_ids
[
0
];
size_t
workspace_group_id
=
workspace_ids
[
1
];
size_t
workspace_ic_id
=
workspace_ids
[
2
];
size_t
workspace_ic
=
workspace_ic_id
*
workspace_ic_block
;
size_t
batch_id
=
ncb_index
.
ndrange_id
[
0
];
size_t
group_id
=
ncb_index
.
ndrange_id
[
1
];
const
int
src_expand
=
stride_h
==
2
?
4
:
16
;
//! TODO: block dim is better to get from arg
int
workspace_ic_block
=
1
;
int
workspace_batch_id
=
workspace_ids
[
0
];
int
workspace_group_id
=
workspace_ids
[
1
];
int
workspace_ic_id
=
workspace_ids
[
2
];
int
workspace_ic
=
workspace_ic_id
*
workspace_ic_block
;
int
batch_id
=
ncb_index
.
ndrange_id
[
0
];
int
group_id
=
ncb_index
.
ndrange_id
[
1
];
const
int8_t
*
sptr
=
static_cast
<
const
int8_t
*>
(
kern_param
.
src
<
int8_t
>
(
batch_id
,
group_id
,
workspace_ic_id
,
1
,
1
));
//! copy to sptr_base to eliminate padding effect
int8_t
*
sptr_base
=
static_cast
<
int8_t
*>
(
bundle
.
get
(
0
))
+
(
workspace_batch_id
*
GROUP
*
padding_group_size
+
(
workspace_batch_id
*
group
*
padding_group_size
+
workspace_group_id
*
padding_group_size
+
workspace_ic
*
IH2
*
IW2
)
*
expend_element
;
conv_bias
::
pack_nchw_src_for_nchw44_conv
(
sptr
,
sptr_base
,
1
,
PH
,
PH
,
PW
,
PW
,
IH
,
IW
);
workspace_ic
*
ih2
*
iw2
)
*
src_expand
;
if
(
stride_h
==
1
)
{
const
size_t
tmp_size
=
get_temp_bytes
(
iw
,
pw
);
int8_t
*
tmp_ptr
=
reinterpret_cast
<
int8_t
*>
(
bundle
.
get
(
2
))
+
ncb_index
.
thread_id
*
tmp_size
;
pack_nchw_src_for_nchw44_conv
<
1
>
(
sptr
,
sptr_base
,
1
,
ph
,
ph
,
pw
,
pw
,
ih
,
iw
,
iw2
,
pw
,
tmp_ptr
);
}
else
{
pack_nchw_src_for_nchw44_conv
<
2
>
(
sptr
,
sptr_base
,
1
,
ph
,
ph
,
pw
,
pw
,
ih
,
iw
,
iw2
,
pw
,
nullptr
);
}
}
static
void
pack_weight
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
{
bundle
.
set
(
kern_param
.
workspace_ptr
);
const
int
group_id
=
ncb_index
.
ndrange_id
[
0
];
int
fh
=
kern_param
.
filter_meta
.
spatial
[
0
];
int
fw
=
kern_param
.
filter_meta
.
spatial
[
1
];
int
oc
=
kern_param
.
filter_meta
.
ocpg
;
int
ic
=
kern_param
.
filter_meta
.
icpg
;
int
stride_h
=
kern_param
.
filter_meta
.
stride
[
0
];
int
fw2
=
stride_h
==
2
?
fw
:
round_up
(
fw
,
4
);
int
oc_block
=
oc
;
int
oc_idx
=
0
;
const
int8_t
*
fptr
=
kern_param
.
filter
<
dt_int8
>
(
group_id
)
+
oc_idx
*
fh
*
fw
*
ic
;
auto
packed_weight
=
reinterpret_cast
<
int8_t
*>
(
bundle
.
get
(
1
))
+
group_id
*
oc
*
ic
*
fh
*
fw2
+
oc_idx
*
ic
*
fh
*
fw2
;
template
<
size_t
filter
,
BiasMode
bias_mode
,
typename
Op
>
if
(
stride_h
==
1
)
{
pack_nchw44_weight_for_nchw_conv
<
1
>
(
fptr
,
packed_weight
,
ic
,
fh
,
fw
,
oc_block
);
}
else
{
pack_nchw44_weight_for_nchw_conv
<
2
>
(
fptr
,
packed_weight
,
ic
,
fh
,
fw
,
oc_block
);
}
}
template
<
size_t
filter
,
BiasMode
bias_mode
,
typename
Op
,
int
stride
>
static
void
do_conv_kern
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
CpuNDRange
&
workspace_ids
,
const
CpuNDRange
&
ncb_range
)
{
size_t
OH
=
kern_param
.
osz
[
0
];
size_t
OW
=
kern_param
.
osz
[
1
];
size_t
FH
=
kern_param
.
filter_meta
.
spatial
[
0
];
size_t
FW
=
kern_param
.
filter_meta
.
spatial
[
1
];
size_t
IC
=
kern_param
.
filter_meta
.
icpg
;
size_t
OC
=
kern_param
.
filter_meta
.
ocpg
;
size_t
GROUP
=
kern_param
.
filter_meta
.
group
;
size_t
IH2
,
IW2
,
OH2
,
OW2
;
get_rectified_size
(
kern_param
,
IH2
,
IW2
,
OH2
,
OW2
);
int
oh
=
kern_param
.
osz
[
0
];
int
ow
=
kern_param
.
osz
[
1
];
int
fh
=
kern_param
.
filter_meta
.
spatial
[
0
];
int
fw
=
kern_param
.
filter_meta
.
spatial
[
1
];
int
fw2
=
stride
==
2
?
fw
:
round_up
(
fw
,
4
);
int
ic
=
kern_param
.
filter_meta
.
icpg
;
int
oc
=
kern_param
.
filter_meta
.
ocpg
;
int
group
=
kern_param
.
filter_meta
.
group
;
int
ih2
,
iw2
,
oh2
,
ow2
;
get_rectified_size
(
kern_param
,
ih2
,
iw2
,
oh2
,
ow2
);
bool
need_post_process
=
kern_param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
;
//! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f)
...
...
@@ -122,54 +176,46 @@ static void do_conv_kern(WorkspaceBundle bundle,
float
scale_dst
=
kern_param
.
dst_type
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
op
=
Op
(
scale_bias
,
scale_dst
);
}
size_t
padding_group_size
=
IH2
*
IW2
*
IC
;
int
padding_group_size
=
ih2
*
iw2
*
ic
;
bundle
.
set
(
kern_param
.
workspace_ptr
);
constexpr
size_
t
pack_c
=
4
;
constexpr
size_t
src_expand_size
=
4
;
const
size_
t
workspace_batch_id
=
workspace_ids
[
0
];
const
size_
t
workspace_group_id
=
workspace_ids
[
1
];
const
size_
t
batch_id
=
ncb_index
.
ndrange_id
[
0
];
const
size_
t
group_id
=
ncb_index
.
ndrange_id
[
1
];
const
size_
t
oc_id
=
ncb_index
.
ndrange_id
[
2
];
const
size_
t
oc_block_num
=
ncb_range
[
2
];
size_t
nr_pack_per_step
=
div_ceil
(
div_ceil
(
OC
,
pack_c
),
oc_block_num
);
size_
t
oc_block
=
nr_pack_per_step
*
pack_c
;
const
size_
t
oc_idx
=
oc_id
*
oc_block
;
constexpr
in
t
pack_c
=
4
;
constexpr
int
src_expand_size
=
stride
==
2
?
4
:
16
;
const
in
t
workspace_batch_id
=
workspace_ids
[
0
];
const
in
t
workspace_group_id
=
workspace_ids
[
1
];
const
in
t
batch_id
=
ncb_index
.
ndrange_id
[
0
];
const
in
t
group_id
=
ncb_index
.
ndrange_id
[
1
];
const
in
t
oc_id
=
ncb_index
.
ndrange_id
[
2
];
const
in
t
oc_block_num
=
ncb_range
[
2
];
int
nr_pack_per_step
=
div_ceil
(
div_ceil
(
oc
,
pack_c
),
oc_block_num
);
in
t
oc_block
=
nr_pack_per_step
*
pack_c
;
const
in
t
oc_idx
=
oc_id
*
oc_block
;
if
(
oc_id
==
(
oc_block_num
-
1
))
{
oc_block
=
OC
-
oc_id
*
nr_pack_per_step
*
pack_c
;
oc_block
=
oc
-
oc_id
*
nr_pack_per_step
*
pack_c
;
}
megdnn_assert
(
oc_block
%
pack_c
==
0
,
"oc must be devisible by 4, but oc = %
zu
"
,
oc_block
);
"oc must be devisible by 4, but oc = %
d
"
,
oc_block
);
const
int8_t
*
sptr
=
static_cast
<
int8_t
*>
(
bundle
.
get
(
0
))
+
workspace_batch_id
*
GROUP
*
padding_group_size
*
src_expand_size
+
workspace_batch_id
*
group
*
padding_group_size
*
src_expand_size
+
workspace_group_id
*
padding_group_size
*
src_expand_size
;
const
int8_t
*
fptr
=
kern_param
.
filter
<
dt_int8
>
(
group_id
)
+
oc_idx
*
FH
*
FW
*
IC
;
void
*
dst
=
reinterpret_cast
<
void
*>
(
int8_t
*
dst
=
reinterpret_cast
<
int8_t
*>
(
reinterpret_cast
<
ptrdiff_t
>
(
kern_param
.
dst
<
void
>
(
batch_id
,
group_id
))
+
oc_idx
*
OH
*
OW
);
oc_idx
*
oh
*
ow
);
const
int32_t
*
bptr
=
kern_param
.
bias
<
dt_int32
>
(
batch_id
,
group_id
)
+
oc_idx
;
auto
packed_weight
=
reinterpret_cast
<
int8_t
*>
(
bundle
.
get
(
1
))
+
group_id
*
OC
*
IC
*
FH
*
FW
+
oc_idx
*
IC
*
FH
*
FW
;
conv_bias
::
pack_nchw44_weight_for_nchw_conv
(
fptr
,
packed_weight
,
IC
,
FH
,
FW
,
oc_block
);
#define KERN1_NCHW44_CONV(filter) \
conv_bias::conv_direct_stride2_##filter##x##filter##_int8_nchw_nchw44< \
bias_mode, Op>(sptr, packed_weight, bptr, nullptr, \
static_cast<int8_t*>(dst), oc_block, IC, IH2, IW2, \
OH, OW, op)
DISPATCH_FILTER
(
filter
,
KERN1_NCHW44_CONV
);
#undef KERN1_NCHW44_CONV
int8_t
*
packed_weight
=
reinterpret_cast
<
int8_t
*>
(
bundle
.
get
(
1
))
+
group_id
*
oc
*
ic
*
fh
*
fw2
+
oc_idx
*
ic
*
fh
*
fw2
;
conv_direct_int8_nchw_nchw44
<
bias_mode
,
Op
,
filter
,
stride
>
(
sptr
,
packed_weight
,
bptr
,
nullptr
,
dst
,
oc_block
,
ic
,
ih2
,
iw2
,
oh
,
ow
,
op
);
}
/* ===================== stride2 algo ===================== */
bool
ConvBiasImpl
::
AlgoS8DirectStride2NCHWNCHW44
::
usable
(
bool
ConvBiasImpl
::
AlgoS8DirectNCHWNCHW44
::
usable
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
algo_selection_strategy
)
const
{
MEGDNN_MARK_USED_VAR
(
algo_selection_strategy
);
...
...
@@ -184,13 +230,14 @@ bool ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::usable(
(
fm
.
format
==
param
::
Convolution
::
Format
::
NCHW44
)
&&
(
OC
%
4
==
0
&&
OC
>=
4
)
&&
!
fm
.
should_flip
&&
fm
.
group
==
1
&&
fm
.
spatial_ndim
==
2
&&
fm
.
dilation
[
0
]
==
1
&&
fm
.
dilation
[
1
]
==
1
&&
fm
.
stride
[
0
]
==
2
&&
fm
.
stride
[
1
]
==
2
&&
FH
==
fm
.
spatial
[
1
]
&&
(
FH
==
3
||
FH
==
5
||
FH
==
7
)
&&
fm
.
group
==
1
&&
param
.
bias_mode
!=
BiasMode
::
BIAS
;
fm
.
dilation
[
1
]
==
1
&&
fm
.
stride
[
0
]
==
fm
.
stride
[
1
]
&&
(
fm
.
stride
[
0
]
==
1
||
fm
.
stride
[
0
]
==
2
)
&&
FH
==
fm
.
spatial
[
1
]
&&
(
FH
==
2
||
FH
==
3
||
FH
==
5
||
FH
==
7
)
&&
fm
.
group
==
1
&&
param
.
bias_mode
!=
BiasMode
::
BIAS
;
return
avaible
;
}
bool
ConvBiasImpl
::
AlgoS8Direct
Stride2
NCHWNCHW44
::
is_preferred
(
bool
ConvBiasImpl
::
AlgoS8DirectNCHWNCHW44
::
is_preferred
(
megdnn
::
fallback
::
ConvBiasImpl
*
conv_bias_impl_ptr
,
const
NCBKernSizeParam
&
param
)
const
{
// TODO: benchmark and fix
...
...
@@ -199,13 +246,13 @@ bool ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::is_preferred(
return
false
;
}
size_t
ConvBiasImpl
::
AlgoS8Direct
Stride2
NCHWNCHW44
::
get_workspace
(
size_t
ConvBiasImpl
::
AlgoS8DirectNCHWNCHW44
::
get_workspace
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
return
get_bundle
(
param
).
total_size_in_bytes
();
}
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ConvBiasImpl
::
AlgoS8Direct
Stride2
NCHWNCHW44
::
dispatch_kerns
(
ConvBiasImpl
::
AlgoS8DirectNCHWNCHW44
::
dispatch_kerns
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
auto
fm
=
param
.
filter_meta
;
size_t
N
=
param
.
n
;
...
...
@@ -215,61 +262,76 @@ ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::dispatch_kerns(
conv_fun
do_conv_fun
=
nullptr
;
// NOTE: remain_w is not used to gen hash of midout for compatible with changing
// shape runtime
#define DO_CONV_KERN_FUN(
filter, bias_mode, op)
\
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44
_stride2,
\
midout_iv(#
filter #bias_mode #op##_hash)) {
\
do_conv_fun = do_conv_kern<filter, bias_mode, op
>;
\
#define DO_CONV_KERN_FUN(
stride, filter, bias_mode, op)
\
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44
,
\
midout_iv(#
stride #filter #bias_mode #op##_hash)) {
\
do_conv_fun = do_conv_kern<filter, bias_mode, op
, stride>;
\
} \
MIDOUT_END();
#define GET_OP_PARAM(
filter, bias_mode)
\
#define GET_OP_PARAM(
stride, filter, bias_mode)
\
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(
filter, bias_mode,
\
DO_CONV_KERN_FUN(
stride, filter, bias_mode,
\
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(
filter, bias_mode,
\
DO_CONV_KERN_FUN(
stride, filter, bias_mode,
\
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(
filter, bias_mode,
\
DO_CONV_KERN_FUN(
stride, filter, bias_mode,
\
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_BIAS_MODE_PARAM(filter) \
#define GET_BIAS_MODE_PARAM(
stride,
filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(filter, BiasMode::NO_BIAS) \
GET_OP_PARAM(
stride,
filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
GET_OP_PARAM(
stride,
filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN(
)
\
#define DISPATCH_CONV_KERN(
stride)
\
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(stride, 2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(
3)
\
GET_BIAS_MODE_PARAM(
stride, 3)
\
break; \
case 5: \
GET_BIAS_MODE_PARAM(
5)
\
GET_BIAS_MODE_PARAM(
stride, 5)
\
break; \
case 7: \
GET_BIAS_MODE_PARAM(
7)
\
GET_BIAS_MODE_PARAM(
stride, 7)
\
break; \
default: \
megdnn_assert(0); \
break; \
}
DISPATCH_CONV_KERN
();
switch
(
param
.
filter_meta
.
stride
[
0
])
{
case
1
:
DISPATCH_CONV_KERN
(
1
);
break
;
case
2
:
DISPATCH_CONV_KERN
(
2
);
break
;
default:
megdnn_throw
(
ssprintf
(
"Unsupport stride size %u for the first conv"
,
param
.
filter_meta
.
stride
[
0
])
.
c_str
());
break
;
}
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
...
...
@@ -290,6 +352,12 @@ ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::dispatch_kerns(
};
ret_kerns
.
push_back
({
copy_padding
,
{
N
,
group
,
fm
.
icpg
}});
auto
do_pack_weight
=
[
bundle
](
const
NCBKernParam
&
kern_param
,
const
NCBKernIndex
&
ncb_index
)
{
pack_weight
(
bundle
,
kern_param
,
ncb_index
);
};
ret_kerns
.
push_back
({
do_pack_weight
,
{
static_cast
<
size_t
>
(
group
)}});
CpuNDRange
ncb_range
=
{
N
,
group
,
div_ceil
(
OC
,
oc_step
)};
auto
do_conv
=
[
bundle
,
do_conv_fun
,
ncb_range
](
const
NCBKernParam
&
kern_param
,
...
...
dnn/src/arm_common/conv_bias/int8/direct_
stride2_nchw_nchw44_kern.cpp
→
dnn/src/arm_common/conv_bias/int8/direct_
nchw_nchw44_kern.h
浏览文件 @
7b0dbe6a
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_
stride2_nchw44_kern_nchw.cpp
* \file dnn/src/arm_common/conv_bias/int8/direct_
nchw_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...
...
@@ -9,28 +9,40 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h"
#pragma once
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
using
namespace
megdnn
;
using
namespace
arm_common
;
namespace
megdnn
{
namespace
arm_common
{
namespace
{
template
<
int
src_idx
,
int
weight_idx
,
int
c_dim
,
typename
Func
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
/**
* @brief core code for calculation patten
*
* @tparam src_idx is offset of src reg
* @tparam weight_idx is offset of weight reg
* @tparam c_dim is output channel
* @tparam Func mla operation funcion
* @tparam stride
* @tparam T outpur regs type
* @tparam T2 src regs type
* @tparam T3 weight regs type
* @tparam T4 temp regs type
*/
template
<
int
src_idx
,
int
weight_idx
,
int
c_dim
,
typename
Func
,
int
stride
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
struct
ShiftCalHelper
{
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
,
T4
&
temp
);
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
);
};
template
<
int
src_idx
,
int
weight_idx
,
typename
Func
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
struct
ShiftCalHelper
<
src_idx
,
weight_idx
,
2
,
Func
,
T
,
T2
,
T3
,
T4
>
{
struct
ShiftCalHelper
<
src_idx
,
weight_idx
,
2
,
Func
,
2
,
T
,
T2
,
T3
,
T4
>
{
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
,
T4
&
temp
)
{
c
[
0
][
0
]
=
Func
::
impl
(
src
[
0
+
src_idx
],
weight
[
0
][
weight_idx
],
c
[
0
][
0
],
temp
[
0
]);
...
...
@@ -62,7 +74,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, T, T2, T3, T4> {
};
template
<
int
src_idx
,
int
weight_idx
,
typename
Func
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
struct
ShiftCalHelper
<
src_idx
,
weight_idx
,
1
,
Func
,
T
,
T2
,
T3
,
T4
>
{
struct
ShiftCalHelper
<
src_idx
,
weight_idx
,
1
,
Func
,
2
,
T
,
T2
,
T3
,
T4
>
{
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
,
T4
&
temp
)
{
c
[
0
][
0
]
=
Func
::
impl
(
src
[
0
+
src_idx
],
weight
[
0
][
weight_idx
],
c
[
0
][
0
],
temp
[
0
]);
...
...
@@ -81,17 +93,81 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, T, T2, T3, T4> {
}
};
template
<
int
src_idx
,
int
weight_idx
,
int
c_dim
,
typename
FUNC
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
template
<
int
src_idx
,
int
weight_idx
,
typename
Func
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
struct
ShiftCalHelper
<
src_idx
,
weight_idx
,
2
,
Func
,
1
,
T
,
T2
,
T3
,
T4
>
{
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
,
T4
&
temp
)
{
c
[
0
][
0
]
=
Func
::
impl
(
src
[(
0
+
src_idx
)
%
8
],
weight
[
0
][
weight_idx
],
c
[
0
][
0
],
temp
[
0
]);
c
[
1
][
0
]
=
Func
::
impl
(
src
[(
0
+
src_idx
)
%
8
],
weight
[
1
][
weight_idx
],
c
[
1
][
0
],
temp
[
1
]);
c
[
0
][
1
]
=
Func
::
impl
(
src
[(
1
+
src_idx
)
%
8
],
weight
[
0
][
weight_idx
],
c
[
0
][
1
],
temp
[
2
]);
c
[
1
][
1
]
=
Func
::
impl
(
src
[(
1
+
src_idx
)
%
8
],
weight
[
1
][
weight_idx
],
c
[
1
][
1
],
temp
[
3
]);
c
[
0
][
2
]
=
Func
::
impl
(
src
[(
2
+
src_idx
)
%
8
],
weight
[
0
][
weight_idx
],
c
[
0
][
2
],
temp
[
0
]);
c
[
1
][
2
]
=
Func
::
impl
(
src
[(
2
+
src_idx
)
%
8
],
weight
[
1
][
weight_idx
],
c
[
1
][
2
],
temp
[
1
]);
c
[
0
][
3
]
=
Func
::
impl
(
src
[(
3
+
src_idx
)
%
8
],
weight
[
0
][
weight_idx
],
c
[
0
][
3
],
temp
[
2
]);
c
[
1
][
3
]
=
Func
::
impl
(
src
[(
3
+
src_idx
)
%
8
],
weight
[
1
][
weight_idx
],
c
[
1
][
3
],
temp
[
3
]);
c
[
0
][
4
]
=
Func
::
impl
(
src
[(
4
+
src_idx
)
%
8
],
weight
[
0
][
weight_idx
],
c
[
0
][
4
],
temp
[
0
]);
c
[
1
][
4
]
=
Func
::
impl
(
src
[(
4
+
src_idx
)
%
8
],
weight
[
1
][
weight_idx
],
c
[
1
][
4
],
temp
[
1
]);
c
[
0
][
5
]
=
Func
::
impl
(
src
[(
5
+
src_idx
)
%
8
],
weight
[
0
][
weight_idx
],
c
[
0
][
5
],
temp
[
2
]);
c
[
1
][
5
]
=
Func
::
impl
(
src
[(
5
+
src_idx
)
%
8
],
weight
[
1
][
weight_idx
],
c
[
1
][
5
],
temp
[
3
]);
c
[
0
][
6
]
=
Func
::
impl
(
src
[(
6
+
src_idx
)
%
8
],
weight
[
0
][
weight_idx
],
c
[
0
][
6
],
temp
[
0
]);
c
[
1
][
6
]
=
Func
::
impl
(
src
[(
6
+
src_idx
)
%
8
],
weight
[
1
][
weight_idx
],
c
[
1
][
6
],
temp
[
1
]);
c
[
0
][
7
]
=
Func
::
impl
(
src
[(
7
+
src_idx
)
%
8
],
weight
[
0
][
weight_idx
],
c
[
0
][
7
],
temp
[
2
]);
c
[
1
][
7
]
=
Func
::
impl
(
src
[(
7
+
src_idx
)
%
8
],
weight
[
1
][
weight_idx
],
c
[
1
][
7
],
temp
[
3
]);
}
static
void
impl
(
T
&
,
T2
&
,
T3
&
);
};
template
<
int
src_idx
,
int
weight_idx
,
typename
Func
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
struct
ShiftCalHelper
<
src_idx
,
weight_idx
,
1
,
Func
,
1
,
T
,
T2
,
T3
,
T4
>
{
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
,
T4
&
temp
)
{
c
[
0
][
0
]
=
Func
::
impl
(
src
[(
0
+
src_idx
)
%
8
],
weight
[
0
][
weight_idx
],
c
[
0
][
0
],
temp
[
0
]);
c
[
0
][
1
]
=
Func
::
impl
(
src
[(
1
+
src_idx
)
%
8
],
weight
[
0
][
weight_idx
],
c
[
0
][
1
],
temp
[
1
]);
c
[
0
][
2
]
=
Func
::
impl
(
src
[(
2
+
src_idx
)
%
8
],
weight
[
0
][
weight_idx
],
c
[
0
][
2
],
temp
[
2
]);
c
[
0
][
3
]
=
Func
::
impl
(
src
[(
3
+
src_idx
)
%
8
],
weight
[
0
][
weight_idx
],
c
[
0
][
3
],
temp
[
3
]);
c
[
0
][
4
]
=
Func
::
impl
(
src
[(
4
+
src_idx
)
%
8
],
weight
[
0
][
weight_idx
],
c
[
0
][
4
],
temp
[
0
]);
c
[
0
][
5
]
=
Func
::
impl
(
src
[(
5
+
src_idx
)
%
8
],
weight
[
0
][
weight_idx
],
c
[
0
][
5
],
temp
[
1
]);
c
[
0
][
6
]
=
Func
::
impl
(
src
[(
6
+
src_idx
)
%
8
],
weight
[
0
][
weight_idx
],
c
[
0
][
6
],
temp
[
2
]);
c
[
0
][
7
]
=
Func
::
impl
(
src
[(
7
+
src_idx
)
%
8
],
weight
[
0
][
weight_idx
],
c
[
0
][
7
],
temp
[
3
]);
}
static
void
impl
(
T
&
,
T2
&
,
T3
&
);
};
template
<
int
src_idx
,
int
weight_idx
,
int
c_dim
,
typename
FUNC
,
int
stride
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
inline
void
cal_helper
(
T
&
c
,
T2
&
src
,
T3
&
weight
,
T4
&
temp
)
{
ShiftCalHelper
<
src_idx
,
weight_idx
,
c_dim
,
FUNC
,
T
,
T2
,
T3
,
T4
>::
impl
(
c
,
src
,
weight
,
temp
);
ShiftCalHelper
<
src_idx
,
weight_idx
,
c_dim
,
FUNC
,
stride
,
T
,
T2
,
T3
,
T4
>::
impl
(
c
,
src
,
weight
,
temp
);
}
template
<
int
src_idx
,
int
weight_idx
,
int
c_dim
,
typename
FUNC
,
typename
T
,
typename
T2
,
typename
T3
>
template
<
int
src_idx
,
int
weight_idx
,
int
c_dim
,
typename
FUNC
,
int
stride
,
typename
T
,
typename
T
2
,
typename
T3
>
inline
void
cal_helper
(
T
&
c
,
T2
&
src
,
T3
&
weight
)
{
ShiftCalHelper
<
src_idx
,
weight_idx
,
c_dim
,
FUNC
,
T
,
T2
,
T3
,
int
>::
impl
(
c
,
src
,
weight
);
ShiftCalHelper
<
src_idx
,
weight_idx
,
c_dim
,
FUNC
,
stride
,
T
,
T2
,
T3
,
int
>::
impl
(
c
,
src
,
weight
);
};
template
<
int
oc
>
...
...
@@ -111,7 +187,7 @@ public:
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
,
int
oc_block
>
int
oc_block
,
int
stride
>
struct
KerNeonXXs2NchwNchw44
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
...
...
@@ -143,8 +219,9 @@ struct KerNeonXXs2NchwNchw44 {
* |x x|x x|x x|x|
* |---|---|---|-|
**/
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
>
struct
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
remain_w
,
7
,
oc_block
>
{
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
,
int
stride
>
struct
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
remain_w
,
7
,
oc_block
,
stride
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
...
...
@@ -176,12 +253,12 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block> {
load_helper
<
3
,
0
,
16
,
c_dim
,
Vld1q_s8
>
(
dot4_weight
,
weight_ptr
,
ld_dot4_weight_oc
);
load_helper
<
6
,
0
,
16
,
0
,
Vld1q_s8
>
(
src
,
nchw_src_ptr
,
0
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_s32_h
>
(
c
,
src
,
dot4_weight
,
temp_c
);
cal_helper
<
1
,
1
,
c_dim
,
Vdotq_s32_h
>
(
c
,
src
,
dot4_weight
,
temp_c
);
cal_helper
<
2
,
2
,
c_dim
,
Vdotq_s32_h
>
(
c
,
src
,
dot4_weight
,
temp_c
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_s32_h
,
stride
>
(
c
,
src
,
dot4_weight
,
temp_c
);
cal_helper
<
1
,
1
,
c_dim
,
Vdotq_s32_h
,
stride
>
(
c
,
src
,
dot4_weight
,
temp_c
);
cal_helper
<
2
,
2
,
c_dim
,
Vdotq_s32_h
,
stride
>
(
c
,
src
,
dot4_weight
,
temp_c
);
int8x8_t
src_dot2
[
4
];
int8x8_t
dot2_weight
[
c_dim
][
1
];
...
...
@@ -189,8 +266,8 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block> {
dot2_weight
,
weight_ptr
,
ld_dot4_weight_oc
);
load_helper
<
4
,
3
*
16
,
16
,
0
,
Vld1_s8
>
(
src_dot2
,
nchw_src_ptr
,
0
);
cal_helper
<
0
,
0
,
c_dim
,
Vdot2_s32_h
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
cal_helper
<
0
,
0
,
c_dim
,
Vdot2_s32_h
,
stride
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
weight_ptr
+=
filter_size
*
pack_iw_len
*
fh_step
;
}
const
int8_t
*
nchw_src_ptr
=
src_ptr
+
ic_idx
*
ic_stride
+
...
...
@@ -204,12 +281,12 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block> {
ld_dot4_weight_oc
);
load_helper_x
<
6
,
0
,
16
,
0
,
Vldq_tbl_low_s8
>
(
src_dot2
,
nchw_src_ptr
,
0
,
tbl
);
cal_helper
<
0
,
0
,
c_dim
,
Vdot2_s32_h
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
cal_helper
<
1
,
1
,
c_dim
,
Vdot2_s32_h
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
cal_helper
<
2
,
2
,
c_dim
,
Vdot2_s32_h
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
cal_helper
<
0
,
0
,
c_dim
,
Vdot2_s32_h
,
stride
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
cal_helper
<
1
,
1
,
c_dim
,
Vdot2_s32_h
,
stride
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
cal_helper
<
2
,
2
,
c_dim
,
Vdot2_s32_h
,
stride
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
int16x8_t
dot1_weight
[
c_dim
][
1
];
int16x8_t
src_dot1
[
4
];
...
...
@@ -217,14 +294,16 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block> {
dot1_weight
,
weight_ptr
,
ld_dot4_weight_oc
);
load_helper
<
4
,
3
*
16
,
16
,
0
,
Vld1_dup_s8_s16
>
(
src_dot1
,
nchw_src_ptr
,
0
);
cal_helper
<
0
,
0
,
c_dim
,
Vmlal_s16
>
(
c
,
src_dot1
,
dot1_weight
);
cal_helper
<
0
,
0
,
c_dim
,
Vmlal_s16
,
stride
>
(
c
,
src_dot1
,
dot1_weight
);
weight_ptr
+=
filter_size
*
pack_iw_len
;
}
store_ocx_ow4_remain_static
<
c_dim
,
remain_w
>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
>
struct
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
remain_w
,
5
,
oc_block
>
{
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
,
int
stride
>
struct
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
remain_w
,
5
,
oc_block
,
stride
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
...
...
@@ -255,10 +334,10 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block> {
load_helper
<
2
,
0
,
16
,
c_dim
,
Vld1q_s8
>
(
dot4_weight
,
weight_ptr
,
ld_dot4_weight_oc
);
load_helper
<
5
,
0
,
16
,
0
,
Vld1q_s8
>
(
src
,
nchw_src_ptr
,
0
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_s32_h
>
(
c
,
src
,
dot4_weight
,
temp_c
);
cal_helper
<
1
,
1
,
c_dim
,
Vdotq_s32_h
>
(
c
,
src
,
dot4_weight
,
temp_c
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_s32_h
,
stride
>
(
c
,
src
,
dot4_weight
,
temp_c
);
cal_helper
<
1
,
1
,
c_dim
,
Vdotq_s32_h
,
stride
>
(
c
,
src
,
dot4_weight
,
temp_c
);
int8x8_t
src_dot2
[
4
];
int8x8_t
dot2_weight
[
c_dim
][
1
];
...
...
@@ -266,8 +345,8 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block> {
dot2_weight
,
weight_ptr
,
ld_dot4_weight_oc
);
load_helper
<
4
,
2
*
16
,
16
,
0
,
Vld1_s8
>
(
src_dot2
,
nchw_src_ptr
,
0
);
cal_helper
<
0
,
0
,
c_dim
,
Vdot2_s32_h
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
cal_helper
<
0
,
0
,
c_dim
,
Vdot2_s32_h
,
stride
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
weight_ptr
+=
filter_size
*
pack_iw_len
*
ih_step
;
}
const
int8_t
*
nchw_src_ptr
=
src_ptr
+
ic_idx
*
ic_stride
+
...
...
@@ -282,10 +361,10 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block> {
load_helper_x
<
5
,
0
,
16
,
0
,
Vldq_tbl_low_s8
>
(
src_dot2
,
nchw_src_ptr
,
0
,
tbl
);
cal_helper
<
0
,
0
,
c_dim
,
Vdot2_s32_h
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
cal_helper
<
1
,
1
,
c_dim
,
Vdot2_s32_h
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
cal_helper
<
0
,
0
,
c_dim
,
Vdot2_s32_h
,
stride
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
cal_helper
<
1
,
1
,
c_dim
,
Vdot2_s32_h
,
stride
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
int16x8_t
dot1_weight
[
c_dim
][
1
];
int16x8_t
src_dot1
[
4
];
...
...
@@ -294,7 +373,8 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block> {
load_helper
<
4
,
2
*
16
,
16
,
0
,
Vld1_dup_s8_s16
>
(
src_dot1
,
nchw_src_ptr
,
0
);
cal_helper
<
0
,
0
,
c_dim
,
Vmlal_s16
>
(
c
,
src_dot1
,
dot1_weight
);
cal_helper
<
0
,
0
,
c_dim
,
Vmlal_s16
,
stride
>
(
c
,
src_dot1
,
dot1_weight
);
weight_ptr
+=
filter_size
*
pack_iw_len
;
}
store_ocx_ow4_remain_static
<
c_dim
,
remain_w
>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
...
...
@@ -315,8 +395,9 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block> {
* |x x|x|
* |-----|
**/
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
>
struct
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
remain_w
,
3
,
oc_block
>
{
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
,
int
stride
>
struct
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
remain_w
,
3
,
oc_block
,
stride
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
...
...
@@ -345,8 +426,8 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block> {
load_helper
<
1
,
0
,
16
,
c_dim
,
Vld1q_s8
>
(
dot4_weight
,
weight_ptr
,
ld_weight_oc
);
load_helper
<
4
,
0
,
16
,
0
,
Vld1q_s8
>
(
src
,
nchw_src_ptr
,
0
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_s32_h
>
(
c
,
src
,
dot4_weight
,
temp_c
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_s32_h
,
stride
>
(
c
,
src
,
dot4_weight
,
temp_c
);
int8x8_t
src_dot2
[
4
];
int8x8_t
dot2_weight
[
c_dim
][
1
];
...
...
@@ -354,8 +435,8 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block> {
dot2_weight
,
weight_ptr
,
ld_weight_oc
);
load_helper
<
4
,
1
*
16
,
16
,
0
,
Vld1_s8
>
(
src_dot2
,
nchw_src_ptr
,
0
);
cal_helper
<
0
,
0
,
c_dim
,
Vdot2_s32_h
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
cal_helper
<
0
,
0
,
c_dim
,
Vdot2_s32_h
,
stride
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
}
// last line
{
...
...
@@ -369,23 +450,257 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block> {
ld_weight_oc
);
load_helper_x
<
4
,
0
,
16
,
0
,
Vldq_tbl_low_s8
>
(
src_dot2
,
nchw_src_ptr
,
0
,
tbl
);
cal_helper
<
0
,
0
,
c_dim
,
Vdot2_s32_h
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
cal_helper
<
0
,
0
,
c_dim
,
Vdot2_s32_h
,
stride
>
(
c
,
src_dot2
,
dot2_weight
,
temp_c
);
int16x8_t
dot1_weight
[
c_dim
][
1
];
int16x8_t
src_dot1
[
4
];
load_helper
<
1
,
32
,
8
,
c_dim
,
Vldq_dup_4s8_8s16
>
(
dot1_weight
,
weight_ptr
,
ld_weight_oc
);
load_helper
<
4
,
1
*
16
,
16
,
0
,
Vld1_dup_s8_s16
>
(
src_dot1
,
nchw_src_ptr
,
0
);
cal_helper
<
0
,
0
,
c_dim
,
Vmlal_s16
>
(
c
,
src_dot1
,
dot1_weight
);
cal_helper
<
0
,
0
,
c_dim
,
Vmlal_s16
,
stride
>
(
c
,
src_dot1
,
dot1_weight
);
weight_ptr
+=
filter_size
*
filter_size
*
pack_iw_len
;
}
}
store_ocx_ow4_remain_static
<
c_dim
,
remain_w
>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
,
int
stride
>
struct
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
remain_w
,
2
,
oc_block
,
stride
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
filter_size
=
2
;
constexpr
int
oc_step
=
4
;
constexpr
int
loop_ic_step
=
1
;
constexpr
int
pack_iw_len
=
4
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc
=
oc_step
*
filter_size
*
filter_size
*
ic
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
int32x4_t
c
[
c_dim
][
4
];
init_ocx_ow4
<
c_dim
,
bias_mode
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
const
int8_t
*
nchw_src_ptr
=
src_ptr
+
ic_idx
*
ic_stride
;
int8x16_t
src
[
4
];
int8x16_t
dot4_weight
[
c_dim
][
1
];
int16x8_t
temp_c
[
4
];
load_helper
<
1
,
0
,
16
,
c_dim
,
Vld1q_s8
>
(
dot4_weight
,
weight_ptr
,
ld_weight_oc
);
load_helper
<
4
,
0
,
16
,
0
,
Vld1q_s8
>
(
src
,
nchw_src_ptr
,
0
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_s32_h
,
stride
>
(
c
,
src
,
dot4_weight
,
temp_c
);
weight_ptr
+=
oc_step
*
filter_size
*
filter_size
;
}
store_ocx_ow4_remain_static
<
c_dim
,
remain_w
>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
>
struct
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
remain_w
,
2
,
oc_block
,
1
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
stride
=
1
;
constexpr
int
filter_height
=
2
;
constexpr
int
filter_width
=
4
;
constexpr
int
oc_step
=
4
;
constexpr
int
loop_ic_step
=
1
;
constexpr
int
simd_len
=
16
;
constexpr
int
pack_iw_len
=
16
;
constexpr
int
src_reg
=
8
;
constexpr
int
weight_reg
=
1
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc
=
oc_step
*
filter_height
*
filter_width
*
ic
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
int32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
8
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
const
int8_t
*
nchw_src_ptr
=
src_ptr
+
ic_idx
*
ic_stride
;
int8x16_t
src
[
src_reg
];
int8x16_t
dot4_weight
[
c_dim
][
weight_reg
];
int16x8_t
temp_c
[
4
];
load_helper
<
weight_reg
,
0
,
simd_len
,
c_dim
,
Vld1q_s8
>
(
dot4_weight
,
weight_ptr
,
ld_weight_oc
);
load_helper
<
src_reg
,
0
,
simd_len
,
0
,
Vld1q_s8
>
(
src
,
nchw_src_ptr
+
0
*
iw
*
pack_iw_len
,
0
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_s32_h
,
stride
>
(
c
,
src
,
dot4_weight
,
temp_c
);
load_helper
<
weight_reg
,
0
,
simd_len
,
c_dim
,
Vld1q_s8
>
(
dot4_weight
,
weight_ptr
+
1
*
filter_width
*
oc_step
,
ld_weight_oc
);
load_helper
<
src_reg
,
0
,
simd_len
,
0
,
Vld1q_s8
>
(
src
,
nchw_src_ptr
+
1
*
iw
*
pack_iw_len
,
0
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_s32_h
,
stride
>
(
c
,
src
,
dot4_weight
,
temp_c
);
weight_ptr
+=
oc_step
*
filter_height
*
filter_width
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
dt_qint8
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
>
struct
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
remain_w
,
3
,
oc_block
,
1
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
stride
=
1
;
constexpr
int
filter_height
=
3
;
constexpr
int
filter_width
=
4
;
constexpr
int
oc_step
=
4
;
constexpr
int
loop_ic_step
=
1
;
constexpr
int
simd_len
=
16
;
constexpr
int
pack_iw_len
=
16
;
constexpr
int
src_reg
=
8
;
constexpr
int
weight_reg
=
1
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc
=
oc_step
*
filter_height
*
filter_width
*
ic
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
int32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
8
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
const
int8_t
*
nchw_src_ptr
=
src_ptr
+
ic_idx
*
ic_stride
;
int8x16_t
src
[
src_reg
];
int8x16_t
dot4_weight
[
c_dim
][
weight_reg
];
int16x8_t
temp_c
[
4
];
load_helper
<
weight_reg
,
0
,
simd_len
,
c_dim
,
Vld1q_s8
>
(
dot4_weight
,
weight_ptr
,
ld_weight_oc
);
load_helper
<
src_reg
,
0
,
simd_len
,
0
,
Vld1q_s8
>
(
src
,
nchw_src_ptr
+
0
*
iw
*
pack_iw_len
,
0
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_s32_h
,
stride
>
(
c
,
src
,
dot4_weight
,
temp_c
);
load_helper
<
weight_reg
,
0
,
simd_len
,
c_dim
,
Vld1q_s8
>
(
dot4_weight
,
weight_ptr
+
1
*
filter_width
*
oc_step
,
ld_weight_oc
);
load_helper
<
src_reg
,
0
,
simd_len
,
0
,
Vld1q_s8
>
(
src
,
nchw_src_ptr
+
1
*
iw
*
pack_iw_len
,
0
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_s32_h
,
stride
>
(
c
,
src
,
dot4_weight
,
temp_c
);
load_helper
<
weight_reg
,
0
,
simd_len
,
c_dim
,
Vld1q_s8
>
(
dot4_weight
,
weight_ptr
+
2
*
filter_width
*
oc_step
,
ld_weight_oc
);
load_helper
<
src_reg
,
0
,
simd_len
,
0
,
Vld1q_s8
>
(
src
,
nchw_src_ptr
+
2
*
iw
*
pack_iw_len
,
0
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_s32_h
,
stride
>
(
c
,
src
,
dot4_weight
,
temp_c
);
weight_ptr
+=
oc_step
*
filter_height
*
filter_width
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
dt_qint8
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
>
struct
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
remain_w
,
5
,
oc_block
,
1
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
stride
=
1
;
constexpr
int
filter_height
=
5
;
constexpr
int
filter_width
=
8
;
constexpr
int
oc_step
=
4
;
constexpr
int
loop_ic_step
=
1
;
constexpr
int
simd_len
=
16
;
constexpr
int
pack_iw_len
=
16
;
constexpr
int
src_reg
=
8
;
constexpr
int
weight_reg
=
2
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc
=
oc_step
*
filter_height
*
filter_width
*
ic
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
int32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
8
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
const
int8_t
*
nchw_src_ptr
=
src_ptr
+
ic_idx
*
ic_stride
;
int8x16_t
src
[
src_reg
];
int8x16_t
dot4_weight
[
c_dim
][
weight_reg
];
int16x8_t
temp_c
[
4
];
#define cb(step) \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
dot4_weight, weight_ptr + step * filter_width * oc_step, \
ld_weight_oc); \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, nchw_src_ptr + step * iw * pack_iw_len, 0); \
cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, temp_c); \
load_helper<4, 0, simd_len, 0, Vld1q_s8>( \
src, \
nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \
0); \
cal_helper<4, 1, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, temp_c);
UNROLL_CALL_RAW
(
5
,
cb
);
#undef cb
weight_ptr
+=
oc_step
*
filter_height
*
filter_width
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
dt_qint8
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
>
struct
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
remain_w
,
7
,
oc_block
,
1
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
stride
=
1
;
constexpr
int
filter_height
=
7
;
constexpr
int
filter_width
=
8
;
constexpr
int
oc_step
=
4
;
constexpr
int
loop_ic_step
=
1
;
constexpr
int
simd_len
=
16
;
constexpr
int
pack_iw_len
=
16
;
constexpr
int
src_reg
=
8
;
constexpr
int
weight_reg
=
2
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc
=
oc_step
*
filter_height
*
filter_width
*
ic
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
int32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
8
>
(
c
,
bias_ptr
,
oc_step
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
loop_ic_step
)
{
const
int8_t
*
nchw_src_ptr
=
src_ptr
+
ic_idx
*
ic_stride
;
int8x16_t
src
[
src_reg
];
int8x16_t
dot4_weight
[
c_dim
][
weight_reg
];
int16x8_t
temp_c
[
4
];
#define cb(step) \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
dot4_weight, weight_ptr + step * filter_width * oc_step, \
ld_weight_oc); \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, nchw_src_ptr + step * iw * pack_iw_len, 0); \
cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, temp_c); \
load_helper<4, 0, simd_len, 0, Vld1q_s8>( \
src, \
nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \
0); \
cal_helper<4, 1, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, temp_c);
UNROLL_CALL_RAW
(
7
,
cb
);
#undef cb
weight_ptr
+=
oc_step
*
filter_height
*
filter_width
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
dt_qint8
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
}
// namespace
enum
PACK_MODE
{
NO_PAD
=
0
,
FIRST_PAD
=
1
,
LAST_PAD
=
2
};
template
<
PACK_MODE
mode
>
inline
void
pack_src_one_line
(
const
int8_t
*
inptr
,
int8_t
*
outptr
,
int
left_pad
,
...
...
@@ -443,14 +758,24 @@ inline void pack_src_one_line(const int8_t* inptr, int8_t* outptr, int left_pad,
memset
(
outptr
,
0
,
combine_row
*
right_pad
*
src_expand
*
sizeof
(
int8_t
));
outptr
+=
combine_row
*
right_pad
*
src_expand
;
}
template
<
int
stride
>
void
pack_nchw_src_for_nchw44_conv
(
const
int8_t
*
inptr
,
int8_t
*
outptr
,
const
int
ic
,
const
int
top_pad
,
const
int
bottom_pad
,
const
int
left_pad
,
const
int
right_pad
,
const
int
ih
,
const
int
iw
,
const
int
iw2
,
const
int
pw
,
int8_t
*
temp_ptr
);
/**
* pack (ic, h, w) to (ic, h / 2, 2 * w)
* pack interleave two adjacent row in src and repeat 4 times, store to one row
* */
void
conv_bias
::
pack_nchw_src_for_nchw44_conv
(
const
int8_t
*
inptr
,
int8_t
*
outptr
,
const
int
ic
,
const
int
top_pad
,
const
int
bottom_pad
,
const
int
left_pad
,
const
int
right_pad
,
const
int
ih
,
const
int
iw
)
{
template
<
>
void
pack_nchw_src_for_nchw44_conv
<
2
>
(
const
int8_t
*
inptr
,
int8_t
*
outptr
,
const
int
ic
,
const
int
top_pad
,
const
int
bottom_pad
,
const
int
left_pad
,
const
int
right_pad
,
const
int
ih
,
const
int
iw
,
const
int
,
const
int
,
int8_t
*
)
{
constexpr
int
src_expand
=
4
;
constexpr
int
oh_step
=
2
;
const
int
oh
=
ih
+
top_pad
+
bottom_pad
;
...
...
@@ -490,15 +815,75 @@ void conv_bias::pack_nchw_src_for_nchw44_conv(
}
}
}
/**
* pack (ic, h, w) to (ic, h, w * 16)
* pack interleave two adjacent row in src and repeat 4 times, store to one row
* */
template
<
>
void
pack_nchw_src_for_nchw44_conv
<
1
>
(
const
int8_t
*
sptr_origin
,
int8_t
*
sptr_base
,
const
int
ic
,
const
int
pad_top
,
const
int
pad_bottom
,
const
int
,
const
int
,
const
int
ih
,
const
int
iw
,
const
int
iw2
,
const
int
pw
,
int8_t
*
temp_ptr
)
{
static
uint8_t
reorder_idx
[
16
]
=
{
0
,
1
,
0
,
1
,
0
,
1
,
0
,
1
,
2
,
3
,
2
,
3
,
2
,
3
,
2
,
3
};
uint8x16_t
tbl_idx
=
vld1q_u8
(
&
reorder_idx
[
0
]);
constexpr
int
iw_step
=
4
;
constexpr
int
pack_iw_len
=
16
;
const
int
ic_stride
=
ih
*
iw
;
const
int
iw_with_pad
=
iw
+
2
*
pw
;
const
int
iw_with_pad_end
=
iw_with_pad
/
iw_step
*
iw_step
;
rep
(
ic_idx
,
ic
)
{
const
int8_t
*
sptr
=
sptr_origin
+
ic_idx
*
ic_stride
;
memset
(
sptr_base
,
0
,
sizeof
(
int8_t
)
*
iw2
*
(
ih
+
pad_top
+
pad_bottom
)
*
pack_iw_len
);
sptr_base
+=
iw2
*
pad_top
*
pack_iw_len
;
rep
(
ih_idx
,
ih
)
{
memset
(
temp_ptr
,
0
,
iw_with_pad
*
sizeof
(
int8_t
));
memcpy
(
temp_ptr
+
pw
,
sptr
,
sizeof
(
int8_t
)
*
iw
);
for
(
int
iw_idx
=
0
;
iw_idx
<
iw_with_pad_end
;
iw_idx
+=
iw_step
)
{
int8x16_t
src
[
4
];
int8x16_t
dst
[
4
];
src
[
0
]
=
vld1q_s8
(
temp_ptr
+
iw_idx
);
src
[
1
]
=
vld1q_s8
(
temp_ptr
+
iw_idx
+
1
);
src
[
2
]
=
vld1q_s8
(
temp_ptr
+
iw_idx
+
2
);
src
[
3
]
=
vld1q_s8
(
temp_ptr
+
iw_idx
+
3
);
dst
[
0
]
=
vqtbl1q_s8
(
src
[
0
],
tbl_idx
);
dst
[
1
]
=
vqtbl1q_s8
(
src
[
1
],
tbl_idx
);
dst
[
2
]
=
vqtbl1q_s8
(
src
[
2
],
tbl_idx
);
dst
[
3
]
=
vqtbl1q_s8
(
src
[
3
],
tbl_idx
);
vst1q_s8
(
sptr_base
+
iw_idx
*
pack_iw_len
+
0
,
dst
[
0
]);
vst1q_s8
(
sptr_base
+
iw_idx
*
pack_iw_len
+
16
,
dst
[
1
]);
vst1q_s8
(
sptr_base
+
iw_idx
*
pack_iw_len
+
32
,
dst
[
2
]);
vst1q_s8
(
sptr_base
+
iw_idx
*
pack_iw_len
+
48
,
dst
[
3
]);
}
for
(
int
iw_idx
=
iw_with_pad_end
;
iw_idx
<
iw_with_pad
;
++
iw_idx
)
{
int8x16_t
src
=
vld1q_s8
(
temp_ptr
+
iw_idx
);
int8x16_t
dst
=
vqtbl1q_s8
(
src
,
tbl_idx
);
vst1q_s8
(
sptr_base
+
iw_idx
*
pack_iw_len
,
dst
);
}
sptr_base
+=
iw2
*
pack_iw_len
;
sptr
+=
iw
;
}
sptr_base
+=
iw2
*
pad_bottom
*
pack_iw_len
;
}
}
template
<
int
stride
>
void
pack_nchw44_weight_for_nchw_conv
(
const
int8_t
*
inptr
,
int8_t
*
outptr
,
const
int
ic
,
const
int
fh
,
const
int
fw
,
const
int
oc
);
/**
* pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh * fw, 4(oc)}
* pack interleave two adjacent row in filter to one row
* */
void
conv_bias
::
pack_nchw44_weight_for_nchw_conv
(
const
int8_t
*
inptr
,
int8_t
*
outptr
,
const
int
ic
,
const
int
fh
,
const
int
fw
,
const
int
oc
)
{
template
<
>
void
pack_nchw44_weight_for_nchw_conv
<
2
>
(
const
int8_t
*
inptr
,
int8_t
*
outptr
,
const
int
ic
,
const
int
fh
,
const
int
fw
,
const
int
oc
)
{
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
2
;
constexpr
int
fh_step
=
2
;
...
...
@@ -610,24 +995,72 @@ void conv_bias::pack_nchw44_weight_for_nchw_conv(const int8_t* inptr,
outptr
+=
oc_step
*
fh
*
fw
*
ic
;
}
}
/**
* pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh ,fw/4, 4(oc)*4(fw)}
* pack interleave two adjacent row in filter to one row
* */
template
<
>
void
pack_nchw44_weight_for_nchw_conv
<
1
>
(
const
int8_t
*
src_ptr
,
int8_t
*
dst_ptr
,
const
int
ic
,
const
int
fh
,
const
int
fw
,
const
int
oc
)
{
constexpr
int
oc_step
=
4
;
const
int
fw2
=
round_up
(
fw
,
4
);
const
int
fw_remain
=
fw2
-
fw
;
const
int
dst_ic_stride
=
fh
*
fw2
;
const
int
oc_step_stride
=
fh
*
fw2
*
ic
*
oc_step
;
static
const
uint8_t
transpose_4x4_idx
[
16
]
=
{
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
,
8
,
12
,
9
,
13
,
10
,
14
,
11
,
15
};
uint8x16_t
tbl_transpose_4x4
=
vld1q_u8
(
&
transpose_4x4_idx
[
0
]);
rep_step
(
oc_idx
,
oc
,
oc_step
)
{
int32_t
*
dst_temp_ptr
=
reinterpret_cast
<
int32_t
*>
(
dst_ptr
+
oc_idx
*
ic
*
fh
*
fw2
);
const
int32_t
*
src_temp_ptr
=
reinterpret_cast
<
const
int32_t
*>
(
src_ptr
+
oc_idx
*
ic
*
fh
*
fw
);
// transpose ic and pad
rep
(
fh_idx
,
fh
)
{
rep
(
fw_idx
,
fw
)
{
rep
(
ic_idx
,
ic
)
{
*
(
dst_temp_ptr
+
ic_idx
*
dst_ic_stride
)
=
*
src_temp_ptr
;
src_temp_ptr
++
;
}
dst_temp_ptr
++
;
}
rep
(
ic_idx
,
ic
)
{
memset
(
dst_temp_ptr
+
ic_idx
*
dst_ic_stride
,
0
,
sizeof
(
int8_t
)
*
oc_step
*
fw_remain
);
}
dst_temp_ptr
+=
fw_remain
;
}
// transpose fw oc
int8_t
*
trans_dst_temp_ptr
=
reinterpret_cast
<
int8_t
*>
(
dst_ptr
+
oc_idx
*
ic
*
fh
*
fw2
);
template
<
BiasMode
bias_mode
,
typename
Op
,
size_t
filter_size
>
static
void
conv_direct_stride2_int8_nchw_nchw44
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
rep_step
(
idx
,
oc_step_stride
,
16
)
{
int8x16_t
temp
=
vld1q_s8
(
trans_dst_temp_ptr
+
idx
);
vst1q_s8
(
trans_dst_temp_ptr
+
idx
,
vqtbl1q_s8
(
temp
,
tbl_transpose_4x4
));
}
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
size_t
filter_size
,
int
stride
>
struct
ConvDiectStrideInt8NchwNchw44
{
static
void
impl
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
MEGDNN_MARK_USED_VAR
(
temp
);
constexpr
size_t
fh
=
filter_size
;
constexpr
size_t
fw
=
filter_size
;
constexpr
size_t
fw
=
stride
==
2
?
filter_size
:
(
filter_size
+
3
)
/
4
*
4
;
constexpr
size_t
ic_step
=
1
;
constexpr
size_t
big_oc_step
=
8
;
constexpr
size_t
oc_step
=
4
;
constexpr
size_t
ih_step
=
2
;
constexpr
size_t
ih_step
=
stride
==
2
?
2
:
1
;
constexpr
size_t
oh_step
=
1
;
constexpr
size_t
ow_step
=
4
;
constexpr
size_t
stride_h
=
2
;
constexpr
size_t
stride_w
=
2
;
constexpr
size_t
ow_step
=
stride
==
2
?
4
:
8
;
constexpr
size_t
stride_h
=
stride
;
constexpr
size_t
stride_w
=
stride
;
constexpr
int
pack_iw_len
=
4
;
const
size_t
img_stride
=
oh
*
ow
;
...
...
@@ -637,10 +1070,10 @@ static void conv_direct_stride2_int8_nchw_nchw44(
const
size_t
oc_remain
=
oc
-
oc_end
;
const
int
ld_dst_oc
=
oc_step
*
img_stride
;
using
remain_fun
=
std
::
function
<
void
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
>
;
using
remain_fun
=
std
::
function
<
void
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
>
;
remain_fun
kern_big_oc_remain
=
nullptr
;
remain_fun
kern_small_oc_remain
=
nullptr
;
switch
(
ow_remain
)
{
...
...
@@ -648,10 +1081,10 @@ static void conv_direct_stride2_int8_nchw_nchw44(
case step: \
kern_big_oc_remain = \
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \
big_oc_step
>::impl;
\
big_oc_step
, stride>::impl;
\
kern_small_oc_remain = \
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \
oc_step
>::impl;
\
oc_step
, stride>::impl;
\
break;
UNROLL_CALL_RAW
(
4
,
cb
);
...
...
@@ -663,27 +1096,28 @@ static void conv_direct_stride2_int8_nchw_nchw44(
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
*
ih_step
)
*
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
*
ih_step
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
0
,
filter_size
,
big_oc_step
>::
impl
(
src
+
src_offset
,
big_oc_step
,
stride
>::
impl
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
*
ih_step
)
*
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
*
ih_step
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
kern_big_oc_remain
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
}
}
...
...
@@ -692,98 +1126,162 @@ static void conv_direct_stride2_int8_nchw_nchw44(
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
*
ih_step
)
*
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
*
ih_step
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
0
,
filter_size
,
oc_step
>::
impl
(
src
+
src_offset
,
oc_step
,
stride
>::
impl
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
*
ih_step
)
*
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
*
ih_step
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
kern_small_oc_remain
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
kern_small_oc_remain
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
}
}
}
#define CONSTRUCT_FUNC(filter_size) \
template <BiasMode bias_mode, typename Op> \
void conv_bias:: \
conv_direct_stride2_##filter_size##x##filter_size##_int8_nchw_nchw44( \
const int8_t* src, const int8_t* filter, \
const int32_t* bias, int32_t* temp, int8_t* dst, \
const size_t oc, const size_t ic, const size_t ih, \
const size_t iw, const size_t oh, const size_t ow, \
const Op& op) { \
conv_direct_stride2_int8_nchw_nchw44<bias_mode, Op, filter_size>( \
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); \
}
CONSTRUCT_FUNC
(
3
);
CONSTRUCT_FUNC
(
5
);
CONSTRUCT_FUNC
(
7
);
#undef CONSTRUCT_FUNC
template
<
BiasMode
bias_mode
,
typename
Op
>
void
conv_bias
::
conv_direct_stride2_2x2_int8_nchw_nchw44
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
size_t
filter_size
>
struct
ConvDiectStrideInt8NchwNchw44
<
bias_mode
,
Op
,
filter_size
,
1
>
{
static
void
impl
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
MEGDNN_MARK_USED_VAR
(
src
);
MEGDNN_MARK_USED_VAR
(
filter
);
MEGDNN_MARK_USED_VAR
(
bias
);
MEGDNN_MARK_USED_VAR
(
temp
);
MEGDNN_MARK_USED_VAR
(
dst
);
MEGDNN_MARK_USED_VAR
(
oc
);
MEGDNN_MARK_USED_VAR
(
ic
);
MEGDNN_MARK_USED_VAR
(
ih
);
MEGDNN_MARK_USED_VAR
(
iw
);
MEGDNN_MARK_USED_VAR
(
oh
);
MEGDNN_MARK_USED_VAR
(
ow
);
MEGDNN_MARK_USED_VAR
(
op
);
megdnn_assert
(
0
,
"not imple nchw_nchw44 2x2s2 conv"
);
constexpr
int
stride
=
1
;
constexpr
size_t
fh
=
filter_size
;
constexpr
size_t
fw
=
(
filter_size
+
3
)
/
4
*
4
;
constexpr
size_t
ic_step
=
1
;
constexpr
size_t
big_oc_step
=
8
;
constexpr
size_t
oc_step
=
4
;
constexpr
size_t
ih_step
=
1
;
constexpr
size_t
oh_step
=
1
;
constexpr
size_t
ow_step
=
8
;
constexpr
size_t
stride_h
=
stride
;
constexpr
size_t
stride_w
=
stride
;
constexpr
int
pack_iw_len
=
16
;
const
size_t
img_stride
=
oh
*
ow
;
const
size_t
ow_end
=
ow
/
ow_step
*
ow_step
;
const
size_t
ow_remain
=
ow
-
ow_end
;
const
size_t
oc_end
=
oc
/
big_oc_step
*
big_oc_step
;
const
size_t
oc_remain
=
oc
-
oc_end
;
const
int
ld_dst_oc
=
oc_step
*
img_stride
;
using
remain_fun
=
std
::
function
<
void
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
>
;
remain_fun
kern_big_oc_remain
=
nullptr
;
remain_fun
kern_small_oc_remain
=
nullptr
;
switch
(
ow_remain
)
{
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \
big_oc_step, stride>::impl; \
kern_small_oc_remain = \
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \
oc_step, stride>::impl; \
break;
UNROLL_CALL_RAW
(
8
,
cb
);
default:
megdnn_assert
(
0
,
"no remain %zu for kern"
,
ow_remain
);
}
for
(
size_t
oc_idx
=
0
;
oc_idx
<
oc_end
;
oc_idx
+=
big_oc_step
)
{
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
*
ih_step
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
ow_step
,
filter_size
,
big_oc_step
,
stride
>::
impl
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
*
ih_step
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
kern_big_oc_remain
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
}
}
if
(
oc_remain
>
0
)
{
size_t
oc_idx
=
oc_end
;
const
size_t
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
size_t
oh_idx
=
0
;
oh_idx
<
oh
;
oh_idx
+=
oh_step
)
{
for
(
size_t
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
*
ih_step
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
KerNeonXXs2NchwNchw44
<
bias_mode
,
Op
,
ow_step
,
filter_size
,
oc_step
,
stride
>::
impl
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
if
(
ow_remain
>
0
)
{
const
size_t
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
*
ih_step
)
*
ic_step
*
pack_iw_len
;
const
size_t
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
kern_small_oc_remain
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
}
}
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
size_t
filter_size
,
int
stride
>
static
void
conv_direct_int8_nchw_nchw44
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
oc
,
const
size_t
ic
,
const
size_t
ih
,
const
size_t
iw
,
const
size_t
oh
,
const
size_t
ow
,
const
Op
&
op
)
{
ConvDiectStrideInt8NchwNchw44
<
bias_mode
,
Op
,
filter_size
,
stride
>::
impl
(
src
,
filter
,
bias
,
temp
,
dst
,
oc
,
ic
,
ih
,
iw
,
oh
,
ow
,
op
);
}
#define INSTANTIATION(stride, i, bias, Op) \
template void conv_bias:: \
conv_direct_##stride##_##i##x##i##_int8_nchw_nchw44<bias, Op>( \
const int8_t*, const int8_t*, const int32_t*, int32_t*, \
int8_t*, const size_t, const size_t, const size_t, \
const size_t, const size_t, const size_t, const Op&);
#define FOR_OP(stride, i, bias) \
INSTANTIATION(stride, i, bias, TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(stride, i, bias, ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(stride, i, bias, HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)
#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)
#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)
FOR_FILTER
(
stride2
)
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
}
// namespace
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h
已删除
100644 → 0
浏览文件 @
198f3eb5
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace
megdnn
{
namespace
arm_common
{
namespace
conv_bias
{
#define KERN(stride, i, layout) \
template <BiasMode bias_mode, typename Op> \
void conv_direct_##stride##_##i##x##i##_int8_nchw_##layout( \
const int8_t* src, const int8_t* filter, const int32_t* bias, \
int32_t* temp, int8_t* dst, const size_t OC, const size_t IC, \
const size_t IH, const size_t IW, const size_t OH, \
const size_t OW, const Op& op);
KERN
(
stride2
,
2
,
nchw44
)
KERN
(
stride2
,
3
,
nchw44
)
KERN
(
stride2
,
5
,
nchw44
)
KERN
(
stride2
,
7
,
nchw44
)
#undef KERN
void
pack_nchw44_weight_for_nchw_conv
(
const
int8_t
*
inptr
,
int8_t
*
outptr
,
const
int
ic
,
const
int
fh
,
const
int
fw
,
const
int
oc
);
void
pack_nchw_src_for_nchw44_conv
(
const
int8_t
*
inptr
,
int8_t
*
outptr
,
const
int
ic
,
const
int
top_pad
,
const
int
bottom_pad
,
const
int
left_pad
,
const
int
right_pad
,
const
int
ih
,
const
int
iw
);
}
// namespace conv_bias
}
// namespace arm_common
}
// namespace megdnn
\ No newline at end of file
dnn/src/arm_common/conv_bias/opr_impl.cpp
浏览文件 @
7b0dbe6a
...
...
@@ -47,7 +47,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8DirectStride2
s8_direct_stride2_large_group
{
true
};
AlgoS8DirectStride2
s8_direct_stride2_small_group
{
false
};
AlgoS8DirectStride2NCHW44
s8_direct_stride2_nchw44
;
AlgoS8Direct
Stride2NCHWNCHW44
s8_direct_stride2
_nchw_nchw44
;
AlgoS8Direct
NCHWNCHW44
s8_direct
_nchw_nchw44
;
AlgoS8DirectStride1
s8_direct_stride1_large_group
{
true
};
AlgoS8DirectStride1
s8_direct_stride1_small_group
{
false
};
AlgoS8DirectStride1NCHW44
s8_direct_stride1_nchw44
;
...
...
@@ -115,7 +115,7 @@ public:
direct_algos
.
emplace_back
(
&
s8_direct_stride2_large_group
);
direct_algos
.
emplace_back
(
&
s8_direct_stride2_small_group
);
direct_algos
.
emplace_back
(
&
s8_direct_stride2_nchw44
);
direct_algos
.
emplace_back
(
&
s8_direct_
stride2_
nchw_nchw44
);
direct_algos
.
emplace_back
(
&
s8_direct_nchw_nchw44
);
direct_algos
.
emplace_back
(
&
s8_direct_stride1_large_group
);
direct_algos
.
emplace_back
(
&
s8_direct_stride1_small_group
);
direct_algos
.
emplace_back
(
&
s8_direct_stride1_nchw44
);
...
...
dnn/src/arm_common/conv_bias/opr_impl.h
浏览文件 @
7b0dbe6a
...
...
@@ -40,7 +40,7 @@ private:
class
AlgoS8DirectStride1NCHW44
;
class
AlgoS8DirectStride2
;
class
AlgoS8DirectStride2NCHW44
;
class
AlgoS8Direct
Stride2
NCHWNCHW44
;
class
AlgoS8DirectNCHWNCHW44
;
class
AlgoQU8DirectStride1
;
class
AlgoQU8DirectStride2
;
class
AlgoFP32WinogradF23_4x4
;
...
...
dnn/test/arm_common/conv_bias.cpp
浏览文件 @
7b0dbe6a
...
...
@@ -244,18 +244,26 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) {
#if MEGDNN_AARCH64
benchmark_convbias
(
handle
(),
"IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384"
,
"IM2COLMATMUL:AARCH64_F32K8X12X1:192"
,
true
);
benchmark_convbias
(
handle
(),
"IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384"
,
"IM2COLMATMUL:AARCH64_F32K8X12X1:192"
,
false
);
#else
benchmark_convbias
(
handle
(),
"IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384"
,
"IM2COLMATMUL:ARMV7_F32:192"
,
true
);
benchmark_convbias
(
handle
(),
"IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384"
,
"IM2COLMATMUL:ARMV7_F32:192"
,
false
);
#endif
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
BENCHMARK_CONVBIAS_NCHW44
)
{
#if MEGDNN_AARCH64
benchmark_convbias
(
handle
(),
"IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384"
,
"IM2COLMATMUL:AARCH64_F32K8X12X1:192"
,
true
);
benchmark_convbias
(
handle
(),
"IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384"
,
"IM2COLMATMUL:AARCH64_F32K8X12X1:192"
,
false
);
#else
benchmark_convbias
(
handle
(),
"IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384"
,
"IM2COLMATMUL:ARMV7_F32:192"
,
true
);
benchmark_convbias
(
handle
(),
"IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384"
,
"IM2COLMATMUL:ARMV7_F32:192"
,
false
);
#endif
}
...
...
dnn/test/arm_common/conv_bias_multi_thread.cpp
浏览文件 @
7b0dbe6a
...
...
@@ -541,7 +541,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_INT8_NCHW_NCHW44
)
{
checker_conv_bias_qint8x8x8
(
get_nchw44_conv_bias_args
({
3
,
5
,
7
},
2
,
false
,
false
,
false
,
true
),
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
2
,
false
,
false
,
false
,
true
),
handle
(),
"S8_CONV_NCHW_NCHW44"
);
checker_conv_bias_qint8x8x8
(
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
1
,
false
,
false
,
false
,
true
),
handle
(),
"S8_CONV_NCHW_NCHW44"
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录