Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f6d99094
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
f6d99094
编写于
12月 17, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn): add elemwise multi type support i16xf32 and u8xf32
GitOrigin-RevId: 2fe469bb4ec9a0b7d20f88a54d2a87e7ad42385b
上级
d9a46ea4
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
1481 addition
and
5 deletion
+1481
-5
dnn/scripts/opr_param_defs.py
dnn/scripts/opr_param_defs.py
+10
-1
dnn/src/arm_common/elemwise_multi_type/kernels.cpp
dnn/src/arm_common/elemwise_multi_type/kernels.cpp
+707
-0
dnn/src/arm_common/elemwise_multi_type/kernels.h
dnn/src/arm_common/elemwise_multi_type/kernels.h
+71
-0
dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp
dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp
+149
-0
dnn/src/arm_common/elemwise_multi_type/opr_impl.h
dnn/src/arm_common/elemwise_multi_type/opr_impl.h
+9
-0
dnn/src/common/elemwise_multi_type/opr_impl.cpp
dnn/src/common/elemwise_multi_type/opr_impl.cpp
+26
-0
dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp
dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp
+11
-0
dnn/src/common/elemwise_multi_type/opr_impl_helper.h
dnn/src/common/elemwise_multi_type/opr_impl_helper.h
+21
-0
dnn/src/fallback/elemwise_multi_type/opr_impl.cpp
dnn/src/fallback/elemwise_multi_type/opr_impl.cpp
+210
-0
dnn/src/fallback/elemwise_multi_type/opr_impl.h
dnn/src/fallback/elemwise_multi_type/opr_impl.h
+6
-0
dnn/src/naive/elemwise_multi_type/opr_impl.cpp
dnn/src/naive/elemwise_multi_type/opr_impl.cpp
+60
-0
dnn/src/naive/elemwise_multi_type/opr_impl.h
dnn/src/naive/elemwise_multi_type/opr_impl.h
+6
-0
dnn/test/arm_common/elemwise_multi_type.cpp
dnn/test/arm_common/elemwise_multi_type.cpp
+103
-0
dnn/test/common/elemwise_multi_type.cpp
dnn/test/common/elemwise_multi_type.cpp
+67
-0
dnn/test/common/elemwise_multi_type.h
dnn/test/common/elemwise_multi_type.h
+7
-4
dnn/test/fallback/elemwise_multi_type.cpp
dnn/test/fallback/elemwise_multi_type.cpp
+18
-0
未找到文件。
dnn/scripts/opr_param_defs.py
浏览文件 @
f6d99094
...
...
@@ -497,7 +497,16 @@ pdef('ElemwiseMultiType').add_enum(
Doc
(
'QCOND_LEQ_MOV = 50'
,
'quantized cond_leq_mov'
),
Doc
(
'QH_SWISH = 51'
,
'quantized h_swish'
),
Doc
(
'QFUSE_ADD_H_SWISH = 52'
,
'quantized h_swish(x+y)'
),
Doc
(
'QH_SWISH_GRAD = 53'
,
'quantized h_swish_grad'
)
Doc
(
'QH_SWISH_GRAD = 53'
,
'quantized h_swish_grad'
),
Doc
(
'FUSE_MUL_ADD3_INT16xF32xF32xF32 = 54'
,
'compute ``a * b + c`` requiring that ``a`` be int16 and ``b`` and '
'``c`` float32, and the result is float32.'
),
Doc
(
'MUL_INT16xF32xF32 = 55'
,
'compute ``a * b `` requiring that ``a`` be int16 and ``b`` float32, '
'and the result is float32.'
),
Doc
(
'FUSE_MUL_ADD3_UINT8xF32xF32xF32 = 56'
,
'compute ``a * b + c`` requiring that ``a`` be uint8 and ``b`` and '
'``c`` float32, and the result is float32.'
)
)
pdef
(
'PowC'
,
'power with constant exponent'
).
add_fields
(
'float32'
,
'exp'
,
0
)
...
...
dnn/src/arm_common/elemwise_multi_type/kernels.cpp
0 → 100644
浏览文件 @
f6d99094
/**
* \file dnn/src/arm_common/elemwise_multi_type/kernels.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "kernels.h"
#include "src/arm_common/simd_macro/marm_neon.h"
namespace
megdnn
{
namespace
arm_common
{
#if defined(__ARM_FEATURE_FMA)
#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m)
#else
#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m)
#endif
void
neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast111c_bcast111c
(
size_t
batch_size
,
size_t
channel_stride
,
size_t
channel_size
,
const
int16_t
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
float
*
dst
)
{
const
int16_t
*
__restrict
sptr0
=
src0
;
const
float
*
__restrict
sptr1
=
src1
;
const
float
*
__restrict
sptr2
=
src2
;
float
*
__restrict
dst_ptr
=
dst
;
for
(
size_t
batch
=
0
;
batch
<
batch_size
;
++
batch
)
{
for
(
size_t
s
=
0
;
s
<
channel_stride
;
++
s
)
{
size_t
i
=
0
;
for
(;
i
+
15
<
channel_size
;
i
+=
16
,
sptr0
+=
16
,
dst_ptr
+=
16
)
{
auto
vec0_01
=
vld1q_s16
(
sptr0
);
auto
vec0_23
=
vld1q_s16
(
sptr0
+
8
);
auto
vec1_0
=
vld1q_f32
(
sptr1
+
i
);
auto
vec1_1
=
vld1q_f32
(
sptr1
+
i
+
4
);
auto
vec1_2
=
vld1q_f32
(
sptr1
+
i
+
8
);
auto
vec1_3
=
vld1q_f32
(
sptr1
+
i
+
12
);
auto
vec2_0
=
vld1q_f32
(
sptr2
+
i
);
auto
vec2_1
=
vld1q_f32
(
sptr2
+
i
+
4
);
auto
vec2_2
=
vld1q_f32
(
sptr2
+
i
+
8
);
auto
vec2_3
=
vld1q_f32
(
sptr2
+
i
+
12
);
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
vec0_2
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_23
)));
auto
vec0_3
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_23
)));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2_0
,
vec0_0
,
vec1_0
);
auto
dst_vec_1
=
Vfmaq_f32
(
vec2_1
,
vec0_1
,
vec1_1
);
auto
dst_vec_2
=
Vfmaq_f32
(
vec2_2
,
vec0_2
,
vec1_2
);
auto
dst_vec_3
=
Vfmaq_f32
(
vec2_3
,
vec0_3
,
vec1_3
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
vst1q_f32
(
dst_ptr
+
8
,
dst_vec_2
);
vst1q_f32
(
dst_ptr
+
12
,
dst_vec_3
);
}
for
(;
i
+
7
<
channel_size
;
i
+=
8
,
sptr0
+=
8
,
dst_ptr
+=
8
)
{
auto
vec0_01
=
vld1q_s16
(
sptr0
);
auto
vec1_0
=
vld1q_f32
(
sptr1
+
i
);
auto
vec1_1
=
vld1q_f32
(
sptr1
+
i
+
4
);
auto
vec2_0
=
vld1q_f32
(
sptr2
+
i
);
auto
vec2_1
=
vld1q_f32
(
sptr2
+
i
+
4
);
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2_0
,
vec0_0
,
vec1_0
);
auto
dst_vec_1
=
Vfmaq_f32
(
vec2_1
,
vec0_1
,
vec1_1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
}
for
(;
i
+
3
<
channel_size
;
i
+=
4
,
sptr0
+=
4
,
dst_ptr
+=
4
)
{
auto
vec0_0
=
vld1_s16
(
sptr0
);
auto
vec1_0
=
vld1q_f32
(
sptr1
+
i
);
auto
vec2_0
=
vld1q_f32
(
sptr2
+
i
);
auto
vec0_0_f32
=
vcvtq_f32_s32
(
vmovl_s16
(
vec0_0
));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2_0
,
vec0_0_f32
,
vec1_0
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
}
for
(;
i
<
channel_size
;
++
i
,
++
sptr0
,
++
dst_ptr
)
{
*
dst_ptr
=
(
float
)(
*
sptr0
)
*
sptr1
[
i
]
+
sptr2
[
i
];
}
}
}
}
void
neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast111c_bcast111c
(
size_t
batch_size
,
size_t
channel_stride
,
size_t
channel_size
,
const
uint8_t
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
float
*
dst
)
{
const
uint8_t
*
__restrict
sptr0
=
src0
;
const
float
*
__restrict
sptr1
=
src1
;
const
float
*
__restrict
sptr2
=
src2
;
float
*
__restrict
dst_ptr
=
dst
;
for
(
size_t
batch
=
0
;
batch
<
batch_size
;
++
batch
)
{
for
(
size_t
s
=
0
;
s
<
channel_stride
;
++
s
)
{
size_t
i
=
0
;
for
(;
i
+
15
<
channel_size
;
i
+=
16
,
sptr0
+=
16
,
dst_ptr
+=
16
)
{
auto
vec0_0123_u8
=
vld1q_u8
(
sptr0
);
auto
vec1_0
=
vld1q_f32
(
sptr1
+
i
);
auto
vec1_1
=
vld1q_f32
(
sptr1
+
i
+
4
);
auto
vec1_2
=
vld1q_f32
(
sptr1
+
i
+
8
);
auto
vec1_3
=
vld1q_f32
(
sptr1
+
i
+
12
);
auto
vec2_0
=
vld1q_f32
(
sptr2
+
i
);
auto
vec2_1
=
vld1q_f32
(
sptr2
+
i
+
4
);
auto
vec2_2
=
vld1q_f32
(
sptr2
+
i
+
8
);
auto
vec2_3
=
vld1q_f32
(
sptr2
+
i
+
12
);
auto
vec0_01
=
vreinterpretq_s16_u16
(
vmovl_u8
(
vget_low_u8
(
vec0_0123_u8
)));
auto
vec0_23
=
vreinterpretq_s16_u16
(
vmovl_u8
(
vget_high_u8
(
vec0_0123_u8
)));
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
vec0_2
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_23
)));
auto
vec0_3
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_23
)));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2_0
,
vec0_0
,
vec1_0
);
auto
dst_vec_1
=
Vfmaq_f32
(
vec2_1
,
vec0_1
,
vec1_1
);
auto
dst_vec_2
=
Vfmaq_f32
(
vec2_2
,
vec0_2
,
vec1_2
);
auto
dst_vec_3
=
Vfmaq_f32
(
vec2_3
,
vec0_3
,
vec1_3
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
vst1q_f32
(
dst_ptr
+
8
,
dst_vec_2
);
vst1q_f32
(
dst_ptr
+
12
,
dst_vec_3
);
}
for
(;
i
+
7
<
channel_size
;
i
+=
8
,
sptr0
+=
8
,
dst_ptr
+=
8
)
{
auto
vec0_01_u8
=
vld1_u8
(
sptr0
);
auto
vec1_0
=
vld1q_f32
(
sptr1
+
i
);
auto
vec1_1
=
vld1q_f32
(
sptr1
+
i
+
4
);
auto
vec2_0
=
vld1q_f32
(
sptr2
+
i
);
auto
vec2_1
=
vld1q_f32
(
sptr2
+
i
+
4
);
auto
vec0_01
=
vreinterpretq_s16_u16
(
vmovl_u8
(
vec0_01_u8
));
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2_0
,
vec0_0
,
vec1_0
);
auto
dst_vec_1
=
Vfmaq_f32
(
vec2_1
,
vec0_1
,
vec1_1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
}
for
(;
i
<
channel_size
;
++
i
,
++
sptr0
,
++
dst_ptr
)
{
*
dst_ptr
=
(
float
)(
*
sptr0
)
*
sptr1
[
i
]
+
sptr2
[
i
];
}
}
}
}
void
neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast101_bcast101
(
size_t
batch_size
,
size_t
channel_size
,
size_t
channel_stride
,
const
int16_t
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
float
*
dst
)
{
const
int16_t
*
__restrict
sptr0
=
src0
;
const
float
*
__restrict
sptr1
=
src1
;
const
float
*
__restrict
sptr2
=
src2
;
float
*
__restrict
dst_ptr
=
dst
;
for
(
size_t
batch
=
0
;
batch
<
batch_size
;
++
batch
)
{
for
(
size_t
chan
=
0
;
chan
<
channel_size
;
++
chan
)
{
auto
vec1
=
vdupq_n_f32
(
sptr1
[
chan
]);
auto
vec2
=
vdupq_n_f32
(
sptr2
[
chan
]);
size_t
i
=
0
;
for
(;
i
+
15
<
channel_stride
;
i
+=
16
,
sptr0
+=
16
,
dst_ptr
+=
16
)
{
auto
vec0_01
=
vld1q_s16
(
sptr0
);
auto
vec0_23
=
vld1q_s16
(
sptr0
+
8
);
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
vec0_2
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_23
)));
auto
vec0_3
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_23
)));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2
,
vec0_0
,
vec1
);
auto
dst_vec_1
=
Vfmaq_f32
(
vec2
,
vec0_1
,
vec1
);
auto
dst_vec_2
=
Vfmaq_f32
(
vec2
,
vec0_2
,
vec1
);
auto
dst_vec_3
=
Vfmaq_f32
(
vec2
,
vec0_3
,
vec1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
vst1q_f32
(
dst_ptr
+
8
,
dst_vec_2
);
vst1q_f32
(
dst_ptr
+
12
,
dst_vec_3
);
}
for
(;
i
+
7
<
channel_stride
;
i
+=
8
,
sptr0
+=
8
,
dst_ptr
+=
8
)
{
auto
vec0_01
=
vld1q_s16
(
sptr0
);
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2
,
vec0_0
,
vec1
);
auto
dst_vec_1
=
Vfmaq_f32
(
vec2
,
vec0_1
,
vec1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
}
for
(;
i
+
3
<
channel_stride
;
i
+=
4
,
sptr0
+=
4
,
dst_ptr
+=
4
)
{
auto
vec0_0
=
vld1_s16
(
sptr0
);
auto
vec0_0_f32
=
vcvtq_f32_s32
(
vmovl_s16
(
vec0_0
));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2
,
vec0_0_f32
,
vec1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
}
for
(;
i
<
channel_stride
;
++
i
,
++
sptr0
,
++
dst_ptr
)
{
*
dst_ptr
=
(
float
)(
*
sptr0
)
*
sptr1
[
chan
]
+
sptr2
[
chan
];
}
}
}
}
void
neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast101_bcast101
(
size_t
batch_size
,
size_t
channel_size
,
size_t
channel_stride
,
const
uint8_t
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
float
*
dst
)
{
const
uint8_t
*
__restrict
sptr0
=
src0
;
const
float
*
__restrict
sptr1
=
src1
;
const
float
*
__restrict
sptr2
=
src2
;
float
*
__restrict
dst_ptr
=
dst
;
for
(
size_t
batch
=
0
;
batch
<
batch_size
;
++
batch
)
{
for
(
size_t
chan
=
0
;
chan
<
channel_size
;
++
chan
)
{
auto
vec1
=
vdupq_n_f32
(
sptr1
[
chan
]);
auto
vec2
=
vdupq_n_f32
(
sptr2
[
chan
]);
size_t
i
=
0
;
for
(;
i
+
15
<
channel_stride
;
i
+=
16
,
sptr0
+=
16
,
dst_ptr
+=
16
)
{
auto
vec0_0123_u8
=
vld1q_u8
(
sptr0
);
auto
vec0_01
=
vreinterpretq_s16_u16
(
vmovl_u8
(
vget_low_u8
(
vec0_0123_u8
)));
auto
vec0_23
=
vreinterpretq_s16_u16
(
vmovl_u8
(
vget_high_u8
(
vec0_0123_u8
)));
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
vec0_2
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_23
)));
auto
vec0_3
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_23
)));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2
,
vec0_0
,
vec1
);
auto
dst_vec_1
=
Vfmaq_f32
(
vec2
,
vec0_1
,
vec1
);
auto
dst_vec_2
=
Vfmaq_f32
(
vec2
,
vec0_2
,
vec1
);
auto
dst_vec_3
=
Vfmaq_f32
(
vec2
,
vec0_3
,
vec1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
vst1q_f32
(
dst_ptr
+
8
,
dst_vec_2
);
vst1q_f32
(
dst_ptr
+
12
,
dst_vec_3
);
}
for
(;
i
+
7
<
channel_stride
;
i
+=
8
,
sptr0
+=
8
,
dst_ptr
+=
8
)
{
auto
vec0_01_u8
=
vld1_u8
(
sptr0
);
auto
vec0_01
=
vreinterpretq_s16_u16
(
vmovl_u8
(
vec0_01_u8
));
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2
,
vec0_0
,
vec1
);
auto
dst_vec_1
=
Vfmaq_f32
(
vec2
,
vec0_1
,
vec1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
}
for
(;
i
<
channel_stride
;
++
i
,
++
sptr0
,
++
dst_ptr
)
{
*
dst_ptr
=
(
float
)(
*
sptr0
)
*
sptr1
[
chan
]
+
sptr2
[
chan
];
}
}
}
}
void
neon_fuse_mul_add3_int16xf32xf32xf32_vec_vec_vec
(
size_t
size
,
const
int16_t
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
float
*
dst
)
{
const
int16_t
*
__restrict
sptr0
=
src0
;
const
float
*
__restrict
sptr1
=
src1
;
const
float
*
__restrict
sptr2
=
src2
;
float
*
__restrict
dst_ptr
=
dst
;
size_t
i
=
0
;
for
(;
i
+
15
<
size
;
i
+=
16
,
sptr0
+=
16
,
sptr1
+=
16
,
sptr2
+=
16
,
dst_ptr
+=
16
)
{
auto
vec0_01
=
vld1q_s16
(
sptr0
);
auto
vec0_23
=
vld1q_s16
(
sptr0
+
8
);
auto
vec1_0
=
vld1q_f32
(
sptr1
);
auto
vec1_1
=
vld1q_f32
(
sptr1
+
4
);
auto
vec1_2
=
vld1q_f32
(
sptr1
+
8
);
auto
vec1_3
=
vld1q_f32
(
sptr1
+
12
);
auto
vec2_0
=
vld1q_f32
(
sptr2
);
auto
vec2_1
=
vld1q_f32
(
sptr2
+
4
);
auto
vec2_2
=
vld1q_f32
(
sptr2
+
8
);
auto
vec2_3
=
vld1q_f32
(
sptr2
+
12
);
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
vec0_2
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_23
)));
auto
vec0_3
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_23
)));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2_0
,
vec0_0
,
vec1_0
);
auto
dst_vec_1
=
Vfmaq_f32
(
vec2_1
,
vec0_1
,
vec1_1
);
auto
dst_vec_2
=
Vfmaq_f32
(
vec2_2
,
vec0_2
,
vec1_2
);
auto
dst_vec_3
=
Vfmaq_f32
(
vec2_3
,
vec0_3
,
vec1_3
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
vst1q_f32
(
dst_ptr
+
8
,
dst_vec_2
);
vst1q_f32
(
dst_ptr
+
12
,
dst_vec_3
);
}
for
(;
i
+
7
<
size
;
i
+=
8
,
sptr0
+=
8
,
sptr1
+=
8
,
sptr2
+=
8
,
dst_ptr
+=
8
)
{
auto
vec0_01
=
vld1q_s16
(
sptr0
);
auto
vec1_0
=
vld1q_f32
(
sptr1
);
auto
vec1_1
=
vld1q_f32
(
sptr1
+
4
);
auto
vec2_0
=
vld1q_f32
(
sptr2
);
auto
vec2_1
=
vld1q_f32
(
sptr2
+
4
);
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2_0
,
vec0_0
,
vec1_0
);
auto
dst_vec_1
=
Vfmaq_f32
(
vec2_1
,
vec0_1
,
vec1_1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
}
for
(;
i
+
3
<
size
;
i
+=
4
,
sptr0
+=
4
,
sptr1
+=
4
,
sptr2
+=
4
,
dst_ptr
+=
4
)
{
auto
vec0_0
=
vld1_s16
(
sptr0
);
auto
vec1_0
=
vld1q_f32
(
sptr1
);
auto
vec2_0
=
vld1q_f32
(
sptr2
);
auto
vec0_0_f32
=
vcvtq_f32_s32
(
vmovl_s16
(
vec0_0
));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2_0
,
vec0_0_f32
,
vec1_0
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
}
for
(;
i
<
size
;
++
i
,
++
sptr0
,
++
sptr1
,
++
sptr2
,
++
dst_ptr
)
{
*
dst_ptr
=
(
float
)(
*
sptr0
)
*
(
*
sptr1
)
+
(
*
sptr2
);
}
}
void
neon_fuse_mul_add3_uint8xf32xf32xf32_vec_vec_vec
(
size_t
size
,
const
uint8_t
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
float
*
dst
)
{
const
uint8_t
*
__restrict
sptr0
=
src0
;
const
float
*
__restrict
sptr1
=
src1
;
const
float
*
__restrict
sptr2
=
src2
;
float
*
__restrict
dst_ptr
=
dst
;
size_t
i
=
0
;
for
(;
i
+
15
<
size
;
i
+=
16
,
sptr0
+=
16
,
sptr1
+=
16
,
sptr2
+=
16
,
dst_ptr
+=
16
)
{
auto
vec0_0123
=
vld1q_u8
(
sptr0
);
auto
vec1_0
=
vld1q_f32
(
sptr1
);
auto
vec1_1
=
vld1q_f32
(
sptr1
+
4
);
auto
vec1_2
=
vld1q_f32
(
sptr1
+
8
);
auto
vec1_3
=
vld1q_f32
(
sptr1
+
12
);
auto
vec2_0
=
vld1q_f32
(
sptr2
);
auto
vec2_1
=
vld1q_f32
(
sptr2
+
4
);
auto
vec2_2
=
vld1q_f32
(
sptr2
+
8
);
auto
vec2_3
=
vld1q_f32
(
sptr2
+
12
);
auto
vec0_01
=
vreinterpretq_s16_u16
(
vmovl_u8
(
vget_low_u8
(
vec0_0123
)));
auto
vec0_23
=
vreinterpretq_s16_u16
(
vmovl_u8
(
vget_high_u8
(
vec0_0123
)));
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
vec0_2
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_23
)));
auto
vec0_3
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_23
)));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2_0
,
vec0_0
,
vec1_0
);
auto
dst_vec_1
=
Vfmaq_f32
(
vec2_1
,
vec0_1
,
vec1_1
);
auto
dst_vec_2
=
Vfmaq_f32
(
vec2_2
,
vec0_2
,
vec1_2
);
auto
dst_vec_3
=
Vfmaq_f32
(
vec2_3
,
vec0_3
,
vec1_3
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
vst1q_f32
(
dst_ptr
+
8
,
dst_vec_2
);
vst1q_f32
(
dst_ptr
+
12
,
dst_vec_3
);
}
for
(;
i
+
7
<
size
;
i
+=
8
,
sptr0
+=
8
,
sptr1
+=
8
,
sptr2
+=
8
,
dst_ptr
+=
8
)
{
auto
vec0_01_u8
=
vld1_u8
(
sptr0
);
auto
vec1_0
=
vld1q_f32
(
sptr1
);
auto
vec1_1
=
vld1q_f32
(
sptr1
+
4
);
auto
vec2_0
=
vld1q_f32
(
sptr2
);
auto
vec2_1
=
vld1q_f32
(
sptr2
+
4
);
auto
vec0_01
=
vreinterpretq_s16_u16
(
vmovl_u8
(
vec0_01_u8
));
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2_0
,
vec0_0
,
vec1_0
);
auto
dst_vec_1
=
Vfmaq_f32
(
vec2_1
,
vec0_1
,
vec1_1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
}
for
(;
i
<
size
;
++
i
,
++
sptr0
,
++
sptr1
,
++
sptr2
,
++
dst_ptr
)
{
*
dst_ptr
=
(
float
)(
*
sptr0
)
*
(
*
sptr1
)
+
(
*
sptr2
);
}
}
void
neon_fuse_mul_add3_int16xf32xf32xf32_vec_scaler_scaler
(
size_t
size
,
const
int16_t
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
float
*
dst
)
{
const
int16_t
*
__restrict
sptr0
=
src0
;
const
float
*
__restrict
sptr1
=
src1
;
const
float
*
__restrict
sptr2
=
src2
;
auto
vec1
=
vdupq_n_f32
(
sptr1
[
0
]);
auto
vec2
=
vdupq_n_f32
(
sptr2
[
0
]);
float
*
__restrict
dst_ptr
=
dst
;
size_t
i
=
0
;
for
(;
i
+
15
<
size
;
i
+=
16
,
sptr0
+=
16
,
dst_ptr
+=
16
)
{
auto
vec0_01
=
vld1q_s16
(
sptr0
);
auto
vec0_23
=
vld1q_s16
(
sptr0
+
8
);
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
vec0_2
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_23
)));
auto
vec0_3
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_23
)));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2
,
vec0_0
,
vec1
);
auto
dst_vec_1
=
Vfmaq_f32
(
vec2
,
vec0_1
,
vec1
);
auto
dst_vec_2
=
Vfmaq_f32
(
vec2
,
vec0_2
,
vec1
);
auto
dst_vec_3
=
Vfmaq_f32
(
vec2
,
vec0_3
,
vec1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
vst1q_f32
(
dst_ptr
+
8
,
dst_vec_2
);
vst1q_f32
(
dst_ptr
+
12
,
dst_vec_3
);
}
for
(;
i
+
7
<
size
;
i
+=
8
,
sptr0
+=
8
,
sptr1
+=
8
,
sptr2
+=
8
,
dst_ptr
+=
8
)
{
auto
vec0_01
=
vld1q_s16
(
sptr0
);
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2
,
vec0_0
,
vec1
);
auto
dst_vec_1
=
Vfmaq_f32
(
vec2
,
vec0_1
,
vec1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
}
for
(;
i
+
3
<
size
;
i
+=
4
,
sptr0
+=
4
,
sptr1
+=
4
,
sptr2
+=
4
,
dst_ptr
+=
4
)
{
auto
vec0_0
=
vld1_s16
(
sptr0
);
auto
vec0_0_f32
=
vcvtq_f32_s32
(
vmovl_s16
(
vec0_0
));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2
,
vec0_0_f32
,
vec1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
}
for
(;
i
<
size
;
++
i
,
++
sptr0
,
++
dst_ptr
)
{
*
dst_ptr
=
(
float
)(
*
sptr0
)
*
(
*
sptr1
)
+
(
*
sptr2
);
}
}
void
neon_fuse_mul_add3_uint8xf32xf32xf32_vec_scaler_scaler
(
size_t
size
,
const
uint8_t
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
float
*
dst
)
{
const
uint8_t
*
__restrict
sptr0
=
src0
;
const
float
*
__restrict
sptr1
=
src1
;
const
float
*
__restrict
sptr2
=
src2
;
auto
vec1
=
vdupq_n_f32
(
sptr1
[
0
]);
auto
vec2
=
vdupq_n_f32
(
sptr2
[
0
]);
float
*
__restrict
dst_ptr
=
dst
;
size_t
i
=
0
;
for
(;
i
+
15
<
size
;
i
+=
16
,
sptr0
+=
16
,
dst_ptr
+=
16
)
{
auto
vec0_0123
=
vld1q_u8
(
sptr0
);
auto
vec0_01
=
vreinterpretq_s16_u16
(
vmovl_u8
(
vget_low_u8
(
vec0_0123
)));
auto
vec0_23
=
vreinterpretq_s16_u16
(
vmovl_u8
(
vget_high_u8
(
vec0_0123
)));
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
vec0_2
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_23
)));
auto
vec0_3
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_23
)));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2
,
vec0_0
,
vec1
);
auto
dst_vec_1
=
Vfmaq_f32
(
vec2
,
vec0_1
,
vec1
);
auto
dst_vec_2
=
Vfmaq_f32
(
vec2
,
vec0_2
,
vec1
);
auto
dst_vec_3
=
Vfmaq_f32
(
vec2
,
vec0_3
,
vec1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
vst1q_f32
(
dst_ptr
+
8
,
dst_vec_2
);
vst1q_f32
(
dst_ptr
+
12
,
dst_vec_3
);
}
for
(;
i
+
7
<
size
;
i
+=
8
,
sptr0
+=
8
,
dst_ptr
+=
8
)
{
auto
vec0_01_u8
=
vld1_u8
(
sptr0
);
auto
vec0_01
=
vreinterpretq_s16_u16
(
vmovl_u8
(
vec0_01_u8
));
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
dst_vec_0
=
Vfmaq_f32
(
vec2
,
vec0_0
,
vec1
);
auto
dst_vec_1
=
Vfmaq_f32
(
vec2
,
vec0_1
,
vec1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
}
for
(;
i
<
size
;
++
i
,
++
sptr0
,
++
dst_ptr
)
{
*
dst_ptr
=
(
float
)(
*
sptr0
)
*
(
*
sptr1
)
+
(
*
sptr2
);
}
}
void
neon_mul_int16xf32xf32_vec_bcast111c
(
size_t
batch_size
,
size_t
channel_stride
,
size_t
channel_size
,
const
int16_t
*
src0
,
const
float
*
src1
,
float
*
dst
)
{
const
int16_t
*
__restrict
sptr0
=
src0
;
const
float
*
__restrict
sptr1
=
src1
;
float
*
__restrict
dst_ptr
=
dst
;
for
(
size_t
batch
=
0
;
batch
<
batch_size
;
++
batch
)
{
for
(
size_t
s
=
0
;
s
<
channel_stride
;
++
s
)
{
size_t
i
=
0
;
for
(;
i
+
15
<
channel_size
;
i
+=
16
,
sptr0
+=
16
,
dst_ptr
+=
16
)
{
auto
vec0_01
=
vld1q_s16
(
sptr0
);
auto
vec0_23
=
vld1q_s16
(
sptr0
+
8
);
auto
vec1_0
=
vld1q_f32
(
sptr1
+
i
);
auto
vec1_1
=
vld1q_f32
(
sptr1
+
i
+
4
);
auto
vec1_2
=
vld1q_f32
(
sptr1
+
i
+
8
);
auto
vec1_3
=
vld1q_f32
(
sptr1
+
i
+
12
);
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
vec0_2
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_23
)));
auto
vec0_3
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_23
)));
auto
dst_vec_0
=
vmulq_f32
(
vec0_0
,
vec1_0
);
auto
dst_vec_1
=
vmulq_f32
(
vec0_1
,
vec1_1
);
auto
dst_vec_2
=
vmulq_f32
(
vec0_2
,
vec1_2
);
auto
dst_vec_3
=
vmulq_f32
(
vec0_3
,
vec1_3
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
vst1q_f32
(
dst_ptr
+
8
,
dst_vec_2
);
vst1q_f32
(
dst_ptr
+
12
,
dst_vec_3
);
}
for
(;
i
+
7
<
channel_size
;
i
+=
8
,
sptr0
+=
8
,
dst_ptr
+=
8
)
{
auto
vec0_01
=
vld1q_s16
(
sptr0
);
auto
vec1_0
=
vld1q_f32
(
sptr1
+
i
);
auto
vec1_1
=
vld1q_f32
(
sptr1
+
i
+
4
);
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
dst_vec_0
=
vmulq_f32
(
vec0_0
,
vec1_0
);
auto
dst_vec_1
=
vmulq_f32
(
vec0_1
,
vec1_1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
}
for
(;
i
<
channel_size
;
++
i
,
++
sptr0
,
++
dst_ptr
)
{
*
dst_ptr
=
(
float
)(
*
sptr0
)
*
sptr1
[
i
];
}
}
}
}
void
neon_mul_int16xf32xf32_vec_bcast101
(
size_t
batch_size
,
size_t
channel_size
,
size_t
channel_stride
,
const
int16_t
*
src0
,
const
float
*
src1
,
float
*
dst
)
{
const
int16_t
*
__restrict
sptr0
=
src0
;
const
float
*
__restrict
sptr1
=
src1
;
float
*
__restrict
dst_ptr
=
dst
;
for
(
size_t
batch
=
0
;
batch
<
batch_size
;
++
batch
)
{
for
(
size_t
chan
=
0
;
chan
<
channel_size
;
++
chan
)
{
auto
vec1
=
vdupq_n_f32
(
sptr1
[
chan
]);
size_t
i
=
0
;
for
(;
i
+
15
<
channel_stride
;
i
+=
16
,
sptr0
+=
16
,
dst_ptr
+=
16
)
{
auto
vec0_01
=
vld1q_s16
(
sptr0
);
auto
vec0_23
=
vld1q_s16
(
sptr0
+
8
);
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
vec0_2
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_23
)));
auto
vec0_3
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_23
)));
auto
dst_vec_0
=
vmulq_f32
(
vec0_0
,
vec1
);
auto
dst_vec_1
=
vmulq_f32
(
vec0_1
,
vec1
);
auto
dst_vec_2
=
vmulq_f32
(
vec0_2
,
vec1
);
auto
dst_vec_3
=
vmulq_f32
(
vec0_3
,
vec1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
vst1q_f32
(
dst_ptr
+
8
,
dst_vec_2
);
vst1q_f32
(
dst_ptr
+
12
,
dst_vec_3
);
}
for
(;
i
+
7
<
channel_stride
;
i
+=
8
,
sptr0
+=
8
,
dst_ptr
+=
8
)
{
auto
vec0_01
=
vld1q_s16
(
sptr0
);
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
dst_vec_0
=
vmulq_f32
(
vec0_0
,
vec1
);
auto
dst_vec_1
=
vmulq_f32
(
vec0_1
,
vec1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
}
for
(;
i
<
channel_stride
;
++
i
,
++
sptr0
,
++
dst_ptr
)
{
*
dst_ptr
=
(
float
)(
*
sptr0
)
*
sptr1
[
chan
];
}
}
}
}
void
neon_mul_int16xf32xf32_vec_vec
(
size_t
size
,
const
int16_t
*
src0
,
const
float
*
src1
,
float
*
dst
)
{
const
int16_t
*
__restrict
sptr0
=
src0
;
const
float
*
__restrict
sptr1
=
src1
;
float
*
__restrict
dst_ptr
=
dst
;
size_t
i
=
0
;
for
(;
i
+
15
<
size
;
i
+=
16
,
sptr0
+=
16
,
sptr1
+=
16
,
dst_ptr
+=
16
)
{
auto
vec0_01
=
vld1q_s16
(
sptr0
);
auto
vec0_23
=
vld1q_s16
(
sptr0
+
8
);
auto
vec1_0
=
vld1q_f32
(
sptr1
);
auto
vec1_1
=
vld1q_f32
(
sptr1
+
4
);
auto
vec1_2
=
vld1q_f32
(
sptr1
+
8
);
auto
vec1_3
=
vld1q_f32
(
sptr1
+
12
);
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
vec0_2
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_23
)));
auto
vec0_3
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_23
)));
auto
dst_vec_0
=
vmulq_f32
(
vec0_0
,
vec1_0
);
auto
dst_vec_1
=
vmulq_f32
(
vec0_1
,
vec1_1
);
auto
dst_vec_2
=
vmulq_f32
(
vec0_2
,
vec1_2
);
auto
dst_vec_3
=
vmulq_f32
(
vec0_3
,
vec1_3
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
vst1q_f32
(
dst_ptr
+
8
,
dst_vec_2
);
vst1q_f32
(
dst_ptr
+
12
,
dst_vec_3
);
}
for
(;
i
+
7
<
size
;
i
+=
8
,
sptr0
+=
8
,
sptr1
+=
8
,
dst_ptr
+=
8
)
{
auto
vec0_01
=
vld1q_s16
(
sptr0
);
auto
vec1_0
=
vld1q_f32
(
sptr1
);
auto
vec1_1
=
vld1q_f32
(
sptr1
+
4
);
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
dst_vec_0
=
vmulq_f32
(
vec0_0
,
vec1_0
);
auto
dst_vec_1
=
vmulq_f32
(
vec0_1
,
vec1_1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
}
for
(;
i
<
size
;
++
i
,
++
sptr0
,
++
sptr1
,
++
dst_ptr
)
{
*
dst_ptr
=
(
float
)(
*
sptr0
)
*
(
*
sptr1
);
}
}
void
neon_mul_int16xf32xf32_vec_scaler
(
size_t
size
,
const
int16_t
*
src0
,
const
float
*
src1
,
float
*
dst
)
{
const
int16_t
*
__restrict
sptr0
=
src0
;
const
float
*
__restrict
sptr1
=
src1
;
auto
vec1
=
vdupq_n_f32
(
sptr1
[
0
]);
float
*
__restrict
dst_ptr
=
dst
;
size_t
i
=
0
;
for
(;
i
+
15
<
size
;
i
+=
16
,
sptr0
+=
16
,
dst_ptr
+=
16
)
{
auto
vec0_01
=
vld1q_s16
(
sptr0
);
auto
vec0_23
=
vld1q_s16
(
sptr0
+
8
);
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
vec0_2
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_23
)));
auto
vec0_3
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_23
)));
auto
dst_vec_0
=
vmulq_f32
(
vec0_0
,
vec1
);
auto
dst_vec_1
=
vmulq_f32
(
vec0_1
,
vec1
);
auto
dst_vec_2
=
vmulq_f32
(
vec0_2
,
vec1
);
auto
dst_vec_3
=
vmulq_f32
(
vec0_3
,
vec1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
vst1q_f32
(
dst_ptr
+
8
,
dst_vec_2
);
vst1q_f32
(
dst_ptr
+
12
,
dst_vec_3
);
}
for
(;
i
+
7
<
size
;
i
+=
8
,
sptr0
+=
8
,
dst_ptr
+=
8
)
{
auto
vec0_01
=
vld1q_s16
(
sptr0
);
auto
vec0_0
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_low_s16
(
vec0_01
)));
auto
vec0_1
=
vcvtq_f32_s32
(
vmovl_s16
(
vget_high_s16
(
vec0_01
)));
auto
dst_vec_0
=
vmulq_f32
(
vec0_0
,
vec1
);
auto
dst_vec_1
=
vmulq_f32
(
vec0_1
,
vec1
);
vst1q_f32
(
dst_ptr
,
dst_vec_0
);
vst1q_f32
(
dst_ptr
+
4
,
dst_vec_1
);
}
for
(;
i
<
size
;
++
i
,
++
sptr0
,
++
dst_ptr
)
{
*
dst_ptr
=
(
float
)(
*
sptr0
)
*
(
*
sptr1
);
}
}
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/arm_common/elemwise_multi_type/kernels.h
0 → 100644
浏览文件 @
f6d99094
/**
* \file dnn/src/arm_common/elemwise_multi_type/kernels.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "stddef.h"
#include "stdint.h"
namespace
megdnn
{
namespace
arm_common
{
void
neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast111c_bcast111c
(
size_t
batch_size
,
size_t
channel_stride
,
size_t
channel_size
,
const
int16_t
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
float
*
dst
);
void
neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast111c_bcast111c
(
size_t
batch_size
,
size_t
channel_stride
,
size_t
channel_size
,
const
uint8_t
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
float
*
dst
);
void
neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast101_bcast101
(
size_t
batch_size
,
size_t
channel_size
,
size_t
channel_stride
,
const
int16_t
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
float
*
dst
);
void
neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast101_bcast101
(
size_t
batch_size
,
size_t
channel_size
,
size_t
channel_stride
,
const
uint8_t
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
float
*
dst
);
void
neon_fuse_mul_add3_int16xf32xf32xf32_vec_vec_vec
(
size_t
size
,
const
int16_t
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
float
*
dst
);
void
neon_fuse_mul_add3_uint8xf32xf32xf32_vec_vec_vec
(
size_t
size
,
const
uint8_t
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
float
*
dst
);
void
neon_fuse_mul_add3_int16xf32xf32xf32_vec_b1x_b1x
(
size_t
size
,
size_t
vec
,
const
int16_t
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
float
*
dst
);
void
neon_fuse_mul_add3_int16xf32xf32xf32_vec_scaler_scaler
(
size_t
size
,
const
int16_t
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
float
*
dst
);
void
neon_fuse_mul_add3_uint8xf32xf32xf32_vec_scaler_scaler
(
size_t
size
,
const
uint8_t
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
float
*
dst
);
void
neon_mul_int16xf32xf32_vec_bcast111c
(
size_t
batch_size
,
size_t
channel_stride
,
size_t
channel_size
,
const
int16_t
*
src0
,
const
float
*
src1
,
float
*
dst
);
void
neon_mul_int16xf32xf32_vec_bcast101
(
size_t
batch_size
,
size_t
channel_size
,
size_t
channel_stride
,
const
int16_t
*
src0
,
const
float
*
src1
,
float
*
dst
);
void
neon_mul_int16xf32xf32_vec_vec
(
size_t
size
,
const
int16_t
*
src0
,
const
float
*
src1
,
float
*
dst
);
void
neon_mul_int16xf32xf32_vec_scaler
(
size_t
size
,
const
int16_t
*
src0
,
const
float
*
src1
,
float
*
dst
);
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp
浏览文件 @
f6d99094
...
...
@@ -11,6 +11,7 @@
*/
#include "./opr_impl.h"
#include "kernels.h"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "src/naive/handle.h"
...
...
@@ -851,6 +852,154 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
#undef DISPATCH_QUANTIZED_MODE
}
void
ElemwiseMultiTypeImpl
::
on_fuse_mul_add3_int16xf32xf32xf32
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
)
{
auto
&
src0
=
param
[
0
],
&
src1
=
param
[
1
],
&
src2
=
param
[
2
];
BroadcastChannelInfo
binfo
;
if
(
is_vector
(
src0
.
layout
)
&&
is_NHWC_broadcasted_channel_like
(
src1
.
layout
,
binfo
)
&&
src1
.
layout
.
eq_layout
(
src2
.
layout
))
{
// VEC_BCAST111C_BCAST111C
MEGDNN_DISPATCH_CPU_KERN_OPR
(
neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast111c_bcast111c
(
binfo
.
x
,
binfo
.
y
,
binfo
.
z
,
static_cast
<
dt_int16
*>
(
src0
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src2
.
raw_ptr
()),
dst
.
ptr
<
dt_float32
>
()));
return
;
}
else
if
(
is_vector
(
src0
.
layout
)
&&
is_broadcasted_channel_like
(
src1
.
layout
,
binfo
)
&&
src1
.
layout
.
eq_layout
(
src2
.
layout
))
{
// VEC_BCAST101_BCAST101
MEGDNN_DISPATCH_CPU_KERN_OPR
(
neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast101_bcast101
(
binfo
.
x
,
binfo
.
y
,
binfo
.
z
,
static_cast
<
dt_int16
*>
(
src0
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src2
.
raw_ptr
()),
dst
.
ptr
<
dt_float32
>
()));
return
;
}
else
if
(
is_vector
(
src0
.
layout
)
&&
is_vector
(
src1
.
layout
)
&&
is_vector
(
src2
.
layout
))
{
// VEC_VEC_VEC
auto
size
=
param
.
size
;
MEGDNN_DISPATCH_CPU_KERN_OPR
(
neon_fuse_mul_add3_int16xf32xf32xf32_vec_vec_vec
(
size
,
static_cast
<
dt_int16
*>
(
src0
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src2
.
raw_ptr
()),
dst
.
ptr
<
dt_float32
>
()));
return
;
}
else
if
(
is_vector
(
src0
.
layout
)
&&
is_broadcasted_scalar
(
src1
.
layout
)
&&
is_broadcasted_scalar
(
src2
.
layout
))
{
// VEC_SCALAR_SCALAR
auto
size
=
param
.
size
;
MEGDNN_DISPATCH_CPU_KERN_OPR
(
neon_fuse_mul_add3_int16xf32xf32xf32_vec_scaler_scaler
(
size
,
static_cast
<
dt_int16
*>
(
src0
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src2
.
raw_ptr
()),
dst
.
ptr
<
dt_float32
>
()));
return
;
}
naive
::
ElemwiseMultiTypeImpl
::
on_fuse_mul_add3_int16xf32xf32xf32
(
param
,
dst
);
}
void
ElemwiseMultiTypeImpl
::
on_fuse_mul_add3_uint8xf32xf32xf32
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
)
{
auto
&
src0
=
param
[
0
],
&
src1
=
param
[
1
],
&
src2
=
param
[
2
];
BroadcastChannelInfo
binfo
;
if
(
is_vector
(
src0
.
layout
)
&&
is_NHWC_broadcasted_channel_like
(
src1
.
layout
,
binfo
)
&&
src1
.
layout
.
eq_layout
(
src2
.
layout
))
{
// VEC_BCAST111C_BCAST111C
MEGDNN_DISPATCH_CPU_KERN_OPR
(
neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast111c_bcast111c
(
binfo
.
x
,
binfo
.
y
,
binfo
.
z
,
static_cast
<
dt_uint8
*>
(
src0
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src2
.
raw_ptr
()),
dst
.
ptr
<
dt_float32
>
()));
return
;
}
else
if
(
is_vector
(
src0
.
layout
)
&&
is_broadcasted_channel_like
(
src1
.
layout
,
binfo
)
&&
src1
.
layout
.
eq_layout
(
src2
.
layout
))
{
// VEC_BCAST101_BCAST101
MEGDNN_DISPATCH_CPU_KERN_OPR
(
neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast101_bcast101
(
binfo
.
x
,
binfo
.
y
,
binfo
.
z
,
static_cast
<
dt_uint8
*>
(
src0
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src2
.
raw_ptr
()),
dst
.
ptr
<
dt_float32
>
()));
return
;
}
else
if
(
is_vector
(
src0
.
layout
)
&&
is_vector
(
src1
.
layout
)
&&
is_vector
(
src2
.
layout
))
{
// VEC_VEC_VEC
auto
size
=
param
.
size
;
MEGDNN_DISPATCH_CPU_KERN_OPR
(
neon_fuse_mul_add3_uint8xf32xf32xf32_vec_vec_vec
(
size
,
static_cast
<
dt_uint8
*>
(
src0
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src2
.
raw_ptr
()),
dst
.
ptr
<
dt_float32
>
()));
return
;
}
else
if
(
is_vector
(
src0
.
layout
)
&&
is_broadcasted_scalar
(
src1
.
layout
)
&&
is_broadcasted_scalar
(
src2
.
layout
))
{
// VEC_SCALAR_SCALAR
auto
size
=
param
.
size
;
MEGDNN_DISPATCH_CPU_KERN_OPR
(
neon_fuse_mul_add3_uint8xf32xf32xf32_vec_scaler_scaler
(
size
,
static_cast
<
dt_uint8
*>
(
src0
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src2
.
raw_ptr
()),
dst
.
ptr
<
dt_float32
>
()));
return
;
}
naive
::
ElemwiseMultiTypeImpl
::
on_fuse_mul_add3_uint8xf32xf32xf32
(
param
,
dst
);
}
void
ElemwiseMultiTypeImpl
::
on_mul_int16xf32xf32
(
const
ElemwiseOpParamN
<
2
>&
param
,
const
TensorND
&
dst
)
{
auto
&
src0
=
param
[
0
],
&
src1
=
param
[
1
];
BroadcastChannelInfo
binfo
;
if
(
is_vector
(
src0
.
layout
)
&&
is_NHWC_broadcasted_channel_like
(
src1
.
layout
,
binfo
))
{
// VEC_BCAST111C
MEGDNN_DISPATCH_CPU_KERN_OPR
(
neon_mul_int16xf32xf32_vec_bcast111c
(
binfo
.
x
,
binfo
.
y
,
binfo
.
z
,
static_cast
<
dt_int16
*>
(
src0
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
()),
dst
.
ptr
<
dt_float32
>
()));
return
;
}
else
if
(
is_vector
(
src0
.
layout
)
&&
is_broadcasted_channel_like
(
src1
.
layout
,
binfo
))
{
// VEC_BCAST101
MEGDNN_DISPATCH_CPU_KERN_OPR
(
neon_mul_int16xf32xf32_vec_bcast101
(
binfo
.
x
,
binfo
.
y
,
binfo
.
z
,
static_cast
<
dt_int16
*>
(
src0
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
()),
dst
.
ptr
<
dt_float32
>
()));
return
;
}
else
if
(
is_vector
(
src0
.
layout
)
&&
is_vector
(
src1
.
layout
))
{
// VEC_VEC
auto
size
=
param
.
size
;
MEGDNN_DISPATCH_CPU_KERN_OPR
(
neon_mul_int16xf32xf32_vec_vec
(
size
,
static_cast
<
dt_int16
*>
(
src0
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
()),
dst
.
ptr
<
dt_float32
>
()));
return
;
}
else
if
(
is_vector
(
src0
.
layout
)
&&
is_broadcasted_scalar
(
src1
.
layout
))
{
auto
size
=
param
.
size
;
MEGDNN_DISPATCH_CPU_KERN_OPR
(
neon_mul_int16xf32xf32_vec_scaler
(
size
,
static_cast
<
dt_int16
*>
(
src0
.
raw_ptr
()),
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
()),
dst
.
ptr
<
dt_float32
>
()));
return
;
}
naive
::
ElemwiseMultiTypeImpl
::
on_mul_int16xf32xf32
(
param
,
dst
);
}
}
// namespace arm_common
}
// namespace megdnn
...
...
dnn/src/arm_common/elemwise_multi_type/opr_impl.h
浏览文件 @
f6d99094
...
...
@@ -48,6 +48,15 @@ protected:
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
,
Elemwise
::
Mode
mode
)
override
;
void
on_fuse_mul_add3_int16xf32xf32xf32
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
)
override
;
void
on_mul_int16xf32xf32
(
const
ElemwiseOpParamN
<
2
>&
param
,
const
TensorND
&
dst
)
override
;
void
on_fuse_mul_add3_uint8xf32xf32xf32
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
)
override
;
public:
using
fallback
::
ElemwiseMultiTypeImpl
::
ElemwiseMultiTypeImpl
;
};
...
...
dnn/src/common/elemwise_multi_type/opr_impl.cpp
浏览文件 @
f6d99094
...
...
@@ -155,6 +155,29 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
dst
.
name
=
name
;
dst
.
need_specify_out_dtype
=
true
;
};
auto
init_fma3_int16xf32xf32xf32
=
[
&
](
ModeTrait
&
dst
,
const
char
*
name
)
{
dst
.
arity
=
3
;
dst
.
check_inp
[
0
]
=
make_check_dtype_func
(
dtype
::
Int16
());
dst
.
check_inp
[
1
]
=
make_check_dtype_func
(
dtype
::
Float32
());
dst
.
check_inp
[
2
]
=
make_check_dtype_func
(
dtype
::
Float32
());
dst
.
check_out
=
make_out_dtype_func
(
dtype
::
Float32
());
dst
.
name
=
name
;
};
auto
init_mul_int16xf32xf32
=
[
&
](
ModeTrait
&
dst
,
const
char
*
name
)
{
dst
.
arity
=
2
;
dst
.
check_inp
[
0
]
=
make_check_dtype_func
(
dtype
::
Int16
());
dst
.
check_inp
[
1
]
=
make_check_dtype_func
(
dtype
::
Float32
());
dst
.
check_out
=
make_out_dtype_func
(
dtype
::
Float32
());
dst
.
name
=
name
;
};
auto
init_fma3_uint8xf32xf32xf32
=
[
&
](
ModeTrait
&
dst
,
const
char
*
name
)
{
dst
.
arity
=
3
;
dst
.
check_inp
[
0
]
=
make_check_dtype_func
(
dtype
::
Uint8
());
dst
.
check_inp
[
1
]
=
make_check_dtype_func
(
dtype
::
Float32
());
dst
.
check_inp
[
2
]
=
make_check_dtype_func
(
dtype
::
Float32
());
dst
.
check_out
=
make_out_dtype_func
(
dtype
::
Float32
());
dst
.
name
=
name
;
};
#define SET(f, m) \
MIDOUT_BEGIN(megdnn_common_elemwise_multi_type, midout_iv(Mode::m)) { \
...
...
@@ -169,6 +192,9 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
SET
(
init_fuse_add_rmulh_rshr_int32x32x32x8
,
FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8
);
SET
(
init_rshrs_iXxi8xi16
,
ROUND_SHR_SATURATE_IXxI8xI16
);
SET
(
init_fma3_int16xf32xf32xf32
,
FUSE_MUL_ADD3_INT16xF32xF32xF32
);
SET
(
init_mul_int16xf32xf32
,
MUL_INT16xF32xF32
);
SET
(
init_fma3_uint8xf32xf32xf32
,
FUSE_MUL_ADD3_UINT8xF32xF32xF32
);
//! quantized opr, with specified dtype.
//! dispatch elemwise mode internally
...
...
dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp
浏览文件 @
f6d99094
...
...
@@ -43,6 +43,17 @@ void ElemwiseMultiTypeImplHelper::exec(
case
Mode
::
ROUND_SHR_SATURATE_IXxI8xI16
:
on_round_shr_saturate_iXxi8xi16
(
make_elemwise_op_param
<
2
>
(
src
,
dst
),
dst
);
break
;
case
Mode
::
FUSE_MUL_ADD3_INT16xF32xF32xF32
:
on_fuse_mul_add3_int16xf32xf32xf32
(
make_elemwise_op_param
<
3
>
(
src
,
dst
),
dst
);
break
;
case
Mode
::
MUL_INT16xF32xF32
:
on_mul_int16xf32xf32
(
make_elemwise_op_param
<
2
>
(
src
,
dst
),
dst
);
break
;
case
Mode
::
FUSE_MUL_ADD3_UINT8xF32xF32xF32
:
on_fuse_mul_add3_uint8xf32xf32xf32
(
make_elemwise_op_param
<
3
>
(
src
,
dst
),
dst
);
break
;
ON_QUANTIZED_MODE
(
RELU
,
1
);
ON_QUANTIZED_MODE
(
ABS
,
1
);
ON_QUANTIZED_MODE
(
ACOS
,
1
);
...
...
dnn/src/common/elemwise_multi_type/opr_impl_helper.h
浏览文件 @
f6d99094
...
...
@@ -50,6 +50,27 @@ protected:
virtual
void
on_round_shr_saturate_iXxi8xi16
(
const
ElemwiseOpParamN
<
2
>&
param
,
const
TensorND
&
dst
)
=
0
;
virtual
void
on_fuse_mul_add3_int16xf32xf32xf32
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
)
{
MEGDNN_MARK_USED_VAR
(
param
);
MEGDNN_MARK_USED_VAR
(
dst
);
megdnn_throw
(
"unsupported ElemwiseMultiType fma3 int16xf32xf32xf32."
);
}
virtual
void
on_mul_int16xf32xf32
(
const
ElemwiseOpParamN
<
2
>&
param
,
const
TensorND
&
dst
)
{
MEGDNN_MARK_USED_VAR
(
param
);
MEGDNN_MARK_USED_VAR
(
dst
);
megdnn_throw
(
"unsupported ElemwiseMultiType fma3 int16xf32xf32."
);
}
virtual
void
on_fuse_mul_add3_uint8xf32xf32xf32
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
)
{
MEGDNN_MARK_USED_VAR
(
param
);
MEGDNN_MARK_USED_VAR
(
dst
);
megdnn_throw
(
"unsupported ElemwiseMultiType fma3 uint8xf32xf32xf32."
);
}
virtual
void
on_quantized_mode
(
const
ElemwiseOpParamN
<
1
>&
param
,
const
TensorND
&
dst
,
Elemwise
::
Mode
mode
)
{
...
...
dnn/src/fallback/elemwise_multi_type/opr_impl.cpp
浏览文件 @
f6d99094
...
...
@@ -56,6 +56,216 @@ void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32(
naive
::
ElemwiseMultiTypeImpl
::
on_fuse_mul_add3_int16x32x32x32
(
param
,
dst
);
}
void
ElemwiseMultiTypeImpl
::
on_fuse_mul_add3_int16xf32xf32xf32
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
)
{
BroadcastChannelInfo
binfo0
,
binfo1
;
if
(
is_vector
(
param
[
0
].
layout
)
&&
is_NHWC_broadcasted_channel_like
(
param
[
1
].
layout
,
binfo0
)
&&
is_NHWC_broadcasted_channel_like
(
param
[
2
].
layout
,
binfo1
)
&&
binfo0
==
binfo1
)
{
auto
x
=
binfo0
.
x
,
y
=
binfo0
.
y
,
z
=
binfo0
.
z
;
auto
src0
=
param
[
0
];
auto
src1
=
param
[
1
];
auto
src2
=
param
[
2
];
auto
work
=
[
=
]()
{
const
dt_int16
*
__restrict__
a
=
static_cast
<
dt_int16
*>
(
src0
.
raw_ptr
());
const
dt_float32
*
__restrict__
b
=
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
());
const
dt_float32
*
__restrict__
c
=
static_cast
<
dt_float32
*>
(
src2
.
raw_ptr
());
dt_float32
*
__restrict__
d
=
dst
.
ptr
<
dt_float32
>
();
for
(
size_t
i
=
0
;
i
<
x
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
y
;
++
j
)
{
auto
off
=
i
*
(
y
*
z
)
+
j
*
z
;
size_t
k
=
0
;
for
(;
k
+
4
<=
z
;
k
+=
4
)
{
d
[
off
+
k
+
0
]
=
a
[
off
+
k
+
0
]
*
b
[
k
+
0
]
+
c
[
k
+
0
];
d
[
off
+
k
+
1
]
=
a
[
off
+
k
+
1
]
*
b
[
k
+
1
]
+
c
[
k
+
1
];
d
[
off
+
k
+
2
]
=
a
[
off
+
k
+
2
]
*
b
[
k
+
2
]
+
c
[
k
+
2
];
d
[
off
+
k
+
3
]
=
a
[
off
+
k
+
3
]
*
b
[
k
+
3
]
+
c
[
k
+
3
];
}
for
(;
k
<
z
;
++
k
)
{
d
[
off
+
k
]
=
a
[
off
+
k
]
*
b
[
k
]
+
c
[
k
];
}
}
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR
(
work
());
return
;
}
else
if
(
is_vector
(
param
[
0
].
layout
)
&&
is_broadcasted_channel_like
(
param
[
1
].
layout
,
binfo0
)
&&
is_broadcasted_channel_like
(
param
[
2
].
layout
,
binfo1
)
&&
binfo0
==
binfo1
)
{
auto
x
=
binfo0
.
x
,
y
=
binfo0
.
y
,
z
=
binfo0
.
z
;
auto
src0
=
param
[
0
];
auto
src1
=
param
[
1
];
auto
src2
=
param
[
2
];
auto
work
=
[
=
]()
{
const
dt_int16
*
__restrict__
a
=
static_cast
<
dt_int16
*>
(
src0
.
raw_ptr
());
const
dt_float32
*
__restrict__
b
=
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
());
const
dt_float32
*
__restrict__
c
=
static_cast
<
dt_float32
*>
(
src2
.
raw_ptr
());
dt_float32
*
__restrict__
d
=
dst
.
ptr
<
dt_float32
>
();
for
(
size_t
j
=
0
;
j
<
y
;
++
j
)
{
auto
bv
=
b
[
j
],
cv
=
c
[
j
];
for
(
size_t
i
=
0
;
i
<
x
;
++
i
)
{
auto
off
=
i
*
(
y
*
z
)
+
j
*
z
,
offt
=
off
+
z
;
for
(;
off
+
4
<=
offt
;
off
+=
4
)
{
d
[
off
+
0
]
=
a
[
off
+
0
]
*
bv
+
cv
;
d
[
off
+
1
]
=
a
[
off
+
1
]
*
bv
+
cv
;
d
[
off
+
2
]
=
a
[
off
+
2
]
*
bv
+
cv
;
d
[
off
+
3
]
=
a
[
off
+
3
]
*
bv
+
cv
;
}
for
(;
off
<
offt
;
++
off
)
{
d
[
off
]
=
a
[
off
]
*
bv
+
cv
;
}
}
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR
(
work
());
return
;
}
naive
::
ElemwiseMultiTypeImpl
::
on_fuse_mul_add3_int16xf32xf32xf32
(
param
,
dst
);
}
void
ElemwiseMultiTypeImpl
::
on_mul_int16xf32xf32
(
const
ElemwiseOpParamN
<
2
>&
param
,
const
TensorND
&
dst
)
{
BroadcastChannelInfo
binfo
;
if
(
is_vector
(
param
[
0
].
layout
)
&&
is_NHWC_broadcasted_channel_like
(
param
[
1
].
layout
,
binfo
))
{
auto
x
=
binfo
.
x
,
y
=
binfo
.
y
,
z
=
binfo
.
z
;
auto
src0
=
param
[
0
];
auto
src1
=
param
[
1
];
auto
work
=
[
=
]()
{
const
dt_int16
*
__restrict__
a
=
static_cast
<
dt_int16
*>
(
src0
.
raw_ptr
());
const
dt_float32
*
__restrict__
b
=
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
());
dt_float32
*
__restrict__
d
=
dst
.
ptr
<
dt_float32
>
();
for
(
size_t
i
=
0
;
i
<
x
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
y
;
++
j
)
{
auto
off
=
i
*
(
y
*
z
)
+
j
*
z
;
size_t
k
=
0
;
for
(;
k
+
4
<=
z
;
k
+=
4
)
{
d
[
off
+
k
+
0
]
=
a
[
off
+
k
+
0
]
*
b
[
k
+
0
];
d
[
off
+
k
+
1
]
=
a
[
off
+
k
+
1
]
*
b
[
k
+
1
];
d
[
off
+
k
+
2
]
=
a
[
off
+
k
+
2
]
*
b
[
k
+
2
];
d
[
off
+
k
+
3
]
=
a
[
off
+
k
+
3
]
*
b
[
k
+
3
];
}
for
(;
k
<
z
;
++
k
)
{
d
[
off
+
k
]
=
a
[
off
+
k
]
*
b
[
k
];
}
}
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR
(
work
());
return
;
}
else
if
(
is_vector
(
param
[
0
].
layout
)
&&
is_broadcasted_channel_like
(
param
[
1
].
layout
,
binfo
))
{
auto
x
=
binfo
.
x
,
y
=
binfo
.
y
,
z
=
binfo
.
z
;
auto
src0
=
param
[
0
];
auto
src1
=
param
[
1
];
auto
work
=
[
=
]()
{
const
dt_int16
*
__restrict__
a
=
static_cast
<
dt_int16
*>
(
src0
.
raw_ptr
());
const
dt_float32
*
__restrict__
b
=
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
());
dt_float32
*
__restrict__
d
=
dst
.
ptr
<
dt_float32
>
();
for
(
size_t
j
=
0
;
j
<
y
;
++
j
)
{
auto
bv
=
b
[
j
];
for
(
size_t
i
=
0
;
i
<
x
;
++
i
)
{
auto
off
=
i
*
(
y
*
z
)
+
j
*
z
,
offt
=
off
+
z
;
for
(;
off
+
4
<=
offt
;
off
+=
4
)
{
d
[
off
+
0
]
=
a
[
off
+
0
]
*
bv
;
d
[
off
+
1
]
=
a
[
off
+
1
]
*
bv
;
d
[
off
+
2
]
=
a
[
off
+
2
]
*
bv
;
d
[
off
+
3
]
=
a
[
off
+
3
]
*
bv
;
}
for
(;
off
<
offt
;
++
off
)
{
d
[
off
]
=
a
[
off
]
*
bv
;
}
}
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR
(
work
());
return
;
}
naive
::
ElemwiseMultiTypeImpl
::
on_mul_int16xf32xf32
(
param
,
dst
);
}
void
ElemwiseMultiTypeImpl
::
on_fuse_mul_add3_uint8xf32xf32xf32
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
)
{
BroadcastChannelInfo
binfo0
,
binfo1
;
if
(
is_vector
(
param
[
0
].
layout
)
&&
is_NHWC_broadcasted_channel_like
(
param
[
1
].
layout
,
binfo0
)
&&
is_NHWC_broadcasted_channel_like
(
param
[
2
].
layout
,
binfo1
)
&&
binfo0
==
binfo1
)
{
auto
x
=
binfo0
.
x
,
y
=
binfo0
.
y
,
z
=
binfo0
.
z
;
auto
src0
=
param
[
0
];
auto
src1
=
param
[
1
];
auto
src2
=
param
[
2
];
auto
work
=
[
=
]()
{
const
dt_uint8
*
__restrict__
a
=
static_cast
<
dt_uint8
*>
(
src0
.
raw_ptr
());
const
dt_float32
*
__restrict__
b
=
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
());
const
dt_float32
*
__restrict__
c
=
static_cast
<
dt_float32
*>
(
src2
.
raw_ptr
());
dt_float32
*
__restrict__
d
=
dst
.
ptr
<
dt_float32
>
();
for
(
size_t
i
=
0
;
i
<
x
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
y
;
++
j
)
{
auto
off
=
i
*
(
y
*
z
)
+
j
*
z
;
size_t
k
=
0
;
for
(;
k
+
4
<=
z
;
k
+=
4
)
{
d
[
off
+
k
+
0
]
=
a
[
off
+
k
+
0
]
*
b
[
k
+
0
]
+
c
[
k
+
0
];
d
[
off
+
k
+
1
]
=
a
[
off
+
k
+
1
]
*
b
[
k
+
1
]
+
c
[
k
+
1
];
d
[
off
+
k
+
2
]
=
a
[
off
+
k
+
2
]
*
b
[
k
+
2
]
+
c
[
k
+
2
];
d
[
off
+
k
+
3
]
=
a
[
off
+
k
+
3
]
*
b
[
k
+
3
]
+
c
[
k
+
3
];
}
for
(;
k
<
z
;
++
k
)
{
d
[
off
+
k
]
=
a
[
off
+
k
]
*
b
[
k
]
+
c
[
k
];
}
}
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR
(
work
());
return
;
}
else
if
(
is_vector
(
param
[
0
].
layout
)
&&
is_broadcasted_channel_like
(
param
[
1
].
layout
,
binfo0
)
&&
is_broadcasted_channel_like
(
param
[
2
].
layout
,
binfo1
)
&&
binfo0
==
binfo1
)
{
auto
x
=
binfo0
.
x
,
y
=
binfo0
.
y
,
z
=
binfo0
.
z
;
auto
src0
=
param
[
0
];
auto
src1
=
param
[
1
];
auto
src2
=
param
[
2
];
auto
work
=
[
=
]()
{
const
dt_uint8
*
__restrict__
a
=
static_cast
<
dt_uint8
*>
(
src0
.
raw_ptr
());
const
dt_float32
*
__restrict__
b
=
static_cast
<
dt_float32
*>
(
src1
.
raw_ptr
());
const
dt_float32
*
__restrict__
c
=
static_cast
<
dt_float32
*>
(
src2
.
raw_ptr
());
dt_float32
*
__restrict__
d
=
dst
.
ptr
<
dt_float32
>
();
for
(
size_t
j
=
0
;
j
<
y
;
++
j
)
{
auto
bv
=
b
[
j
],
cv
=
c
[
j
];
for
(
size_t
i
=
0
;
i
<
x
;
++
i
)
{
auto
off
=
i
*
(
y
*
z
)
+
j
*
z
,
offt
=
off
+
z
;
for
(;
off
+
4
<=
offt
;
off
+=
4
)
{
d
[
off
+
0
]
=
a
[
off
+
0
]
*
bv
+
cv
;
d
[
off
+
1
]
=
a
[
off
+
1
]
*
bv
+
cv
;
d
[
off
+
2
]
=
a
[
off
+
2
]
*
bv
+
cv
;
d
[
off
+
3
]
=
a
[
off
+
3
]
*
bv
+
cv
;
}
for
(;
off
<
offt
;
++
off
)
{
d
[
off
]
=
a
[
off
]
*
bv
+
cv
;
}
}
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR
(
work
());
return
;
}
naive
::
ElemwiseMultiTypeImpl
::
on_fuse_mul_add3_uint8xf32xf32xf32
(
param
,
dst
);
}
template
<
typename
ctype
>
void
ElemwiseMultiTypeImpl
::
dispatch_fma3_iXxf32xf32xi8_bcast_1x
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
Broadcast1xInfo
&
binfo
,
...
...
dnn/src/fallback/elemwise_multi_type/opr_impl.h
浏览文件 @
f6d99094
...
...
@@ -43,6 +43,12 @@ protected:
const
ElemwiseOpParamN
<
6
>&
param
,
const
TensorND
&
dst
)
override
;
void
on_round_shr_saturate_iXxi8xi16
(
const
ElemwiseOpParamN
<
2
>&
param
,
const
TensorND
&
dst
)
override
;
void
on_fuse_mul_add3_int16xf32xf32xf32
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
)
override
;
void
on_mul_int16xf32xf32
(
const
ElemwiseOpParamN
<
2
>&
param
,
const
TensorND
&
dst
)
override
;
void
on_fuse_mul_add3_uint8xf32xf32xf32
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
)
override
;
public:
using
naive
::
ElemwiseMultiTypeImpl
::
ElemwiseMultiTypeImpl
;
...
...
dnn/src/naive/elemwise_multi_type/opr_impl.cpp
浏览文件 @
f6d99094
...
...
@@ -39,6 +39,66 @@ void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32(
MEGDNN_DISPATCH_CPU_KERN_OPR
(
work
());
}
void
ElemwiseMultiTypeImpl
::
on_fuse_mul_add3_int16xf32xf32xf32
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
)
{
auto
size
=
param
.
size
;
auto
src0
=
param
[
0
];
auto
src1
=
param
[
1
];
auto
src2
=
param
[
2
];
auto
work
=
[
src0
,
src1
,
src2
,
size
,
dst
]()
{
auto
i0
=
tensor_iter_valonly
<
dt_int16
>
(
src0
).
begin
();
auto
i1
=
tensor_iter_valonly
<
dt_float32
>
(
src1
).
begin
();
auto
i2
=
tensor_iter_valonly
<
dt_float32
>
(
src2
).
begin
();
auto
dst_ptr
=
dst
.
ptr
<
dt_float32
>
();
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
dst_ptr
[
i
]
=
(
*
i0
)
*
(
*
i1
)
+
(
*
i2
);
++
i0
;
++
i1
;
++
i2
;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR
(
work
());
}
void
ElemwiseMultiTypeImpl
::
on_fuse_mul_add3_uint8xf32xf32xf32
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
)
{
auto
size
=
param
.
size
;
auto
src0
=
param
[
0
];
auto
src1
=
param
[
1
];
auto
src2
=
param
[
2
];
auto
work
=
[
src0
,
src1
,
src2
,
size
,
dst
]()
{
auto
i0
=
tensor_iter_valonly
<
dt_uint8
>
(
src0
).
begin
();
auto
i1
=
tensor_iter_valonly
<
dt_float32
>
(
src1
).
begin
();
auto
i2
=
tensor_iter_valonly
<
dt_float32
>
(
src2
).
begin
();
auto
dst_ptr
=
dst
.
ptr
<
dt_float32
>
();
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
dst_ptr
[
i
]
=
(
*
i0
)
*
(
*
i1
)
+
(
*
i2
);
++
i0
;
++
i1
;
++
i2
;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR
(
work
());
}
void
ElemwiseMultiTypeImpl
::
on_mul_int16xf32xf32
(
const
ElemwiseOpParamN
<
2
>&
param
,
const
TensorND
&
dst
)
{
auto
size
=
param
.
size
;
auto
src0
=
param
[
0
];
auto
src1
=
param
[
1
];
auto
work
=
[
src0
,
src1
,
size
,
dst
]()
{
auto
i0
=
tensor_iter_valonly
<
dt_int16
>
(
src0
).
begin
();
auto
i1
=
tensor_iter_valonly
<
dt_float32
>
(
src1
).
begin
();
auto
dst_ptr
=
dst
.
ptr
<
dt_float32
>
();
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
dst_ptr
[
i
]
=
(
*
i0
)
*
(
*
i1
);
++
i0
;
++
i1
;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR
(
work
());
}
void
ElemwiseMultiTypeImpl
::
on_fuse_mul_add3_iXxf32xf32xi8
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
)
{
switch
(
param
[
0
].
layout
.
dtype
.
enumv
())
{
...
...
dnn/src/naive/elemwise_multi_type/opr_impl.h
浏览文件 @
f6d99094
...
...
@@ -60,6 +60,12 @@ protected:
const
ElemwiseOpParamN
<
6
>&
param
,
const
TensorND
&
dst
)
override
;
void
on_round_shr_saturate_iXxi8xi16
(
const
ElemwiseOpParamN
<
2
>&
param
,
const
TensorND
&
dst
)
override
;
void
on_fuse_mul_add3_int16xf32xf32xf32
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
)
override
;
void
on_mul_int16xf32xf32
(
const
ElemwiseOpParamN
<
2
>&
param
,
const
TensorND
&
dst
)
override
;
void
on_fuse_mul_add3_uint8xf32xf32xf32
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
)
override
;
void
on_quantized_mode
(
const
ElemwiseOpParamN
<
1
>&
param
,
const
TensorND
&
dst
,
...
...
dnn/test/arm_common/elemwise_multi_type.cpp
浏览文件 @
f6d99094
...
...
@@ -456,4 +456,107 @@ TEST_F(ARM_COMMON, ELEMWISE_QUANTIZED_MODE_TERNARY_RECORD) {
}
}
TEST_F
(
ARM_COMMON
,
ELEMWISE_FMA3_INT16xF32xF32xF32
)
{
Checker
<
ElemwiseMultiType
>
checker
(
handle
());
checker
.
set_param
({
ElemwiseMultiType
::
Mode
::
FUSE_MUL_ADD3_INT16xF32xF32xF32
});
checker
.
set_dtype
(
0
,
dtype
::
Int16
());
checker
.
set_dtype
(
1
,
dtype
::
Float32
());
checker
.
set_dtype
(
2
,
dtype
::
Float32
());
UniformIntRNG
rng
{
-
100
,
100
};
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_rng
(
2
,
&
rng
);
checker
.
execs
({{
5
,
7
,
16
},
{
1
,
1
,
16
},
{
1
,
1
,
16
},
{}})
.
execs
({{
2
,
700
,
600
},
{
1
,
1
,
600
},
{
1
,
1
,
600
},
{}})
.
execs
({{
2
,
700
,
600
},
{
2
,
700
,
600
},
{
2
,
700
,
600
},
{}})
.
execs
({{
16
,
16
,
128
},
{
16
,
16
,
128
},
{
16
,
16
,
128
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
128
,
1
,
1
},
{
1
,
128
,
1
,
1
},
{}});
}
TEST_F
(
ARM_COMMON
,
ELEMWISE_FMA3_INT16xF32xF32xF32_RECORD
)
{
TaskRecordChecker
<
ElemwiseMultiType
>
checker
(
0
);
checker
.
set_param
({
ElemwiseMultiType
::
Mode
::
FUSE_MUL_ADD3_INT16xF32xF32xF32
});
checker
.
set_dtype
(
0
,
dtype
::
Int16
());
checker
.
set_dtype
(
1
,
dtype
::
Float32
());
checker
.
set_dtype
(
2
,
dtype
::
Float32
());
UniformIntRNG
rng
{
-
100
,
100
};
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_rng
(
2
,
&
rng
);
checker
.
execs
({{
5
,
7
,
16
},
{
1
,
1
,
16
},
{
1
,
1
,
16
},
{}})
.
execs
({{
2
,
700
,
600
},
{
1
,
1
,
600
},
{
1
,
1
,
600
},
{}})
.
execs
({{
2
,
700
,
600
},
{
2
,
700
,
600
},
{
2
,
700
,
600
},
{}})
.
execs
({{
16
,
16
,
128
},
{
16
,
16
,
128
},
{
16
,
16
,
128
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
128
,
1
,
1
},
{
1
,
128
,
1
,
1
},
{}})
.
execs
({{
16
,
128
,
16
,
18
},
{
1
,
1
,
1
,
18
},
{
1
,
1
,
1
,
18
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
},
{}});
}
TEST_F
(
ARM_COMMON
,
ELEMWISE_MUL_INT16xF32xF32
)
{
Checker
<
ElemwiseMultiType
>
checker
(
handle
());
checker
.
set_param
({
ElemwiseMultiType
::
Mode
::
MUL_INT16xF32xF32
});
checker
.
set_dtype
(
0
,
dtype
::
Int16
());
checker
.
set_dtype
(
1
,
dtype
::
Float32
());
UniformIntRNG
rng
{
-
100
,
100
};
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
execs
({{
5
,
7
,
16
},
{
1
,
1
,
16
},
{}})
.
execs
({{
2
,
700
,
600
},
{
1
,
1
,
600
},
{}})
.
execs
({{
2
,
700
,
600
},
{
2
,
700
,
600
},
{}})
.
execs
({{
16
,
16
,
128
},
{
16
,
16
,
128
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
128
,
1
,
1
},
{}});
}
TEST_F
(
ARM_COMMON
,
ELEMWISE_ELEMWISE_MUL_INT16xF32xF32_RECORD
)
{
TaskRecordChecker
<
ElemwiseMultiType
>
checker
(
0
);
checker
.
set_param
({
ElemwiseMultiType
::
Mode
::
MUL_INT16xF32xF32
});
checker
.
set_dtype
(
0
,
dtype
::
Int16
());
checker
.
set_dtype
(
1
,
dtype
::
Float32
());
UniformIntRNG
rng
{
-
100
,
100
};
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
execs
({{
5
,
7
,
16
},
{
1
,
1
,
16
},
{}})
.
execs
({{
2
,
700
,
600
},
{
1
,
1
,
600
},
{}})
.
execs
({{
2
,
700
,
600
},
{
2
,
700
,
600
},
{}})
.
execs
({{
16
,
16
,
128
},
{
16
,
16
,
128
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
128
,
1
,
1
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
1
,
1
,
1
},
{}});
}
TEST_F
(
ARM_COMMON
,
ELEMWISE_FMA3_UINT8xF32xF32xF32
)
{
Checker
<
ElemwiseMultiType
>
checker
(
handle
());
checker
.
set_param
({
ElemwiseMultiType
::
Mode
::
FUSE_MUL_ADD3_UINT8xF32xF32xF32
});
checker
.
set_dtype
(
0
,
dtype
::
Uint8
());
checker
.
set_dtype
(
1
,
dtype
::
Float32
());
checker
.
set_dtype
(
2
,
dtype
::
Float32
());
UniformIntRNG
rng
{
-
100
,
100
};
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_rng
(
2
,
&
rng
);
checker
.
execs
({{
5
,
7
,
16
},
{
1
,
1
,
16
},
{
1
,
1
,
16
},
{}})
.
execs
({{
2
,
700
,
600
},
{
1
,
1
,
600
},
{
1
,
1
,
600
},
{}})
.
execs
({{
2
,
700
,
600
},
{
2
,
700
,
600
},
{
2
,
700
,
600
},
{}})
.
execs
({{
16
,
16
,
128
},
{
16
,
16
,
128
},
{
16
,
16
,
128
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
128
,
1
,
1
},
{
1
,
128
,
1
,
1
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
},
{}});
}
TEST_F
(
ARM_COMMON
,
ELEMWISE_FMA3_UINT8xF32xF32xF32_RECORD
)
{
TaskRecordChecker
<
ElemwiseMultiType
>
checker
(
0
);
checker
.
set_param
({
ElemwiseMultiType
::
Mode
::
FUSE_MUL_ADD3_UINT8xF32xF32xF32
});
checker
.
set_dtype
(
0
,
dtype
::
Uint8
());
checker
.
set_dtype
(
1
,
dtype
::
Float32
());
checker
.
set_dtype
(
2
,
dtype
::
Float32
());
UniformIntRNG
rng
{
-
100
,
100
};
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_rng
(
2
,
&
rng
);
checker
.
execs
({{
5
,
7
,
16
},
{
1
,
1
,
16
},
{
1
,
1
,
16
},
{}})
.
execs
({{
2
,
700
,
600
},
{
1
,
1
,
600
},
{
1
,
1
,
600
},
{}})
.
execs
({{
2
,
700
,
600
},
{
2
,
700
,
600
},
{
2
,
700
,
600
},
{}})
.
execs
({{
16
,
16
,
128
},
{
16
,
16
,
128
},
{
16
,
16
,
128
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
128
,
1
,
1
},
{
1
,
128
,
1
,
1
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
},
{}});
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
dnn/test/common/elemwise_multi_type.cpp
浏览文件 @
f6d99094
...
...
@@ -79,6 +79,73 @@ DEF_TEST(fuse_mul_add3_int16x32x32x32) {
.
execs
({{
102
,
67
,
71
},
{
1
,
67
,
1
},
{
1
,
67
,
1
},
{}});
}
DEF_TEST
(
fuse_mul_add3_int16xf32xf32xf32
)
{
// This is not implemented on CUDA.
if
(
handle
->
type
()
==
Handle
::
HandleType
::
CUDA
)
{
return
;
}
Checker
<
ElemwiseMultiType
>
checker
(
handle
);
checker
.
set_param
({
ElemwiseMultiType
::
Mode
::
FUSE_MUL_ADD3_INT16xF32xF32xF32
});
checker
.
set_dtype
(
0
,
dtype
::
Int16
());
checker
.
set_dtype
(
1
,
dtype
::
Float32
());
checker
.
set_dtype
(
2
,
dtype
::
Float32
());
UniformIntRNG
rng
{
-
100
,
100
};
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_rng
(
2
,
&
rng
);
checker
.
execs
({{
5
,
7
,
6
},
{
1
,
1
,
6
},
{
1
,
1
,
6
},
{}})
.
execs
({{
1
,
700
,
600
},
{
1
,
1
,
600
},
{
1
,
1
,
600
},
{}})
.
execs
({{
1
,
700
,
600
},
{
1
,
700
,
600
},
{
1
,
700
,
600
},
{}})
.
execs
({{
102
,
71
,
67
},
{
1
,
1
,
67
},
{
1
,
1
,
67
},
{}})
.
execs
({{
16
,
16
,
128
},
{
16
,
16
,
128
},
{
16
,
16
,
128
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
128
,
1
,
1
},
{
1
,
128
,
1
,
1
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
},
{}});
}
DEF_TEST
(
fuse_mul_add3_uint8xf32xf32xf32
)
{
// This is not implemented on CUDA.
if
(
handle
->
type
()
==
Handle
::
HandleType
::
CUDA
)
{
return
;
}
Checker
<
ElemwiseMultiType
>
checker
(
handle
);
checker
.
set_param
({
ElemwiseMultiType
::
Mode
::
FUSE_MUL_ADD3_UINT8xF32xF32xF32
});
checker
.
set_dtype
(
0
,
dtype
::
Uint8
());
checker
.
set_dtype
(
1
,
dtype
::
Float32
());
checker
.
set_dtype
(
2
,
dtype
::
Float32
());
UniformIntRNG
rng
{
-
100
,
100
};
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_rng
(
2
,
&
rng
);
checker
.
execs
({{
5
,
7
,
6
},
{
1
,
1
,
6
},
{
1
,
1
,
6
},
{}})
.
execs
({{
1
,
700
,
600
},
{
1
,
1
,
600
},
{
1
,
1
,
600
},
{}})
.
execs
({{
1
,
700
,
600
},
{
1
,
700
,
600
},
{
1
,
700
,
600
},
{}})
.
execs
({{
102
,
71
,
67
},
{
1
,
1
,
67
},
{
1
,
1
,
67
},
{}})
.
execs
({{
16
,
16
,
128
},
{
16
,
16
,
128
},
{
16
,
16
,
128
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
128
,
1
,
1
},
{
1
,
128
,
1
,
1
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
},
{}});
}
DEF_TEST
(
fuse_mul_add3_int16xf32xf32
)
{
// This is not implemented on CUDA.
if
(
handle
->
type
()
==
Handle
::
HandleType
::
CUDA
)
{
return
;
}
Checker
<
ElemwiseMultiType
>
checker
(
handle
);
checker
.
set_param
({
ElemwiseMultiType
::
Mode
::
MUL_INT16xF32xF32
});
checker
.
set_dtype
(
0
,
dtype
::
Int16
());
checker
.
set_dtype
(
1
,
dtype
::
Float32
());
UniformIntRNG
rng
{
-
100
,
100
};
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
execs
({{
5
,
7
,
6
},
{
1
,
1
,
6
},
{}})
.
execs
({{
1
,
700
,
600
},
{
1
,
1
,
600
},
{}})
.
execs
({{
1
,
700
,
600
},
{
1
,
700
,
600
},
{}})
.
execs
({{
102
,
71
,
67
},
{
1
,
1
,
67
},
{}})
.
execs
({{
16
,
16
,
128
},
{
16
,
16
,
128
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
128
,
1
,
1
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
1
,
1
,
1
},
{}});
}
DEF_TEST
(
fuse_mul_add3_iXxf32xf32xi8
)
{
Checker
<
ElemwiseMultiType
>
checker
(
handle
);
checker
.
set_param
({
ElemwiseMultiType
::
Mode
::
FUSE_MUL_ADD3_IXxF32xF32xI8
});
...
...
dnn/test/common/elemwise_multi_type.h
浏览文件 @
f6d99094
...
...
@@ -23,7 +23,10 @@ namespace elemwise_multi_type {
#define FOREACH_ELEMWISE_MULTI_TYPE_NONFIRST_CASE(cb) \
cb(fuse_mul_add3_iXxf32xf32xi8) cb(round_shr_saturate_iXxi8xi8) \
cb(fuse_add_rmulh_round_shr_saturate_int16) \
cb(fuse_add_rmulh_round_shr_saturate_int32)
cb(fuse_add_rmulh_round_shr_saturate_int32) \
cb(fuse_mul_add3_int16xf32xf32xf32) \
cb(fuse_mul_add3_uint8xf32xf32xf32) \
cb(fuse_mul_add3_int16xf32xf32)
#define FOREACH_ELEMWISE_MULTI_TYPE_CASE(cb) \
cb(FIRST_ELEMWISE_MULTI_TYPE_CASE) FOREACH_ELEMWISE_MULTI_TYPE_NONFIRST_CASE(cb)
...
...
dnn/test/fallback/elemwise_multi_type.cpp
浏览文件 @
f6d99094
...
...
@@ -40,6 +40,24 @@ TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_RECORD_FMA3_INT16x32x32x32) {
checker
.
execs
({{
A
,
B
,
C
},
{
1
,
B
,
1
},
{
1
,
B
,
1
},
{}});
}
TEST_F
(
FALLBACK
,
ELEMWISE_MULTI_TYPE_RECORD_FMA3_INT16xF32xF32xF32
)
{
TaskRecordChecker
<
ElemwiseMultiType
>
checker
{
1
};
checker
.
set_param
({
ElemwiseMultiType
::
Mode
::
FUSE_MUL_ADD3_INT16xF32xF32xF32
});
checker
.
set_dtype
(
0
,
dtype
::
Int16
());
checker
.
set_dtype
(
1
,
dtype
::
Float32
());
checker
.
set_dtype
(
2
,
dtype
::
Float32
());
UniformIntRNG
rng
{
-
10
,
10
};
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_rng
(
2
,
&
rng
);
checker
.
execs
({{
5
,
7
,
16
},
{
1
,
1
,
16
},
{
1
,
1
,
16
},
{}})
.
execs
({{
2
,
700
,
600
},
{
1
,
1
,
600
},
{
1
,
1
,
600
},
{}})
.
execs
({{
2
,
700
,
600
},
{
2
,
700
,
600
},
{
2
,
700
,
600
},
{}})
.
execs
({{
16
,
16
,
128
},
{
16
,
16
,
128
},
{
16
,
16
,
128
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
128
,
1
,
1
},
{
1
,
128
,
1
,
1
},
{}})
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
},
{}});
}
#if MEGDNN_WITH_BENCHMARK
TEST_F
(
FALLBACK
,
ELEMWISE_MULTI_TYPE_BENCHMARK_FMA3_INT16x32x32x32
)
{
Benchmarker
<
ElemwiseMultiType
>
bench
{
handle
()};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录