Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
68c5e766
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
68c5e766
编写于
11月 30, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(midout): add more midout tag
GitOrigin-RevId: 6096aa2f66cb8f0c89bb3d5003623fca22dcfcc3
上级
9bf718b7
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
402 addition
and
210 deletion
+402
-210
dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp
dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp
+151
-81
dnn/src/common/warp_common.h
dnn/src/common/warp_common.h
+45
-25
dnn/src/fallback/elemwise_multi_type/quantized_impl.cpp
dnn/src/fallback/elemwise_multi_type/quantized_impl.cpp
+151
-81
dnn/src/naive/elemwise_multi_type/opr_impl_3.cpp
dnn/src/naive/elemwise_multi_type/opr_impl_3.cpp
+9
-1
dnn/src/naive/elemwise_multi_type/opr_impl_4.cpp
dnn/src/naive/elemwise_multi_type/opr_impl_4.cpp
+37
-21
dnn/src/naive/elemwise_multi_type/opr_impl_5.cpp
dnn/src/naive/elemwise_multi_type/opr_impl_5.cpp
+9
-1
未找到文件。
dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp
浏览文件 @
68c5e766
...
@@ -6,6 +6,9 @@
...
@@ -6,6 +6,9 @@
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "midout.h"
MIDOUT_DECL
(
megdnn_arm_common_elemwise_multi_type
)
namespace
{
namespace
{
using
namespace
megdnn
;
using
namespace
megdnn
;
...
@@ -370,10 +373,15 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -370,10 +373,15 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void(const src_ctype*, dst_ctype*, DType, DType, size_t)> run = \
thin_function<void(const src_ctype*, dst_ctype*, DType, DType, size_t)> run = \
OpCallerUnary<_op<src_ctype, dst_ctype>, VEC>::run; \
OpCallerUnary<_op<src_ctype, dst_ctype>, VEC>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(0), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src.layout.dtype, \
run(src.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src.layout.dtype, \
dst.layout.dtype, nr_elems)); \
dst.layout.dtype, nr_elems)); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -469,10 +477,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -469,10 +477,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t)> \
size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_VEC>::run; \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
MIDOUT_BEGIN( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
megdnn_arm_common_elemwise_multi_type, midout_iv(1), src_ctype, \
src0.layout.dtype, src1.layout.dtype, dst.layout.dtype, nr_elems)); \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, nr_elems)); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -502,11 +516,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -502,11 +516,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype, dst_ctype*, DType, DType, DType, \
const src_ctype*, const src_ctype, dst_ctype*, DType, DType, DType, \
size_t)> \
size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_SCALAR>::run; \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_SCALAR>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(2), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>()[0], \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>()[0], \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, src0.layout.total_nr_elems())); \
dst.layout.dtype, src0.layout.total_nr_elems())); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -525,11 +544,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -525,11 +544,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype, const src_ctype*, dst_ctype*, DType, DType, DType, \
const src_ctype, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t)> \
size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, SCALAR_VEC>::run; \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, SCALAR_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(3), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>()[0], src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>()[0], src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, src1.layout.total_nr_elems())); \
dst.layout.dtype, src1.layout.total_nr_elems())); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -563,11 +587,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -563,11 +587,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t)> \
size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101>::run; \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(4), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -586,11 +615,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -586,11 +615,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t)> \
size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101_VEC>::run; \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(5), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -613,11 +647,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -613,11 +647,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t, size_t)> \
size_t, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101xX>::run; \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101xX>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(6), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
size_t
batch_size
=
src0
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
size_t
batch_size
=
src0
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
DISPATCH
()
DISPATCH
()
...
@@ -636,11 +675,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -636,11 +675,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t, size_t)> \
size_t, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101xX_VEC>::run; \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101xX_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(7), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
size_t
batch_size
=
src1
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
size_t
batch_size
=
src1
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
DISPATCH
()
DISPATCH
()
...
@@ -693,11 +737,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -693,11 +737,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \
DType, DType, DType, DType, size_t)> \
DType, DType, DType, DType, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_VEC>::run; \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
MIDOUT_BEGIN( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), src2.ptr<src_ctype>(), \
megdnn_arm_common_elemwise_multi_type, midout_iv(8), src_ctype, \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst_ctype, midout_iv(_mode)) { \
src2.layout.dtype, dst.layout.dtype, nr_elems)); \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, nr_elems)); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -716,12 +766,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -716,12 +766,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, const src_ctype, dst_ctype*, \
const src_ctype*, const src_ctype*, const src_ctype, dst_ctype*, \
DType, DType, DType, DType, size_t)> \
DType, DType, DType, DType, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_SCALAR>::run; \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_SCALAR>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(9), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>()[0], dst.ptr<dst_ctype>(), src0.layout.dtype,
\
src2.ptr<src_ctype>()[0], dst.ptr<dst_ctype>(),
\
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,
\
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype,
\
src0.layout.total_nr_elems()));
\
dst.layout.dtype, src0.layout.total_nr_elems()));
\
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -745,12 +800,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -745,12 +800,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \
_op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MIDOUT_BEGIN( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
megdnn_arm_common_elemwise_multi_type, midout_iv(10), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \
binfo.y, binfo.z, binfo.y* binfo.z)); \
binfo.y, binfo.z, binfo.y* binfo.z)); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -775,12 +835,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -775,12 +835,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, VEC_BCAST101xX_VEC>::run; \
_op<src_ctype, dst_ctype>, VEC_BCAST101xX_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(11), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype,
\
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(),
\
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,
\
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype,
\
batch_size, binfo.x, binfo.y, binfo.z));
\
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z));
\
return; \
return; \
} \
MIDOUT_END(); \
}
}
size_t
batch_size
=
src0
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
size_t
batch_size
=
src0
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
...
@@ -803,12 +868,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -803,12 +868,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, BCAST101xX_VEC_BCAST101xX>::run; \
_op<src_ctype, dst_ctype>, BCAST101xX_VEC_BCAST101xX>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(12), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype,
\
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(),
\
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,
\
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype,
\
batch_size, binfo.x, binfo.y, binfo.z));
\
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z));
\
return; \
return; \
} \
MIDOUT_END(); \
}
}
size_t
batch_size
=
src1
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
size_t
batch_size
=
src1
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
...
...
dnn/src/common/warp_common.h
浏览文件 @
68c5e766
...
@@ -916,23 +916,43 @@ void remap(
...
@@ -916,23 +916,43 @@ void remap(
#define DISPATCH_IMODE(_imode, _bmode, _ch, _cb) \
#define DISPATCH_IMODE(_imode, _bmode, _ch, _cb) \
switch (_imode) { \
switch (_imode) { \
case InterpolationMode::NEAREST: { \
case InterpolationMode::NEAREST: { \
MIDOUT_BEGIN( \
megdnn_warp, midout_iv(0), midout_iv("DISPATCH_IMODE"_hash)) { \
DISPATCH_BMODE(InterpolationMode::NEAREST, _bmode, _ch, _cb); \
DISPATCH_BMODE(InterpolationMode::NEAREST, _bmode, _ch, _cb); \
} \
MIDOUT_END(); \
break; \
break; \
} \
} \
case InterpolationMode::LINEAR: { \
case InterpolationMode::LINEAR: { \
MIDOUT_BEGIN( \
megdnn_warp, midout_iv(1), midout_iv("DISPATCH_IMODE"_hash)) { \
DISPATCH_BMODE(InterpolationMode::LINEAR, _bmode, _ch, _cb); \
DISPATCH_BMODE(InterpolationMode::LINEAR, _bmode, _ch, _cb); \
} \
MIDOUT_END(); \
break; \
break; \
} \
} \
case InterpolationMode::AREA: { \
case InterpolationMode::AREA: { \
MIDOUT_BEGIN( \
megdnn_warp, midout_iv(2), midout_iv("DISPATCH_IMODE"_hash)) { \
DISPATCH_BMODE(InterpolationMode::AREA, _bmode, _ch, _cb); \
DISPATCH_BMODE(InterpolationMode::AREA, _bmode, _ch, _cb); \
} \
MIDOUT_END(); \
break; \
break; \
} \
} \
case InterpolationMode::CUBIC: { \
case InterpolationMode::CUBIC: { \
MIDOUT_BEGIN( \
megdnn_warp, midout_iv(3), midout_iv("DISPATCH_IMODE"_hash)) { \
DISPATCH_BMODE(InterpolationMode::CUBIC, _bmode, _ch, _cb); \
DISPATCH_BMODE(InterpolationMode::CUBIC, _bmode, _ch, _cb); \
} \
MIDOUT_END(); \
break; \
break; \
} \
} \
case InterpolationMode::LANCZOS4: { \
case InterpolationMode::LANCZOS4: { \
MIDOUT_BEGIN( \
megdnn_warp, midout_iv(4), midout_iv("DISPATCH_IMODE"_hash)) { \
DISPATCH_BMODE(InterpolationMode::LANCZOS4, _bmode, _ch, _cb); \
DISPATCH_BMODE(InterpolationMode::LANCZOS4, _bmode, _ch, _cb); \
} \
MIDOUT_END(); \
break; \
break; \
} \
} \
default: { \
default: { \
...
...
dnn/src/fallback/elemwise_multi_type/quantized_impl.cpp
浏览文件 @
68c5e766
...
@@ -3,6 +3,9 @@
...
@@ -3,6 +3,9 @@
#include "src/fallback/elemwise_multi_type/opr_impl.h"
#include "src/fallback/elemwise_multi_type/opr_impl.h"
#include "src/naive/handle.h"
#include "src/naive/handle.h"
#include "midout.h"
MIDOUT_DECL
(
megdnn_fallback_elemwise_multi_type_quantized
)
using
namespace
megdnn
;
using
namespace
megdnn
;
using
namespace
fallback
;
using
namespace
fallback
;
using
namespace
elemwise
;
using
namespace
elemwise
;
...
@@ -52,10 +55,15 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -52,10 +55,15 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void(const src_ctype*, dst_ctype*, DType, DType, size_t)> run = \
thin_function<void(const src_ctype*, dst_ctype*, DType, DType, size_t)> run = \
OpCallerUnary<_op<src_ctype, dst_ctype>, VEC>::run; \
OpCallerUnary<_op<src_ctype, dst_ctype>, VEC>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(0), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src.layout.dtype, \
run(src.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src.layout.dtype, \
dst.layout.dtype, nr_elems)); \
dst.layout.dtype, nr_elems)); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -130,10 +138,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -130,10 +138,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t)> \
size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_VEC>::run; \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
MIDOUT_BEGIN( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(1), \
src0.layout.dtype, src1.layout.dtype, dst.layout.dtype, nr_elems)); \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, nr_elems)); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -164,11 +178,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -164,11 +178,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype, dst_ctype*, DType, DType, DType, \
const src_ctype*, const src_ctype, dst_ctype*, DType, DType, DType, \
size_t)> \
size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_SCALAR>::run; \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_SCALAR>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(2), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>()[0], \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>()[0], \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, src0.layout.total_nr_elems())); \
dst.layout.dtype, src0.layout.total_nr_elems())); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -187,11 +206,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -187,11 +206,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype, const src_ctype*, dst_ctype*, DType, DType, DType, \
const src_ctype, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t)> \
size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, SCALAR_VEC>::run; \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, SCALAR_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(3), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>()[0], src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>()[0], src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, src1.layout.total_nr_elems())); \
dst.layout.dtype, src1.layout.total_nr_elems())); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -225,11 +249,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -225,11 +249,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t)> \
size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101>::run; \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(4), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -248,11 +277,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -248,11 +277,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t)> \
size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101_VEC>::run; \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(5), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -275,11 +309,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -275,11 +309,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t, size_t)> \
size_t, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101xX>::run; \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101xX>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(6), _src_dt, \
_dst_dt, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
size_t
batch_size
=
src0
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
size_t
batch_size
=
src0
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
DISPATCH
()
DISPATCH
()
...
@@ -298,11 +337,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -298,11 +337,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t, size_t)> \
size_t, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101xX_VEC>::run; \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101xX_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(7), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
size_t
batch_size
=
src1
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
size_t
batch_size
=
src1
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
DISPATCH
()
DISPATCH
()
...
@@ -355,11 +399,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -355,11 +399,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \
DType, DType, DType, DType, size_t)> \
DType, DType, DType, DType, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_VEC>::run; \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
MIDOUT_BEGIN( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), src2.ptr<src_ctype>(), \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(8), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
src_ctype, dst_ctype, midout_iv(_mode)) { \
src2.layout.dtype, dst.layout.dtype, nr_elems)); \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, nr_elems)); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -378,12 +428,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -378,12 +428,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, const src_ctype, dst_ctype*, \
const src_ctype*, const src_ctype*, const src_ctype, dst_ctype*, \
DType, DType, DType, DType, size_t)> \
DType, DType, DType, DType, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_SCALAR>::run; \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_SCALAR>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(9), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>()[0], dst.ptr<dst_ctype>(), src0.layout.dtype,
\
src2.ptr<src_ctype>()[0], dst.ptr<dst_ctype>(),
\
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,
\
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype,
\
src0.layout.total_nr_elems()));
\
dst.layout.dtype, src0.layout.total_nr_elems()));
\
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -407,12 +462,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -407,12 +462,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \
_op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MIDOUT_BEGIN( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(10), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \
binfo.y, binfo.z, binfo.y* binfo.z)); \
binfo.y, binfo.z, binfo.y* binfo.z)); \
return; \
return; \
} \
MIDOUT_END(); \
}
}
DISPATCH
()
DISPATCH
()
...
@@ -437,12 +497,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -437,12 +497,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, VEC_BCAST101xX_VEC>::run; \
_op<src_ctype, dst_ctype>, VEC_BCAST101xX_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(11), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype,
\
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(),
\
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,
\
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype,
\
batch_size, binfo.x, binfo.y, binfo.z));
\
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z));
\
return; \
return; \
} \
MIDOUT_END(); \
}
}
size_t
batch_size
=
src0
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
size_t
batch_size
=
src0
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
...
@@ -465,12 +530,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -465,12 +530,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, BCAST101xX_VEC_BCAST101xX>::run; \
_op<src_ctype, dst_ctype>, BCAST101xX_VEC_BCAST101xX>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(12), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype,
\
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(),
\
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,
\
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype,
\
batch_size, binfo.x, binfo.y, binfo.z));
\
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z));
\
return; \
return; \
} \
MIDOUT_END(); \
}
}
size_t
batch_size
=
src1
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
size_t
batch_size
=
src1
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
...
...
dnn/src/naive/elemwise_multi_type/opr_impl_3.cpp
浏览文件 @
68c5e766
...
@@ -2,6 +2,9 @@
...
@@ -2,6 +2,9 @@
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "midout.h"
MIDOUT_DECL
(
megdnn_naive_elemwise_multi_type
)
using
namespace
megdnn
;
using
namespace
megdnn
;
using
namespace
naive
;
using
namespace
naive
;
...
@@ -16,7 +19,12 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -16,7 +19,12 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
typedef ElemwiseKern< \
typedef ElemwiseKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
KernImpl; \
KernImpl; \
MIDOUT_BEGIN( \
megdnn_naive_elemwise_multi_type, midout_iv(1), \
param_enumv::Elemwise::Mode::_mode) { \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<1>>(param, dst); \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<1>>(param, dst); \
} \
MIDOUT_END(); \
break; \
break; \
}
}
...
...
dnn/src/naive/elemwise_multi_type/opr_impl_4.cpp
浏览文件 @
68c5e766
...
@@ -2,6 +2,9 @@
...
@@ -2,6 +2,9 @@
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "midout.h"
MIDOUT_DECL
(
megdnn_naive_elemwise_multi_type
)
using
namespace
megdnn
;
using
namespace
megdnn
;
using
namespace
naive
;
using
namespace
naive
;
...
@@ -18,7 +21,12 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -18,7 +21,12 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
typedef ElemwiseKern< \
typedef ElemwiseKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
KernImpl; \
KernImpl; \
MIDOUT_BEGIN( \
megdnn_naive_elemwise_multi_type, midout_iv(2), \
param_enumv::Elemwise::Mode::_mode) { \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<2>>(param, dst); \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<2>>(param, dst); \
} \
MIDOUT_END(); \
break; \
break; \
}
}
...
@@ -65,9 +73,13 @@ void ElemwiseMultiTypeImpl::dest_type_bool_mode(
...
@@ -65,9 +73,13 @@ void ElemwiseMultiTypeImpl::dest_type_bool_mode(
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \
typename DTypeTrait<_dt>::ctype, dt_bool> \
typename DTypeTrait<_dt>::ctype, dt_bool> \
KernImpl##_mode; \
KernImpl##_mode; \
dispatch_dst_bool_op< \
using _ctype = typename DTypeTrait<_dt>::ctype; \
KernImpl##_mode, typename DTypeTrait<_dt>::ctype, dt_bool>( \
MIDOUT_BEGIN( \
param, dst); \
megdnn_naive_elemwise_multi_type, midout_iv(0), _ctype, \
param_enumv::Elemwise::Mode::_mode) { \
dispatch_dst_bool_op<KernImpl##_mode, _ctype, dt_bool>(param, dst); \
} \
MIDOUT_END(); \
break; \
break; \
}
}
#define DISPATCH_MODE(_mode) \
#define DISPATCH_MODE(_mode) \
...
@@ -111,9 +123,13 @@ void ElemwiseMultiTypeImpl::dest_type_bool_mode(
...
@@ -111,9 +123,13 @@ void ElemwiseMultiTypeImpl::dest_type_bool_mode(
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \
typename DTypeTrait<_dt>::ctype, dt_bool> \
typename DTypeTrait<_dt>::ctype, dt_bool> \
KernImpl##_mode; \
KernImpl##_mode; \
dispatch_dst_bool_op< \
using _ctype = typename DTypeTrait<_dt>::ctype; \
KernImpl##_mode, typename DTypeTrait<_dt>::ctype, dt_bool>( \
MIDOUT_BEGIN( \
param, dst); \
megdnn_naive_elemwise_multi_type, midout_iv(1), _ctype, \
param_enumv::Elemwise::Mode::_mode) { \
dispatch_dst_bool_op<KernImpl##_mode, _ctype, dt_bool>(param, dst); \
} \
MIDOUT_END(); \
break; \
break; \
};
};
#define DISPATCH_MODE(_mode) \
#define DISPATCH_MODE(_mode) \
...
...
dnn/src/naive/elemwise_multi_type/opr_impl_5.cpp
浏览文件 @
68c5e766
...
@@ -2,6 +2,9 @@
...
@@ -2,6 +2,9 @@
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "midout.h"
MIDOUT_DECL
(
megdnn_naive_elemwise_multi_type
)
using
namespace
megdnn
;
using
namespace
megdnn
;
using
namespace
naive
;
using
namespace
naive
;
...
@@ -19,7 +22,12 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
...
@@ -19,7 +22,12 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
typedef ElemwiseKern< \
typedef ElemwiseKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
KernImpl; \
KernImpl; \
MIDOUT_BEGIN( \
megdnn_naive_elemwise_multi_type, midout_iv(3), \
param_enumv::Elemwise::Mode::_mode) { \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<3>>(param, dst); \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<3>>(param, dst); \
} \
MIDOUT_END(); \
break; \
break; \
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录