From ffbf8fad6c36f5becc43b488bb78875c5c725344 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 25 Mar 2022 15:13:33 +0800 Subject: [PATCH] feat(fallback): add general intrinsic to elemwise multitype GitOrigin-RevId: fe7b335545fd959f917b7df8ee48739ccb2a86ab --- .../arm_common/elemwise_helper/elemwise_op.h | 2 +- .../elemwise_multi_type/opr_impl.cpp | 30 +- .../fallback/elemwise_helper/elemwise_op.h | 57 -- .../elemwise_helper/kimpl/fuse_add_h_swish.h | 2 +- .../fallback/elemwise_helper/kimpl/op_base.h | 10 +- dnn/src/fallback/elemwise_helper/kimpl/relu.h | 2 +- dnn/src/fallback/elemwise_helper/op_common.h | 50 ++ .../fallback/elemwise_multi_type/opr_impl.h | 12 + .../elemwise_multi_type/quantized_impl.cpp | 499 ++++++++++++++++++ .../fallback/general_intrinsic/gi_common.h | 1 + dnn/test/arm_common/elemwise_multi_type.cpp | 93 ++++ dnn/test/fallback/elemwise_multi_type.cpp | 169 ++++++ 12 files changed, 837 insertions(+), 90 deletions(-) create mode 100644 dnn/src/fallback/elemwise_multi_type/quantized_impl.cpp diff --git a/dnn/src/arm_common/elemwise_helper/elemwise_op.h b/dnn/src/arm_common/elemwise_helper/elemwise_op.h index b62a2987e..28b4c398d 100644 --- a/dnn/src/arm_common/elemwise_helper/elemwise_op.h +++ b/dnn/src/arm_common/elemwise_helper/elemwise_op.h @@ -15,7 +15,7 @@ #include "src/arm_common/elemwise_helper/op_binary.h" #include "src/arm_common/elemwise_helper/op_ternary.h" #include "src/arm_common/elemwise_helper/op_unary.h" -#include "src/fallback/elemwise_helper/elemwise_op.h" +#include "src/fallback/elemwise_helper/op_common.h" namespace megdnn { namespace elemwise { diff --git a/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp b/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp index d56785dcb..2007f0b77 100644 --- a/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp @@ -364,17 +364,9 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( } #define DISPATCH() \ - if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ - dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ - DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ - } else if ( \ - param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ - dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ + if (param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ + dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, dtype::Quantized8Asymm) \ - } else if ( \ - param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ - dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ - DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ } else if ( \ param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ @@ -467,16 +459,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( #define DISPATCH() \ if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ - dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ - DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ - } else if ( \ - param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ - dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ + dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ DISPATCH_MODE(dtype::QuantizedS32, dtype::Quantized8Asymm) \ - } else if ( \ - param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ - dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ - DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ } else if ( \ param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ @@ -701,12 +685,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( } #define DISPATCH() \ - if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ - dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ - DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ - } else if ( \ - param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ - dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ + if (param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ + dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, dtype::Quantized8Asymm) \ } diff --git a/dnn/src/fallback/elemwise_helper/elemwise_op.h b/dnn/src/fallback/elemwise_helper/elemwise_op.h index f7b935cfc..1ed9ef9aa 100644 --- a/dnn/src/fallback/elemwise_helper/elemwise_op.h +++ b/dnn/src/fallback/elemwise_helper/elemwise_op.h @@ -12,61 +12,4 @@ #include "src/fallback/general_intrinsic/gi_float.h" #include "src/fallback/general_intrinsic/gi_int.h" -namespace megdnn { -namespace elemwise { - -///////////////////////////////// ParamElemVistor /////////////////////////// - -#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ - template <> \ - struct ParamElemVisitor<_ctype> { \ - _simd_type operator()(const _ctype* src) const { \ - return GiLoad##_fun_suffix(src); \ - } \ - }; \ - template <> \ - struct ParamElemVisitorDup<_ctype> { \ - _simd_type operator()(const _ctype* src) const { \ - return GiBroadcast##_fun_suffix( \ - *reinterpret_cast(src)); \ - } \ - } -cb(dt_qint32, int32_t, GI_INT32_t, Int32); -cb(dt_qint8, int8_t, GI_INT8_t, Int8); - -cb(dt_float32, float, GI_FLOAT32_t, Float32); -cb(dt_int32, int32_t, GI_INT32_t, Int32); -cb(dt_int8, int8_t, GI_INT8_t, Int8); -#undef cb - -template -struct ParamElemVisitorBcast101x4; -#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \ - template <> \ - struct ParamElemVisitorBcast101x4<_ctype> { \ - _simd_type operator()(const _ctype* src) const { \ - return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \ - *reinterpret_cast(src))); \ - } \ - } - -cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32); -cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32); -#undef cb -#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ - template <> \ - struct ParamElemVisitorBcast101x4<_ctype> { \ - _simd_type operator()(const _ctype* src) const { \ - return GiLoad##_fun_suffix(src); \ - } \ - } - -cb(dt_qint32, int32_t, GI_INT32_t, Int32); -cb(dt_float32, float, GI_FLOAT32_t, Float32); -cb(dt_int32, int32_t, GI_INT32_t, Int32); -#undef cb - -} // namespace elemwise -} // namespace megdnn - // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h b/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h index 72abc2439..987de10b0 100644 --- a/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h +++ b/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h @@ -87,7 +87,7 @@ template <> struct FuseAddHSwishOp : FuseAddHSwishOpBase { using FuseAddHSwishOpBase::FuseAddHSwishOpBase; using FuseAddHSwishOpBase::operator(); - constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); void operator()( const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, dt_qint8* dst) const { diff --git a/dnn/src/fallback/elemwise_helper/kimpl/op_base.h b/dnn/src/fallback/elemwise_helper/kimpl/op_base.h index 5affdd1a6..e5b899726 100644 --- a/dnn/src/fallback/elemwise_helper/kimpl/op_base.h +++ b/dnn/src/fallback/elemwise_helper/kimpl/op_base.h @@ -41,7 +41,7 @@ struct UnaryOpBase : OpBase { GiStoreLowInt8( \ reinterpret_cast(dst + 8), \ operator()({{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}})); \ - GI_INT16_t vsrct2 = GiMoveHighLongInt8(vsrc.val[1]); \ + GI_INT16_t vsrct2 = GiMoveLowLongInt8(vsrc.val[1]); \ GiStoreLowInt8( \ reinterpret_cast(dst + 16), \ operator()({{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \ @@ -330,7 +330,7 @@ struct UnaryQuantizationOp; template struct UnaryQuantizationOp : UnaryOpBase { using UnaryOpBase::UnaryOpBase; - constexpr static size_t SIMD_WIDTH = 16; + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); Op op; void operator()(const dt_qint8& src, dt_qint8* dst) const { @@ -354,7 +354,7 @@ struct UnaryQuantizationOp : UnaryOpBaseop({{vitem0, vitem1}}); val.val[0] = GiMultiplyFloat32(val.val[0], this->vscale_dst); val.val[1] = GiMultiplyFloat32(val.val[1], this->vscale_dst); - return QConverter::convert(val); + return QConverter::convert(val); } }; @@ -364,7 +364,7 @@ struct BinaryQuantizationOp; template struct BinaryQuantizationOp : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - constexpr static size_t SIMD_WIDTH = 16; + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); Op op; void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { @@ -403,7 +403,7 @@ template struct TernaryQuantizationOp : TernaryOpBase { using TernaryOpBase::TernaryOpBase; - constexpr static size_t SIMD_WIDTH = 16; + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); Op op; void operator()( diff --git a/dnn/src/fallback/elemwise_helper/kimpl/relu.h b/dnn/src/fallback/elemwise_helper/kimpl/relu.h index bd7412d64..ffddb422c 100644 --- a/dnn/src/fallback/elemwise_helper/kimpl/relu.h +++ b/dnn/src/fallback/elemwise_helper/kimpl/relu.h @@ -69,7 +69,7 @@ struct ReluOpBase : UnaryOpBase { template <> struct ReluOp : ReluOpBase { using ReluOpBase::ReluOpBase; - constexpr static size_t SIMD_WIDTH = 16; + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); using ReluOpBase::operator(); void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const { diff --git a/dnn/src/fallback/elemwise_helper/op_common.h b/dnn/src/fallback/elemwise_helper/op_common.h index 11b9d3d46..a0348b47f 100644 --- a/dnn/src/fallback/elemwise_helper/op_common.h +++ b/dnn/src/fallback/elemwise_helper/op_common.h @@ -8,6 +8,7 @@ namespace megdnn { namespace elemwise { + /*! * \brief broadcast type * BCAST_x[0]x[1]...: x[i] == !stride[i] @@ -49,6 +50,55 @@ struct ParamElemVisitorDup; template struct ParamElemVisitorBcast101x4; +#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ + template <> \ + struct ParamElemVisitor<_ctype> { \ + _simd_type operator()(const _ctype* src) const { \ + return GiLoad##_fun_suffix(src); \ + } \ + }; \ + template <> \ + struct ParamElemVisitorDup<_ctype> { \ + _simd_type operator()(const _ctype* src) const { \ + return GiBroadcast##_fun_suffix( \ + *reinterpret_cast(src)); \ + } \ + } +cb(dt_qint32, int32_t, GI_INT32_t, Int32); +cb(dt_qint8, int8_t, GI_INT8_t, Int8); + +cb(dt_float32, float, GI_FLOAT32_t, Float32); +cb(dt_int32, int32_t, GI_INT32_t, Int32); +cb(dt_int8, int8_t, GI_INT8_t, Int8); +#undef cb + +template +struct ParamElemVisitorBcast101x4; +#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \ + template <> \ + struct ParamElemVisitorBcast101x4<_ctype> { \ + _simd_type operator()(const _ctype* src) const { \ + return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \ + *reinterpret_cast(src))); \ + } \ + } + +cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32); +cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32); +#undef cb +#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ + template <> \ + struct ParamElemVisitorBcast101x4<_ctype> { \ + _simd_type operator()(const _ctype* src) const { \ + return GiLoad##_fun_suffix(src); \ + } \ + } + +cb(dt_qint32, int32_t, GI_INT32_t, Int32); +cb(dt_float32, float, GI_FLOAT32_t, Float32); +cb(dt_int32, int32_t, GI_INT32_t, Int32); +#undef cb + ///////////////////////////////// OpCaller ///////////////////////////// template struct OpCallerUnary; diff --git a/dnn/src/fallback/elemwise_multi_type/opr_impl.h b/dnn/src/fallback/elemwise_multi_type/opr_impl.h index 9d6035200..bcbabfa08 100644 --- a/dnn/src/fallback/elemwise_multi_type/opr_impl.h +++ b/dnn/src/fallback/elemwise_multi_type/opr_impl.h @@ -50,6 +50,18 @@ protected: void on_fuse_mul_add3_uint8xf32xf32xf32( const ElemwiseOpParamN<3>& param, const TensorND& dst) override; + void on_quantized_mode( + const ElemwiseOpParamN<1>& param, const TensorND& dst, + Elemwise::Mode mode) override; + + void on_quantized_mode( + const ElemwiseOpParamN<2>& param, const TensorND& dst, + Elemwise::Mode mode) override; + + void on_quantized_mode( + const ElemwiseOpParamN<3>& param, const TensorND& dst, + Elemwise::Mode mode) override; + public: using naive::ElemwiseMultiTypeImpl::ElemwiseMultiTypeImpl; }; diff --git a/dnn/src/fallback/elemwise_multi_type/quantized_impl.cpp b/dnn/src/fallback/elemwise_multi_type/quantized_impl.cpp new file mode 100644 index 000000000..77e1542e8 --- /dev/null +++ b/dnn/src/fallback/elemwise_multi_type/quantized_impl.cpp @@ -0,0 +1,499 @@ +/** + * \file dnn/src/fallback/elemwise_multi_type/quantized_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megdnn/tensor_iter.h" +#include "src/fallback/elemwise_helper/elemwise_op.h" +#include "src/fallback/elemwise_multi_type/opr_impl.h" +#include "src/naive/handle.h" + +using namespace megdnn; +using namespace fallback; +using namespace elemwise; + +void ElemwiseMultiTypeImpl::on_quantized_mode( + const ElemwiseOpParamN<1>& param, const TensorND& dst, Elemwise::Mode mode) { + megdnn_assert(param[0].layout.dtype.category() == DTypeCategory::QUANTIZED); + megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED); + +#define DISPATCH_MODE(_src_dt, _dst_dt) \ + switch (mode) { \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::RELU, ReluOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::H_SWISH, HSwishOp) \ + default: \ + break; \ + } + +#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ + switch (mode) { \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::RELU, ReluOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ABS, AbsOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::SIGMOID, SigmoidOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::EXP, ExpOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::TANH, TanhOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FAST_TANH, FastTanhOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::H_SWISH, HSwishOp) \ + default: \ + break; \ + } + +#define DISPATCH() \ + if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ + DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ + } else if ( \ + param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ + DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ + } + + TensorND src = param[0]; + + size_t nr_elems = src.layout.total_nr_elems(); +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function run = \ + OpCallerUnary<_op, VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src.ptr(), dst.ptr(), src.layout.dtype, \ + dst.layout.dtype, nr_elems)); \ + return; \ + } + + DISPATCH() + + naive::ElemwiseMultiTypeImpl::on_quantized_mode(param, dst, mode); + +#undef DISPATCH_SINGLE_MODE +#undef DISPATCH +#undef DISPATCH_QUANTIZED_MODE +#undef DISPATCH_MODE +} + +void ElemwiseMultiTypeImpl::on_quantized_mode( + const ElemwiseOpParamN<2>& param, const TensorND& dst, Elemwise::Mode mode) { + megdnn_assert( + param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv() && + param[0].layout.dtype.category() == DTypeCategory::QUANTIZED); + megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED); + +#define DISPATCH_MODE(_src_dt, _dst_dt) \ + switch (mode) { \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ADD, AddOp) \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_RELU, FuseAddReluOp) \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_H_SWISH, FuseAddHSwishOp) \ + default: \ + break; \ + } + +#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ + switch (mode) { \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ADD, AddOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MIN, MinOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MAX, MaxOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::SUB, SubOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MUL, MulOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::TRUE_DIV, TrueDivOp) \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_RELU, FuseAddReluOp) \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_SIGMOID, FuseAddSigmoidOp) \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_TANH, FuseAddTanhOp) \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_H_SWISH, FuseAddHSwishOp) \ + default: \ + break; \ + } + +#define DISPATCH() \ + if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ + DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ + } else if ( \ + param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ + DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ + } + + TensorND src0 = param[0]; + TensorND src1 = param[1]; + + //! VEC + VEC + if (is_vector(src0.layout) && is_vector(src1.layout)) { + size_t nr_elems = src0.layout.total_nr_elems(); +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, VEC_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ + src0.ptr(), src1.ptr(), dst.ptr(), \ + src0.layout.dtype, src1.layout.dtype, dst.layout.dtype, nr_elems)); \ + return; \ + } + + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + + //! VEC + SCALAR + { + bool normal_case = is_vector(src0.layout) && is_broadcasted_scalar(src1.layout); + bool swap_case = false; + bool commutable = false; + if (mode != Elemwise::Mode::SUB && mode != Elemwise::Mode::TRUE_DIV) + commutable = true; + if (!normal_case && commutable) { + swap_case = is_vector(src1.layout) && is_broadcasted_scalar(src0.layout); + } + if (normal_case || swap_case) { + auto &lhs = src0, &rhs = src1; + if (swap_case) { + std::swap(lhs, rhs); + } +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, VEC_SCALAR>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr()[0], \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, src0.layout.total_nr_elems())); \ + return; \ + } + + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + + //! SCALAR + VEC + if (!commutable && is_vector(src1.layout) && + is_broadcasted_scalar(src0.layout)) { +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, SCALAR_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr()[0], src1.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, src1.layout.total_nr_elems())); \ + return; \ + } + + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + } + + //! VEC + BCAST101 + { + BroadcastChannelInfo binfo; + bool normal_case = is_vector(src0.layout) && + is_broadcasted_channel_like(src1.layout, binfo); + bool swap_case = false; + bool commutable = false; + if (mode != Elemwise::Mode::SUB && mode != Elemwise::Mode::TRUE_DIV) + commutable = true; + if (!normal_case && commutable) { + swap_case = is_vector(src1.layout) && + is_broadcasted_channel_like(src0.layout, binfo); + } + if (normal_case || swap_case) { + auto &lhs = src0, &rhs = src1; + if (swap_case) + std::swap(lhs, rhs); +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, VEC_BCAST101>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ + return; \ + } + + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + + //! BCAST101 + VEC : only for SUB or TRUE_DIV + if (!commutable && is_vector(src1.layout) && + is_broadcasted_channel_like(src0.layout, binfo)) { +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, BCAST101_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ + return; \ + } + + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + } + + //! VEC + BCAST101x4 + { + BroadcastChannelInfo binfo; + if (is_vector(src0.layout) && + (is_broadcastedx_channel_like<4>(src1.layout, binfo) || + is_broadcastedx_channel_like<8>(src1.layout, binfo))) { +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, VEC_BCAST101xX>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ + return; \ + } + size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + + //! BCAST101x + VEC + if (is_vector(src1.layout) && + is_broadcastedx_channel_like<4>(src0.layout, binfo)) { +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, BCAST101xX_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ + return; \ + } + size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + } + + naive::ElemwiseMultiTypeImpl::on_quantized_mode(param, dst, mode); + +#undef DISPATCH_MODE +#undef DISPATCH_QUANTIZED_MODE +#undef DISPATCH +} + +void ElemwiseMultiTypeImpl::on_quantized_mode( + const ElemwiseOpParamN<3>& param, const TensorND& dst, Elemwise::Mode mode) { + megdnn_assert( + param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv() && + param[0].layout.dtype.enumv() == param[2].layout.dtype.enumv() && + param[0].layout.dtype.category() == DTypeCategory::QUANTIZED); + megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED); + +#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ + switch (mode) { \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_MUL_ADD3, FuseMulAdd3Op) \ + default: \ + break; \ + } + +#define DISPATCH() \ + if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ + DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ + } + + TensorND src0 = param[0]; + TensorND src1 = param[1]; + TensorND src2 = param[2]; + + //! VEC + VEC + VEC + if (is_vector(src0.layout) && is_vector(src1.layout) && is_vector(src2.layout)) { + size_t nr_elems = src0.layout.total_nr_elems(); +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerTernary<_op, VEC_VEC_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ + src0.ptr(), src1.ptr(), src2.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + src2.layout.dtype, dst.layout.dtype, nr_elems)); \ + return; \ + } + + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + + //! VEC + VEC + SCALAR + if (is_vector(src0.layout) && is_vector(src1.layout) && + is_broadcasted_scalar(src2.layout)) { +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerTernary<_op, VEC_VEC_SCALAR>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + src2.ptr()[0], dst.ptr(), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + src0.layout.total_nr_elems())); \ + return; \ + } + + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + + //! BCAST101 + VEC + BCAST101 + { + BroadcastChannelInfo binfo; + bool normal_case = is_vector(src1.layout) && + is_broadcasted_channel_like(src0.layout, binfo) && + src0.layout.eq_shape(src2.layout); + if (normal_case) { +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerTernary< \ + _op, BCAST101_VEC_BCAST101>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + src2.ptr(), dst.ptr(), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \ + binfo.y, binfo.z, binfo.y* binfo.z)); \ + return; \ + } + + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + } + + //! VEC + BCAST101x4 + VEC + { + BroadcastChannelInfo binfo; + if (is_vector(src0.layout) && + (is_broadcastedx_channel_like<4>(src1.layout, binfo) || + is_broadcastedx_channel_like<8>(src1.layout, binfo)) && + src0.layout.eq_shape(src2.layout)) { +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerTernary< \ + _op, VEC_BCAST101xX_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + src2.ptr(), dst.ptr(), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + batch_size, binfo.x, binfo.y, binfo.z)); \ + return; \ + } + + size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + + //! BCAST101x + VEC +BCAST101x + if (is_vector(src1.layout) && + (is_broadcastedx_channel_like<4>(src0.layout, binfo) || + is_broadcastedx_channel_like<8>(src0.layout, binfo)) && + src0.layout.eq_shape(src2.layout)) { +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerTernary< \ + _op, BCAST101xX_VEC_BCAST101xX>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + src2.ptr(), dst.ptr(), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + batch_size, binfo.x, binfo.y, binfo.z)); \ + return; \ + } + + size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + } + + naive::ElemwiseMultiTypeImpl::on_quantized_mode(param, dst, mode); +#undef DISPATCH +#undef DISPATCH_QUANTIZED_MODE +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/general_intrinsic/gi_common.h b/dnn/src/fallback/general_intrinsic/gi_common.h index 6f6418e34..4beed1325 100644 --- a/dnn/src/fallback/general_intrinsic/gi_common.h +++ b/dnn/src/fallback/general_intrinsic/gi_common.h @@ -60,6 +60,7 @@ #define GI_NEON_INTRINSICS #if defined(__aarch64__) #define GI_NEON64_INTRINSICS +#define GI_NEON32_INTRINSICS #else #define GI_NEON32_INTRINSICS #endif diff --git a/dnn/test/arm_common/elemwise_multi_type.cpp b/dnn/test/arm_common/elemwise_multi_type.cpp index 7e8e5b3a9..ae792b2a7 100644 --- a/dnn/test/arm_common/elemwise_multi_type.cpp +++ b/dnn/test/arm_common/elemwise_multi_type.cpp @@ -11,8 +11,10 @@ */ #include "test/common/elemwise_multi_type.h" +#include "megdnn/opr_param_defs.h" #include "megdnn/oprs.h" #include "test/arm_common/fixture.h" +#include "test/common/benchmarker.h" #include "test/common/checker.h" #include "test/common/task_record_check.h" #include "test/common/timer.h" @@ -559,4 +561,95 @@ TEST_F(ARM_COMMON, ELEMWISE_FMA3_UINT8xF32xF32xF32_RECORD) { .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}}); } +#if MEGDNN_WITH_BENCHMARK +namespace { +void run_elemwise_benchmark( + const TensorShapeArray& shapes, ElemwiseMultiType::Param::Mode mode, + const char* mode_str, std::vector types, Handle* handle_bench) { + auto handle_fallback = create_cpu_handle(1); + Benchmarker benchmarker_bench(handle_bench); + Benchmarker benchmarker_fallback(handle_fallback.get()); + + float throughput = 0; + SmallVector layouts; + std::string src_strs; + for (size_t i = 0; i < shapes.size(); i++) { + layouts.emplace_back(shapes[i], types[i]); + throughput += layouts.back().span().dist_byte(); + src_strs += layouts.back().to_string(); + if (i != shapes.size() - 1) { + src_strs += ","; + } + } + constexpr size_t RUN = 50; + benchmarker_fallback.set_times(RUN).set_display(false); + benchmarker_bench.set_times(RUN).set_display(false); + + benchmarker_fallback.set_param(mode); + benchmarker_bench.set_param(mode); + + TensorLayout dst_layout; + dst_layout.dtype = types.back(); + auto opr = handle_bench->create_operator(); + opr->param() = mode; + opr->deduce_layout(layouts, dst_layout); + + float computations = + dst_layout.total_nr_elems() * (std::max(shapes.size(), 2) - 1); + throughput += dst_layout.span().dist_byte(); + computations *= (1e3 / (1024.0 * 1024)); + throughput *= (1e3 / (1024.0 * 1024)); + + layouts.emplace_back(dst_layout); + auto fallback_time = benchmarker_fallback.execl(layouts) / RUN; + auto bench_time = benchmarker_bench.execl(layouts) / RUN; + + float fallback_flops = computations / fallback_time; + float bench_flops = computations / bench_time; + float fallback_thr = throughput / fallback_time; + float bench_thr = throughput / bench_time; + + printf("%s = %s (mode: %s) cpu=%fMFLOPS %fMB/s, bench=%fMFLOPS " + "%fMB/s " + "computations: %fx, throughput: %fx\n", + src_strs.c_str(), dst_layout.to_string().c_str(), mode_str, fallback_flops, + fallback_thr, bench_flops, bench_thr, bench_flops / fallback_flops, + bench_thr / fallback_thr); +} +} // namespace + +#define RUN_WITH_MODE(shape, mode, types) \ + run_elemwise_benchmark(shape, mode, #mode, types, handle()); + +TEST_F(ARM_COMMON, BENCHMARK_UNARY_MULTI_TYPE) { + using Mode = ElemwiseMultiType::Param::Mode; + for (auto mode : + {Mode::QRELU, Mode::QABS, Mode::QSIGMOID, Mode::QEXP, Mode::QTANH, + Mode::QFAST_TANH, Mode::QH_SWISH}) { + std::vector types = {dtype::QuantizedS8(1.4f), dtype::QuantizedS8(3.4f)}; + TensorShapeArray shapes = {{10000}}; + RUN_WITH_MODE(shapes, mode, types); + std::vector types2 = { + dtype::QuantizedS32(1.4f), dtype::QuantizedS8(3.4f)}; + RUN_WITH_MODE(shapes, mode, types2); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_BINARY_MULTI_TYPE) { + using Mode = ElemwiseMultiType::Param::Mode; + for (auto mode : {Mode::QADD, Mode::QFUSE_ADD_RELU, Mode::QFUSE_ADD_H_SWISH}) { + std::vector types = { + dtype::QuantizedS8(1.4f), dtype::QuantizedS8(3.4f), + dtype::QuantizedS8(1.6f)}; + TensorShapeArray shapes = {{10000}, {10000}}; + RUN_WITH_MODE(shapes, mode, types); + std::vector types2 = { + dtype::QuantizedS32(1.4f), dtype::QuantizedS32(3.4f), + dtype::QuantizedS8(1.6f)}; + RUN_WITH_MODE(shapes, mode, types2); + } +} + +#endif + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/test/fallback/elemwise_multi_type.cpp b/dnn/test/fallback/elemwise_multi_type.cpp index 7e0566f6d..9069ac8d8 100644 --- a/dnn/test/fallback/elemwise_multi_type.cpp +++ b/dnn/test/fallback/elemwise_multi_type.cpp @@ -26,6 +26,175 @@ TYPED_TEST(FALLBACK_ELEMWISE_MULTI_TYPE, run) { elemwise_multi_type::run_test(this->handle()); } +TEST_F(FALLBACK, ELEMWISE_QUANTIZED_MODE_UNARY) { + using Mode = ElemwiseMultiType::Param::Mode; + Checker checker(handle()); + + std::unique_ptr rng; + for (auto mode : + {Mode::QRELU, Mode::QABS, Mode::QSIGMOID, Mode::QEXP, Mode::QTANH, + Mode::QFAST_TANH, Mode::QH_SWISH}) { + checker.set_param({mode}); + + for (DType src_type : + std::vector{dtype::QuantizedS8(1.4f), dtype::QuantizedS32(1.3f)}) { + checker.set_dtype(0, src_type); + if (src_type.enumv() == DTypeEnum::QuantizedS8) { + rng = std::make_unique(-127, 127); + checker.set_dtype(1, dtype::QuantizedS8(1.7f)); + } else { + rng = std::make_unique(INT16_MIN >> 1, INT16_MAX >> 1); + } + + checker.set_rng(0, rng.get()); + auto run = [&]() { + checker.execs({{3, 4, 5, 6}, {}}); + + checker.execs({{3}, {}}); + checker.execs({{9}, {}}); + checker.execs({{17}, {}}); + }; + + if (src_type.enumv() == DTypeEnum::QuantizedS32) { + for (DType dst_type : + std::vector{dtype::QuantizedS8(32718.6f)}) { + checker.set_dtype(1, dst_type); + run(); + } + } else { + run(); + } + } + } +} + +TEST_F(FALLBACK, ELEMWISE_QUANTIZED_MODE_BINARY) { + using Mode = ElemwiseMultiType::Param::Mode; + Checker checker(handle()); + auto run = [&]() { + //! nchw44 + checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); + checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); + checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); + checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); + checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); + checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); + //! VEC + SCALAR + checker.execs({{3, 4, 5, 6}, {1, 1, 1, 1}, {}}); + checker.execs({{1, 1, 1, 1}, {3, 4, 5, 6}, {}}); + checker.execs({{3, 4, 5, 6}, {1}, {}}); + checker.execs({{1}, {3, 4, 5, 6}, {}}); + + //! VEC + 1C11 + checker.execs({{3, 4, 5, 6}, {1, 4, 1, 1}, {}}); + checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {}}); + + //! VEC + VEC + checker.execs({{3}, {3}, {}}); + checker.execs({{9}, {9}, {}}); + checker.execs({{17}, {17}, {}}); + }; + + // qint32 to qint8/quint8 + for (auto mode : {Mode::QADD, Mode::QFUSE_ADD_RELU, Mode::QFUSE_ADD_H_SWISH}) { + checker.set_param({mode}); + UniformIntRNG rng{INT16_MIN >> 1, INT16_MAX >> 1}; + checker.set_rng(0, &rng) + .set_rng(1, &rng) + .set_dtype(0, dtype::QuantizedS32(1.3f)) + .set_dtype(1, dtype::QuantizedS32(1.2f)); + + for (DType dst_type : std::vector{dtype::QuantizedS8(32718.6f)}) { + checker.set_dtype(2, dst_type); + run(); + } + } + + for (auto mode : + {Mode::QMUL, Mode::QADD, Mode::QMIN, Mode::QMAX, Mode::QSUB, + Mode::QFUSE_ADD_RELU, Mode::QFUSE_ADD_SIGMOID, Mode::QFUSE_ADD_H_SWISH}) { + checker.set_param({mode}); + + // qint8 to qint8 + UniformIntRNG rng_int8{-127, 127}; + checker.set_rng(0, &rng_int8) + .set_rng(1, &rng_int8) + .set_dtype(0, dtype::QuantizedS8(1.35f)) + .set_dtype(1, dtype::QuantizedS8(1.15f)) + .set_dtype(2, dtype::QuantizedS8(1.75f)); + run(); + } + + //! TRUE_DIV : 0.0 / 0.0 will fail + checker.set_param({Mode::QTRUE_DIV}); + UniformIntRNG rng_int8_1{-127, 127}; + UniformIntRNG rng_int8_2{-127, -1}; + checker.set_rng(0, &rng_int8_1) + .set_rng(1, &rng_int8_2) + .set_dtype(0, dtype::QuantizedS8(1.4f)) + .set_dtype(1, dtype::QuantizedS8(1.1f)) + .set_dtype(2, dtype::QuantizedS8(1.7f)); + + run(); + //! TANH + checker.set_param({Mode::QFUSE_ADD_TANH}); + UniformIntRNG rng_int8{-5, 5}; + checker.set_rng(0, &rng_int8) + .set_rng(1, &rng_int8) + .set_dtype(0, dtype::QuantizedS8(1.1f)) + .set_dtype(1, dtype::QuantizedS8(1.4f)) + .set_dtype(2, dtype::QuantizedS8(1.7f)); + + run(); +} + +TEST_F(FALLBACK, ELEMWISE_QUANTIZED_MODE_TERNARY) { + using Mode = ElemwiseMultiType::Param::Mode; + Checker checker(handle()); + + auto run = [&]() { + //! nchw44 + checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); + checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); + + //! nchw44 + checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); + checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); + checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); + checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); + + checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {1, 1, 1, 1}, {}}); + checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {1, 4, 1, 1}, {}}); + + checker.execs({{3}, {3}, {3}, {}}); + checker.execs({{9}, {9}, {9}, {}}); + checker.execs({{17}, {17}, {17}, {}}); + checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}, {}}); + }; + + for (auto mode : {Mode::QFUSE_MUL_ADD3}) { + checker.set_param({mode}); + + // qint8 to qint8 + UniformIntRNG rng_int8{-127, 127}; + checker.set_rng(0, &rng_int8) + .set_rng(1, &rng_int8) + .set_rng(2, &rng_int8) + .set_dtype(0, dtype::QuantizedS8(1.45f)) + .set_dtype(1, dtype::QuantizedS8(1.15f)) + .set_dtype(2, dtype::QuantizedS8(1.75f)) + .set_dtype(3, dtype::QuantizedS8(1.35f)); + run(); + } +} + TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_RECORD_FMA3_INT16x32x32x32) { TaskRecordChecker checker{1}; checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32}); -- GitLab