Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
39d98d45
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
39d98d45
编写于
3月 09, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(fallback): add fallback typecvt with general intrinsic
GitOrigin-RevId: 1e6fcd929b02e4745b4a641e5101df8f624d4bea
上级
d2278f02
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
485 addition
and
136 deletion
+485
-136
dnn/src/fallback/elemwise_helper/kimpl/relu.h
dnn/src/fallback/elemwise_helper/kimpl/relu.h
+40
-41
dnn/src/fallback/general_intrinsic/gi_common.h
dnn/src/fallback/general_intrinsic/gi_common.h
+53
-0
dnn/src/fallback/general_intrinsic/gi_float.h
dnn/src/fallback/general_intrinsic/gi_float.h
+3
-26
dnn/src/fallback/general_intrinsic/gi_int.h
dnn/src/fallback/general_intrinsic/gi_int.h
+20
-58
dnn/src/fallback/quantized_converter.h
dnn/src/fallback/quantized_converter.h
+6
-10
dnn/src/fallback/type_cvt/opr_impl.cpp
dnn/src/fallback/type_cvt/opr_impl.cpp
+59
-1
dnn/src/fallback/type_cvt/opr_impl.h
dnn/src/fallback/type_cvt/opr_impl.h
+2
-0
dnn/src/fallback/type_cvt/typecvt_helper.h
dnn/src/fallback/type_cvt/typecvt_helper.h
+302
-0
未找到文件。
dnn/src/fallback/elemwise_helper/kimpl/relu.h
浏览文件 @
39d98d45
...
@@ -20,7 +20,7 @@ struct ReluOpBase : UnaryOpBase<src_ctype, dst_ctype> {
...
@@ -20,7 +20,7 @@ struct ReluOpBase : UnaryOpBase<src_ctype, dst_ctype> {
template
<
typename
src_ctype
,
typename
dst_type
=
src_ctype
>
template
<
typename
src_ctype
,
typename
dst_type
=
src_ctype
>
struct
ReluOp
;
struct
ReluOp
;
#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width
, zero
) \
template <> \
template <> \
struct ReluOp<_ctype> : ReluOpBase<_ctype> { \
struct ReluOp<_ctype> : ReluOpBase<_ctype> { \
using ReluOpBase::ReluOpBase; \
using ReluOpBase::ReluOpBase; \
...
@@ -32,9 +32,8 @@ struct ReluOp;
...
@@ -32,9 +32,8 @@ struct ReluOp;
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
} \
_simd_type2 operator()(const _simd_type2& src) const { \
_simd_type2 operator()(const _simd_type2& src) const { \
auto vzero = GiBroadcast##_func_suffix(0); \
auto vitem0 = GiMaximum##_func_suffix(src.val[0], zero); \
auto vitem0 = GiMaximum##_func_suffix(src.val[0], vzero); \
auto vitem1 = GiMaximum##_func_suffix(src.val[1], zero); \
auto vitem1 = GiMaximum##_func_suffix(src.val[1], vzero); \
return {{vitem0, vitem1}}; \
return {{vitem0, vitem1}}; \
} \
} \
void operator()(const _simd_type& src, _ctype* dst) const { \
void operator()(const _simd_type& src, _ctype* dst) const { \
...
@@ -42,14 +41,16 @@ struct ReluOp;
...
@@ -42,14 +41,16 @@ struct ReluOp;
GiStore##_func_suffix(dst, vitem); \
GiStore##_func_suffix(dst, vitem); \
} \
} \
_simd_type operator()(const _simd_type& src) const { \
_simd_type operator()(const _simd_type& src) const { \
auto vzero = GiBroadcast##_func_suffix(0); \
return GiMaximum##_func_suffix(src, zero); \
return GiMaximum##_func_suffix(src, vzero); \
} \
} \
};
};
OP
(
dt_float32
,
GI_FLOAT32_t
,
GI_FLOAT32_V2_t
,
Float32
,
GI_SIMD_LEN_BYTE
/
sizeof
(
float
))
OP
(
dt_float32
,
GI_FLOAT32_t
,
GI_FLOAT32_V2_t
,
Float32
,
GI_SIMD_LEN_BYTE
/
sizeof
(
float
),
OP
(
dt_int32
,
GI_INT32_t
,
GI_INT32_V2_t
,
Int32
,
GI_SIMD_LEN_BYTE
/
sizeof
(
int32_t
))
vfzero
)
OP
(
dt_int8
,
GI_INT8_t
,
GI_INT8_V2_t
,
Int8
,
GI_SIMD_LEN_BYTE
/
sizeof
(
int8_t
))
OP
(
dt_int32
,
GI_INT32_t
,
GI_INT32_V2_t
,
Int32
,
GI_SIMD_LEN_BYTE
/
sizeof
(
int32_t
),
vzero
)
OP
(
dt_int8
,
GI_INT8_t
,
GI_INT8_V2_t
,
Int8
,
GI_SIMD_LEN_BYTE
/
sizeof
(
int8_t
),
vzero_int8
)
#undef OP
#undef OP
template
<
>
template
<
>
...
@@ -75,11 +76,10 @@ struct ReluOp<dt_qint8, dt_qint8> : ReluOpBase<dt_qint8, dt_qint8> {
...
@@ -75,11 +76,10 @@ struct ReluOp<dt_qint8, dt_qint8> : ReluOpBase<dt_qint8, dt_qint8> {
OPERATOR_UNARY_QINT8_FALLBACK
;
OPERATOR_UNARY_QINT8_FALLBACK
;
}
}
GI_INT8_t
operator
()(
const
GI_INT32_V2_t
&
vsrc
)
const
{
GI_INT8_t
operator
()(
const
GI_INT32_V2_t
&
vsrc
)
const
{
auto
vzero
=
GiBroadcastFloat32
(
0.
f
);
auto
vitem0
=
GiMultiplyFloat32
(
GiCastToFloat32
(
vsrc
.
val
[
0
]),
this
->
vscale
);
auto
vitem0
=
GiMultiplyFloat32
(
GiCastToFloat32
(
vsrc
.
val
[
0
]),
this
->
vscale
);
auto
vitem1
=
GiMultiplyFloat32
(
GiCastToFloat32
(
vsrc
.
val
[
1
]),
this
->
vscale
);
auto
vitem1
=
GiMultiplyFloat32
(
GiCastToFloat32
(
vsrc
.
val
[
1
]),
this
->
vscale
);
vitem0
=
GiMaximumFloat32
(
vitem0
,
vzero
);
vitem0
=
GiMaximumFloat32
(
vitem0
,
v
f
zero
);
vitem1
=
GiMaximumFloat32
(
vitem1
,
vzero
);
vitem1
=
GiMaximumFloat32
(
vitem1
,
v
f
zero
);
return
QConverter
::
convert
<
GI_INT8_t
,
GI_FLOAT32_V2_t
>
({{
vitem0
,
vitem1
}});
return
QConverter
::
convert
<
GI_INT8_t
,
GI_FLOAT32_V2_t
>
({{
vitem0
,
vitem1
}});
}
}
};
};
...
@@ -114,12 +114,11 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase
...
@@ -114,12 +114,11 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase
void
operator
()(
const
int32x4x2_t
&
vsrc
,
dt_qint8
*
dst
)
const
{
void
operator
()(
const
int32x4x2_t
&
vsrc
,
dt_qint8
*
dst
)
const
{
vst1_s8
(
reinterpret_cast
<
int8_t
*>
(
dst
),
vget_low_s8
(
operator
()(
vsrc
)));
vst1_s8
(
reinterpret_cast
<
int8_t
*>
(
dst
),
vget_low_s8
(
operator
()(
vsrc
)));
}
}
int8x16_t
operator
()(
const
int32x4x2_t
&
vsrc
)
const
{
int8x16_t
operator
()(
const
int32x4x2_t
&
vsrc
)
const
{
int32x4_t
vitem0
=
vqrdmulhq_s32
(
vsrc
.
val
[
0
],
vmultiplier
);
int32x4_t
vitem0
=
vqrdmulhq_s32
(
vsrc
.
val
[
0
],
vmultiplier
);
int32x4_t
vitem1
=
vqrdmulhq_s32
(
vsrc
.
val
[
1
],
vmultiplier
);
int32x4_t
vitem1
=
vqrdmulhq_s32
(
vsrc
.
val
[
1
],
vmultiplier
);
vitem0
=
vmaxq_s32
(
vitem0
,
QConverterBase
::
vzero
()
);
vitem0
=
vmaxq_s32
(
vitem0
,
vzero
);
vitem1
=
vmaxq_s32
(
vitem1
,
QConverterBase
::
vzero
()
);
vitem1
=
vmaxq_s32
(
vitem1
,
vzero
);
auto
tmp
=
vqmovn_s16
(
vcombine_s16
(
auto
tmp
=
vqmovn_s16
(
vcombine_s16
(
vqmovn_s32
(
vrshlq_s32
(
vitem0
,
vshift
)),
vqmovn_s32
(
vrshlq_s32
(
vitem0
,
vshift
)),
vqmovn_s32
(
vrshlq_s32
(
vitem1
,
vshift
))));
vqmovn_s32
(
vrshlq_s32
(
vitem1
,
vshift
))));
...
@@ -127,7 +126,7 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase
...
@@ -127,7 +126,7 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase
}
}
int8x16_t
operator
()(
const
float32x4_t
&
vsrc
)
const
{
int8x16_t
operator
()(
const
float32x4_t
&
vsrc
)
const
{
int32x4_t
vitem0
=
vqrdmulhq_s32
(
vcvtq_s32_f32
(
vsrc
),
vmultiplier
);
int32x4_t
vitem0
=
vqrdmulhq_s32
(
vcvtq_s32_f32
(
vsrc
),
vmultiplier
);
vitem0
=
vmaxq_s32
(
vitem0
,
QConverterBase
::
vzero
()
);
vitem0
=
vmaxq_s32
(
vitem0
,
vzero
);
vitem0
=
vrshlq_s32
(
vitem0
,
vshift
);
vitem0
=
vrshlq_s32
(
vitem0
,
vshift
);
int16x4_t
vitem
=
vqmovn_s32
(
vitem0
);
int16x4_t
vitem
=
vqmovn_s32
(
vitem0
);
auto
tmp
=
vqmovn_s16
(
vcombine_s16
(
vitem
,
vitem
));
auto
tmp
=
vqmovn_s16
(
vcombine_s16
(
vitem
,
vitem
));
...
@@ -135,13 +134,13 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase
...
@@ -135,13 +134,13 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase
}
}
void
operator
()(
const
int32x4_t
&
src
,
dt_qint8
*
dst
)
const
{
void
operator
()(
const
int32x4_t
&
src
,
dt_qint8
*
dst
)
const
{
auto
vitem0
=
vmulq_f32
(
vcvtq_f32_s32
(
src
),
this
->
vscale
);
auto
vitem0
=
vmulq_f32
(
vcvtq_f32_s32
(
src
),
this
->
vscale
);
vitem0
=
vmaxq_f32
(
vitem0
,
QConverterBase
::
vfzero
()
);
vitem0
=
vmaxq_f32
(
vitem0
,
vfzero
);
auto
result
=
QConverter
::
convert
<
int8x16_t
,
float32x4_t
>
(
vitem0
);
auto
result
=
QConverter
::
convert
<
int8x16_t
,
float32x4_t
>
(
vitem0
);
vst1q_lane_s32
(
reinterpret_cast
<
int32_t
*>
(
dst
),
(
int32x4_t
)
result
,
0
);
vst1q_lane_s32
(
reinterpret_cast
<
int32_t
*>
(
dst
),
(
int32x4_t
)
result
,
0
);
}
}
void
operator
()(
const
float32x4_t
&
src
,
dt_qint8
*
dst
)
const
{
void
operator
()(
const
float32x4_t
&
src
,
dt_qint8
*
dst
)
const
{
auto
vitem0
=
vmulq_f32
(
src
,
this
->
vscale
);
auto
vitem0
=
vmulq_f32
(
src
,
this
->
vscale
);
vitem0
=
vmaxq_f32
(
vitem0
,
QConverterBase
::
vfzero
()
);
vitem0
=
vmaxq_f32
(
vitem0
,
vfzero
);
auto
result
=
QConverter
::
convert
<
int8x16_t
,
float32x4_t
>
(
vitem0
);
auto
result
=
QConverter
::
convert
<
int8x16_t
,
float32x4_t
>
(
vitem0
);
vst1q_lane_s32
(
reinterpret_cast
<
int32_t
*>
(
dst
),
(
int32x4_t
)
result
,
0
);
vst1q_lane_s32
(
reinterpret_cast
<
int32_t
*>
(
dst
),
(
int32x4_t
)
result
,
0
);
}
}
...
@@ -165,19 +164,19 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8> {
...
@@ -165,19 +164,19 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8> {
GI_INT8_t
operator
()(
const
GI_INT32_V2_t
&
vsrc
)
const
{
GI_INT8_t
operator
()(
const
GI_INT32_V2_t
&
vsrc
)
const
{
auto
vitem0
=
GiMultiplyFloat32
(
GiCastToFloat32
(
vsrc
.
val
[
0
]),
this
->
vscale
);
auto
vitem0
=
GiMultiplyFloat32
(
GiCastToFloat32
(
vsrc
.
val
[
0
]),
this
->
vscale
);
auto
vitem1
=
GiMultiplyFloat32
(
GiCastToFloat32
(
vsrc
.
val
[
1
]),
this
->
vscale
);
auto
vitem1
=
GiMultiplyFloat32
(
GiCastToFloat32
(
vsrc
.
val
[
1
]),
this
->
vscale
);
vitem0
=
GiMaximumFloat32
(
vitem0
,
QConverterBase
::
vfzero
()
);
vitem0
=
GiMaximumFloat32
(
vitem0
,
vfzero
);
vitem1
=
GiMaximumFloat32
(
vitem1
,
QConverterBase
::
vfzero
()
);
vitem1
=
GiMaximumFloat32
(
vitem1
,
vfzero
);
return
QConverter
::
convert
<
GI_INT8_t
,
GI_FLOAT32_V2_t
>
({{
vitem0
,
vitem1
}});
return
QConverter
::
convert
<
GI_INT8_t
,
GI_FLOAT32_V2_t
>
({{
vitem0
,
vitem1
}});
}
}
GI_INT8_t
operator
()(
const
GI_INT32_t
&
src
)
const
{
GI_INT8_t
operator
()(
const
GI_INT32_t
&
src
)
const
{
auto
vitem0
=
GiMultiplyFloat32
(
GiCastToFloat32
(
src
),
this
->
vscale
);
auto
vitem0
=
GiMultiplyFloat32
(
GiCastToFloat32
(
src
),
this
->
vscale
);
vitem0
=
GiMaximumFloat32
(
vitem0
,
QConverterBase
::
vfzero
()
);
vitem0
=
GiMaximumFloat32
(
vitem0
,
vfzero
);
return
QConverter
::
convert
<
GI_INT8_t
,
GI_FLOAT32_t
>
(
vitem0
);
return
QConverter
::
convert
<
GI_INT8_t
,
GI_FLOAT32_t
>
(
vitem0
);
}
}
GI_INT8_t
operator
()(
const
GI_FLOAT32_t
&
src
)
const
{
GI_INT8_t
operator
()(
const
GI_FLOAT32_t
&
src
)
const
{
auto
vitem0
=
GiMultiplyFloat32
(
src
,
this
->
vscale
);
auto
vitem0
=
GiMultiplyFloat32
(
src
,
this
->
vscale
);
vitem0
=
GiMaximumFloat32
(
vitem0
,
QConverterBase
::
vfzero
()
);
vitem0
=
GiMaximumFloat32
(
vitem0
,
vfzero
);
return
QConverter
::
convert
<
GI_INT8_t
,
GI_FLOAT32_t
>
(
vitem0
);
return
QConverter
::
convert
<
GI_INT8_t
,
GI_FLOAT32_t
>
(
vitem0
);
}
}
};
};
...
...
dnn/src/fallback/general_intrinsic/gi_common.h
浏览文件 @
39d98d45
...
@@ -213,4 +213,57 @@ GI_INT32_t GiXorInt32(GI_INT32_t Vector1, GI_INT32_t Vector2) {
...
@@ -213,4 +213,57 @@ GI_INT32_t GiXorInt32(GI_INT32_t Vector1, GI_INT32_t Vector2) {
#endif
#endif
}
}
GI_FORCEINLINE
GI_FLOAT32_t
GiBroadcastFloat32
(
float
Value
)
{
#if defined(GI_NEON_INTRINSICS)
return
vdupq_n_f32
(
Value
);
#elif defined(GI_SSE2_INTRINSICS)
return
_mm_set1_ps
(
Value
);
#else
GI_FLOAT32_t
ret
;
for
(
size_t
i
=
0
;
i
<
GI_SIMD_LEN_BYTE
/
sizeof
(
float
);
i
++
)
{
ret
[
i
]
=
Value
;
}
return
ret
;
#endif
}
GI_FORCEINLINE
GI_INT32_t
GiBroadcastInt32
(
int32_t
Value
)
{
#if defined(GI_NEON_INTRINSICS)
return
vdupq_n_s32
(
Value
);
#elif defined(GI_SSE2_INTRINSICS)
return
_mm_set1_epi32
(
Value
);
#else
GI_INT32_t
ret
;
for
(
size_t
i
=
0
;
i
<
GI_SIMD_LEN_BYTE
/
sizeof
(
int32_t
);
i
++
)
{
ret
[
i
]
=
Value
;
}
return
ret
;
#endif
}
GI_FORCEINLINE
GI_INT8_t
GiBroadcastInt8
(
int8_t
Value
)
{
#if defined(GI_NEON_INTRINSICS)
return
vdupq_n_s8
(
Value
);
#elif defined(GI_SSE2_INTRINSICS)
return
_mm_set1_epi8
(
Value
);
#else
GI_INT8_t
ret
;
for
(
size_t
i
=
0
;
i
<
GI_SIMD_LEN_BYTE
/
sizeof
(
int8_t
);
i
++
)
{
ret
[
i
]
=
Value
;
}
return
ret
;
#endif
}
__attribute__
((
unused
))
const
GI_INT8_t
vzero_int8
=
GiBroadcastInt8
(
0
);
__attribute__
((
unused
))
const
GI_INT32_t
vzero
=
GiBroadcastInt32
(
0
);
__attribute__
((
unused
))
const
GI_FLOAT32_t
vfzero
=
GiBroadcastFloat32
(
0.0
f
);
__attribute__
((
unused
))
const
GI_FLOAT32_t
vfhalf
=
GiBroadcastFloat32
(
0.5
f
);
__attribute__
((
unused
))
const
GI_FLOAT32_t
vfneg_half
=
GiBroadcastFloat32
(
-
0.5
f
);
__attribute__
((
unused
))
const
GI_FLOAT32_t
vfmin_int8
=
GiBroadcastFloat32
(
-
128.0
f
);
__attribute__
((
unused
))
const
GI_FLOAT32_t
vfmax_int8
=
GiBroadcastFloat32
(
127.0
f
);
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/src/fallback/general_intrinsic/gi_float.h
浏览文件 @
39d98d45
...
@@ -71,20 +71,12 @@ GI_INT32_t GiRoundAsInt32(GI_FLOAT32_t Vector) {
...
@@ -71,20 +71,12 @@ GI_INT32_t GiRoundAsInt32(GI_FLOAT32_t Vector) {
#if __ARM_ARCH >= 8
#if __ARM_ARCH >= 8
return
vcvtaq_s32_f32
(
Vector
);
return
vcvtaq_s32_f32
(
Vector
);
#else
#else
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vinc0
=
vbslq_f32
(
vcgeq_f32
(
Vector
,
vfzero
),
vfhalf
,
vfneg_half
);
float32x4_t
vfhalf
=
vdupq_n_f32
(
0.5
f
);
float32x4_t
vfneg_half
=
vdupq_n_f32
(
-
0.5
f
);
float32x4_t
vinc0
=
vbslq_f32
(
vcgeq_f32
(
Vector
,
vzero
),
vfhalf
,
vfneg_half
);
return
vcvtq_s32_f32
(
vaddq_f32
(
Vector
,
vinc0
));
return
vcvtq_s32_f32
(
vaddq_f32
(
Vector
,
vinc0
));
#endif
#endif
#elif defined(GI_SSE42_INTRINSICS)
#elif defined(GI_SSE42_INTRINSICS)
__m128
vfzero
=
_mm_set1_ps
(
0.
f
);
__m128
vfhalf
=
_mm_set1_ps
(
0.5
f
);
__m128
vfneg_half
=
_mm_set1_ps
(
-
0.5
f
);
__m128
vinc0
=
_mm_blendv_ps
(
vfneg_half
,
vfhalf
,
_mm_cmpge_ps
(
Vector
,
vfzero
));
__m128
vinc0
=
_mm_blendv_ps
(
vfneg_half
,
vfhalf
,
_mm_cmpge_ps
(
Vector
,
vfzero
));
__m128
vres0
=
_mm_add_ps
(
Vector
,
vinc0
);
return
_mm_cvttps_epi32
(
_mm_add_ps
(
Vector
,
vinc0
));
return
_mm_castps_si128
(
_mm_round_ps
(
vres0
,
_MM_FROUND_TO_ZERO
|
_MM_FROUND_NO_EXC
));
#else
#else
GI_INT32_t
ret
;
GI_INT32_t
ret
;
for
(
size_t
i
=
0
;
i
<
GI_SIMD_LEN_BYTE
/
sizeof
(
float
);
i
++
)
{
for
(
size_t
i
=
0
;
i
<
GI_SIMD_LEN_BYTE
/
sizeof
(
float
);
i
++
)
{
...
@@ -118,22 +110,7 @@ GI_FLOAT32_t GiCastToFloat32(GI_INT32_t Vector) {
...
@@ -118,22 +110,7 @@ GI_FLOAT32_t GiCastToFloat32(GI_INT32_t Vector) {
#else
#else
GI_FLOAT32_t
ret
;
GI_FLOAT32_t
ret
;
for
(
size_t
i
=
0
;
i
<
GI_SIMD_LEN_BYTE
/
sizeof
(
int32_t
);
i
++
)
{
for
(
size_t
i
=
0
;
i
<
GI_SIMD_LEN_BYTE
/
sizeof
(
int32_t
);
i
++
)
{
ret
[
i
]
=
float
(
Vector
[
i
]);
ret
[
i
]
=
(
float
)
Vector
[
i
];
}
return
ret
;
#endif
}
GI_FORCEINLINE
GI_FLOAT32_t
GiBroadcastFloat32
(
float
Value
)
{
#if defined(GI_NEON_INTRINSICS)
return
vdupq_n_f32
(
Value
);
#elif defined(GI_SSE2_INTRINSICS)
return
_mm_set1_ps
(
Value
);
#else
GI_FLOAT32_t
ret
;
for
(
size_t
i
=
0
;
i
<
GI_SIMD_LEN_BYTE
/
sizeof
(
float
);
i
++
)
{
ret
[
i
]
=
Value
;
}
}
return
ret
;
return
ret
;
#endif
#endif
...
...
dnn/src/fallback/general_intrinsic/gi_int.h
浏览文件 @
39d98d45
...
@@ -13,21 +13,6 @@
...
@@ -13,21 +13,6 @@
#include "gi_common.h"
#include "gi_common.h"
GI_FORCEINLINE
GI_INT32_t
GiBroadcastInt32
(
int32_t
Value
)
{
#if defined(GI_NEON_INTRINSICS)
return
vdupq_n_s32
(
Value
);
#elif defined(GI_SSE2_INTRINSICS)
return
_mm_set1_epi32
(
Value
);
#else
GI_INT32_t
ret
;
for
(
size_t
i
=
0
;
i
<
GI_SIMD_LEN_BYTE
/
sizeof
(
int32_t
);
i
++
)
{
ret
[
i
]
=
Value
;
}
return
ret
;
#endif
}
GI_FORCEINLINE
GI_FORCEINLINE
GI_UINT32_t
GiBroadcastUint32
(
int32_t
Value
)
{
GI_UINT32_t
GiBroadcastUint32
(
int32_t
Value
)
{
#if defined(GI_NEON_INTRINSICS)
#if defined(GI_NEON_INTRINSICS)
...
@@ -44,30 +29,31 @@ GI_UINT32_t GiBroadcastUint32(int32_t Value) {
...
@@ -44,30 +29,31 @@ GI_UINT32_t GiBroadcastUint32(int32_t Value) {
}
}
GI_FORCEINLINE
GI_FORCEINLINE
GI_INT
8_t
GiBroadcastInt8
(
int8_t
Value
)
{
GI_INT
32_t
GiLoadInt32
(
const
void
*
Buffer
)
{
#if defined(GI_NEON_INTRINSICS)
#if defined(GI_NEON_INTRINSICS)
return
v
dupq_n_s8
(
Value
);
return
v
ld1q_s32
((
int32_t
*
)
Buffer
);
#elif defined(GI_SSE2_INTRINSICS)
#elif defined(GI_SSE2_INTRINSICS)
return
_mm_
set1_epi8
(
Value
);
return
_mm_
loadu_si128
((
const
__m128i
*
)
Buffer
);
#else
#else
GI_INT8_t
ret
;
GI_INT32_t
ret
;
for
(
size_t
i
=
0
;
i
<
GI_SIMD_LEN_BYTE
/
sizeof
(
int8_t
);
i
++
)
{
const
int32_t
*
ptr
=
(
int32_t
*
)
Buffer
;
ret
[
i
]
=
Value
;
for
(
size_t
i
=
0
;
i
<
GI_SIMD_LEN_BYTE
/
sizeof
(
int32_t
);
i
++
)
{
ret
[
i
]
=
ptr
[
i
];
}
}
return
ret
;
return
ret
;
#endif
#endif
}
}
GI_FORCEINLINE
GI_FORCEINLINE
GI_INT
32_t
GiLoadInt32
(
const
void
*
Buffer
)
{
GI_INT
16_t
GiLoadInt16
(
const
void
*
Buffer
)
{
#if defined(GI_NEON_INTRINSICS)
#if defined(GI_NEON_INTRINSICS)
return
vld1q_s
32
((
int32
_t
*
)
Buffer
);
return
vld1q_s
16
((
int16
_t
*
)
Buffer
);
#elif defined(GI_SSE2_INTRINSICS)
#elif defined(GI_SSE2_INTRINSICS)
return
_mm_loadu_si128
((
const
__m128i
*
)
Buffer
);
return
_mm_loadu_si128
((
const
__m128i
*
)
Buffer
);
#else
#else
GI_INT
32
_t
ret
;
GI_INT
16
_t
ret
;
const
int
32_t
*
ptr
=
(
int32
_t
*
)
Buffer
;
const
int
16_t
*
ptr
=
(
int16
_t
*
)
Buffer
;
for
(
size_t
i
=
0
;
i
<
GI_SIMD_LEN_BYTE
/
sizeof
(
int
32
_t
);
i
++
)
{
for
(
size_t
i
=
0
;
i
<
GI_SIMD_LEN_BYTE
/
sizeof
(
int
16
_t
);
i
++
)
{
ret
[
i
]
=
ptr
[
i
];
ret
[
i
]
=
ptr
[
i
];
}
}
return
ret
;
return
ret
;
...
@@ -810,21 +796,12 @@ GI_INT8_t GiCvtFromFloat32ToInt8(GI_FLOAT32_t src) {
...
@@ -810,21 +796,12 @@ GI_INT8_t GiCvtFromFloat32ToInt8(GI_FLOAT32_t src) {
int16x8_t
mid_s16
=
vcombine_s16
(
vqmovn_s32
(
vres0
),
vqmovn_s32
(
vres0
));
int16x8_t
mid_s16
=
vcombine_s16
(
vqmovn_s32
(
vres0
),
vqmovn_s32
(
vres0
));
return
vcombine_s8
(
vqmovn_s16
(
mid_s16
),
vqmovn_s16
(
mid_s16
));
return
vcombine_s8
(
vqmovn_s16
(
mid_s16
),
vqmovn_s16
(
mid_s16
));
#else
#else
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vinc0
=
vbslq_f32
(
vcgeq_f32
(
src
,
vfzero
),
vfhalf
,
vfneg_half
);
float32x4_t
vfhalf
=
vdupq_n_f32
(
0.5
f
);
float32x4_t
vfneg_half
=
vdupq_n_f32
(
-
0.5
f
);
float32x4_t
vinc0
=
vbslq_f32
(
vcgeq_f32
(
src
,
vzero
),
vfhalf
,
vfneg_half
);
int32x4_t
vres0
=
vcvtq_s32_f32
(
vaddq_f32
(
src
,
vinc0
));
int32x4_t
vres0
=
vcvtq_s32_f32
(
vaddq_f32
(
src
,
vinc0
));
int16x8_t
mid_s16
=
vcombine_s16
(
vqmovn_s32
(
vres0
),
vqmovn_s32
(
vres0
));
int16x8_t
mid_s16
=
vcombine_s16
(
vqmovn_s32
(
vres0
),
vqmovn_s32
(
vres0
));
return
vcombine_s8
(
vqmovn_s16
(
mid_s16
),
vqmovn_s16
(
mid_s16
));
return
vcombine_s8
(
vqmovn_s16
(
mid_s16
),
vqmovn_s16
(
mid_s16
));
#endif
#endif
#elif defined(GI_SSE42_INTRINSICS)
#elif defined(GI_SSE42_INTRINSICS)
__m128
vfzero
=
_mm_set1_ps
(
0.
f
);
__m128
vfhalf
=
_mm_set1_ps
(
0.5
f
);
__m128
vfneg_half
=
_mm_set1_ps
(
-
0.5
f
);
__m128
vfmin_int8
=
_mm_set1_ps
(
-
128.
f
);
__m128
vfmax_int8
=
_mm_set1_ps
(
127.
f
);
__m128
vinc0
=
_mm_blendv_ps
(
vfneg_half
,
vfhalf
,
_mm_cmpge_ps
(
src
,
vfzero
));
__m128
vinc0
=
_mm_blendv_ps
(
vfneg_half
,
vfhalf
,
_mm_cmpge_ps
(
src
,
vfzero
));
__m128
vres0
=
_mm_add_ps
(
src
,
vinc0
);
__m128
vres0
=
_mm_add_ps
(
src
,
vinc0
);
vres0
=
_mm_round_ps
(
vres0
,
_MM_FROUND_TO_ZERO
|
_MM_FROUND_NO_EXC
);
vres0
=
_mm_round_ps
(
vres0
,
_MM_FROUND_TO_ZERO
|
_MM_FROUND_NO_EXC
);
...
@@ -857,23 +834,14 @@ GI_INT8_t GiCvtFromFloat32V2ToInt8(GI_FLOAT32_V2_t vsrc) {
...
@@ -857,23 +834,14 @@ GI_INT8_t GiCvtFromFloat32V2ToInt8(GI_FLOAT32_V2_t vsrc) {
int8x8_t
mid1
=
vqmovn_s16
(
vcombine_s16
(
vqmovn_s32
(
vres0
),
vqmovn_s32
(
vres1
)));
int8x8_t
mid1
=
vqmovn_s16
(
vcombine_s16
(
vqmovn_s32
(
vres0
),
vqmovn_s32
(
vres1
)));
return
vcombine_s8
(
mid1
,
mid1
);
return
vcombine_s8
(
mid1
,
mid1
);
#else
#else
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vinc0
=
vbslq_f32
(
vcgeq_f32
(
vsrc
.
val
[
0
],
vfzero
),
vfhalf
,
vfneg_half
);
float32x4_t
vfhalf
=
vdupq_n_f32
(
0.5
f
);
float32x4_t
vinc1
=
vbslq_f32
(
vcgeq_f32
(
vsrc
.
val
[
1
],
vfzero
),
vfhalf
,
vfneg_half
);
float32x4_t
vfneg_half
=
vdupq_n_f32
(
-
0.5
f
);
float32x4_t
vinc0
=
vbslq_f32
(
vcgeq_f32
(
vsrc
.
val
[
0
],
vzero
),
vfhalf
,
vfneg_half
);
float32x4_t
vinc1
=
vbslq_f32
(
vcgeq_f32
(
vsrc
.
val
[
1
],
vzero
),
vfhalf
,
vfneg_half
);
int32x4_t
vres0
=
vcvtq_s32_f32
(
vaddq_f32
(
vsrc
.
val
[
0
],
vinc0
));
int32x4_t
vres0
=
vcvtq_s32_f32
(
vaddq_f32
(
vsrc
.
val
[
0
],
vinc0
));
int32x4_t
vres1
=
vcvtq_s32_f32
(
vaddq_f32
(
vsrc
.
val
[
1
],
vinc1
));
int32x4_t
vres1
=
vcvtq_s32_f32
(
vaddq_f32
(
vsrc
.
val
[
1
],
vinc1
));
int8x8_t
mid1
=
vqmovn_s16
(
vcombine_s16
(
vqmovn_s32
(
vres0
),
vqmovn_s32
(
vres1
)));
int8x8_t
mid1
=
vqmovn_s16
(
vcombine_s16
(
vqmovn_s32
(
vres0
),
vqmovn_s32
(
vres1
)));
return
vcombine_s8
(
mid1
,
mid1
);
return
vcombine_s8
(
mid1
,
mid1
);
#endif
#endif
#elif defined(GI_SSE42_INTRINSICS)
#elif defined(GI_SSE42_INTRINSICS)
__m128
vfzero
=
_mm_set1_ps
(
0.
f
);
__m128
vfhalf
=
_mm_set1_ps
(
0.5
f
);
__m128
vfneg_half
=
_mm_set1_ps
(
-
0.5
f
);
__m128
vfmin_int8
=
_mm_set1_ps
(
-
128.
f
);
__m128
vfmax_int8
=
_mm_set1_ps
(
127.
f
);
__m128
vinc0
=
_mm_blendv_ps
(
vfneg_half
,
vfhalf
,
_mm_cmpge_ps
(
vsrc
.
val
[
0
],
vfzero
));
__m128
vinc0
=
_mm_blendv_ps
(
vfneg_half
,
vfhalf
,
_mm_cmpge_ps
(
vsrc
.
val
[
0
],
vfzero
));
__m128
vinc1
=
_mm_blendv_ps
(
vfneg_half
,
vfhalf
,
_mm_cmpge_ps
(
vsrc
.
val
[
1
],
vfzero
));
__m128
vinc1
=
_mm_blendv_ps
(
vfneg_half
,
vfhalf
,
_mm_cmpge_ps
(
vsrc
.
val
[
1
],
vfzero
));
...
@@ -913,13 +881,13 @@ GI_INT8_t GiCvtFromFloat32V4ToInt8(GI_FLOAT32_V4_t vsrc) {
...
@@ -913,13 +881,13 @@ GI_INT8_t GiCvtFromFloat32V4ToInt8(GI_FLOAT32_V4_t vsrc) {
int8x8_t
mid2
=
vqmovn_s16
(
vcombine_s16
(
vqmovn_s32
(
vres2
),
vqmovn_s32
(
vres3
)));
int8x8_t
mid2
=
vqmovn_s16
(
vcombine_s16
(
vqmovn_s32
(
vres2
),
vqmovn_s32
(
vres3
)));
return
vcombine_s8
(
mid1
,
mid2
);
return
vcombine_s8
(
mid1
,
mid2
);
#else
#else
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
float32x4_t
v
f
zero
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vfhalf
=
vdupq_n_f32
(
0.5
f
);
float32x4_t
vfhalf
=
vdupq_n_f32
(
0.5
f
);
float32x4_t
vfneg_half
=
vdupq_n_f32
(
-
0.5
f
);
float32x4_t
vfneg_half
=
vdupq_n_f32
(
-
0.5
f
);
float32x4_t
vinc0
=
vbslq_f32
(
vcgeq_f32
(
vsrc
.
val
[
0
],
vzero
),
vfhalf
,
vfneg_half
);
float32x4_t
vinc0
=
vbslq_f32
(
vcgeq_f32
(
vsrc
.
val
[
0
],
v
f
zero
),
vfhalf
,
vfneg_half
);
float32x4_t
vinc1
=
vbslq_f32
(
vcgeq_f32
(
vsrc
.
val
[
1
],
vzero
),
vfhalf
,
vfneg_half
);
float32x4_t
vinc1
=
vbslq_f32
(
vcgeq_f32
(
vsrc
.
val
[
1
],
v
f
zero
),
vfhalf
,
vfneg_half
);
float32x4_t
vinc2
=
vbslq_f32
(
vcgeq_f32
(
vsrc
.
val
[
2
],
vzero
),
vfhalf
,
vfneg_half
);
float32x4_t
vinc2
=
vbslq_f32
(
vcgeq_f32
(
vsrc
.
val
[
2
],
v
f
zero
),
vfhalf
,
vfneg_half
);
float32x4_t
vinc3
=
vbslq_f32
(
vcgeq_f32
(
vsrc
.
val
[
3
],
vzero
),
vfhalf
,
vfneg_half
);
float32x4_t
vinc3
=
vbslq_f32
(
vcgeq_f32
(
vsrc
.
val
[
3
],
v
f
zero
),
vfhalf
,
vfneg_half
);
int32x4_t
vres0
=
vcvtq_s32_f32
(
vaddq_f32
(
vsrc
.
val
[
0
],
vinc0
));
int32x4_t
vres0
=
vcvtq_s32_f32
(
vaddq_f32
(
vsrc
.
val
[
0
],
vinc0
));
int32x4_t
vres1
=
vcvtq_s32_f32
(
vaddq_f32
(
vsrc
.
val
[
1
],
vinc1
));
int32x4_t
vres1
=
vcvtq_s32_f32
(
vaddq_f32
(
vsrc
.
val
[
1
],
vinc1
));
int32x4_t
vres2
=
vcvtq_s32_f32
(
vaddq_f32
(
vsrc
.
val
[
2
],
vinc2
));
int32x4_t
vres2
=
vcvtq_s32_f32
(
vaddq_f32
(
vsrc
.
val
[
2
],
vinc2
));
...
@@ -929,12 +897,6 @@ GI_INT8_t GiCvtFromFloat32V4ToInt8(GI_FLOAT32_V4_t vsrc) {
...
@@ -929,12 +897,6 @@ GI_INT8_t GiCvtFromFloat32V4ToInt8(GI_FLOAT32_V4_t vsrc) {
return
vcombine_s8
(
mid1
,
mid2
);
return
vcombine_s8
(
mid1
,
mid2
);
#endif
#endif
#elif defined(GI_SSE42_INTRINSICS)
#elif defined(GI_SSE42_INTRINSICS)
__m128
vfzero
=
_mm_set1_ps
(
0.
f
);
__m128
vfhalf
=
_mm_set1_ps
(
0.5
f
);
__m128
vfneg_half
=
_mm_set1_ps
(
-
0.5
f
);
__m128
vfmin_int8
=
_mm_set1_ps
(
-
128.
f
);
__m128
vfmax_int8
=
_mm_set1_ps
(
127.
f
);
__m128
vinc0
=
_mm_blendv_ps
(
vfneg_half
,
vfhalf
,
_mm_cmpge_ps
(
vsrc
.
val
[
0
],
vfzero
));
__m128
vinc0
=
_mm_blendv_ps
(
vfneg_half
,
vfhalf
,
_mm_cmpge_ps
(
vsrc
.
val
[
0
],
vfzero
));
__m128
vinc1
=
_mm_blendv_ps
(
vfneg_half
,
vfhalf
,
_mm_cmpge_ps
(
vsrc
.
val
[
1
],
vfzero
));
__m128
vinc1
=
_mm_blendv_ps
(
vfneg_half
,
vfhalf
,
_mm_cmpge_ps
(
vsrc
.
val
[
1
],
vfzero
));
__m128
vinc2
=
_mm_blendv_ps
(
vfneg_half
,
vfhalf
,
_mm_cmpge_ps
(
vsrc
.
val
[
2
],
vfzero
));
__m128
vinc2
=
_mm_blendv_ps
(
vfneg_half
,
vfhalf
,
_mm_cmpge_ps
(
vsrc
.
val
[
2
],
vfzero
));
...
...
dnn/src/fallback/quantized_converter.h
浏览文件 @
39d98d45
...
@@ -20,16 +20,6 @@
...
@@ -20,16 +20,6 @@
namespace
megdnn
{
namespace
megdnn
{
namespace
fallback
{
namespace
fallback
{
struct
QConverterBase
{
inline
static
GI_INT32_t
vzero
()
{
return
GiBroadcastInt32
(
0
);
}
inline
static
GI_FLOAT32_t
vfzero
()
{
return
GiBroadcastFloat32
(
0.
f
);
}
inline
static
GI_FLOAT32_t
vfhalf
()
{
return
GiBroadcastFloat32
(
0.5
f
);
}
inline
static
GI_FLOAT32_t
vfneg_half
()
{
return
GiBroadcastFloat32
(
-
0.5
f
);
}
};
struct
QConverter
{
struct
QConverter
{
template
<
typename
dst_type
,
typename
...
src_type
>
template
<
typename
dst_type
,
typename
...
src_type
>
static
inline
dst_type
convert
(
const
src_type
&
...
src
);
static
inline
dst_type
convert
(
const
src_type
&
...
src
);
...
@@ -66,6 +56,12 @@ template <>
...
@@ -66,6 +56,12 @@ template <>
inline
GI_INT8_t
QConverter
::
convert
(
const
GI_FLOAT32_V2_t
&
vsrc
)
{
inline
GI_INT8_t
QConverter
::
convert
(
const
GI_FLOAT32_V2_t
&
vsrc
)
{
return
GiCvtFromFloat32V2ToInt8
(
vsrc
);
return
GiCvtFromFloat32V2ToInt8
(
vsrc
);
}
}
template
<
>
inline
GI_INT8_t
QConverter
::
convert
(
const
GI_FLOAT32_V4_t
&
vsrc
)
{
return
GiCvtFromFloat32V4ToInt8
(
vsrc
);
}
template
<
>
template
<
>
inline
GI_INT8_t
QConverter
::
convert
(
const
GI_FLOAT32_t
&
src
)
{
inline
GI_INT8_t
QConverter
::
convert
(
const
GI_FLOAT32_t
&
src
)
{
return
GiCvtFromFloat32ToInt8
(
src
);
return
GiCvtFromFloat32ToInt8
(
src
);
...
...
dnn/src/fallback/type_cvt/opr_impl.cpp
浏览文件 @
39d98d45
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
*/
#include "src/fallback/type_cvt/opr_impl.h"
#include "src/fallback/type_cvt/opr_impl.h"
#include "src/fallback/type_cvt/typecvt_helper.h"
#include "midout.h"
#include "midout.h"
#include "src/common/utils.h"
#include "src/common/utils.h"
...
@@ -17,6 +18,7 @@
...
@@ -17,6 +18,7 @@
// MIDOUT_DECL(megdnn_fb_typecvt_src)
// MIDOUT_DECL(megdnn_fb_typecvt_src)
MIDOUT_DECL
(
megdnn_fb_typecvt_dst_dtype
)
MIDOUT_DECL
(
megdnn_fb_typecvt_dst_dtype
)
MIDOUT_DECL
(
megdnn_fb_typecvt_src_dtype
)
MIDOUT_DECL
(
megdnn_fb_typecvt_src_dtype
)
MIDOUT_DECL
(
megdnn_fb_typecvt_optimized
)
namespace
{
namespace
{
...
@@ -513,12 +515,68 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
...
@@ -513,12 +515,68 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
!
is_quantize_lowbit
(
dst
.
layout
.
dtype
)
&&
!
is_quantize_lowbit
(
dst
.
layout
.
dtype
)
&&
dst
.
layout
.
dtype
.
enumv
()
!=
DTypeEnum
::
QuantizedS1
&&
dst
.
layout
.
dtype
.
enumv
()
!=
DTypeEnum
::
QuantizedS1
&&
src
.
layout
.
dtype
.
enumv
()
!=
DTypeEnum
::
QuantizedS1
)
{
src
.
layout
.
dtype
.
enumv
()
!=
DTypeEnum
::
QuantizedS1
)
{
if
(
!
exec_optimized
(
src
,
dst
))
{
MEGDNN_DISPATCH_CPU_KERN_OPR
(
run_contiguous
(
src
,
dst
));
MEGDNN_DISPATCH_CPU_KERN_OPR
(
run_contiguous
(
src
,
dst
));
}
}
else
{
}
else
{
naive
::
TypeCvtImpl
::
exec
(
src
,
dst
);
naive
::
TypeCvtImpl
::
exec
(
src
,
dst
);
}
}
}
}
bool
TypeCvtImpl
::
exec_optimized
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
)
{
DType
src_dtype
=
src
.
layout
.
dtype
;
DType
dst_dtype
=
dst
.
layout
.
dtype
;
bool
execed
=
false
;
using
namespace
dtype
;
size_t
nr_elems
=
src
.
layout
.
total_nr_elems
();
#define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \
dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \
MIDOUT_BEGIN(megdnn_fb_typecvt_optimized, midout_iv(_midout_iv)) { \
using _TypeCvter = QuantizedTypeCvter<_stype, _dtype>; \
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \
src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \
src_dtype, dst_dtype, nr_elems)); \
execed = true; \
} \
MIDOUT_END(); \
}
DISPATCH_QUANTIZED
(
QuantizedS32
,
int32_t
,
QuantizedS8
,
int8_t
,
1
);
DISPATCH_QUANTIZED
(
QuantizedS8
,
int8_t
,
QuantizedS32
,
int32_t
,
2
);
DISPATCH_QUANTIZED
(
QuantizedS8
,
int8_t
,
QuantizedS8
,
int8_t
,
3
);
DISPATCH_QUANTIZED
(
QuantizedS32
,
int32_t
,
QuantizedS32
,
int32_t
,
4
);
DISPATCH_QUANTIZED
(
float
,
float
,
QuantizedS8
,
int8_t
,
5
);
#undef DISPATCH_QUANTIZED
#define DISPATCH_FIX2FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \
dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \
MIDOUT_BEGIN(megdnn_fb_typecvt_optimized, midout_iv(_midout_iv)) { \
using _TypeCvter = Fix2FloatTypeCvter<_stype, _dtype>; \
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \
src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \
src_dtype, dst_dtype, src.layout.total_nr_elems())); \
execed = true; \
} \
MIDOUT_END(); \
}
DISPATCH_FIX2FLOAT
(
Int16
,
int16_t
,
Float32
,
float
,
6
);
DISPATCH_FIX2FLOAT
(
Int8
,
int8_t
,
Float32
,
float
,
7
);
if
(
src_dtype
.
enumv
()
==
DTypeTrait
<
QuantizedS8
>::
enumv
&&
dst_dtype
.
enumv
()
==
DTypeTrait
<
Float32
>::
enumv
)
{
MIDOUT_BEGIN
(
megdnn_fb_typecvt_optimized
,
midout_iv
(
8
))
{
using
TypeCvter
=
Quan2FloatTypeCvter
<
int8_t
,
float
>
;
MEGDNN_DISPATCH_CPU_KERN_OPR
(
do_typecvt
<
TypeCvter
>
(
src
.
compatible_ptr
<
int8_t
>
(),
dst
.
compatible_ptr
<
float
>
(),
src_dtype
,
dst_dtype
,
src
.
layout
.
total_nr_elems
()));
execed
=
true
;
}
MIDOUT_END
();
}
return
execed
;
}
}
// namespace fallback
}
// namespace fallback
}
// namespace megdnn
}
// namespace megdnn
...
...
dnn/src/fallback/type_cvt/opr_impl.h
浏览文件 @
39d98d45
...
@@ -15,6 +15,8 @@ namespace megdnn {
...
@@ -15,6 +15,8 @@ namespace megdnn {
namespace
fallback
{
namespace
fallback
{
class
TypeCvtImpl
:
public
naive
::
TypeCvtImpl
{
class
TypeCvtImpl
:
public
naive
::
TypeCvtImpl
{
bool
exec_optimized
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
);
public:
public:
using
naive
::
TypeCvtImpl
::
TypeCvtImpl
;
using
naive
::
TypeCvtImpl
::
TypeCvtImpl
;
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
)
override
;
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
)
override
;
...
...
dnn/src/fallback/type_cvt/typecvt_helper.h
0 → 100644
浏览文件 @
39d98d45
/**
* \file dnn/src/fallback/type_cvt/typecvt_helper.h
*/
#include "src/fallback/general_intrinsic/gi_float.h"
#include "src/fallback/general_intrinsic/gi_int.h"
#include "src/fallback/quantized_converter.h"
namespace
megdnn
{
namespace
fallback
{
template
<
typename
ctype
,
typename
dtype
>
struct
QuantizedTypeCvter
;
template
<
>
struct
QuantizedTypeCvter
<
int32_t
,
int8_t
>
{
using
stype
=
int32_t
;
using
dst_type
=
int8_t
;
static
constexpr
size_t
SIMD_WIDTH
=
GI_SIMD_LEN_BYTE
/
sizeof
(
int32_t
)
*
2
;
static
constexpr
size_t
SIMD_STEP
=
GI_SIMD_LEN_BYTE
/
sizeof
(
int32_t
);
float
scale
;
GI_FLOAT32_t
vscale
;
QuantizedTypeCvter
(
DType
src_dtype
,
DType
dst_dtype
)
{
float
src_scale
=
src_dtype
.
param
<
dtype
::
QuantizedS32
>
().
scale
;
float
dst_scale
=
dst_dtype
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
scale
=
src_scale
/
dst_scale
;
vscale
=
GiBroadcastFloat32
(
scale
);
}
void
cvt
(
const
int32_t
*
src
,
int8_t
*
dst
)
{
GI_FLOAT32_t
vitem0
=
GiMultiplyFloat32
(
GiCastToFloat32
(
GiLoadInt32
(
src
)),
vscale
);
GI_FLOAT32_t
vitem1
=
GiMultiplyFloat32
(
GiCastToFloat32
(
GiLoadInt32
(
src
+
SIMD_STEP
)),
vscale
);
auto
vres
=
QConverter
::
convert
<
GI_INT8_t
,
GI_FLOAT32_V2_t
>
({{
vitem0
,
vitem1
}});
GiStoreLowInt8
(
dst
,
vres
);
}
void
cvt_remain
(
const
int32_t
*
src
,
int8_t
*
dst
)
{
*
dst
=
saturate
<
int8_t
,
float
>
(
std
::
round
(
*
src
*
scale
),
-
128.
f
,
127.
f
);
}
};
template
<
>
struct
QuantizedTypeCvter
<
int8_t
,
int32_t
>
{
using
stype
=
int8_t
;
using
dst_type
=
int32_t
;
static
constexpr
size_t
SIMD_WIDTH
=
GI_SIMD_LEN_BYTE
/
sizeof
(
int8_t
);
float
scale
;
GI_FLOAT32_t
vscale
;
QuantizedTypeCvter
(
DType
src_dtype
,
DType
dst_dtype
)
{
float
src_scale
=
src_dtype
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
float
dst_scale
=
dst_dtype
.
param
<
dtype
::
QuantizedS32
>
().
scale
;
scale
=
src_scale
/
dst_scale
;
vscale
=
GiBroadcastFloat32
(
scale
);
}
void
cvt
(
const
int8_t
*
src
,
int32_t
*
dst
)
{
GI_INT8_t
data
=
GiLoadInt8
(
src
);
GI_INT16_t
vitem0
=
GiMoveLowLongInt8
(
data
);
GI_INT16_t
vitem1
=
GiMoveHighLongInt8
(
data
);
auto
vret0
=
QConverter
::
round
<
GI_INT32_t
,
GI_FLOAT32_t
>
(
GiMultiplyFloat32
(
GiCastToFloat32
(
GiMoveLowLongInt16
(
vitem0
)),
vscale
));
auto
vret1
=
QConverter
::
round
<
GI_INT32_t
,
GI_FLOAT32_t
>
(
GiMultiplyFloat32
(
GiCastToFloat32
(
GiMoveHighLongInt16
(
vitem0
)),
vscale
));
auto
vret2
=
QConverter
::
round
<
GI_INT32_t
,
GI_FLOAT32_t
>
(
GiMultiplyFloat32
(
GiCastToFloat32
(
GiMoveLowLongInt16
(
vitem1
)),
vscale
));
auto
vret3
=
QConverter
::
round
<
GI_INT32_t
,
GI_FLOAT32_t
>
(
GiMultiplyFloat32
(
GiCastToFloat32
(
GiMoveHighLongInt16
(
vitem1
)),
vscale
));
constexpr
size_t
step
=
GI_SIMD_LEN_BYTE
/
sizeof
(
int32_t
);
GiStoreInt32
(
dst
,
vret0
);
GiStoreInt32
(
dst
+
step
,
vret1
);
GiStoreInt32
(
dst
+
2
*
step
,
vret2
);
GiStoreInt32
(
dst
+
3
*
step
,
vret3
);
}
void
cvt_remain
(
const
int8_t
*
src
,
int32_t
*
dst
)
{
*
dst
=
saturate
<
int32_t
,
float
>
(
std
::
round
(
*
src
*
scale
),
-
2147483648.
f
,
2147483647.
f
);
}
};
template
<
>
struct
QuantizedTypeCvter
<
float
,
int8_t
>
{
using
stype
=
float
;
using
dst_type
=
int8_t
;
static
constexpr
size_t
SIMD_WIDTH
=
GI_SIMD_LEN_BYTE
/
sizeof
(
float
)
*
2
;
static
constexpr
size_t
SIMD_STEP
=
GI_SIMD_LEN_BYTE
/
sizeof
(
float
);
float
scale
;
GI_FLOAT32_t
vscale
;
QuantizedTypeCvter
(
DType
src_dtype
,
DType
dst_dtype
)
{
MEGDNN_MARK_USED_VAR
(
src_dtype
);
float
src_scale
=
1
;
float
dst_scale
=
dst_dtype
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
scale
=
src_scale
/
dst_scale
;
vscale
=
GiBroadcastFloat32
(
scale
);
}
void
cvt
(
const
float
*
src
,
int8_t
*
dst
)
{
GI_FLOAT32_t
vitem0
=
GiMultiplyFloat32
(
GiLoadFloat32
(
src
),
vscale
);
GI_FLOAT32_t
vitem1
=
GiMultiplyFloat32
(
GiLoadFloat32
(
src
+
SIMD_STEP
),
vscale
);
auto
vres
=
QConverter
::
convert
<
GI_INT8_t
,
GI_FLOAT32_V2_t
>
({{
vitem0
,
vitem1
}});
GiStoreLowInt8
(
dst
,
vres
);
}
void
cvt_remain
(
const
float
*
src
,
int8_t
*
dst
)
{
*
dst
=
saturate
<
int8_t
,
float
>
(
std
::
round
(
*
src
*
scale
),
-
128.
f
,
127.
f
);
}
};
template
<
>
struct
QuantizedTypeCvter
<
int32_t
,
int32_t
>
{
using
stype
=
int32_t
;
using
dst_type
=
int32_t
;
static
constexpr
size_t
SIMD_WIDTH
=
GI_SIMD_LEN_BYTE
/
sizeof
(
int32_t
);
float
scale
;
GI_FLOAT32_t
vscale
;
QuantizedTypeCvter
(
DType
src_dtype
,
DType
dst_dtype
)
{
float
src_scale
=
src_dtype
.
param
<
dtype
::
QuantizedS32
>
().
scale
;
float
dst_scale
=
dst_dtype
.
param
<
dtype
::
QuantizedS32
>
().
scale
;
scale
=
src_scale
/
dst_scale
;
vscale
=
GiBroadcastFloat32
(
scale
);
}
void
cvt
(
const
int32_t
*
src
,
int32_t
*
dst
)
{
GI_FLOAT32_t
vitem
=
GiMultiplyFloat32
(
GiCastToFloat32
(
GiLoadInt32
(
src
)),
vscale
);
auto
vres
=
QConverter
::
round
<
GI_INT32_t
,
GI_FLOAT32_t
>
(
vitem
);
GiStoreInt32
(
dst
,
vres
);
}
void
cvt_remain
(
const
int32_t
*
src
,
int32_t
*
dst
)
{
*
dst
=
saturate
<
int32_t
,
float
>
(
std
::
round
(
*
src
*
scale
),
-
2147483648.
f
,
2147483647.
f
);
}
};
template
<
>
struct
QuantizedTypeCvter
<
int8_t
,
int8_t
>
{
using
stype
=
int8_t
;
using
dst_type
=
int8_t
;
static
constexpr
size_t
SIMD_WIDTH
=
GI_SIMD_LEN_BYTE
/
sizeof
(
int8_t
);
float
scale
;
GI_FLOAT32_t
vscale
;
QuantizedTypeCvter
(
DType
src_dtype
,
DType
dst_dtype
)
{
float
src_scale
=
src_dtype
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
float
dst_scale
=
dst_dtype
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
scale
=
src_scale
/
dst_scale
;
vscale
=
GiBroadcastFloat32
(
scale
);
}
void
cvt
(
const
int8_t
*
src
,
int8_t
*
dst
)
{
GI_INT8_t
data
=
GiLoadInt8
(
src
);
GI_INT16_t
vitem0
=
GiMoveLowLongInt8
(
data
);
GI_INT16_t
vitem1
=
GiMoveHighLongInt8
(
data
);
auto
vret0
=
GiMultiplyFloat32
(
GiCastToFloat32
(
GiMoveLowLongInt16
(
vitem0
)),
vscale
);
auto
vret1
=
GiMultiplyFloat32
(
GiCastToFloat32
(
GiMoveHighLongInt16
(
vitem0
)),
vscale
);
auto
vret2
=
GiMultiplyFloat32
(
GiCastToFloat32
(
GiMoveLowLongInt16
(
vitem1
)),
vscale
);
auto
vret3
=
GiMultiplyFloat32
(
GiCastToFloat32
(
GiMoveHighLongInt16
(
vitem1
)),
vscale
);
auto
vres
=
QConverter
::
convert
<
GI_INT8_t
,
GI_FLOAT32_V4_t
>
(
{{
vret0
,
vret1
,
vret2
,
vret3
}});
GiStoreInt8
(
dst
,
vres
);
}
void
cvt_remain
(
const
int8_t
*
src
,
int8_t
*
dst
)
{
*
dst
=
saturate
<
int8_t
,
float
>
(
std
::
round
(
*
src
*
scale
),
-
128.
f
,
127.
f
);
}
};
template
<
typename
ctype
,
typename
dtype
>
struct
Fix2FloatTypeCvter
;
template
<
typename
ctype
,
typename
dtype
>
struct
Quan2FloatTypeCvter
;
template
<
>
struct
Fix2FloatTypeCvter
<
int16_t
,
float
>
{
using
stype
=
int16_t
;
using
dst_type
=
float
;
static
constexpr
size_t
SIMD_WIDTH
=
GI_SIMD_LEN_BYTE
/
sizeof
(
int16_t
);
static
constexpr
size_t
SIMD_STEP
=
GI_SIMD_LEN_BYTE
/
sizeof
(
float
);
Fix2FloatTypeCvter
(
DType
src_dtype
,
DType
dst_dtype
)
{
MEGDNN_MARK_USED_VAR
(
src_dtype
);
MEGDNN_MARK_USED_VAR
(
dst_dtype
);
}
void
cvt
(
const
int16_t
*
src
,
float
*
dst
)
{
GI_INT16_t
vitem
=
GiLoadInt16
(
src
);
auto
vret0
=
GiCastToFloat32
(
GiMoveLowLongInt16
(
vitem
));
auto
vret1
=
GiCastToFloat32
(
GiMoveHighLongInt16
(
vitem
));
GiStoreFloat32
(
dst
,
vret0
);
GiStoreFloat32
(
dst
+
SIMD_STEP
,
vret1
);
}
void
cvt_remain
(
const
int16_t
*
src
,
float
*
dst
)
{
*
dst
=
*
src
;
}
};
template
<
>
struct
Fix2FloatTypeCvter
<
int8_t
,
float
>
{
using
stype
=
int8_t
;
using
dst_type
=
float
;
static
constexpr
size_t
SIMD_WIDTH
=
GI_SIMD_LEN_BYTE
/
sizeof
(
int8_t
);
static
constexpr
size_t
SIMD_STEP
=
GI_SIMD_LEN_BYTE
/
sizeof
(
float
);
Fix2FloatTypeCvter
(
DType
src_dtype
,
DType
dst_dtype
)
{
MEGDNN_MARK_USED_VAR
(
src_dtype
);
MEGDNN_MARK_USED_VAR
(
dst_dtype
);
}
void
cvt
(
const
int8_t
*
src
,
float
*
dst
)
{
GI_INT8_t
data
=
GiLoadInt8
(
src
);
GI_INT16_t
vitem0
=
GiMoveLowLongInt8
(
data
);
GI_INT16_t
vitem1
=
GiMoveHighLongInt8
(
data
);
auto
vret0
=
GiCastToFloat32
(
GiMoveLowLongInt16
(
vitem0
));
auto
vret1
=
GiCastToFloat32
(
GiMoveHighLongInt16
(
vitem0
));
auto
vret2
=
GiCastToFloat32
(
GiMoveLowLongInt16
(
vitem1
));
auto
vret3
=
GiCastToFloat32
(
GiMoveHighLongInt16
(
vitem1
));
GiStoreFloat32
(
dst
,
vret0
);
GiStoreFloat32
(
dst
+
SIMD_STEP
,
vret1
);
GiStoreFloat32
(
dst
+
2
*
SIMD_STEP
,
vret2
);
GiStoreFloat32
(
dst
+
3
*
SIMD_STEP
,
vret3
);
}
void
cvt_remain
(
const
int8_t
*
src
,
float
*
dst
)
{
*
dst
=
*
src
;
}
};
template
<
>
struct
Quan2FloatTypeCvter
<
int8_t
,
float
>
{
using
stype
=
int8_t
;
using
dst_type
=
float
;
static
constexpr
size_t
SIMD_WIDTH
=
GI_SIMD_LEN_BYTE
/
sizeof
(
int8_t
);
static
constexpr
size_t
SIMD_STEP
=
GI_SIMD_LEN_BYTE
/
sizeof
(
float
);
float
_scale
=
0.0
f
;
GI_FLOAT32_t
vscale
;
Quan2FloatTypeCvter
(
DType
src_dtype
,
DType
dst_dtype
)
{
_scale
=
src_dtype
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
vscale
=
GiBroadcastFloat32
(
_scale
);
MEGDNN_MARK_USED_VAR
(
dst_dtype
);
}
void
cvt
(
const
int8_t
*
src
,
float
*
dst
)
{
GI_INT8_t
data
=
GiLoadInt8
(
src
);
GI_INT16_t
vitem0
=
GiMoveLowLongInt8
(
data
);
GI_INT16_t
vitem1
=
GiMoveHighLongInt8
(
data
);
auto
vret0
=
GiMultiplyFloat32
(
GiCastToFloat32
(
GiMoveLowLongInt16
(
vitem0
)),
vscale
);
auto
vret1
=
GiMultiplyFloat32
(
GiCastToFloat32
(
GiMoveHighLongInt16
(
vitem0
)),
vscale
);
auto
vret2
=
GiMultiplyFloat32
(
GiCastToFloat32
(
GiMoveLowLongInt16
(
vitem1
)),
vscale
);
auto
vret3
=
GiMultiplyFloat32
(
GiCastToFloat32
(
GiMoveHighLongInt16
(
vitem1
)),
vscale
);
GiStoreFloat32
(
dst
,
vret0
);
GiStoreFloat32
(
dst
+
SIMD_STEP
,
vret1
);
GiStoreFloat32
(
dst
+
2
*
SIMD_STEP
,
vret2
);
GiStoreFloat32
(
dst
+
3
*
SIMD_STEP
,
vret3
);
}
void
cvt_remain
(
const
int8_t
*
src
,
float
*
dst
)
{
*
dst
=
*
src
*
_scale
;
}
};
template
<
typename
TypeCvter
>
void
do_typecvt
(
const
typename
TypeCvter
::
stype
*
src
,
typename
TypeCvter
::
dst_type
*
dst
,
DType
src_dtype
,
DType
dst_dtype
,
size_t
nr_elems
)
{
TypeCvter
typecvt
(
src_dtype
,
dst_dtype
);
size_t
i
=
0
;
for
(;
i
+
TypeCvter
::
SIMD_WIDTH
<=
nr_elems
;
i
+=
TypeCvter
::
SIMD_WIDTH
)
{
typecvt
.
cvt
(
src
,
dst
);
src
+=
TypeCvter
::
SIMD_WIDTH
;
dst
+=
TypeCvter
::
SIMD_WIDTH
;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for
(;
i
<
nr_elems
;
i
++
)
{
typecvt
.
cvt_remain
(
src
,
dst
);
src
++
;
dst
++
;
}
}
}
// namespace fallback
}
// namespace megdnn
// vim: syntax=cpp.doxygen
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录