提交 1918ed2c 编写于 作者: M Megvii Engine Team

fix(midout): add more midout tag

GitOrigin-RevId: 6096aa2f66cb8f0c89bb3d5003623fca22dcfcc3
上级 8dbc602e
......@@ -6,6 +6,9 @@
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_elemwise_multi_type)
namespace {
using namespace megdnn;
......@@ -370,10 +373,15 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void(const src_ctype*, dst_ctype*, DType, DType, size_t)> run = \
OpCallerUnary<_op<src_ctype, dst_ctype>, VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src.layout.dtype, \
dst.layout.dtype, nr_elems)); \
return; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(0), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src.layout.dtype, \
dst.layout.dtype, nr_elems)); \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -469,10 +477,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src0.layout.dtype, src1.layout.dtype, dst.layout.dtype, nr_elems)); \
return; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(1), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, nr_elems)); \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -502,11 +516,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype, dst_ctype*, DType, DType, DType, \
size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_SCALAR>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>()[0], \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, src0.layout.total_nr_elems())); \
return; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(2), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>()[0], \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, src0.layout.total_nr_elems())); \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -525,11 +544,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, SCALAR_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>()[0], src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, src1.layout.total_nr_elems())); \
return; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(3), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>()[0], src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, src1.layout.total_nr_elems())); \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -563,11 +587,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
return; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(4), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -586,11 +615,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
return; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(5), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -613,11 +647,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101xX>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(6), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
} \
MIDOUT_END(); \
}
size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
DISPATCH()
......@@ -636,11 +675,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101xX_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(7), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
} \
MIDOUT_END(); \
}
size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
DISPATCH()
......@@ -685,19 +729,25 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
//! VEC + VEC + VEC
if (is_vector(src0.layout) && is_vector(src1.layout) && is_vector(src2.layout)) {
size_t nr_elems = src0.layout.total_nr_elems();
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
case _mode: { \
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void( \
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \
DType, DType, DType, DType, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), src2.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
src2.layout.dtype, dst.layout.dtype, nr_elems)); \
return; \
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
case _mode: { \
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void( \
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \
DType, DType, DType, DType, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(8), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, nr_elems)); \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -716,12 +766,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, const src_ctype, dst_ctype*, \
DType, DType, DType, DType, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_SCALAR>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>()[0], dst.ptr<dst_ctype>(), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
src0.layout.total_nr_elems())); \
return; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(9), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>()[0], dst.ptr<dst_ctype>(), \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, src0.layout.total_nr_elems())); \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -745,12 +800,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(10), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \
binfo.y, binfo.z, binfo.y* binfo.z)); \
return; \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -766,21 +826,26 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
(is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo)) &&
src0.layout.eq_shape(src2.layout)) {
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
case _mode: { \
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void( \
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, VEC_BCAST101xX_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
case _mode: { \
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void( \
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, VEC_BCAST101xX_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(11), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
} \
MIDOUT_END(); \
}
size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
......@@ -803,12 +868,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, BCAST101xX_VEC_BCAST101xX>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(12), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
} \
MIDOUT_END(); \
}
size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
......
......@@ -913,31 +913,51 @@ void remap(
} \
}
#define DISPATCH_IMODE(_imode, _bmode, _ch, _cb) \
switch (_imode) { \
case InterpolationMode::NEAREST: { \
DISPATCH_BMODE(InterpolationMode::NEAREST, _bmode, _ch, _cb); \
break; \
} \
case InterpolationMode::LINEAR: { \
DISPATCH_BMODE(InterpolationMode::LINEAR, _bmode, _ch, _cb); \
break; \
} \
case InterpolationMode::AREA: { \
DISPATCH_BMODE(InterpolationMode::AREA, _bmode, _ch, _cb); \
break; \
} \
case InterpolationMode::CUBIC: { \
DISPATCH_BMODE(InterpolationMode::CUBIC, _bmode, _ch, _cb); \
break; \
} \
case InterpolationMode::LANCZOS4: { \
DISPATCH_BMODE(InterpolationMode::LANCZOS4, _bmode, _ch, _cb); \
break; \
} \
default: { \
megdnn_assert(0, "unsupport interpolation mode for cv"); \
} \
#define DISPATCH_IMODE(_imode, _bmode, _ch, _cb) \
switch (_imode) { \
case InterpolationMode::NEAREST: { \
MIDOUT_BEGIN( \
megdnn_warp, midout_iv(0), midout_iv("DISPATCH_IMODE"_hash)) { \
DISPATCH_BMODE(InterpolationMode::NEAREST, _bmode, _ch, _cb); \
} \
MIDOUT_END(); \
break; \
} \
case InterpolationMode::LINEAR: { \
MIDOUT_BEGIN( \
megdnn_warp, midout_iv(1), midout_iv("DISPATCH_IMODE"_hash)) { \
DISPATCH_BMODE(InterpolationMode::LINEAR, _bmode, _ch, _cb); \
} \
MIDOUT_END(); \
break; \
} \
case InterpolationMode::AREA: { \
MIDOUT_BEGIN( \
megdnn_warp, midout_iv(2), midout_iv("DISPATCH_IMODE"_hash)) { \
DISPATCH_BMODE(InterpolationMode::AREA, _bmode, _ch, _cb); \
} \
MIDOUT_END(); \
break; \
} \
case InterpolationMode::CUBIC: { \
MIDOUT_BEGIN( \
megdnn_warp, midout_iv(3), midout_iv("DISPATCH_IMODE"_hash)) { \
DISPATCH_BMODE(InterpolationMode::CUBIC, _bmode, _ch, _cb); \
} \
MIDOUT_END(); \
break; \
} \
case InterpolationMode::LANCZOS4: { \
MIDOUT_BEGIN( \
megdnn_warp, midout_iv(4), midout_iv("DISPATCH_IMODE"_hash)) { \
DISPATCH_BMODE(InterpolationMode::LANCZOS4, _bmode, _ch, _cb); \
} \
MIDOUT_END(); \
break; \
} \
default: { \
megdnn_assert(0, "unsupport interpolation mode for cv"); \
} \
}
} // namespace warp
......
......@@ -3,6 +3,9 @@
#include "src/fallback/elemwise_multi_type/opr_impl.h"
#include "src/naive/handle.h"
#include "midout.h"
MIDOUT_DECL(megdnn_fallback_elemwise_multi_type_quantized)
using namespace megdnn;
using namespace fallback;
using namespace elemwise;
......@@ -52,10 +55,15 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void(const src_ctype*, dst_ctype*, DType, DType, size_t)> run = \
OpCallerUnary<_op<src_ctype, dst_ctype>, VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src.layout.dtype, \
dst.layout.dtype, nr_elems)); \
return; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(0), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src.layout.dtype, \
dst.layout.dtype, nr_elems)); \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -130,10 +138,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src0.layout.dtype, src1.layout.dtype, dst.layout.dtype, nr_elems)); \
return; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(1), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, nr_elems)); \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -164,11 +178,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype, dst_ctype*, DType, DType, DType, \
size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_SCALAR>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>()[0], \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, src0.layout.total_nr_elems())); \
return; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(2), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>()[0], \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, src0.layout.total_nr_elems())); \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -187,11 +206,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, SCALAR_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>()[0], src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, src1.layout.total_nr_elems())); \
return; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(3), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>()[0], src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, src1.layout.total_nr_elems())); \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -225,11 +249,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
return; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(4), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -248,11 +277,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
return; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(5), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -275,11 +309,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101xX>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(6), _src_dt, \
_dst_dt, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
} \
MIDOUT_END(); \
}
size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
DISPATCH()
......@@ -298,11 +337,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101xX_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(7), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
} \
MIDOUT_END(); \
}
size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
DISPATCH()
......@@ -347,19 +391,25 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
//! VEC + VEC + VEC
if (is_vector(src0.layout) && is_vector(src1.layout) && is_vector(src2.layout)) {
size_t nr_elems = src0.layout.total_nr_elems();
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
case _mode: { \
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void( \
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \
DType, DType, DType, DType, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), src2.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
src2.layout.dtype, dst.layout.dtype, nr_elems)); \
return; \
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
case _mode: { \
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void( \
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \
DType, DType, DType, DType, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(8), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, nr_elems)); \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -378,12 +428,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, const src_ctype, dst_ctype*, \
DType, DType, DType, DType, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_SCALAR>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>()[0], dst.ptr<dst_ctype>(), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
src0.layout.total_nr_elems())); \
return; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(9), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>()[0], dst.ptr<dst_ctype>(), \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, src0.layout.total_nr_elems())); \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -407,12 +462,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(10), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \
binfo.y, binfo.z, binfo.y* binfo.z)); \
return; \
return; \
} \
MIDOUT_END(); \
}
DISPATCH()
......@@ -428,21 +488,26 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
(is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo)) &&
src0.layout.eq_shape(src2.layout)) {
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
case _mode: { \
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void( \
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, VEC_BCAST101xX_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
case _mode: { \
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void( \
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, VEC_BCAST101xX_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(11), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
} \
MIDOUT_END(); \
}
size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
......@@ -465,12 +530,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, BCAST101xX_VEC_BCAST101xX>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(12), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
} \
MIDOUT_END(); \
}
size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
......
......@@ -2,6 +2,9 @@
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "midout.h"
MIDOUT_DECL(megdnn_naive_elemwise_multi_type)
using namespace megdnn;
using namespace naive;
......@@ -16,7 +19,12 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
typedef ElemwiseKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
KernImpl; \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<1>>(param, dst); \
MIDOUT_BEGIN( \
megdnn_naive_elemwise_multi_type, midout_iv(1), \
param_enumv::Elemwise::Mode::_mode) { \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<1>>(param, dst); \
} \
MIDOUT_END(); \
break; \
}
......
......@@ -2,6 +2,9 @@
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "midout.h"
MIDOUT_DECL(megdnn_naive_elemwise_multi_type)
using namespace megdnn;
using namespace naive;
......@@ -18,7 +21,12 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
typedef ElemwiseKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
KernImpl; \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<2>>(param, dst); \
MIDOUT_BEGIN( \
megdnn_naive_elemwise_multi_type, midout_iv(2), \
param_enumv::Elemwise::Mode::_mode) { \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<2>>(param, dst); \
} \
MIDOUT_END(); \
break; \
}
......@@ -59,16 +67,20 @@ void ElemwiseMultiTypeImpl::dest_type_bool_mode(
switch (mode) {
case Elemwise::Mode::ISINF: {
switch (param[0].layout.dtype.enumv()) {
#define DISPATCH(_dt, _mode) \
case DTypeTrait<_dt>::enumv: { \
typedef ElemwiseBoolKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \
typename DTypeTrait<_dt>::ctype, dt_bool> \
KernImpl##_mode; \
dispatch_dst_bool_op< \
KernImpl##_mode, typename DTypeTrait<_dt>::ctype, dt_bool>( \
param, dst); \
break; \
#define DISPATCH(_dt, _mode) \
case DTypeTrait<_dt>::enumv: { \
typedef ElemwiseBoolKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \
typename DTypeTrait<_dt>::ctype, dt_bool> \
KernImpl##_mode; \
using _ctype = typename DTypeTrait<_dt>::ctype; \
MIDOUT_BEGIN( \
megdnn_naive_elemwise_multi_type, midout_iv(0), _ctype, \
param_enumv::Elemwise::Mode::_mode) { \
dispatch_dst_bool_op<KernImpl##_mode, _ctype, dt_bool>(param, dst); \
} \
MIDOUT_END(); \
break; \
}
#define DISPATCH_MODE(_mode) \
DISPATCH(megdnn::dtype::Float32, _mode); \
......@@ -105,16 +117,20 @@ void ElemwiseMultiTypeImpl::dest_type_bool_mode(
switch (mode) {
case Elemwise::Mode::EQ: {
switch (param[0].layout.dtype.enumv()) {
#define DISPATCH(_dt, _mode) \
case DTypeTrait<_dt>::enumv: { \
typedef ElemwiseBoolKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \
typename DTypeTrait<_dt>::ctype, dt_bool> \
KernImpl##_mode; \
dispatch_dst_bool_op< \
KernImpl##_mode, typename DTypeTrait<_dt>::ctype, dt_bool>( \
param, dst); \
break; \
#define DISPATCH(_dt, _mode) \
case DTypeTrait<_dt>::enumv: { \
typedef ElemwiseBoolKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \
typename DTypeTrait<_dt>::ctype, dt_bool> \
KernImpl##_mode; \
using _ctype = typename DTypeTrait<_dt>::ctype; \
MIDOUT_BEGIN( \
megdnn_naive_elemwise_multi_type, midout_iv(1), _ctype, \
param_enumv::Elemwise::Mode::_mode) { \
dispatch_dst_bool_op<KernImpl##_mode, _ctype, dt_bool>(param, dst); \
} \
MIDOUT_END(); \
break; \
};
#define DISPATCH_MODE(_mode) \
DISPATCH(megdnn::dtype::Float32, _mode); \
......
......@@ -2,6 +2,9 @@
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "midout.h"
MIDOUT_DECL(megdnn_naive_elemwise_multi_type)
using namespace megdnn;
using namespace naive;
......@@ -19,7 +22,12 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
typedef ElemwiseKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
KernImpl; \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<3>>(param, dst); \
MIDOUT_BEGIN( \
megdnn_naive_elemwise_multi_type, midout_iv(3), \
param_enumv::Elemwise::Mode::_mode) { \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<3>>(param, dst); \
} \
MIDOUT_END(); \
break; \
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册