From 247e2f59a4721439db73b66b9918e1ddae4184b5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 9 May 2022 22:43:38 +0800 Subject: [PATCH] feat(mgb/dnn): add modes that the output type is bool in elemwise GitOrigin-RevId: fd0134fca201e64a66402a65d2748b89f64613ef --- dnn/scripts/gen_elemwise_multi_type_utils.py | 12 ++ dnn/scripts/opr_param_defs.py | 9 ++ dnn/src/common/elemwise/kern_defs.cuh | 52 ++++++++ .../common/elemwise_multi_type/kern_defs.cuh | 1 - .../common/elemwise_multi_type/opr_impl.cpp | 34 ++++- .../elemwise_multi_type/opr_impl_helper.cpp | 13 ++ .../elemwise_multi_type/opr_impl_helper.h | 18 +++ .../elemwise_multi_type/kern_impl_bool.inl | 27 ++++ dnn/src/cuda/elemwise_multi_type/kern_ops.cuh | 82 +++++++++++- .../kimpl/EQ_dt_bfloat16_dt_bool.cu | 6 + .../kimpl/EQ_dt_bool_dt_bool.cu | 6 + .../kimpl/EQ_dt_float16_dt_bool.cu | 6 + .../kimpl/EQ_dt_float32_dt_bool.cu | 6 + .../kimpl/EQ_dt_int16_dt_bool.cu | 6 + .../kimpl/EQ_dt_int32_dt_bool.cu | 6 + .../kimpl/EQ_dt_int8_dt_bool.cu | 6 + .../kimpl/EQ_dt_uint8_dt_bool.cu | 6 + .../kimpl/ISINF_dt_bfloat16_dt_bool.cu | 6 + .../kimpl/ISINF_dt_float16_dt_bool.cu | 6 + .../kimpl/ISINF_dt_float32_dt_bool.cu | 6 + .../kimpl/ISNAN_dt_bfloat16_dt_bool.cu | 6 + .../kimpl/ISNAN_dt_float16_dt_bool.cu | 6 + .../kimpl/ISNAN_dt_float32_dt_bool.cu | 6 + .../kimpl/LEQ_dt_bfloat16_dt_bool.cu | 6 + .../kimpl/LEQ_dt_bool_dt_bool.cu | 6 + .../kimpl/LEQ_dt_float16_dt_bool.cu | 6 + .../kimpl/LEQ_dt_float32_dt_bool.cu | 6 + .../kimpl/LEQ_dt_int16_dt_bool.cu | 6 + .../kimpl/LEQ_dt_int32_dt_bool.cu | 6 + .../kimpl/LEQ_dt_int8_dt_bool.cu | 6 + .../kimpl/LEQ_dt_uint8_dt_bool.cu | 6 + .../kimpl/LT_dt_bfloat16_dt_bool.cu | 6 + .../kimpl/LT_dt_bool_dt_bool.cu | 6 + .../kimpl/LT_dt_float16_dt_bool.cu | 6 + .../kimpl/LT_dt_float32_dt_bool.cu | 6 + .../kimpl/LT_dt_int16_dt_bool.cu | 6 + .../kimpl/LT_dt_int32_dt_bool.cu | 6 + .../kimpl/LT_dt_int8_dt_bool.cu | 6 + .../kimpl/LT_dt_uint8_dt_bool.cu | 6 + .../kimpl/NEQ_dt_bfloat16_dt_bool.cu | 6 + .../kimpl/NEQ_dt_bool_dt_bool.cu | 6 + .../kimpl/NEQ_dt_float16_dt_bool.cu | 6 + .../kimpl/NEQ_dt_float32_dt_bool.cu | 6 + .../kimpl/NEQ_dt_int16_dt_bool.cu | 6 + .../kimpl/NEQ_dt_int32_dt_bool.cu | 6 + .../kimpl/NEQ_dt_int8_dt_bool.cu | 6 + .../kimpl/NEQ_dt_uint8_dt_bool.cu | 6 + dnn/src/cuda/elemwise_multi_type/opr_impl.cpp | 110 ++++++++++++++++ dnn/src/cuda/elemwise_multi_type/opr_impl.h | 8 ++ dnn/src/naive/elemwise_multi_type/opr_impl.h | 51 +++++++- .../naive/elemwise_multi_type/opr_impl_4.cpp | 117 ++++++++++++++++++ dnn/test/common/checker.cpp | 4 +- dnn/test/common/elemwise.cpp | 8 ++ dnn/test/common/rng.cpp | 3 + dnn/test/cuda/elemwise_multi_type.cpp | 50 +++++++- .../megengine/core/tensor/array_method.py | 40 ++++-- .../python/megengine/functional/math.py | 6 +- src/jit/impl/ast_c.cpp | 4 +- src/opr/test/basic_arith/elemwise.cpp | 4 +- 59 files changed, 856 insertions(+), 25 deletions(-) create mode 100644 dnn/src/cuda/elemwise_multi_type/kern_impl_bool.inl create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_bfloat16_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_bool_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_float16_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_float32_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int16_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int32_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int8_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_uint8_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_bfloat16_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_float16_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_float32_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_bfloat16_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_float16_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_float32_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_bfloat16_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_bool_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_float16_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_float32_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int16_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int32_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int8_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_uint8_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_bfloat16_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_bool_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_float16_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_float32_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int16_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int32_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int8_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_uint8_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_bfloat16_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_bool_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_float16_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_float32_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int16_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int32_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int8_dt_bool.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_uint8_dt_bool.cu diff --git a/dnn/scripts/gen_elemwise_multi_type_utils.py b/dnn/scripts/gen_elemwise_multi_type_utils.py index 9bbf65d65..6de5a5124 100755 --- a/dnn/scripts/gen_elemwise_multi_type_utils.py +++ b/dnn/scripts/gen_elemwise_multi_type_utils.py @@ -6,6 +6,10 @@ SUPPORT_QINT32_DTYPES = [('dt_qint32', 'dt_qint8'), ('dt_qint8', 'dt_qint32'), SUPPORT_DTYPES_Q4 = [('dt_qint4', 'dt_qint4'), ('dt_quint4', 'dt_quint4')] SUPPORT_QINT32_DTYPES_Q4 = [('dt_qint32', 'dt_qint4'), ('dt_qint32', 'dt_quint4')] +SUPPORT_ARRITY2_DTYPES = ['dt_int32', 'dt_uint8', 'dt_int8', 'dt_int16', 'dt_bool', 'dt_float32', + 'dt_float16', 'dt_bfloat16'] +SUPPORT_ARRITY1_DTYPES = ['dt_float32','dt_float16', 'dt_bfloat16'] + MODES = { 1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', @@ -34,3 +38,11 @@ QINT32_MODES = { 2: ['ADD', 'FUSE_ADD_RELU', 'FUSE_ADD_SIGMOID', 'FUSE_ADD_TANH', 'FUSE_ADD_H_SWISH'] } + +ARRITY1_BOOL_MODES = { + 1: ['ISINF','ISNAN'], +} + +ARRITY2_BOOL_MODES = { + 2: ['EQ','LEQ','NEQ','LT'], +} diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 824c9cc9a..1c7dd1938 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -421,6 +421,9 @@ pdef('Elemwise').add_enum( Doc('GELU = 58', 'unary: x Phi(x)'), Doc('GELU_GRAD = 59', 'binary: grad(x Phi(x))'), Doc('COND_LT_MOV = 60', 'ternary: x < y ? z : 0'), + Doc('NEQ = 61', 'binary: x != y'), + Doc('ISNAN = 62', 'unary: isnan(x)'), + Doc('ISINF = 63', 'unary: isinf(x)'), ) pdef('ElemwiseMultiType').add_enum( @@ -513,6 +516,12 @@ pdef('ElemwiseMultiType').add_enum( 'compute ``a * b + c`` requiring that ``a`` be uint8 and ``b`` and ' '``c`` float32, and the result is float32.'), Doc('QCOND_LT_MOV = 57', 'quantized cond_lt_mov'), + Doc('EQ = 58', 'eq'), + Doc('NEQ = 59', 'eq'), + Doc('LT = 60', 'lt'), + Doc('LEQ = 61', 'leq'), + Doc('ISNAN = 62', 'isnan'), + Doc('ISINF = 63', 'isinf') ) pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) diff --git a/dnn/src/common/elemwise/kern_defs.cuh b/dnn/src/common/elemwise/kern_defs.cuh index 35a785cae..87788c0d8 100644 --- a/dnn/src/common/elemwise/kern_defs.cuh +++ b/dnn/src/common/elemwise/kern_defs.cuh @@ -10,6 +10,7 @@ #include #include +#include "math.h" #if MEGDNN_CC_HOST #include @@ -272,6 +273,57 @@ DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z); #undef DEF_KERN_AD #undef DEF_KERN +#undef DEF_KERN_FLOAT +#undef DEF_KERN_INT +#undef DEF_KERN_ALL + +/* ================== bool kernels ================== */ +//! define kernel +template +struct ElemwiseBoolKern; + +#define DEF_KERN(_ctype, _dtype, _mode, _imp) \ + template \ + struct ElemwiseBoolKern< \ + plat, param_enumv::Elemwise::Mode::_mode, _ctype, _dtype> { \ + typedef _ctype ctype; \ + static __host__ __device__ _dtype apply(KERN_SIG) { return _dtype(_imp); } \ + } + +//! define kernel for all float types +#define DEF_KERN_FLOAT(_mode, _imp) \ + DEF_KERN(dt_float32, dt_bool, _mode, _imp); \ + DNN_INC_FLOAT16(DEF_KERN(dt_float16, dt_bool, _mode, _imp);) \ + DNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, dt_bool, _mode, _imp);) + +//! define kernel for all int types +#define DEF_KERN_INT(_mode, _imp) \ + DEF_KERN(dt_int32, dt_bool, _mode, _imp); \ + DEF_KERN(dt_int16, dt_bool, _mode, _imp); \ + DEF_KERN(dt_int8, dt_bool, _mode, _imp); \ + DEF_KERN(dt_uint8, dt_bool, _mode, _imp); + +//! define kernel for all ctypes +#define DEF_KERN_ALL(_mode, _imp) \ + DEF_KERN_INT(_mode, _imp); \ + DEF_KERN_FLOAT(_mode, _imp); \ + DEF_KERN(dt_bool, dt_bool, _mode, _imp); +#define KERN_SIG ctype x +DEF_KERN_FLOAT(ISNAN, isnan(float(x))); +DEF_KERN_FLOAT(ISINF, isinf(float(x))); +#undef KERN_SIG +#define KERN_SIG ctype x, ctype y +DEF_KERN_ALL(LT, x < y); +DEF_KERN_ALL(LEQ, x <= y); +DEF_KERN_ALL(EQ, x == y); +DEF_KERN_ALL(NEQ, x != y); +#undef KERN_SIG + +#undef DEF_KERN_AD +#undef DEF_KERN +#undef DEF_KERN_FLOAT +#undef DEF_KERN_INT +#undef DEF_KERN_ALL } // namespace megdnn diff --git a/dnn/src/common/elemwise_multi_type/kern_defs.cuh b/dnn/src/common/elemwise_multi_type/kern_defs.cuh index 405fd073f..f7c2600a1 100644 --- a/dnn/src/common/elemwise_multi_type/kern_defs.cuh +++ b/dnn/src/common/elemwise_multi_type/kern_defs.cuh @@ -28,7 +28,6 @@ MEGDNN_HOST MEGDNN_DEVICE dtype round_shr_saturate(stype x, int k) { } return static_cast(result); } - } // namespace elemwise_multi_type } // namespace megdnn diff --git a/dnn/src/common/elemwise_multi_type/opr_impl.cpp b/dnn/src/common/elemwise_multi_type/opr_impl.cpp index b4c35bd2f..f5bc9ec14 100644 --- a/dnn/src/common/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/common/elemwise_multi_type/opr_impl.cpp @@ -31,6 +31,14 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { return func; }; + auto make_not_check_dtype_func = []() { + auto func = [](DType dtype) { + megdnn_assert( + true, "This function is to not check the dtype %s", dtype.name()); + }; + return func; + }; + auto make_check_category = [](DTypeCategory expected) { auto func = [expected](DType dtype) { megdnn_assert(expected == dtype.category()); @@ -126,6 +134,23 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { dst.need_specify_out_dtype = true; }; + auto init_bool_unary_op = [&](ModeTrait& dst, const char* name) { + dst.arity = 1; + dst.check_inp[0] = make_check_category(DTypeCategory::FLOAT); + dst.check_out = make_out_dtype_func(dtype::Bool()); + dst.name = name; + dst.need_specify_out_dtype = true; + }; + + auto init_bool_binary_op = [&](ModeTrait& dst, const char* name) { + dst.arity = 2; + dst.check_inp[0] = make_not_check_dtype_func(); + dst.check_inp[1] = make_not_check_dtype_func(); + dst.check_out = make_out_dtype_func(dtype::Bool()); + dst.name = name; + dst.need_specify_out_dtype = true; + }; + auto init_quantized_binary_op = [&](ModeTrait& dst, const char* name) { dst.arity = 2; dst.check_inp[0] = make_check_category(DTypeCategory::QUANTIZED); @@ -240,6 +265,13 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { SET(init_quantized_ternary_op, QFUSE_MUL_ADD3); SET(init_quantized_ternary_op, QCOND_LEQ_MOV); SET(init_quantized_ternary_op, QCOND_LT_MOV); + + SET(init_bool_binary_op, LT); + SET(init_bool_binary_op, LEQ); + SET(init_bool_binary_op, EQ); + SET(init_bool_binary_op, NEQ); + SET(init_bool_unary_op, ISNAN); + SET(init_bool_unary_op, ISINF); #undef SET } @@ -273,4 +305,4 @@ void ElemwiseMultiType::check_layout_and_broadcast( megdnn_assert(dst.is_contiguous()); } -// vim: syntax=cpp.doxygen +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp b/dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp index 179c51e3e..506fa59fa 100644 --- a/dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp +++ b/dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp @@ -9,6 +9,12 @@ using namespace megdnn; make_elemwise_op_param<_n>(src, dst), dst, Elemwise::Mode::_MODE); \ break +#define ON_BOOL_MODE(_MODE, _n) \ + case Mode::_MODE: \ + dest_type_bool_mode( \ + make_elemwise_op_param<_n>(src, dst), dst, Elemwise::Mode::_MODE); \ + break + void ElemwiseMultiTypeImplHelper::exec( _megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) { switch (m_param.mode) { @@ -96,6 +102,13 @@ void ElemwiseMultiTypeImplHelper::exec( ON_QUANTIZED_MODE(FUSE_MUL_ADD3, 3); ON_QUANTIZED_MODE(COND_LEQ_MOV, 3); ON_QUANTIZED_MODE(COND_LT_MOV, 3); + + ON_BOOL_MODE(LT, 2); + ON_BOOL_MODE(LEQ, 2); + ON_BOOL_MODE(EQ, 2); + ON_BOOL_MODE(NEQ, 2); + ON_BOOL_MODE(ISNAN, 1); + ON_BOOL_MODE(ISINF, 1); default: megdnn_throw("invalid mode"); } diff --git a/dnn/src/common/elemwise_multi_type/opr_impl_helper.h b/dnn/src/common/elemwise_multi_type/opr_impl_helper.h index 74a3bafec..dcf7910bb 100644 --- a/dnn/src/common/elemwise_multi_type/opr_impl_helper.h +++ b/dnn/src/common/elemwise_multi_type/opr_impl_helper.h @@ -73,6 +73,24 @@ protected: const ElemwiseOpParamN<2>& param, const TensorND& dst, Elemwise::Mode mode) = 0; + virtual void dest_type_bool_mode( + const ElemwiseOpParamN<1>& param, const TensorND& dst, + Elemwise::Mode mode) { + MEGDNN_MARK_USED_VAR(param); + MEGDNN_MARK_USED_VAR(dst); + MEGDNN_MARK_USED_VAR(mode); + megdnn_throw("Unrealized except arm_common"); + } + + virtual void dest_type_bool_mode( + const ElemwiseOpParamN<2>& param, const TensorND& dst, + Elemwise::Mode mode) { + MEGDNN_MARK_USED_VAR(param); + MEGDNN_MARK_USED_VAR(dst); + MEGDNN_MARK_USED_VAR(mode); + megdnn_throw("Unrealized except arm_common"); + } + virtual void on_quantized_mode( const ElemwiseOpParamN<3>& param, const TensorND& dst, Elemwise::Mode mode) { diff --git a/dnn/src/cuda/elemwise_multi_type/kern_impl_bool.inl b/dnn/src/cuda/elemwise_multi_type/kern_impl_bool.inl new file mode 100644 index 000000000..5d95f67a6 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kern_impl_bool.inl @@ -0,0 +1,27 @@ +#pragma once + +#ifndef KERN_IMPL_MODE +#error "KERN_IMPL_MODE, KERN_IMPL_ARITY, KERN_IMPL_STYPE, KERN_IMPL_DTYPE must be defined" +#endif + +#include "src/cuda/elemwise_multi_type/kern_ops.cuh" + +namespace megdnn { +namespace cuda { + +#define cb(_m) \ + typedef ElemwiseBoolKern< \ + megcorePlatformCUDA, param_enumv::Elemwise::Mode::_m, KERN_IMPL_STYPE, \ + KERN_IMPL_DTYPE> \ + KernImpl; \ + typedef kern_ops_quantized::QuantizedMultiTypeOp< \ + KERN_IMPL_ARITY, KERN_IMPL_STYPE, KERN_IMPL_DTYPE, KernImpl> \ + Op; \ + INST_RUN_ELEMWISE(Op, KERN_IMPL_STYPE, KERN_IMPL_ARITY); + +KERN_IMPL_MODE(cb) + +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/elemwise_multi_type/kern_ops.cuh b/dnn/src/cuda/elemwise_multi_type/kern_ops.cuh index 8de5bba16..82d52b5bd 100644 --- a/dnn/src/cuda/elemwise_multi_type/kern_ops.cuh +++ b/dnn/src/cuda/elemwise_multi_type/kern_ops.cuh @@ -4,7 +4,6 @@ #include "src/cuda/elemwise_multi_type/kern.cuh" #include "src/cuda/integer_subbyte_utils.cuh" #include "src/cuda/utils.cuh" - namespace megdnn { namespace cuda { using namespace elemwise_intl; @@ -122,6 +121,7 @@ struct QuantizedMultiTypeOp< (std::is_same::value || std::is_same::value || std::is_same::value) && + !std::is_same::value && IsNotTypeQ4::value>::type> { ctype_dst* dst; CudaDTypeParam dst_param; @@ -158,6 +158,37 @@ struct QuantizedMultiTypeOp< #endif }; +template +struct QuantizedMultiTypeOp< + 1, ctype_src, ctype_dst, KernImpl, + typename std::enable_if::value>::type> { + ctype_dst* dst; + typedef typename elemwise_intl::VectTypeTrait::vect_type src_vect_type; + typedef typename elemwise_intl::VectTypeTrait::vect_type dst_vect_type; + +#if !MEGDNN_CC_CUDA + QuantizedMultiTypeOp(ctype_dst* m_dst) : dst{m_dst} {} +#endif + +#if MEGDNN_CC_CUDA + __device__ __forceinline__ ctype_dst apply(ctype_src v1) { + ctype_dst rv = KernImpl::apply(v1); + return rv; + } + + __device__ __forceinline__ void operator()(uint32_t idx, ctype_src a) { + dst[idx] = KernImpl::apply(a); + } + + __device__ __forceinline__ void operator()(uint32_t idx, src_vect_type a) { + ctype_src a_x(a.x), a_y(a.y), a_z(a.z), a_w(a.w); + ctype_dst x = apply(a_x), y = apply(a_y), z = apply(a_z), w = apply(a_w); + *(dst_vect_type*)(&dst[idx]) = + elemwise_intl::VectTypeTrait::make_vector(x, y, z, w); + } +#endif +}; + template struct QuantizedMultiTypeOp< 2, ctype_src, ctype_dst, KernImpl, @@ -165,6 +196,7 @@ struct QuantizedMultiTypeOp< (std::is_same::value || std::is_same::value || std::is_same::value) && + !std::is_same::value && IsNotTypeQ4::value>::type> { ctype_dst* dst; CudaDTypeParam dst_param; @@ -206,6 +238,40 @@ struct QuantizedMultiTypeOp< #endif }; +template +struct QuantizedMultiTypeOp< + 2, ctype_src, ctype_dst, KernImpl, + typename std::enable_if<(std::is_same::value)>::type> { + ctype_dst* dst; + typedef typename elemwise_intl::VectTypeTrait::vect_type src_vect_type; + typedef typename elemwise_intl::VectTypeTrait::vect_type dst_vect_type; + +#if !MEGDNN_CC_CUDA + QuantizedMultiTypeOp(ctype_dst* m_dst) : dst{m_dst} {} +#endif + +#if MEGDNN_CC_CUDA + __device__ __forceinline__ ctype_dst apply(ctype_src v1, ctype_src v2) { + ctype_dst rv = KernImpl::apply(v1, v2); + return rv; + } + + __device__ __forceinline__ void operator()(uint32_t idx, ctype_src a, ctype_src b) { + dst[idx] = KernImpl::apply(a, b); + } + + __device__ __forceinline__ void operator()( + uint32_t idx, src_vect_type a, src_vect_type b) { + ctype_src a_x(a.x), a_y(a.y), a_z(a.z), a_w(a.w), b_x(b.x), b_y(b.y), b_z(b.z), + b_w(b.w); + ctype_dst x = apply(a_x, b_x), y = apply(a_y, b_y), z = apply(a_z, b_z), + w = apply(a_w, b_w); + *(dst_vect_type*)(&dst[idx]) = + elemwise_intl::VectTypeTrait::make_vector(x, y, z, w); + } +#endif +}; + template struct QuantizedMultiTypeOp< 3, ctype_src, ctype_dst, KernImpl, @@ -262,7 +328,8 @@ template struct QuantizedMultiTypeOp< 1, ctype_src, ctype_dst, KernImpl, typename std::enable_if< - IsTypeQ4::value && IsNotTypeQ4::value>::type> { + IsTypeQ4::value && IsNotTypeQ4::value && + !std::is_same::value>::type> { ctype_dst* dst; CudaDTypeParam dst_param; CudaDTypeParam param_a; @@ -293,7 +360,8 @@ template struct QuantizedMultiTypeOp< 2, ctype_src, ctype_dst, KernImpl, typename std::enable_if< - IsTypeQ4::value && IsNotTypeQ4::value>::type> { + IsTypeQ4::value && IsNotTypeQ4::value && + !std::is_same::value>::type> { ctype_dst* dst; CudaDTypeParam dst_param; CudaDTypeParam param_a, param_b; @@ -326,7 +394,8 @@ template struct QuantizedMultiTypeOp< 1, ctype_src, ctype_dst, KernImpl, typename std::enable_if< - IsTypeQ4::value && IsTypeQ4::value>::type> { + IsTypeQ4::value && IsTypeQ4::value && + !std::is_same::value>::type> { using src_storage = typename elemwise_intl::VectTypeTrait::Storage; using dst_storage = typename elemwise_intl::VectTypeTrait::Storage; dst_storage* dst; @@ -371,6 +440,7 @@ struct QuantizedMultiTypeOp< (std::is_same::value || std::is_same::value || std::is_same::value) && + !std::is_same::value && IsTypeQ4::value>::type> { using dst_storage = typename elemwise_intl::VectTypeTrait::Storage; dst_storage* dst; @@ -407,7 +477,8 @@ template struct QuantizedMultiTypeOp< 2, ctype_src, ctype_dst, KernImpl, typename std::enable_if< - IsTypeQ4::value && IsTypeQ4::value>::type> { + IsTypeQ4::value && IsTypeQ4::value && + !std::is_same::value>::type> { using src_storage = typename elemwise_intl::VectTypeTrait::Storage; using dst_storage = typename elemwise_intl::VectTypeTrait::Storage; dst_storage* dst; @@ -460,6 +531,7 @@ struct QuantizedMultiTypeOp< (std::is_same::value || std::is_same::value || std::is_same::value) && + !std::is_same::value && IsTypeQ4::value>::type> { using dst_storage = typename elemwise_intl::VectTypeTrait::Storage; dst_storage* dst; diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_bfloat16_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_bfloat16_dt_bool.cu new file mode 100644 index 000000000..b792ac348 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_bfloat16_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_bfloat16 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_bool_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_bool_dt_bool.cu new file mode 100644 index 000000000..7cbc1989f --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_bool_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_bool +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_float16_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_float16_dt_bool.cu new file mode 100644 index 000000000..79a989c78 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_float16_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_float16 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_float32_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_float32_dt_bool.cu new file mode 100644 index 000000000..59902cf7d --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_float32_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_float32 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int16_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int16_dt_bool.cu new file mode 100644 index 000000000..136e78085 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int16_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_int16 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int32_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int32_dt_bool.cu new file mode 100644 index 000000000..066ee87a3 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int32_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_int32 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int8_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int8_dt_bool.cu new file mode 100644 index 000000000..d6ae05ca5 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_int8_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_int8 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_uint8_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_uint8_dt_bool.cu new file mode 100644 index 000000000..f55f7fc50 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_uint8_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_uint8 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_bfloat16_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_bfloat16_dt_bool.cu new file mode 100644 index 000000000..4972784d0 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_bfloat16_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISINF, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_bfloat16 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_float16_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_float16_dt_bool.cu new file mode 100644 index 000000000..bdd1d0dd6 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_float16_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISINF, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_float16 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_float32_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_float32_dt_bool.cu new file mode 100644 index 000000000..7690ad7ed --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ISINF_dt_float32_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISINF, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_float32 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_bfloat16_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_bfloat16_dt_bool.cu new file mode 100644 index 000000000..128ff8b47 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_bfloat16_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISNAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_bfloat16 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_float16_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_float16_dt_bool.cu new file mode 100644 index 000000000..fcf48b128 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_float16_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISNAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_float16 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_float32_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_float32_dt_bool.cu new file mode 100644 index 000000000..497987ce9 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ISNAN_dt_float32_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ISNAN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_float32 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_bfloat16_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_bfloat16_dt_bool.cu new file mode 100644 index 000000000..5bb44b4a8 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_bfloat16_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_bfloat16 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_bool_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_bool_dt_bool.cu new file mode 100644 index 000000000..67bb2eca6 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_bool_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_bool +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_float16_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_float16_dt_bool.cu new file mode 100644 index 000000000..2a15ff9d0 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_float16_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_float16 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_float32_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_float32_dt_bool.cu new file mode 100644 index 000000000..6f7c1a1ed --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_float32_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_float32 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int16_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int16_dt_bool.cu new file mode 100644 index 000000000..6c4b034d7 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int16_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_int16 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int32_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int32_dt_bool.cu new file mode 100644 index 000000000..9bc9a1e36 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int32_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_int32 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int8_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int8_dt_bool.cu new file mode 100644 index 000000000..dd544b810 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_int8_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_int8 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_uint8_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_uint8_dt_bool.cu new file mode 100644 index 000000000..cd7e9c0d8 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_uint8_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_uint8 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_bfloat16_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_bfloat16_dt_bool.cu new file mode 100644 index 000000000..2d1b7a9b0 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_bfloat16_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_bfloat16 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_bool_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_bool_dt_bool.cu new file mode 100644 index 000000000..15d7ce8f5 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_bool_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_bool +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_float16_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_float16_dt_bool.cu new file mode 100644 index 000000000..a2115755f --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_float16_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_float16 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_float32_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_float32_dt_bool.cu new file mode 100644 index 000000000..ad57554a6 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_float32_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_float32 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int16_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int16_dt_bool.cu new file mode 100644 index 000000000..60f112e38 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int16_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_int16 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int32_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int32_dt_bool.cu new file mode 100644 index 000000000..f1e34e577 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int32_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_int32 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int8_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int8_dt_bool.cu new file mode 100644 index 000000000..2339ff188 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_int8_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_int8 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_uint8_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_uint8_dt_bool.cu new file mode 100644 index 000000000..59a4acc5c --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_uint8_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_uint8 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_bfloat16_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_bfloat16_dt_bool.cu new file mode 100644 index 000000000..96d6c2eab --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_bfloat16_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_bfloat16 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_bool_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_bool_dt_bool.cu new file mode 100644 index 000000000..72d82a341 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_bool_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_bool +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_float16_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_float16_dt_bool.cu new file mode 100644 index 000000000..d6ff7560c --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_float16_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_float16 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_float32_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_float32_dt_bool.cu new file mode 100644 index 000000000..4601f4abe --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_float32_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_float32 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int16_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int16_dt_bool.cu new file mode 100644 index 000000000..77f033079 --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int16_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_int16 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int32_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int32_dt_bool.cu new file mode 100644 index 000000000..4289eb28b --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int32_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_int32 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int8_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int8_dt_bool.cu new file mode 100644 index 000000000..00be07a5c --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_int8_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_int8 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_uint8_dt_bool.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_uint8_dt_bool.cu new file mode 100644 index 000000000..11d9b844c --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/NEQ_dt_uint8_dt_bool.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls_bool.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_uint8 +#define KERN_IMPL_DTYPE dt_bool +#include "../kern_impl_bool.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp b/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp index 5f351213a..64a90258a 100644 --- a/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp @@ -295,6 +295,60 @@ IMPL_MODE_DISPATCHER(2, dt_qint32, dt_quint4); #undef _cb_dispatch_mode #undef IMPL_MODE_DISPATCHER +#define _cb_dispatch_mode(_m) \ + case param::Elemwise::Mode::_m: \ + do { \ + using KernImpl = ElemwiseBoolKern< \ + megcorePlatformCUDA, param_enumv::Elemwise::Mode::_m, src_ctype, \ + dt_bool>; \ + using Op = kern_ops_quantized::QuantizedMultiTypeOp< \ + arity, src_ctype, bool, KernImpl>; \ + dst_ctype* dst = dst_tensor.ptr(); \ + Op op(dst); \ + return run_elemwise(src, stream, op); \ + } while (0); + +#define IMPL_MODE_DISPATCHER_BOOL(_arity, _src_ctype, _dst_ctype) \ + template <> \ + struct ModeDispatcher<_arity, _src_ctype, _dst_ctype> { \ + static constexpr int arity = _arity; \ + using src_ctype = _src_ctype; \ + using dst_ctype = _dst_ctype; \ + static void run( \ + const ElemwiseOpParamN<_arity>& src, const TensorND& dst_tensor, \ + param::Elemwise::Mode mode, cudaStream_t stream) { \ + switch (mode) { \ + FOREACH(_cb_dispatch_mode) \ + default: \ + megdnn_throw("bad mode"); \ + } \ + } \ + } + +#define FOREACH(cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(NEQ, cb) +IMPL_MODE_DISPATCHER_BOOL(2, dt_int8, dt_bool); +IMPL_MODE_DISPATCHER_BOOL(2, dt_float32, dt_bool); +IMPL_MODE_DISPATCHER_BOOL(2, dt_bfloat16, dt_bool); +IMPL_MODE_DISPATCHER_BOOL(2, dt_float16, dt_bool); +IMPL_MODE_DISPATCHER_BOOL(2, dt_int16, dt_bool); +IMPL_MODE_DISPATCHER_BOOL(2, dt_int32, dt_bool); +IMPL_MODE_DISPATCHER_BOOL(2, dt_bool, dt_bool); +IMPL_MODE_DISPATCHER_BOOL(2, dt_uint8, dt_bool); +#undef FOREACH +#define FOREACH(cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ISNAN, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ISINF, cb) +IMPL_MODE_DISPATCHER_BOOL(1, dt_float16, dt_bool); +IMPL_MODE_DISPATCHER_BOOL(1, dt_float32, dt_bool); +IMPL_MODE_DISPATCHER_BOOL(1, dt_bfloat16, dt_bool); +#undef FOREACH +#undef _cb_dispatch_mode +#undef IMPL_MODE_DISPATCHER_BOOL + template void dispatch_src_ctype( const ElemwiseOpParamN<1>&, const TensorND& dst_tensor, Elemwise::Mode, @@ -578,6 +632,62 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( #undef DISPATCH } +void ElemwiseMultiTypeImpl::dest_type_bool_mode( + const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor, + Elemwise::Mode mode) { + auto stream = cuda_stream(this->handle()); + switch (param[0].layout.dtype.enumv()) { +#define DISPATCH(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + ModeDispatcher<1, typename DTypeTrait<_dt>::ctype, bool>::run( \ + param, dst_tensor, mode, stream); \ + break; \ + } + + DISPATCH(dtype::Float32); + DISPATCH(dtype::Float16); + DISPATCH(dtype::BFloat16); + + default: + megdnn_throw(ssprintf( + "Unsupported input dtype %s for ElemwiseMultiType", + param[0].layout.dtype.name())); + } + +#undef DISPATCH +} + +void ElemwiseMultiTypeImpl::dest_type_bool_mode( + const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor, + Elemwise::Mode mode) { + megdnn_assert(param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv()); + auto stream = cuda_stream(this->handle()); + switch (param[0].layout.dtype.enumv()) { +#define DISPATCH(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + ModeDispatcher<2, typename DTypeTrait<_dt>::ctype, bool>::run( \ + param, dst_tensor, mode, stream); \ + break; \ + } + + DISPATCH(dtype::Int8); + DISPATCH(dtype::Float32); + DISPATCH(dtype::BFloat16); + DISPATCH(dtype::Bool); + DISPATCH(dtype::Float16); + DISPATCH(dtype::Int16); + DISPATCH(dtype::Int32); + DISPATCH(dtype::Uint8); + + default: + megdnn_throw(ssprintf( + "Unsupported input dtype %s for ElemwiseMultiType", + param[0].layout.dtype.name())); + } + +#undef DISPATCH +} + void ElemwiseMultiTypeImpl::on_quantized_mode( const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor, Elemwise::Mode mode) { diff --git a/dnn/src/cuda/elemwise_multi_type/opr_impl.h b/dnn/src/cuda/elemwise_multi_type/opr_impl.h index 9a5475730..27c109d5e 100644 --- a/dnn/src/cuda/elemwise_multi_type/opr_impl.h +++ b/dnn/src/cuda/elemwise_multi_type/opr_impl.h @@ -36,6 +36,14 @@ class ElemwiseMultiTypeImpl final : public ElemwiseMultiTypeImplHelper { const ElemwiseOpParamN<3>& param, const TensorND& dst, Elemwise::Mode mode) override; + void dest_type_bool_mode( + const ElemwiseOpParamN<1>& param, const TensorND& dst, + Elemwise::Mode mode) override; + + void dest_type_bool_mode( + const ElemwiseOpParamN<2>& param, const TensorND& dst, + Elemwise::Mode mode) override; + public: using ElemwiseMultiTypeImplHelper::ElemwiseMultiTypeImplHelper; }; diff --git a/dnn/src/naive/elemwise_multi_type/opr_impl.h b/dnn/src/naive/elemwise_multi_type/opr_impl.h index a1c4ec1c3..d39fa3550 100644 --- a/dnn/src/naive/elemwise_multi_type/opr_impl.h +++ b/dnn/src/naive/elemwise_multi_type/opr_impl.h @@ -3,7 +3,6 @@ #include "megdnn/tensor_iter.h" #include "src/common/elemwise_multi_type/opr_impl_helper.h" #include "src/naive/handle.h" - namespace megdnn { namespace naive { @@ -67,6 +66,25 @@ class ElemwiseMultiTypeImpl : public ElemwiseMultiTypeImplHelper { MEGDNN_DISPATCH_CPU_KERN_OPR(work()); } + template + void dispatch_dst_bool_op( + const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor) { + auto size = param.size; + auto src0 = param[0]; + auto work = [src0, size, dst_tensor]() { + // This is needed as these iterators are captured as const value. + auto iA = tensor_iter_valonly(src0).begin(); + auto pD = tensor_iter_valonly(dst_tensor).begin(); + for (size_t i = 0; i < size; i++) { + src_ctype a = *iA; + *pD = KernImpl::apply(a); + ++iA; + ++pD; + } + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); + } + template void dispatch_add_qint_op( const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor) { @@ -97,6 +115,29 @@ class ElemwiseMultiTypeImpl : public ElemwiseMultiTypeImplHelper { MEGDNN_DISPATCH_CPU_KERN_OPR(work()); } + template + void dispatch_dst_bool_op( + const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor) { + auto size = param.size; + auto src0 = param[0]; + auto src1 = param[1]; + auto work = [src0, src1, size, dst_tensor]() { + // This is needed as these iterators are captured as const value. + auto iA = tensor_iter_valonly(src0).begin(); + auto iB = tensor_iter_valonly(src1).begin(); + auto pD = tensor_iter_valonly(dst_tensor).begin(); + for (size_t i = 0; i < size; i++) { + src_ctype a = *iA; + src_ctype b = *iB; + *pD = KernImpl::apply(a, b); + ++iA; + ++iB; + ++pD; + } + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); + } + template void dispatch_add_qint_op( const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor) { @@ -178,6 +219,14 @@ protected: const ElemwiseOpParamN<3>& param, const TensorND& dst, Elemwise::Mode mode) override; + void dest_type_bool_mode( + const ElemwiseOpParamN<1>& param, const TensorND& dst, + Elemwise::Mode mode) override; + + void dest_type_bool_mode( + const ElemwiseOpParamN<2>& param, const TensorND& dst, + Elemwise::Mode mode) override; + public: using ElemwiseMultiTypeImplHelper::ElemwiseMultiTypeImplHelper; }; diff --git a/dnn/src/naive/elemwise_multi_type/opr_impl_4.cpp b/dnn/src/naive/elemwise_multi_type/opr_impl_4.cpp index c94be9179..f9b86f864 100644 --- a/dnn/src/naive/elemwise_multi_type/opr_impl_4.cpp +++ b/dnn/src/naive/elemwise_multi_type/opr_impl_4.cpp @@ -54,4 +54,121 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( } } +void ElemwiseMultiTypeImpl::dest_type_bool_mode( + const ElemwiseOpParamN<1>& param, const TensorND& dst, Elemwise::Mode mode) { + switch (mode) { + case Elemwise::Mode::ISINF: { + switch (param[0].layout.dtype.enumv()) { +#define DISPATCH(_dt, _mode) \ + case DTypeTrait<_dt>::enumv: { \ + typedef ElemwiseBoolKern< \ + megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \ + typename DTypeTrait<_dt>::ctype, dt_bool> \ + KernImpl##_mode; \ + dispatch_dst_bool_op< \ + KernImpl##_mode, typename DTypeTrait<_dt>::ctype, dt_bool>( \ + param, dst); \ + break; \ + } +#define DISPATCH_MODE(_mode) \ + DISPATCH(megdnn::dtype::Float32, _mode); \ + DNN_INC_FLOAT16(DISPATCH(megdnn::dtype::Float16, _mode);) \ + DNN_INC_FLOAT16(DISPATCH(megdnn::dtype::BFloat16, _mode);) + DISPATCH_MODE(ISINF); + default: + megdnn_throw(ssprintf( + "Unsupported input dtype %s for ElemwiseMultiType", + param[0].layout.dtype.name())); + }; + break; + }; + case Elemwise::Mode::ISNAN: { + switch (param[0].layout.dtype.enumv()) { + DISPATCH_MODE(ISNAN); + default: + megdnn_throw(ssprintf( + "Unsupported input dtype %s for ElemwiseMultiType", + param[0].layout.dtype.name())); + }; + break; + }; + default: + megdnn_assert_internal(0); + } +#undef DISPATCH_MODE +#undef DISPATCH +} + +void ElemwiseMultiTypeImpl::dest_type_bool_mode( + const ElemwiseOpParamN<2>& param, const TensorND& dst, Elemwise::Mode mode) { + megdnn_assert(param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv()); + switch (mode) { + case Elemwise::Mode::EQ: { + switch (param[0].layout.dtype.enumv()) { +#define DISPATCH(_dt, _mode) \ + case DTypeTrait<_dt>::enumv: { \ + typedef ElemwiseBoolKern< \ + megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, \ + typename DTypeTrait<_dt>::ctype, dt_bool> \ + KernImpl##_mode; \ + dispatch_dst_bool_op< \ + KernImpl##_mode, typename DTypeTrait<_dt>::ctype, dt_bool>( \ + param, dst); \ + break; \ + }; +#define DISPATCH_MODE(_mode) \ + DISPATCH(megdnn::dtype::Float32, _mode); \ + DNN_INC_FLOAT16(DISPATCH(megdnn::dtype::Float16, _mode);) \ + DNN_INC_FLOAT16(DISPATCH(megdnn::dtype::BFloat16, _mode);) \ + DISPATCH(megdnn::dtype::Int32, _mode); \ + DISPATCH(megdnn::dtype::Int16, _mode); \ + DISPATCH(megdnn::dtype::Int8, _mode); \ + DISPATCH(megdnn::dtype::Uint8, _mode); \ + DISPATCH(megdnn::dtype::Bool, _mode); + DISPATCH_MODE(EQ); + break; + default: + megdnn_throw(ssprintf( + "Unsupported input dtype %s for ElemwiseMultiType", + param[0].layout.dtype.name())); + }; + break; + }; + case Elemwise::Mode::NEQ: { + switch (param[0].layout.dtype.enumv()) { + DISPATCH_MODE(NEQ); + default: + megdnn_throw(ssprintf( + "Unsupported input dtype %s for ElemwiseMultiType", + param[0].layout.dtype.name())); + }; + break; + }; + case Elemwise::Mode::LT: { + switch (param[0].layout.dtype.enumv()) { + DISPATCH_MODE(LT); + default: + megdnn_throw(ssprintf( + "Unsupported input dtype %s for ElemwiseMultiType", + param[0].layout.dtype.name())); + }; + break; + }; + case Elemwise::Mode::LEQ: { + switch (param[0].layout.dtype.enumv()) { + DISPATCH_MODE(LEQ); + default: + megdnn_throw(ssprintf( + "Unsupported input dtype %s for ElemwiseMultiType", + param[0].layout.dtype.name())); + }; + break; + }; + default: + megdnn_assert_internal(0); + } +#undef DISPATCH_MODE +#undef DISPATCH +} + // vim: syntax=cpp.doxygen diff --git a/dnn/test/common/checker.cpp b/dnn/test/common/checker.cpp index abfc37bc2..1491f5a3c 100644 --- a/dnn/test/common/checker.cpp +++ b/dnn/test/common/checker.cpp @@ -149,8 +149,9 @@ void copy_tensors( //! use QuantizedS16 dtype in winograd_filter_preprocess now. cb(::megdnn::dtype::QuantizedS16) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) cb(::megdnn::dtype::Uint16) cb(::megdnn::dtype::QuantizedS1) + cb(::megdnn::dtype::Bool) #undef cb - default : megdnn_trap(); + default : megdnn_trap(); } } @@ -325,6 +326,7 @@ void CheckerHelper::do_exec( m_output_canonizer(tensors_cur_host); m_output_canonizer(tensors_naive); } + check_tensors(tensors_naive, tensors_cur_host); if (m_extra_opr_impl) { check_tensors(tensors_naive, *tensors_extra_opr_impl); diff --git a/dnn/test/common/elemwise.cpp b/dnn/test/common/elemwise.cpp index f7d3cb795..31f9bcf02 100644 --- a/dnn/test/common/elemwise.cpp +++ b/dnn/test/common/elemwise.cpp @@ -756,6 +756,14 @@ DEF_TEST(all_modes) { auto should_ignore = [handle](Mode mode) { MEGDNN_MARK_USED_VAR(mode); + switch (mode) { + case Mode::NEQ: + case Mode::ISNAN: + case Mode::ISINF: + return true; + default: + break; + } return false; }; diff --git a/dnn/test/common/rng.cpp b/dnn/test/common/rng.cpp index 364baa695..030ccc8b0 100644 --- a/dnn/test/common/rng.cpp +++ b/dnn/test/common/rng.cpp @@ -195,6 +195,9 @@ void IIDRNG::gen(const TensorND& tensor) { if (tensor.layout.dtype.enumv() == DTypeEnum::Uint16) { return; } + if (tensor.layout.dtype.enumv() == DTypeEnum::Bool) { + return; + } megdnn_assert( 0, "IIDRNG does not know how to generate value for DType %s", tensor.layout.dtype.name()); diff --git a/dnn/test/cuda/elemwise_multi_type.cpp b/dnn/test/cuda/elemwise_multi_type.cpp index 36b985fe3..96b4b944a 100644 --- a/dnn/test/cuda/elemwise_multi_type.cpp +++ b/dnn/test/cuda/elemwise_multi_type.cpp @@ -4,7 +4,6 @@ #include "test/cuda/benchmark.h" #include "test/cuda/fixture.h" #include "test/cuda/utils.h" - #undef cuda_check #include "src/cuda/utils.h" @@ -143,6 +142,43 @@ static void run_test_q4(int arity, Checker& checker, Mode mod } } +static void run_test_bool(int arity, Checker& checker, Mode mode) { + for (DType type : + std::vector{{dtype::Int8()}, {dtype::Float32()}, {dtype::Float16()}}) { + if ((mode == Mode::ISNAN || mode == Mode::ISINF) && type == dtype::Int8()) { + continue; + } + checker.set_param(mode); + UniformIntRNG rng_int8{1, 1}; + NormalRNG rng_normal{0, 1}; + + auto set_inp_rng = [&](DType dtype, size_t i) { + if (dtype.enumv() == DTypeEnum::Int8) { + checker.set_rng(i, &rng_int8); + } else if ( + dtype.enumv() == DTypeEnum::Float32 || + dtype.enumv() == DTypeEnum::Float16) { + checker.set_rng(i, &rng_normal); + } else { + megdnn_assert(0); + } + checker.set_dtype(i, dtype); + }; + auto src_type = type; + for (int i = 0; i < arity; i++) { + set_inp_rng(src_type, i); + } + + if (arity == 1) { + checker.execs({{3, 4, 5, 6}, {}}); + } else if (arity == 2) { + checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {}}); + } else { + megdnn_assert(0); + } + } +} + TEST_F(CUDA, ELEMWISE_QUANTIZED_MODE_UNARY) { Checker checker(handle_cuda()); for (auto mode : @@ -203,6 +239,18 @@ TEST_F(CUDA, ELEMWISE_QUANTIZED_MODE_BINARY) { } } +TEST_F(CUDA, ELEMWISE_BOOL_MODE_BINARY) { + using Mode = ElemwiseMultiType::Param::Mode; + + Checker checker(handle_cuda()); + for (auto mode : {Mode::EQ, Mode::NEQ, Mode::LT, Mode::LEQ}) { + run_test_bool(2, checker, mode); + } + for (auto mode : {Mode::ISNAN, Mode::ISINF}) { + run_test_bool(1, checker, mode); + } +} + TEST_F(CUDA, ELEMWISE_QUANTIZED_MODE_TENARY) { using Mode = ElemwiseMultiType::Param::Mode; Checker checker(handle_cuda()); diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 136720c5c..78b2668f0 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -23,11 +23,25 @@ from .._imperative_rt.core2 import ( ) from ..ops import builtin from . import amp -from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph +from .utils import ( + _normalize_axis, + astensor1d, + cast_tensors, + convert_inputs, + make_shape_tuple, + subgraph, +) _ElwMod = builtin.Elemwise.Mode +def _elemwise_multi_type(*args, mode, **kwargs): + op = builtin.ElemwiseMultiType(mode=mode, **kwargs) + args = convert_inputs(*args) + (result,) = apply(op, *args) + return result + + def _elwise_apply(args, mode): op = builtin.Elemwise(mode) (result,) = apply(op, *args) @@ -234,13 +248,23 @@ class ArrayMethodMixin(abc.ABC): __hash__ = None # due to __eq__ diviates from python convention - __lt__ = lambda self, value: _elwise(self, value, mode=_ElwMod.LT).astype("bool") - __le__ = lambda self, value: _elwise(self, value, mode=_ElwMod.LEQ).astype("bool") - __gt__ = lambda self, value: _elwise(value, self, mode=_ElwMod.LT).astype("bool") - __ge__ = lambda self, value: _elwise(value, self, mode=_ElwMod.LEQ).astype("bool") - __eq__ = lambda self, value: _elwise(self, value, mode=_ElwMod.EQ).astype("bool") - __ne__ = lambda self, value: _elwise( - _elwise(self, value, mode=_ElwMod.EQ).astype("bool"), mode=_ElwMod.NOT, + __lt__ = lambda self, value: _elemwise_multi_type( + self, value, mode="lt", dtype="Bool" + ) + __le__ = lambda self, value: _elemwise_multi_type( + self, value, mode="leq", dtype="Bool" + ) + __gt__ = lambda self, value: _elemwise_multi_type( + value, self, mode="lt", dtype="Bool" + ) + __ge__ = lambda self, value: _elemwise_multi_type( + value, self, mode="leq", dtype="Bool" + ) + __eq__ = lambda self, value: _elemwise_multi_type( + self, value, mode="eq", dtype="Bool" + ) + __ne__ = lambda self, value: _elemwise_multi_type( + self, value, mode="neq", dtype="Bool" ) __neg__ = _unary_elwise(_ElwMod.NEGATE) diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 8cf735623..0fa09958f 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -10,7 +10,7 @@ from ..core.tensor.array_method import _matmul from ..core.tensor.utils import _normalize_axis from ..tensor import Tensor from ..utils.deprecation import deprecated_kwargs_default -from .elemwise import clip +from .elemwise import _elemwise_multi_type, clip from .tensor import expand_dims, squeeze __all__ = [ @@ -52,7 +52,7 @@ def isnan(inp: Tensor) -> Tensor: >>> F.isnan(x).numpy() array([False, True, False]) """ - return inp != inp + return _elemwise_multi_type(inp, mode="isnan", dtype="Bool") def isinf(inp: Tensor) -> Tensor: @@ -69,7 +69,7 @@ def isinf(inp: Tensor) -> Tensor: >>> F.isinf(x).numpy() array([False, True, False]) """ - return abs(inp).astype("float32") == float("inf") + return _elemwise_multi_type(inp, mode="isinf", dtype="Bool") def sign(inp: Tensor): diff --git a/src/jit/impl/ast_c.cpp b/src/jit/impl/ast_c.cpp index 25df37758..edce00074 100644 --- a/src/jit/impl/ast_c.cpp +++ b/src/jit/impl/ast_c.cpp @@ -133,9 +133,9 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() { 0.f}) / 6.f), }; - mgb_assert(map.size() + 16 == opr::Elemwise::Param::MODE_NR_MEMBER); + mgb_assert(map.size() + 19 == opr::Elemwise::Param::MODE_NR_MEMBER); // unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH, - // ERFINV, ERFCINV, NOT, AND, OR, XOR + // ERFINV, ERFCINV, NOT, AND, OR, XOR, NEQ, ISNAN, ISINF return map; #undef ADD_OPR } diff --git a/src/opr/test/basic_arith/elemwise.cpp b/src/opr/test/basic_arith/elemwise.cpp index c220102cd..b4bab90d5 100644 --- a/src/opr/test/basic_arith/elemwise.cpp +++ b/src/opr/test/basic_arith/elemwise.cpp @@ -756,8 +756,8 @@ TYPED_TEST(TestOprBasicArithTernaryElemwise, Float32) { TEST(TestOprBasicArithElemwise, CheckAllModeTested) { size_t nr_member = opr::Elemwise::Param::MODE_NR_MEMBER; - ASSERT_EQ(nr_member, tested_mode.size() + 4); - // Not using TestRunner: NOT, AND, OR, XOR + ASSERT_EQ(nr_member, tested_mode.size() + 7); + // Not using TestRunner: NOT, AND, OR, XOR, NEQ, ISNAN, ISINF } #define TEST_OPR_BASIC_ARITH_UNARY_BOOL(_mode, _op) \ TEST(TestOprBasicArithElemwise, _mode) { \ -- GitLab