Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
6c29548d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
6c29548d
编写于
6月 02, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(dnn/arm): fix nchw_nchw44 dot stride1 support
GitOrigin-RevId: c8d3d55b258e2a43c27b903808566f2ea1857842
上级
02cbb13b
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
1432 addition
and
200 deletion
+1432
-200
dnn/src/arm_common/conv_bias/block_helper.h
dnn/src/arm_common/conv_bias/block_helper.h
+36
-0
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp
+7
-21
dnn/src/arm_common/conv_bias/int8/algos.h
dnn/src/arm_common/conv_bias/int8/algos.h
+15
-0
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
...arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
+321
-0
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h
...c/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h
+779
-0
dnn/src/arm_common/conv_bias/intrinsic_helper.h
dnn/src/arm_common/conv_bias/intrinsic_helper.h
+200
-169
dnn/src/arm_common/conv_bias/opr_impl.cpp
dnn/src/arm_common/conv_bias/opr_impl.cpp
+2
-0
dnn/src/arm_common/conv_bias/opr_impl.h
dnn/src/arm_common/conv_bias/opr_impl.h
+1
-0
dnn/src/arm_common/neon_struct.h
dnn/src/arm_common/neon_struct.h
+8
-0
dnn/src/arm_common/simd_macro/marm_neon.h
dnn/src/arm_common/simd_macro/marm_neon.h
+40
-6
dnn/test/arm_common/conv_bias.cpp
dnn/test/arm_common/conv_bias.cpp
+13
-4
dnn/test/arm_common/conv_bias_multi_thread.cpp
dnn/test/arm_common/conv_bias_multi_thread.cpp
+10
-0
未找到文件。
dnn/src/arm_common/conv_bias/block_helper.h
0 → 100644
浏览文件 @
6c29548d
/**
* \file dnn/src/arm_common/conv_bias/block_helper.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/common/utils.h"
namespace
megdnn
{
namespace
{
// block_helper is used to calculate oh block size
static
inline
int
l2_block_helper
(
const
int
nthread
,
const
int
amount
,
const
int
size_per_unit
)
{
constexpr
int
l2_cache_size
=
256
*
1024
;
const
int
block_per_thread
=
div_ceil
(
amount
,
nthread
);
const
int
best_block
=
std
::
min
(
amount
,
(
l2_cache_size
+
size_per_unit
/
2
)
/
size_per_unit
);
const
int
max_block_num
=
div_ceil
(
block_per_thread
,
best_block
);
const
int
min_block_num
=
std
::
max
(
max_block_num
-
1
,
1
);
const
int
max_block
=
div_ceil
(
block_per_thread
,
max_block_num
);
const
int
min_block
=
div_ceil
(
block_per_thread
,
min_block_num
);
const
int
max_loss
=
std
::
abs
(
max_block_num
*
max_block
-
block_per_thread
);
const
int
min_loss
=
std
::
abs
(
min_block_num
*
min_block
-
block_per_thread
);
int
block
=
max_loss
>
min_loss
?
min_block
:
max_block
;
return
block
;
}
}
// namespace
}
// namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp
浏览文件 @
6c29548d
...
...
@@ -11,6 +11,7 @@
*/
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h"
...
...
@@ -26,22 +27,7 @@ using conv_fun = std::function<void(
const
CpuNDRange
&
workspace_ids
,
const
CpuNDRange
&
ncb_range
)
>
;
MIDOUT_DECL
(
megdnn_arm_common_conv_bias_fp32_nchw44_stride1
)
namespace
{
// block_helper is used to calculate oh block size
static
inline
int
block_helper
(
const
int
nthread
,
const
int
amount
,
const
int
size_per_unit
)
{
constexpr
int
l2_cache_size
=
256
*
1024
;
const
int
block_per_thread
=
div_ceil
(
amount
,
nthread
);
const
int
best_block
=
std
::
min
(
amount
,
(
l2_cache_size
+
size_per_unit
/
2
)
/
size_per_unit
);
const
int
max_block_num
=
div_ceil
(
block_per_thread
,
best_block
);
const
int
min_block_num
=
std
::
max
(
max_block_num
-
1
,
1
);
const
int
max_block
=
div_ceil
(
block_per_thread
,
max_block_num
);
const
int
min_block
=
div_ceil
(
block_per_thread
,
min_block_num
);
const
int
max_loss
=
std
::
abs
(
max_block_num
*
max_block
-
block_per_thread
);
const
int
min_loss
=
std
::
abs
(
min_block_num
*
min_block
-
block_per_thread
);
int
block
=
max_loss
>
min_loss
?
min_block
:
max_block
;
return
block
;
}
static
inline
size_t
get_perthread_cache_bytes
(
const
int
ic
,
const
int
ih2
,
const
int
iw2
)
{
// border_size is used to avoid read illegal memory
...
...
@@ -60,7 +46,7 @@ static void get_rectified_size(
ow2
=
ow
;
constexpr
int
cacheline
=
64
/
sizeof
(
float
);
int
block_oh
=
block_helper
(
param
.
nr_threads
,
oh
,
ic
*
iw
*
sizeof
(
float
)
*
2
);
l2_
block_helper
(
param
.
nr_threads
,
oh
,
ic
*
iw
*
sizeof
(
float
)
*
2
);
auto
&&
fm
=
param
.
filter_meta
;
const
int
stride_h
=
static_cast
<
int
>
(
fm
.
stride
[
0
]);
const
int
filter_h
=
static_cast
<
int
>
(
fm
.
spatial
[
0
]);
...
...
@@ -106,8 +92,8 @@ static void do_conv_kern(WorkspaceBundle bundle,
const
int
group_id
=
ncb_index
.
ndrange_id
[
1
];
constexpr
int
oc_idx
=
0
;
int
oc_block
=
oc
;
int
oh_block
=
block_helper
(
kern_param
.
nr_threads
,
oh2
,
ic
*
iw
*
sizeof
(
float
)
*
stride_h
);
int
oh_block
=
l2_
block_helper
(
kern_param
.
nr_threads
,
oh2
,
ic
*
iw
*
sizeof
(
float
)
*
stride_h
);
const
int
oh_idx
=
ncb_index
.
ndrange_id
[
2
];
const
int
oh_block_real
=
std
::
min
(
oh
-
oh_idx
*
oh_block
,
oh_block
);
const
int
ih_real
=
oh_block_real
*
stride_h
+
fh
-
stride_h
;
...
...
@@ -298,8 +284,8 @@ ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_kerns(
int
ic
=
param
.
filter_meta
.
icpg
;
int
iw
=
param
.
isz
[
1
];
int
stride_h
=
param
.
filter_meta
.
stride
[
0
];
int
oh_block
=
block_helper
(
param
.
nr_threads
,
oh
,
ic
*
iw
*
sizeof
(
float
)
*
stride_h
);
int
oh_block
=
l2_
block_helper
(
param
.
nr_threads
,
oh
,
ic
*
iw
*
sizeof
(
float
)
*
stride_h
);
CpuNDRange
ncb_range
=
{
static_cast
<
size_t
>
(
batch
),
static_cast
<
size_t
>
(
group
),
static_cast
<
size_t
>
(
div_ceil
(
oh
,
oh_block
))};
...
...
dnn/src/arm_common/conv_bias/int8/algos.h
浏览文件 @
6c29548d
...
...
@@ -133,6 +133,21 @@ public:
};
#if __ARM_FEATURE_DOTPROD
class
ConvBiasImpl
::
AlgoDotS8DirectNCHWNCHW44
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"ARMDOTS8_NCHW_NCHW44"
;
}
bool
usable
(
FallbackConvBiasImpl
*
,
const
NCBKernSizeParam
&
,
AlgoSelectionStrategy
algo_selection_strategy
)
const
override
;
size_t
get_workspace
(
FallbackConvBiasImpl
*
,
const
NCBKernSizeParam
&
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
fallback
::
ConvBiasImpl
*
opr
,
const
NCBKernSizeParam
&
param
)
const
override
;
};
class
ConvBiasImpl
::
AlgoDotS8DirectStride1
final
:
public
AlgoBase
{
bool
m_large_group
;
...
...
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
0 → 100644
浏览文件 @
6c29548d
/**
* \file
* dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "midout.h"
using
namespace
megdnn
;
using
namespace
arm_common
;
using
conv_fun
=
std
::
function
<
void
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
CpuNDRange
&
workspace_ids
,
const
CpuNDRange
&
ncb_range
)
>
;
MIDOUT_DECL
(
megdnn_arm_common_conv_bias_int8_nchw44_dot
)
namespace
{
static
inline
size_t
get_perthread_cache_bytes
(
const
int
ic
,
const
int
ih2
,
const
int
iw2
,
const
int
stride
)
{
//! border_size is used to avoid read illegal memory
constexpr
int
cacheline_size
=
64
;
constexpr
int
border_size
=
2
*
cacheline_size
;
const
int
pack_iw_len
=
stride
==
1
?
4
:
1
;
return
round_up
(
ic
*
ih2
*
iw2
*
pack_iw_len
*
(
int
)
sizeof
(
int8_t
)
+
border_size
,
cacheline_size
);
}
static
inline
size_t
get_temp_bytes
(
const
int
iw
,
const
int
pw
)
{
//! border_size is used to avoid read illegal memory
constexpr
int
cacheline_size
=
64
;
constexpr
int
border_size
=
1
*
cacheline_size
;
return
round_up
(
iw
+
pw
*
2
,
cacheline_size
)
+
border_size
;
}
static
void
get_rectified_size
(
const
megdnn
::
fallback
::
ConvBiasImpl
::
NCBKernSizeParam
&
param
,
int
&
ih2
,
int
&
iw2
)
{
auto
&&
fm
=
param
.
filter_meta
;
const
int
stride_h
=
static_cast
<
int
>
(
fm
.
stride
[
0
]);
const
int
filter_h
=
static_cast
<
int
>
(
fm
.
spatial
[
0
]);
int
ic
=
param
.
filter_meta
.
icpg
;
int
iw
=
param
.
isz
[
1
];
int
oh
=
param
.
osz
[
0
];
int
block_oh
=
l2_block_helper
(
param
.
nr_threads
,
oh
,
ic
*
iw
*
sizeof
(
int8_t
)
*
stride_h
);
ih2
=
block_oh
*
stride_h
+
filter_h
-
stride_h
;
iw2
=
iw
+
2
*
static_cast
<
int
>
(
fm
.
padding
[
1
]);
}
static
WorkspaceBundle
get_bundle
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
)
{
auto
&&
fm
=
param
.
filter_meta
;
int
ic
=
fm
.
icpg
;
int
fh
=
fm
.
spatial
[
0
];
int
fw
=
fm
.
spatial
[
1
];
int
iw
=
param
.
isz
[
1
];
int
pw
=
param
.
filter_meta
.
padding
[
1
];
int
stride_w
=
param
.
filter_meta
.
stride
[
1
];
int
ih2
,
iw2
;
get_rectified_size
(
param
,
ih2
,
iw2
);
size_t
src_size
=
get_perthread_cache_bytes
(
ic
,
ih2
,
iw2
,
stride_w
);
size_t
weight_size
=
fm
.
group
*
fm
.
icpg
*
fm
.
ocpg
*
fh
*
round_up
(
fw
,
4
);
size_t
temp_size
=
0
;
if
(
fm
.
stride
[
0
]
==
1
)
{
temp_size
=
get_temp_bytes
(
iw
,
pw
);
}
return
{
nullptr
,
{
src_size
*
param
.
nr_threads
,
weight_size
,
temp_size
*
param
.
nr_threads
}};
};
void
do_weight_trans
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
,
const
CpuNDRange
&
)
{
const
int
ic
=
kern_param
.
filter_meta
.
icpg
;
const
int
oc
=
kern_param
.
filter_meta
.
ocpg
;
const
int
fh
=
kern_param
.
filter_meta
.
spatial
[
0
];
const
int
fw
=
kern_param
.
filter_meta
.
spatial
[
1
];
const
int
fw2
=
round_up
(
fw
,
4
);
bundle
.
set
(
kern_param
.
workspace_ptr
);
auto
packed_weight
=
reinterpret_cast
<
int8_t
*>
(
bundle
.
get
(
1
));
auto
origin_weight
=
kern_param
.
filter
<
dt_int8
>
();
pack_weight_int8_nchw_nchw44_dot
(
packed_weight
,
origin_weight
,
oc
,
ic
,
fh
,
fw
,
fw2
);
}
template
<
size_t
filter
,
BiasMode
bias_mode
,
typename
Op
,
int
stride
>
static
void
do_conv_kern
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
kern_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
CpuNDRange
&
,
const
CpuNDRange
&
)
{
const
int
oh
=
kern_param
.
osz
[
0
];
const
int
ow
=
kern_param
.
osz
[
1
];
const
int
fh
=
kern_param
.
filter_meta
.
spatial
[
0
];
const
int
fw
=
kern_param
.
filter_meta
.
spatial
[
1
];
const
int
ic
=
kern_param
.
filter_meta
.
icpg
;
const
int
oc
=
kern_param
.
filter_meta
.
ocpg
;
const
int
ih
=
kern_param
.
isz
[
0
];
const
int
iw
=
kern_param
.
isz
[
1
];
const
int
stride_h
=
kern_param
.
filter_meta
.
stride
[
0
];
const
int
stride_w
=
kern_param
.
filter_meta
.
stride
[
1
];
const
int
ph
=
kern_param
.
filter_meta
.
padding
[
0
];
const
int
pw
=
kern_param
.
filter_meta
.
padding
[
1
];
int
ih2
=
0
;
int
iw2
=
0
;
get_rectified_size
(
kern_param
,
ih2
,
iw2
);
bundle
.
set
(
kern_param
.
workspace_ptr
);
constexpr
int
pack_c
=
4
;
const
int
batch_id
=
ncb_index
.
ndrange_id
[
0
];
const
int
group_id
=
ncb_index
.
ndrange_id
[
1
];
constexpr
int
oc_idx
=
0
;
int
oc_block
=
oc
;
int
oh_block
=
l2_block_helper
(
kern_param
.
nr_threads
,
oh
,
ic
*
iw
*
sizeof
(
int8_t
)
*
stride_h
);
const
int
oh_idx
=
ncb_index
.
ndrange_id
[
2
];
const
int
oh_block_real
=
std
::
min
(
oh
-
oh_idx
*
oh_block
,
oh_block
);
const
int
ih_real
=
oh_block_real
*
stride_h
+
fh
-
stride_h
;
const
int
src_top_pad
=
std
::
max
(
ph
-
oh_idx
*
oh_block
*
stride_h
,
0
);
const
int
src_bottom_pad
=
std
::
max
(
(
oh_idx
*
oh_block
+
oh_block_real
-
1
)
*
stride_h
+
fh
-
ih
-
ph
,
0
);
const
int
remain_right_pad
=
std
::
max
(
iw2
-
iw
-
pw
,
0
);
const
int
src_offset
=
std
::
max
(
oh_idx
*
oh_block
*
stride_h
-
ph
,
0
)
*
iw
;
const
int8_t
*
origin_sptr
=
static_cast
<
const
int8_t
*>
(
kern_param
.
src
<
int8_t
>
(
batch_id
,
group_id
,
0
,
1
,
1
))
+
src_offset
;
const
size_t
src_size
=
get_perthread_cache_bytes
(
ic
,
ih2
,
iw2
,
stride_w
);
int8_t
*
sptr
=
reinterpret_cast
<
int8_t
*>
(
bundle
.
get
(
0
))
+
ncb_index
.
thread_id
*
src_size
;
int8_t
*
tmp_ptr
=
nullptr
;
if
(
stride
==
1
)
{
const
size_t
tmp_size
=
get_temp_bytes
(
iw
,
pw
);
tmp_ptr
=
reinterpret_cast
<
int8_t
*>
(
bundle
.
get
(
2
))
+
ncb_index
.
thread_id
*
tmp_size
;
}
pack_src_int8_nchw_nchw44_dot
<
stride
>
(
sptr
,
origin_sptr
,
ph
,
pw
,
remain_right_pad
,
ih_real
-
src_top_pad
-
src_bottom_pad
,
iw
,
iw2
,
src_top_pad
,
src_bottom_pad
,
ic
,
ih
*
iw
,
tmp_ptr
);
const
int8_t
*
fptr
=
reinterpret_cast
<
int8_t
*>
(
bundle
.
get
(
1
))
+
oc_idx
*
fh
*
fw
*
ic
;
int8_t
*
dst
=
kern_param
.
dst
<
int8_t
>
(
batch_id
,
group_id
)
+
oh_idx
*
oh_block
*
ow
*
pack_c
;
const
int
bias_offset
=
oc_idx
;
const
int32_t
*
bptr
=
kern_param
.
bias
<
dt_int32
>
(
batch_id
,
group_id
)
+
bias_offset
;
float
scale_bias
=
kern_param
.
bias_type
.
param
<
dtype
::
QuantizedS32
>
().
scale
;
float
scale_dst
=
kern_param
.
dst_type
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
Op
op
(
scale_bias
,
scale_dst
);
conv_direct_int8_nchw_nchw44_dot
<
bias_mode
,
Op
,
filter
,
stride
>
(
sptr
,
fptr
,
bptr
,
nullptr
,
dst
,
oc_block
,
ic
,
ih_real
,
iw2
,
oh
,
oh_block_real
,
ow
,
op
);
}
}
// namespace
bool
ConvBiasImpl
::
AlgoDotS8DirectNCHWNCHW44
::
usable
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
)
const
{
auto
&&
fm
=
param
.
filter_meta
;
auto
fh
=
fm
.
spatial
[
0
];
int
oc
=
fm
.
ocpg
;
int
ic
=
fm
.
icpg
;
bool
ok_type
=
((
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
param
.
filter_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
(
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)))
&&
(
fm
.
format
==
param
::
Convolution
::
Format
::
NCHW44
);
bool
ok_src_dst
=
(
oc
%
4
==
0
&&
oc
>=
4
&&
ic
<
4
);
bool
ok_filter
=
fm
.
spatial_ndim
==
2
&&
fh
==
fm
.
spatial
[
1
]
&&
(
fh
==
2
||
fh
==
3
||
fh
==
5
||
fh
==
7
);
bool
ok_slide
=
fm
.
dilation
[
0
]
==
1
&&
fm
.
dilation
[
1
]
==
1
&&
fm
.
stride
[
0
]
==
fm
.
stride
[
1
]
&&
(
fm
.
stride
[
0
]
==
1
||
fm
.
stride
[
0
]
==
2
);
bool
ok_conv
=
!
fm
.
should_flip
&&
param
.
bias_mode
!=
BiasMode
::
BIAS
;
bool
avaible
=
ok_type
&&
ok_src_dst
&&
ok_filter
&&
ok_slide
&&
ok_conv
;
return
avaible
;
}
size_t
ConvBiasImpl
::
AlgoDotS8DirectNCHWNCHW44
::
get_workspace
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
return
get_bundle
(
param
).
total_size_in_bytes
();
}
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ConvBiasImpl
::
AlgoDotS8DirectNCHWNCHW44
::
dispatch_kerns
(
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
auto
fm
=
param
.
filter_meta
;
const
int
batch
=
param
.
n
;
const
int
group
=
fm
.
group
;
WorkspaceBundle
wbundle
=
get_bundle
(
param
);
conv_fun
do_conv_fun
=
nullptr
;
// NOTE: remain_w is not used to gen hash of midout for compatible with
// shape runtime
#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_dot, \
midout_iv(#stride #filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(stride, filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_BIAS_MODE_PARAM(stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define DISPATCH_CONV_KERN(stride) \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(stride, 2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(stride, 3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(stride, 5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(stride, 7) \
break; \
default: \
megdnn_assert(0); \
break; \
}
switch
(
param
.
filter_meta
.
stride
[
0
])
{
case
1
:
DISPATCH_CONV_KERN
(
1
);
break
;
case
2
:
DISPATCH_CONV_KERN
(
2
);
break
;
default:
megdnn_assert
(
0
);
break
;
}
#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
#undef GET_OP_PARAM
#undef GET_BIAS_MODE_PARAM
#undef DISPATCH_CONV_KERN
megdnn_assert
(
do_conv_fun
);
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ret_kerns
;
WorkspaceBundle
bundle
=
wbundle
;
int
oh
=
param
.
osz
[
0
];
int
ic
=
param
.
filter_meta
.
icpg
;
int
iw
=
param
.
isz
[
1
];
int
stride_h
=
param
.
filter_meta
.
stride
[
0
];
int
oh_block
=
l2_block_helper
(
param
.
nr_threads
,
oh
,
ic
*
iw
*
sizeof
(
int8_t
)
*
stride_h
);
CpuNDRange
ncb_range
=
{
static_cast
<
size_t
>
(
batch
),
static_cast
<
size_t
>
(
group
),
static_cast
<
size_t
>
(
div_ceil
(
oh
,
oh_block
))};
auto
do_trans_weight
=
[
bundle
](
const
NCBKernParam
&
kern_param
,
const
NCBKernIndex
&
ncb_index
)
{
do_weight_trans
(
bundle
,
kern_param
,
ncb_index
,
ncb_index
.
ndrange_id
);
};
ret_kerns
.
push_back
({
do_trans_weight
,
{
1
}});
auto
do_conv
=
[
bundle
,
do_conv_fun
,
ncb_range
](
const
NCBKernParam
&
kern_param
,
const
NCBKernIndex
&
ncb_index
)
{
do_conv_fun
(
bundle
,
kern_param
,
ncb_index
,
ncb_index
.
ndrange_id
,
ncb_range
);
};
ret_kerns
.
push_back
({
do_conv
,
ncb_range
});
return
ret_kerns
;
}
#endif
// vim: syntax=cpp.doxygen
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h
0 → 100644
浏览文件 @
6c29548d
/**
* \file
* dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
using
namespace
megdnn
;
using
namespace
arm_common
;
namespace
{
template
<
int
src_idx
,
int
weight_idx
,
int
c_dim
,
typename
Func
,
int
ow_block
,
int
stride
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
struct
ShiftCalHelper
{
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
);
};
template
<
int
src_idx
,
int
weight_idx
,
typename
Func
,
int
stride
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
struct
ShiftCalHelper
<
src_idx
,
weight_idx
,
2
,
Func
,
8
,
stride
,
T
,
T2
,
T3
,
T4
>
{
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
)
{
#define cb(step) \
c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2], weight[0][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[1][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[1][step * 2], weight[1][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2 + 1], weight[0][weight_idx], \
src[1][(src_idx + step) / 4]); \
c[1][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[1][step * 2 + 1], weight[1][weight_idx], \
src[1][(src_idx + step) / 4]);
UNROLL_CALL_RAW
(
4
,
cb
);
#undef cb
}
};
template
<
int
src_idx
,
int
weight_idx
,
typename
Func
,
int
stride
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
struct
ShiftCalHelper
<
src_idx
,
weight_idx
,
1
,
Func
,
8
,
stride
,
T
,
T2
,
T3
,
T4
>
{
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
)
{
#define cb(step) \
c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2], weight[0][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2 + 1], weight[0][weight_idx], \
src[1][(src_idx + step) / 4]);
UNROLL_CALL_RAW
(
4
,
cb
);
#undef cb
}
};
template
<
int
src_idx
,
int
weight_idx
,
typename
Func
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
struct
ShiftCalHelper
<
src_idx
,
weight_idx
,
2
,
Func
,
8
,
1
,
T
,
T2
,
T3
,
T4
>
{
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
)
{
#define cb(step) \
c[0][step] = Func::template impl<(src_idx + step) % 4>( \
c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); \
c[1][step] = Func::template impl<(src_idx + step) % 4>( \
c[1][step], weight[1][weight_idx], src[(src_idx + step) / 4]);
UNROLL_CALL_RAW
(
8
,
cb
);
#undef cb
}
};
template
<
int
src_idx
,
int
weight_idx
,
typename
Func
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
struct
ShiftCalHelper
<
src_idx
,
weight_idx
,
1
,
Func
,
8
,
1
,
T
,
T2
,
T3
,
T4
>
{
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
)
{
#define cb(step) \
c[0][step] = Func::template impl<(src_idx + step) % 4>( \
c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]);
UNROLL_CALL_RAW
(
8
,
cb
);
#undef cb
}
};
template
<
int
src_idx
,
int
weight_idx
,
int
c_dim
,
typename
FUNC
,
int
ow_block
,
int
stride
,
typename
T
,
typename
T2
,
typename
T3
>
inline
void
cal_helper
(
T
&
c
,
T2
&
src
,
T3
&
weight
)
{
ShiftCalHelper
<
src_idx
,
weight_idx
,
c_dim
,
FUNC
,
ow_block
,
stride
,
T
,
T2
,
T3
,
int
>::
impl
(
c
,
src
,
weight
);
};
//! OCHelper is used to trans oc_block to row number of result regs
template
<
int
oc
>
struct
OCHelper
{
public:
static
const
int
val
=
-
1
;
};
template
<
>
struct
OCHelper
<
4
>
{
public:
static
const
int
val
=
1
;
};
#if MEGDNN_AARCH64
template
<
>
struct
OCHelper
<
8
>
{
public:
static
const
int
val
=
2
;
};
#endif
/**
* oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel
* */
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
filter_size
,
int
oc_block
,
int
ow_block
,
int
stride
>
struct
KerNeonDotXXs2Nchw44Int8
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
);
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
,
int
ow_block
,
int
stride
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
2
,
oc_block
,
ow_block
,
stride
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
filter_hight
=
2
;
constexpr
int
filter_width
=
4
;
constexpr
int
weight_reg
=
1
;
constexpr
int
src_reg
=
1
;
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
1
;
constexpr
int
pack_iw_len
=
1
;
constexpr
int
simd_len
=
16
;
const
int
ld_bias
=
oc_step
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc
=
oc_step
*
filter_hight
*
filter_width
*
ic
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
int32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
ow_block
>
(
c
,
bias_ptr
,
ld_bias
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
ic_step
)
{
int8x16_t
src
[
2
][
src_reg
];
int8x16_t
weight
[
c_dim
][
weight_reg
];
// row 0
load_helper
<
src_reg
,
0
,
simd_len
,
2
,
Vld1q_s8
>
(
src
,
src_ptr
+
0
*
iw
,
stride
);
load_helper
<
weight_reg
,
0
,
simd_len
,
c_dim
,
Vld1q_s8
>
(
weight
,
weight_ptr
,
ld_weight_oc
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_laneq_s32
,
ow_block
,
stride
>
(
c
,
src
,
weight
);
// row 1
load_helper
<
src_reg
,
0
,
simd_len
,
2
,
Vld1q_s8
>
(
src
,
src_ptr
+
1
*
iw
,
stride
);
load_helper
<
weight_reg
,
1
*
simd_len
,
simd_len
,
c_dim
,
Vld1q_s8
>
(
weight
,
weight_ptr
,
ld_weight_oc
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_laneq_s32
,
ow_block
,
stride
>
(
c
,
src
,
weight
);
src_ptr
+=
ic_stride
;
weight_ptr
+=
filter_hight
*
filter_width
*
oc_step
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
dt_qint8
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
,
int
ow_block
,
int
stride
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
3
,
oc_block
,
ow_block
,
stride
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
filter_hight
=
3
;
constexpr
int
filter_width
=
4
;
constexpr
int
weight_reg
=
1
;
constexpr
int
src_reg
=
1
;
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
1
;
constexpr
int
pack_iw_len
=
1
;
constexpr
int
simd_len
=
16
;
const
int
ld_bias
=
oc_step
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc
=
oc_step
*
filter_hight
*
filter_width
*
ic
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
int32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
ow_block
>
(
c
,
bias_ptr
,
ld_bias
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
ic_step
)
{
int8x16_t
src
[
2
][
src_reg
];
int8x16_t
weight
[
c_dim
][
weight_reg
];
// row 0
load_helper
<
src_reg
,
0
,
simd_len
,
2
,
Vld1q_s8
>
(
src
,
src_ptr
+
0
*
iw
,
stride
);
load_helper
<
weight_reg
,
0
,
simd_len
,
c_dim
,
Vld1q_s8
>
(
weight
,
weight_ptr
,
ld_weight_oc
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_laneq_s32
,
ow_block
,
stride
>
(
c
,
src
,
weight
);
// row 1
load_helper
<
src_reg
,
0
,
simd_len
,
2
,
Vld1q_s8
>
(
src
,
src_ptr
+
1
*
iw
,
stride
);
load_helper
<
weight_reg
,
1
*
simd_len
,
simd_len
,
c_dim
,
Vld1q_s8
>
(
weight
,
weight_ptr
,
ld_weight_oc
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_laneq_s32
,
ow_block
,
stride
>
(
c
,
src
,
weight
);
// row 2
load_helper
<
src_reg
,
0
,
simd_len
,
2
,
Vld1q_s8
>
(
src
,
src_ptr
+
2
*
iw
,
stride
);
load_helper
<
weight_reg
,
2
*
simd_len
,
simd_len
,
c_dim
,
Vld1q_s8
>
(
weight
,
weight_ptr
,
ld_weight_oc
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_laneq_s32
,
ow_block
,
stride
>
(
c
,
src
,
weight
);
src_ptr
+=
ic_stride
;
weight_ptr
+=
filter_hight
*
filter_width
*
oc_step
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
dt_qint8
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
,
int
ow_block
,
int
stride
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
5
,
oc_block
,
ow_block
,
stride
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
filter_hight
=
5
;
constexpr
int
filter_width
=
8
;
constexpr
int
src_reg
=
2
;
constexpr
int
weight_reg
=
2
;
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
1
;
constexpr
int
pack_iw_len
=
1
;
constexpr
int
simd_len
=
16
;
const
int
ld_bias
=
oc_step
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc
=
oc_step
*
filter_hight
*
filter_width
*
ic
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
int32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
ow_block
>
(
c
,
bias_ptr
,
ld_bias
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
ic_step
)
{
int8x16_t
src
[
2
][
src_reg
];
int8x16_t
weight
[
c_dim
][
weight_reg
];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \
stride); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);
UNROLL_CALL_RAW
(
5
,
cb
);
#undef cb
src_ptr
+=
ic_stride
;
weight_ptr
+=
5
*
32
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
dt_qint8
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
/**
* oc = 8, ow = 8
* dot 4 element, pad last filter and do twice dot every row filter, filter like
* below
* --------------------------
* |x, x, x, x,| x, x, x, 0 |
* --------------------------
**/
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
,
int
ow_block
,
int
stride
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
7
,
oc_block
,
ow_block
,
stride
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
filter_hight
=
7
;
constexpr
int
filter_width
=
8
;
constexpr
int
src_reg
=
2
;
constexpr
int
weight_reg
=
2
;
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
1
;
constexpr
int
pack_iw_len
=
1
;
constexpr
int
simd_len
=
16
;
const
int
ld_bias
=
oc_step
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc
=
oc_step
*
filter_hight
*
filter_width
*
ic
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
int32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
ow_block
>
(
c
,
bias_ptr
,
ld_bias
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
ic_step
)
{
int8x16_t
src
[
2
][
src_reg
];
int8x16_t
weight
[
c_dim
][
weight_reg
];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \
stride); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);
UNROLL_CALL_RAW
(
7
,
cb
);
#undef cb
src_ptr
+=
ic_stride
;
weight_ptr
+=
7
*
32
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
dt_qint8
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
////////////////////stride 1///////////////////
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
,
int
ow_block
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
2
,
oc_block
,
ow_block
,
1
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
stride
=
1
;
constexpr
int
filter_hight
=
2
;
constexpr
int
filter_width
=
4
;
constexpr
int
weight_reg
=
2
;
constexpr
int
src_reg
=
2
;
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
1
;
constexpr
int
pack_iw_len
=
4
;
constexpr
int
simd_len
=
16
;
const
int
ld_bias
=
oc_step
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc
=
oc_step
*
filter_hight
*
filter_width
*
ic
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
int32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
ow_block
>
(
c
,
bias_ptr
,
ld_bias
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
ic_step
)
{
int8x16_t
src
[
src_reg
];
int8x16_t
weight
[
c_dim
][
weight_reg
];
// row 0
load_helper
<
src_reg
,
0
,
simd_len
,
0
,
Vld1q_s8
>
(
src
,
src_ptr
+
0
*
iw
*
pack_iw_len
,
0
);
load_helper
<
weight_reg
,
0
,
simd_len
,
c_dim
,
Vld1q_s8
>
(
weight
,
weight_ptr
,
ld_weight_oc
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_laneq_s32
,
ow_block
,
stride
>
(
c
,
src
,
weight
);
// row 1
load_helper
<
src_reg
,
0
,
simd_len
,
0
,
Vld1q_s8
>
(
src
,
src_ptr
+
1
*
iw
*
pack_iw_len
,
0
);
cal_helper
<
0
,
1
,
c_dim
,
Vdotq_laneq_s32
,
ow_block
,
stride
>
(
c
,
src
,
weight
);
src_ptr
+=
ic_stride
;
weight_ptr
+=
filter_hight
*
filter_width
*
oc_step
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
dt_qint8
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
,
int
ow_block
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
3
,
oc_block
,
ow_block
,
1
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
stride
=
1
;
constexpr
int
filter_hight
=
3
;
constexpr
int
filter_width
=
4
;
constexpr
int
weight_reg
=
3
;
constexpr
int
src_reg
=
2
;
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
1
;
constexpr
int
pack_iw_len
=
4
;
constexpr
int
simd_len
=
16
;
const
int
ld_bias
=
oc_step
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc
=
oc_step
*
filter_hight
*
filter_width
*
ic
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
int32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
ow_block
>
(
c
,
bias_ptr
,
ld_bias
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
ic_step
)
{
int8x16_t
src
[
src_reg
];
int8x16_t
weight
[
c_dim
][
weight_reg
];
// row 0
load_helper
<
src_reg
,
0
,
simd_len
,
0
,
Vld1q_s8
>
(
src
,
src_ptr
+
0
*
iw
*
pack_iw_len
,
0
);
load_helper
<
weight_reg
,
0
,
simd_len
,
c_dim
,
Vld1q_s8
>
(
weight
,
weight_ptr
,
ld_weight_oc
);
cal_helper
<
0
,
0
,
c_dim
,
Vdotq_laneq_s32
,
ow_block
,
stride
>
(
c
,
src
,
weight
);
// row 1
load_helper
<
src_reg
,
0
,
simd_len
,
0
,
Vld1q_s8
>
(
src
,
src_ptr
+
1
*
iw
*
pack_iw_len
,
0
);
cal_helper
<
0
,
1
,
c_dim
,
Vdotq_laneq_s32
,
ow_block
,
stride
>
(
c
,
src
,
weight
);
// row 2
load_helper
<
src_reg
,
0
,
simd_len
,
0
,
Vld1q_s8
>
(
src
,
src_ptr
+
2
*
iw
*
pack_iw_len
,
0
);
cal_helper
<
0
,
2
,
c_dim
,
Vdotq_laneq_s32
,
ow_block
,
stride
>
(
c
,
src
,
weight
);
src_ptr
+=
ic_stride
;
weight_ptr
+=
filter_hight
*
filter_width
*
oc_step
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
dt_qint8
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
,
int
ow_block
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
5
,
oc_block
,
ow_block
,
1
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
stride
=
1
;
constexpr
int
filter_hight
=
5
;
constexpr
int
filter_width
=
8
;
constexpr
int
src_reg
=
3
;
constexpr
int
weight_reg
=
2
;
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
1
;
constexpr
int
pack_iw_len
=
4
;
constexpr
int
simd_len
=
16
;
const
int
ld_bias
=
oc_step
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc
=
oc_step
*
filter_hight
*
filter_width
*
ic
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
int32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
ow_block
>
(
c
,
bias_ptr
,
ld_bias
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
ic_step
)
{
int8x16_t
src
[
src_reg
];
int8x16_t
weight
[
c_dim
][
weight_reg
];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, src_ptr + step * iw * pack_iw_len, 0); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);
UNROLL_CALL_RAW
(
5
,
cb
);
#undef cb
src_ptr
+=
ic_stride
;
weight_ptr
+=
filter_hight
*
filter_width
*
oc_step
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
dt_qint8
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
BiasMode
bias_mode
,
typename
Op
,
int
remain_w
,
int
oc_block
,
int
ow_block
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
7
,
oc_block
,
ow_block
,
1
>
{
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
constexpr
int
stride
=
1
;
constexpr
int
filter_hight
=
7
;
constexpr
int
filter_width
=
8
;
constexpr
int
src_reg
=
3
;
constexpr
int
weight_reg
=
2
;
constexpr
int
oc_step
=
4
;
constexpr
int
ic_step
=
1
;
constexpr
int
pack_iw_len
=
4
;
constexpr
int
simd_len
=
16
;
const
int
ld_bias
=
oc_step
;
const
int
ic_stride
=
ih
*
iw
*
pack_iw_len
;
const
int
ld_weight_oc
=
oc_step
*
filter_hight
*
filter_width
*
ic
;
constexpr
int
c_dim
=
OCHelper
<
oc_block
>::
val
;
int32x4_t
c
[
c_dim
][
8
];
init_ocx_ow8
<
c_dim
,
bias_mode
,
ow_block
>
(
c
,
bias_ptr
,
ld_bias
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
ic_step
)
{
int8x16_t
src
[
src_reg
];
int8x16_t
weight
[
c_dim
][
weight_reg
];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, src_ptr + step * iw * pack_iw_len, 0); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);
UNROLL_CALL_RAW
(
7
,
cb
);
#undef cb
src_ptr
+=
ic_stride
;
weight_ptr
+=
filter_hight
*
filter_width
*
oc_step
;
}
store_ocx_ow8_remain_static_dt
<
c_dim
,
remain_w
,
Op
,
dt_qint8
*>
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
};
template
<
int
stride
>
void
pack_src_int8_nchw_nchw44_dot
(
int8_t
*
sptr_base
,
const
int8_t
*
sptr_origin
,
const
int
,
const
int
pw
,
const
int
,
const
int
ih
,
const
int
iw
,
const
int
iw2
,
const
int
pad_top
,
const
int
pad_bottom
,
const
int
ic
,
const
int
ic_stride
,
int8_t
*
)
{
constexpr
int
ic_step
=
1
;
rep_step
(
ic_idx
,
ic
,
ic_step
)
{
const
int8_t
*
sptr
=
sptr_origin
+
ic_idx
*
ic_stride
;
memset
(
sptr_base
,
0
,
sizeof
(
int8_t
)
*
ic_step
*
iw2
*
(
ih
+
pad_top
+
pad_bottom
));
sptr_base
+=
iw2
*
pad_top
*
ic_step
;
rep
(
ih_idx
,
ih
)
{
memcpy
(
sptr_base
+
pw
*
ic_step
,
sptr
,
sizeof
(
int8_t
)
*
iw
*
ic_step
);
sptr_base
+=
iw2
*
ic_step
;
sptr
+=
iw
*
ic_step
;
}
sptr_base
+=
iw2
*
pad_bottom
*
ic_step
;
}
}
template
<
>
void
pack_src_int8_nchw_nchw44_dot
<
1
>
(
int8_t
*
sptr_base
,
const
int8_t
*
sptr_origin
,
const
int
,
const
int
pw
,
const
int
,
const
int
ih
,
const
int
iw
,
const
int
iw2
,
const
int
pad_top
,
const
int
pad_bottom
,
const
int
ic
,
const
int
ic_stride
,
int8_t
*
temp_ptr
)
{
static
uint8_t
reorder_idx
[
16
]
=
{
0
,
1
,
2
,
3
,
1
,
2
,
3
,
4
,
2
,
3
,
4
,
5
,
3
,
4
,
5
,
6
};
uint8x16_t
tbl_idx
=
vld1q_u8
(
&
reorder_idx
[
0
]);
constexpr
int
iw_step
=
16
;
constexpr
int
pack_iw_len
=
4
;
const
int
iw_with_pad
=
iw
+
2
*
pw
;
const
int
iw_with_pad_end
=
iw_with_pad
/
iw_step
*
iw_step
;
rep
(
ic_idx
,
ic
)
{
const
int8_t
*
sptr
=
sptr_origin
+
ic_idx
*
ic_stride
;
memset
(
sptr_base
,
0
,
sizeof
(
int8_t
)
*
iw2
*
(
ih
+
pad_top
+
pad_bottom
)
*
pack_iw_len
);
sptr_base
+=
iw2
*
pad_top
*
pack_iw_len
;
rep
(
ih_idx
,
ih
)
{
memset
(
temp_ptr
,
0
,
iw_with_pad
*
sizeof
(
int8_t
));
memcpy
(
temp_ptr
+
pw
,
sptr
,
sizeof
(
int8_t
)
*
iw
);
for
(
int
iw_idx
=
0
;
iw_idx
<
iw_with_pad_end
;
iw_idx
+=
iw_step
)
{
int8x16_t
src
[
4
];
int8x16_t
dst
[
4
];
src
[
0
]
=
vld1q_s8
(
temp_ptr
+
iw_idx
);
src
[
1
]
=
vld1q_s8
(
temp_ptr
+
iw_idx
+
4
);
src
[
2
]
=
vld1q_s8
(
temp_ptr
+
iw_idx
+
8
);
src
[
3
]
=
vld1q_s8
(
temp_ptr
+
iw_idx
+
12
);
dst
[
0
]
=
vqtbl1q_s8
(
src
[
0
],
tbl_idx
);
dst
[
1
]
=
vqtbl1q_s8
(
src
[
1
],
tbl_idx
);
dst
[
2
]
=
vqtbl1q_s8
(
src
[
2
],
tbl_idx
);
dst
[
3
]
=
vqtbl1q_s8
(
src
[
3
],
tbl_idx
);
vst1q_s8
(
sptr_base
+
iw_idx
*
pack_iw_len
+
0
,
dst
[
0
]);
vst1q_s8
(
sptr_base
+
iw_idx
*
pack_iw_len
+
16
,
dst
[
1
]);
vst1q_s8
(
sptr_base
+
iw_idx
*
pack_iw_len
+
32
,
dst
[
2
]);
vst1q_s8
(
sptr_base
+
iw_idx
*
pack_iw_len
+
48
,
dst
[
3
]);
}
for
(
int
iw_idx
=
iw_with_pad_end
;
iw_idx
<
iw_with_pad
;
++
iw_idx
)
{
*
(
sptr_base
+
iw_idx
*
pack_iw_len
+
0
)
=
*
(
temp_ptr
+
iw_idx
+
0
);
*
(
sptr_base
+
iw_idx
*
pack_iw_len
+
1
)
=
*
(
temp_ptr
+
iw_idx
+
1
);
*
(
sptr_base
+
iw_idx
*
pack_iw_len
+
2
)
=
*
(
temp_ptr
+
iw_idx
+
2
);
*
(
sptr_base
+
iw_idx
*
pack_iw_len
+
3
)
=
*
(
temp_ptr
+
iw_idx
+
3
);
}
sptr_base
+=
iw2
*
pack_iw_len
;
sptr
+=
iw
;
}
sptr_base
+=
iw2
*
pad_bottom
*
pack_iw_len
;
}
}
static
inline
void
pack_weight_int8_nchw_nchw44_dot
(
int8_t
*
dst_ptr
,
const
int8_t
*
src_ptr
,
const
int
oc
,
const
int
ic
,
const
int
fh
,
const
int
fw
,
const
int
fw2
)
{
constexpr
int
oc_step
=
4
;
const
int
fw_remain
=
fw2
-
fw
;
const
int
dst_ic_stride
=
fh
*
fw2
;
const
int
oc_step_stride
=
fh
*
fw2
*
ic
*
oc_step
;
static
const
uint8_t
transpose_4x4_idx
[
16
]
=
{
0
,
4
,
8
,
12
,
1
,
5
,
9
,
13
,
2
,
6
,
10
,
14
,
3
,
7
,
11
,
15
};
uint8x16_t
tbl_transpose_4x4
=
vld1q_u8
(
&
transpose_4x4_idx
[
0
]);
rep_step
(
oc_idx
,
oc
,
oc_step
)
{
int32_t
*
dst_temp_ptr
=
reinterpret_cast
<
int32_t
*>
(
dst_ptr
+
oc_idx
*
ic
*
fh
*
fw2
);
const
int32_t
*
src_temp_ptr
=
reinterpret_cast
<
const
int32_t
*>
(
src_ptr
+
oc_idx
*
ic
*
fh
*
fw
);
// transpose ic and pad
rep
(
fh_idx
,
fh
)
{
rep
(
fw_idx
,
fw
)
{
rep
(
ic_idx
,
ic
)
{
*
(
dst_temp_ptr
+
ic_idx
*
dst_ic_stride
)
=
*
src_temp_ptr
;
src_temp_ptr
++
;
}
dst_temp_ptr
++
;
}
rep
(
ic_idx
,
ic
)
{
memset
(
dst_temp_ptr
+
ic_idx
*
dst_ic_stride
,
0
,
sizeof
(
int8_t
)
*
oc_step
*
fw_remain
);
}
dst_temp_ptr
+=
fw_remain
;
}
// transpose fw oc
int8_t
*
trans_dst_temp_ptr
=
reinterpret_cast
<
int8_t
*>
(
dst_ptr
+
oc_idx
*
ic
*
fh
*
fw2
);
rep_step
(
idx
,
oc_step_stride
,
16
)
{
int8x16_t
temp
=
vld1q_s8
(
trans_dst_temp_ptr
+
idx
);
vst1q_s8
(
trans_dst_temp_ptr
+
idx
,
vqtbl1q_s8
(
temp
,
tbl_transpose_4x4
));
}
}
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
filter_size
,
int
stride
>
static
void
conv_direct_int8_nchw_nchw44_dot
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
int
oc
,
const
int
ic
,
const
int
ih
,
const
int
iw
,
const
int
oh
,
const
int
oh_block
,
const
int
ow
,
const
Op
&
op
)
{
MEGDNN_MARK_USED_VAR
(
temp
);
constexpr
int
fh
=
filter_size
;
constexpr
int
fw
=
(
filter_size
+
3
)
/
4
*
4
;
#if MEGDNN_AARCH64
constexpr
int
big_oc_step
=
8
;
#else
constexpr
int
big_oc_step
=
4
;
#endif
constexpr
int
oc_step
=
4
;
constexpr
int
ih_step
=
1
;
constexpr
int
oh_step
=
1
;
constexpr
int
ow_step
=
8
;
constexpr
int
stride_h
=
stride
;
constexpr
int
stride_w
=
stride
;
constexpr
int
pack_iw_len
=
stride
==
2
?
1
:
4
;
const
int
img_stride
=
oh
*
ow
;
const
int
ow_end
=
ow
/
ow_step
*
ow_step
;
const
int
ow_remain
=
ow
-
ow_end
;
const
int
oc_end
=
oc
/
big_oc_step
*
big_oc_step
;
const
int
oc_remain
=
oc
-
oc_end
;
const
int
ld_dst_oc
=
oc_step
*
img_stride
;
using
remain_fun
=
std
::
function
<
void
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
>
;
remain_fun
kern_big_oc_remain
=
nullptr
;
remain_fun
kern_small_oc_remain
=
nullptr
;
switch
(
ow_remain
)
{
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \
big_oc_step, ow_step, stride>::impl; \
kern_small_oc_remain = \
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \
oc_step, ow_step, stride>::impl; \
break;
UNROLL_CALL_RAW
(
8
,
cb
);
default:
megdnn_assert
(
0
,
"no remain %d for kern"
,
ow_remain
);
}
for
(
int
oc_idx
=
0
;
oc_idx
<
oc_end
;
oc_idx
+=
big_oc_step
)
{
const
int
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
int
oh_idx
=
0
;
oh_idx
<
oh_block
;
oh_idx
+=
oh_step
)
{
for
(
int
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
int
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
*
ih_step
)
*
pack_iw_len
;
const
int
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
ow_step
,
filter_size
,
big_oc_step
,
ow_step
,
stride
>::
impl
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
if
(
ow_remain
>
0
)
{
const
int
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
*
ih_step
)
*
pack_iw_len
;
const
int
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
kern_big_oc_remain
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
}
}
if
(
oc_remain
>
0
)
{
int
oc_idx
=
oc_end
;
const
int
weight_offset
=
oc_idx
*
ic
*
fh
*
fw
;
for
(
int
oh_idx
=
0
;
oh_idx
<
oh_block
;
oh_idx
+=
oh_step
)
{
for
(
int
ow_idx
=
0
;
ow_idx
<
ow_end
;
ow_idx
+=
ow_step
)
{
const
int
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_idx
*
stride_w
*
ih_step
)
*
pack_iw_len
;
const
int
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_idx
)
*
oc_step
;
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
ow_step
,
filter_size
,
oc_step
,
ow_step
,
stride
>::
impl
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
if
(
ow_remain
>
0
)
{
const
int
src_offset
=
(
oh_idx
*
stride_h
*
iw
+
ow_end
*
stride_w
*
ih_step
)
*
pack_iw_len
;
const
int
dst_offset
=
oc_idx
*
img_stride
+
(
oh_idx
*
ow
+
ow_end
)
*
oc_step
;
kern_small_oc_remain
(
src
+
src_offset
,
filter
+
weight_offset
,
bias
+
oc_idx
,
dst
+
dst_offset
,
ic
,
ih
,
iw
,
ld_dst_oc
,
op
);
}
}
}
}
}
// namespace
#endif
// vim: syntax=cpp.doxygen
dnn/src/arm_common/conv_bias/intrinsic_helper.h
浏览文件 @
6c29548d
...
...
@@ -176,187 +176,202 @@ inline void store_ocx_ow4_remain_static(T& c, const Op& op, int8_t* dst_ptr,
StoreOcxOw4Remain
<
c_dim
,
ow_remain
,
Op
,
T
>::
impl
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
////////////////////Store_OCX_OW8_Remain/////////////////////////
template
<
int
c_dim
,
int
ow_remain
,
typename
Op
,
typename
T
>
template
<
int
c_dim
,
int
ow_remain
,
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
);
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
ld_dst_oc
);
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
2
,
0
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
dst_ptr
+
16
);
op
({{
c
[
0
][
6
],
c
[
0
][
7
]}},
dst_ptr
+
24
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
2
,
0
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
8
)
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
16
)
);
op
({{
c
[
0
][
6
],
c
[
0
][
7
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
24
)
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
dst_ptr
+
ld_dst_oc
);
op
({{
c
[
1
][
2
],
c
[
1
][
3
]}},
dst_ptr
+
ld_dst_oc
+
8
);
op
({{
c
[
1
][
4
],
c
[
1
][
5
]}},
dst_ptr
+
ld_dst_oc
+
16
);
op
({{
c
[
1
][
6
],
c
[
1
][
7
]}},
dst_ptr
+
ld_dst_oc
+
24
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
));
op
({{
c
[
1
][
2
],
c
[
1
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
+
8
));
op
({{
c
[
1
][
4
],
c
[
1
][
5
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
+
16
));
op
({{
c
[
1
][
6
],
c
[
1
][
7
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
+
24
));
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
2
,
8
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
dst_ptr
+
16
);
op
({{
c
[
0
][
6
],
c
[
0
][
7
]}},
dst_ptr
+
24
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
2
,
8
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
8
)
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
16
)
);
op
({{
c
[
0
][
6
],
c
[
0
][
7
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
24
)
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
dst_ptr
+
ld_dst_oc
);
op
({{
c
[
1
][
2
],
c
[
1
][
3
]}},
dst_ptr
+
ld_dst_oc
+
8
);
op
({{
c
[
1
][
4
],
c
[
1
][
5
]}},
dst_ptr
+
ld_dst_oc
+
16
);
op
({{
c
[
1
][
6
],
c
[
1
][
7
]}},
dst_ptr
+
ld_dst_oc
+
24
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
));
op
({{
c
[
1
][
2
],
c
[
1
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
+
8
));
op
({{
c
[
1
][
4
],
c
[
1
][
5
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
+
16
));
op
({{
c
[
1
][
6
],
c
[
1
][
7
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
+
24
));
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
2
,
7
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
dst_ptr
+
16
);
op
(
c
[
0
][
6
],
dst_ptr
+
24
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
2
,
7
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
8
)
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
16
)
);
op
(
c
[
0
][
6
],
reinterpret_cast
<
T3
>
(
dst_ptr
+
24
)
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
dst_ptr
+
ld_dst_oc
);
op
({{
c
[
1
][
2
],
c
[
1
][
3
]}},
dst_ptr
+
ld_dst_oc
+
8
);
op
({{
c
[
1
][
4
],
c
[
1
][
5
]}},
dst_ptr
+
ld_dst_oc
+
16
);
op
(
c
[
1
][
6
],
dst_ptr
+
ld_dst_oc
+
24
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
));
op
({{
c
[
1
][
2
],
c
[
1
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
+
8
));
op
({{
c
[
1
][
4
],
c
[
1
][
5
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
+
16
));
op
(
c
[
1
][
6
],
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
+
24
));
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
2
,
6
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
dst_ptr
+
16
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
2
,
6
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
8
)
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
16
)
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
dst_ptr
+
ld_dst_oc
);
op
({{
c
[
1
][
2
],
c
[
1
][
3
]}},
dst_ptr
+
ld_dst_oc
+
8
);
op
({{
c
[
1
][
4
],
c
[
1
][
5
]}},
dst_ptr
+
ld_dst_oc
+
16
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
));
op
({{
c
[
1
][
2
],
c
[
1
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
+
8
));
op
({{
c
[
1
][
4
],
c
[
1
][
5
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
+
16
));
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
2
,
5
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
(
c
[
0
][
4
],
dst_ptr
+
16
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
2
,
5
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
8
)
);
op
(
c
[
0
][
4
],
reinterpret_cast
<
T3
>
(
dst_ptr
+
16
)
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
dst_ptr
+
ld_dst_oc
);
op
({{
c
[
1
][
2
],
c
[
1
][
3
]}},
dst_ptr
+
ld_dst_oc
+
8
);
op
(
c
[
1
][
4
],
dst_ptr
+
ld_dst_oc
+
16
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
)
);
op
({{
c
[
1
][
2
],
c
[
1
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
+
8
)
);
op
(
c
[
1
][
4
],
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
+
16
)
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
2
,
4
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
2
,
4
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
8
)
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
dst_ptr
+
ld_dst_oc
);
op
({{
c
[
1
][
2
],
c
[
1
][
3
]}},
dst_ptr
+
ld_dst_oc
+
8
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
)
);
op
({{
c
[
1
][
2
],
c
[
1
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
+
8
)
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
2
,
3
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
(
c
[
0
][
2
],
dst_ptr
+
8
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
2
,
3
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
op
(
c
[
0
][
2
],
reinterpret_cast
<
T3
>
(
dst_ptr
+
8
)
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
dst_ptr
+
ld_dst_oc
);
op
(
c
[
1
][
2
],
dst_ptr
+
ld_dst_oc
+
8
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
)
);
op
(
c
[
1
][
2
],
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
+
8
)
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
2
,
2
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
dst_ptr
+
ld_dst_oc
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
2
,
2
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
ld_dst_oc
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
op
({{
c
[
1
][
0
],
c
[
1
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
)
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
2
,
1
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
ld_dst_oc
)
{
op
(
c
[
0
][
0
],
dst_ptr
);
op
(
c
[
1
][
0
],
dst_ptr
+
ld_dst_oc
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
2
,
1
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
ld_dst_oc
)
{
op
(
c
[
0
][
0
],
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
op
(
c
[
1
][
0
],
reinterpret_cast
<
T3
>
(
dst_ptr
+
ld_dst_oc
)
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
1
,
0
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
dst_ptr
+
16
);
op
({{
c
[
0
][
6
],
c
[
0
][
7
]}},
dst_ptr
+
24
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
1
,
0
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
8
)
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
16
)
);
op
({{
c
[
0
][
6
],
c
[
0
][
7
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
24
)
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
1
,
8
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
dst_ptr
+
16
);
op
({{
c
[
0
][
6
],
c
[
0
][
7
]}},
dst_ptr
+
24
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
1
,
8
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
8
)
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
16
)
);
op
({{
c
[
0
][
6
],
c
[
0
][
7
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
24
)
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
1
,
7
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
dst_ptr
+
16
);
op
(
c
[
0
][
6
],
dst_ptr
+
24
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
1
,
7
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
8
)
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
16
)
);
op
(
c
[
0
][
6
],
reinterpret_cast
<
T3
>
(
dst_ptr
+
24
)
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
1
,
6
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
dst_ptr
+
16
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
1
,
6
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
8
)
);
op
({{
c
[
0
][
4
],
c
[
0
][
5
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
16
)
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
1
,
5
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
op
(
c
[
0
][
4
],
dst_ptr
+
16
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
1
,
5
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
8
)
);
op
(
c
[
0
][
4
],
reinterpret_cast
<
T3
>
(
dst_ptr
+
16
)
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
1
,
4
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
dst_ptr
+
8
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
1
,
4
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
op
({{
c
[
0
][
2
],
c
[
0
][
3
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
+
8
)
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
1
,
3
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
op
(
c
[
0
][
2
],
dst_ptr
+
8
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
1
,
3
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
op
(
c
[
0
][
2
],
reinterpret_cast
<
T3
>
(
dst_ptr
+
8
)
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
1
,
2
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
dst_ptr
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
1
,
2
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
)
{
op
({{
c
[
0
][
0
],
c
[
0
][
1
]}},
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
}
};
template
<
typename
Op
,
typename
T
>
struct
StoreOcxOw8Remain
<
1
,
1
,
Op
,
T
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
int
)
{
op
(
c
[
0
][
0
],
dst_ptr
);
template
<
typename
Op
,
typename
T
,
typename
T2
,
typename
T3
>
struct
StoreOcxOw8Remain
<
1
,
1
,
Op
,
T
,
T2
,
T3
>
{
static
void
impl
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
)
{
op
(
c
[
0
][
0
],
reinterpret_cast
<
T3
>
(
dst_ptr
)
);
}
};
template
<
int
c_dim
,
int
ow_remain
,
typename
Op
,
typename
T
>
inline
void
store_ocx_ow8_remain_static
(
T
&
c
,
const
Op
&
op
,
float32_t
*
dst_ptr
,
template
<
int
c_dim
,
int
ow_remain
,
typename
Op
,
typename
T
,
typename
T2
>
inline
void
store_ocx_ow8_remain_static
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
ld_dst_oc
)
{
StoreOcxOw8Remain
<
c_dim
,
ow_remain
,
Op
,
T
>::
impl
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
StoreOcxOw8Remain
<
c_dim
,
ow_remain
,
Op
,
T
,
T2
,
T2
>::
impl
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
template
<
int
c_dim
,
int
ow_remain
,
typename
Op
,
typename
T3
,
typename
T
,
typename
T2
>
inline
void
store_ocx_ow8_remain_static_dt
(
T
&
c
,
const
Op
&
op
,
T2
dst_ptr
,
int
ld_dst_oc
)
{
StoreOcxOw8Remain
<
c_dim
,
ow_remain
,
Op
,
T
,
T2
,
T3
>::
impl
(
c
,
op
,
dst_ptr
,
ld_dst_oc
);
}
////////////////////Store_OC8_OW8_Remain/////////////////////////
...
...
@@ -522,68 +537,84 @@ inline void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr,
}
}
/////////////////////////init_ocx_ow8////////////////////
inline
float32x4_t
neon_vdupq_n
(
float
val
)
{
return
vdupq_n_f32
(
val
);
}
inline
int32x4_t
neon_vdupq_n
(
int
val
)
{
return
vdupq_n_s32
(
val
);
}
inline
float32x4_t
neon_vld1q
(
const
float
*
ptr
)
{
return
vld1q_f32
(
ptr
);
}
inline
int32x4_t
neon_vld1q
(
const
int
*
ptr
)
{
return
vld1q_s32
(
ptr
);
}
template
<
int
c_dim
,
BiasMode
bias_mode
,
int
ow_block
,
typename
T
,
typename
T2
>
struct
InitOcxOw8
{
static
void
impl
(
T
&
c
,
T2
bias_ptr
,
int
oc_step
);
static
void
impl
(
T
&
c
,
const
T2
*
bias_ptr
,
int
oc_step
);
};
template
<
typename
T
,
typename
T2
>
struct
InitOcxOw8
<
2
,
BiasMode
::
NO_BIAS
,
8
,
T
,
T2
>
{
static
void
impl
(
T
&
c
,
const
float32_t
*
,
int
)
{
#define BAIS_INIT(step) \
c[0][step] =
vdupq_n_f32(0
); \
c[1][step] =
vdupq_n_f32(0
);
static
void
impl
(
T
&
c
,
const
T2
*
,
int
)
{
#define BAIS_INIT(step)
\
c[0][step] =
neon_vdupq_n(static_cast<T2>(0)
); \
c[1][step] =
neon_vdupq_n(static_cast<T2>(0)
);
UNROLL_CALL_RAW
(
8
,
BAIS_INIT
);
#undef BAIS_INIT
}
};
template
<
typename
T
,
typename
T2
>
struct
InitOcxOw8
<
2
,
BiasMode
::
NO_BIAS
,
4
,
T
,
T2
>
{
static
void
impl
(
T
&
c
,
const
float32_t
*
,
int
)
{
#define BAIS_INIT(step) \
c[0][step] =
vdupq_n_f32(0
); \
c[1][step] =
vdupq_n_f32(0
);
static
void
impl
(
T
&
c
,
const
T2
*
,
int
)
{
#define BAIS_INIT(step)
\
c[0][step] =
neon_vdupq_n(static_cast<T2>(0)
); \
c[1][step] =
neon_vdupq_n(static_cast<T2>(0)
);
UNROLL_CALL_RAW
(
4
,
BAIS_INIT
);
#undef BAIS_INIT
}
};
template
<
typename
T
,
typename
T2
>
struct
InitOcxOw8
<
2
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
8
,
T
,
T2
>
{
static
void
impl
(
T
&
c
,
const
float32_t
*
bias_ptr
,
int
oc_step
)
{
#define BAIS_INIT(step) \
c[0][step] =
vld1q_f32
(bias_ptr); \
c[1][step] =
vld1q_f32
(bias_ptr + oc_step);
static
void
impl
(
T
&
c
,
const
T2
*
bias_ptr
,
int
oc_step
)
{
#define BAIS_INIT(step)
\
c[0][step] =
neon_vld1q
(bias_ptr); \
c[1][step] =
neon_vld1q
(bias_ptr + oc_step);
UNROLL_CALL_RAW
(
8
,
BAIS_INIT
);
#undef BAIS_INIT
}
};
template
<
typename
T
,
typename
T2
>
struct
InitOcxOw8
<
2
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
4
,
T
,
T2
>
{
static
void
impl
(
T
&
c
,
const
float32_t
*
bias_ptr
,
int
oc_step
)
{
#define BAIS_INIT(step) \
c[0][step] =
vld1q_f32
(bias_ptr); \
c[1][step] =
vld1q_f32
(bias_ptr + oc_step);
static
void
impl
(
T
&
c
,
const
T2
*
bias_ptr
,
int
oc_step
)
{
#define BAIS_INIT(step)
\
c[0][step] =
neon_vld1q
(bias_ptr); \
c[1][step] =
neon_vld1q
(bias_ptr + oc_step);
UNROLL_CALL_RAW
(
4
,
BAIS_INIT
);
#undef BAIS_INIT
}
};
template
<
typename
T
,
typename
T2
>
struct
InitOcxOw8
<
2
,
BiasMode
::
BIAS
,
8
,
T
,
T2
>
{
static
void
impl
(
T
&
c
,
const
float32_t
*
bias_ptr
,
int
oc_step
)
{
static
void
impl
(
T
&
c
,
const
T2
*
bias_ptr
,
int
oc_step
)
{
constexpr
int
simd_len
=
4
;
#define BAIS_INIT(step) \
c[0][step] =
vld1q_f32
(bias_ptr + step * simd_len); \
c[1][step] =
vld1q_f32
(bias_ptr + oc_step + step * simd_len);
#define BAIS_INIT(step)
\
c[0][step] =
neon_vld1q
(bias_ptr + step * simd_len); \
c[1][step] =
neon_vld1q
(bias_ptr + oc_step + step * simd_len);
UNROLL_CALL_RAW
(
8
,
BAIS_INIT
);
#undef BAIS_INIT
}
};
template
<
typename
T
,
typename
T2
>
struct
InitOcxOw8
<
2
,
BiasMode
::
BIAS
,
4
,
T
,
T2
>
{
static
void
impl
(
T
&
c
,
const
float32_t
*
bias_ptr
,
int
oc_step
)
{
static
void
impl
(
T
&
c
,
const
T2
*
bias_ptr
,
int
oc_step
)
{
constexpr
int
simd_len
=
4
;
#define BAIS_INIT(step) \
c[0][step] =
vld1q_f32
(bias_ptr + step * simd_len); \
c[1][step] =
vld1q_f32
(bias_ptr + oc_step + step * simd_len);
#define BAIS_INIT(step)
\
c[0][step] =
neon_vld1q
(bias_ptr + step * simd_len); \
c[1][step] =
neon_vld1q
(bias_ptr + oc_step + step * simd_len);
UNROLL_CALL_RAW
(
4
,
BAIS_INIT
);
#undef BAIS_INIT
}
...
...
@@ -591,57 +622,57 @@ struct InitOcxOw8<2, BiasMode::BIAS, 4, T, T2> {
template
<
typename
T
,
typename
T2
>
struct
InitOcxOw8
<
1
,
BiasMode
::
NO_BIAS
,
8
,
T
,
T2
>
{
static
void
impl
(
T
&
c
,
const
float32_t
*
,
int
)
{
#define BAIS_INIT(step) c[0][step] =
vdupq_n_f32(0
);
static
void
impl
(
T
&
c
,
const
T2
*
,
int
)
{
#define BAIS_INIT(step) c[0][step] =
neon_vdupq_n(static_cast<T2>(0)
);
UNROLL_CALL_RAW
(
8
,
BAIS_INIT
);
#undef BAIS_INIT
}
};
template
<
typename
T
,
typename
T2
>
struct
InitOcxOw8
<
1
,
BiasMode
::
NO_BIAS
,
4
,
T
,
T2
>
{
static
void
impl
(
T
&
c
,
const
float32_t
*
,
int
)
{
#define BAIS_INIT(step) c[0][step] =
vdupq_n_f32(0
);
static
void
impl
(
T
&
c
,
const
T2
*
,
int
)
{
#define BAIS_INIT(step) c[0][step] =
neon_vdupq_n(static_cast<T2>(0)
);
UNROLL_CALL_RAW
(
4
,
BAIS_INIT
);
#undef BAIS_INIT
}
};
template
<
typename
T
,
typename
T2
>
struct
InitOcxOw8
<
1
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
8
,
T
,
T2
>
{
static
void
impl
(
T
&
c
,
const
float32_t
*
bias_ptr
,
int
)
{
#define BAIS_INIT(step) c[0][step] =
vld1q_f32
(bias_ptr);
static
void
impl
(
T
&
c
,
const
T2
*
bias_ptr
,
int
)
{
#define BAIS_INIT(step) c[0][step] =
neon_vld1q
(bias_ptr);
UNROLL_CALL_RAW
(
8
,
BAIS_INIT
);
#undef BAIS_INIT
}
};
template
<
typename
T
,
typename
T2
>
struct
InitOcxOw8
<
1
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
4
,
T
,
T2
>
{
static
void
impl
(
T
&
c
,
const
float32_t
*
bias_ptr
,
int
)
{
#define BAIS_INIT(step) c[0][step] =
vld1q_f32
(bias_ptr);
static
void
impl
(
T
&
c
,
const
T2
*
bias_ptr
,
int
)
{
#define BAIS_INIT(step) c[0][step] =
neon_vld1q
(bias_ptr);
UNROLL_CALL_RAW
(
4
,
BAIS_INIT
);
#undef BAIS_INIT
}
};
template
<
typename
T
,
typename
T2
>
struct
InitOcxOw8
<
1
,
BiasMode
::
BIAS
,
8
,
T
,
T2
>
{
static
void
impl
(
T
&
c
,
const
float32_t
*
bias_ptr
,
int
)
{
static
void
impl
(
T
&
c
,
const
T2
*
bias_ptr
,
int
)
{
constexpr
int
simd_len
=
4
;
#define BAIS_INIT(step) c[0][step] =
vld1q_f32
(bias_ptr + step * simd_len);
#define BAIS_INIT(step) c[0][step] =
neon_vld1q
(bias_ptr + step * simd_len);
UNROLL_CALL_RAW
(
8
,
BAIS_INIT
);
#undef BAIS_INIT
}
};
template
<
typename
T
,
typename
T2
>
struct
InitOcxOw8
<
1
,
BiasMode
::
BIAS
,
4
,
T
,
T2
>
{
static
void
impl
(
T
&
c
,
const
float32_t
*
bias_ptr
,
int
)
{
static
void
impl
(
T
&
c
,
const
T2
*
bias_ptr
,
int
)
{
constexpr
int
simd_len
=
4
;
#define BAIS_INIT(step) c[0][step] =
vld1q_f32
(bias_ptr + step * simd_len);
#define BAIS_INIT(step) c[0][step] =
neon_vld1q
(bias_ptr + step * simd_len);
UNROLL_CALL_RAW
(
4
,
BAIS_INIT
);
#undef BAIS_INIT
}
};
template
<
int
c_dim
,
BiasMode
bias_mode
,
int
ow_block
,
typename
T
,
typename
T2
>
inline
void
init_ocx_ow8
(
T
&
c
,
T2
bias_ptr
,
int
oc_step
)
{
inline
void
init_ocx_ow8
(
T
&
c
,
const
T2
*
bias_ptr
,
int
oc_step
)
{
InitOcxOw8
<
c_dim
,
bias_mode
,
ow_block
,
T
,
T2
>::
impl
(
c
,
bias_ptr
,
oc_step
);
}
/////////////////////init_ocx_ow4/////////////////////
...
...
dnn/src/arm_common/conv_bias/opr_impl.cpp
浏览文件 @
6c29548d
...
...
@@ -55,6 +55,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8ChanWiseStride2NCHW44
s8_channel_wise_stride2_nchw44
;
#if __ARM_FEATURE_DOTPROD
AlgoDotS8DirectNCHWNCHW44
ds8_direct_stride2_nchw_nchw44
;
AlgoDotS8DirectStride1
ds8_direct_stride1_large_group
{
true
};
AlgoDotS8DirectStride1
ds8_direct_stride1_small_group
{
false
};
AlgoDotS8DirectStride2
ds8_direct_stride2_large_group
{
true
};
...
...
@@ -93,6 +94,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
public:
AlgoPack
()
{
#if __ARM_FEATURE_DOTPROD
direct_algos
.
emplace_back
(
&
ds8_direct_stride2_nchw_nchw44
);
direct_algos
.
emplace_back
(
&
ds8_direct_stride1_large_group
);
direct_algos
.
emplace_back
(
&
ds8_direct_stride1_small_group
);
direct_algos
.
emplace_back
(
&
ds8_direct_stride2_large_group
);
...
...
dnn/src/arm_common/conv_bias/opr_impl.h
浏览文件 @
6c29548d
...
...
@@ -62,6 +62,7 @@ private:
class
AlgoFP16WinogradF23_8x8
;
#endif
#if __ARM_FEATURE_DOTPROD
class
AlgoDotS8DirectNCHWNCHW44
;
class
AlgoDotS8DirectStride1
;
class
AlgoDotS8DirectStride2
;
class
AlgoDotU8DirectStride1
;
...
...
dnn/src/arm_common/neon_struct.h
浏览文件 @
6c29548d
...
...
@@ -60,6 +60,14 @@ struct Vfmaq_laneq_f32 {
return
vfmaq_laneq_f32
(
a
,
b
,
v
,
lane
);
}
};
#if __ARM_FEATURE_DOTPROD
struct
Vdotq_laneq_s32
{
template
<
const
int
lane
>
static
int32x4_t
impl
(
int32x4_t
a
,
int8x16_t
b
,
int8x16_t
v
)
{
return
vdotq_laneq_s32
(
a
,
b
,
v
,
lane
);
}
};
#endif
}
// namespace
}
// namespace megdnn
...
...
dnn/src/arm_common/simd_macro/marm_neon.h
浏览文件 @
6c29548d
...
...
@@ -481,37 +481,71 @@ UNROLL_CALL_RAW(4, cb);
#define vdup_laneq_s16(vec, lane) Vdup_laneq_s16_armv7<lane>::impl(vec)
namespace
{
template
<
int
lane
>
struct
Vfma
p
_laneq_f32_armv7
{
struct
Vfma
q
_laneq_f32_armv7
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
);
};
template
<
>
struct
Vfma
p
_laneq_f32_armv7
<
0
>
{
struct
Vfma
q
_laneq_f32_armv7
<
0
>
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
return
vmlaq_lane_f32
(
a
,
b
,
vget_low_f32
(
v
),
0
);
}
};
template
<
>
struct
Vfma
p
_laneq_f32_armv7
<
1
>
{
struct
Vfma
q
_laneq_f32_armv7
<
1
>
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
return
vmlaq_lane_f32
(
a
,
b
,
vget_low_f32
(
v
),
1
);
}
};
template
<
>
struct
Vfma
p
_laneq_f32_armv7
<
2
>
{
struct
Vfma
q
_laneq_f32_armv7
<
2
>
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
return
vmlaq_lane_f32
(
a
,
b
,
vget_high_f32
(
v
),
0
);
}
};
template
<
>
struct
Vfma
p
_laneq_f32_armv7
<
3
>
{
struct
Vfma
q
_laneq_f32_armv7
<
3
>
{
static
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
return
vmlaq_lane_f32
(
a
,
b
,
vget_high_f32
(
v
),
1
);
}
};
}
// namespace
#define vfmaq_laneq_f32(a, b, v, lane) \
Vfmap_laneq_f32_armv7<lane>::impl(a, b, v)
Vfmaq_laneq_f32_armv7<lane>::impl(a, b, v)
#if __ARM_FEATURE_DOTPROD
template
<
int
lane
>
struct
Vdotq_laneq_s32_armv7
{
static
int32x4_t
impl
(
int32x4_t
a
,
int8x16_t
b
,
int8x16_t
v
);
};
template
<
>
struct
Vdotq_laneq_s32_armv7
<
0
>
{
static
int32x4_t
impl
(
int32x4_t
a
,
int8x16_t
b
,
int8x16_t
v
)
{
return
vdotq_lane_s32
(
a
,
b
,
vget_low_s32
(
v
),
0
);
}
};
template
<
>
struct
Vdotq_laneq_s32_armv7
<
1
>
{
static
int32x4_t
impl
(
int32x4_t
a
,
int8x16_t
b
,
int8x16_t
v
)
{
return
vdotq_lane_s32
(
a
,
b
,
vget_low_s32
(
v
),
1
);
}
};
template
<
>
struct
Vdotq_laneq_s32_armv7
<
2
>
{
static
int32x4_t
impl
(
int32x4_t
a
,
int8x16_t
b
,
int8x16_t
v
)
{
return
vdotq_lane_s32
(
a
,
b
,
vget_high_s32
(
v
),
0
);
}
};
template
<
>
struct
Vdotq_laneq_s32_armv7
<
3
>
{
static
int32x4_t
impl
(
int32x4_t
a
,
int8x16_t
b
,
int8x16_t
v
)
{
return
vdotq_lane_s32
(
a
,
b
,
vget_high_f32
(
v
),
1
);
}
};
#define vdotq_laneq_s32(a, b, v, lane) \
Vdotq_laneq_s32_armv7<lane>::impl(a, b, v)
#endif
#endif
...
...
dnn/test/arm_common/conv_bias.cpp
浏览文件 @
6c29548d
...
...
@@ -109,14 +109,12 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
.
set_dtype
(
4
,
dtype
::
QuantizedS8
(
60.25
))
.
set_display
(
false
);
benchmarker_int
.
set_before_exec_callback
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBias
>
(
"IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384"
));
conv_bias
::
ConvBiasAlgoChecker
<
ConvBias
>
(
"IM2COLMATMUL:.+"
));
Benchmarker
<
ConvBias
>
benchmarker_float
(
handle
);
benchmarker_float
.
set_display
(
false
).
set_times
(
RUNS
);
benchmarker_float
.
set_before_exec_callback
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBias
>
(
"IM2COLMATMUL:AARCH64_F32K8X12X1:192"
));
conv_bias
::
ConvBiasAlgoChecker
<
ConvBias
>
(
"IM2COLMATMUL:.+"
));
Benchmarker
<
ConvBias
>
benchmarker_nchw44
(
handle
);
if
(
is_fp32
)
{
...
...
@@ -213,6 +211,15 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
run
(
1
,
256
,
256
,
14
,
14
,
3
,
1
,
false
);
run
(
1
,
512
,
512
,
7
,
7
,
3
,
1
,
false
);
}
else
{
run
(
1
,
1
,
4
,
112
,
112
,
2
,
2
,
true
);
run
(
1
,
3
,
32
,
224
,
224
,
3
,
2
,
true
);
run
(
1
,
3
,
32
,
224
,
224
,
5
,
2
,
true
);
run
(
1
,
3
,
64
,
224
,
224
,
7
,
2
,
true
);
run
(
1
,
1
,
4
,
112
,
112
,
2
,
1
,
true
);
run
(
1
,
3
,
32
,
224
,
224
,
3
,
1
,
true
);
run
(
1
,
3
,
32
,
224
,
224
,
5
,
1
,
true
);
run
(
1
,
3
,
64
,
224
,
224
,
7
,
1
,
true
);
for
(
size_t
stride
:
{
1
,
2
})
{
printf
(
"stride %zu
\n
"
,
stride
);
for
(
size_t
filter_size
:
{
2
,
3
,
5
,
7
})
{
...
...
@@ -228,9 +235,11 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
}
TEST_F
(
ARM_COMMON
,
BENCHMARK_CONVBIAS_NCHW44
)
{
benchmark_convbias
(
handle
(),
true
);
benchmark_convbias
(
handle
(),
false
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
BENCHMARK_CONVBIAS_NCHW44
)
{
benchmark_convbias
(
handle
(),
true
);
benchmark_convbias
(
handle
(),
false
);
}
#endif
...
...
dnn/test/arm_common/conv_bias_multi_thread.cpp
浏览文件 @
6c29548d
...
...
@@ -557,6 +557,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) {
/****************************dot qint8 direct*************************/
#if __ARM_FEATURE_DOTPROD
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_DOT_NCHW_NCHW44
)
{
checker_conv_bias_qint8x8x8
(
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
2
,
false
,
false
,
false
,
true
),
handle
(),
"ARMDOTS8_NCHW_NCHW44"
);
checker_conv_bias_qint8x8x8
(
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
1
,
false
,
false
,
false
,
true
),
handle
(),
"ARMDOTS8_NCHW_NCHW44"
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP
)
{
checker_conv_bias_qint8x8x8
(
get_int8_quint8_conv_bias_args
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录