提交 68c5e766 编写于 作者: M Megvii Engine Team

fix(midout): add more midout tag

GitOrigin-RevId: 6096aa2f66cb8f0c89bb3d5003623fca22dcfcc3
上级 9bf718b7
...@@ -6,6 +6,9 @@ ...@@ -6,6 +6,9 @@
#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_elemwise_multi_type)
namespace { namespace {
using namespace megdnn; using namespace megdnn;
...@@ -370,10 +373,15 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -370,10 +373,15 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void(const src_ctype*, dst_ctype*, DType, DType, size_t)> run = \ thin_function<void(const src_ctype*, dst_ctype*, DType, DType, size_t)> run = \
OpCallerUnary<_op<src_ctype, dst_ctype>, VEC>::run; \ OpCallerUnary<_op<src_ctype, dst_ctype>, VEC>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(0), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src.layout.dtype, \ run(src.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src.layout.dtype, \
dst.layout.dtype, nr_elems)); \ dst.layout.dtype, nr_elems)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -469,10 +477,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -469,10 +477,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t)> \ size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_VEC>::run; \ run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ MIDOUT_BEGIN( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \ megdnn_arm_common_elemwise_multi_type, midout_iv(1), src_ctype, \
src0.layout.dtype, src1.layout.dtype, dst.layout.dtype, nr_elems)); \ dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, nr_elems)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -502,11 +516,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -502,11 +516,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype, dst_ctype*, DType, DType, DType, \ const src_ctype*, const src_ctype, dst_ctype*, DType, DType, DType, \
size_t)> \ size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_SCALAR>::run; \ run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_SCALAR>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(2), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>()[0], \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>()[0], \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, src0.layout.total_nr_elems())); \ dst.layout.dtype, src0.layout.total_nr_elems())); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -525,11 +544,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -525,11 +544,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype, const src_ctype*, dst_ctype*, DType, DType, DType, \ const src_ctype, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t)> \ size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, SCALAR_VEC>::run; \ run = OpCallerBinary<_op<src_ctype, dst_ctype>, SCALAR_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(3), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>()[0], src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>()[0], src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, src1.layout.total_nr_elems())); \ dst.layout.dtype, src1.layout.total_nr_elems())); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -563,11 +587,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -563,11 +587,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t)> \ size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101>::run; \ run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(4), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -586,11 +615,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -586,11 +615,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t)> \ size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101_VEC>::run; \ run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(5), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -613,11 +647,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -613,11 +647,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t, size_t)> \ size_t, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101xX>::run; \ run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101xX>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(6), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
DISPATCH() DISPATCH()
...@@ -636,11 +675,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -636,11 +675,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t, size_t)> \ size_t, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101xX_VEC>::run; \ run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101xX_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(7), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
DISPATCH() DISPATCH()
...@@ -693,11 +737,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -693,11 +737,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \ const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \
DType, DType, DType, DType, size_t)> \ DType, DType, DType, DType, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_VEC>::run; \ run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ MIDOUT_BEGIN( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), src2.ptr<src_ctype>(), \ megdnn_arm_common_elemwise_multi_type, midout_iv(8), src_ctype, \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ dst_ctype, midout_iv(_mode)) { \
src2.layout.dtype, dst.layout.dtype, nr_elems)); \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, nr_elems)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -716,12 +766,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -716,12 +766,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, const src_ctype, dst_ctype*, \ const src_ctype*, const src_ctype*, const src_ctype, dst_ctype*, \
DType, DType, DType, DType, size_t)> \ DType, DType, DType, DType, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_SCALAR>::run; \ run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_SCALAR>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(9), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>()[0], dst.ptr<dst_ctype>(), src0.layout.dtype, \ src2.ptr<src_ctype>()[0], dst.ptr<dst_ctype>(), \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
src0.layout.total_nr_elems())); \ dst.layout.dtype, src0.layout.total_nr_elems())); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -745,12 +800,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -745,12 +800,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \ DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \ run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \ _op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MIDOUT_BEGIN( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ megdnn_arm_common_elemwise_multi_type, midout_iv(10), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \ src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \ src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \
binfo.y, binfo.z, binfo.y* binfo.z)); \ binfo.y, binfo.z, binfo.y* binfo.z)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -775,12 +835,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -775,12 +835,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \ DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \ run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, VEC_BCAST101xX_VEC>::run; \ _op<src_ctype, dst_ctype>, VEC_BCAST101xX_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(11), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \ src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
batch_size, binfo.x, binfo.y, binfo.z)); \ dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
...@@ -803,12 +868,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -803,12 +868,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \ DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \ run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, BCAST101xX_VEC_BCAST101xX>::run; \ _op<src_ctype, dst_ctype>, BCAST101xX_VEC_BCAST101xX>::run; \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_multi_type, midout_iv(12), src_ctype, \
dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \ src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
batch_size, binfo.x, binfo.y, binfo.z)); \ dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
......
...@@ -916,23 +916,43 @@ void remap( ...@@ -916,23 +916,43 @@ void remap(
#define DISPATCH_IMODE(_imode, _bmode, _ch, _cb) \ #define DISPATCH_IMODE(_imode, _bmode, _ch, _cb) \
switch (_imode) { \ switch (_imode) { \
case InterpolationMode::NEAREST: { \ case InterpolationMode::NEAREST: { \
MIDOUT_BEGIN( \
megdnn_warp, midout_iv(0), midout_iv("DISPATCH_IMODE"_hash)) { \
DISPATCH_BMODE(InterpolationMode::NEAREST, _bmode, _ch, _cb); \ DISPATCH_BMODE(InterpolationMode::NEAREST, _bmode, _ch, _cb); \
} \
MIDOUT_END(); \
break; \ break; \
} \ } \
case InterpolationMode::LINEAR: { \ case InterpolationMode::LINEAR: { \
MIDOUT_BEGIN( \
megdnn_warp, midout_iv(1), midout_iv("DISPATCH_IMODE"_hash)) { \
DISPATCH_BMODE(InterpolationMode::LINEAR, _bmode, _ch, _cb); \ DISPATCH_BMODE(InterpolationMode::LINEAR, _bmode, _ch, _cb); \
} \
MIDOUT_END(); \
break; \ break; \
} \ } \
case InterpolationMode::AREA: { \ case InterpolationMode::AREA: { \
MIDOUT_BEGIN( \
megdnn_warp, midout_iv(2), midout_iv("DISPATCH_IMODE"_hash)) { \
DISPATCH_BMODE(InterpolationMode::AREA, _bmode, _ch, _cb); \ DISPATCH_BMODE(InterpolationMode::AREA, _bmode, _ch, _cb); \
} \
MIDOUT_END(); \
break; \ break; \
} \ } \
case InterpolationMode::CUBIC: { \ case InterpolationMode::CUBIC: { \
MIDOUT_BEGIN( \
megdnn_warp, midout_iv(3), midout_iv("DISPATCH_IMODE"_hash)) { \
DISPATCH_BMODE(InterpolationMode::CUBIC, _bmode, _ch, _cb); \ DISPATCH_BMODE(InterpolationMode::CUBIC, _bmode, _ch, _cb); \
} \
MIDOUT_END(); \
break; \ break; \
} \ } \
case InterpolationMode::LANCZOS4: { \ case InterpolationMode::LANCZOS4: { \
MIDOUT_BEGIN( \
megdnn_warp, midout_iv(4), midout_iv("DISPATCH_IMODE"_hash)) { \
DISPATCH_BMODE(InterpolationMode::LANCZOS4, _bmode, _ch, _cb); \ DISPATCH_BMODE(InterpolationMode::LANCZOS4, _bmode, _ch, _cb); \
} \
MIDOUT_END(); \
break; \ break; \
} \ } \
default: { \ default: { \
......
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
#include "src/fallback/elemwise_multi_type/opr_impl.h" #include "src/fallback/elemwise_multi_type/opr_impl.h"
#include "src/naive/handle.h" #include "src/naive/handle.h"
#include "midout.h"
MIDOUT_DECL(megdnn_fallback_elemwise_multi_type_quantized)
using namespace megdnn; using namespace megdnn;
using namespace fallback; using namespace fallback;
using namespace elemwise; using namespace elemwise;
...@@ -52,10 +55,15 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -52,10 +55,15 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void(const src_ctype*, dst_ctype*, DType, DType, size_t)> run = \ thin_function<void(const src_ctype*, dst_ctype*, DType, DType, size_t)> run = \
OpCallerUnary<_op<src_ctype, dst_ctype>, VEC>::run; \ OpCallerUnary<_op<src_ctype, dst_ctype>, VEC>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(0), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src.layout.dtype, \ run(src.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src.layout.dtype, \
dst.layout.dtype, nr_elems)); \ dst.layout.dtype, nr_elems)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -130,10 +138,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -130,10 +138,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t)> \ size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_VEC>::run; \ run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ MIDOUT_BEGIN( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \ megdnn_fallback_elemwise_multi_type_quantized, midout_iv(1), \
src0.layout.dtype, src1.layout.dtype, dst.layout.dtype, nr_elems)); \ src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, nr_elems)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -164,11 +178,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -164,11 +178,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype, dst_ctype*, DType, DType, DType, \ const src_ctype*, const src_ctype, dst_ctype*, DType, DType, DType, \
size_t)> \ size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_SCALAR>::run; \ run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_SCALAR>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(2), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>()[0], \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>()[0], \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, src0.layout.total_nr_elems())); \ dst.layout.dtype, src0.layout.total_nr_elems())); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -187,11 +206,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -187,11 +206,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype, const src_ctype*, dst_ctype*, DType, DType, DType, \ const src_ctype, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t)> \ size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, SCALAR_VEC>::run; \ run = OpCallerBinary<_op<src_ctype, dst_ctype>, SCALAR_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(3), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>()[0], src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>()[0], src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, src1.layout.total_nr_elems())); \ dst.layout.dtype, src1.layout.total_nr_elems())); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -225,11 +249,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -225,11 +249,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t)> \ size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101>::run; \ run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(4), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -248,11 +277,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -248,11 +277,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t)> \ size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101_VEC>::run; \ run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(5), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -275,11 +309,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -275,11 +309,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t, size_t)> \ size_t, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101xX>::run; \ run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101xX>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(6), _src_dt, \
_dst_dt, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
DISPATCH() DISPATCH()
...@@ -298,11 +337,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -298,11 +337,16 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \
size_t, size_t, size_t, size_t)> \ size_t, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101xX_VEC>::run; \ run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101xX_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(7), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
DISPATCH() DISPATCH()
...@@ -355,11 +399,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -355,11 +399,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \ const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \
DType, DType, DType, DType, size_t)> \ DType, DType, DType, DType, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_VEC>::run; \ run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ MIDOUT_BEGIN( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), src2.ptr<src_ctype>(), \ megdnn_fallback_elemwise_multi_type_quantized, midout_iv(8), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ src_ctype, dst_ctype, midout_iv(_mode)) { \
src2.layout.dtype, dst.layout.dtype, nr_elems)); \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, nr_elems)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -378,12 +428,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -378,12 +428,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
const src_ctype*, const src_ctype*, const src_ctype, dst_ctype*, \ const src_ctype*, const src_ctype*, const src_ctype, dst_ctype*, \
DType, DType, DType, DType, size_t)> \ DType, DType, DType, DType, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_SCALAR>::run; \ run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_SCALAR>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(9), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>()[0], dst.ptr<dst_ctype>(), src0.layout.dtype, \ src2.ptr<src_ctype>()[0], dst.ptr<dst_ctype>(), \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
src0.layout.total_nr_elems())); \ dst.layout.dtype, src0.layout.total_nr_elems())); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -407,12 +462,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -407,12 +462,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \ DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \ run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \ _op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MIDOUT_BEGIN( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ megdnn_fallback_elemwise_multi_type_quantized, midout_iv(10), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \ src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \ src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \
binfo.y, binfo.z, binfo.y* binfo.z)); \ binfo.y, binfo.z, binfo.y* binfo.z)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
DISPATCH() DISPATCH()
...@@ -437,12 +497,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -437,12 +497,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \ DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \ run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, VEC_BCAST101xX_VEC>::run; \ _op<src_ctype, dst_ctype>, VEC_BCAST101xX_VEC>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(11), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \ src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
batch_size, binfo.x, binfo.y, binfo.z)); \ dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
...@@ -465,12 +530,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -465,12 +530,17 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \ DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \ run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, BCAST101xX_VEC_BCAST101xX>::run; \ _op<src_ctype, dst_ctype>, BCAST101xX_VEC_BCAST101xX>::run; \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_multi_type_quantized, midout_iv(12), \
src_ctype, dst_ctype, midout_iv(_mode)) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \ src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
batch_size, binfo.x, binfo.y, binfo.z)); \ dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \ return; \
} \
MIDOUT_END(); \
} }
size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
......
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
#include "src/common/elemwise/kern_defs.cuh" #include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh" #include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "midout.h"
MIDOUT_DECL(megdnn_naive_elemwise_multi_type)
using namespace megdnn; using namespace megdnn;
using namespace naive; using namespace naive;
...@@ -16,7 +19,12 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -16,7 +19,12 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
typedef ElemwiseKern< \ typedef ElemwiseKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \ megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
KernImpl; \ KernImpl; \
MIDOUT_BEGIN( \
megdnn_naive_elemwise_multi_type, midout_iv(1), \
param_enumv::Elemwise::Mode::_mode) { \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<1>>(param, dst); \ dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<1>>(param, dst); \
} \
MIDOUT_END(); \
break; \ break; \
} }
......
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
#include "src/common/elemwise/kern_defs.cuh" #include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh" #include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "midout.h"
MIDOUT_DECL(megdnn_naive_elemwise_multi_type)
using namespace megdnn; using namespace megdnn;
using namespace naive; using namespace naive;
...@@ -18,7 +21,12 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -18,7 +21,12 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
typedef ElemwiseKern< \ typedef ElemwiseKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \ megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
KernImpl; \ KernImpl; \
MIDOUT_BEGIN( \
megdnn_naive_elemwise_multi_type, midout_iv(2), \
param_enumv::Elemwise::Mode::_mode) { \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<2>>(param, dst); \ dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<2>>(param, dst); \
} \
MIDOUT_END(); \
break; \ break; \
} }
...@@ -65,9 +73,13 @@ void ElemwiseMultiTypeImpl::dest_type_bool_mode( ...@@ -65,9 +73,13 @@ void ElemwiseMultiTypeImpl::dest_type_bool_mode(
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \ megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \
typename DTypeTrait<_dt>::ctype, dt_bool> \ typename DTypeTrait<_dt>::ctype, dt_bool> \
KernImpl##_mode; \ KernImpl##_mode; \
dispatch_dst_bool_op< \ using _ctype = typename DTypeTrait<_dt>::ctype; \
KernImpl##_mode, typename DTypeTrait<_dt>::ctype, dt_bool>( \ MIDOUT_BEGIN( \
param, dst); \ megdnn_naive_elemwise_multi_type, midout_iv(0), _ctype, \
param_enumv::Elemwise::Mode::_mode) { \
dispatch_dst_bool_op<KernImpl##_mode, _ctype, dt_bool>(param, dst); \
} \
MIDOUT_END(); \
break; \ break; \
} }
#define DISPATCH_MODE(_mode) \ #define DISPATCH_MODE(_mode) \
...@@ -111,9 +123,13 @@ void ElemwiseMultiTypeImpl::dest_type_bool_mode( ...@@ -111,9 +123,13 @@ void ElemwiseMultiTypeImpl::dest_type_bool_mode(
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \ megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \
typename DTypeTrait<_dt>::ctype, dt_bool> \ typename DTypeTrait<_dt>::ctype, dt_bool> \
KernImpl##_mode; \ KernImpl##_mode; \
dispatch_dst_bool_op< \ using _ctype = typename DTypeTrait<_dt>::ctype; \
KernImpl##_mode, typename DTypeTrait<_dt>::ctype, dt_bool>( \ MIDOUT_BEGIN( \
param, dst); \ megdnn_naive_elemwise_multi_type, midout_iv(1), _ctype, \
param_enumv::Elemwise::Mode::_mode) { \
dispatch_dst_bool_op<KernImpl##_mode, _ctype, dt_bool>(param, dst); \
} \
MIDOUT_END(); \
break; \ break; \
}; };
#define DISPATCH_MODE(_mode) \ #define DISPATCH_MODE(_mode) \
......
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
#include "src/common/elemwise/kern_defs.cuh" #include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh" #include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "midout.h"
MIDOUT_DECL(megdnn_naive_elemwise_multi_type)
using namespace megdnn; using namespace megdnn;
using namespace naive; using namespace naive;
...@@ -19,7 +22,12 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -19,7 +22,12 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
typedef ElemwiseKern< \ typedef ElemwiseKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \ megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
KernImpl; \ KernImpl; \
MIDOUT_BEGIN( \
megdnn_naive_elemwise_multi_type, midout_iv(3), \
param_enumv::Elemwise::Mode::_mode) { \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<3>>(param, dst); \ dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<3>>(param, dst); \
} \
MIDOUT_END(); \
break; \ break; \
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册