Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
36ba1d6d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
36ba1d6d
编写于
6月 16, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(riscv): fix ci fp16 build and move test GI_TEST_NAIVE by megdnn_gi_api_test
GitOrigin-RevId: e463855d925d6ea8eb2da82c2c911dde4fcb3d45
上级
dcce4610
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
22 addition
and
99 deletion
+22
-99
dnn/src/arm_common/elemwise_helper/elemwise_op.h
dnn/src/arm_common/elemwise_helper/elemwise_op.h
+22
-39
dnn/src/fallback/elemwise_helper/op_common.h
dnn/src/fallback/elemwise_helper/op_common.h
+0
-60
未找到文件。
dnn/src/arm_common/elemwise_helper/elemwise_op.h
浏览文件 @
36ba1d6d
...
@@ -13,18 +13,6 @@ using BcastType = megdnn::elemwise::BcastType;
...
@@ -13,18 +13,6 @@ using BcastType = megdnn::elemwise::BcastType;
///////////////////////////////// ParamElemVistor ///////////////////////////
///////////////////////////////// ParamElemVistor ///////////////////////////
#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, _neon_type_v2) \
#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, _neon_type_v2) \
template <> \
struct ParamElemVisitor<_ctype> { \
_neon_type operator()(const _ctype* src) const { \
return vld1q_##_fun_suffix(reinterpret_cast<const _inner_ctype*>(src)); \
} \
}; \
template <> \
struct ParamElemVisitorDup<_ctype> { \
_neon_type operator()(const _ctype* src) const { \
return vdupq_n_##_fun_suffix(*reinterpret_cast<const _inner_ctype*>(src)); \
} \
}; \
template <> \
template <> \
struct ParamElemVisitorV2<_ctype> { \
struct ParamElemVisitorV2<_ctype> { \
_neon_type_v2 operator()(const _ctype* src, const _ctype* src_1) const { \
_neon_type_v2 operator()(const _ctype* src, const _ctype* src_1) const { \
...
@@ -53,16 +41,7 @@ cb(__fp16, __fp16, float16x8_t, f16, float16x8x2_t);
...
@@ -53,16 +41,7 @@ cb(__fp16, __fp16, float16x8_t, f16, float16x8x2_t);
cb
(
dt_int16
,
int16_t
,
int16x8_t
,
s16
,
int16x8x2_t
);
cb
(
dt_int16
,
int16_t
,
int16x8_t
,
s16
,
int16x8x2_t
);
#undef cb
#undef cb
template
<
typename
ctype
>
struct
ParamElemVisitorBcast101x4
;
#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, rel_suffix, _neon_type_v2) \
#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, rel_suffix, _neon_type_v2) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_neon_type operator()(const _ctype* src) const { \
return vreinterpretq_##_fun_suffix##_##rel_suffix(vld1q_dup_##rel_suffix( \
reinterpret_cast<const _inner_ctype*>(src))); \
} \
}; \
template <> \
template <> \
struct ParamElemVisitorBcast101x4V2<_ctype> { \
struct ParamElemVisitorBcast101x4V2<_ctype> { \
_neon_type_v2 operator()(const _ctype* src) const { \
_neon_type_v2 operator()(const _ctype* src) const { \
...
@@ -83,16 +62,20 @@ cb(__fp16, uint64_t, float16x8_t, f16, u64, float16x8x2_t);
...
@@ -83,16 +62,20 @@ cb(__fp16, uint64_t, float16x8_t, f16, u64, float16x8x2_t);
#undef cb
#undef cb
template
<
typename
ctype
>
template
<
typename
ctype
>
struct
ParamElemVisitorBcast101x8
;
struct
ParamElemVisitorBcast101x8
V2
;
#define cb(_ctype, _inner_ctype, _neon_type
, _fun_suffix)
\
#define cb(_ctype, _inner_ctype, _neon_type
_v2, _fun_suffix)
\
template <> \
template <> \
struct ParamElemVisitorBcast101x8<_ctype> { \
struct ParamElemVisitorBcast101x8V2<_ctype> { \
_neon_type operator()(const _ctype* src) const { \
_neon_type_v2 operator()(const _ctype* src) const { \
return vld1q_##_fun_suffix(reinterpret_cast<const _inner_ctype*>(src)); \
_neon_type_v2 ret; \
ret.val[0] = \
vld1q_##_fun_suffix(reinterpret_cast<const _inner_ctype*>(src)); \
ret.val[1] = ret.val[0]; \
return ret; \
} \
} \
}
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
cb
(
__fp16
,
__fp16
,
float16x8_t
,
f16
);
cb
(
__fp16
,
__fp16
,
float16x8
x2
_t
,
f16
);
#endif
#endif
#undef cb
#undef cb
...
@@ -106,8 +89,8 @@ struct OpCallerBinaryBcast101xXVec<__fp16, 8> {
...
@@ -106,8 +89,8 @@ struct OpCallerBinaryBcast101xXVec<__fp16, 8> {
const
src_ctype
*
src0
,
const
src_ctype
*
src1
,
typename
Op
::
dst_ctype
*
dst
,
const
src_ctype
*
src0
,
const
src_ctype
*
src1
,
typename
Op
::
dst_ctype
*
dst
,
const
Op
&
op
,
size_t
batch
,
size_t
nr_channel_blocks
,
const
Op
&
op
,
size_t
batch
,
size_t
nr_channel_blocks
,
size_t
channel_stride
)
{
size_t
channel_stride
)
{
ParamElemVisitorBcast101x8
<
src_ctype
>
vis0
;
ParamElemVisitorBcast101x8
V2
<
src_ctype
>
vis0
;
ParamElemVisitor
<
src_ctype
>
vis1
;
ParamElemVisitor
V2
<
src_ctype
>
vis1
;
OpCallerBinaryBcast101xDVec
<
src_ctype
,
8
>::
run
(
OpCallerBinaryBcast101xDVec
<
src_ctype
,
8
>::
run
(
src0
,
src1
,
dst
,
op
,
vis0
,
vis1
,
batch
,
nr_channel_blocks
,
src0
,
src1
,
dst
,
op
,
vis0
,
vis1
,
batch
,
nr_channel_blocks
,
channel_stride
);
channel_stride
);
...
@@ -122,8 +105,8 @@ struct OpCallerBinaryVecBcast101xX<__fp16, 8> {
...
@@ -122,8 +105,8 @@ struct OpCallerBinaryVecBcast101xX<__fp16, 8> {
const
src_ctype
*
src0
,
const
src_ctype
*
src1
,
typename
Op
::
dst_ctype
*
dst
,
const
src_ctype
*
src0
,
const
src_ctype
*
src1
,
typename
Op
::
dst_ctype
*
dst
,
const
Op
&
op
,
size_t
batch
,
size_t
nr_channel_blocks
,
const
Op
&
op
,
size_t
batch
,
size_t
nr_channel_blocks
,
size_t
channel_stride
)
{
size_t
channel_stride
)
{
ParamElemVisitor
<
src_ctype
>
vis0
;
ParamElemVisitor
V2
<
src_ctype
>
vis0
;
ParamElemVisitorBcast101x8
<
src_ctype
>
vis1
;
ParamElemVisitorBcast101x8
V2
<
src_ctype
>
vis1
;
OpCallerBinaryVecBcast101xD
<
src_ctype
,
8
>::
run
(
OpCallerBinaryVecBcast101xD
<
src_ctype
,
8
>::
run
(
src0
,
src1
,
dst
,
op
,
vis0
,
vis1
,
batch
,
nr_channel_blocks
,
src0
,
src1
,
dst
,
op
,
vis0
,
vis1
,
batch
,
nr_channel_blocks
,
channel_stride
);
channel_stride
);
...
@@ -138,9 +121,9 @@ struct OpCallerTernaryBcast101xXVecBcast101xX<__fp16, 8> {
...
@@ -138,9 +121,9 @@ struct OpCallerTernaryBcast101xXVecBcast101xX<__fp16, 8> {
const
src_ctype
*
src0
,
const
src_ctype
*
src1
,
const
src_ctype
*
src2
,
const
src_ctype
*
src0
,
const
src_ctype
*
src1
,
const
src_ctype
*
src2
,
typename
Op
::
dst_ctype
*
dst
,
const
Op
&
op
,
size_t
batch
,
typename
Op
::
dst_ctype
*
dst
,
const
Op
&
op
,
size_t
batch
,
size_t
nr_channel_blocks
,
size_t
channel_stride
)
{
size_t
nr_channel_blocks
,
size_t
channel_stride
)
{
ParamElemVisitorBcast101x8
<
src_ctype
>
vis0
;
ParamElemVisitorBcast101x8
V2
<
src_ctype
>
vis0
;
ParamElemVisitor
<
src_ctype
>
vis1
;
ParamElemVisitor
V2
<
src_ctype
>
vis1
;
ParamElemVisitorBcast101x8
<
src_ctype
>
vis2
;
ParamElemVisitorBcast101x8
V2
<
src_ctype
>
vis2
;
OpCallerTernaryBcast101xDVecBcast101xD
<
src_ctype
,
8
>::
run
(
OpCallerTernaryBcast101xDVecBcast101xD
<
src_ctype
,
8
>::
run
(
src0
,
src1
,
src2
,
dst
,
op
,
vis0
,
vis1
,
vis2
,
batch
,
nr_channel_blocks
,
src0
,
src1
,
src2
,
dst
,
op
,
vis0
,
vis1
,
vis2
,
batch
,
nr_channel_blocks
,
channel_stride
);
channel_stride
);
...
@@ -155,9 +138,9 @@ struct OpCallerTernaryVecBcast101xXVec<__fp16, 8> {
...
@@ -155,9 +138,9 @@ struct OpCallerTernaryVecBcast101xXVec<__fp16, 8> {
const
src_ctype
*
src0
,
const
src_ctype
*
src1
,
const
src_ctype
*
src2
,
const
src_ctype
*
src0
,
const
src_ctype
*
src1
,
const
src_ctype
*
src2
,
typename
Op
::
dst_ctype
*
dst
,
const
Op
&
op
,
size_t
batch
,
typename
Op
::
dst_ctype
*
dst
,
const
Op
&
op
,
size_t
batch
,
size_t
nr_channel_blocks
,
size_t
channel_stride
)
{
size_t
nr_channel_blocks
,
size_t
channel_stride
)
{
ParamElemVisitor
<
src_ctype
>
vis0
;
ParamElemVisitor
V2
<
src_ctype
>
vis0
;
ParamElemVisitorBcast101x8
<
src_ctype
>
vis1
;
ParamElemVisitorBcast101x8
V2
<
src_ctype
>
vis1
;
ParamElemVisitor
<
src_ctype
>
vis2
;
ParamElemVisitor
V2
<
src_ctype
>
vis2
;
OpCallerTernaryVecBcast101xDVec
<
src_ctype
,
8
>::
run
(
OpCallerTernaryVecBcast101xDVec
<
src_ctype
,
8
>::
run
(
src0
,
src1
,
src2
,
dst
,
op
,
vis0
,
vis1
,
vis2
,
batch
,
nr_channel_blocks
,
src0
,
src1
,
src2
,
dst
,
op
,
vis0
,
vis1
,
vis2
,
batch
,
nr_channel_blocks
,
channel_stride
);
channel_stride
);
...
...
dnn/src/fallback/elemwise_helper/op_common.h
浏览文件 @
36ba1d6d
...
@@ -36,66 +36,6 @@ enum BcastType {
...
@@ -36,66 +36,6 @@ enum BcastType {
UNKNOWN_BCAST_TYPE
UNKNOWN_BCAST_TYPE
};
};
///////////////////////////////// ParamElemVistor ///////////////////////////
template
<
typename
ctype
>
struct
ParamElemVisitor
;
//! visitor single elemwise, and dup to vector
template
<
typename
ctype
>
struct
ParamElemVisitorDup
;
template
<
typename
ctype
>
struct
ParamElemVisitorBcast101x4
;
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \
template <> \
struct ParamElemVisitor<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiLoad##_fun_suffix(src); \
} \
}; \
template <> \
struct ParamElemVisitorDup<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiBroadcast##_fun_suffix( \
*reinterpret_cast<const _inner_ctype*>(src)); \
} \
}
cb
(
dt_qint32
,
int32_t
,
GI_INT32_t
,
Int32
);
cb
(
dt_qint8
,
int8_t
,
GI_INT8_t
,
Int8
);
cb
(
dt_float32
,
float
,
GI_FLOAT32_t
,
Float32
);
cb
(
dt_int32
,
int32_t
,
GI_INT32_t
,
Int32
);
cb
(
dt_int8
,
int8_t
,
GI_INT8_t
,
Int8
);
#undef cb
template
<
typename
ctype
>
struct
ParamElemVisitorBcast101x4
;
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \
*reinterpret_cast<const _inner_ctype*>(src))); \
} \
}
cb
(
dt_qint8
,
int32_t
,
GI_INT8_t
,
Int8
,
Int32
);
cb
(
dt_int8
,
int32_t
,
GI_INT8_t
,
Int8
,
Int32
);
#undef cb
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiLoad##_fun_suffix(src); \
} \
}
cb
(
dt_qint32
,
int32_t
,
GI_INT32_t
,
Int32
);
cb
(
dt_float32
,
float
,
GI_FLOAT32_t
,
Float32
);
cb
(
dt_int32
,
int32_t
,
GI_INT32_t
,
Int32
);
#undef cb
///////////////////////////////// ParamElemVistor v2///////////////////////////
///////////////////////////////// ParamElemVistor v2///////////////////////////
template
<
typename
ctype
>
template
<
typename
ctype
>
struct
ParamElemVisitorV2
;
struct
ParamElemVisitorV2
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录