From 39d98d4525b23214466381804d3a46f1f8c35e7d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 9 Mar 2022 18:31:34 +0800 Subject: [PATCH] feat(fallback): add fallback typecvt with general intrinsic GitOrigin-RevId: 1e6fcd929b02e4745b4a641e5101df8f624d4bea --- dnn/src/fallback/elemwise_helper/kimpl/relu.h | 81 +++-- .../fallback/general_intrinsic/gi_common.h | 53 +++ dnn/src/fallback/general_intrinsic/gi_float.h | 29 +- dnn/src/fallback/general_intrinsic/gi_int.h | 78 ++--- dnn/src/fallback/quantized_converter.h | 16 +- dnn/src/fallback/type_cvt/opr_impl.cpp | 60 +++- dnn/src/fallback/type_cvt/opr_impl.h | 2 + dnn/src/fallback/type_cvt/typecvt_helper.h | 302 ++++++++++++++++++ 8 files changed, 485 insertions(+), 136 deletions(-) create mode 100644 dnn/src/fallback/type_cvt/typecvt_helper.h diff --git a/dnn/src/fallback/elemwise_helper/kimpl/relu.h b/dnn/src/fallback/elemwise_helper/kimpl/relu.h index 6614fe645..bd7412d64 100644 --- a/dnn/src/fallback/elemwise_helper/kimpl/relu.h +++ b/dnn/src/fallback/elemwise_helper/kimpl/relu.h @@ -20,36 +20,37 @@ struct ReluOpBase : UnaryOpBase { template struct ReluOp; -#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ - template <> \ - struct ReluOp<_ctype> : ReluOpBase<_ctype> { \ - using ReluOpBase::ReluOpBase; \ - using ReluOpBase::operator(); \ - constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _simd_type2& src, _ctype* dst) const { \ - auto vitem = operator()(src); \ - GiStore##_func_suffix(dst, vitem.val[0]); \ - GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ - } \ - _simd_type2 operator()(const _simd_type2& src) const { \ - auto vzero = GiBroadcast##_func_suffix(0); \ - auto vitem0 = GiMaximum##_func_suffix(src.val[0], vzero); \ - auto vitem1 = GiMaximum##_func_suffix(src.val[1], vzero); \ - return {{vitem0, vitem1}}; \ - } \ - void operator()(const _simd_type& src, _ctype* dst) const { \ - auto vitem = operator()(src); \ - GiStore##_func_suffix(dst, vitem); \ - } \ - _simd_type operator()(const _simd_type& src) const { \ - auto vzero = GiBroadcast##_func_suffix(0); \ - return GiMaximum##_func_suffix(src, vzero); \ - } \ +#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width, zero) \ + template <> \ + struct ReluOp<_ctype> : ReluOpBase<_ctype> { \ + using ReluOpBase::ReluOpBase; \ + using ReluOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _simd_type2& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _simd_type2 operator()(const _simd_type2& src) const { \ + auto vitem0 = GiMaximum##_func_suffix(src.val[0], zero); \ + auto vitem1 = GiMaximum##_func_suffix(src.val[1], zero); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()(const _simd_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + GiStore##_func_suffix(dst, vitem); \ + } \ + _simd_type operator()(const _simd_type& src) const { \ + return GiMaximum##_func_suffix(src, zero); \ + } \ }; -OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) -OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) -OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) +OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float), + vfzero) +OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t), + vzero) +OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t), + vzero_int8) #undef OP template <> @@ -75,11 +76,10 @@ struct ReluOp : ReluOpBase { OPERATOR_UNARY_QINT8_FALLBACK; } GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const { - auto vzero = GiBroadcastFloat32(0.f); auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale); auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale); - vitem0 = GiMaximumFloat32(vitem0, vzero); - vitem1 = GiMaximumFloat32(vitem1, vzero); + vitem0 = GiMaximumFloat32(vitem0, vfzero); + vitem1 = GiMaximumFloat32(vitem1, vfzero); return QConverter::convert({{vitem0, vitem1}}); } }; @@ -114,12 +114,11 @@ struct ReluOp : ReluOpBase, FixupBase void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { vst1_s8(reinterpret_cast(dst), vget_low_s8(operator()(vsrc))); } - int8x16_t operator()(const int32x4x2_t& vsrc) const { int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier); int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier); - vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); - vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero()); + vitem0 = vmaxq_s32(vitem0, vzero); + vitem1 = vmaxq_s32(vitem1, vzero); auto tmp = vqmovn_s16(vcombine_s16( vqmovn_s32(vrshlq_s32(vitem0, vshift)), vqmovn_s32(vrshlq_s32(vitem1, vshift)))); @@ -127,7 +126,7 @@ struct ReluOp : ReluOpBase, FixupBase } int8x16_t operator()(const float32x4_t& vsrc) const { int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier); - vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); + vitem0 = vmaxq_s32(vitem0, vzero); vitem0 = vrshlq_s32(vitem0, vshift); int16x4_t vitem = vqmovn_s32(vitem0); auto tmp = vqmovn_s16(vcombine_s16(vitem, vitem)); @@ -135,13 +134,13 @@ struct ReluOp : ReluOpBase, FixupBase } void operator()(const int32x4_t& src, dt_qint8* dst) const { auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale); - vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); + vitem0 = vmaxq_f32(vitem0, vfzero); auto result = QConverter::convert(vitem0); vst1q_lane_s32(reinterpret_cast(dst), (int32x4_t)result, 0); } void operator()(const float32x4_t& src, dt_qint8* dst) const { auto vitem0 = vmulq_f32(src, this->vscale); - vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); + vitem0 = vmaxq_f32(vitem0, vfzero); auto result = QConverter::convert(vitem0); vst1q_lane_s32(reinterpret_cast(dst), (int32x4_t)result, 0); } @@ -165,19 +164,19 @@ struct ReluOp : ReluOpBase { GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const { auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale); auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale); - vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero()); - vitem1 = GiMaximumFloat32(vitem1, QConverterBase::vfzero()); + vitem0 = GiMaximumFloat32(vitem0, vfzero); + vitem1 = GiMaximumFloat32(vitem1, vfzero); return QConverter::convert({{vitem0, vitem1}}); } GI_INT8_t operator()(const GI_INT32_t& src) const { auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(src), this->vscale); - vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero()); + vitem0 = GiMaximumFloat32(vitem0, vfzero); return QConverter::convert(vitem0); } GI_INT8_t operator()(const GI_FLOAT32_t& src) const { auto vitem0 = GiMultiplyFloat32(src, this->vscale); - vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero()); + vitem0 = GiMaximumFloat32(vitem0, vfzero); return QConverter::convert(vitem0); } }; diff --git a/dnn/src/fallback/general_intrinsic/gi_common.h b/dnn/src/fallback/general_intrinsic/gi_common.h index 8be13e052..6f6418e34 100644 --- a/dnn/src/fallback/general_intrinsic/gi_common.h +++ b/dnn/src/fallback/general_intrinsic/gi_common.h @@ -213,4 +213,57 @@ GI_INT32_t GiXorInt32(GI_INT32_t Vector1, GI_INT32_t Vector2) { #endif } +GI_FORCEINLINE +GI_FLOAT32_t GiBroadcastFloat32(float Value) { +#if defined(GI_NEON_INTRINSICS) + return vdupq_n_f32(Value); +#elif defined(GI_SSE2_INTRINSICS) + return _mm_set1_ps(Value); +#else + GI_FLOAT32_t ret; + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { + ret[i] = Value; + } + return ret; +#endif +} + +GI_FORCEINLINE +GI_INT32_t GiBroadcastInt32(int32_t Value) { +#if defined(GI_NEON_INTRINSICS) + return vdupq_n_s32(Value); +#elif defined(GI_SSE2_INTRINSICS) + return _mm_set1_epi32(Value); +#else + GI_INT32_t ret; + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { + ret[i] = Value; + } + return ret; +#endif +} + +GI_FORCEINLINE +GI_INT8_t GiBroadcastInt8(int8_t Value) { +#if defined(GI_NEON_INTRINSICS) + return vdupq_n_s8(Value); +#elif defined(GI_SSE2_INTRINSICS) + return _mm_set1_epi8(Value); +#else + GI_INT8_t ret; + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { + ret[i] = Value; + } + return ret; +#endif +} + +__attribute__((unused)) const GI_INT8_t vzero_int8 = GiBroadcastInt8(0); +__attribute__((unused)) const GI_INT32_t vzero = GiBroadcastInt32(0); +__attribute__((unused)) const GI_FLOAT32_t vfzero = GiBroadcastFloat32(0.0f); +__attribute__((unused)) const GI_FLOAT32_t vfhalf = GiBroadcastFloat32(0.5f); +__attribute__((unused)) const GI_FLOAT32_t vfneg_half = GiBroadcastFloat32(-0.5f); +__attribute__((unused)) const GI_FLOAT32_t vfmin_int8 = GiBroadcastFloat32(-128.0f); +__attribute__((unused)) const GI_FLOAT32_t vfmax_int8 = GiBroadcastFloat32(127.0f); + // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/general_intrinsic/gi_float.h b/dnn/src/fallback/general_intrinsic/gi_float.h index a6b1ac1c6..3bdbca38f 100644 --- a/dnn/src/fallback/general_intrinsic/gi_float.h +++ b/dnn/src/fallback/general_intrinsic/gi_float.h @@ -71,20 +71,12 @@ GI_INT32_t GiRoundAsInt32(GI_FLOAT32_t Vector) { #if __ARM_ARCH >= 8 return vcvtaq_s32_f32(Vector); #else - float32x4_t vzero = vdupq_n_f32(0.f); - float32x4_t vfhalf = vdupq_n_f32(0.5f); - float32x4_t vfneg_half = vdupq_n_f32(-0.5f); - float32x4_t vinc0 = vbslq_f32(vcgeq_f32(Vector, vzero), vfhalf, vfneg_half); + float32x4_t vinc0 = vbslq_f32(vcgeq_f32(Vector, vfzero), vfhalf, vfneg_half); return vcvtq_s32_f32(vaddq_f32(Vector, vinc0)); #endif #elif defined(GI_SSE42_INTRINSICS) - __m128 vfzero = _mm_set1_ps(0.f); - __m128 vfhalf = _mm_set1_ps(0.5f); - __m128 vfneg_half = _mm_set1_ps(-0.5f); __m128 vinc0 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(Vector, vfzero)); - __m128 vres0 = _mm_add_ps(Vector, vinc0); - return _mm_castps_si128( - _mm_round_ps(vres0, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + return _mm_cvttps_epi32(_mm_add_ps(Vector, vinc0)); #else GI_INT32_t ret; for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { @@ -118,22 +110,7 @@ GI_FLOAT32_t GiCastToFloat32(GI_INT32_t Vector) { #else GI_FLOAT32_t ret; for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { - ret[i] = float(Vector[i]); - } - return ret; -#endif -} - -GI_FORCEINLINE -GI_FLOAT32_t GiBroadcastFloat32(float Value) { -#if defined(GI_NEON_INTRINSICS) - return vdupq_n_f32(Value); -#elif defined(GI_SSE2_INTRINSICS) - return _mm_set1_ps(Value); -#else - GI_FLOAT32_t ret; - for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { - ret[i] = Value; + ret[i] = (float)Vector[i]; } return ret; #endif diff --git a/dnn/src/fallback/general_intrinsic/gi_int.h b/dnn/src/fallback/general_intrinsic/gi_int.h index abb4e2b16..b2d95b566 100644 --- a/dnn/src/fallback/general_intrinsic/gi_int.h +++ b/dnn/src/fallback/general_intrinsic/gi_int.h @@ -13,21 +13,6 @@ #include "gi_common.h" -GI_FORCEINLINE -GI_INT32_t GiBroadcastInt32(int32_t Value) { -#if defined(GI_NEON_INTRINSICS) - return vdupq_n_s32(Value); -#elif defined(GI_SSE2_INTRINSICS) - return _mm_set1_epi32(Value); -#else - GI_INT32_t ret; - for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { - ret[i] = Value; - } - return ret; -#endif -} - GI_FORCEINLINE GI_UINT32_t GiBroadcastUint32(int32_t Value) { #if defined(GI_NEON_INTRINSICS) @@ -44,30 +29,31 @@ GI_UINT32_t GiBroadcastUint32(int32_t Value) { } GI_FORCEINLINE -GI_INT8_t GiBroadcastInt8(int8_t Value) { +GI_INT32_t GiLoadInt32(const void* Buffer) { #if defined(GI_NEON_INTRINSICS) - return vdupq_n_s8(Value); + return vld1q_s32((int32_t*)Buffer); #elif defined(GI_SSE2_INTRINSICS) - return _mm_set1_epi8(Value); + return _mm_loadu_si128((const __m128i*)Buffer); #else - GI_INT8_t ret; - for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { - ret[i] = Value; + GI_INT32_t ret; + const int32_t* ptr = (int32_t*)Buffer; + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { + ret[i] = ptr[i]; } return ret; #endif } GI_FORCEINLINE -GI_INT32_t GiLoadInt32(const void* Buffer) { +GI_INT16_t GiLoadInt16(const void* Buffer) { #if defined(GI_NEON_INTRINSICS) - return vld1q_s32((int32_t*)Buffer); + return vld1q_s16((int16_t*)Buffer); #elif defined(GI_SSE2_INTRINSICS) return _mm_loadu_si128((const __m128i*)Buffer); #else - GI_INT32_t ret; - const int32_t* ptr = (int32_t*)Buffer; - for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { + GI_INT16_t ret; + const int16_t* ptr = (int16_t*)Buffer; + for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int16_t); i++) { ret[i] = ptr[i]; } return ret; @@ -810,21 +796,12 @@ GI_INT8_t GiCvtFromFloat32ToInt8(GI_FLOAT32_t src) { int16x8_t mid_s16 = vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres0)); return vcombine_s8(vqmovn_s16(mid_s16), vqmovn_s16(mid_s16)); #else - float32x4_t vzero = vdupq_n_f32(0.f); - float32x4_t vfhalf = vdupq_n_f32(0.5f); - float32x4_t vfneg_half = vdupq_n_f32(-0.5f); - float32x4_t vinc0 = vbslq_f32(vcgeq_f32(src, vzero), vfhalf, vfneg_half); + float32x4_t vinc0 = vbslq_f32(vcgeq_f32(src, vfzero), vfhalf, vfneg_half); int32x4_t vres0 = vcvtq_s32_f32(vaddq_f32(src, vinc0)); int16x8_t mid_s16 = vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres0)); return vcombine_s8(vqmovn_s16(mid_s16), vqmovn_s16(mid_s16)); #endif #elif defined(GI_SSE42_INTRINSICS) - __m128 vfzero = _mm_set1_ps(0.f); - __m128 vfhalf = _mm_set1_ps(0.5f); - __m128 vfneg_half = _mm_set1_ps(-0.5f); - __m128 vfmin_int8 = _mm_set1_ps(-128.f); - __m128 vfmax_int8 = _mm_set1_ps(127.f); - __m128 vinc0 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(src, vfzero)); __m128 vres0 = _mm_add_ps(src, vinc0); vres0 = _mm_round_ps(vres0, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); @@ -857,23 +834,14 @@ GI_INT8_t GiCvtFromFloat32V2ToInt8(GI_FLOAT32_V2_t vsrc) { int8x8_t mid1 = vqmovn_s16(vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres1))); return vcombine_s8(mid1, mid1); #else - float32x4_t vzero = vdupq_n_f32(0.f); - float32x4_t vfhalf = vdupq_n_f32(0.5f); - float32x4_t vfneg_half = vdupq_n_f32(-0.5f); - float32x4_t vinc0 = vbslq_f32(vcgeq_f32(vsrc.val[0], vzero), vfhalf, vfneg_half); - float32x4_t vinc1 = vbslq_f32(vcgeq_f32(vsrc.val[1], vzero), vfhalf, vfneg_half); + float32x4_t vinc0 = vbslq_f32(vcgeq_f32(vsrc.val[0], vfzero), vfhalf, vfneg_half); + float32x4_t vinc1 = vbslq_f32(vcgeq_f32(vsrc.val[1], vfzero), vfhalf, vfneg_half); int32x4_t vres0 = vcvtq_s32_f32(vaddq_f32(vsrc.val[0], vinc0)); int32x4_t vres1 = vcvtq_s32_f32(vaddq_f32(vsrc.val[1], vinc1)); int8x8_t mid1 = vqmovn_s16(vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres1))); return vcombine_s8(mid1, mid1); #endif #elif defined(GI_SSE42_INTRINSICS) - __m128 vfzero = _mm_set1_ps(0.f); - __m128 vfhalf = _mm_set1_ps(0.5f); - __m128 vfneg_half = _mm_set1_ps(-0.5f); - __m128 vfmin_int8 = _mm_set1_ps(-128.f); - __m128 vfmax_int8 = _mm_set1_ps(127.f); - __m128 vinc0 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(vsrc.val[0], vfzero)); __m128 vinc1 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(vsrc.val[1], vfzero)); @@ -913,13 +881,13 @@ GI_INT8_t GiCvtFromFloat32V4ToInt8(GI_FLOAT32_V4_t vsrc) { int8x8_t mid2 = vqmovn_s16(vcombine_s16(vqmovn_s32(vres2), vqmovn_s32(vres3))); return vcombine_s8(mid1, mid2); #else - float32x4_t vzero = vdupq_n_f32(0.f); + float32x4_t vfzero = vdupq_n_f32(0.f); float32x4_t vfhalf = vdupq_n_f32(0.5f); float32x4_t vfneg_half = vdupq_n_f32(-0.5f); - float32x4_t vinc0 = vbslq_f32(vcgeq_f32(vsrc.val[0], vzero), vfhalf, vfneg_half); - float32x4_t vinc1 = vbslq_f32(vcgeq_f32(vsrc.val[1], vzero), vfhalf, vfneg_half); - float32x4_t vinc2 = vbslq_f32(vcgeq_f32(vsrc.val[2], vzero), vfhalf, vfneg_half); - float32x4_t vinc3 = vbslq_f32(vcgeq_f32(vsrc.val[3], vzero), vfhalf, vfneg_half); + float32x4_t vinc0 = vbslq_f32(vcgeq_f32(vsrc.val[0], vfzero), vfhalf, vfneg_half); + float32x4_t vinc1 = vbslq_f32(vcgeq_f32(vsrc.val[1], vfzero), vfhalf, vfneg_half); + float32x4_t vinc2 = vbslq_f32(vcgeq_f32(vsrc.val[2], vfzero), vfhalf, vfneg_half); + float32x4_t vinc3 = vbslq_f32(vcgeq_f32(vsrc.val[3], vfzero), vfhalf, vfneg_half); int32x4_t vres0 = vcvtq_s32_f32(vaddq_f32(vsrc.val[0], vinc0)); int32x4_t vres1 = vcvtq_s32_f32(vaddq_f32(vsrc.val[1], vinc1)); int32x4_t vres2 = vcvtq_s32_f32(vaddq_f32(vsrc.val[2], vinc2)); @@ -929,12 +897,6 @@ GI_INT8_t GiCvtFromFloat32V4ToInt8(GI_FLOAT32_V4_t vsrc) { return vcombine_s8(mid1, mid2); #endif #elif defined(GI_SSE42_INTRINSICS) - __m128 vfzero = _mm_set1_ps(0.f); - __m128 vfhalf = _mm_set1_ps(0.5f); - __m128 vfneg_half = _mm_set1_ps(-0.5f); - __m128 vfmin_int8 = _mm_set1_ps(-128.f); - __m128 vfmax_int8 = _mm_set1_ps(127.f); - __m128 vinc0 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(vsrc.val[0], vfzero)); __m128 vinc1 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(vsrc.val[1], vfzero)); __m128 vinc2 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(vsrc.val[2], vfzero)); diff --git a/dnn/src/fallback/quantized_converter.h b/dnn/src/fallback/quantized_converter.h index bc0b4f7b8..73b8dc768 100644 --- a/dnn/src/fallback/quantized_converter.h +++ b/dnn/src/fallback/quantized_converter.h @@ -20,16 +20,6 @@ namespace megdnn { namespace fallback { -struct QConverterBase { - inline static GI_INT32_t vzero() { return GiBroadcastInt32(0); } - - inline static GI_FLOAT32_t vfzero() { return GiBroadcastFloat32(0.f); } - - inline static GI_FLOAT32_t vfhalf() { return GiBroadcastFloat32(0.5f); } - - inline static GI_FLOAT32_t vfneg_half() { return GiBroadcastFloat32(-0.5f); } -}; - struct QConverter { template static inline dst_type convert(const src_type&... src); @@ -66,6 +56,12 @@ template <> inline GI_INT8_t QConverter::convert(const GI_FLOAT32_V2_t& vsrc) { return GiCvtFromFloat32V2ToInt8(vsrc); } + +template <> +inline GI_INT8_t QConverter::convert(const GI_FLOAT32_V4_t& vsrc) { + return GiCvtFromFloat32V4ToInt8(vsrc); +} + template <> inline GI_INT8_t QConverter::convert(const GI_FLOAT32_t& src) { return GiCvtFromFloat32ToInt8(src); diff --git a/dnn/src/fallback/type_cvt/opr_impl.cpp b/dnn/src/fallback/type_cvt/opr_impl.cpp index 6f002b417..d48e5135d 100644 --- a/dnn/src/fallback/type_cvt/opr_impl.cpp +++ b/dnn/src/fallback/type_cvt/opr_impl.cpp @@ -9,6 +9,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/fallback/type_cvt/opr_impl.h" +#include "src/fallback/type_cvt/typecvt_helper.h" #include "midout.h" #include "src/common/utils.h" @@ -17,6 +18,7 @@ // MIDOUT_DECL(megdnn_fb_typecvt_src) MIDOUT_DECL(megdnn_fb_typecvt_dst_dtype) MIDOUT_DECL(megdnn_fb_typecvt_src_dtype) +MIDOUT_DECL(megdnn_fb_typecvt_optimized) namespace { @@ -513,12 +515,68 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { !is_quantize_lowbit(dst.layout.dtype) && dst.layout.dtype.enumv() != DTypeEnum::QuantizedS1 && src.layout.dtype.enumv() != DTypeEnum::QuantizedS1) { - MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst)); + if (!exec_optimized(src, dst)) { + MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst)); + } } else { naive::TypeCvtImpl::exec(src, dst); } } +bool TypeCvtImpl::exec_optimized(_megdnn_tensor_in src, _megdnn_tensor_out dst) { + DType src_dtype = src.layout.dtype; + DType dst_dtype = dst.layout.dtype; + bool execed = false; + using namespace dtype; + size_t nr_elems = src.layout.total_nr_elems(); +#define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ + if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ + dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ + MIDOUT_BEGIN(megdnn_fb_typecvt_optimized, midout_iv(_midout_iv)) { \ + using _TypeCvter = QuantizedTypeCvter<_stype, _dtype>; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ + src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \ + src_dtype, dst_dtype, nr_elems)); \ + execed = true; \ + } \ + MIDOUT_END(); \ + } + DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS8, int8_t, 1); + DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS32, int32_t, 2); + DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS8, int8_t, 3); + DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS32, int32_t, 4); + DISPATCH_QUANTIZED(float, float, QuantizedS8, int8_t, 5); +#undef DISPATCH_QUANTIZED + +#define DISPATCH_FIX2FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ + if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ + dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ + MIDOUT_BEGIN(megdnn_fb_typecvt_optimized, midout_iv(_midout_iv)) { \ + using _TypeCvter = Fix2FloatTypeCvter<_stype, _dtype>; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ + src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \ + src_dtype, dst_dtype, src.layout.total_nr_elems())); \ + execed = true; \ + } \ + MIDOUT_END(); \ + } + DISPATCH_FIX2FLOAT(Int16, int16_t, Float32, float, 6); + DISPATCH_FIX2FLOAT(Int8, int8_t, Float32, float, 7); + + if (src_dtype.enumv() == DTypeTrait::enumv && + dst_dtype.enumv() == DTypeTrait::enumv) { + MIDOUT_BEGIN(megdnn_fb_typecvt_optimized, midout_iv(8)) { + using TypeCvter = Quan2FloatTypeCvter; + MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt( + src.compatible_ptr(), dst.compatible_ptr(), + src_dtype, dst_dtype, src.layout.total_nr_elems())); + execed = true; + } + MIDOUT_END(); + } + return execed; +} + } // namespace fallback } // namespace megdnn diff --git a/dnn/src/fallback/type_cvt/opr_impl.h b/dnn/src/fallback/type_cvt/opr_impl.h index dea64fce1..de9719066 100644 --- a/dnn/src/fallback/type_cvt/opr_impl.h +++ b/dnn/src/fallback/type_cvt/opr_impl.h @@ -15,6 +15,8 @@ namespace megdnn { namespace fallback { class TypeCvtImpl : public naive::TypeCvtImpl { + bool exec_optimized(_megdnn_tensor_in src, _megdnn_tensor_out dst); + public: using naive::TypeCvtImpl::TypeCvtImpl; void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) override; diff --git a/dnn/src/fallback/type_cvt/typecvt_helper.h b/dnn/src/fallback/type_cvt/typecvt_helper.h new file mode 100644 index 000000000..a9d13c298 --- /dev/null +++ b/dnn/src/fallback/type_cvt/typecvt_helper.h @@ -0,0 +1,302 @@ +/** + * \file dnn/src/fallback/type_cvt/typecvt_helper.h + */ +#include "src/fallback/general_intrinsic/gi_float.h" +#include "src/fallback/general_intrinsic/gi_int.h" +#include "src/fallback/quantized_converter.h" + +namespace megdnn { +namespace fallback { + +template +struct QuantizedTypeCvter; + +template <> +struct QuantizedTypeCvter { + using stype = int32_t; + using dst_type = int8_t; + static constexpr size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t) * 2; + static constexpr size_t SIMD_STEP = GI_SIMD_LEN_BYTE / sizeof(int32_t); + float scale; + GI_FLOAT32_t vscale; + + QuantizedTypeCvter(DType src_dtype, DType dst_dtype) { + float src_scale = src_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + scale = src_scale / dst_scale; + vscale = GiBroadcastFloat32(scale); + } + + void cvt(const int32_t* src, int8_t* dst) { + GI_FLOAT32_t vitem0 = + GiMultiplyFloat32(GiCastToFloat32(GiLoadInt32(src)), vscale); + GI_FLOAT32_t vitem1 = GiMultiplyFloat32( + GiCastToFloat32(GiLoadInt32(src + SIMD_STEP)), vscale); + + auto vres = QConverter::convert({{vitem0, vitem1}}); + GiStoreLowInt8(dst, vres); + } + + void cvt_remain(const int32_t* src, int8_t* dst) { + *dst = saturate(std::round(*src * scale), -128.f, 127.f); + } +}; + +template <> +struct QuantizedTypeCvter { + using stype = int8_t; + using dst_type = int32_t; + static constexpr size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); + float scale; + GI_FLOAT32_t vscale; + + QuantizedTypeCvter(DType src_dtype, DType dst_dtype) { + float src_scale = src_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + scale = src_scale / dst_scale; + vscale = GiBroadcastFloat32(scale); + } + + void cvt(const int8_t* src, int32_t* dst) { + GI_INT8_t data = GiLoadInt8(src); + GI_INT16_t vitem0 = GiMoveLowLongInt8(data); + GI_INT16_t vitem1 = GiMoveHighLongInt8(data); + auto vret0 = QConverter::round( + GiMultiplyFloat32(GiCastToFloat32(GiMoveLowLongInt16(vitem0)), vscale)); + auto vret1 = QConverter::round(GiMultiplyFloat32( + GiCastToFloat32(GiMoveHighLongInt16(vitem0)), vscale)); + auto vret2 = QConverter::round( + GiMultiplyFloat32(GiCastToFloat32(GiMoveLowLongInt16(vitem1)), vscale)); + auto vret3 = QConverter::round(GiMultiplyFloat32( + GiCastToFloat32(GiMoveHighLongInt16(vitem1)), vscale)); + + constexpr size_t step = GI_SIMD_LEN_BYTE / sizeof(int32_t); + GiStoreInt32(dst, vret0); + GiStoreInt32(dst + step, vret1); + GiStoreInt32(dst + 2 * step, vret2); + GiStoreInt32(dst + 3 * step, vret3); + } + + void cvt_remain(const int8_t* src, int32_t* dst) { + *dst = saturate( + std::round(*src * scale), -2147483648.f, 2147483647.f); + } +}; + +template <> +struct QuantizedTypeCvter { + using stype = float; + using dst_type = int8_t; + static constexpr size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(float) * 2; + static constexpr size_t SIMD_STEP = GI_SIMD_LEN_BYTE / sizeof(float); + float scale; + GI_FLOAT32_t vscale; + + QuantizedTypeCvter(DType src_dtype, DType dst_dtype) { + MEGDNN_MARK_USED_VAR(src_dtype); + float src_scale = 1; + float dst_scale = dst_dtype.param().scale; + scale = src_scale / dst_scale; + vscale = GiBroadcastFloat32(scale); + } + + void cvt(const float* src, int8_t* dst) { + GI_FLOAT32_t vitem0 = GiMultiplyFloat32(GiLoadFloat32(src), vscale); + GI_FLOAT32_t vitem1 = GiMultiplyFloat32(GiLoadFloat32(src + SIMD_STEP), vscale); + + auto vres = QConverter::convert({{vitem0, vitem1}}); + GiStoreLowInt8(dst, vres); + } + + void cvt_remain(const float* src, int8_t* dst) { + *dst = saturate(std::round(*src * scale), -128.f, 127.f); + } +}; + +template <> +struct QuantizedTypeCvter { + using stype = int32_t; + using dst_type = int32_t; + static constexpr size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); + float scale; + GI_FLOAT32_t vscale; + + QuantizedTypeCvter(DType src_dtype, DType dst_dtype) { + float src_scale = src_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + scale = src_scale / dst_scale; + vscale = GiBroadcastFloat32(scale); + } + + void cvt(const int32_t* src, int32_t* dst) { + GI_FLOAT32_t vitem = + GiMultiplyFloat32(GiCastToFloat32(GiLoadInt32(src)), vscale); + + auto vres = QConverter::round(vitem); + GiStoreInt32(dst, vres); + } + + void cvt_remain(const int32_t* src, int32_t* dst) { + *dst = saturate( + std::round(*src * scale), -2147483648.f, 2147483647.f); + } +}; + +template <> +struct QuantizedTypeCvter { + using stype = int8_t; + using dst_type = int8_t; + static constexpr size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); + float scale; + GI_FLOAT32_t vscale; + + QuantizedTypeCvter(DType src_dtype, DType dst_dtype) { + float src_scale = src_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + scale = src_scale / dst_scale; + vscale = GiBroadcastFloat32(scale); + } + + void cvt(const int8_t* src, int8_t* dst) { + GI_INT8_t data = GiLoadInt8(src); + GI_INT16_t vitem0 = GiMoveLowLongInt8(data); + GI_INT16_t vitem1 = GiMoveHighLongInt8(data); + auto vret0 = + GiMultiplyFloat32(GiCastToFloat32(GiMoveLowLongInt16(vitem0)), vscale); + auto vret1 = + GiMultiplyFloat32(GiCastToFloat32(GiMoveHighLongInt16(vitem0)), vscale); + auto vret2 = + GiMultiplyFloat32(GiCastToFloat32(GiMoveLowLongInt16(vitem1)), vscale); + auto vret3 = + GiMultiplyFloat32(GiCastToFloat32(GiMoveHighLongInt16(vitem1)), vscale); + + auto vres = QConverter::convert( + {{vret0, vret1, vret2, vret3}}); + GiStoreInt8(dst, vres); + } + + void cvt_remain(const int8_t* src, int8_t* dst) { + *dst = saturate(std::round(*src * scale), -128.f, 127.f); + } +}; + +template +struct Fix2FloatTypeCvter; + +template +struct Quan2FloatTypeCvter; + +template <> +struct Fix2FloatTypeCvter { + using stype = int16_t; + using dst_type = float; + static constexpr size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int16_t); + static constexpr size_t SIMD_STEP = GI_SIMD_LEN_BYTE / sizeof(float); + + Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) { + MEGDNN_MARK_USED_VAR(src_dtype); + MEGDNN_MARK_USED_VAR(dst_dtype); + } + + void cvt(const int16_t* src, float* dst) { + GI_INT16_t vitem = GiLoadInt16(src); + auto vret0 = GiCastToFloat32(GiMoveLowLongInt16(vitem)); + auto vret1 = GiCastToFloat32(GiMoveHighLongInt16(vitem)); + GiStoreFloat32(dst, vret0); + GiStoreFloat32(dst + SIMD_STEP, vret1); + } + + void cvt_remain(const int16_t* src, float* dst) { *dst = *src; } +}; + +template <> +struct Fix2FloatTypeCvter { + using stype = int8_t; + using dst_type = float; + static constexpr size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); + static constexpr size_t SIMD_STEP = GI_SIMD_LEN_BYTE / sizeof(float); + + Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) { + MEGDNN_MARK_USED_VAR(src_dtype); + MEGDNN_MARK_USED_VAR(dst_dtype); + } + + void cvt(const int8_t* src, float* dst) { + GI_INT8_t data = GiLoadInt8(src); + GI_INT16_t vitem0 = GiMoveLowLongInt8(data); + GI_INT16_t vitem1 = GiMoveHighLongInt8(data); + auto vret0 = GiCastToFloat32(GiMoveLowLongInt16(vitem0)); + auto vret1 = GiCastToFloat32(GiMoveHighLongInt16(vitem0)); + auto vret2 = GiCastToFloat32(GiMoveLowLongInt16(vitem1)); + auto vret3 = GiCastToFloat32(GiMoveHighLongInt16(vitem1)); + GiStoreFloat32(dst, vret0); + GiStoreFloat32(dst + SIMD_STEP, vret1); + GiStoreFloat32(dst + 2 * SIMD_STEP, vret2); + GiStoreFloat32(dst + 3 * SIMD_STEP, vret3); + } + + void cvt_remain(const int8_t* src, float* dst) { *dst = *src; } +}; + +template <> +struct Quan2FloatTypeCvter { + using stype = int8_t; + using dst_type = float; + static constexpr size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); + static constexpr size_t SIMD_STEP = GI_SIMD_LEN_BYTE / sizeof(float); + float _scale = 0.0f; + GI_FLOAT32_t vscale; + + Quan2FloatTypeCvter(DType src_dtype, DType dst_dtype) { + _scale = src_dtype.param().scale; + vscale = GiBroadcastFloat32(_scale); + MEGDNN_MARK_USED_VAR(dst_dtype); + } + + void cvt(const int8_t* src, float* dst) { + GI_INT8_t data = GiLoadInt8(src); + GI_INT16_t vitem0 = GiMoveLowLongInt8(data); + GI_INT16_t vitem1 = GiMoveHighLongInt8(data); + auto vret0 = + GiMultiplyFloat32(GiCastToFloat32(GiMoveLowLongInt16(vitem0)), vscale); + auto vret1 = + GiMultiplyFloat32(GiCastToFloat32(GiMoveHighLongInt16(vitem0)), vscale); + auto vret2 = + GiMultiplyFloat32(GiCastToFloat32(GiMoveLowLongInt16(vitem1)), vscale); + auto vret3 = + GiMultiplyFloat32(GiCastToFloat32(GiMoveHighLongInt16(vitem1)), vscale); + + GiStoreFloat32(dst, vret0); + GiStoreFloat32(dst + SIMD_STEP, vret1); + GiStoreFloat32(dst + 2 * SIMD_STEP, vret2); + GiStoreFloat32(dst + 3 * SIMD_STEP, vret3); + } + void cvt_remain(const int8_t* src, float* dst) { *dst = *src * _scale; } +}; + +template +void do_typecvt( + const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst, + DType src_dtype, DType dst_dtype, size_t nr_elems) { + TypeCvter typecvt(src_dtype, dst_dtype); + size_t i = 0; + for (; i + TypeCvter::SIMD_WIDTH <= nr_elems; i += TypeCvter::SIMD_WIDTH) { + typecvt.cvt(src, dst); + src += TypeCvter::SIMD_WIDTH; + dst += TypeCvter::SIMD_WIDTH; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + typecvt.cvt_remain(src, dst); + src++; + dst++; + } +} + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen -- GitLab