提交 39d98d45 编写于 作者: M Megvii Engine Team

feat(fallback): add fallback typecvt with general intrinsic

GitOrigin-RevId: 1e6fcd929b02e4745b4a641e5101df8f624d4bea
上级 d2278f02
...@@ -20,7 +20,7 @@ struct ReluOpBase : UnaryOpBase<src_ctype, dst_ctype> { ...@@ -20,7 +20,7 @@ struct ReluOpBase : UnaryOpBase<src_ctype, dst_ctype> {
template <typename src_ctype, typename dst_type = src_ctype> template <typename src_ctype, typename dst_type = src_ctype>
struct ReluOp; struct ReluOp;
#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ #define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width, zero) \
template <> \ template <> \
struct ReluOp<_ctype> : ReluOpBase<_ctype> { \ struct ReluOp<_ctype> : ReluOpBase<_ctype> { \
using ReluOpBase::ReluOpBase; \ using ReluOpBase::ReluOpBase; \
...@@ -32,9 +32,8 @@ struct ReluOp; ...@@ -32,9 +32,8 @@ struct ReluOp;
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \ } \
_simd_type2 operator()(const _simd_type2& src) const { \ _simd_type2 operator()(const _simd_type2& src) const { \
auto vzero = GiBroadcast##_func_suffix(0); \ auto vitem0 = GiMaximum##_func_suffix(src.val[0], zero); \
auto vitem0 = GiMaximum##_func_suffix(src.val[0], vzero); \ auto vitem1 = GiMaximum##_func_suffix(src.val[1], zero); \
auto vitem1 = GiMaximum##_func_suffix(src.val[1], vzero); \
return {{vitem0, vitem1}}; \ return {{vitem0, vitem1}}; \
} \ } \
void operator()(const _simd_type& src, _ctype* dst) const { \ void operator()(const _simd_type& src, _ctype* dst) const { \
...@@ -42,14 +41,16 @@ struct ReluOp; ...@@ -42,14 +41,16 @@ struct ReluOp;
GiStore##_func_suffix(dst, vitem); \ GiStore##_func_suffix(dst, vitem); \
} \ } \
_simd_type operator()(const _simd_type& src) const { \ _simd_type operator()(const _simd_type& src) const { \
auto vzero = GiBroadcast##_func_suffix(0); \ return GiMaximum##_func_suffix(src, zero); \
return GiMaximum##_func_suffix(src, vzero); \
} \ } \
}; };
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float),
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) vfzero)
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t),
vzero)
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t),
vzero_int8)
#undef OP #undef OP
template <> template <>
...@@ -75,11 +76,10 @@ struct ReluOp<dt_qint8, dt_qint8> : ReluOpBase<dt_qint8, dt_qint8> { ...@@ -75,11 +76,10 @@ struct ReluOp<dt_qint8, dt_qint8> : ReluOpBase<dt_qint8, dt_qint8> {
OPERATOR_UNARY_QINT8_FALLBACK; OPERATOR_UNARY_QINT8_FALLBACK;
} }
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const { GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
auto vzero = GiBroadcastFloat32(0.f);
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale); auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale);
auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale); auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale);
vitem0 = GiMaximumFloat32(vitem0, vzero); vitem0 = GiMaximumFloat32(vitem0, vfzero);
vitem1 = GiMaximumFloat32(vitem1, vzero); vitem1 = GiMaximumFloat32(vitem1, vfzero);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}}); return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
} }
}; };
...@@ -114,12 +114,11 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase ...@@ -114,12 +114,11 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase
void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const {
vst1_s8(reinterpret_cast<int8_t*>(dst), vget_low_s8(operator()(vsrc))); vst1_s8(reinterpret_cast<int8_t*>(dst), vget_low_s8(operator()(vsrc)));
} }
int8x16_t operator()(const int32x4x2_t& vsrc) const { int8x16_t operator()(const int32x4x2_t& vsrc) const {
int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier); int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier);
int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier); int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier);
vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); vitem0 = vmaxq_s32(vitem0, vzero);
vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero()); vitem1 = vmaxq_s32(vitem1, vzero);
auto tmp = vqmovn_s16(vcombine_s16( auto tmp = vqmovn_s16(vcombine_s16(
vqmovn_s32(vrshlq_s32(vitem0, vshift)), vqmovn_s32(vrshlq_s32(vitem0, vshift)),
vqmovn_s32(vrshlq_s32(vitem1, vshift)))); vqmovn_s32(vrshlq_s32(vitem1, vshift))));
...@@ -127,7 +126,7 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase ...@@ -127,7 +126,7 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase
} }
int8x16_t operator()(const float32x4_t& vsrc) const { int8x16_t operator()(const float32x4_t& vsrc) const {
int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier); int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier);
vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); vitem0 = vmaxq_s32(vitem0, vzero);
vitem0 = vrshlq_s32(vitem0, vshift); vitem0 = vrshlq_s32(vitem0, vshift);
int16x4_t vitem = vqmovn_s32(vitem0); int16x4_t vitem = vqmovn_s32(vitem0);
auto tmp = vqmovn_s16(vcombine_s16(vitem, vitem)); auto tmp = vqmovn_s16(vcombine_s16(vitem, vitem));
...@@ -135,13 +134,13 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase ...@@ -135,13 +134,13 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase
} }
void operator()(const int32x4_t& src, dt_qint8* dst) const { void operator()(const int32x4_t& src, dt_qint8* dst) const {
auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale); auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale);
vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); vitem0 = vmaxq_f32(vitem0, vfzero);
auto result = QConverter::convert<int8x16_t, float32x4_t>(vitem0); auto result = QConverter::convert<int8x16_t, float32x4_t>(vitem0);
vst1q_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x4_t)result, 0); vst1q_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x4_t)result, 0);
} }
void operator()(const float32x4_t& src, dt_qint8* dst) const { void operator()(const float32x4_t& src, dt_qint8* dst) const {
auto vitem0 = vmulq_f32(src, this->vscale); auto vitem0 = vmulq_f32(src, this->vscale);
vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); vitem0 = vmaxq_f32(vitem0, vfzero);
auto result = QConverter::convert<int8x16_t, float32x4_t>(vitem0); auto result = QConverter::convert<int8x16_t, float32x4_t>(vitem0);
vst1q_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x4_t)result, 0); vst1q_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x4_t)result, 0);
} }
...@@ -165,19 +164,19 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8> { ...@@ -165,19 +164,19 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8> {
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const { GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale); auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale);
auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale); auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale);
vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero()); vitem0 = GiMaximumFloat32(vitem0, vfzero);
vitem1 = GiMaximumFloat32(vitem1, QConverterBase::vfzero()); vitem1 = GiMaximumFloat32(vitem1, vfzero);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}}); return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
} }
GI_INT8_t operator()(const GI_INT32_t& src) const { GI_INT8_t operator()(const GI_INT32_t& src) const {
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(src), this->vscale); auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(src), this->vscale);
vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero()); vitem0 = GiMaximumFloat32(vitem0, vfzero);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0); return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0);
} }
GI_INT8_t operator()(const GI_FLOAT32_t& src) const { GI_INT8_t operator()(const GI_FLOAT32_t& src) const {
auto vitem0 = GiMultiplyFloat32(src, this->vscale); auto vitem0 = GiMultiplyFloat32(src, this->vscale);
vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero()); vitem0 = GiMaximumFloat32(vitem0, vfzero);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0); return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0);
} }
}; };
......
...@@ -213,4 +213,57 @@ GI_INT32_t GiXorInt32(GI_INT32_t Vector1, GI_INT32_t Vector2) { ...@@ -213,4 +213,57 @@ GI_INT32_t GiXorInt32(GI_INT32_t Vector1, GI_INT32_t Vector2) {
#endif #endif
} }
GI_FORCEINLINE
GI_FLOAT32_t GiBroadcastFloat32(float Value) {
#if defined(GI_NEON_INTRINSICS)
return vdupq_n_f32(Value);
#elif defined(GI_SSE2_INTRINSICS)
return _mm_set1_ps(Value);
#else
GI_FLOAT32_t ret;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) {
ret[i] = Value;
}
return ret;
#endif
}
GI_FORCEINLINE
GI_INT32_t GiBroadcastInt32(int32_t Value) {
#if defined(GI_NEON_INTRINSICS)
return vdupq_n_s32(Value);
#elif defined(GI_SSE2_INTRINSICS)
return _mm_set1_epi32(Value);
#else
GI_INT32_t ret;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) {
ret[i] = Value;
}
return ret;
#endif
}
GI_FORCEINLINE
GI_INT8_t GiBroadcastInt8(int8_t Value) {
#if defined(GI_NEON_INTRINSICS)
return vdupq_n_s8(Value);
#elif defined(GI_SSE2_INTRINSICS)
return _mm_set1_epi8(Value);
#else
GI_INT8_t ret;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) {
ret[i] = Value;
}
return ret;
#endif
}
__attribute__((unused)) const GI_INT8_t vzero_int8 = GiBroadcastInt8(0);
__attribute__((unused)) const GI_INT32_t vzero = GiBroadcastInt32(0);
__attribute__((unused)) const GI_FLOAT32_t vfzero = GiBroadcastFloat32(0.0f);
__attribute__((unused)) const GI_FLOAT32_t vfhalf = GiBroadcastFloat32(0.5f);
__attribute__((unused)) const GI_FLOAT32_t vfneg_half = GiBroadcastFloat32(-0.5f);
__attribute__((unused)) const GI_FLOAT32_t vfmin_int8 = GiBroadcastFloat32(-128.0f);
__attribute__((unused)) const GI_FLOAT32_t vfmax_int8 = GiBroadcastFloat32(127.0f);
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -71,20 +71,12 @@ GI_INT32_t GiRoundAsInt32(GI_FLOAT32_t Vector) { ...@@ -71,20 +71,12 @@ GI_INT32_t GiRoundAsInt32(GI_FLOAT32_t Vector) {
#if __ARM_ARCH >= 8 #if __ARM_ARCH >= 8
return vcvtaq_s32_f32(Vector); return vcvtaq_s32_f32(Vector);
#else #else
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vinc0 = vbslq_f32(vcgeq_f32(Vector, vfzero), vfhalf, vfneg_half);
float32x4_t vfhalf = vdupq_n_f32(0.5f);
float32x4_t vfneg_half = vdupq_n_f32(-0.5f);
float32x4_t vinc0 = vbslq_f32(vcgeq_f32(Vector, vzero), vfhalf, vfneg_half);
return vcvtq_s32_f32(vaddq_f32(Vector, vinc0)); return vcvtq_s32_f32(vaddq_f32(Vector, vinc0));
#endif #endif
#elif defined(GI_SSE42_INTRINSICS) #elif defined(GI_SSE42_INTRINSICS)
__m128 vfzero = _mm_set1_ps(0.f);
__m128 vfhalf = _mm_set1_ps(0.5f);
__m128 vfneg_half = _mm_set1_ps(-0.5f);
__m128 vinc0 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(Vector, vfzero)); __m128 vinc0 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(Vector, vfzero));
__m128 vres0 = _mm_add_ps(Vector, vinc0); return _mm_cvttps_epi32(_mm_add_ps(Vector, vinc0));
return _mm_castps_si128(
_mm_round_ps(vres0, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
#else #else
GI_INT32_t ret; GI_INT32_t ret;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) { for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) {
...@@ -118,22 +110,7 @@ GI_FLOAT32_t GiCastToFloat32(GI_INT32_t Vector) { ...@@ -118,22 +110,7 @@ GI_FLOAT32_t GiCastToFloat32(GI_INT32_t Vector) {
#else #else
GI_FLOAT32_t ret; GI_FLOAT32_t ret;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) {
ret[i] = float(Vector[i]); ret[i] = (float)Vector[i];
}
return ret;
#endif
}
GI_FORCEINLINE
GI_FLOAT32_t GiBroadcastFloat32(float Value) {
#if defined(GI_NEON_INTRINSICS)
return vdupq_n_f32(Value);
#elif defined(GI_SSE2_INTRINSICS)
return _mm_set1_ps(Value);
#else
GI_FLOAT32_t ret;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(float); i++) {
ret[i] = Value;
} }
return ret; return ret;
#endif #endif
......
...@@ -13,21 +13,6 @@ ...@@ -13,21 +13,6 @@
#include "gi_common.h" #include "gi_common.h"
GI_FORCEINLINE
GI_INT32_t GiBroadcastInt32(int32_t Value) {
#if defined(GI_NEON_INTRINSICS)
return vdupq_n_s32(Value);
#elif defined(GI_SSE2_INTRINSICS)
return _mm_set1_epi32(Value);
#else
GI_INT32_t ret;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) {
ret[i] = Value;
}
return ret;
#endif
}
GI_FORCEINLINE GI_FORCEINLINE
GI_UINT32_t GiBroadcastUint32(int32_t Value) { GI_UINT32_t GiBroadcastUint32(int32_t Value) {
#if defined(GI_NEON_INTRINSICS) #if defined(GI_NEON_INTRINSICS)
...@@ -44,30 +29,31 @@ GI_UINT32_t GiBroadcastUint32(int32_t Value) { ...@@ -44,30 +29,31 @@ GI_UINT32_t GiBroadcastUint32(int32_t Value) {
} }
GI_FORCEINLINE GI_FORCEINLINE
GI_INT8_t GiBroadcastInt8(int8_t Value) { GI_INT32_t GiLoadInt32(const void* Buffer) {
#if defined(GI_NEON_INTRINSICS) #if defined(GI_NEON_INTRINSICS)
return vdupq_n_s8(Value); return vld1q_s32((int32_t*)Buffer);
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)
return _mm_set1_epi8(Value); return _mm_loadu_si128((const __m128i*)Buffer);
#else #else
GI_INT8_t ret; GI_INT32_t ret;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { const int32_t* ptr = (int32_t*)Buffer;
ret[i] = Value; for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) {
ret[i] = ptr[i];
} }
return ret; return ret;
#endif #endif
} }
GI_FORCEINLINE GI_FORCEINLINE
GI_INT32_t GiLoadInt32(const void* Buffer) { GI_INT16_t GiLoadInt16(const void* Buffer) {
#if defined(GI_NEON_INTRINSICS) #if defined(GI_NEON_INTRINSICS)
return vld1q_s32((int32_t*)Buffer); return vld1q_s16((int16_t*)Buffer);
#elif defined(GI_SSE2_INTRINSICS) #elif defined(GI_SSE2_INTRINSICS)
return _mm_loadu_si128((const __m128i*)Buffer); return _mm_loadu_si128((const __m128i*)Buffer);
#else #else
GI_INT32_t ret; GI_INT16_t ret;
const int32_t* ptr = (int32_t*)Buffer; const int16_t* ptr = (int16_t*)Buffer;
for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int16_t); i++) {
ret[i] = ptr[i]; ret[i] = ptr[i];
} }
return ret; return ret;
...@@ -810,21 +796,12 @@ GI_INT8_t GiCvtFromFloat32ToInt8(GI_FLOAT32_t src) { ...@@ -810,21 +796,12 @@ GI_INT8_t GiCvtFromFloat32ToInt8(GI_FLOAT32_t src) {
int16x8_t mid_s16 = vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres0)); int16x8_t mid_s16 = vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres0));
return vcombine_s8(vqmovn_s16(mid_s16), vqmovn_s16(mid_s16)); return vcombine_s8(vqmovn_s16(mid_s16), vqmovn_s16(mid_s16));
#else #else
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vinc0 = vbslq_f32(vcgeq_f32(src, vfzero), vfhalf, vfneg_half);
float32x4_t vfhalf = vdupq_n_f32(0.5f);
float32x4_t vfneg_half = vdupq_n_f32(-0.5f);
float32x4_t vinc0 = vbslq_f32(vcgeq_f32(src, vzero), vfhalf, vfneg_half);
int32x4_t vres0 = vcvtq_s32_f32(vaddq_f32(src, vinc0)); int32x4_t vres0 = vcvtq_s32_f32(vaddq_f32(src, vinc0));
int16x8_t mid_s16 = vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres0)); int16x8_t mid_s16 = vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres0));
return vcombine_s8(vqmovn_s16(mid_s16), vqmovn_s16(mid_s16)); return vcombine_s8(vqmovn_s16(mid_s16), vqmovn_s16(mid_s16));
#endif #endif
#elif defined(GI_SSE42_INTRINSICS) #elif defined(GI_SSE42_INTRINSICS)
__m128 vfzero = _mm_set1_ps(0.f);
__m128 vfhalf = _mm_set1_ps(0.5f);
__m128 vfneg_half = _mm_set1_ps(-0.5f);
__m128 vfmin_int8 = _mm_set1_ps(-128.f);
__m128 vfmax_int8 = _mm_set1_ps(127.f);
__m128 vinc0 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(src, vfzero)); __m128 vinc0 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(src, vfzero));
__m128 vres0 = _mm_add_ps(src, vinc0); __m128 vres0 = _mm_add_ps(src, vinc0);
vres0 = _mm_round_ps(vres0, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); vres0 = _mm_round_ps(vres0, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC);
...@@ -857,23 +834,14 @@ GI_INT8_t GiCvtFromFloat32V2ToInt8(GI_FLOAT32_V2_t vsrc) { ...@@ -857,23 +834,14 @@ GI_INT8_t GiCvtFromFloat32V2ToInt8(GI_FLOAT32_V2_t vsrc) {
int8x8_t mid1 = vqmovn_s16(vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres1))); int8x8_t mid1 = vqmovn_s16(vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres1)));
return vcombine_s8(mid1, mid1); return vcombine_s8(mid1, mid1);
#else #else
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vinc0 = vbslq_f32(vcgeq_f32(vsrc.val[0], vfzero), vfhalf, vfneg_half);
float32x4_t vfhalf = vdupq_n_f32(0.5f); float32x4_t vinc1 = vbslq_f32(vcgeq_f32(vsrc.val[1], vfzero), vfhalf, vfneg_half);
float32x4_t vfneg_half = vdupq_n_f32(-0.5f);
float32x4_t vinc0 = vbslq_f32(vcgeq_f32(vsrc.val[0], vzero), vfhalf, vfneg_half);
float32x4_t vinc1 = vbslq_f32(vcgeq_f32(vsrc.val[1], vzero), vfhalf, vfneg_half);
int32x4_t vres0 = vcvtq_s32_f32(vaddq_f32(vsrc.val[0], vinc0)); int32x4_t vres0 = vcvtq_s32_f32(vaddq_f32(vsrc.val[0], vinc0));
int32x4_t vres1 = vcvtq_s32_f32(vaddq_f32(vsrc.val[1], vinc1)); int32x4_t vres1 = vcvtq_s32_f32(vaddq_f32(vsrc.val[1], vinc1));
int8x8_t mid1 = vqmovn_s16(vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres1))); int8x8_t mid1 = vqmovn_s16(vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres1)));
return vcombine_s8(mid1, mid1); return vcombine_s8(mid1, mid1);
#endif #endif
#elif defined(GI_SSE42_INTRINSICS) #elif defined(GI_SSE42_INTRINSICS)
__m128 vfzero = _mm_set1_ps(0.f);
__m128 vfhalf = _mm_set1_ps(0.5f);
__m128 vfneg_half = _mm_set1_ps(-0.5f);
__m128 vfmin_int8 = _mm_set1_ps(-128.f);
__m128 vfmax_int8 = _mm_set1_ps(127.f);
__m128 vinc0 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(vsrc.val[0], vfzero)); __m128 vinc0 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(vsrc.val[0], vfzero));
__m128 vinc1 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(vsrc.val[1], vfzero)); __m128 vinc1 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(vsrc.val[1], vfzero));
...@@ -913,13 +881,13 @@ GI_INT8_t GiCvtFromFloat32V4ToInt8(GI_FLOAT32_V4_t vsrc) { ...@@ -913,13 +881,13 @@ GI_INT8_t GiCvtFromFloat32V4ToInt8(GI_FLOAT32_V4_t vsrc) {
int8x8_t mid2 = vqmovn_s16(vcombine_s16(vqmovn_s32(vres2), vqmovn_s32(vres3))); int8x8_t mid2 = vqmovn_s16(vcombine_s16(vqmovn_s32(vres2), vqmovn_s32(vres3)));
return vcombine_s8(mid1, mid2); return vcombine_s8(mid1, mid2);
#else #else
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vfzero = vdupq_n_f32(0.f);
float32x4_t vfhalf = vdupq_n_f32(0.5f); float32x4_t vfhalf = vdupq_n_f32(0.5f);
float32x4_t vfneg_half = vdupq_n_f32(-0.5f); float32x4_t vfneg_half = vdupq_n_f32(-0.5f);
float32x4_t vinc0 = vbslq_f32(vcgeq_f32(vsrc.val[0], vzero), vfhalf, vfneg_half); float32x4_t vinc0 = vbslq_f32(vcgeq_f32(vsrc.val[0], vfzero), vfhalf, vfneg_half);
float32x4_t vinc1 = vbslq_f32(vcgeq_f32(vsrc.val[1], vzero), vfhalf, vfneg_half); float32x4_t vinc1 = vbslq_f32(vcgeq_f32(vsrc.val[1], vfzero), vfhalf, vfneg_half);
float32x4_t vinc2 = vbslq_f32(vcgeq_f32(vsrc.val[2], vzero), vfhalf, vfneg_half); float32x4_t vinc2 = vbslq_f32(vcgeq_f32(vsrc.val[2], vfzero), vfhalf, vfneg_half);
float32x4_t vinc3 = vbslq_f32(vcgeq_f32(vsrc.val[3], vzero), vfhalf, vfneg_half); float32x4_t vinc3 = vbslq_f32(vcgeq_f32(vsrc.val[3], vfzero), vfhalf, vfneg_half);
int32x4_t vres0 = vcvtq_s32_f32(vaddq_f32(vsrc.val[0], vinc0)); int32x4_t vres0 = vcvtq_s32_f32(vaddq_f32(vsrc.val[0], vinc0));
int32x4_t vres1 = vcvtq_s32_f32(vaddq_f32(vsrc.val[1], vinc1)); int32x4_t vres1 = vcvtq_s32_f32(vaddq_f32(vsrc.val[1], vinc1));
int32x4_t vres2 = vcvtq_s32_f32(vaddq_f32(vsrc.val[2], vinc2)); int32x4_t vres2 = vcvtq_s32_f32(vaddq_f32(vsrc.val[2], vinc2));
...@@ -929,12 +897,6 @@ GI_INT8_t GiCvtFromFloat32V4ToInt8(GI_FLOAT32_V4_t vsrc) { ...@@ -929,12 +897,6 @@ GI_INT8_t GiCvtFromFloat32V4ToInt8(GI_FLOAT32_V4_t vsrc) {
return vcombine_s8(mid1, mid2); return vcombine_s8(mid1, mid2);
#endif #endif
#elif defined(GI_SSE42_INTRINSICS) #elif defined(GI_SSE42_INTRINSICS)
__m128 vfzero = _mm_set1_ps(0.f);
__m128 vfhalf = _mm_set1_ps(0.5f);
__m128 vfneg_half = _mm_set1_ps(-0.5f);
__m128 vfmin_int8 = _mm_set1_ps(-128.f);
__m128 vfmax_int8 = _mm_set1_ps(127.f);
__m128 vinc0 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(vsrc.val[0], vfzero)); __m128 vinc0 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(vsrc.val[0], vfzero));
__m128 vinc1 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(vsrc.val[1], vfzero)); __m128 vinc1 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(vsrc.val[1], vfzero));
__m128 vinc2 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(vsrc.val[2], vfzero)); __m128 vinc2 = _mm_blendv_ps(vfneg_half, vfhalf, _mm_cmpge_ps(vsrc.val[2], vfzero));
......
...@@ -20,16 +20,6 @@ ...@@ -20,16 +20,6 @@
namespace megdnn { namespace megdnn {
namespace fallback { namespace fallback {
struct QConverterBase {
inline static GI_INT32_t vzero() { return GiBroadcastInt32(0); }
inline static GI_FLOAT32_t vfzero() { return GiBroadcastFloat32(0.f); }
inline static GI_FLOAT32_t vfhalf() { return GiBroadcastFloat32(0.5f); }
inline static GI_FLOAT32_t vfneg_half() { return GiBroadcastFloat32(-0.5f); }
};
struct QConverter { struct QConverter {
template <typename dst_type, typename... src_type> template <typename dst_type, typename... src_type>
static inline dst_type convert(const src_type&... src); static inline dst_type convert(const src_type&... src);
...@@ -66,6 +56,12 @@ template <> ...@@ -66,6 +56,12 @@ template <>
inline GI_INT8_t QConverter::convert(const GI_FLOAT32_V2_t& vsrc) { inline GI_INT8_t QConverter::convert(const GI_FLOAT32_V2_t& vsrc) {
return GiCvtFromFloat32V2ToInt8(vsrc); return GiCvtFromFloat32V2ToInt8(vsrc);
} }
template <>
inline GI_INT8_t QConverter::convert(const GI_FLOAT32_V4_t& vsrc) {
return GiCvtFromFloat32V4ToInt8(vsrc);
}
template <> template <>
inline GI_INT8_t QConverter::convert(const GI_FLOAT32_t& src) { inline GI_INT8_t QConverter::convert(const GI_FLOAT32_t& src) {
return GiCvtFromFloat32ToInt8(src); return GiCvtFromFloat32ToInt8(src);
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "src/fallback/type_cvt/opr_impl.h" #include "src/fallback/type_cvt/opr_impl.h"
#include "src/fallback/type_cvt/typecvt_helper.h"
#include "midout.h" #include "midout.h"
#include "src/common/utils.h" #include "src/common/utils.h"
...@@ -17,6 +18,7 @@ ...@@ -17,6 +18,7 @@
// MIDOUT_DECL(megdnn_fb_typecvt_src) // MIDOUT_DECL(megdnn_fb_typecvt_src)
MIDOUT_DECL(megdnn_fb_typecvt_dst_dtype) MIDOUT_DECL(megdnn_fb_typecvt_dst_dtype)
MIDOUT_DECL(megdnn_fb_typecvt_src_dtype) MIDOUT_DECL(megdnn_fb_typecvt_src_dtype)
MIDOUT_DECL(megdnn_fb_typecvt_optimized)
namespace { namespace {
...@@ -513,12 +515,68 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { ...@@ -513,12 +515,68 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
!is_quantize_lowbit(dst.layout.dtype) && !is_quantize_lowbit(dst.layout.dtype) &&
dst.layout.dtype.enumv() != DTypeEnum::QuantizedS1 && dst.layout.dtype.enumv() != DTypeEnum::QuantizedS1 &&
src.layout.dtype.enumv() != DTypeEnum::QuantizedS1) { src.layout.dtype.enumv() != DTypeEnum::QuantizedS1) {
if (!exec_optimized(src, dst)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst)); MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst));
}
} else { } else {
naive::TypeCvtImpl::exec(src, dst); naive::TypeCvtImpl::exec(src, dst);
} }
} }
bool TypeCvtImpl::exec_optimized(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
DType src_dtype = src.layout.dtype;
DType dst_dtype = dst.layout.dtype;
bool execed = false;
using namespace dtype;
size_t nr_elems = src.layout.total_nr_elems();
#define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \
dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \
MIDOUT_BEGIN(megdnn_fb_typecvt_optimized, midout_iv(_midout_iv)) { \
using _TypeCvter = QuantizedTypeCvter<_stype, _dtype>; \
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \
src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \
src_dtype, dst_dtype, nr_elems)); \
execed = true; \
} \
MIDOUT_END(); \
}
DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS8, int8_t, 1);
DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS32, int32_t, 2);
DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS8, int8_t, 3);
DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS32, int32_t, 4);
DISPATCH_QUANTIZED(float, float, QuantizedS8, int8_t, 5);
#undef DISPATCH_QUANTIZED
#define DISPATCH_FIX2FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \
dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \
MIDOUT_BEGIN(megdnn_fb_typecvt_optimized, midout_iv(_midout_iv)) { \
using _TypeCvter = Fix2FloatTypeCvter<_stype, _dtype>; \
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \
src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \
src_dtype, dst_dtype, src.layout.total_nr_elems())); \
execed = true; \
} \
MIDOUT_END(); \
}
DISPATCH_FIX2FLOAT(Int16, int16_t, Float32, float, 6);
DISPATCH_FIX2FLOAT(Int8, int8_t, Float32, float, 7);
if (src_dtype.enumv() == DTypeTrait<QuantizedS8>::enumv &&
dst_dtype.enumv() == DTypeTrait<Float32>::enumv) {
MIDOUT_BEGIN(megdnn_fb_typecvt_optimized, midout_iv(8)) {
using TypeCvter = Quan2FloatTypeCvter<int8_t, float>;
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<TypeCvter>(
src.compatible_ptr<int8_t>(), dst.compatible_ptr<float>(),
src_dtype, dst_dtype, src.layout.total_nr_elems()));
execed = true;
}
MIDOUT_END();
}
return execed;
}
} // namespace fallback } // namespace fallback
} // namespace megdnn } // namespace megdnn
......
...@@ -15,6 +15,8 @@ namespace megdnn { ...@@ -15,6 +15,8 @@ namespace megdnn {
namespace fallback { namespace fallback {
class TypeCvtImpl : public naive::TypeCvtImpl { class TypeCvtImpl : public naive::TypeCvtImpl {
bool exec_optimized(_megdnn_tensor_in src, _megdnn_tensor_out dst);
public: public:
using naive::TypeCvtImpl::TypeCvtImpl; using naive::TypeCvtImpl::TypeCvtImpl;
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) override; void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) override;
......
/**
* \file dnn/src/fallback/type_cvt/typecvt_helper.h
*/
#include "src/fallback/general_intrinsic/gi_float.h"
#include "src/fallback/general_intrinsic/gi_int.h"
#include "src/fallback/quantized_converter.h"
namespace megdnn {
namespace fallback {
template <typename ctype, typename dtype>
struct QuantizedTypeCvter;
template <>
struct QuantizedTypeCvter<int32_t, int8_t> {
using stype = int32_t;
using dst_type = int8_t;
static constexpr size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t) * 2;
static constexpr size_t SIMD_STEP = GI_SIMD_LEN_BYTE / sizeof(int32_t);
float scale;
GI_FLOAT32_t vscale;
QuantizedTypeCvter(DType src_dtype, DType dst_dtype) {
float src_scale = src_dtype.param<dtype::QuantizedS32>().scale;
float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale;
scale = src_scale / dst_scale;
vscale = GiBroadcastFloat32(scale);
}
void cvt(const int32_t* src, int8_t* dst) {
GI_FLOAT32_t vitem0 =
GiMultiplyFloat32(GiCastToFloat32(GiLoadInt32(src)), vscale);
GI_FLOAT32_t vitem1 = GiMultiplyFloat32(
GiCastToFloat32(GiLoadInt32(src + SIMD_STEP)), vscale);
auto vres = QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
GiStoreLowInt8(dst, vres);
}
void cvt_remain(const int32_t* src, int8_t* dst) {
*dst = saturate<int8_t, float>(std::round(*src * scale), -128.f, 127.f);
}
};
template <>
struct QuantizedTypeCvter<int8_t, int32_t> {
using stype = int8_t;
using dst_type = int32_t;
static constexpr size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
float scale;
GI_FLOAT32_t vscale;
QuantizedTypeCvter(DType src_dtype, DType dst_dtype) {
float src_scale = src_dtype.param<dtype::QuantizedS8>().scale;
float dst_scale = dst_dtype.param<dtype::QuantizedS32>().scale;
scale = src_scale / dst_scale;
vscale = GiBroadcastFloat32(scale);
}
void cvt(const int8_t* src, int32_t* dst) {
GI_INT8_t data = GiLoadInt8(src);
GI_INT16_t vitem0 = GiMoveLowLongInt8(data);
GI_INT16_t vitem1 = GiMoveHighLongInt8(data);
auto vret0 = QConverter::round<GI_INT32_t, GI_FLOAT32_t>(
GiMultiplyFloat32(GiCastToFloat32(GiMoveLowLongInt16(vitem0)), vscale));
auto vret1 = QConverter::round<GI_INT32_t, GI_FLOAT32_t>(GiMultiplyFloat32(
GiCastToFloat32(GiMoveHighLongInt16(vitem0)), vscale));
auto vret2 = QConverter::round<GI_INT32_t, GI_FLOAT32_t>(
GiMultiplyFloat32(GiCastToFloat32(GiMoveLowLongInt16(vitem1)), vscale));
auto vret3 = QConverter::round<GI_INT32_t, GI_FLOAT32_t>(GiMultiplyFloat32(
GiCastToFloat32(GiMoveHighLongInt16(vitem1)), vscale));
constexpr size_t step = GI_SIMD_LEN_BYTE / sizeof(int32_t);
GiStoreInt32(dst, vret0);
GiStoreInt32(dst + step, vret1);
GiStoreInt32(dst + 2 * step, vret2);
GiStoreInt32(dst + 3 * step, vret3);
}
void cvt_remain(const int8_t* src, int32_t* dst) {
*dst = saturate<int32_t, float>(
std::round(*src * scale), -2147483648.f, 2147483647.f);
}
};
template <>
struct QuantizedTypeCvter<float, int8_t> {
using stype = float;
using dst_type = int8_t;
static constexpr size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(float) * 2;
static constexpr size_t SIMD_STEP = GI_SIMD_LEN_BYTE / sizeof(float);
float scale;
GI_FLOAT32_t vscale;
QuantizedTypeCvter(DType src_dtype, DType dst_dtype) {
MEGDNN_MARK_USED_VAR(src_dtype);
float src_scale = 1;
float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale;
scale = src_scale / dst_scale;
vscale = GiBroadcastFloat32(scale);
}
void cvt(const float* src, int8_t* dst) {
GI_FLOAT32_t vitem0 = GiMultiplyFloat32(GiLoadFloat32(src), vscale);
GI_FLOAT32_t vitem1 = GiMultiplyFloat32(GiLoadFloat32(src + SIMD_STEP), vscale);
auto vres = QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
GiStoreLowInt8(dst, vres);
}
void cvt_remain(const float* src, int8_t* dst) {
*dst = saturate<int8_t, float>(std::round(*src * scale), -128.f, 127.f);
}
};
template <>
struct QuantizedTypeCvter<int32_t, int32_t> {
using stype = int32_t;
using dst_type = int32_t;
static constexpr size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t);
float scale;
GI_FLOAT32_t vscale;
QuantizedTypeCvter(DType src_dtype, DType dst_dtype) {
float src_scale = src_dtype.param<dtype::QuantizedS32>().scale;
float dst_scale = dst_dtype.param<dtype::QuantizedS32>().scale;
scale = src_scale / dst_scale;
vscale = GiBroadcastFloat32(scale);
}
void cvt(const int32_t* src, int32_t* dst) {
GI_FLOAT32_t vitem =
GiMultiplyFloat32(GiCastToFloat32(GiLoadInt32(src)), vscale);
auto vres = QConverter::round<GI_INT32_t, GI_FLOAT32_t>(vitem);
GiStoreInt32(dst, vres);
}
void cvt_remain(const int32_t* src, int32_t* dst) {
*dst = saturate<int32_t, float>(
std::round(*src * scale), -2147483648.f, 2147483647.f);
}
};
template <>
struct QuantizedTypeCvter<int8_t, int8_t> {
using stype = int8_t;
using dst_type = int8_t;
static constexpr size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
float scale;
GI_FLOAT32_t vscale;
QuantizedTypeCvter(DType src_dtype, DType dst_dtype) {
float src_scale = src_dtype.param<dtype::QuantizedS8>().scale;
float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale;
scale = src_scale / dst_scale;
vscale = GiBroadcastFloat32(scale);
}
void cvt(const int8_t* src, int8_t* dst) {
GI_INT8_t data = GiLoadInt8(src);
GI_INT16_t vitem0 = GiMoveLowLongInt8(data);
GI_INT16_t vitem1 = GiMoveHighLongInt8(data);
auto vret0 =
GiMultiplyFloat32(GiCastToFloat32(GiMoveLowLongInt16(vitem0)), vscale);
auto vret1 =
GiMultiplyFloat32(GiCastToFloat32(GiMoveHighLongInt16(vitem0)), vscale);
auto vret2 =
GiMultiplyFloat32(GiCastToFloat32(GiMoveLowLongInt16(vitem1)), vscale);
auto vret3 =
GiMultiplyFloat32(GiCastToFloat32(GiMoveHighLongInt16(vitem1)), vscale);
auto vres = QConverter::convert<GI_INT8_t, GI_FLOAT32_V4_t>(
{{vret0, vret1, vret2, vret3}});
GiStoreInt8(dst, vres);
}
void cvt_remain(const int8_t* src, int8_t* dst) {
*dst = saturate<int8_t, float>(std::round(*src * scale), -128.f, 127.f);
}
};
template <typename ctype, typename dtype>
struct Fix2FloatTypeCvter;
template <typename ctype, typename dtype>
struct Quan2FloatTypeCvter;
template <>
struct Fix2FloatTypeCvter<int16_t, float> {
using stype = int16_t;
using dst_type = float;
static constexpr size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int16_t);
static constexpr size_t SIMD_STEP = GI_SIMD_LEN_BYTE / sizeof(float);
Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) {
MEGDNN_MARK_USED_VAR(src_dtype);
MEGDNN_MARK_USED_VAR(dst_dtype);
}
void cvt(const int16_t* src, float* dst) {
GI_INT16_t vitem = GiLoadInt16(src);
auto vret0 = GiCastToFloat32(GiMoveLowLongInt16(vitem));
auto vret1 = GiCastToFloat32(GiMoveHighLongInt16(vitem));
GiStoreFloat32(dst, vret0);
GiStoreFloat32(dst + SIMD_STEP, vret1);
}
void cvt_remain(const int16_t* src, float* dst) { *dst = *src; }
};
template <>
struct Fix2FloatTypeCvter<int8_t, float> {
using stype = int8_t;
using dst_type = float;
static constexpr size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
static constexpr size_t SIMD_STEP = GI_SIMD_LEN_BYTE / sizeof(float);
Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) {
MEGDNN_MARK_USED_VAR(src_dtype);
MEGDNN_MARK_USED_VAR(dst_dtype);
}
void cvt(const int8_t* src, float* dst) {
GI_INT8_t data = GiLoadInt8(src);
GI_INT16_t vitem0 = GiMoveLowLongInt8(data);
GI_INT16_t vitem1 = GiMoveHighLongInt8(data);
auto vret0 = GiCastToFloat32(GiMoveLowLongInt16(vitem0));
auto vret1 = GiCastToFloat32(GiMoveHighLongInt16(vitem0));
auto vret2 = GiCastToFloat32(GiMoveLowLongInt16(vitem1));
auto vret3 = GiCastToFloat32(GiMoveHighLongInt16(vitem1));
GiStoreFloat32(dst, vret0);
GiStoreFloat32(dst + SIMD_STEP, vret1);
GiStoreFloat32(dst + 2 * SIMD_STEP, vret2);
GiStoreFloat32(dst + 3 * SIMD_STEP, vret3);
}
void cvt_remain(const int8_t* src, float* dst) { *dst = *src; }
};
template <>
struct Quan2FloatTypeCvter<int8_t, float> {
using stype = int8_t;
using dst_type = float;
static constexpr size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
static constexpr size_t SIMD_STEP = GI_SIMD_LEN_BYTE / sizeof(float);
float _scale = 0.0f;
GI_FLOAT32_t vscale;
Quan2FloatTypeCvter(DType src_dtype, DType dst_dtype) {
_scale = src_dtype.param<dtype::QuantizedS8>().scale;
vscale = GiBroadcastFloat32(_scale);
MEGDNN_MARK_USED_VAR(dst_dtype);
}
void cvt(const int8_t* src, float* dst) {
GI_INT8_t data = GiLoadInt8(src);
GI_INT16_t vitem0 = GiMoveLowLongInt8(data);
GI_INT16_t vitem1 = GiMoveHighLongInt8(data);
auto vret0 =
GiMultiplyFloat32(GiCastToFloat32(GiMoveLowLongInt16(vitem0)), vscale);
auto vret1 =
GiMultiplyFloat32(GiCastToFloat32(GiMoveHighLongInt16(vitem0)), vscale);
auto vret2 =
GiMultiplyFloat32(GiCastToFloat32(GiMoveLowLongInt16(vitem1)), vscale);
auto vret3 =
GiMultiplyFloat32(GiCastToFloat32(GiMoveHighLongInt16(vitem1)), vscale);
GiStoreFloat32(dst, vret0);
GiStoreFloat32(dst + SIMD_STEP, vret1);
GiStoreFloat32(dst + 2 * SIMD_STEP, vret2);
GiStoreFloat32(dst + 3 * SIMD_STEP, vret3);
}
void cvt_remain(const int8_t* src, float* dst) { *dst = *src * _scale; }
};
template <typename TypeCvter>
void do_typecvt(
const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst,
DType src_dtype, DType dst_dtype, size_t nr_elems) {
TypeCvter typecvt(src_dtype, dst_dtype);
size_t i = 0;
for (; i + TypeCvter::SIMD_WIDTH <= nr_elems; i += TypeCvter::SIMD_WIDTH) {
typecvt.cvt(src, dst);
src += TypeCvter::SIMD_WIDTH;
dst += TypeCvter::SIMD_WIDTH;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for (; i < nr_elems; i++) {
typecvt.cvt_remain(src, dst);
src++;
dst++;
}
}
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册