Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
9e876203
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
9e876203
编写于
5月 21, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn): add int8 direct conv dot nchw44
GitOrigin-RevId: 31830ba7a49f7c0b9fb3f011e09f934601a825a0
上级
09ceaaae
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
1390 addition
and
1 deletion
+1390
-1
dnn/src/arm_common/conv_bias/int8/algos.h
dnn/src/arm_common/conv_bias/int8/algos.h
+22
-0
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp
+370
-0
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h
+87
-0
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp
.../arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp
+341
-0
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h
...rc/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h
+430
-0
dnn/src/arm_common/conv_bias/intrinsic_helper.h
dnn/src/arm_common/conv_bias/intrinsic_helper.h
+1
-0
dnn/src/arm_common/conv_bias/opr_impl.cpp
dnn/src/arm_common/conv_bias/opr_impl.cpp
+4
-0
dnn/src/arm_common/conv_bias/opr_impl.h
dnn/src/arm_common/conv_bias/opr_impl.h
+2
-0
dnn/test/arm_common/conv_bias.cpp
dnn/test/arm_common/conv_bias.cpp
+75
-0
dnn/test/arm_common/conv_bias_multi_thread.cpp
dnn/test/arm_common/conv_bias_multi_thread.cpp
+58
-1
未找到文件。
dnn/src/arm_common/conv_bias/int8/algos.h
浏览文件 @
9e876203
...
...
@@ -189,6 +189,28 @@ public:
fallback
::
ConvBiasImpl
*
opr
,
const
NCBKernSizeParam
&
param
)
const
override
;
};
class
ConvBiasImpl
::
AlgoDotS8Direct_NCHW44
final
:
public
AlgoBase
{
public:
AlgoDotS8Direct_NCHW44
()
{}
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"ARMDOTS8DIRECT_NCHW44"
;
}
bool
usable
(
FallbackConvBiasImpl
*
,
const
NCBKernSizeParam
&
,
AlgoSelectionStrategy
algo_selection_strategy
)
const
override
;
size_t
get_workspace
(
FallbackConvBiasImpl
*
,
const
NCBKernSizeParam
&
)
const
override
;
SmallVector
<
NCBKern
>
dispatch_kerns
(
fallback
::
ConvBiasImpl
*
opr
,
const
NCBKernSizeParam
&
param
)
const
override
;
bool
is_preferred
(
megdnn
::
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
override
;
};
#endif
class
ConvBiasImpl
::
AlgoS8WinogradF23_8x8
final
:
public
AlgoBase
{
...
...
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp
0 → 100644
浏览文件 @
9e876203
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#ifdef __ARM_FEATURE_DOTPROD
#include "src/arm_common/elemwise_helper/kimpl/typecvt.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h"
namespace
megdnn
{
namespace
arm_common
{
namespace
direct_dotprod_nchw44
{
template
<
>
void
copy_packed_src_int8_nchw44
<
1
>
(
int8_t
*
dst
,
const
int
dst_step
,
const
int8_t
*
src
,
const
int
src_step
,
const
int
ic
,
const
int
ic_step
,
const
int
ih
,
const
int
pad_left
,
const
int
pad_right
,
const
int
pad_top
,
const
int
pad_bottom
)
{
constexpr
int
IC_PACK_SIZE
=
4
;
rep_step
(
ic_idx
,
ic
,
IC_PACK_SIZE
)
{
const
int8_t
*
i_src
=
src
+
ic_idx
*
ic_step
;
//! pad top
int
bytes_pad_top
=
pad_top
*
dst_step
*
IC_PACK_SIZE
*
sizeof
(
int8_t
);
memset
(
dst
,
0
,
bytes_pad_top
);
dst
+=
bytes_pad_top
/
sizeof
(
int8_t
);
rep
(
ih_idx
,
ih
)
{
int
bytes_row_in_dst
=
dst_step
*
IC_PACK_SIZE
*
sizeof
(
int8_t
);
memset
(
dst
,
0
,
bytes_row_in_dst
);
//! left elements
int
pad_left_elements
=
pad_left
*
IC_PACK_SIZE
;
//! copy row [ih_idx, x]
int
bytes_copy
=
src_step
*
IC_PACK_SIZE
*
sizeof
(
int8_t
);
memcpy
(
dst
+
pad_left_elements
,
i_src
,
bytes_copy
);
//! dst move to next row
dst
+=
bytes_row_in_dst
/
sizeof
(
int8_t
);
//! src move to next row
i_src
+=
bytes_copy
/
sizeof
(
int8_t
);
}
//! pad bottom
int
bytes_pad_bottom
=
pad_bottom
*
dst_step
*
IC_PACK_SIZE
*
sizeof
(
int8_t
);
memset
(
dst
,
0
,
bytes_pad_bottom
);
dst
+=
bytes_pad_bottom
/
sizeof
(
int8_t
);
}
}
template
<
>
void
copy_packed_src_int8_nchw44
<
2
>
(
int8_t
*
dst
,
const
int
dst_step
,
const
int8_t
*
src
,
const
int
src_step
,
const
int
ic
,
const
int
ic_step
,
const
int
ih
,
const
int
pad_left
,
const
int
pad_right
,
const
int
pad_top
,
const
int
pad_bottom
)
{
constexpr
int
IC_PACK_SIZE
=
4
;
int
odd_start
=
megdnn
::
div_ceil
(
dst_step
,
2
);
bool
nochange
=
pad_left
%
2
==
0
;
rep_step
(
ic_idx
,
ic
,
IC_PACK_SIZE
)
{
const
int32_t
*
i_src
=
reinterpret_cast
<
const
int32_t
*>
(
src
+
ic_idx
*
ic_step
);
int
bytes_pad_top
=
pad_top
*
dst_step
*
IC_PACK_SIZE
*
sizeof
(
int8_t
);
memset
(
dst
,
0
,
bytes_pad_top
);
dst
+=
bytes_pad_top
/
sizeof
(
int8_t
);
rep
(
ih_idx
,
ih
)
{
int
bytes_row_in_dst
=
dst_step
*
IC_PACK_SIZE
*
sizeof
(
int8_t
);
memset
(
dst
,
0
,
bytes_row_in_dst
);
int32_t
*
dst_even
=
reinterpret_cast
<
int32_t
*>
(
dst
)
+
pad_left
/
2
+
pad_left
%
2
;
int32_t
*
dst_odd
=
reinterpret_cast
<
int32_t
*>
(
dst
)
+
odd_start
+
pad_left
/
2
;
int
i_src_idx
=
0
;
if
(
nochange
)
{
for
(;
i_src_idx
+
7
<
src_step
;
i_src_idx
+=
8
)
{
int32x4x2_t
tmp
;
tmp
=
vld2q_s32
(
i_src
+
i_src_idx
);
vst1q_s32
(
dst_even
,
tmp
.
val
[
0
]);
vst1q_s32
(
dst_odd
,
tmp
.
val
[
1
]);
dst_even
+=
4
;
dst_odd
+=
4
;
}
}
else
{
for
(;
i_src_idx
+
7
<
src_step
;
i_src_idx
+=
8
)
{
int32x4x2_t
tmp
;
tmp
=
vld2q_s32
(
i_src
+
i_src_idx
);
vst1q_s32
(
dst_even
,
tmp
.
val
[
1
]);
vst1q_s32
(
dst_odd
,
tmp
.
val
[
0
]);
dst_even
+=
4
;
dst_odd
+=
4
;
}
}
for
(;
i_src_idx
<
src_step
;
++
i_src_idx
)
{
if
(
nochange
)
{
if
(
i_src_idx
%
2
==
0
)
{
*
dst_even
=
*
(
i_src
+
i_src_idx
);
dst_even
++
;
}
else
{
*
dst_odd
=
*
(
i_src
+
i_src_idx
);
dst_odd
++
;
}
}
else
{
if
(
i_src_idx
%
2
==
0
)
{
*
dst_odd
=
*
(
i_src
+
i_src_idx
);
dst_odd
++
;
}
else
{
*
dst_even
=
*
(
i_src
+
i_src_idx
);
dst_even
++
;
}
}
}
//! dst move to next row
dst
+=
bytes_row_in_dst
/
sizeof
(
int8_t
);
//! src move to next row
i_src
+=
src_step
;
}
//! pad bottom
int
bytes_pad_bottom
=
pad_bottom
*
dst_step
*
IC_PACK_SIZE
*
sizeof
(
int8_t
);
memset
(
dst
,
0
,
bytes_pad_bottom
);
dst
+=
bytes_pad_bottom
/
sizeof
(
int8_t
);
}
}
template
<
typename
dst_type
,
int
stride
,
BiasMode
bias_mode
,
typename
Op
,
int
filter_size
>
void
conv_direct_sdot_int8_nchw44
(
dst_type
*
dst
,
const
int
oh
,
const
int
ow
,
const
int8_t
*
src
,
const
int
ih
,
const
int
iw
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
const
int
oh_size
,
const
int
oc
,
const
int
ic
,
const
Op
&
op
)
{
constexpr
int
FH
=
filter_size
;
constexpr
int
FW
=
filter_size
;
constexpr
int
IC_PACK_SIZE
=
4
;
constexpr
int
OC_PACK_SIZE
=
4
;
#if MEGDNN_AARCH64
constexpr
int
OC_BIG_INTERVAL
=
12
;
constexpr
int
OC_MID_INTERVAL
=
8
;
constexpr
int
OC_SMA_INTERVAL
=
4
;
#else
constexpr
int
OC_BIG_INTERVAL
=
4
;
constexpr
int
OC_MID_INTERVAL
=
4
;
constexpr
int
OC_SMA_INTERVAL
=
4
;
#endif
constexpr
int
OW_INTERVAL
=
8
;
constexpr
int
SH
=
stride
;
const
int
dst_numbers_per_channel
=
oh
*
ow
;
const
int
ow_remain
=
ow
%
OW_INTERVAL
;
const
int
ow_end_idx
=
ow
-
ow_remain
;
const
int
oc_remain
=
oc
%
OC_BIG_INTERVAL
;
//! NCHW44 means oc_remain = 4 or 8
const
int
oc_end_idx
=
oc
-
oc_remain
;
const
int
dst_numbers_4channel_packed
=
dst_numbers_per_channel
*
OC_PACK_SIZE
;
using
remain_fun
=
std
::
function
<
void
(
dst_type
*
dst
,
const
int
dst_step
,
const
int8_t
*
src
,
const
int
ih
,
const
int
iw
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
const
int
ic
,
const
Op
&
op
)
>
;
remain_fun
kern_big_oc_remain
=
nullptr
;
remain_fun
kern_mid_oc_remain
=
nullptr
;
remain_fun
kern_sma_oc_remain
=
nullptr
;
switch
(
ow_remain
)
{
#define cb(step) \
case step: \
kern_big_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_BIG_INTERVAL, \
OW_INTERVAL>::impl; \
kern_mid_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_MID_INTERVAL, \
OW_INTERVAL>::impl; \
kern_sma_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_SMA_INTERVAL, \
OW_INTERVAL>::impl; \
break;
UNROLL_CALL_RAW
(
8
,
cb
);
#undef cb
default:
megdnn_assert
(
0
,
"no remain %d for kern"
,
ow_remain
);
}
//! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC]
//! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL,
//! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates
//! [OW_INTERVAL, 1, OC_INTERVAL] each time
for
(
int
oc_idx
=
0
;
oc_idx
<
oc_end_idx
;
oc_idx
+=
OC_BIG_INTERVAL
)
{
const
int
filter_offset_in_element
=
oc_idx
*
ic
*
FH
*
FW
;
for
(
int
oh_idx
=
0
;
oh_idx
<
oh_size
;
++
oh_idx
)
{
for
(
int
ow_idx
=
0
;
ow_idx
<
ow_end_idx
;
ow_idx
+=
OW_INTERVAL
)
{
const
int
src_offset_in_element
=
(
oh_idx
*
SH
*
iw
+
ow_idx
)
*
IC_PACK_SIZE
;
const
int
dst_offset_in_element
=
oc_idx
*
dst_numbers_per_channel
+
(
oh_idx
*
ow
+
ow_idx
)
*
OC_PACK_SIZE
;
const
int
bias_offset_in_element
=
oc_idx
;
KernNeonSdotNCHW44
<
dst_type
,
stride
,
bias_mode
,
Op
,
OW_INTERVAL
,
filter_size
,
OC_BIG_INTERVAL
,
OW_INTERVAL
>::
impl
(
dst
+
dst_offset_in_element
,
dst_numbers_4channel_packed
,
src
+
src_offset_in_element
,
ih
,
iw
,
filter
+
filter_offset_in_element
,
bias
+
bias_offset_in_element
,
ic
,
op
);
}
if
(
ow_remain
)
{
const
int
src_offset_in_element
=
(
oh_idx
*
SH
*
iw
+
ow_end_idx
)
*
IC_PACK_SIZE
;
const
int
dst_offset_in_element
=
oc_idx
*
dst_numbers_per_channel
+
(
oh_idx
*
ow
+
ow_end_idx
)
*
OC_PACK_SIZE
;
const
int
bias_offset_in_element
=
oc_idx
;
kern_big_oc_remain
(
dst
+
dst_offset_in_element
,
dst_numbers_4channel_packed
,
src
+
src_offset_in_element
,
ih
,
iw
,
filter
+
filter_offset_in_element
,
bias
+
bias_offset_in_element
,
ic
,
op
);
}
}
}
#ifdef MEGDNN_AARCH64
//! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32
if
(
oc_remain
)
{
int
oc_idx
=
oc_end_idx
;
const
int
filter_offset_in_element
=
oc_idx
*
ic
*
FH
*
FW
;
for
(
int
oh_idx
=
0
;
oh_idx
<
oh_size
;
++
oh_idx
)
{
for
(
int
ow_idx
=
0
;
ow_idx
<
ow_end_idx
;
ow_idx
+=
OW_INTERVAL
)
{
const
int
src_offset_in_element
=
(
oh_idx
*
SH
*
iw
+
ow_idx
)
*
IC_PACK_SIZE
;
const
int
dst_offset_in_element
=
oc_idx
*
dst_numbers_per_channel
+
(
oh_idx
*
ow
+
ow_idx
)
*
OC_PACK_SIZE
;
const
int
bias_offset_in_element
=
oc_idx
;
if
(
oc_remain
==
8
)
{
KernNeonSdotNCHW44
<
dst_type
,
stride
,
bias_mode
,
Op
,
OW_INTERVAL
,
filter_size
,
OC_MID_INTERVAL
,
OW_INTERVAL
>::
impl
(
dst
+
dst_offset_in_element
,
dst_numbers_4channel_packed
,
src
+
src_offset_in_element
,
ih
,
iw
,
filter
+
filter_offset_in_element
,
bias
+
bias_offset_in_element
,
ic
,
op
);
}
else
{
KernNeonSdotNCHW44
<
dst_type
,
stride
,
bias_mode
,
Op
,
OW_INTERVAL
,
filter_size
,
OC_SMA_INTERVAL
,
OW_INTERVAL
>::
impl
(
dst
+
dst_offset_in_element
,
dst_numbers_4channel_packed
,
src
+
src_offset_in_element
,
ih
,
iw
,
filter
+
filter_offset_in_element
,
bias
+
bias_offset_in_element
,
ic
,
op
);
}
}
if
(
ow_remain
)
{
const
int
src_offset_in_element
=
(
oh_idx
*
SH
*
iw
+
ow_end_idx
)
*
IC_PACK_SIZE
;
const
int
dst_offset_in_element
=
oc_idx
*
dst_numbers_per_channel
+
(
oh_idx
*
ow
+
ow_end_idx
)
*
OC_PACK_SIZE
;
const
int
bias_offset_in_element
=
oc_idx
;
if
(
oc_remain
==
8
)
{
kern_mid_oc_remain
(
dst
+
dst_offset_in_element
,
dst_numbers_4channel_packed
,
src
+
src_offset_in_element
,
ih
,
iw
,
filter
+
filter_offset_in_element
,
bias
+
bias_offset_in_element
,
ic
,
op
);
}
else
{
kern_sma_oc_remain
(
dst
+
dst_offset_in_element
,
dst_numbers_4channel_packed
,
src
+
src_offset_in_element
,
ih
,
iw
,
filter
+
filter_offset_in_element
,
bias
+
bias_offset_in_element
,
ic
,
op
);
}
}
}
}
#endif
}
#define CONSTRUCT_FUNC(filter_size) \
template <typename dst_type, BiasMode bias_mode, typename Op, int stride> \
void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \
dst_type* dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op) { \
conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, Op, \
filter_size>( \
dst, oh, ow, src, ih, iw, weight, bias, oh_size, oc, ic, op); \
}
CONSTRUCT_FUNC
(
2
);
CONSTRUCT_FUNC
(
3
);
CONSTRUCT_FUNC
(
5
);
CONSTRUCT_FUNC
(
7
);
#undef CONSTRUCT_FUNC
#define INSTANTIATION(dst_type, stride, i, bias_mode, Op) \
template void conv_direct_##i##x##i##_int8_nchw44<dst_type, bias_mode, Op, \
stride>( \
dst_type * dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op);
#define FOR_OP(stride, i, bias_mode) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int32, stride, i, bias_mode, \
NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)
#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)
#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)
FOR_FILTER
(
1
)
FOR_FILTER
(
2
)
#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
}
// namespace direct_dotprod_nchw44
}
// namespace arm_common
}
// namespace megdnn
#endif
//vim: syntax=cpp.doxygen
\ No newline at end of file
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h
0 → 100644
浏览文件 @
9e876203
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#pragma once
#include "src/arm_common/conv_bias/opr_impl.h"
namespace
megdnn
{
namespace
arm_common
{
namespace
direct_dotprod_nchw44
{
using
BiasMode
=
ConvBiasForward
::
BiasMode
;
/**
* @brief : do direct conv with no side effect
* input buffer's size is [ih, iw]
* output buffer's size is [oh, ow]
* filter layout is [OC/4, IC/4, FH, FW, 4, 4]
*
* @param : [output ptr] dst
* [input] oh -> dst rows
* [input] ow -> dst cols
* [input ptr] src
* [input] ih -> rows of src used by this this kernel
* [input] iw -> src step in elements [iw2]
* [input ptr] filter
* [input ptr] bias
* [input] oh_size -> rows of result generated by this kernel
* [input] oc -> output channels
* [input] ic -> intput channels
* [input] op -> post process operator
* @return none
*/
#define KERN(filter_size) \
template <typename dst_type, BiasMode bias_mode, typename Op, int stride> \
void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \
dst_type* dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op)
KERN
(
2
);
KERN
(
3
);
KERN
(
5
);
KERN
(
7
);
#undef KERN
/**
* @brief : copy data from src to dst for direct conv with no side effect
* @param : [output ptr] dst
* [input] dst_step -> step of dst in numbers of elements
* [input ptr] src
* [input] src_step -> step of src in numbers of elements
* [input] ic -> input channels
* [input] ic_step -> step of ic in numbers of elements
* [input] ih -> totle rows to copy
* [input] pad_left -> cols padding at left
* [input] pad_right -> cols padding at right
* [input] pad_top -> rows padding at top
* [input] pad_bottom -> rows padding at bottom
* @return none
*/
template
<
int
stride
>
void
copy_packed_src_int8_nchw44
(
int8_t
*
dst
,
const
int
dst_step
,
const
int8_t
*
src
,
const
int
src_step
,
const
int
ic
,
const
int
ic_step
,
const
int
ih
,
const
int
pad_left
,
const
int
pad_right
,
const
int
pad_top
,
const
int
pad_bottom
);
}
// namespace direct_dotprod_nchw44
}
// namespace arm_common
}
// namespace megdnn
#endif
//vim: syntax=cpp.doxygen
\ No newline at end of file
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp
0 → 100644
浏览文件 @
9e876203
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_dotpord_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h"
#include "src/arm_common/elemwise_op.h"
#include "midout.h"
using
namespace
megdnn
;
using
namespace
arm_common
;
MIDOUT_DECL
(
megdnn_arm_common_conv_bias_int8
)
using
direct_fun
=
std
::
function
<
void
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
ncb_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
>
;
namespace
{
static
void
get_rectified_size
(
const
megdnn
::
fallback
::
ConvBiasImpl
::
NCBKernSizeParam
&
param
,
int
&
ih
,
int
&
iw
,
int
&
oh
,
int
&
ow
)
{
int
IC
=
param
.
filter_meta
.
icpg
;
int
IW
=
param
.
isz
[
1
];
int
OH
=
param
.
osz
[
0
];
int
OW
=
param
.
osz
[
1
];
oh
=
OH
;
ow
=
OW
;
constexpr
int
cacheline
=
64
/
sizeof
(
int8_t
);
int
oh_tile_size
=
l2_block_helper
(
param
.
nr_threads
,
OH
,
IC
*
IW
*
sizeof
(
int8_t
)
*
2
);
auto
&&
fm
=
param
.
filter_meta
;
const
int
SH
=
static_cast
<
int
>
(
fm
.
stride
[
0
]);
const
int
FH
=
static_cast
<
int
>
(
fm
.
spatial
[
0
]);
const
int
PW
=
static_cast
<
int
>
(
fm
.
padding
[
1
]);
ih
=
oh_tile_size
*
SH
+
FH
-
SH
;
iw
=
round_up
(
IW
+
2
*
PW
,
cacheline
);
}
static
inline
int
get_perthread_cache_bytes
(
const
int
ic
,
const
int
ih
,
const
int
iw
)
{
// border_size is used to avoid read illegal memory
int
border_size
=
64
*
2
;
return
ic
*
ih
*
iw
*
sizeof
(
int8_t
)
+
border_size
;
}
static
WorkspaceBundle
get_bundle
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
)
{
auto
&&
fm
=
param
.
filter_meta
;
int
IC
=
fm
.
icpg
;
int
ih2
,
iw2
,
oh2
,
ow2
;
get_rectified_size
(
param
,
ih2
,
iw2
,
oh2
,
ow2
);
int
bytes_of_copy_per_thread
=
get_perthread_cache_bytes
(
IC
,
ih2
,
iw2
);
return
{
nullptr
,
{
bytes_of_copy_per_thread
*
param
.
nr_threads
}};
}
template
<
typename
dst_type
,
size_t
filter_size
,
BiasMode
bias_mode
,
typename
Op
,
int
stride
>
static
void
conv_kern
(
WorkspaceBundle
bundle
,
const
ConvBiasImpl
::
NCBKernParam
&
ncb_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
{
const
int
OH
=
ncb_param
.
osz
[
0
];
const
int
OW
=
ncb_param
.
osz
[
1
];
const
int
FH
=
ncb_param
.
filter_meta
.
spatial
[
0
];
const
int
IC
=
ncb_param
.
filter_meta
.
icpg
;
const
int
OC
=
ncb_param
.
filter_meta
.
ocpg
;
const
int
IH
=
ncb_param
.
isz
[
0
];
const
int
IW
=
ncb_param
.
isz
[
1
];
const
int
SH
=
ncb_param
.
filter_meta
.
stride
[
0
];
const
int
PH
=
ncb_param
.
filter_meta
.
padding
[
0
];
const
int
PW
=
ncb_param
.
filter_meta
.
padding
[
1
];
int
ih2
=
0
;
int
iw2
=
0
;
int
oh2
=
0
;
int
ow2
=
0
;
get_rectified_size
(
ncb_param
,
ih2
,
iw2
,
oh2
,
ow2
);
constexpr
int
IC_PACK_SIZE
=
4
;
constexpr
int
OC_PACK_SIZE
=
4
;
bundle
.
set
(
ncb_param
.
workspace_ptr
);
const
int
batch_id
=
ncb_index
.
ndrange_id
[
0
];
const
int
group_id
=
ncb_index
.
ndrange_id
[
1
];
const
int
oh_tile_id
=
ncb_index
.
ndrange_id
[
2
];
const
int
thread_id
=
ncb_index
.
thread_id
;
const
int
oh_tile_size
=
l2_block_helper
(
ncb_param
.
nr_threads
,
OH
,
IC
*
IW
*
sizeof
(
int8_t
)
*
2
);
const
int
oh_start_row
=
oh_tile_id
*
oh_tile_size
;
const
int
ih_start_row
=
std
::
max
(
oh_start_row
*
SH
-
PH
,
0
);
const
int
oh_real_size
=
std
::
min
(
OH
-
oh_start_row
,
oh_tile_size
);
const
int
ih_real_size
=
oh_real_size
*
SH
+
FH
-
SH
;
const
int
rows_padding_at_top
=
std
::
max
(
PH
-
oh_start_row
*
SH
,
0
);
const
int
rows_padding_at_bottom
=
std
::
max
((
oh_start_row
+
oh_real_size
-
1
)
*
SH
+
FH
-
IH
-
PH
,
0
);
const
int
cols_padding_at_left
=
PW
;
const
int
cols_padding_at_right
=
std
::
max
(
iw2
-
IW
-
PW
,
0
);
//! src layout{IC/4, IH, IW, 4}
const
int
bytes_of_src_offset
=
ih_start_row
*
IW
*
IC_PACK_SIZE
*
sizeof
(
int8_t
);
const
int8_t
*
copy_src
=
static_cast
<
const
int8_t
*>
(
ncb_param
.
src
<
int8_t
>
(
batch_id
,
group_id
)
+
bytes_of_src_offset
);
const
int
bytes_of_copy_per_thread
=
get_perthread_cache_bytes
(
IC
,
ih2
,
iw2
);
int8_t
*
copy_dst
=
reinterpret_cast
<
int8_t
*>
(
bundle
.
get
(
0
))
+
thread_id
*
bytes_of_copy_per_thread
;
const
int
rows_copy_from_src
=
ih_real_size
-
rows_padding_at_top
-
rows_padding_at_bottom
;
direct_dotprod_nchw44
::
copy_packed_src_int8_nchw44
<
stride
>
(
copy_dst
,
iw2
,
copy_src
,
IW
,
IC
,
IH
*
IW
,
rows_copy_from_src
,
cols_padding_at_left
,
cols_padding_at_right
,
rows_padding_at_top
,
rows_padding_at_bottom
);
const
int8_t
*
weights
=
ncb_param
.
filter
<
int8_t
>
(
group_id
);
dst_type
*
dst
=
ncb_param
.
dst
<
dst_type
>
(
batch_id
,
group_id
)
+
oh_start_row
*
OW
*
OC_PACK_SIZE
;
//! only broadcast or no_bias
const
int32_t
*
bias
=
ncb_param
.
bias
<
int32_t
>
(
batch_id
,
group_id
);
Op
op
=
Op
(
1.0
f
,
4.0
f
);
if
(
ncb_param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)
{
float
scale_bias
=
ncb_param
.
bias_type
.
param
<
dtype
::
QuantizedS32
>
().
scale
;
float
scale_dst
=
ncb_param
.
dst_type
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
op
=
Op
(
scale_bias
,
scale_dst
);
}
#define KERN1_NCHW44_CONV(filter) \
direct_dotprod_nchw44::conv_direct_##filter##x##filter##_int8_nchw44< \
dst_type, bias_mode, Op, stride>(dst, OH, OW, copy_dst, \
ih_real_size, iw2, weights, bias, \
oh_real_size, OC, IC, op);
DISPATCH_FILTER
(
filter_size
,
KERN1_NCHW44_CONV
);
#undef KERN1_NCHW44_CONV
}
}
// namespace
bool
ConvBiasImpl
::
AlgoDotS8Direct_NCHW44
::
usable
(
FallbackConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
algo_selection_strategy
)
const
{
auto
&&
fm
=
param
.
filter_meta
;
auto
FH
=
fm
.
spatial
[
0
];
auto
FW
=
fm
.
spatial
[
1
];
auto
SH
=
fm
.
stride
[
0
];
auto
SW
=
fm
.
stride
[
1
];
auto
OC
=
fm
.
ocpg
;
auto
IC
=
fm
.
icpg
;
//! src and filter are qint8, dst is qint8.
bool
data_type_ok
=
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
param
.
filter_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
(
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
);
if
(
param
.
bias_type
.
valid
())
{
data_type_ok
&=
param
.
bias_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
;
}
data_type_ok
|=
param
.
src_type
.
enumv
()
==
DTypeEnum
::
Int8
&&
param
.
filter_type
.
enumv
()
==
DTypeEnum
::
Int8
&&
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
Int32
;
bool
layout_ok
=
fm
.
format
==
param
::
Convolution
::
Format
::
NCHW44_DOT
&&
IC
%
4
==
0
&&
OC
%
4
==
0
;
bool
param_ok
=
!
fm
.
should_flip
&&
fm
.
spatial_ndim
==
2
&&
fm
.
dilation
[
0
]
==
1
&&
fm
.
dilation
[
1
]
==
1
&&
FH
==
FW
&&
(
FH
>=
2
&&
FH
<=
7
);
bool
stride_ok
=
SH
==
SW
&&
(
SH
==
1
||
SH
==
2
);
return
data_type_ok
&&
layout_ok
&&
param_ok
&&
stride_ok
;
}
bool
ConvBiasImpl
::
AlgoDotS8Direct_NCHW44
::
is_preferred
(
megdnn
::
fallback
::
ConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
return
true
;
}
size_t
ConvBiasImpl
::
AlgoDotS8Direct_NCHW44
::
get_workspace
(
FallbackConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
return
get_bundle
(
param
).
total_size_in_bytes
();
}
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ConvBiasImpl
::
AlgoDotS8Direct_NCHW44
::
dispatch_kerns
(
FallbackConvBiasImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
MIDOUT_BEGIN
(
megdnn_arm_common_conv_bias_int8
,
midout_iv
(
"ALGODOTS8DIRECT_NCHW44"
_hash
))
{
auto
fm
=
param
.
filter_meta
;
size_t
BATCH
=
param
.
n
;
size_t
GROUP
=
fm
.
group
;
WorkspaceBundle
wbundle
=
get_bundle
(
param
);
direct_fun
kernel
=
nullptr
;
bool
quantized
=
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
;
#define DO_CONV_KERN_FUN(dst_type, filter, bias_mode, op, stride) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, \
midout_iv(#dst_type #filter #bias_mode #op##_hash)) { \
kernel = conv_kern<dst_type, filter, bias_mode, op, stride>; \
} \
MIDOUT_END();
#define GET_OP_PARAM(i, bias_mode, stride) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
if (quantized) { \
DO_CONV_KERN_FUN(dt_int8, i, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>, \
stride) \
} else { \
DO_CONV_KERN_FUN(dt_int32, i, bias_mode, \
NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>, \
stride) \
} \
break; \
case param::ConvBias::NonlineMode::RELU: \
if (quantized) { \
DO_CONV_KERN_FUN(dt_int8, i, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>, \
stride) \
} else { \
megdnn_assert("No support NoQuantized RELU"); \
} \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
if (quantized) { \
DO_CONV_KERN_FUN(dt_int8, i, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>, \
stride) \
} else { \
megdnn_assert("No support NoQuantized H_SWISH"); \
} \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_STRIDE_PARAM(filter, bias_mode) \
switch (fm.stride[0]) { \
case 1: \
GET_OP_PARAM(filter, bias_mode, 1); \
break; \
case 2: \
GET_OP_PARAM(filter, bias_mode, 2); \
break; \
default: \
megdnn_assert(0); \
}
#define GET_BIAS_MODE_PARAM(filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_STRIDE_PARAM(filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_STRIDE_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define SELECT_CONV_KERN() \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(7) \
break; \
default: \
megdnn_assert(0); \
break; \
}
SELECT_CONV_KERN
()
#undef DO_CONV_KERN_FUN
#undef GET_OP_PARAM
#undef GET_STRIDE_PARAM
#undef GET_BIAS_MODE_PARAM
#undef SELECT_CONV_KERN
megdnn_assert
(
kernel
);
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ret_kerns
;
int
OH
=
param
.
osz
[
0
];
int
IC
=
param
.
filter_meta
.
icpg
;
int
IW
=
param
.
isz
[
1
];
int
oh_tile_size
=
l2_block_helper
(
param
.
nr_threads
,
OH
,
IC
*
IW
*
sizeof
(
int8_t
)
*
2
);
size_t
oh_tiles
=
static_cast
<
size_t
>
(
div_ceil
(
OH
,
oh_tile_size
));
auto
do_conv
=
[
wbundle
,
kernel
](
const
NCBKernParam
&
ncb_param
,
const
NCBKernIndex
&
ncb_index
)
{
kernel
(
wbundle
,
ncb_param
,
std
::
move
(
ncb_index
));
};
ret_kerns
.
push_back
({
do_conv
,
{
BATCH
,
GROUP
,
oh_tiles
}});
return
ret_kerns
;
}
MIDOUT_END
();
return
{};
}
#endif
//vim: syntax=cpp.doxygen
\ No newline at end of file
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h
0 → 100644
浏览文件 @
9e876203
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#ifdef __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/intrinsic_helper.h"
#include "src/arm_common/neon_struct.h"
#include "src/common/unroll_macro.h"
namespace
megdnn
{
namespace
arm_common
{
namespace
direct_dotprod_nchw44
{
constexpr
int
SIMD_LEN
=
16
;
constexpr
int
IC_PACK_SIZE
=
4
;
constexpr
int
OC_PACK_SIZE
=
4
;
constexpr
int
filter_next_col
=
IC_PACK_SIZE
*
OC_PACK_SIZE
;
//! [OC/4, IC/4, FH, FW, 4OC, 4IC]
template
<
int
row
,
BiasMode
bias_mode
>
inline
void
init_ocx_ow8
(
int32x4_t
c
[][
8
],
const
int32_t
*
bias_ptr
,
int
oc_step
)
{
static_assert
(
row
==
1
||
row
==
2
||
row
==
3
,
"Invalid OC number."
);
if
(
bias_mode
==
BiasMode
::
BROADCAST_CHANNEL_BIAS
)
{
#define BIAS_INIT(step, i) c[i][step] = vld1q_s32(bias_ptr + i * oc_step);
switch
(
row
)
{
case
3
:
UNROLL_CALL_RAW
(
8
,
BIAS_INIT
,
2
);
case
2
:
UNROLL_CALL_RAW
(
8
,
BIAS_INIT
,
1
);
default:
UNROLL_CALL_RAW
(
8
,
BIAS_INIT
,
0
);
}
#undef BIAS_INIT
}
else
{
#define BIAS_INIT(step, i) c[i][step] = vdupq_n_s32(0);
switch
(
row
)
{
case
3
:
UNROLL_CALL_RAW
(
8
,
BIAS_INIT
,
2
);
case
2
:
UNROLL_CALL_RAW
(
8
,
BIAS_INIT
,
1
);
default:
UNROLL_CALL_RAW
(
8
,
BIAS_INIT
,
0
);
}
#undef BIAS_INIT
}
}
#define cb11(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8));
#define cb21(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \
op(res[1][col], \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8));
#define cb31(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \
op(res[1][col], \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8)); \
op(res[2][col], reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + \
ld_dst_oc + col / 2 * 8));
#define cb12(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8));
#define cb22(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \
op({{res[1][2 * step], res[1][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8));
#define cb32(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \
op({{res[1][2 * step], res[1][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8)); \
op({{res[2][2 * step], res[2][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + 2 * ld_dst_oc + step * 8));
template
<
int
row
,
int
ow_remain
,
typename
Op
,
typename
T
>
struct
StoreOCxOWx
{
static
void
impl
(
int32x4_t
res
[][
8
],
const
Op
&
op
,
T
*
dst_ptr
,
const
int
ld_dst_oc
);
};
template
<
int
ow_remain
,
typename
Op
,
typename
T
>
struct
StoreOCxOWx
<
1
,
ow_remain
,
Op
,
T
>
{
static
void
impl
(
int32x4_t
res
[][
8
],
const
Op
&
op
,
T
*
dst_ptr
,
const
int
ld_dst_oc
)
{
switch
(
ow_remain
)
{
case
8
:
UNROLL_CALL_RAW
(
4
,
cb12
);
break
;
case
7
:
cb11
(
6
);
case
6
:
UNROLL_CALL_RAW
(
3
,
cb12
);
break
;
case
5
:
cb11
(
4
);
case
4
:
UNROLL_CALL_RAW
(
2
,
cb12
);
break
;
case
3
:
cb11
(
2
);
case
2
:
UNROLL_CALL_RAW
(
1
,
cb12
);
break
;
case
1
:
cb11
(
0
);
default:
break
;
}
}
};
template
<
int
ow_remain
,
typename
Op
,
typename
T
>
struct
StoreOCxOWx
<
2
,
ow_remain
,
Op
,
T
>
{
static
void
impl
(
int32x4_t
res
[][
8
],
const
Op
&
op
,
T
*
dst_ptr
,
const
int
ld_dst_oc
)
{
switch
(
ow_remain
)
{
case
8
:
UNROLL_CALL_RAW
(
4
,
cb22
);
break
;
case
7
:
cb21
(
6
);
case
6
:
UNROLL_CALL_RAW
(
3
,
cb22
);
break
;
case
5
:
cb21
(
4
);
case
4
:
UNROLL_CALL_RAW
(
2
,
cb22
);
break
;
case
3
:
cb21
(
2
);
case
2
:
UNROLL_CALL_RAW
(
1
,
cb22
);
break
;
case
1
:
cb21
(
0
);
default:
break
;
}
}
};
template
<
int
ow_remain
,
typename
Op
,
typename
T
>
struct
StoreOCxOWx
<
3
,
ow_remain
,
Op
,
T
>
{
static
void
impl
(
int32x4_t
res
[][
8
],
const
Op
&
op
,
T
*
dst_ptr
,
const
int
ld_dst_oc
)
{
switch
(
ow_remain
)
{
case
8
:
UNROLL_CALL_RAW
(
4
,
cb32
);
break
;
case
7
:
cb31
(
6
);
case
6
:
UNROLL_CALL_RAW
(
3
,
cb32
);
break
;
case
5
:
cb31
(
4
);
case
4
:
UNROLL_CALL_RAW
(
2
,
cb32
);
break
;
case
3
:
cb31
(
2
);
case
2
:
UNROLL_CALL_RAW
(
1
,
cb32
);
break
;
case
1
:
cb31
(
0
);
default:
break
;
}
}
};
#undef cb11
#undef cb21
#undef cb31
#undef cb12
#undef cb22
#undef cb32
template
<
int
row
,
int
ow_remain
,
typename
Op
,
typename
T
>
inline
void
store_ocx_owx_remain_static
(
int32x4_t
res
[][
8
],
const
Op
&
op
,
T
*
dst_ptr
,
const
int
ld_dst_oc
)
{
StoreOCxOWx
<
row
,
ow_remain
,
Op
,
T
>::
impl
(
res
,
op
,
dst_ptr
,
ld_dst_oc
);
}
template
<
int
res_row
,
int
src_row
,
int
src_start_idx
,
int
weight_idx
,
typename
FUNC
,
typename
T
,
typename
T2
,
typename
T3
>
struct
ShiftCalHelper
{
static
void
impl
(
T
&
res
,
T2
&
src
,
T3
&
weight
)
{
#define cb(step) \
res[res_row][step] = FUNC::template impl<((src_start_idx + step) % 4)>( \
res[res_row][step], weight[weight_idx], \
src[src_row][(src_start_idx + step) / 4]);
UNROLL_CALL_RAW
(
8
,
cb
);
#undef cb
}
};
template
<
int
res_row
,
int
src_row
,
int
src_start_idx
,
int
weight_idx
,
typename
FUNC
,
typename
T
,
typename
T2
,
typename
T3
>
inline
void
cal_helper
(
T
&
res
,
T2
&
src
,
T3
&
weight
)
{
ShiftCalHelper
<
res_row
,
src_row
,
src_start_idx
,
weight_idx
,
FUNC
,
T
,
T2
,
T3
>::
impl
(
res
,
src
,
weight
);
};
/**
* oc12_owx(m = 12, n = x) and oc8_owx(m = 8, n = x) and oc4_owx(m = 4, n = x)
* gemm like kernel
* */
template
<
typename
dst_type
,
int
stride
,
BiasMode
bias_mode
,
typename
Op
,
int
ow_remain
,
int
filter_size
,
int
oc_interval
,
int
ow_interval
>
struct
KernNeonSdotNCHW44
{
static
void
impl
(
dst_type
*
dst
,
const
int
dst_step
,
const
int8_t
*
src
,
const
int
ih
,
const
int
iw
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
const
int
ic
,
const
Op
&
op
);
};
template
<
typename
dst_type
,
BiasMode
bias_mode
,
typename
Op
,
int
ow_remain
,
int
filter_size
,
int
oc_interval
,
int
ow_interval
>
struct
KernNeonSdotNCHW44
<
dst_type
,
1
,
bias_mode
,
Op
,
ow_remain
,
filter_size
,
oc_interval
,
ow_interval
>
{
static
void
impl
(
dst_type
*
dst
,
const
int
dst_step
,
const
int8_t
*
src
,
const
int
ih
,
const
int
iw
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
const
int
ic
,
const
Op
&
op
)
{
constexpr
int
FH
=
filter_size
;
constexpr
int
FW
=
filter_size
;
constexpr
int
filter_next_row
=
FW
*
OC_PACK_SIZE
*
IC_PACK_SIZE
;
//! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const
int
filter_next_4oc
=
FH
*
FW
*
ic
*
OC_PACK_SIZE
;
//! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const
int
src_next_ic
=
ih
*
iw
;
const
int
src_next_row
=
iw
*
IC_PACK_SIZE
;
constexpr
int
NSRC
=
(
ow_interval
+
filter_size
-
1
)
/
4
+
1
;
constexpr
int
LOOP
=
oc_interval
/
4
;
int32x4_t
res
[
3
][
ow_interval
];
init_ocx_ow8
<
LOOP
,
bias_mode
>
(
res
,
bias
,
OC_PACK_SIZE
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
IC_PACK_SIZE
)
{
const
int8_t
*
i_src
=
src
+
ic_idx
*
src_next_ic
;
const
int8_t
*
i_filter
=
filter
+
ic_idx
*
FH
*
FW
*
OC_PACK_SIZE
;
for
(
int
fh_idx
=
0
;
fh_idx
<
FH
;
++
fh_idx
)
{
int8x16_t
src
[
1
][
4
];
int8x16_t
weight
[
3
];
load_helper
<
NSRC
,
0
,
SIMD_LEN
,
1
,
Vld1q_s8
>
(
src
,
i_src
,
0
);
//! do not use switch order 3,2,1 because it will slow the speed.
#define CALC_PART(step) \
switch (LOOP) { \
case 1: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \
break; \
case 2: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \
break; \
case 3: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \
weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \
filter_next_col * step); \
cal_helper<2, 0, step, 2, Vdotq_laneq_s32>(res, src, weight); \
break; \
default: \
break; \
}
switch
(
filter_size
)
{
case
2
:
UNROLL_CALL_RAW
(
2
,
CALC_PART
);
break
;
case
3
:
UNROLL_CALL_RAW
(
3
,
CALC_PART
);
break
;
case
5
:
UNROLL_CALL_RAW
(
5
,
CALC_PART
);
break
;
case
7
:
UNROLL_CALL_RAW
(
7
,
CALC_PART
);
break
;
default:
break
;
}
#undef CALC_PART
i_filter
+=
filter_next_row
;
i_src
+=
src_next_row
;
}
}
store_ocx_owx_remain_static
<
LOOP
,
ow_remain
,
Op
>
(
res
,
op
,
dst
,
dst_step
);
}
};
template
<
typename
dst_type
,
BiasMode
bias_mode
,
typename
Op
,
int
ow_remain
,
int
filter_size
,
int
oc_interval
,
int
ow_interval
>
struct
KernNeonSdotNCHW44
<
dst_type
,
2
,
bias_mode
,
Op
,
ow_remain
,
filter_size
,
oc_interval
,
ow_interval
>
{
static
void
impl
(
dst_type
*
dst
,
const
int
dst_step
,
const
int8_t
*
src
,
const
int
ih
,
const
int
iw
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
const
int
ic
,
const
Op
&
op
)
{
constexpr
int
FH
=
filter_size
;
constexpr
int
FW
=
filter_size
;
constexpr
int
filter_next_row
=
FW
*
OC_PACK_SIZE
*
IC_PACK_SIZE
;
//! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const
int
filter_next_4oc
=
FH
*
FW
*
ic
*
OC_PACK_SIZE
;
//! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const
int
src_next_ic
=
ih
*
iw
;
const
int
src_next_row
=
iw
*
IC_PACK_SIZE
;
constexpr
int
NSRC
=
(
ow_interval
*
2
+
filter_size
-
3
)
/
8
+
1
;
constexpr
int
LOOP
=
oc_interval
/
4
;
int32x4_t
res
[
3
][
ow_interval
];
init_ocx_ow8
<
LOOP
,
bias_mode
>
(
res
,
bias
,
OC_PACK_SIZE
);
for
(
int
ic_idx
=
0
;
ic_idx
<
ic
;
ic_idx
+=
IC_PACK_SIZE
)
{
const
int8_t
*
i_src
=
src
+
ic_idx
*
src_next_ic
;
const
int8_t
*
i_filter
=
filter
+
ic_idx
*
FH
*
FW
*
OC_PACK_SIZE
;
for
(
int
fh_idx
=
0
;
fh_idx
<
FH
;
++
fh_idx
)
{
int8x16_t
src
[
2
][
3
];
int8x16_t
weight
[
3
];
const
int
offset
=
megdnn
::
div_ceil
(
iw
,
2
)
*
IC_PACK_SIZE
;
load_helper
<
NSRC
,
0
,
SIMD_LEN
,
2
,
Vld1q_s8
>
(
src
,
i_src
,
offset
);
//! do not use switch order 3,2,1 because it will slow the speed.
#define CALC_PART(step) \
switch (LOOP) { \
case 1: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \
weight); \
break; \
case 2: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \
weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \
weight); \
break; \
case 3: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \
weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \
weight); \
weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \
filter_next_col * step); \
cal_helper<2, step % 2, step / 2, 2, Vdotq_laneq_s32>(res, src, \
weight); \
break; \
default: \
break; \
}
switch
(
filter_size
)
{
case
2
:
UNROLL_CALL_RAW
(
2
,
CALC_PART
);
break
;
case
3
:
UNROLL_CALL_RAW
(
3
,
CALC_PART
);
break
;
case
5
:
UNROLL_CALL_RAW
(
5
,
CALC_PART
);
break
;
case
7
:
UNROLL_CALL_RAW
(
7
,
CALC_PART
);
break
;
default:
break
;
}
#undef CALC_PART
i_filter
+=
filter_next_row
;
i_src
+=
src_next_row
;
}
}
store_ocx_owx_remain_static
<
LOOP
,
ow_remain
,
Op
>
(
res
,
op
,
dst
,
dst_step
);
}
};
}
// namespace direct_dotprod_nchw44
}
// namespace arm_common
}
// namespace megdnn
#endif
//vim: syntax=cpp.doxygen
dnn/src/arm_common/conv_bias/intrinsic_helper.h
浏览文件 @
9e876203
...
...
@@ -536,6 +536,7 @@ inline void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr,
#undef BAIS_INIT
}
}
/////////////////////////init_ocx_ow8////////////////////
inline
float32x4_t
neon_vdupq_n
(
float
val
)
{
...
...
dnn/src/arm_common/conv_bias/opr_impl.cpp
浏览文件 @
9e876203
...
...
@@ -64,6 +64,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoDotU8DirectStride1
du8_direct_stride1_small_group
{
false
};
AlgoDotU8DirectStride2
du8_direct_stride2_large_group
{
true
};
AlgoDotU8DirectStride2
du8_direct_stride2_small_group
{
false
};
AlgoDotS8Direct_NCHW44
ds8_direct_nchw44
;
#endif
AlgoF32DirectNCHWNCHW44
f32_direct_stride2_nchw_nchw44
;
...
...
@@ -103,6 +105,8 @@ public:
direct_algos
.
emplace_back
(
&
du8_direct_stride1_small_group
);
direct_algos
.
emplace_back
(
&
du8_direct_stride2_large_group
);
direct_algos
.
emplace_back
(
&
du8_direct_stride2_small_group
);
direct_algos
.
emplace_back
(
&
ds8_direct_nchw44
);
#endif
direct_algos
.
emplace_back
(
&
qu8_direct_stride2_large_group
);
direct_algos
.
emplace_back
(
&
qu8_direct_stride2_small_group
);
...
...
dnn/src/arm_common/conv_bias/opr_impl.h
浏览文件 @
9e876203
...
...
@@ -67,6 +67,8 @@ private:
class
AlgoDotS8DirectStride2
;
class
AlgoDotU8DirectStride1
;
class
AlgoDotU8DirectStride2
;
class
AlgoDotS8Direct_NCHW44
;
#endif
class
AlgoF32Direct
;
class
AlgoF32DirectStride1
;
...
...
dnn/test/arm_common/conv_bias.cpp
浏览文件 @
9e876203
...
...
@@ -1809,6 +1809,81 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) {
used1
/
used0
);
}
}
TEST_F
(
ARM_COMMON
,
BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT
)
{
using
namespace
conv_bias
;
std
::
vector
<
TestArg
>
args
;
auto
run
=
[
&
](
size_t
oc
,
size_t
ic
,
size_t
w
,
size_t
h
,
size_t
kernel
,
size_t
p
,
size_t
stride
,
NonlineMode
nonline_mode
)
{
if
(
w
+
2
*
p
<
kernel
||
h
+
2
*
p
<
kernel
)
return
;
param
::
ConvBias
param
;
param
.
stride_h
=
stride
;
param
.
stride_w
=
stride
;
param
.
pad_h
=
p
;
param
.
pad_w
=
p
;
param
.
nonlineMode
=
nonline_mode
;
param
.
format
=
param
::
ConvBias
::
Format
::
NCHW44_DOT
;
//! channel bias
args
.
emplace_back
(
param
,
TensorShape
{
1
,
ic
/
4
,
h
,
w
,
4
},
TensorShape
{
oc
/
4
,
ic
/
4
,
kernel
,
kernel
,
4
,
4
},
TensorShape
{
1
,
oc
/
4
,
1
,
1
,
4
});
};
for
(
size_t
stride
:
{
1
,
2
})
for
(
size_t
kernel
:
{
2
,
3
,
5
,
7
})
for
(
size_t
oc
:
{
64
})
for
(
NonlineMode
nonline_mode
:
{
NonlineMode
::
IDENTITY
})
{
run
(
oc
,
oc
,
56
,
56
,
kernel
,
kernel
/
2
,
stride
,
nonline_mode
);
}
constexpr
size_t
RUN
=
50
;
Benchmarker
<
ConvBias
>
benchmark0
(
handle
());
benchmark0
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
2.5
f
))
.
set_dtype
(
1
,
dtype
::
QuantizedS8
(
2.5
f
))
.
set_dtype
(
2
,
dtype
::
QuantizedS32
(
6.25
f
))
.
set_dtype
(
4
,
dtype
::
QuantizedS8
(
60.25
f
));
benchmark0
.
set_display
(
false
);
benchmark0
.
set_times
(
RUN
);
benchmark0
.
set_before_exec_callback
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBiasForward
>
(
"ARMDOTS8DIRECT_NCHW44"
));
Benchmarker
<
ConvBias
>
benchmark1
(
handle
());
benchmark1
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
2.5
f
))
.
set_dtype
(
1
,
dtype
::
QuantizedS8
(
2.5
f
))
.
set_dtype
(
2
,
dtype
::
QuantizedS32
(
6.25
f
))
.
set_dtype
(
4
,
dtype
::
QuantizedS8
(
60.25
f
));
benchmark1
.
set_display
(
false
);
benchmark1
.
set_times
(
RUN
);
for
(
auto
&&
arg
:
args
)
{
TensorLayout
dst_layout
;
auto
opr
=
handle
()
->
create_operator
<
ConvBias
>
();
opr
->
param
()
=
arg
.
param
;
opr
->
deduce_layout
({
arg
.
src
,
dtype
::
Int8
()},
{
arg
.
filter
,
dtype
::
Int8
()},
{
arg
.
bias
,
dtype
::
Int32
()},
{},
dst_layout
);
//! dst.nr_elems * IC * FH * FW * 2
float
computations
=
dst_layout
.
total_nr_elems
()
*
arg
.
filter
[
1
]
*
arg
.
filter
[
2
]
*
arg
.
filter
[
3
]
*
8.0
/
(
1024
*
1024
*
1024
)
*
1e3
;
auto
used0
=
benchmark0
.
set_param
(
arg
.
param
).
exec
(
{
arg
.
src
,
arg
.
filter
,
arg
.
bias
,
{},
{}})
/
RUN
;
auto
used1
=
benchmark1
.
set_param
(
arg
.
param
).
exec
(
{
arg
.
src
,
arg
.
filter
,
arg
.
bias
,
{},
{}})
/
RUN
;
printf
(
"%s %s: Direct use: %f ms %f Gflops normal: %f ms %f GFlops "
"speedup: %f
\n
"
,
arg
.
src
.
to_string
().
c_str
(),
arg
.
filter
.
to_string
().
c_str
(),
used0
,
computations
/
used0
,
used1
,
computations
/
used1
,
used1
/
used0
);
}
}
#endif
#endif
...
...
dnn/test/arm_common/conv_bias_multi_thread.cpp
浏览文件 @
9e876203
...
...
@@ -155,7 +155,7 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
if
(
support_sigmoid
)
{
nonlinemode
.
emplace_back
(
NLMode
::
SIGMOID
);
}
std
::
vector
<
megdnn
::
BiasMode
>
bias_mode
=
{
megdnn
::
BiasMode
::
BROADCAST_CHANNEL_BIAS
};
if
(
no_bias
)
{
...
...
@@ -672,6 +672,63 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_SMALL_GROUP) {
get_conv_bias_args
({
2
,
3
,
5
,
7
},
2
,
false
,
true
,
true
),
handle
(),
"ARMDOTU8STRD2_SMALL_GROUP"
);
}
/******************************dot int8x8x8 nchw44 ***********************/
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8
)
{
using
namespace
conv_bias
;
std
::
vector
<
TestArg
>
args
=
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
1
);
for
(
auto
&&
arg
:
args
)
arg
.
param
.
format
=
param
::
ConvBias
::
Format
::
NCHW44_DOT
;
checker_conv_bias_qint8x8x8
(
args
,
handle
(),
"ARMDOTS8DIRECT_NCHW44"
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x32
)
{
using
namespace
conv_bias
;
std
::
vector
<
TestArg
>
args
=
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
1
,
false
,
true
,
true
);
for
(
auto
&&
arg
:
args
)
arg
.
param
.
format
=
param
::
ConvBias
::
Format
::
NCHW44_DOT
;
checker_conv_bias_qint8x8x32
(
args
,
handle
(),
"ARMDOTS8DIRECT_NCHW44"
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_8x8x32
)
{
using
namespace
conv_bias
;
std
::
vector
<
TestArg
>
args
=
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
1
,
false
,
true
,
true
);
for
(
auto
&&
arg
:
args
)
arg
.
param
.
format
=
param
::
ConvBias
::
Format
::
NCHW44_DOT
;
checker_conv_bias_int8x8x32_multi
(
args
,
handle
(),
"ARMDOTS8DIRECT_NCHW44"
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x8
)
{
using
namespace
conv_bias
;
//! test qint8x8x8
std
::
vector
<
TestArg
>
args
=
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
2
);
for
(
auto
&&
arg
:
args
)
arg
.
param
.
format
=
param
::
ConvBias
::
Format
::
NCHW44_DOT
;
checker_conv_bias_qint8x8x8
(
args
,
handle
(),
"ARMDOTS8DIRECT_NCHW44"
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x32
)
{
using
namespace
conv_bias
;
//! test qint8x8x8
std
::
vector
<
TestArg
>
args
=
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
2
,
false
,
true
,
true
);
for
(
auto
&&
arg
:
args
)
arg
.
param
.
format
=
param
::
ConvBias
::
Format
::
NCHW44_DOT
;
checker_conv_bias_qint8x8x32
(
args
,
handle
(),
"ARMDOTS8DIRECT_NCHW44"
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_8x8x32
)
{
using
namespace
conv_bias
;
//! test qint8x8x8
std
::
vector
<
TestArg
>
args
=
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
2
,
false
,
true
,
true
);
for
(
auto
&&
arg
:
args
)
arg
.
param
.
format
=
param
::
ConvBias
::
Format
::
NCHW44_DOT
;
checker_conv_bias_int8x8x32_multi
(
args
,
handle
(),
"ARMDOTS8DIRECT_NCHW44"
);
}
#endif
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_WINOGRAD_F23_4
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录