Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f2b42bf0
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
f2b42bf0
编写于
2月 07, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
chore(dotprod): add arm dotprod attribute for easy use
GitOrigin-RevId: 78c3e72218b8db009542b00e3688315d058d37fa
上级
fa4bf168
变更
105
展开全部
隐藏空白更改
内联
并排
Showing
105 changed file
with
553 addition
and
344 deletion
+553
-344
dnn/src/aarch64/conv_bias/int8/algos.cpp
dnn/src/aarch64/conv_bias/int8/algos.cpp
+34
-8
dnn/src/aarch64/conv_bias/int8/strategy.cpp
dnn/src/aarch64/conv_bias/int8/strategy.cpp
+7
-10
dnn/src/aarch64/conv_bias/int8/strategy.h
dnn/src/aarch64/conv_bias/int8/strategy.h
+1
-4
dnn/src/aarch64/conv_bias/quint8/algos.cpp
dnn/src/aarch64/conv_bias/quint8/algos.cpp
+39
-3
dnn/src/aarch64/conv_bias/quint8/strategy.cpp
dnn/src/aarch64/conv_bias/quint8/strategy.cpp
+49
-39
dnn/src/aarch64/conv_bias/quint8/strategy.h
dnn/src/aarch64/conv_bias/quint8/strategy.h
+31
-15
dnn/src/aarch64/matrix_mul/algos.cpp
dnn/src/aarch64/matrix_mul/algos.cpp
+23
-14
dnn/src/aarch64/matrix_mul/algos.h
dnn/src/aarch64/matrix_mul/algos.h
+4
-8
dnn/src/aarch64/matrix_mul/fp32/strategy.cpp
dnn/src/aarch64/matrix_mul/fp32/strategy.cpp
+0
-3
dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h
dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h
+0
-2
dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h
dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h
+0
-2
dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h
dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h
+0
-3
dnn/src/aarch64/matrix_mul/int8/strategy.cpp
dnn/src/aarch64/matrix_mul/int8/strategy.cpp
+0
-3
dnn/src/aarch64/matrix_mul/int8/strategy.h
dnn/src/aarch64/matrix_mul/int8/strategy.h
+0
-2
dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h
dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h
+7
-5
dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h
dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h
+5
-4
dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp
dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp
+1
-1
dnn/src/aarch64/matrix_mul/int8_dot/strategy.h
dnn/src/aarch64/matrix_mul/int8_dot/strategy.h
+1
-1
dnn/src/aarch64/matrix_mul/opr_impl.cpp
dnn/src/aarch64/matrix_mul/opr_impl.cpp
+8
-12
dnn/src/aarch64/matrix_mul/opr_impl.h
dnn/src/aarch64/matrix_mul/opr_impl.h
+4
-6
dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h
dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h
+0
-2
dnn/src/aarch64/matrix_mul/quint8/strategy.cpp
dnn/src/aarch64/matrix_mul/quint8/strategy.cpp
+0
-2
dnn/src/aarch64/matrix_mul/quint8/strategy.h
dnn/src/aarch64/matrix_mul/quint8/strategy.h
+0
-2
dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp
dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp
+2
-7
dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h
dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h
+2
-3
dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h
dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h
+5
-3
dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp
dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp
+5
-5
dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h
dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h
+2
-2
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h
.../fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h
+0
-3
dnn/src/arm_common/conv_bias/int8/algos.cpp
dnn/src/arm_common/conv_bias/int8/algos.cpp
+7
-1
dnn/src/arm_common/conv_bias/int8/algos.h
dnn/src/arm_common/conv_bias/int8/algos.h
+1
-1
dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp
dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp
+10
-1
dnn/src/arm_common/conv_bias/int8/direct_dotprod.h
dnn/src/arm_common/conv_bias/int8/direct_dotprod.h
+1
-1
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp
+1
-2
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h
+2
-3
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp
.../arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp
+4
-2
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h
.../conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h
+4
-2
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp
...on/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp
+4
-2
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp
...on/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp
+4
-2
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp
...nv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp
+9
-2
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp
...nv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp
+9
-2
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
...arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
+4
-1
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h
...c/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h
+1
-1
dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp
dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp
+1
-1
dnn/src/arm_common/conv_bias/int8/stride1_dotprod.h
dnn/src/arm_common/conv_bias/int8/stride1_dotprod.h
+1
-1
dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp
dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp
+1
-1
dnn/src/arm_common/conv_bias/int8/stride2_dotprod.h
dnn/src/arm_common/conv_bias/int8/stride2_dotprod.h
+1
-1
dnn/src/arm_common/conv_bias/opr_impl.cpp
dnn/src/arm_common/conv_bias/opr_impl.cpp
+2
-2
dnn/src/arm_common/conv_bias/opr_impl.h
dnn/src/arm_common/conv_bias/opr_impl.h
+1
-1
dnn/src/arm_common/conv_bias/quint8/algos.cpp
dnn/src/arm_common/conv_bias/quint8/algos.cpp
+9
-2
dnn/src/arm_common/conv_bias/quint8/algos.h
dnn/src/arm_common/conv_bias/quint8/algos.h
+1
-1
dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp
dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp
+9
-1
dnn/src/arm_common/conv_bias/quint8/direct_dotprod.h
dnn/src/arm_common/conv_bias/quint8/direct_dotprod.h
+1
-1
dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp
dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp
+1
-1
dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.h
dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.h
+1
-1
dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp
dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp
+1
-1
dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.h
dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.h
+1
-1
dnn/src/arm_common/convolution/int8x8x32/algos.cpp
dnn/src/arm_common/convolution/int8x8x32/algos.cpp
+8
-2
dnn/src/arm_common/convolution/int8x8x32/algos.h
dnn/src/arm_common/convolution/int8x8x32/algos.h
+1
-1
dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp
...rm_common/convolution/int8x8x32/conv_backdata_stride1.cpp
+6
-3
dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h
.../arm_common/convolution/int8x8x32/conv_backdata_stride1.h
+1
-1
dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp
...rm_common/convolution/int8x8x32/conv_backdata_stride2.cpp
+5
-3
dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h
.../arm_common/convolution/int8x8x32/conv_backdata_stride2.h
+1
-1
dnn/src/arm_common/convolution/opr_impl.cpp
dnn/src/arm_common/convolution/opr_impl.cpp
+2
-2
dnn/src/arm_common/convolution/opr_impl.h
dnn/src/arm_common/convolution/opr_impl.h
+1
-1
dnn/src/arm_common/convolution/quint8/algos.cpp
dnn/src/arm_common/convolution/quint8/algos.cpp
+9
-1
dnn/src/arm_common/convolution/quint8/algos.h
dnn/src/arm_common/convolution/quint8/algos.h
+1
-1
dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp
...c/arm_common/convolution/quint8/conv_backdata_stride1.cpp
+6
-3
dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.h
...src/arm_common/convolution/quint8/conv_backdata_stride1.h
+1
-4
dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp
...c/arm_common/convolution/quint8/conv_backdata_stride2.cpp
+5
-3
dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.h
...src/arm_common/convolution/quint8/conv_backdata_stride2.h
+1
-4
dnn/src/arm_common/matrix_mul/algos.cpp
dnn/src/arm_common/matrix_mul/algos.cpp
+7
-1
dnn/src/arm_common/matrix_mul/algos.h
dnn/src/arm_common/matrix_mul/algos.h
+1
-1
dnn/src/arm_common/matrix_mul/int8/gemv.cpp
dnn/src/arm_common/matrix_mul/int8/gemv.cpp
+27
-8
dnn/src/arm_common/matrix_mul/int8/gemv.h
dnn/src/arm_common/matrix_mul/int8/gemv.h
+1
-1
dnn/src/arm_common/matrix_mul/opr_impl.cpp
dnn/src/arm_common/matrix_mul/opr_impl.cpp
+2
-2
dnn/src/arm_common/matrix_mul/opr_impl.h
dnn/src/arm_common/matrix_mul/opr_impl.h
+1
-1
dnn/src/arm_common/neon_struct.h
dnn/src/arm_common/neon_struct.h
+3
-2
dnn/src/arm_common/simd_macro/marm_neon.h
dnn/src/arm_common/simd_macro/marm_neon.h
+20
-7
dnn/src/armv7/matrix_mul/algos.cpp
dnn/src/armv7/matrix_mul/algos.cpp
+13
-1
dnn/src/armv7/matrix_mul/algos.h
dnn/src/armv7/matrix_mul/algos.h
+1
-1
dnn/src/armv7/matrix_mul/asm/common.h
dnn/src/armv7/matrix_mul/asm/common.h
+0
-1
dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp
dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp
+0
-1
dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h
dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h
+3
-1
dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h
dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h
+3
-2
dnn/src/armv7/matrix_mul/int8/strategy.cpp
dnn/src/armv7/matrix_mul/int8/strategy.cpp
+1
-1
dnn/src/armv7/matrix_mul/int8/strategy.h
dnn/src/armv7/matrix_mul/int8/strategy.h
+1
-1
dnn/src/armv7/matrix_mul/opr_impl.cpp
dnn/src/armv7/matrix_mul/opr_impl.cpp
+2
-2
dnn/src/armv7/matrix_mul/opr_impl.h
dnn/src/armv7/matrix_mul/opr_impl.h
+1
-1
dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h
dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h
+3
-2
dnn/src/armv7/matrix_mul/quint8/strategy.cpp
dnn/src/armv7/matrix_mul/quint8/strategy.cpp
+1
-1
dnn/src/armv7/matrix_mul/quint8/strategy.h
dnn/src/armv7/matrix_mul/quint8/strategy.h
+1
-1
dnn/src/common/utils.h
dnn/src/common/utils.h
+7
-0
dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp
dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp
+1
-1
dnn/test/aarch64/matrix_mul.cpp
dnn/test/aarch64/matrix_mul.cpp
+3
-3
dnn/test/arm_common/conv_bias.cpp
dnn/test/arm_common/conv_bias.cpp
+5
-6
dnn/test/arm_common/conv_bias_multi_thread.cpp
dnn/test/arm_common/conv_bias_multi_thread.cpp
+1
-1
dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp
dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp
+4
-4
dnn/test/arm_common/conv_bias_multi_thread_conv1x1.cpp
dnn/test/arm_common/conv_bias_multi_thread_conv1x1.cpp
+9
-11
dnn/test/arm_common/conv_bias_multi_thread_im2col.cpp
dnn/test/arm_common/conv_bias_multi_thread_im2col.cpp
+11
-10
dnn/test/arm_common/conv_bias_multi_thread_weight_preprocess.cpp
...t/arm_common/conv_bias_multi_thread_weight_preprocess.cpp
+17
-20
dnn/test/arm_common/convolution.cpp
dnn/test/arm_common/convolution.cpp
+2
-2
dnn/test/arm_common/matrix_mul.cpp
dnn/test/arm_common/matrix_mul.cpp
+6
-1
dnn/test/armv7/matrix_mul.cpp
dnn/test/armv7/matrix_mul.cpp
+3
-3
src/megbrain_build_config.h.in
src/megbrain_build_config.h.in
+22
-0
未找到文件。
dnn/src/aarch64/conv_bias/int8/algos.cpp
浏览文件 @
f2b42bf0
...
...
@@ -67,6 +67,23 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle(
size_t
K
=
IC
*
FH
*
FW
;
size_t
N
=
OH
*
OW
;
#if MGB_ENABLE_DOT
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
part2 = megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline>( \
M, N, K, false, false, strategy) \
.get_workspace_size();
if
(
cpuinfo_has_arm_neon_dot
())
{
DISPATCH_GEMM_BIAS
(
s8_8x12
,
1
)
}
else
{
DISPATCH_GEMM_BIAS
(
s8_4x4
,
0
)
}
#else
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
...
...
@@ -80,11 +97,7 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle(
.get_workspace_size(); \
} \
MIDOUT_END()
#if !(__ARM_FEATURE_DOTPROD)
DISPATCH_GEMM_BIAS
(
s8_4x4
,
0
)
#else
DISPATCH_GEMM_BIAS
(
s8_8x12
,
1
)
#endif
#undef DISPATCH_GEMM_STRATEGY
}
...
...
@@ -158,6 +171,23 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param,
size_t
K
=
IC
*
FH
*
FW
;
size_t
N
=
OH
*
OW
;
#if MGB_ENABLE_DOT
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias);
if
(
cpuinfo_has_arm_neon_dot
())
{
DISPATCH_GEMM_BIAS
(
s8_8x12
,
1
)
}
else
{
DISPATCH_GEMM_BIAS
(
s8_4x4
,
0
)
}
#else
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
...
...
@@ -172,11 +202,7 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param,
bias); \
} \
MIDOUT_END()
#if !(__ARM_FEATURE_DOTPROD)
DISPATCH_GEMM_BIAS
(
s8_4x4
,
0
)
#else
DISPATCH_GEMM_BIAS
(
s8_8x12
,
1
)
#endif
#undef DISPATCH_GEMM_STRATEGY
}
...
...
dnn/src/aarch64/conv_bias/int8/strategy.cpp
浏览文件 @
f2b42bf0
...
...
@@ -26,7 +26,7 @@ namespace impl {
template
<
BiasMode
bmode
,
typename
Op
,
int
block_m
,
int
block_n
>
struct
KernCaller
;
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
template
<
BiasMode
bmode
,
typename
Op
>
struct
KernCaller
<
bmode
,
Op
,
8
,
12
>
{
static
void
run
(
const
dt_int8
*
packA
,
const
dt_int8
*
packB
,
size_t
M
,
...
...
@@ -118,7 +118,7 @@ struct KernCaller<bmode, Op, 8, 12> {
}
};
#e
lse
#e
ndif
template
<
BiasMode
bmode
,
typename
Op
>
struct
KernCaller
<
bmode
,
Op
,
4
,
4
>
{
...
...
@@ -196,10 +196,8 @@ struct KernCaller<bmode, Op, 4, 4> {
}
};
#endif
}
// namespace impl
#if !(__ARM_FEATURE_DOTPROD)
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_s8_4x4_nobias_identity
)
void
gemm_s8_4x4_nobias_identity
::
pack_A
(
dt_int8
*
outptr
,
const
dt_int8
*
inptr
,
...
...
@@ -227,7 +225,8 @@ void gemm_s8_4x4_nobias_identity::pack_B(dt_int8* out, const dt_int8* in,
size_t
gemm_s8_4x4_nobias_identity
::
get_workspace_size
()
const
{
return
4
*
4
*
sizeof
(
dt_int32
);
}
#else
#if MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_s8_8x12_nobias_identity
)
void
gemm_s8_8x12_nobias_identity
::
pack_A
(
dt_int8
*
outptr
,
const
dt_int8
*
inptr
,
...
...
@@ -277,11 +276,10 @@ size_t gemm_s8_8x12_nobias_identity::get_workspace_size() const {
#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, scale_C);
#if !(__ARM_FEATURE_DOTPROD)
KERN
(
4
,
4
,
nobias
,
BiasMode
::
NO_BIAS
,
identity
,
TypeCvtOp
)
KERN
(
4
,
4
,
nobias
,
BiasMode
::
NO_BIAS
,
relu
,
ReluOp
)
KERN
(
4
,
4
,
nobias
,
BiasMode
::
NO_BIAS
,
hswish
,
HSwishOp
)
#
else
#
if MGB_ENABLE_DOT
KERN
(
8
,
12
,
nobias
,
BiasMode
::
NO_BIAS
,
identity
,
TypeCvtOp
)
KERN
(
8
,
12
,
nobias
,
BiasMode
::
NO_BIAS
,
relu
,
ReluOp
)
KERN
(
8
,
12
,
nobias
,
BiasMode
::
NO_BIAS
,
hswish
,
HSwishOp
)
...
...
@@ -291,12 +289,11 @@ KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp)
#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, \
scale_A* scale_B, scale_C);
#if !(__ARM_FEATURE_DOTPROD)
KERN
(
4
,
4
,
bias_channel
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
identity
,
AddOp
)
KERN
(
4
,
4
,
bias_channel
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
relu
,
FuseAddReluOp
)
KERN
(
4
,
4
,
bias_channel
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
hswish
,
FuseAddHSwishOp
)
#
else
#
if MGB_ENABLE_DOT
KERN
(
8
,
12
,
bias_channel
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
identity
,
AddOp
)
KERN
(
8
,
12
,
bias_channel
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
relu
,
FuseAddReluOp
)
KERN
(
8
,
12
,
bias_channel
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
hswish
,
...
...
dnn/src/aarch64/conv_bias/int8/strategy.h
浏览文件 @
f2b42bf0
...
...
@@ -15,7 +15,6 @@ namespace megdnn {
namespace
aarch64
{
namespace
matmul
{
#if !(__ARM_FEATURE_DOTPROD)
/**
* \brief base strategy of gemm.
*
...
...
@@ -39,8 +38,7 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_relu,
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER
(
gemm_s8_4x4_bias_channel_hswish
,
gemm_s8_4x4_nobias_identity
);
#else
#if MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK
(
dt_int8
,
dt_int8
,
dt_int32
,
8
,
12
,
4
,
false
,
true
,
gemm_s8_8x12_nobias_identity
);
...
...
@@ -59,7 +57,6 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_relu,
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER
(
gemm_s8_8x12_bias_channel_hswish
,
gemm_s8_8x12_nobias_identity
);
#endif
}
// namespace matmul
...
...
dnn/src/aarch64/conv_bias/quint8/algos.cpp
浏览文件 @
f2b42bf0
...
...
@@ -69,6 +69,23 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle(
size_t
K
=
IC
*
FH
*
FW
;
size_t
N
=
OH
*
OW
;
#if MGB_ENABLE_DOT
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
part2 = megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline>( \
M, N, K, false, false, strategy) \
.get_workspace_size();
if
(
cpuinfo_has_arm_neon_dot
())
{
DISPATCH_GEMM_BIAS
(
u8_8x8_dot
,
1
);
}
else
{
DISPATCH_GEMM_BIAS
(
u8_8x8_nodot
,
0
);
}
#else
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
...
...
@@ -82,8 +99,8 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle(
.get_workspace_size(); \
} \
MIDOUT_END()
DISPATCH_GEMM_BIAS
(
u8_8x8
,
0
)
DISPATCH_GEMM_BIAS
(
u8_8x8_nodot
,
0
)
#endif
#undef DISPATCH_GEMM_STRATEGY
}
return
{
nullptr
,
{
part0
,
part1
,
part2
}};
...
...
@@ -157,6 +174,23 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param,
size_t
K
=
IC
*
FH
*
FW
;
size_t
N
=
OH
*
OW
;
#if MGB_ENABLE_DOT
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias);
if
(
cpuinfo_has_arm_neon_dot
())
{
DISPATCH_GEMM_BIAS
(
u8_8x8_dot
,
1
)
}
else
{
DISPATCH_GEMM_BIAS
(
u8_8x8_nodot
,
0
)
}
#else
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
...
...
@@ -172,7 +206,9 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param,
} \
MIDOUT_END()
DISPATCH_GEMM_BIAS
(
u8_8x8
,
0
)
DISPATCH_GEMM_BIAS
(
u8_8x8_nodot
,
0
)
#endif
#undef DISPATCH_GEMM_STRATEGY
}
}
...
...
dnn/src/aarch64/conv_bias/quint8/strategy.cpp
浏览文件 @
f2b42bf0
...
...
@@ -23,12 +23,12 @@ using namespace aarch64;
using
namespace
aarch64
::
matmul
;
namespace
impl
{
template
<
BiasMode
bmode
,
typename
Op
,
int
block_m
,
int
block_n
>
template
<
BiasMode
bmode
,
typename
Op
,
int
block_m
,
int
block_n
,
bool
dot
>
struct
KernCaller
;
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
template
<
BiasMode
bmode
,
typename
Op
>
struct
KernCaller
<
bmode
,
Op
,
8
,
8
>
{
struct
KernCaller
<
bmode
,
Op
,
8
,
8
,
true
>
{
static
void
run
(
const
dt_uint8
*
packA
,
const
dt_uint8
*
packB
,
size_t
M
,
size_t
N
,
size_t
K
,
dt_uint8
*
C
,
size_t
LDC
,
bool
is_first_k
,
Op
op
,
const
dt_int32
*
bias
,
...
...
@@ -120,10 +120,10 @@ struct KernCaller<bmode, Op, 8, 8> {
}
};
#e
lse
#e
ndif
template
<
BiasMode
bmode
,
typename
Op
>
struct
KernCaller
<
bmode
,
Op
,
8
,
8
>
{
struct
KernCaller
<
bmode
,
Op
,
8
,
8
,
false
>
{
static
void
run
(
const
dt_uint8
*
packA
,
const
dt_uint8
*
packB
,
size_t
M
,
size_t
N
,
size_t
K
,
dt_uint8
*
C
,
size_t
LDC
,
bool
is_first_k
,
Op
op
,
const
dt_int32
*
bias
,
...
...
@@ -215,13 +215,11 @@ struct KernCaller<bmode, Op, 8, 8> {
}
};
#endif
}
// namespace impl
#if
__ARM_FEATURE_DOTPROD
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_u8_8x8_nobias_identity
)
#if
MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_u8_8x8_
dot_
nobias_identity
)
void
gemm_u8_8x8_nobias_identity
::
pack_A
(
uint8_t
*
outptr
,
const
uint8_t
*
inptr
,
void
gemm_u8_8x8_
dot_
nobias_identity
::
pack_A
(
uint8_t
*
outptr
,
const
uint8_t
*
inptr
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
if
(
transpose
)
{
...
...
@@ -233,7 +231,7 @@ void gemm_u8_8x8_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr,
}
}
void
gemm_u8_8x8_nobias_identity
::
pack_B
(
uint8_t
*
out
,
const
uint8_t
*
in
,
void
gemm_u8_8x8_
dot_
nobias_identity
::
pack_B
(
uint8_t
*
out
,
const
uint8_t
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
if
(
transpose
)
{
...
...
@@ -245,10 +243,13 @@ void gemm_u8_8x8_nobias_identity::pack_B(uint8_t* out, const uint8_t* in,
}
}
#else
size_t
gemm_u8_8x8_dot_nobias_identity
::
get_workspace_size
()
const
{
return
8
*
8
*
sizeof
(
dt_int32
);
}
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_u8_8x8_nobias_identity
)
void
gemm_u8_8x8_nobias_identity
::
pack_A
(
dt_uint8
*
outptr
,
#endif
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_u8_8x8_nodot_nobias_identity
)
void
gemm_u8_8x8_nodot_nobias_identity
::
pack_A
(
dt_uint8
*
outptr
,
const
dt_uint8
*
inptr
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
...
...
@@ -262,7 +263,7 @@ void gemm_u8_8x8_nobias_identity::pack_A(dt_uint8* outptr,
}
}
void
gemm_u8_8x8_nobias_identity
::
pack_B
(
dt_uint8
*
out
,
const
dt_uint8
*
in
,
void
gemm_u8_8x8_no
dot_no
bias_identity
::
pack_B
(
dt_uint8
*
out
,
const
dt_uint8
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
uint8_t
zB
=
B_dtype
.
param
<
dtype
::
Quantized8Asymm
>
().
zero_point
;
...
...
@@ -275,43 +276,52 @@ void gemm_u8_8x8_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in,
}
}
#endif
size_t
gemm_u8_8x8_nobias_identity
::
get_workspace_size
()
const
{
size_t
gemm_u8_8x8_nodot_nobias_identity
::
get_workspace_size
()
const
{
return
8
*
8
*
sizeof
(
dt_int32
);
}
#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \
void gemm_u8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \
size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \
const dt_int32* bias, dt_int32* workspace) const { \
float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \
DEFINE_OP(_OP); \
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \
workspace, zp_A, zp_B); \
#define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, \
_OP) \
void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline:: \
kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, \
size_t N, size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \
const dt_int32* bias, dt_int32* workspace) const { \
float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \
DEFINE_OP(_OP); \
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \
workspace, zp_A, zp_B); \
}
#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, scale_C, zp_C);
KERN
(
8
,
8
,
nobias
,
BiasMode
::
NO_BIAS
,
identity
,
TypeCvtOp
)
KERN
(
8
,
8
,
nobias
,
BiasMode
::
NO_BIAS
,
relu
,
ReluOp
)
KERN
(
8
,
8
,
nobias
,
BiasMode
::
NO_BIAS
,
hswish
,
HSwishOp
)
#if MGB_ENABLE_DOT
KERN
(
8
,
8
,
true
,
_dot
,
nobias
,
BiasMode
::
NO_BIAS
,
identity
,
TypeCvtOp
)
KERN
(
8
,
8
,
true
,
_dot
,
nobias
,
BiasMode
::
NO_BIAS
,
relu
,
ReluOp
)
KERN
(
8
,
8
,
true
,
_dot
,
nobias
,
BiasMode
::
NO_BIAS
,
hswish
,
HSwishOp
)
#endif
KERN
(
8
,
8
,
false
,
_nodot
,
nobias
,
BiasMode
::
NO_BIAS
,
identity
,
TypeCvtOp
)
KERN
(
8
,
8
,
false
,
_nodot
,
nobias
,
BiasMode
::
NO_BIAS
,
relu
,
ReluOp
)
KERN
(
8
,
8
,
false
,
_nodot
,
nobias
,
BiasMode
::
NO_BIAS
,
hswish
,
HSwishOp
)
#undef DEFINE_OP
#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, \
scale_A* scale_B, scale_C, zp_C);
KERN
(
8
,
8
,
bias_channel
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
identity
,
AddOp
)
KERN
(
8
,
8
,
bias_channel
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
relu
,
FuseAddReluOp
)
KERN
(
8
,
8
,
bias_channel
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
hswish
,
FuseAddHSwishOp
)
#if MGB_ENABLE_DOT
KERN
(
8
,
8
,
true
,
_dot
,
bias_channel
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
identity
,
AddOp
)
KERN
(
8
,
8
,
true
,
_dot
,
bias_channel
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
relu
,
FuseAddReluOp
)
KERN
(
8
,
8
,
true
,
_dot
,
bias_channel
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
hswish
,
FuseAddHSwishOp
)
#endif
KERN
(
8
,
8
,
false
,
_nodot
,
bias_channel
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
identity
,
AddOp
)
KERN
(
8
,
8
,
false
,
_nodot
,
bias_channel
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
relu
,
FuseAddReluOp
)
KERN
(
8
,
8
,
false
,
_nodot
,
bias_channel
,
BiasMode
::
BROADCAST_CHANNEL_BIAS
,
hswish
,
FuseAddHSwishOp
)
#undef DEFINE_OP
#undef KERN
...
...
dnn/src/aarch64/conv_bias/quint8/strategy.h
浏览文件 @
f2b42bf0
...
...
@@ -15,30 +15,46 @@ namespace megdnn {
namespace
aarch64
{
namespace
matmul
{
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK
(
dt_uint8
,
dt_uint8
,
dt_int32
,
8
,
8
,
4
,
false
,
true
,
gemm_u8_8x8_nobias_identity
);
#else
gemm_u8_8x8_dot_nobias_identity
);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER
(
gemm_u8_8x8_dot_nobias_relu
,
gemm_u8_8x8_dot_nobias_identity
);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER
(
gemm_u8_8x8_dot_nobias_hswish
,
gemm_u8_8x8_dot_nobias_identity
);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER
(
gemm_u8_8x8_dot_bias_channel_identity
,
gemm_u8_8x8_dot_nobias_identity
);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER
(
gemm_u8_8x8_dot_bias_channel_relu
,
gemm_u8_8x8_dot_nobias_identity
);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER
(
gemm_u8_8x8_dot_bias_channel_hswish
,
gemm_u8_8x8_dot_nobias_identity
);
#endif
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK
(
dt_uint8
,
dt_uint8
,
dt_int32
,
8
,
8
,
8
,
false
,
true
,
gemm_u8_8x8_nobias_identity
);
#endif
gemm_u8_8x8_nodot_nobias_identity
);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER
(
gemm_u8_8x8_nobias_relu
,
gemm_u8_8x8_nobias_identity
);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER
(
gemm_u8_8x8_no
dot_no
bias_relu
,
gemm_u8_8x8_no
dot_no
bias_identity
);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER
(
gemm_u8_8x8_nobias_hswish
,
gemm_u8_8x8_nobias_identity
);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER
(
gemm_u8_8x8_no
dot_no
bias_hswish
,
gemm_u8_8x8_no
dot_no
bias_identity
);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER
(
gemm_u8_8x8_bias_channel_identity
,
gemm_u8_8x8_nobias_identity
);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER
(
gemm_u8_8x8_
nodot_
bias_channel_identity
,
gemm_u8_8x8_no
dot_no
bias_identity
);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER
(
gemm_u8_8x8_bias_channel_relu
,
gemm_u8_8x8_nobias_identity
);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER
(
gemm_u8_8x8_
nodot_
bias_channel_relu
,
gemm_u8_8x8_no
dot_no
bias_identity
);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER
(
gemm_u8_8x8_bias_channel_hswish
,
gemm_u8_8x8_nobias_identity
);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER
(
gemm_u8_8x8_
nodot_
bias_channel_hswish
,
gemm_u8_8x8_no
dot_no
bias_identity
);
}
// namespace matmul
...
...
dnn/src/aarch64/matrix_mul/algos.cpp
浏览文件 @
f2b42bf0
...
...
@@ -24,9 +24,6 @@
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_impl.h"
#if MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
#endif
#include "midout.h"
MIDOUT_DECL
(
megdnn_aarch64_matmul_kern
)
...
...
@@ -394,7 +391,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_8x8::get_kern(
#endif
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
/* ==================== Int8x8x32 K8x12x4 Dotprod algo ==================== */
namespace
{
void
int8x8x32_k8x12x4_dotprod_kern
(
...
...
@@ -422,6 +419,9 @@ void int8x8x32_k8x12x4_dotprod_kern(
bool
MatrixMulImpl
::
AlgoInt8x8x32K8x12x4DotProd
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
()){
return
false
;
}
return
can_be_treated_as_int8x8x32
(
kern_size_param
);
}
...
...
@@ -484,6 +484,11 @@ void int8x8x32_mk4_8x12x4_dotprod_kern(
bool
MatrixMulImpl
::
AlgoInt8x8x32MK4_8x12x4DotProd
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
()){
return
false
;
}
return
kern_size_param
.
A_type
.
enumv
()
==
kern_size_param
.
B_type
.
enumv
()
&&
(
kern_size_param
.
A_type
.
enumv
()
==
DTypeEnum
::
Int8
||
kern_size_param
.
A_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)
&&
...
...
@@ -527,7 +532,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x12x4DotProd,
aarch64
::
matmul
::
gemm_mk4_s8_8x12
,
int8_t
,
int32_t
,
AlgoDataType
::
QINT8X8X32
,
MK4_DOT
);
#e
lse
#e
ndif
/* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */
namespace
{
...
...
@@ -727,7 +732,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x8x8,
aarch64
::
matmul
::
gemm_s8_8x8
,
int8_t
,
int32_t
,
AlgoDataType
::
QINT8X8X32
,
DEFAULT
);
#endif
/* ===================== Int8x8x16 K8x8x8 algo ===================== */
namespace
{
...
...
@@ -1151,7 +1155,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_kern(
return
kern_mk8_8x8
;
}
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
/* ==================== Quint8 K8x8x4 Dotprod algo ==================== */
namespace
{
void
quint8_k8x8x4_dotprod_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
...
...
@@ -1166,8 +1170,8 @@ void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
Bptr
=
kern_param
.
B
<
dt_uint8
>
();
auto
Cptr
=
kern_param
.
C
<
dt_int32
>
();
aarch64
::
matmul
::
gemm_u8_8x8
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
megdnn
::
matmul
::
GemmInterleaved
<
aarch64
::
matmul
::
gemm_u8_8x8
>
(
aarch64
::
matmul
::
gemm_u8_8x8
_dot
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
megdnn
::
matmul
::
GemmInterleaved
<
aarch64
::
matmul
::
gemm_u8_8x8
_dot
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
execute
(
Aptr
,
LDA
,
Bptr
,
LDB
,
Cptr
,
LDC
,
kern_param
.
workspace_ptr
);
...
...
@@ -1178,6 +1182,9 @@ void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
bool
MatrixMulImpl
::
AlgoQuint8K8x8x4DotProd
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
()){
return
false
;
}
return
kern_size_param
.
A_type
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
&&
kern_size_param
.
B_type
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
&&
kern_size_param
.
C_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
&&
...
...
@@ -1195,8 +1202,8 @@ size_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_workspace(
auto
A_type
=
kern_size_param
.
A_type
,
B_type
=
kern_size_param
.
B_type
,
C_type
=
kern_size_param
.
C_type
;
aarch64
::
matmul
::
gemm_u8_8x8
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
return
megdnn
::
matmul
::
GemmInterleaved
<
aarch64
::
matmul
::
gemm_u8_8x8
>
(
aarch64
::
matmul
::
gemm_u8_8x8
_dot
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
return
megdnn
::
matmul
::
GemmInterleaved
<
aarch64
::
matmul
::
gemm_u8_8x8
_dot
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
get_workspace_size
();
}
...
...
@@ -1212,7 +1219,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL
(
AlgoQuint8K8x8x4DotProd
,
megdnn_aarch64_matmul_kern
,
"AlgoQuint8K8x8x4DotProdImpl"
_hash
,
aarch64
::
matmul
::
gemm_u8_8x8
,
uint8_t
,
aarch64
::
matmul
::
gemm_u8_8x8
_dot
,
uint8_t
,
int32_t
,
AlgoDataType
::
QUINT8X8X32
,
DEFAULT
);
/* ===================== Quint8 Gemv DotProd algo ===================== */
...
...
@@ -1238,6 +1245,9 @@ void quint8_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
bool
MatrixMulImpl
::
AlgoQuint8GemvDotProd
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
()){
return
false
;
}
return
kern_size_param
.
A_type
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
&&
kern_size_param
.
B_type
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
&&
kern_size_param
.
C_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
&&
...
...
@@ -1257,7 +1267,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8GemvDotProd::get_kern(
const
KernSizeParam
&
)
const
{
return
quint8_gemv_dotprod_kern
;
}
#e
lse
#e
ndif
/* ===================== Quint8 K8x8x8 algo ===================== */
namespace
{
...
...
@@ -1322,7 +1332,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8,
aarch64
::
matmul
::
gemm_u8_8x8
,
uint8_t
,
int32_t
,
AlgoDataType
::
QUINT8X8X32
,
DEFAULT
);
#endif
/* ===================== Int8x8x16 K8x8x8 algo ===================== */
namespace
{
...
...
dnn/src/aarch64/matrix_mul/algos.h
浏览文件 @
f2b42bf0
...
...
@@ -111,7 +111,7 @@ public:
#endif
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
class
MatrixMulImpl
::
AlgoInt8x8x32K8x12x4DotProd
final
:
public
AlgoBase
{
public:
AlgoAttribute
attribute
()
const
override
{
...
...
@@ -141,7 +141,7 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
MEGDNN_DECL_ALGO_TYPE
(
AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD
)
};
#e
lse
#e
ndif
class
MatrixMulImpl
::
AlgoInt8x8x32MK4_4x4x16
final
:
public
AlgoBase
{
public:
...
...
@@ -187,7 +187,6 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
MEGDNN_DECL_ALGO_TYPE
(
AARCH64_INT8X8X32_K8X8X8
)
};
#endif
class
MatrixMulImpl
::
AlgoInt8x8x16K8x8x8
final
:
public
AlgoBase
{
public:
...
...
@@ -313,7 +312,7 @@ public:
MEGDNN_DECL_ALGO_TYPE
(
AARCH64_INT16X16X32_MK8_8X8
)
};
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
class
MatrixMulImpl
::
AlgoQuint8K8x8x4DotProd
final
:
public
AlgoBase
{
public:
AlgoAttribute
attribute
()
const
override
{
...
...
@@ -328,7 +327,6 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
MEGDNN_DECL_ALGO_TYPE
(
AARCH64_QUINT8_K8X8X4_DOTPROD
)
};
class
MatrixMulImpl
::
AlgoQuint8GemvDotProd
final
:
public
AlgoBase
{
public:
AlgoAttribute
attribute
()
const
override
{
...
...
@@ -344,8 +342,7 @@ public:
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
,
AlgoDataType
::
QUINT8X8X32
,
DEFAULT
)
MEGDNN_DECL_ALGO_TYPE
(
AARCH64_QUINT8_GEMV_DOTPROD
)
};
#else
#endif
class
MatrixMulImpl
::
AlgoQuint8K8x8x8
final
:
public
AlgoBase
{
public:
AlgoAttribute
attribute
()
const
override
{
...
...
@@ -358,7 +355,6 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
MEGDNN_DECL_ALGO_TYPE
(
AARCH64_QUINT8_K8X8X8
)
};
#endif
}
// namespace aarch64
}
// namespace megdnn
...
...
dnn/src/aarch64/matrix_mul/fp32/strategy.cpp
浏览文件 @
f2b42bf0
...
...
@@ -20,9 +20,6 @@
#include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/common/utils.h"
#if MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
#endif
using
namespace
megdnn
;
using
namespace
aarch64
;
...
...
dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h
浏览文件 @
f2b42bf0
...
...
@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if !(__ARM_FEATURE_DOTPROD)
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
...
...
@@ -851,6 +850,5 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
}
// namespace matmul_4x4x16
}
// namespace aarch64
}
// namespace megdnn
#endif
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h
浏览文件 @
f2b42bf0
...
...
@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if !(__ARM_FEATURE_DOTPROD)
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
...
...
@@ -1372,4 +1371,3 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr,
}
// namespace megdnn
// vim: syntax=cpp.doxygen
#endif
dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h
浏览文件 @
f2b42bf0
...
...
@@ -10,8 +10,6 @@
* implied.
*/
#include <cstring>
#if !(__ARM_FEATURE_DOTPROD)
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
...
...
@@ -887,6 +885,5 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin,
}
// namespace matmul_4x4x16
}
// namespace aarch64
}
// namespace megdnn
#endif
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/int8/strategy.cpp
浏览文件 @
f2b42bf0
...
...
@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if !(__ARM_FEATURE_DOTPROD)
#include "src/aarch64/matrix_mul/int8/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int8/kernel_4x4x16.h"
...
...
@@ -105,7 +104,6 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
packA
+=
K4
;
}
}
///////////////////////// gemm_mk4_s8_4x4 ////////////////////////////////////
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_mk4_s8_4x4
);
...
...
@@ -258,6 +256,5 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
packA
+=
K4
;
}
}
#endif
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/int8/strategy.h
浏览文件 @
f2b42bf0
...
...
@@ -10,7 +10,6 @@
*/
#pragma once
#if !(__ARM_FEATURE_DOTPROD)
#include "src/fallback/matrix_mul/gemm_common.h"
namespace
megdnn
{
...
...
@@ -30,5 +29,4 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true,
}
// namespace aarch64
}
// namespace megdnn
#endif
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h
浏览文件 @
f2b42bf0
...
...
@@ -9,8 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
...
...
@@ -50,7 +49,9 @@ namespace matmul_8x12x4 {
* same, I test in kirin980 with small and big core, here i just keep both the
* implementation.
*/
#if 1
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_8x12
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
)
{
K
/=
4
;
...
...
@@ -408,6 +409,7 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
);
}
#else
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_8x12
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
)
{
K
/=
4
;
...
...
@@ -650,7 +652,7 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
// +-------+-------+ - - - - +--------+--------+--------+
//
// Accumulator
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_4x12
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
m_remain
)
{
K
/=
4
;
...
...
@@ -837,7 +839,7 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,
// +-------+-------+ - - - - +---------+
//
// Accumulator
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_8x4
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
n_remain
)
{
K
/=
4
;
...
...
@@ -1038,7 +1040,7 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
// +-------+-------+ - - - - +--------+
//
// Accumulator
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_4x4
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
m_remain
,
int
n_remain
)
{
...
...
dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h
浏览文件 @
f2b42bf0
...
...
@@ -10,8 +10,7 @@
* implied.
*/
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
...
...
@@ -40,6 +39,7 @@ namespace matmul_mk4_8x12x4 {
//
// Accumulator
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_8x12
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
)
{
K
/=
4
;
...
...
@@ -60,7 +60,6 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
int32_t
*
outptr0
=
output
;
int32_t
*
outptr1
;
asm
volatile
(
// load accumulator C
"add %[outptr1], %[outptr0], %x[LDC]
\n
"
...
...
@@ -397,6 +396,7 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
//
// Accumulator
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_4x12
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
)
{
K
/=
4
;
...
...
@@ -543,6 +543,7 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,
// +--------+--------+ - - - - +------------+
// Accumulator
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_8x4
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
n_remain
)
{
K
/=
4
;
...
...
@@ -718,6 +719,7 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
// +--------+--------+ - - - - +------------+
// Accumulator
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_4x4
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
n_remain
)
{
K
/=
4
;
...
...
@@ -928,6 +930,5 @@ static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin,
}
// namespace matmul_mk4_8x12x4
}
// namespace aarch64
}
// namespace megdnn
#endif
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp
浏览文件 @
f2b42bf0
...
...
@@ -10,13 +10,13 @@
*/
#include "src/aarch64/matrix_mul/int8_dot/strategy.h"
#if MGB_ENABLE_DOT
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h"
#include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h"
#if __ARM_FEATURE_DOTPROD
using
namespace
megdnn
;
using
namespace
aarch64
;
using
namespace
aarch64
::
matmul
;
...
...
dnn/src/aarch64/matrix_mul/int8_dot/strategy.h
浏览文件 @
f2b42bf0
...
...
@@ -11,7 +11,7 @@
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
namespace
megdnn
{
namespace
aarch64
{
namespace
matmul
{
...
...
dnn/src/aarch64/matrix_mul/opr_impl.cpp
浏览文件 @
f2b42bf0
...
...
@@ -27,14 +27,13 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoF16K8x24x1
f16_k8x24x1
;
AlgoF16MK8_8x8
f16_mk8_8x8
;
#endif
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
AlgoInt8x8x32K8x12x4DotProd
int8x8x32_k8x12x4_dotprod
;
AlgoInt8x8x32MK4_8x12x4DotProd
int8x8x32_mk4_8x12x4_dotprod
;
#e
lse
#e
ndif
AlgoInt8x8x32MK4_4x4x16
int8x8x32_mk4_4x4x16
;
AlgoInt8x8x32K4x4x16
int8x8x32_k4x4x16
;
AlgoInt8x8x32K8x8x8
int8x8x32_k8x8x8
;
#endif
AlgoInt8x8x16K8x8x8
int8x8x16_k8x8x8
;
AlgoInt8x8x16K4x4x16
int8x8x16_k4x4x16
;
AlgoInt8x8x16MK4_16x12x4
int8x8x16_mk4_16x12x4
;
...
...
@@ -44,12 +43,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt16x16x32K12x8x1
int16x16x32_k12x8x1
;
AlgoInt16x16x32MK8_8x8
int16x16x32_mk8_8x8
;
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
AlgoQuint8K8x8x4DotProd
quint8_k8x8x4_dotprod
;
AlgoQuint8GemvDotProd
quint8_gemv_dotprod
;
#else
AlgoQuint8K8x8x8
quint8_k8x8x8
;
#endif
AlgoQuint8K8x8x8
quint8_k8x8x8
;
AlgoInt4x4x16K8x8x8
int4x4x16_k8x8x8
;
SmallVector
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
m_all_algos
;
...
...
@@ -66,14 +64,13 @@ public:
m_all_algos
.
emplace_back
(
&
f16_k8x24x1
);
m_all_algos
.
emplace_back
(
&
f16_mk8_8x8
);
#endif
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
m_all_algos
.
emplace_back
(
&
int8x8x32_k8x12x4_dotprod
);
m_all_algos
.
emplace_back
(
&
int8x8x32_mk4_8x12x4_dotprod
);
#e
lse
#e
ndif
m_all_algos
.
emplace_back
(
&
int8x8x32_k4x4x16
);
m_all_algos
.
emplace_back
(
&
int8x8x32_k8x8x8
);
m_all_algos
.
emplace_back
(
&
int8x8x32_mk4_4x4x16
);
#endif
m_all_algos
.
emplace_back
(
&
int8x8x16_k4x4x16
);
m_all_algos
.
emplace_back
(
&
int8x8x16_k8x8x8
);
m_all_algos
.
emplace_back
(
&
int8x8x16_mk4_k8x8x8
);
...
...
@@ -82,12 +79,11 @@ public:
m_all_algos
.
emplace_back
(
&
int16x16x32_k12x8x1
);
m_all_algos
.
emplace_back
(
&
int16x16x32_mk8_8x8
);
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
m_all_algos
.
emplace_back
(
&
quint8_gemv_dotprod
);
m_all_algos
.
emplace_back
(
&
quint8_k8x8x4_dotprod
);
#else
m_all_algos
.
emplace_back
(
&
quint8_k8x8x8
);
#endif
m_all_algos
.
emplace_back
(
&
quint8_k8x8x8
);
m_all_algos
.
emplace_back
(
&
int4x4x16_k8x8x8
);
for
(
auto
&&
algo
:
m_all_algos
)
{
...
...
dnn/src/aarch64/matrix_mul/opr_impl.h
浏览文件 @
f2b42bf0
...
...
@@ -41,16 +41,15 @@ private:
class
AlgoF16MK8_8x8
;
// Aarch64 F16 Format MK8 block 16x8
#endif
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
class
AlgoInt8x8x32K8x12x4DotProd
;
// Aarch64 Int8x8x32 Kernel
// 8x12x4 DotProduct
class
AlgoInt8x8x32MK4_8x12x4DotProd
;
// Aarch64 nchw44 Int8x8x32 Kernel
// 8x12x4 DotProduct
#e
lse
#e
ndif
class
AlgoInt8x8x32MK4_4x4x16
;
// Aarch64 nchw44 Int8x8x32 Kernel 4x4x16
class
AlgoInt8x8x32K4x4x16
;
// Aarch64 Int8x8x32 Kernel 4x4x16
class
AlgoInt8x8x32K8x8x8
;
// Aarch64 Int8x8x32 Kernel 8x8x8
#endif
class
AlgoInt8x8x16K8x8x8
;
// Aarch64 Int8x8x16 Kernel 8x8x8
class
AlgoInt8x8x16K4x4x16
;
// Aarch64 Int8x8x16 Kernel 4x4x16
class
AlgoInt8x8x16MK4_16x12x4
;
// Aarch64 Int8x8x16 Kernel 16x12x16
...
...
@@ -59,13 +58,12 @@ private:
class
AlgoInt16x16x32K12x8x1
;
// Aarch64 Int16x16x32 Kernel 12x8x1
class
AlgoInt16x16x32MK8_8x8
;
// Aarch64 Int16x16x32 Format MK8 block 8x8
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
class
AlgoQuint8K8x8x4DotProd
;
// Aarch64 Quint8 Kernel
// 8x8x4 DotProduct
class
AlgoQuint8GemvDotProd
;
// Aarch64 Quint8 Gemv DotProduct
#else
class
AlgoQuint8K8x8x8
;
// Aarch64 Quint8 Kernel 8x8x8
#endif
class
AlgoQuint8K8x8x8
;
// Aarch64 Quint8 Kernel 8x8x8
class
AlgoInt8x8x16MK4_K8x8x8
;
// Aarch64 Int8x8x16 Kernel 4x4x16
class
AlgoInt4x4x16K8x8x8
;
// Aarch64 Int4x4x16 Kernel 4x4x16
class
AlgoPack
;
...
...
dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h
浏览文件 @
f2b42bf0
...
...
@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if !(__ARM_FEATURE_DOTPROD)
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
...
...
@@ -1395,4 +1394,3 @@ static void gemm_u8_8x8_transpose_pack_B_n(dt_uint8* outptr,
}
// namespace megdnn
// vim: syntax=cpp.doxygen
#endif
dnn/src/aarch64/matrix_mul/quint8/strategy.cpp
浏览文件 @
f2b42bf0
...
...
@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if !(__ARM_FEATURE_DOTPROD)
#include "src/aarch64/matrix_mul/quint8/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h"
...
...
@@ -108,6 +107,5 @@ void gemm_u8_8x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M,
packA
+=
K4
;
}
}
#endif
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/quint8/strategy.h
浏览文件 @
f2b42bf0
...
...
@@ -10,7 +10,6 @@
*/
#pragma once
#if !(__ARM_FEATURE_DOTPROD)
#include "src/fallback/matrix_mul/gemm_common.h"
namespace
megdnn
{
...
...
@@ -23,6 +22,5 @@ MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 8, 8, 8, false, true,
}
// namespace matmul
}
// namespace aarch64
}
// namespace megdnn
#endif
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp
浏览文件 @
f2b42bf0
...
...
@@ -10,15 +10,13 @@
*/
#include "src/aarch64/matrix_mul/quint8_dot/gemv.h"
#i
nclude <cstddef>
#i
f MGB_ENABLE_DOT
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/common/unroll_macro.h"
#if __ARM_FEATURE_DOTPROD
namespace
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
gemv_naive_n
(
const
uint8_t
*
__restrict
A
,
const
uint8_t
*
__restrict
B
,
int32_t
*
__restrict
C
,
size_t
M
,
size_t
N
,
size_t
K
,
size_t
Astride
,
size_t
Bstride
,
size_t
Cstride
,
...
...
@@ -146,7 +144,6 @@ void gemv_naive_n(const uint8_t* __restrict A, const uint8_t* __restrict B,
acc
[
0
]
+
acc
[
1
]
+
acc
[
2
]
+
acc
[
3
]
+
zAB
-
acc_zA
-
acc_zB
;
}
}
}
// namespace
bool
megdnn
::
aarch64
::
matmul
::
is_gemv_like_preferred_quint8
(
...
...
@@ -171,7 +168,5 @@ void megdnn::aarch64::matmul::gemv_like_quint8(
return
gemv_naive_n
(
A
,
B
,
C
,
M
,
N
,
K
,
Astride
,
Bstride
,
Cstride
,
zero_point_A
,
zero_point_B
);
}
#endif
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h
浏览文件 @
f2b42bf0
...
...
@@ -10,10 +10,9 @@
*/
#pragma once
#include <cstddef>
#include <cstdint>
#include "src/common/utils.h"
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
namespace
megdnn
{
namespace
aarch64
{
namespace
matmul
{
...
...
dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h
浏览文件 @
f2b42bf0
...
...
@@ -9,8 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
...
...
@@ -56,7 +55,7 @@ namespace matmul_8x8x4 {
// C = sum((A - zA) * (B - zB)) = sum(A * B) - sum(A) * zB - sum(B) * zA + zA *
// zB * k
// A -> v27, v28 | B -> v29, v30 | zA * zB * k -> v26
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_8x8
(
const
uint8_t
*
packA
,
const
uint8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
,
uint8_t
zero_point_A
,
uint8_t
zero_point_B
,
uint32_t
zAB
)
{
...
...
@@ -293,6 +292,7 @@ static void kern_8x8(const uint8_t* packA, const uint8_t* packB, int K,
// zB * k
// A -> v28 | B -> v29, v30 | zA * zB * k -> v26
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_4x8
(
const
uint8_t
*
packA
,
const
uint8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
m_remain
,
uint8_t
zero_point_A
,
uint8_t
zero_point_B
,
uint32_t
zAB
)
{
...
...
@@ -495,6 +495,7 @@ static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K,
// zB * k
// A -> v27, v28 | B -> v29 | zA * zB * k -> v26
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_8x4
(
const
uint8_t
*
packA
,
const
uint8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
n_remain
,
uint8_t
zero_point_A
,
uint8_t
zero_point_B
,
uint32_t
zAB
)
{
...
...
@@ -733,6 +734,7 @@ static void kern_8x4(const uint8_t* packA, const uint8_t* packB, int K,
// zB * k
// A -> v28 | B -> v29 | zA * zB * k -> v26
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_4x4
(
const
uint8_t
*
packA
,
const
uint8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
m_remain
,
int
n_remain
,
uint8_t
zero_point_A
,
uint8_t
zero_point_B
,
...
...
dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp
浏览文件 @
f2b42bf0
...
...
@@ -16,14 +16,14 @@
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
using
namespace
megdnn
;
using
namespace
aarch64
;
using
namespace
aarch64
::
matmul
;
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_u8_8x8
);
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_u8_8x8
_dot
);
void
gemm_u8_8x8
::
pack_A
(
uint8_t
*
outptr
,
const
uint8_t
*
inptr
,
int
ldin
,
void
gemm_u8_8x8
_dot
::
pack_A
(
uint8_t
*
outptr
,
const
uint8_t
*
inptr
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
if
(
transpose
)
{
...
...
@@ -35,7 +35,7 @@ void gemm_u8_8x8::pack_A(uint8_t* outptr, const uint8_t* inptr, int ldin,
}
}
void
gemm_u8_8x8
::
pack_B
(
uint8_t
*
out
,
const
uint8_t
*
in
,
int
ldin
,
int
x0
,
void
gemm_u8_8x8
_dot
::
pack_B
(
uint8_t
*
out
,
const
uint8_t
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
if
(
transpose
)
{
matmul_8x8x4
::
gemm_u8_8x8_interleave_pack_helper
(
out
,
in
,
ldin
,
x0
,
...
...
@@ -46,7 +46,7 @@ void gemm_u8_8x8::pack_B(uint8_t* out, const uint8_t* in, int ldin, int x0,
}
}
void
gemm_u8_8x8
::
kern
(
const
uint8_t
*
packA
,
const
uint8_t
*
packB
,
size_t
M
,
void
gemm_u8_8x8
_dot
::
kern
(
const
uint8_t
*
packA
,
const
uint8_t
*
packB
,
size_t
M
,
size_t
N
,
size_t
K
,
dt_int32
*
C
,
size_t
LDC
,
bool
is_first_k
,
const
dt_int32
*
,
dt_int32
*
)
const
{
megdnn_assert
(
A_dtype
.
enumv
()
==
B_dtype
.
enumv
()
&&
...
...
dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h
浏览文件 @
f2b42bf0
...
...
@@ -11,13 +11,13 @@
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
namespace
megdnn
{
namespace
aarch64
{
namespace
matmul
{
MEGDNN_REG_GEMM_STRATEGY
(
uint8_t
,
int32_t
,
int32_t
,
8
,
8
,
4
,
false
,
true
,
gemm_u8_8x8
);
gemm_u8_8x8
_dot
);
}
// namespace aarch64
}
// namespace matmul
...
...
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h
浏览文件 @
f2b42bf0
...
...
@@ -23,9 +23,6 @@
#include "src/armv7/matrix_mul/asm/common.h"
#endif
#if MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
#endif
using
namespace
megdnn
;
using
namespace
arm_common
;
...
...
dnn/src/arm_common/conv_bias/int8/algos.cpp
浏览文件 @
f2b42bf0
...
...
@@ -161,10 +161,13 @@ ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns(
return
{};
}
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
/* ===================== dot stride1 algo ======================== */
bool
ConvBiasImpl
::
AlgoDotS8DirectStride1
::
usable
(
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
())
{
return
false
;
}
return
direct_dotprod_int8_stride1
::
can_conv_direct_stride1_int8
(
param
);
}
...
...
@@ -195,6 +198,9 @@ ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns(
/* ===================== dot stride2 algo ======================== */
bool
ConvBiasImpl
::
AlgoDotS8DirectStride2
::
usable
(
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
()){
return
false
;
}
return
direct_dotprod_int8_stride2
::
can_conv_direct_stride2_int8
(
param
);
}
...
...
dnn/src/arm_common/conv_bias/int8/algos.h
浏览文件 @
f2b42bf0
...
...
@@ -129,7 +129,7 @@ public:
MEGDNN_DECL_ALGO_TYPE
(
ARM_COMMON_CHANWISE_STRD2_NCHW44_S8
)
};
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
class
ConvBiasImpl
::
AlgoDotS8DirectNCHWNCHW44
final
:
public
AlgoBase
{
public:
...
...
dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp
浏览文件 @
f2b42bf0
...
...
@@ -9,8 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/direct_dotprod.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
...
...
@@ -90,6 +90,7 @@ inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index) {
_sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem);
template
<
bool
first_ic
,
bool
last_ic
,
BiasMode
bias_mode
,
typename
Op
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_bias
::
conv_direct_stride1_2x2_int8_dot
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
IH
,
const
size_t
IW
,
...
...
@@ -325,6 +326,7 @@ void conv_bias::conv_direct_stride1_2x2_int8_dot(
}
template
<
bool
first_ic
,
bool
last_ic
,
BiasMode
bias_mode
,
typename
Op
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_bias
::
conv_direct_stride1_3x3_int8_dot
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
IH
,
const
size_t
IW
,
...
...
@@ -560,6 +562,7 @@ void conv_bias::conv_direct_stride1_3x3_int8_dot(
}
template
<
bool
first_ic
,
bool
last_ic
,
BiasMode
bias_mode
,
typename
Op
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_bias
::
conv_direct_stride2_2x2_int8_dot
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
IH
,
const
size_t
IW
,
...
...
@@ -655,6 +658,7 @@ void conv_bias::conv_direct_stride2_2x2_int8_dot(
}
template
<
bool
first_ic
,
bool
last_ic
,
BiasMode
bias_mode
,
typename
Op
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_bias
::
conv_direct_stride2_3x3_int8_dot
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
IH
,
const
size_t
IW
,
...
...
@@ -810,6 +814,7 @@ void conv_bias::conv_direct_stride2_3x3_int8_dot(
_sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem);
template
<
bool
first_ic
,
bool
last_ic
,
BiasMode
bias_mode
,
typename
Op
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_bias
::
conv_direct_stride2_5x5_int8_dot
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
IH
,
const
size_t
IW
,
...
...
@@ -1108,6 +1113,7 @@ void conv_bias::conv_direct_stride2_5x5_int8_dot(
}
template
<
bool
first_ic
,
bool
last_ic
,
BiasMode
bias_mode
,
typename
Op
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_bias
::
conv_direct_stride2_7x7_int8_dot
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
IH
,
const
size_t
IW
,
...
...
@@ -1470,6 +1476,7 @@ void conv_bias::conv_direct_stride2_7x7_int8_dot(
}
template
<
bool
first_ic
,
bool
last_ic
,
BiasMode
bias_mode
,
typename
Op
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_bias
::
conv_direct_stride1_5x5_int8_dot
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
IH
,
const
size_t
IW
,
...
...
@@ -1770,6 +1777,7 @@ void conv_bias::conv_direct_stride1_5x5_int8_dot(
}
template
<
bool
first_ic
,
bool
last_ic
,
BiasMode
bias_mode
,
typename
Op
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_bias
::
conv_direct_stride1_7x7_int8_dot
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
size_t
IH
,
const
size_t
IW
,
...
...
@@ -2115,6 +2123,7 @@ void conv_bias::conv_direct_stride1_7x7_int8_dot(
#undef ST1_S32X4
#undef ST2_S32X4X2
#define INSTANTIATION(stride, i, first_ic, last_ic, bias, Op) \
template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_dot< \
first_ic, last_ic, bias, Op>( \
...
...
dnn/src/arm_common/conv_bias/int8/direct_dotprod.h
浏览文件 @
f2b42bf0
...
...
@@ -8,8 +8,8 @@
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/opr_impl.h"
#if MGB_ENABLE_DOT
#include "src/fallback/conv_bias/common.h"
namespace
megdnn
{
...
...
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp
浏览文件 @
f2b42bf0
...
...
@@ -10,9 +10,8 @@
* implied.
*/
#ifdef __ARM_FEATURE_DOTPROD
#include "src/arm_common/elemwise_helper/kimpl/typecvt.h"
#if MGB_ENABLE_DOT
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
...
...
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h
浏览文件 @
f2b42bf0
...
...
@@ -10,11 +10,10 @@
* implied.
*/
#if __ARM_FEATURE_DOTPROD
#pragma once
#include "src/arm_common/conv_bias/opr_impl.h"
#if MGB_ENABLE_DOT
namespace
megdnn
{
namespace
arm_common
{
...
...
@@ -78,4 +77,4 @@ void copy_packed_src_int8_nchw44(int8_t* dst, const int dst_step,
#endif
// vim: syntax=cpp.doxygen
\ No newline at end of file
// vim: syntax=cpp.doxygen
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp
浏览文件 @
f2b42bf0
...
...
@@ -10,9 +10,8 @@
* implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/block_helper.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h"
#include "src/arm_common/elemwise_op.h"
...
...
@@ -159,6 +158,9 @@ static void conv_kern(const WorkspaceBundle& bundle,
bool
ConvBiasImpl
::
AlgoDotS8Direct_NCHW44
::
usable
(
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
algo_selection_strategy
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
()){
return
false
;
}
MEGDNN_MARK_USED_VAR
(
algo_selection_strategy
);
auto
&&
fm
=
param
.
filter_meta
;
auto
FH
=
fm
.
spatial
[
0
];
...
...
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h
浏览文件 @
f2b42bf0
...
...
@@ -11,9 +11,9 @@
* implied.
*/
#pragma once
#if __ARM_FEATURE_DOTPROD
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/intrinsic_helper.h"
#include "src/arm_common/neon_struct.h"
...
...
@@ -208,6 +208,7 @@ MEGDNN_ALWAYS_INLINE void store_ocx_owx_remain_static(int32x4_t res[][8],
template
<
int
res_row
,
int
src_row
,
int
src_start_idx
,
int
weight_idx
,
typename
T
,
typename
T2
,
typename
T3
>
struct
ShiftCalHelper
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
MEGDNN_ALWAYS_INLINE
void
impl
(
T
&
res
,
T2
&
src
,
T3
&
weight
)
{
#define cb(step) \
res[res_row][step] = \
...
...
@@ -221,6 +222,7 @@ struct ShiftCalHelper {
template
<
int
res_row
,
int
src_row
,
int
src_start_idx
,
int
weight_idx
,
typename
T
,
typename
T2
,
typename
T3
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
MEGDNN_ALWAYS_INLINE
void
cal_helper
(
T
&
res
,
T2
&
src
,
T3
&
weight
)
{
ShiftCalHelper
<
res_row
,
src_row
,
src_start_idx
,
weight_idx
,
T
,
T2
,
T3
>::
impl
(
res
,
src
,
weight
);
...
...
@@ -242,4 +244,4 @@ struct KernNeonSdotNCHW44 {
}
// namespace arm_common
}
// namespace megdnn
#endif
// vim: syntax=cpp.doxygen
\ No newline at end of file
// vim: syntax=cpp.doxygen
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp
浏览文件 @
f2b42bf0
...
...
@@ -10,8 +10,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h"
#if MGB_ENABLE_DOT
namespace
megdnn
{
namespace
arm_common
{
...
...
@@ -20,6 +20,7 @@ template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int
filter_size
,
int
oc_interval
,
int
ow_interval
>
struct
KernNeonSdotNCHW44
<
dst_type
,
1
,
bias_mode
,
Op
,
ow_remain
,
filter_size
,
oc_interval
,
ow_interval
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
impl
(
dst_type
*
dst
,
const
int
dst_step
,
const
int8_t
*
src
,
const
int
ih
,
const
int
iw
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
const
int
ic
,
const
Op
&
op
)
{
...
...
@@ -109,6 +110,7 @@ struct KernNeonSdotNCHW44<dst_type, 1, bias_mode, Op, ow_remain, filter_size,
template
<
typename
dst_type
,
int
stride
,
BiasMode
bias_mode
,
typename
Op
,
int
filter_size
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_direct_sdot_int8_nchw44
(
dst_type
*
dst
,
const
int
oh
,
const
int
ow
,
const
int8_t
*
src
,
const
int
ih
,
const
int
iw
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
...
...
@@ -317,4 +319,4 @@ FOR_FILTER(1)
}
// namespace arm_common
}
// namespace megdnn
#endif
// vim: syntax=cpp.doxygen
\ No newline at end of file
// vim: syntax=cpp.doxygen
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp
浏览文件 @
f2b42bf0
...
...
@@ -10,9 +10,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h"
#if MGB_ENABLE_DOT
namespace
megdnn
{
namespace
arm_common
{
namespace
direct_dotprod_nchw44
{
...
...
@@ -20,6 +20,7 @@ template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int
filter_size
,
int
oc_interval
,
int
ow_interval
>
struct
KernNeonSdotNCHW44
<
dst_type
,
2
,
bias_mode
,
Op
,
ow_remain
,
filter_size
,
oc_interval
,
ow_interval
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
impl
(
dst_type
*
dst
,
const
int
dst_step
,
const
int8_t
*
src
,
const
int
ih
,
const
int
iw
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
const
int
ic
,
const
Op
&
op
)
{
...
...
@@ -110,6 +111,7 @@ struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size,
template
<
typename
dst_type
,
int
stride
,
BiasMode
bias_mode
,
typename
Op
,
int
filter_size
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_direct_sdot_int8_nchw44
(
dst_type
*
dst
,
const
int
oh
,
const
int
ow
,
const
int8_t
*
src
,
const
int
ih
,
const
int
iw
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
...
...
@@ -319,4 +321,4 @@ FOR_FILTER(2)
}
// namespace megdnn
#endif
// vim: syntax=cpp.doxygen
\ No newline at end of file
// vim: syntax=cpp.doxygen
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp
浏览文件 @
f2b42bf0
...
...
@@ -11,8 +11,8 @@
* implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
#if MGB_ENABLE_DOT
namespace
megdnn
{
namespace
arm_common
{
namespace
dot_direct_nchw_nchw44
{
...
...
@@ -20,6 +20,7 @@ namespace dot_direct_nchw_nchw44 {
template
<
int
src_idx
,
int
weight_idx
,
typename
Func
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
struct
ShiftCalHelper
<
src_idx
,
weight_idx
,
2
,
Func
,
8
,
1
,
T
,
T2
,
T3
,
T4
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
)
{
#define cb(step) \
c[0][step] = Func::template impl<(src_idx + step) % 4>( \
...
...
@@ -35,6 +36,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 1, T, T2, T3, T4> {
template
<
int
src_idx
,
int
weight_idx
,
typename
Func
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
struct
ShiftCalHelper
<
src_idx
,
weight_idx
,
1
,
Func
,
8
,
1
,
T
,
T2
,
T3
,
T4
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
)
{
#define cb(step) \
c[0][step] = Func::template impl<(src_idx + step) % 4>( \
...
...
@@ -49,6 +51,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int
ow_block
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
2
,
oc_block
,
ow_block
,
1
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
...
...
@@ -97,6 +100,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int
ow_block
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
3
,
oc_block
,
ow_block
,
1
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
...
...
@@ -151,6 +155,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int
ow_block
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
5
,
oc_block
,
ow_block
,
1
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
...
...
@@ -200,6 +205,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int
ow_block
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
7
,
oc_block
,
ow_block
,
1
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
...
...
@@ -302,6 +308,7 @@ void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base,
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
filter_size
,
int
stride
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_direct_int8_nchw_nchw44_dot
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
int
oc
,
const
int
ic
,
...
...
@@ -445,4 +452,4 @@ DISPATCH_CONV_KERN(1);
}
// namespace arm_common
}
// namespace megdnn
#endif
// vim: syntax=cpp.doxygen
\ No newline at end of file
// vim: syntax=cpp.doxygen
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp
浏览文件 @
f2b42bf0
...
...
@@ -10,8 +10,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
#if MGB_ENABLE_DOT
namespace
megdnn
{
namespace
arm_common
{
namespace
dot_direct_nchw_nchw44
{
...
...
@@ -19,6 +19,7 @@ namespace dot_direct_nchw_nchw44 {
template
<
int
src_idx
,
int
weight_idx
,
typename
Func
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
struct
ShiftCalHelper
<
src_idx
,
weight_idx
,
2
,
Func
,
8
,
2
,
T
,
T2
,
T3
,
T4
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
)
{
#define cb(step) \
c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \
...
...
@@ -42,6 +43,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 2, T, T2, T3, T4> {
template
<
int
src_idx
,
int
weight_idx
,
typename
Func
,
typename
T
,
typename
T2
,
typename
T3
,
typename
T4
>
struct
ShiftCalHelper
<
src_idx
,
weight_idx
,
1
,
Func
,
8
,
2
,
T
,
T2
,
T3
,
T4
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
impl
(
T
&
c
,
T2
&
src
,
T3
&
weight
)
{
#define cb(step) \
c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \
...
...
@@ -60,6 +62,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int
ow_block
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
2
,
oc_block
,
ow_block
,
2
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
...
...
@@ -111,6 +114,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int
ow_block
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
3
,
oc_block
,
ow_block
,
2
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
...
...
@@ -169,6 +173,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int
ow_block
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
5
,
oc_block
,
ow_block
,
2
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
...
...
@@ -224,6 +229,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int
ow_block
>
struct
KerNeonDotXXs2Nchw44Int8
<
bias_mode
,
Op
,
remain_w
,
7
,
oc_block
,
ow_block
,
2
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
impl
(
const
int8_t
*
src_ptr
,
const
int8_t
*
weight_ptr
,
const
int32_t
*
bias_ptr
,
int8_t
*
dst_ptr
,
int
ic
,
int
ih
,
int
iw
,
int
ld_dst_oc
,
const
Op
&
op
)
{
...
...
@@ -289,6 +295,7 @@ void pack_src_int8_nchw_nchw44_dot<2>(
}
template
<
BiasMode
bias_mode
,
typename
Op
,
int
filter_size
,
int
stride
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_direct_int8_nchw_nchw44_dot
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
int8_t
*
dst
,
const
int
oc
,
const
int
ic
,
...
...
@@ -434,4 +441,4 @@ DISPATCH_CONV_KERN(2);
}
// namespace megdnn
#endif
// vim: syntax=cpp.doxygen
\ No newline at end of file
// vim: syntax=cpp.doxygen
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
浏览文件 @
f2b42bf0
...
...
@@ -10,8 +10,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "megdnn/oprs.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
...
...
@@ -175,6 +175,9 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
bool
ConvBiasImpl
::
AlgoDotS8DirectNCHWNCHW44
::
usable
(
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
()){
return
false
;
}
return
nchw_nchwxx_valid
<
NchwNchwxxType
::
NCHW44_INT8_DOT
>
(
param
.
src_type
.
enumv
(),
param
.
filter_type
.
enumv
(),
param
.
dst_type
.
enumv
(),
param
.
filter_meta
,
param
.
bias_mode
,
...
...
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h
浏览文件 @
f2b42bf0
...
...
@@ -11,9 +11,9 @@
* implied.
*/
#pragma once
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
...
...
dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp
浏览文件 @
f2b42bf0
...
...
@@ -8,9 +8,9 @@
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/stride1_dotprod.h"
#if MGB_ENABLE_DOT
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod.h"
#include "src/arm_common/conv_bias/int8/strategy.h"
...
...
dnn/src/arm_common/conv_bias/int8/stride1_dotprod.h
浏览文件 @
f2b42bf0
...
...
@@ -8,10 +8,10 @@
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#pragma once
#include "src/arm_common/conv_bias/opr_impl.h"
#if MGB_ENABLE_DOT
namespace
megdnn
{
namespace
arm_common
{
namespace
direct_dotprod_int8_stride1
{
...
...
dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp
浏览文件 @
f2b42bf0
...
...
@@ -9,8 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/stride2_dotprod.h"
#if MGB_ENABLE_DOT
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod.h"
#include "src/arm_common/conv_bias/int8/strategy.h"
...
...
dnn/src/arm_common/conv_bias/int8/stride2_dotprod.h
浏览文件 @
f2b42bf0
...
...
@@ -9,9 +9,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#pragma once
#include "src/arm_common/conv_bias/opr_impl.h"
#if MGB_ENABLE_DOT
namespace
megdnn
{
namespace
arm_common
{
...
...
dnn/src/arm_common/conv_bias/opr_impl.cpp
浏览文件 @
f2b42bf0
...
...
@@ -60,7 +60,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8x8x16ChanWiseStride1Stride2NCHW44
s8x8x16_channel_wise_stride1_stride2_nchw44
;
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
AlgoDotS8DirectStride1
ds8_direct_stride1
;
AlgoDotS8DirectStride2
ds8_direct_stride2
;
AlgoDotU8DirectStride1
du8_direct_stride1
;
...
...
@@ -94,7 +94,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
public:
AlgoPack
()
{
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
m_direct_algos
.
emplace_back
(
&
ds8_direct_stride1
);
m_direct_algos
.
emplace_back
(
&
ds8_direct_stride2
);
m_direct_algos
.
emplace_back
(
&
du8_direct_stride1
);
...
...
dnn/src/arm_common/conv_bias/opr_impl.h
浏览文件 @
f2b42bf0
...
...
@@ -70,7 +70,7 @@ private:
class
AlgoFP16WinogradF63
;
class
AlgoFP16WinogradF23_8x8
;
#endif
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
class
AlgoDotS8DirectNCHWNCHW44
;
class
AlgoDotS8DirectStride1
;
class
AlgoDotS8DirectStride2
;
...
...
dnn/src/arm_common/conv_bias/quint8/algos.cpp
浏览文件 @
f2b42bf0
...
...
@@ -11,7 +11,6 @@
*/
#include "src/arm_common/conv_bias/quint8/algos.h"
#include "midout.h"
#include "src/arm_common/conv_bias/quint8/stride1.h"
#include "src/arm_common/conv_bias/quint8/stride1_dotprod.h"
#include "src/arm_common/conv_bias/quint8/stride2.h"
...
...
@@ -19,6 +18,8 @@
#include "src/arm_common/elemwise_op.h"
#include "src/fallback/conv_bias/common.h"
#include "midout.h"
MIDOUT_DECL
(
megdnn_arm_common_conv_bias_quint8
)
using
namespace
megdnn
;
...
...
@@ -84,10 +85,13 @@ ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns(
MIDOUT_END
();
return
{};
}
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
/* ===================== stride1 algo ===================== */
bool
ConvBiasImpl
::
AlgoDotU8DirectStride1
::
usable
(
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
()){
return
false
;
}
return
direct_dotprod_quint8_stride1
::
can_conv_direct_stride1_quint8
(
param
);
}
...
...
@@ -118,6 +122,9 @@ ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns(
/* ===================== stride2 algo ===================== */
bool
ConvBiasImpl
::
AlgoDotU8DirectStride2
::
usable
(
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
()){
return
false
;
}
return
direct_dotprod_quint8_stride2
::
can_conv_direct_stride2_quint8
(
param
);
}
...
...
dnn/src/arm_common/conv_bias/quint8/algos.h
浏览文件 @
f2b42bf0
...
...
@@ -55,7 +55,7 @@ public:
}
MEGDNN_DECL_ALGO_TYPE
(
ARM_COMMON_DIRECT_STRD2_QU8
)
};
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
class
ConvBiasImpl
::
AlgoDotU8DirectStride1
final
:
public
AlgoBase
{
public:
...
...
dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp
浏览文件 @
f2b42bf0
...
...
@@ -9,8 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/quint8/direct_dotprod.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
...
...
@@ -120,6 +120,7 @@ inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index){
template
<
bool
first_ic
,
bool
last_ic
,
bool
fused_kern
,
BiasMode
bias_mode
,
typename
Op
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_bias
::
conv_direct_stride1_2x2_quint8_dot
(
const
uint8_t
*
src
,
const
uint8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
uint8_t
*
dst
,
const
size_t
IH
,
const
size_t
IW
,
...
...
@@ -452,6 +453,7 @@ void conv_bias::conv_direct_stride1_2x2_quint8_dot(
template
<
bool
first_ic
,
bool
last_ic
,
bool
fused_kern
,
BiasMode
bias_mode
,
typename
Op
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_bias
::
conv_direct_stride1_3x3_quint8_dot
(
const
uint8_t
*
src
,
const
uint8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
uint8_t
*
dst
,
const
size_t
IH
,
const
size_t
IW
,
...
...
@@ -691,6 +693,7 @@ void conv_bias::conv_direct_stride1_3x3_quint8_dot(
template
<
bool
first_ic
,
bool
last_ic
,
bool
fused_kern
,
BiasMode
bias_mode
,
typename
Op
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_bias
::
conv_direct_stride2_2x2_quint8_dot
(
const
uint8_t
*
src
,
const
uint8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
uint8_t
*
dst
,
const
size_t
IH
,
const
size_t
IW
,
...
...
@@ -801,6 +804,7 @@ void conv_bias::conv_direct_stride2_2x2_quint8_dot(
template
<
bool
first_ic
,
bool
last_ic
,
bool
fused_kern
,
BiasMode
bias_mode
,
typename
Op
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_bias
::
conv_direct_stride2_3x3_quint8_dot
(
const
uint8_t
*
src
,
const
uint8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
uint8_t
*
dst
,
const
size_t
IH
,
const
size_t
IW
,
...
...
@@ -1135,6 +1139,7 @@ void conv_bias::conv_direct_stride2_3x3_quint8_dot(
template
<
bool
first_ic
,
bool
last_ic
,
bool
fused_kern
,
BiasMode
bias_mode
,
typename
Op
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_bias
::
conv_direct_stride1_5x5_quint8_dot
(
const
uint8_t
*
src
,
const
uint8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
uint8_t
*
dst
,
const
size_t
IH
,
const
size_t
IW
,
...
...
@@ -1443,6 +1448,7 @@ void conv_bias::conv_direct_stride1_5x5_quint8_dot(
template
<
bool
first_ic
,
bool
last_ic
,
bool
fused_kern
,
BiasMode
bias_mode
,
typename
Op
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_bias
::
conv_direct_stride1_7x7_quint8_dot
(
const
uint8_t
*
src
,
const
uint8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
uint8_t
*
dst
,
const
size_t
IH
,
const
size_t
IW
,
...
...
@@ -1785,6 +1791,7 @@ void conv_bias::conv_direct_stride1_7x7_quint8_dot(
template
<
bool
first_ic
,
bool
last_ic
,
bool
fused_kern
,
BiasMode
bias_mode
,
typename
Op
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_bias
::
conv_direct_stride2_5x5_quint8_dot
(
const
uint8_t
*
src
,
const
uint8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
uint8_t
*
dst
,
const
size_t
IH
,
const
size_t
IW
,
...
...
@@ -2090,6 +2097,7 @@ void conv_bias::conv_direct_stride2_5x5_quint8_dot(
template
<
bool
first_ic
,
bool
last_ic
,
bool
fused_kern
,
BiasMode
bias_mode
,
typename
Op
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
conv_bias
::
conv_direct_stride2_7x7_quint8_dot
(
const
uint8_t
*
src
,
const
uint8_t
*
filter
,
const
int32_t
*
bias
,
int32_t
*
temp
,
uint8_t
*
dst
,
const
size_t
IH
,
const
size_t
IW
,
...
...
dnn/src/arm_common/conv_bias/quint8/direct_dotprod.h
浏览文件 @
f2b42bf0
...
...
@@ -8,9 +8,9 @@
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/opr_impl.h"
#if MGB_ENABLE_DOT
#include "src/fallback/conv_bias/common.h"
namespace
megdnn
{
...
...
dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp
浏览文件 @
f2b42bf0
...
...
@@ -8,8 +8,8 @@
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/quint8/stride1_dotprod.h"
#if MGB_ENABLE_DOT
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/quint8/direct_dotprod.h"
#include "src/arm_common/elemwise_op.h"
...
...
dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.h
浏览文件 @
f2b42bf0
...
...
@@ -8,10 +8,10 @@
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#pragma once
#include "src/arm_common/conv_bias/opr_impl.h"
#if MGB_ENABLE_DOT
namespace
megdnn
{
namespace
arm_common
{
...
...
dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp
浏览文件 @
f2b42bf0
...
...
@@ -8,8 +8,8 @@
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/quint8/stride2_dotprod.h"
#if MGB_ENABLE_DOT
#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/quint8/direct_dotprod.h"
#include "src/arm_common/elemwise_op.h"
...
...
dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.h
浏览文件 @
f2b42bf0
...
...
@@ -8,10 +8,10 @@
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#pragma once
#include "src/arm_common/conv_bias/opr_impl.h"
#if MGB_ENABLE_DOT
namespace
megdnn
{
namespace
arm_common
{
...
...
dnn/src/arm_common/convolution/int8x8x32/algos.cpp
浏览文件 @
f2b42bf0
...
...
@@ -13,21 +13,24 @@
#include "src/arm_common/convolution/int8x8x32/algos.h"
#include "src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h"
#include "src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h"
#include "src/common/opr_delegate.h"
#include "midout.h"
#include "src/common/opr_delegate.h"
MIDOUT_DECL
(
megdnn_arm_conv_int8832_kimpl
)
using
namespace
megdnn
;
using
namespace
arm_common
;
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
/* ===================== ConvolutionBackwardData ===================== */
/* ===================== direct stride 1 algo ===================== */
bool
ConvolutionBackwardDataImpl
::
AlgoSdot8DirectStride1
::
usable
(
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
()){
return
false
;
}
return
deconv
::
can_stride1_int8x8x32_dot
(
param
);
}
...
...
@@ -57,6 +60,9 @@ ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern(
bool
ConvolutionBackwardDataImpl
::
AlgoSdot8DirectStride2
::
usable
(
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
()){
return
false
;
}
return
deconv
::
can_stride2_int8x8x32_dot
(
param
);
}
...
...
dnn/src/arm_common/convolution/int8x8x32/algos.h
浏览文件 @
f2b42bf0
...
...
@@ -17,7 +17,7 @@
namespace
megdnn
{
namespace
arm_common
{
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
/* ===================== ConvolutionBackwardData ===================== */
class
ConvolutionBackwardDataImpl
::
AlgoSdot8DirectStride1
final
...
...
dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp
浏览文件 @
f2b42bf0
...
...
@@ -9,11 +9,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h"
#if MGB_ENABLE_DOT
#include "src/common/utils.h"
#include <cstring>
#include "src/arm_common/simd_macro/marm_neon.h"
using
namespace
megdnn
;
...
...
@@ -94,6 +92,7 @@ inline int8x16_t vqtbl1q_s8_common(int8x16_t a, uint8x16_t index) {
_sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k1_idx, _elem); \
_sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem);
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
deconv_direct_2x2
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
int32_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
IC
)
{
MEGDNN_MARK_USED_VAR
(
IH
);
...
...
@@ -328,6 +327,7 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst,
}
}
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
deconv_direct_3x3
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
int32_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
IC
)
{
MEGDNN_MARK_USED_VAR
(
IH
);
...
...
@@ -530,6 +530,7 @@ void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst,
_sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k01_idx, _elem); \
_sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem);
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
deconv_direct_5x5
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
int32_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
IC
)
{
MEGDNN_MARK_USED_VAR
(
IH
);
...
...
@@ -777,6 +778,7 @@ void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst,
}
}
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
deconv_direct_7x7
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
int32_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
IC
)
{
MEGDNN_MARK_USED_VAR
(
IH
);
...
...
@@ -1070,6 +1072,7 @@ void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst,
}
// anonymous namespace
size_t
deconv
::
get_workspace_in_bytes_stride1_int8x8x32_dot
(
const
NCBKernSizeParam
&
param
)
{
return
get_bundle
(
param
).
total_size_in_bytes
();
...
...
dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h
浏览文件 @
f2b42bf0
...
...
@@ -10,8 +10,8 @@
*/
#pragma once
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/convolution/opr_impl.h"
#if MGB_ENABLE_DOT
#include <cstddef>
#include <cstdint>
...
...
dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp
浏览文件 @
f2b42bf0
...
...
@@ -9,11 +9,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h"
#if MGB_ENABLE_DOT
#include "src/common/utils.h"
#include <cstring>
#include "src/arm_common/simd_macro/marm_neon.h"
using
namespace
megdnn
;
...
...
@@ -83,6 +81,7 @@ inline int8x16_t vqtbl1q_s8_common(int8x16_t a, uint8x16_t index) {
_sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem);
template
<
bool
even
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
deconv_direct_2x2
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
int32_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
IC
)
{
MEGDNN_MARK_USED_VAR
(
IH
);
...
...
@@ -334,6 +333,7 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst,
}
template
<
bool
even
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
deconv_direct_3x3
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
int32_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
IC
)
{
MEGDNN_MARK_USED_VAR
(
IH
);
...
...
@@ -558,6 +558,7 @@ void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst,
_sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem);
template
<
bool
even
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
deconv_direct_5x5
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
int32_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
IC
)
{
MEGDNN_MARK_USED_VAR
(
IH
);
...
...
@@ -835,6 +836,7 @@ void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst,
}
template
<
bool
even
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
deconv_direct_7x7
(
const
int8_t
*
src
,
const
int8_t
*
filter
,
int32_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
IC
)
{
MEGDNN_MARK_USED_VAR
(
IH
);
...
...
dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h
浏览文件 @
f2b42bf0
...
...
@@ -10,8 +10,8 @@
*/
#pragma once
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/convolution/opr_impl.h"
#if MGB_ENABLE_DOT
#include <cstddef>
#include <cstdint>
...
...
dnn/src/arm_common/convolution/opr_impl.cpp
浏览文件 @
f2b42bf0
...
...
@@ -24,7 +24,7 @@ using namespace arm_common;
/* ===================== ConvolutionBackwardData ===================== */
class
ConvolutionBackwardDataImpl
::
AlgoPack
:
NonCopyableObj
{
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
AlgoSdot8DirectStride1
i8x8x32_direct_stride1_sdot
;
AlgoSdot8DirectStride2
i8x8x32_direct_stride2_sdot
;
AlgoUdot8DirectStride1
quint8_direct_stride1_udot
;
...
...
@@ -37,7 +37,7 @@ class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
public:
AlgoPack
()
{
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
m_all_algos
.
emplace_back
(
&
i8x8x32_direct_stride1_sdot
);
m_all_algos
.
emplace_back
(
&
i8x8x32_direct_stride2_sdot
);
m_all_algos
.
emplace_back
(
&
quint8_direct_stride1_udot
);
...
...
dnn/src/arm_common/convolution/opr_impl.h
浏览文件 @
f2b42bf0
...
...
@@ -56,7 +56,7 @@ public:
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC
(
ConvolutionBackwardDataImpl
);
private:
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
class
AlgoSdot8DirectStride1
;
class
AlgoSdot8DirectStride2
;
class
AlgoUdot8DirectStride1
;
...
...
dnn/src/arm_common/convolution/quint8/algos.cpp
浏览文件 @
f2b42bf0
...
...
@@ -14,6 +14,7 @@
#include "src/arm_common/convolution/quint8/conv_backdata_stride1.h"
#include "src/arm_common/convolution/quint8/conv_backdata_stride2.h"
#include "src/common/opr_delegate.h"
#include "midout.h"
MIDOUT_DECL
(
megdnn_arm_conv_quint8_kimpl
)
...
...
@@ -21,7 +22,7 @@ MIDOUT_DECL(megdnn_arm_conv_quint8_kimpl)
using
namespace
megdnn
;
using
namespace
arm_common
;
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
/* ===================== ConvolutionBackwardData ===================== */
...
...
@@ -29,6 +30,10 @@ using namespace arm_common;
bool
ConvolutionBackwardDataImpl
::
AlgoUdot8DirectStride1
::
usable
(
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
()){
return
false
;
}
return
deconv
::
can_stride1_quint8_dot
(
param
);
}
...
...
@@ -58,6 +63,9 @@ ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern(
bool
ConvolutionBackwardDataImpl
::
AlgoUdot8DirectStride2
::
usable
(
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
()){
return
false
;
}
return
deconv
::
can_stride2_quint8_dot
(
param
);
}
...
...
dnn/src/arm_common/convolution/quint8/algos.h
浏览文件 @
f2b42bf0
...
...
@@ -17,7 +17,7 @@
namespace
megdnn
{
namespace
arm_common
{
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
/* ===================== ConvolutionBackwardData ===================== */
class
ConvolutionBackwardDataImpl
::
AlgoUdot8DirectStride1
final
:
public
AlgoBase
{
...
...
dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp
浏览文件 @
f2b42bf0
...
...
@@ -9,11 +9,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/convolution/quint8/conv_backdata_stride1.h"
#if MGB_ENABLE_DOT
#include "src/common/utils.h"
#include <cstring>
#include "src/arm_common/simd_macro/marm_neon.h"
using
namespace
megdnn
;
...
...
@@ -109,6 +107,7 @@ inline uint8x16_t vqtbl1q_u8_common(uint8x16_t a, uint8x16_t index) {
_sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem));
template
<
bool
last_oc
=
false
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
deconv_direct_2x2
(
const
uint8_t
*
src
,
const
uint8_t
*
filter
,
int32_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
IC
,
uint8_t
src_zp
,
uint8_t
filter_zp
,
...
...
@@ -385,6 +384,7 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst,
}
template
<
bool
last_oc
=
false
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
deconv_direct_3x3
(
const
uint8_t
*
src
,
const
uint8_t
*
filter
,
int32_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
IC
,
uint8_t
src_zp
,
uint8_t
filter_zp
,
...
...
@@ -636,6 +636,7 @@ void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst,
_sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2);
template
<
bool
last_oc
=
false
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
deconv_direct_5x5
(
const
uint8_t
*
src
,
const
uint8_t
*
filter
,
int32_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
IC
,
uint8_t
src_zp
,
uint8_t
filter_zp
,
...
...
@@ -907,6 +908,7 @@ void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst,
}
template
<
bool
last_oc
=
false
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
deconv_direct_7x7
(
const
uint8_t
*
src
,
const
uint8_t
*
filter
,
int32_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
IC
,
uint8_t
src_zp
,
uint8_t
filter_zp
,
...
...
@@ -1220,6 +1222,7 @@ void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst,
}
// anonymous namespace
size_t
deconv
::
get_workspace_in_bytes_stride1_quint8_dot
(
const
NCBKernSizeParam
&
param
)
{
return
get_bundle
(
param
).
total_size_in_bytes
();
...
...
dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.h
浏览文件 @
f2b42bf0
...
...
@@ -10,11 +10,8 @@
*/
#pragma once
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/convolution/opr_impl.h"
#include <cstddef>
#include <cstdint>
#if MGB_ENABLE_DOT
namespace
megdnn
{
namespace
arm_common
{
...
...
dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp
浏览文件 @
f2b42bf0
...
...
@@ -9,11 +9,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/convolution/quint8/conv_backdata_stride2.h"
#if MGB_ENABLE_DOT
#include "src/common/utils.h"
#include <cstring>
#include "src/arm_common/simd_macro/marm_neon.h"
using
namespace
megdnn
;
...
...
@@ -110,6 +108,7 @@ inline uint8x16_t vqtbx1q_u8_common(uint8x16_t a, uint8x16_t t,
_sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem));
template
<
bool
even
,
bool
last_oc
=
false
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
deconv_direct_2x2
(
const
uint8_t
*
src
,
const
uint8_t
*
filter
,
int32_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
IC
,
uint8_t
src_zp
,
uint8_t
filter_zp
,
...
...
@@ -402,6 +401,7 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst,
}
template
<
bool
even
,
bool
last_oc
=
false
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
deconv_direct_3x3
(
const
uint8_t
*
src
,
const
uint8_t
*
filter
,
int32_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
IC
,
uint8_t
src_zp
,
uint8_t
filter_zp
,
...
...
@@ -673,6 +673,7 @@ void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst,
_sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2);
template
<
bool
even
,
bool
last_oc
=
false
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
deconv_direct_5x5
(
const
uint8_t
*
src
,
const
uint8_t
*
filter
,
int32_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
IC
,
uint8_t
src_zp
,
uint8_t
filter_zp
,
...
...
@@ -972,6 +973,7 @@ void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst,
}
template
<
bool
even
,
bool
last_oc
=
false
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
deconv_direct_7x7
(
const
uint8_t
*
src
,
const
uint8_t
*
filter
,
int32_t
*
dst
,
size_t
IH
,
size_t
IW
,
size_t
OH
,
size_t
OW
,
size_t
IC
,
uint8_t
src_zp
,
uint8_t
filter_zp
,
...
...
dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.h
浏览文件 @
f2b42bf0
...
...
@@ -10,11 +10,8 @@
*/
#pragma once
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/convolution/opr_impl.h"
#include <cstddef>
#include <cstdint>
#if MGB_ENABLE_DOT
namespace
megdnn
{
namespace
arm_common
{
...
...
dnn/src/arm_common/matrix_mul/algos.cpp
浏览文件 @
f2b42bf0
...
...
@@ -14,8 +14,10 @@
#include "src/arm_common/matrix_mul/fp16/hgemv.h"
#include "src/arm_common/matrix_mul/fp32/exec_sgemv.h"
#include "src/arm_common/matrix_mul/int8/gemv.h"
#include "midout.h"
MIDOUT_DECL
(
megdnn_arm_hgemv
)
MIDOUT_DECL
(
megdnn_arm_exec_int8816
)
MIDOUT_DECL
(
megdnn_arm_exec_int8832
)
...
...
@@ -158,7 +160,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern(
return
int8x8x32_gemv_mk4_kern
;
}
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
/* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */
namespace
{
void
int8x8x32_gemv_mk4_dot_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
...
...
@@ -176,6 +178,10 @@ void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) {
bool
MatrixMulImpl
::
AlgoInt8x8x32GemvMK4Dot
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
()){
return
false
;
}
auto
M
=
kern_size_param
.
M
;
auto
N
=
kern_size_param
.
N
;
auto
K
=
kern_size_param
.
K
;
...
...
dnn/src/arm_common/matrix_mul/algos.h
浏览文件 @
f2b42bf0
...
...
@@ -63,7 +63,7 @@ public:
MEGDNN_DECL_ALGO_TYPE
(
ARM_COMMON_INT8X8X32_GEMV_MK4
)
};
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
class
MatrixMulImpl
::
AlgoInt8x8x32GemvMK4Dot
:
public
AlgoBase
{
public:
AlgoAttribute
attribute
()
const
override
{
...
...
dnn/src/arm_common/matrix_mul/int8/gemv.cpp
浏览文件 @
f2b42bf0
...
...
@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <cstddef>
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/matrix_mul/int8/gemv.h"
#include "src/common/utils.h"
...
...
@@ -21,7 +20,6 @@ MIDOUT_DECL(megdnn_arm_common_int8_gemv)
using
namespace
megdnn
;
using
namespace
arm_common
;
#if !__ARM_FEATURE_DOTPROD
namespace
{
...
...
@@ -170,12 +168,11 @@ void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B,
}
}
// namespace
#endif
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
namespace
{
void
gemv_naive_n
(
const
int8_t
*
__restrict
A
,
const
int8_t
*
__restrict
B
,
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
gemv_naive_n
_dot
(
const
int8_t
*
__restrict
A
,
const
int8_t
*
__restrict
B
,
int32_t
*
__restrict
C
,
size_t
M
,
size_t
N
,
size_t
K
,
size_t
Astride
,
size_t
Bstride
,
size_t
Cstride
)
{
megdnn_assert
(
N
==
1
&&
Bstride
==
1
);
...
...
@@ -244,7 +241,8 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B,
}
}
void
gemv_naive_n_mk4
(
const
int8_t
*
__restrict
A
,
const
int8_t
*
__restrict
B
,
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
gemv_naive_n_mk4_dotprod
(
const
int8_t
*
__restrict
A
,
const
int8_t
*
__restrict
B
,
int32_t
*
__restrict
C
,
size_t
M
,
size_t
N
,
size_t
K
,
size_t
Astride
,
size_t
Bstride
,
size_t
Cstride
)
{
constexpr
size_t
PACK_SIZE
=
4
;
...
...
@@ -323,6 +321,7 @@ void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B,
}
}
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
gemv_naive_n_mk4_dot
(
const
int8_t
*
__restrict
A
,
const
int8_t
*
__restrict
B
,
int32_t
*
__restrict
C
,
size_t
M
,
size_t
N
,
size_t
K
,
size_t
Astride
,
...
...
@@ -403,7 +402,16 @@ void arm_common::gemv_like(const int8_t* __restrict A,
megdnn_assert
(
N
==
1
);
MIDOUT_BEGIN
(
megdnn_arm_common_int8_gemv
,
midout_iv
(
"INT8_gemv_like"
_hash
))
{
#if MGB_ENABLE_DOT
if
(
cpuinfo_has_arm_neon_dot
())
{
return
gemv_naive_n_dot
(
A
,
B
,
C
,
M
,
N
,
K
,
Astride
,
Bstride
,
Cstride
);
}
else
{
return
gemv_naive_n
(
A
,
B
,
C
,
M
,
N
,
K
,
Astride
,
Bstride
,
Cstride
);
}
#else
return
gemv_naive_n
(
A
,
B
,
C
,
M
,
N
,
K
,
Astride
,
Bstride
,
Cstride
);
#endif
}
MIDOUT_END
();
}
...
...
@@ -416,12 +424,22 @@ void arm_common::gemv_like_mk4(const int8_t* __restrict A,
megdnn_assert
(
N
==
1
);
MIDOUT_BEGIN
(
megdnn_arm_common_int8_gemv
,
midout_iv
(
"INT8_gemv_like_mk4"
_hash
))
{
#if MGB_ENABLE_DOT
if
(
cpuinfo_has_arm_neon_dot
())
{
return
gemv_naive_n_mk4_dotprod
(
A
,
B
,
C
,
M
,
N
,
K
,
Astride
,
Bstride
,
Cstride
);
}
else
{
return
gemv_naive_n_mk4
(
A
,
B
,
C
,
M
,
N
,
K
,
Astride
,
Bstride
,
Cstride
);
}
#else
return
gemv_naive_n_mk4
(
A
,
B
,
C
,
M
,
N
,
K
,
Astride
,
Bstride
,
Cstride
);
#endif
}
MIDOUT_END
();
}
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
void
arm_common
::
gemv_like_mk4_dot
(
const
int8_t
*
__restrict
A
,
const
int8_t
*
__restrict
B
,
int32_t
*
__restrict
C
,
size_t
M
,
size_t
N
,
...
...
@@ -437,4 +455,5 @@ void arm_common::gemv_like_mk4_dot(const int8_t* __restrict A,
}
#endif
// vim: syntax=cpp.doxygen
dnn/src/arm_common/matrix_mul/int8/gemv.h
浏览文件 @
f2b42bf0
...
...
@@ -28,7 +28,7 @@ void gemv_like_mk4(const int8_t* __restrict A, const int8_t* __restrict B,
int32_t
*
__restrict
C
,
size_t
M
,
size_t
N
,
size_t
K
,
size_t
Astride
,
size_t
Bstride
,
size_t
Cstride
);
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
void
gemv_like_mk4_dot
(
const
int8_t
*
__restrict
A
,
const
int8_t
*
__restrict
B
,
int32_t
*
__restrict
C
,
size_t
M
,
size_t
N
,
size_t
K
,
size_t
Astride
,
size_t
Bstride
,
size_t
Cstride
);
...
...
dnn/src/arm_common/matrix_mul/opr_impl.cpp
浏览文件 @
f2b42bf0
...
...
@@ -22,7 +22,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#endif
AlgoInt8x8x32Gemv
int8x8x32_gemv
;
AlgoInt8x8x32GemvMK4
int8x8x32_gemv_mk4
;
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
AlgoInt8x8x32GemvMK4Dot
int8x8x32_gemv_mk4_dot
;
#endif
AlgoGevm
gevm
;
...
...
@@ -37,7 +37,7 @@ public:
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
m_all_algos
.
emplace_back
(
&
f16gemv
);
#endif
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
m_all_algos
.
emplace_back
(
&
int8x8x32_gemv_mk4_dot
);
#endif
m_all_algos
.
emplace_back
(
&
int8x8x32_gemv
);
...
...
dnn/src/arm_common/matrix_mul/opr_impl.h
浏览文件 @
f2b42bf0
...
...
@@ -42,7 +42,7 @@ protected:
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class
AlgoF16Gemv
;
#endif
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
class
AlgoInt8x8x32GemvMK4Dot
;
// Arm_common Int8x8x32 Gemv NCHW44_DOT
#endif
class
AlgoInt8x8x16
;
// Arm_common Int 8x8x16
...
...
dnn/src/arm_common/neon_struct.h
浏览文件 @
f2b42bf0
...
...
@@ -69,9 +69,10 @@ struct Vfmaq_laneq_f32 {
return
vfmaq_laneq_f32
(
a
,
b
,
v
,
lane
);
}
};
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
struct
Vdotq_laneq_s32
{
template
<
const
int
lane
>
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
__ai
int32x4_t
impl
(
int32x4_t
a
,
int8x16_t
b
,
int8x16_t
v
)
{
return
vdotq_laneq_s32
(
a
,
b
,
v
,
lane
);
}
...
...
@@ -82,4 +83,4 @@ struct Vdotq_laneq_s32 {
}
// namespace megdnn
#undef __ai
// vim: syntax=cpp.doxygen
\ No newline at end of file
// vim: syntax=cpp.doxygen
dnn/src/arm_common/simd_macro/marm_neon.h
浏览文件 @
f2b42bf0
...
...
@@ -10,7 +10,12 @@
* implied.
*/
#pragma once
#if MGB_ENABLE_DOT
#if defined(__ARM_FEATURE_DOTPROD)
#undef __ARM_FEATURE_DOTPROD
#endif
#define __ARM_FEATURE_DOTPROD 1
#endif
#include <arm_neon.h>
#include "megdnn/arch.h"
#include "src/common/unroll_macro.h"
...
...
@@ -249,13 +254,14 @@ __ai float16x8_t vdupq_n_f16(__fp16 a) {
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
__ai
int32x4_t
vdotq2_s32
(
int8x16_t
a
,
int8x16_t
b
)
{
int32x4_t
c
=
vdupq_n_s32
(
0
);
return
vdotq_s32
(
c
,
a
,
b
);
}
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
__ai
uint32x4_t
vdotq2_u32
(
uint8x16_t
a
,
uint8x16_t
b
)
{
uint32x4_t
c
=
vdupq_n_u32
(
0
);
return
vdotq_u32
(
c
,
a
,
b
);
...
...
@@ -275,11 +281,13 @@ __ai uint32x4_t vdotq2_u32(uint8x16_t a, uint8x16_t b) {
c; \
})
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
__ai
int32x2_t
vdot2_s32
(
int8x8_t
a
,
int8x8_t
b
)
{
int32x2_t
c
=
vdup_n_s32
(
0
);
return
vdot_s32
(
c
,
a
,
b
);
}
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
__ai
uint32x2_t
vdot2_u8
(
uint8x8_t
a
,
uint8x8_t
b
)
{
uint32x2_t
c
=
vdup_n_u32
(
0
);
return
vdot_u32
(
c
,
a
,
b
);
...
...
@@ -298,8 +306,7 @@ __ai uint32x2_t vdot2_u8(uint8x8_t a, uint8x8_t b) {
c = vdot_lane_u32(c, a, b, lane); \
c; \
})
#endif // __ARM_FEATURE_DOTPROD
#endif // MGB_ENABLE_DOT
#if __GNUC__ < 8
#undef vld1q_f32_x2
...
...
@@ -575,7 +582,7 @@ struct Vfmsq_laneq_f32_armv7<3> {
#define vfmsq_laneq_f32(a, b, v, lane) \
Vfmsq_laneq_f32_armv7<lane>::impl(a, b, v)
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
namespace
{
template
<
int
lane
>
struct
Vdotq_laneq_s32_armv7
{
...
...
@@ -583,24 +590,28 @@ struct Vdotq_laneq_s32_armv7 {
};
template
<
>
struct
Vdotq_laneq_s32_armv7
<
0
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
__ai
int32x4_t
impl
(
int32x4_t
a
,
int8x16_t
b
,
int8x16_t
v
)
{
return
vdotq_lane_s32
(
a
,
b
,
vget_low_s32
(
v
),
0
);
}
};
template
<
>
struct
Vdotq_laneq_s32_armv7
<
1
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
__ai
int32x4_t
impl
(
int32x4_t
a
,
int8x16_t
b
,
int8x16_t
v
)
{
return
vdotq_lane_s32
(
a
,
b
,
vget_low_s32
(
v
),
1
);
}
};
template
<
>
struct
Vdotq_laneq_s32_armv7
<
2
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
__ai
int32x4_t
impl
(
int32x4_t
a
,
int8x16_t
b
,
int8x16_t
v
)
{
return
vdotq_lane_s32
(
a
,
b
,
vget_high_s32
(
v
),
0
);
}
};
template
<
>
struct
Vdotq_laneq_s32_armv7
<
3
>
{
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
__ai
int32x4_t
impl
(
int32x4_t
a
,
int8x16_t
b
,
int8x16_t
v
)
{
return
vdotq_lane_s32
(
a
,
b
,
vget_high_f32
(
v
),
1
);
}
...
...
@@ -765,7 +776,9 @@ __ai float32x4_t Vfmsq_f32(float32x4_t& a, float32x4_t& b, float32x4_t& v) {
:
);
return
a
;
}
#if MGB_ENABLE_DOT
#undef __ARM_FEATURE_DOTPROD
#endif
#undef __ai
#pragma GCC diagnostic pop
...
...
dnn/src/armv7/matrix_mul/algos.cpp
浏览文件 @
f2b42bf0
...
...
@@ -19,6 +19,9 @@
#include "src/armv7/matrix_mul/quint8/strategy.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_impl.h"
#if MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
#endif
#include "midout.h"
...
...
@@ -744,7 +747,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x4x1,
armv7
::
matmul
::
gemm_s16x16x32_12x4
,
int16_t
,
int32_t
,
AlgoDataType
::
INT16X16X32
,
DEFAULT
);
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
/* ===================== Int8 K6x8x4 algo ===================== */
namespace
{
void
int8_k6x8x4_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
...
...
@@ -769,6 +772,9 @@ void int8_k6x8x4_kern(const MatrixMulImpl::KernParam& kern_param) {
bool
MatrixMulImpl
::
AlgoInt8x8x32K6x8x4
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
()){
return
false
;
}
return
can_be_treated_as_int8x8x32
(
kern_size_param
);
}
...
...
@@ -827,6 +833,9 @@ void quint8_dot_k4x8x4_kern(const MatrixMulImpl::KernParam& kern_param) {
bool
MatrixMulImpl
::
AlgoQuint8DotK4x8x4
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
()){
return
false
;
}
return
kern_size_param
.
A_type
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
&&
kern_size_param
.
B_type
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
&&
kern_size_param
.
C_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
&&
...
...
@@ -891,6 +900,9 @@ void int8_mk4_8x4x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
bool
MatrixMulImpl
::
AlgoInt8x8x32MK4_8x4x4DotProd
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
if
(
!
cpuinfo_has_arm_neon_dot
()){
return
false
;
}
return
kern_size_param
.
A_type
.
enumv
()
==
kern_size_param
.
B_type
.
enumv
()
&&
(
kern_size_param
.
A_type
.
enumv
()
==
DTypeEnum
::
Int8
||
kern_size_param
.
A_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)
&&
...
...
dnn/src/armv7/matrix_mul/algos.h
浏览文件 @
f2b42bf0
...
...
@@ -86,7 +86,7 @@ public:
MEGDNN_DECL_ALGO_TYPE
(
ARMV7_F16_MK8_4X8
)
};
#endif
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
class
MatrixMulImpl
::
AlgoInt8x8x32K6x8x4
final
:
public
AlgoBase
{
public:
AlgoAttribute
attribute
()
const
override
{
...
...
dnn/src/armv7/matrix_mul/asm/common.h
浏览文件 @
f2b42bf0
...
...
@@ -10,7 +10,6 @@
* implied.
*/
#pragma once
#include <arm_neon.h>
#include <cmath>
#include <cstdint>
#include <type_traits>
...
...
dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp
浏览文件 @
f2b42bf0
...
...
@@ -10,7 +10,6 @@
* implied.
*/
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/armv7/matrix_mul/asm/common.h"
#include "src/armv7/matrix_mul/fp32/strategy.h"
#include "src/common/utils.h"
...
...
dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h
浏览文件 @
f2b42bf0
...
...
@@ -9,7 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/armv7/matrix_mul/asm/common.h"
...
...
@@ -43,6 +43,7 @@ namespace matmul_dot_6x8x4 {
//
// Accumulator
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_6x8
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
,
size_t
m_remain
=
6
)
{
...
...
@@ -274,6 +275,7 @@ static void kern_6x8(const int8_t* packA, const int8_t* packB, int K,
//
// Accumulator
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_6x4
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
,
size_t
n_remain
=
8
,
size_t
m_remain
=
6
)
{
...
...
dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h
浏览文件 @
f2b42bf0
...
...
@@ -10,7 +10,7 @@
* implied.
*/
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/armv7/matrix_mul/asm/common.h"
...
...
@@ -42,7 +42,7 @@ namespace matmul_mk4_dot_8x4x4 {
// |q14[0-4]|
// +--------+
// Accumulator
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_8x4
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
n_remain
)
{
K
/=
4
;
...
...
@@ -211,6 +211,7 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
// +--------+
// Accumulator
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_4x4
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
n_remain
)
{
K
/=
4
;
...
...
dnn/src/armv7/matrix_mul/int8/strategy.cpp
浏览文件 @
f2b42bf0
...
...
@@ -175,7 +175,7 @@ void gemm_s8_4x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
}
}
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
// ===========================gemm_s8_6x8======================================
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_dots8_6x8
);
void
gemm_dots8_6x8
::
pack_A
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
y0
,
...
...
dnn/src/armv7/matrix_mul/int8/strategy.h
浏览文件 @
f2b42bf0
...
...
@@ -23,7 +23,7 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 8, 8, false, true,
MEGDNN_REG_GEMM_STRATEGY
(
dt_int8
,
dt_int32
,
dt_int32
,
4
,
2
,
16
,
false
,
false
,
gemm_mk4_s8_4x2
);
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY
(
dt_int8
,
dt_int32
,
dt_int32
,
6
,
8
,
4
,
false
,
false
,
gemm_dots8_6x8
);
...
...
dnn/src/armv7/matrix_mul/opr_impl.cpp
浏览文件 @
f2b42bf0
...
...
@@ -27,7 +27,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoF16K4x16x1
f16_k4x16x1
;
AlgoF16MK8_4x8
f16_mk8_4x8
;
#endif
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
AlgoInt8x8x32K6x8x4
int8_k6x8x4
;
AlgoQuint8DotK4x8x4
quint8_k4x8x4
;
AlgoInt8x8x32MK4_8x4x4DotProd
int8x8x32_mk4_8x4x4_dotprod
;
...
...
@@ -57,7 +57,7 @@ public:
m_all_algos
.
emplace_back
(
&
f16_k4x16x1
);
m_all_algos
.
emplace_back
(
&
f16_mk8_4x8
);
#endif
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
m_all_algos
.
emplace_back
(
&
int8x8x32_mk4_8x4x4_dotprod
);
m_all_algos
.
emplace_back
(
&
int8_k6x8x4
);
m_all_algos
.
emplace_back
(
&
quint8_k4x8x4
);
...
...
dnn/src/armv7/matrix_mul/opr_impl.h
浏览文件 @
f2b42bf0
...
...
@@ -49,7 +49,7 @@ private:
class
AlgoF16K4x16x1
;
// Armv7 F16 Kernel 4x16x1
class
AlgoF16MK8_4x8
;
// Armv7 F16 MK8 Format block 4x8
#endif
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
class
AlgoInt8x8x32K6x8x4
;
// Armv7 Int8 Kernel 6x8x4
class
AlgoQuint8DotK4x8x4
;
// Armv7 Quint8 Kernel 6x8x4
class
AlgoInt8x8x32MK4_8x4x4DotProd
;
// Armv7 nchw44 Int8x8x32 Kernel 8x4x4
...
...
dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h
浏览文件 @
f2b42bf0
...
...
@@ -9,7 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/armv7/matrix_mul/asm/common.h"
...
...
@@ -41,7 +41,7 @@ namespace matmul_dot_4x8x4 {
// +-------+-------+ - - - - +--------+--------+--------+
//
// Accumulator
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_4x8
(
const
uint8_t
*
packA
,
const
uint8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
,
uint8_t
zA
,
uint8_t
zB
,
uint32_t
zAB
,
size_t
m_remain
=
4
)
{
...
...
@@ -257,6 +257,7 @@ static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K,
// +-------+-------+ - - - - +--------+--------+--------+
//
// Accumulator
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
static
void
kern_4x4
(
const
uint8_t
*
packA
,
const
uint8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
,
uint8_t
zA
,
uint8_t
zB
,
uint32_t
zAB
,
size_t
m_remain
=
4
,
...
...
dnn/src/armv7/matrix_mul/quint8/strategy.cpp
浏览文件 @
f2b42bf0
...
...
@@ -88,7 +88,7 @@ void gemm_u8_4x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M,
}
}
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
// ===========================gemm_dot_quint8_4x8======================================
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_dot_quint8_4x8
);
void
gemm_dot_quint8_4x8
::
pack_A
(
dt_uint8
*
out
,
const
dt_uint8
*
in
,
int
ldin
,
...
...
dnn/src/armv7/matrix_mul/quint8/strategy.h
浏览文件 @
f2b42bf0
...
...
@@ -17,7 +17,7 @@ namespace matmul {
MEGDNN_REG_GEMM_STRATEGY
(
dt_uint8
,
dt_int32
,
dt_int32
,
4
,
8
,
8
,
false
,
true
,
gemm_u8_4x8
);
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY
(
dt_uint8
,
dt_int32
,
dt_int32
,
4
,
8
,
4
,
false
,
false
,
gemm_dot_quint8_4x8
);
#endif
...
...
dnn/src/common/utils.h
浏览文件 @
f2b42bf0
...
...
@@ -60,6 +60,13 @@
#include <windows.h>
#endif
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
#if MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
#endif
#endif
#if __cplusplus >= 201703L || __clang_major__ >= 4
#define MEGDNN_FALLTHRU [[fallthrough]];
#elif __GNUC__ >= 7
...
...
dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp
浏览文件 @
f2b42bf0
...
...
@@ -148,7 +148,7 @@ struct GemvLike<stype, btype, param::ConvBias::Format::NCHW44> {
}
};
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
template
<
typename
stype
,
typename
btype
>
struct
GemvLike
<
stype
,
btype
,
param
::
ConvBias
::
Format
::
NCHW44_DOT
>
{
inline
static
void
do_gemv
(
const
stype
*
A
,
const
stype
*
B
,
btype
*
C
,
...
...
dnn/test/aarch64/matrix_mul.cpp
浏览文件 @
f2b42bf0
...
...
@@ -87,7 +87,7 @@ TEST_F(AARCH64, MATRIX_MUL_F16_MK8) {
}
#endif
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
TEST_F
(
AARCH64
,
MATRIX_MUL_INT8X8X32_K8X12X4_DOTPROD
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int32
{},
handle
(),
"AARCH64_INT8X8X32_K8X12X4_DOTPROD"
);
...
...
@@ -690,7 +690,7 @@ TEST_F(AARCH64, BENCHMARK_GEMV) {
run
(
M
,
K
,
N
);
}
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
TEST_F
(
AARCH64
,
BENCHMARK_TRANSPOSED_MATRIX_MUL_INT_8X8X32
)
{
constexpr
size_t
RUNS
=
50
;
param
::
MatrixMul
param
;
...
...
@@ -803,7 +803,7 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT8X8X32_MK4_8X12X4) {
std
::
cout
<<
std
::
endl
;
}
}
#endif //
__ARM_FEATURE_DOTPROD
#endif //
MGB_ENABLE_DOT
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F
(
AARCH64
,
BENCHMARK_MATRIX_MUL_F16_MK8
)
{
...
...
dnn/test/arm_common/conv_bias.cpp
浏览文件 @
f2b42bf0
...
...
@@ -166,7 +166,7 @@ static void benchmark_convbias(Handle* handle, std::string int_name,
.
set_display
(
false
);
}
auto
nchw44_algo_regx
=
".*(DIRECT|NCHW_NCHW44).*"
;
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENBALE_DOT
if
(
!
is_fp32
)
{
nchw44_algo_regx
=
".*DOT.*"
;
}
...
...
@@ -1852,7 +1852,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) {
#endif
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENBALE_DOT
#if MEGDNN_WITH_BENCHMARK
TEST_F
(
ARM_COMMON
,
BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD
)
{
// have to remove preferred restrict in usable func before run the benchmark
...
...
@@ -2440,7 +2440,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDSYM) {
dtype
::
QuantizedS8
stype
(
2.5
f
);
dtype
::
QuantizedS32
dtype
(
6.25
f
);
#if MEGDNN_AARCH64
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENBALE_DOT
benchmark_conv1x1
(
"AARCH64_INT8X8X32_K8X12X4_DOTPROD"
,
handle
(),
stype
,
dtype
,
dtype
,
dtype
);
#else
...
...
@@ -2460,7 +2460,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDASYM) {
dtype
::
QuantizedS32
dtype
(
1.2
*
1.2
);
#if MEGDNN_AARCH64
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENBALE_DOT
benchmark_conv1x1
(
"AARCH64_QUINT8_K8X8X4_DOTPROD"
,
handle
(),
stype
,
dtype
,
dtype
,
dtype
);
#else
...
...
@@ -2565,7 +2565,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_GEMV_FP32) {
}
}
#ifndef __ARM_FEATURE_DOTPROD
//! enable none dot algo now
TEST_F
(
ARM_COMMON
,
BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32
)
{
std
::
vector
<
TestArg
>
conv_bias_1x1_args_nchw44
=
get_conv_bias_1x1_benchmark_args
(
4
);
...
...
@@ -2634,7 +2634,6 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32) {
computations
/
conv1x1_nchw44
,
conv1x1_nchw
/
conv1x1_nchw44
);
}
}
#endif
TEST_F
(
ARM_COMMON
,
BENCHMARK_CONV_BIAS_WINOGRAD_VS_IM2COL_INT8
)
{
auto
&&
args
=
get_winograd_benchmark_args
(
3
,
8
);
...
...
dnn/test/arm_common/conv_bias_multi_thread.cpp
浏览文件 @
f2b42bf0
...
...
@@ -500,7 +500,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) {
}
/****************************dot qint8 direct*************************/
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_DOT_NCHW_NCHW44
)
{
auto
args
=
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
QUAN_NLMODE
,
BR_AND_NO_BIASMODE
,
2
,
false
,
true
);
...
...
dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp
浏览文件 @
f2b42bf0
...
...
@@ -655,7 +655,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44) {
bench_case
(
1
,
512
,
256
,
28
,
28
,
3
,
4
,
1
,
2
);
}
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
TEST_F
(
ARM_COMMON_BENCHMARK_MULTI_THREADS
,
BENCHMARK_CONVBIAS_INT8_NCHW44_DOT
)
{
constexpr
size_t
RUNS
=
40
;
std
::
vector
<
DType
>
data_type
=
{
...
...
@@ -892,7 +892,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
benchmark_impl
(
param
,
shapes_and_computation
,
algo_name
,
RUNS
,
{
2
,
{
4
,
5
}},
{
1
,
{
4
}},
data_type
);
}
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
TEST_F
(
ARM_COMMON_BENCHMARK_MULTI_THREADS
,
BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE1_WITHDOTPROD
)
{
constexpr
size_t
RUNS
=
50
;
...
...
@@ -1157,7 +1157,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
benchmark_impl
(
param
,
shapes_and_computation
,
algo_name
,
RUNS
,
{
2
,
{
4
,
5
}},
{
1
,
{
4
}},
data_type
);
}
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
TEST_F
(
ARM_COMMON_BENCHMARK_MULTI_THREADS
,
BENCHMARK_CONVBIAS_QUINT8_QUINT8_QUINT8_STRIDE1_WITHDOTPROD
)
{
constexpr
size_t
RUNS
=
50
;
...
...
@@ -1977,7 +1977,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
dtype
::
QuantizedS32
btype
(
0.04
f
);
dtype
::
Quantized8Asymm
dtype
(
1.4
f
,
110
);
#if MEGDNN_AARCH64
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
conv1x1_multithread_benchmark
(
"CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:8"
,
stype
,
ftype
,
btype
,
dtype
);
#else
...
...
dnn/test/arm_common/conv_bias_multi_thread_conv1x1.cpp
浏览文件 @
f2b42bf0
...
...
@@ -20,7 +20,7 @@ using namespace megdnn;
using
namespace
test
;
using
namespace
conv_bias
;
#if
def __ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT
)
{
UniformIntRNG
rng
{
-
50
,
50
};
...
...
@@ -138,7 +138,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) {
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
dtype::QuantizedS8(60.25f), name);
#if MEGDNN_AARCH64
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
cb
(
"CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:24"
);
#else
cb
(
"CONV1x1:AARCH64_INT8X8X32_K8X8X8:24"
);
...
...
@@ -174,7 +174,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) {
name);
float
epsilon
=
0.001
;
#if MEGDNN_AARCH64
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
cb
(
"CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:48"
);
#else
cb
(
"CONV1x1:AARCH64_QUINT8_K8X8X8:24"
);
...
...
@@ -210,13 +210,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32) {
dtype::QuantizedS32(1.2 * 1.3), {}, name);
#if MEGDNN_AARCH64
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
cb
(
"CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:24"
);
#else
cb
(
"CONV1x1:AARCH64_QUINT8_K8X8X8:48"
);
#endif
#elif MEGDNN_ARMV7
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
cb
(
"CONV1x1:AARCH32_QUINT8_K4X8X4:48"
);
#endif
cb
(
"CONV1x1:ARMV7_QUINT8_K4X8X8:24"
);
...
...
@@ -287,14 +287,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) {
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
#if MEGDNN_AARCH64
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
cb
(
"CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:48"
);
#else
cb
(
"CONV1x1:AARCH64_INT8X8X32_K8X8X8:24"
);
cb
(
"CONV1x1:AARCH64_INT8X8X32_K4X4X16:24"
);
#endif
#elif MEGDNN_ARMV7
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
cb
(
"CONV1x1:AARCH32_INT8_K6X8X4:48"
);
#endif
cb
(
"CONV1x1:ARMV7_INT8X8X32_K4X8X8:24"
);
...
...
@@ -312,8 +312,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) {
}
checker_conv_bias_mul_int8x8x32
(
gemv_args
,
handle
(),
"CONV1x1_GEMV"
);
}
#ifndef __ARM_FEATURE_DOTPROD
//! enable none dot algo now
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_1X1_S1_INT8x8x32_MK4
)
{
using
namespace
conv_bias
;
std
::
vector
<
conv_bias
::
TestArg
>
args
=
...
...
@@ -345,7 +344,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) {
#endif
#undef cb
}
#endif
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_1X1_S1_INT8x8x32_NCHW44
)
{
using
namespace
conv_bias
;
...
...
@@ -364,7 +362,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44) {
"CONV1x1_GEMV"
);
}
#if
def __ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_1X1_S1_INT8x8x32_NCHW44_DOT
)
{
using
namespace
conv_bias
;
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_nchw44_conv_bias_args
(
...
...
dnn/test/arm_common/conv_bias_multi_thread_im2col.cpp
浏览文件 @
f2b42bf0
...
...
@@ -135,7 +135,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) {
float
epsilon
=
0.001
;
#if MEGDNN_AARCH64
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
cb
(
"IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD"
);
#else
cb
(
"IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8"
);
...
...
@@ -148,7 +148,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) {
#undef cb
}
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT
)
{
UniformIntRNG
rng
{
-
50
,
50
};
...
...
@@ -173,6 +173,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) {
#if MEGDNN_AARCH64
cb
(
"IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"
);
#elif MEGDNN_ARMV7
epsilon
=
1
;
cb
(
"IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96"
);
#endif
#undef cb
...
...
@@ -194,6 +195,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS,
#if MEGDNN_AARCH64
cb
(
"IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"
);
#elif MEGDNN_ARMV7
epsilon
=
1
;
cb
(
"IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96"
);
#endif
#undef cb
...
...
@@ -273,7 +275,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) {
dtype::Quantized8Asymm(50.3f, (uint8_t)120), name);
float
epsilon
=
0.001
;
#if MEGDNN_AARCH64
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
cb
(
"IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD"
);
#else
cb
(
"IM2COLMATMUL:AARCH64_QUINT8_K8X8X8"
);
...
...
@@ -305,13 +307,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) {
dtype::QuantizedS32(1.2 * 1.3), {}, name);
#if MEGDNN_AARCH64
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
cb
(
"IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD"
);
#else
cb
(
"IM2COLMATMUL:AARCH64_QUINT8_K8X8X8"
);
#endif
#elif MEGDNN_ARMV7
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
cb
(
"IM2COLMATMUL:AARCH32_QUINT8_K4X8X4"
);
#endif
cb
(
"IM2COLMATMUL:ARMV7_QUINT8_K4X8X8"
);
...
...
@@ -392,7 +394,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP16) {
#endif
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
#if !__ARM_FEATURE_DOTPROD
//! enable none dot algo now
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2
)
{
using
namespace
conv_bias
;
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_nchw44_conv_bias_args
(
...
...
@@ -481,12 +483,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS,
#undef cb
}
#endif
#endif
#endif
#if MEGDNN_AARCH64
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44DOT_FUSE
)
{
UniformIntRNG
rng
{
-
50
,
50
};
...
...
@@ -516,14 +517,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
#if MEGDNN_AARCH64
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
cb
(
"IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD"
);
#else
cb
(
"IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8"
);
cb
(
"IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16"
);
#endif
#elif MEGDNN_ARMV7
#if
__ARM_FEATURE_DOTPROD
#if
MGB_ENABLE_DOT
cb
(
"IM2COLMATMUL:AARCH32_INT8_K6X8X4"
);
#endif
cb
(
"IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8"
);
...
...
dnn/test/arm_common/conv_bias_multi_thread_weight_preprocess.cpp
浏览文件 @
f2b42bf0
此差异已折叠。
点击以展开。
dnn/test/arm_common/convolution.cpp
浏览文件 @
f2b42bf0
此差异已折叠。
点击以展开。
dnn/test/arm_common/matrix_mul.cpp
浏览文件 @
f2b42bf0
此差异已折叠。
点击以展开。
dnn/test/armv7/matrix_mul.cpp
浏览文件 @
f2b42bf0
此差异已折叠。
点击以展开。
src/megbrain_build_config.h.in
浏览文件 @
f2b42bf0
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录