Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ffbf8fad
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
ffbf8fad
编写于
3月 25, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(fallback): add general intrinsic to elemwise multitype
GitOrigin-RevId: fe7b335545fd959f917b7df8ee48739ccb2a86ab
上级
484e1f11
变更
12
展开全部
隐藏空白更改
内联
并排
Showing
12 changed file
with
837 addition
and
90 deletion
+837
-90
dnn/src/arm_common/elemwise_helper/elemwise_op.h
dnn/src/arm_common/elemwise_helper/elemwise_op.h
+1
-1
dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp
dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp
+5
-25
dnn/src/fallback/elemwise_helper/elemwise_op.h
dnn/src/fallback/elemwise_helper/elemwise_op.h
+0
-57
dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h
dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h
+1
-1
dnn/src/fallback/elemwise_helper/kimpl/op_base.h
dnn/src/fallback/elemwise_helper/kimpl/op_base.h
+5
-5
dnn/src/fallback/elemwise_helper/kimpl/relu.h
dnn/src/fallback/elemwise_helper/kimpl/relu.h
+1
-1
dnn/src/fallback/elemwise_helper/op_common.h
dnn/src/fallback/elemwise_helper/op_common.h
+50
-0
dnn/src/fallback/elemwise_multi_type/opr_impl.h
dnn/src/fallback/elemwise_multi_type/opr_impl.h
+12
-0
dnn/src/fallback/elemwise_multi_type/quantized_impl.cpp
dnn/src/fallback/elemwise_multi_type/quantized_impl.cpp
+499
-0
dnn/src/fallback/general_intrinsic/gi_common.h
dnn/src/fallback/general_intrinsic/gi_common.h
+1
-0
dnn/test/arm_common/elemwise_multi_type.cpp
dnn/test/arm_common/elemwise_multi_type.cpp
+93
-0
dnn/test/fallback/elemwise_multi_type.cpp
dnn/test/fallback/elemwise_multi_type.cpp
+169
-0
未找到文件。
dnn/src/arm_common/elemwise_helper/elemwise_op.h
浏览文件 @
ffbf8fad
...
...
@@ -15,7 +15,7 @@
#include "src/arm_common/elemwise_helper/op_binary.h"
#include "src/arm_common/elemwise_helper/op_ternary.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/fallback/elemwise_helper/
elemwise_op
.h"
#include "src/fallback/elemwise_helper/
op_common
.h"
namespace
megdnn
{
namespace
elemwise
{
...
...
dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp
浏览文件 @
ffbf8fad
...
...
@@ -364,17 +364,9 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
}
#define DISPATCH() \
if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \
DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \
} else if ( \
param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \
if (param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \
DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, dtype::Quantized8Asymm) \
} else if ( \
param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \
DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \
} else if ( \
param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \
...
...
@@ -467,16 +459,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
#define DISPATCH() \
if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \
DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \
} else if ( \
param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \
DISPATCH_MODE(dtype::QuantizedS32, dtype::Quantized8Asymm) \
} else if ( \
param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \
DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \
} else if ( \
param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \
...
...
@@ -701,12 +685,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
}
#define DISPATCH() \
if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \
DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \
} else if ( \
param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \
if (param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \
DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, dtype::Quantized8Asymm) \
}
...
...
dnn/src/fallback/elemwise_helper/elemwise_op.h
浏览文件 @
ffbf8fad
...
...
@@ -12,61 +12,4 @@
#include "src/fallback/general_intrinsic/gi_float.h"
#include "src/fallback/general_intrinsic/gi_int.h"
namespace
megdnn
{
namespace
elemwise
{
///////////////////////////////// ParamElemVistor ///////////////////////////
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \
template <> \
struct ParamElemVisitor<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiLoad##_fun_suffix(src); \
} \
}; \
template <> \
struct ParamElemVisitorDup<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiBroadcast##_fun_suffix( \
*reinterpret_cast<const _inner_ctype*>(src)); \
} \
}
cb
(
dt_qint32
,
int32_t
,
GI_INT32_t
,
Int32
);
cb
(
dt_qint8
,
int8_t
,
GI_INT8_t
,
Int8
);
cb
(
dt_float32
,
float
,
GI_FLOAT32_t
,
Float32
);
cb
(
dt_int32
,
int32_t
,
GI_INT32_t
,
Int32
);
cb
(
dt_int8
,
int8_t
,
GI_INT8_t
,
Int8
);
#undef cb
template
<
typename
ctype
>
struct
ParamElemVisitorBcast101x4
;
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \
*reinterpret_cast<const _inner_ctype*>(src))); \
} \
}
cb
(
dt_qint8
,
int32_t
,
GI_INT8_t
,
Int8
,
Int32
);
cb
(
dt_int8
,
int32_t
,
GI_INT8_t
,
Int8
,
Int32
);
#undef cb
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiLoad##_fun_suffix(src); \
} \
}
cb
(
dt_qint32
,
int32_t
,
GI_INT32_t
,
Int32
);
cb
(
dt_float32
,
float
,
GI_FLOAT32_t
,
Float32
);
cb
(
dt_int32
,
int32_t
,
GI_INT32_t
,
Int32
);
#undef cb
}
// namespace elemwise
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h
浏览文件 @
ffbf8fad
...
...
@@ -87,7 +87,7 @@ template <>
struct
FuseAddHSwishOp
<
dt_qint32
,
dt_qint8
>
:
FuseAddHSwishOpBase
<
dt_qint32
,
dt_qint8
>
{
using
FuseAddHSwishOpBase
::
FuseAddHSwishOpBase
;
using
FuseAddHSwishOpBase
::
operator
();
constexpr
static
size_t
SIMD_WIDTH
=
GI_SIMD_LEN_BYTE
/
sizeof
(
int
8
_t
);
constexpr
static
size_t
SIMD_WIDTH
=
GI_SIMD_LEN_BYTE
/
sizeof
(
int
32
_t
);
void
operator
()(
const
GI_INT32_V2_t
&
vsrc0
,
const
GI_INT32_V2_t
&
vsrc1
,
dt_qint8
*
dst
)
const
{
...
...
dnn/src/fallback/elemwise_helper/kimpl/op_base.h
浏览文件 @
ffbf8fad
...
...
@@ -41,7 +41,7 @@ struct UnaryOpBase : OpBase<src_ctype, dst_ctype> {
GiStoreLowInt8( \
reinterpret_cast<int8_t*>(dst + 8), \
operator()({{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}})); \
GI_INT16_t vsrct2 = GiMove
HighLongInt8(vsrc.val[1]);
\
GI_INT16_t vsrct2 = GiMove
LowLongInt8(vsrc.val[1]);
\
GiStoreLowInt8( \
reinterpret_cast<int8_t*>(dst + 16), \
operator()({{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \
...
...
@@ -330,7 +330,7 @@ struct UnaryQuantizationOp;
template
<
typename
Op
>
struct
UnaryQuantizationOp
<
dt_qint8
,
dt_qint8
,
Op
>
:
UnaryOpBase
<
dt_qint8
,
dt_qint8
>
{
using
UnaryOpBase
<
dt_qint8
,
dt_qint8
>::
UnaryOpBase
;
constexpr
static
size_t
SIMD_WIDTH
=
16
;
constexpr
static
size_t
SIMD_WIDTH
=
GI_SIMD_LEN_BYTE
/
sizeof
(
int8_t
)
;
Op
op
;
void
operator
()(
const
dt_qint8
&
src
,
dt_qint8
*
dst
)
const
{
...
...
@@ -354,7 +354,7 @@ struct UnaryQuantizationOp<dt_qint8, dt_qint8, Op> : UnaryOpBase<dt_qint8, dt_qi
auto
val
=
this
->
op
({{
vitem0
,
vitem1
}});
val
.
val
[
0
]
=
GiMultiplyFloat32
(
val
.
val
[
0
],
this
->
vscale_dst
);
val
.
val
[
1
]
=
GiMultiplyFloat32
(
val
.
val
[
1
],
this
->
vscale_dst
);
return
QConverter
::
convert
<
GI_INT8_t
,
GI_FLOAT32_V
4
_t
>
(
val
);
return
QConverter
::
convert
<
GI_INT8_t
,
GI_FLOAT32_V
2
_t
>
(
val
);
}
};
...
...
@@ -364,7 +364,7 @@ struct BinaryQuantizationOp;
template
<
typename
Op
>
struct
BinaryQuantizationOp
<
dt_qint8
,
dt_qint8
,
Op
>
:
BinaryOpBase
<
dt_qint8
,
dt_qint8
>
{
using
BinaryOpBase
<
dt_qint8
,
dt_qint8
>::
BinaryOpBase
;
constexpr
static
size_t
SIMD_WIDTH
=
16
;
constexpr
static
size_t
SIMD_WIDTH
=
GI_SIMD_LEN_BYTE
/
sizeof
(
int8_t
)
;
Op
op
;
void
operator
()(
const
dt_qint8
&
src0
,
const
dt_qint8
&
src1
,
dt_qint8
*
dst
)
const
{
...
...
@@ -403,7 +403,7 @@ template <typename Op>
struct
TernaryQuantizationOp
<
dt_qint8
,
dt_qint8
,
Op
>
:
TernaryOpBase
<
dt_qint8
,
dt_qint8
>
{
using
TernaryOpBase
<
dt_qint8
,
dt_qint8
>::
TernaryOpBase
;
constexpr
static
size_t
SIMD_WIDTH
=
16
;
constexpr
static
size_t
SIMD_WIDTH
=
GI_SIMD_LEN_BYTE
/
sizeof
(
int8_t
)
;
Op
op
;
void
operator
()(
...
...
dnn/src/fallback/elemwise_helper/kimpl/relu.h
浏览文件 @
ffbf8fad
...
...
@@ -69,7 +69,7 @@ struct ReluOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> {
template
<
>
struct
ReluOp
<
dt_qint8
,
dt_qint8
>
:
ReluOpBase
<
dt_qint8
,
dt_qint8
>
{
using
ReluOpBase
::
ReluOpBase
;
constexpr
static
size_t
SIMD_WIDTH
=
16
;
constexpr
static
size_t
SIMD_WIDTH
=
GI_SIMD_LEN_BYTE
/
sizeof
(
int8_t
)
;
using
ReluOpBase
::
operator
();
void
operator
()(
const
GI_INT8_V2_t
&
vsrc
,
dt_qint8
*
dst
)
const
{
...
...
dnn/src/fallback/elemwise_helper/op_common.h
浏览文件 @
ffbf8fad
...
...
@@ -8,6 +8,7 @@
namespace
megdnn
{
namespace
elemwise
{
/*!
* \brief broadcast type
* BCAST_x[0]x[1]...: x[i] == !stride[i]
...
...
@@ -49,6 +50,55 @@ struct ParamElemVisitorDup;
template
<
typename
ctype
>
struct
ParamElemVisitorBcast101x4
;
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \
template <> \
struct ParamElemVisitor<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiLoad##_fun_suffix(src); \
} \
}; \
template <> \
struct ParamElemVisitorDup<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiBroadcast##_fun_suffix( \
*reinterpret_cast<const _inner_ctype*>(src)); \
} \
}
cb
(
dt_qint32
,
int32_t
,
GI_INT32_t
,
Int32
);
cb
(
dt_qint8
,
int8_t
,
GI_INT8_t
,
Int8
);
cb
(
dt_float32
,
float
,
GI_FLOAT32_t
,
Float32
);
cb
(
dt_int32
,
int32_t
,
GI_INT32_t
,
Int32
);
cb
(
dt_int8
,
int8_t
,
GI_INT8_t
,
Int8
);
#undef cb
template
<
typename
ctype
>
struct
ParamElemVisitorBcast101x4
;
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \
*reinterpret_cast<const _inner_ctype*>(src))); \
} \
}
cb
(
dt_qint8
,
int32_t
,
GI_INT8_t
,
Int8
,
Int32
);
cb
(
dt_int8
,
int32_t
,
GI_INT8_t
,
Int8
,
Int32
);
#undef cb
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiLoad##_fun_suffix(src); \
} \
}
cb
(
dt_qint32
,
int32_t
,
GI_INT32_t
,
Int32
);
cb
(
dt_float32
,
float
,
GI_FLOAT32_t
,
Float32
);
cb
(
dt_int32
,
int32_t
,
GI_INT32_t
,
Int32
);
#undef cb
///////////////////////////////// OpCaller /////////////////////////////
template
<
typename
Op
,
BcastType
bcast_type
>
struct
OpCallerUnary
;
...
...
dnn/src/fallback/elemwise_multi_type/opr_impl.h
浏览文件 @
ffbf8fad
...
...
@@ -50,6 +50,18 @@ protected:
void
on_fuse_mul_add3_uint8xf32xf32xf32
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
)
override
;
void
on_quantized_mode
(
const
ElemwiseOpParamN
<
1
>&
param
,
const
TensorND
&
dst
,
Elemwise
::
Mode
mode
)
override
;
void
on_quantized_mode
(
const
ElemwiseOpParamN
<
2
>&
param
,
const
TensorND
&
dst
,
Elemwise
::
Mode
mode
)
override
;
void
on_quantized_mode
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
,
Elemwise
::
Mode
mode
)
override
;
public:
using
naive
::
ElemwiseMultiTypeImpl
::
ElemwiseMultiTypeImpl
;
};
...
...
dnn/src/fallback/elemwise_multi_type/quantized_impl.cpp
0 → 100644
浏览文件 @
ffbf8fad
此差异已折叠。
点击以展开。
dnn/src/fallback/general_intrinsic/gi_common.h
浏览文件 @
ffbf8fad
...
...
@@ -60,6 +60,7 @@
#define GI_NEON_INTRINSICS
#if defined(__aarch64__)
#define GI_NEON64_INTRINSICS
#define GI_NEON32_INTRINSICS
#else
#define GI_NEON32_INTRINSICS
#endif
...
...
dnn/test/arm_common/elemwise_multi_type.cpp
浏览文件 @
ffbf8fad
...
...
@@ -11,8 +11,10 @@
*/
#include "test/common/elemwise_multi_type.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs.h"
#include "test/arm_common/fixture.h"
#include "test/common/benchmarker.h"
#include "test/common/checker.h"
#include "test/common/task_record_check.h"
#include "test/common/timer.h"
...
...
@@ -559,4 +561,95 @@ TEST_F(ARM_COMMON, ELEMWISE_FMA3_UINT8xF32xF32xF32_RECORD) {
.
execs
({{
16
,
128
,
16
,
16
},
{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
},
{}});
}
#if MEGDNN_WITH_BENCHMARK
namespace
{
void
run_elemwise_benchmark
(
const
TensorShapeArray
&
shapes
,
ElemwiseMultiType
::
Param
::
Mode
mode
,
const
char
*
mode_str
,
std
::
vector
<
DType
>
types
,
Handle
*
handle_bench
)
{
auto
handle_fallback
=
create_cpu_handle
(
1
);
Benchmarker
<
ElemwiseMultiType
>
benchmarker_bench
(
handle_bench
);
Benchmarker
<
ElemwiseMultiType
>
benchmarker_fallback
(
handle_fallback
.
get
());
float
throughput
=
0
;
SmallVector
<
TensorLayout
>
layouts
;
std
::
string
src_strs
;
for
(
size_t
i
=
0
;
i
<
shapes
.
size
();
i
++
)
{
layouts
.
emplace_back
(
shapes
[
i
],
types
[
i
]);
throughput
+=
layouts
.
back
().
span
().
dist_byte
();
src_strs
+=
layouts
.
back
().
to_string
();
if
(
i
!=
shapes
.
size
()
-
1
)
{
src_strs
+=
","
;
}
}
constexpr
size_t
RUN
=
50
;
benchmarker_fallback
.
set_times
(
RUN
).
set_display
(
false
);
benchmarker_bench
.
set_times
(
RUN
).
set_display
(
false
);
benchmarker_fallback
.
set_param
(
mode
);
benchmarker_bench
.
set_param
(
mode
);
TensorLayout
dst_layout
;
dst_layout
.
dtype
=
types
.
back
();
auto
opr
=
handle_bench
->
create_operator
<
ElemwiseMultiType
>
();
opr
->
param
()
=
mode
;
opr
->
deduce_layout
(
layouts
,
dst_layout
);
float
computations
=
dst_layout
.
total_nr_elems
()
*
(
std
::
max
<
size_t
>
(
shapes
.
size
(),
2
)
-
1
);
throughput
+=
dst_layout
.
span
().
dist_byte
();
computations
*=
(
1e3
/
(
1024.0
*
1024
));
throughput
*=
(
1e3
/
(
1024.0
*
1024
));
layouts
.
emplace_back
(
dst_layout
);
auto
fallback_time
=
benchmarker_fallback
.
execl
(
layouts
)
/
RUN
;
auto
bench_time
=
benchmarker_bench
.
execl
(
layouts
)
/
RUN
;
float
fallback_flops
=
computations
/
fallback_time
;
float
bench_flops
=
computations
/
bench_time
;
float
fallback_thr
=
throughput
/
fallback_time
;
float
bench_thr
=
throughput
/
bench_time
;
printf
(
"%s = %s (mode: %s) cpu=%fMFLOPS %fMB/s, bench=%fMFLOPS "
"%fMB/s "
"computations: %fx, throughput: %fx
\n
"
,
src_strs
.
c_str
(),
dst_layout
.
to_string
().
c_str
(),
mode_str
,
fallback_flops
,
fallback_thr
,
bench_flops
,
bench_thr
,
bench_flops
/
fallback_flops
,
bench_thr
/
fallback_thr
);
}
}
// namespace
#define RUN_WITH_MODE(shape, mode, types) \
run_elemwise_benchmark(shape, mode, #mode, types, handle());
TEST_F
(
ARM_COMMON
,
BENCHMARK_UNARY_MULTI_TYPE
)
{
using
Mode
=
ElemwiseMultiType
::
Param
::
Mode
;
for
(
auto
mode
:
{
Mode
::
QRELU
,
Mode
::
QABS
,
Mode
::
QSIGMOID
,
Mode
::
QEXP
,
Mode
::
QTANH
,
Mode
::
QFAST_TANH
,
Mode
::
QH_SWISH
})
{
std
::
vector
<
DType
>
types
=
{
dtype
::
QuantizedS8
(
1.4
f
),
dtype
::
QuantizedS8
(
3.4
f
)};
TensorShapeArray
shapes
=
{{
10000
}};
RUN_WITH_MODE
(
shapes
,
mode
,
types
);
std
::
vector
<
DType
>
types2
=
{
dtype
::
QuantizedS32
(
1.4
f
),
dtype
::
QuantizedS8
(
3.4
f
)};
RUN_WITH_MODE
(
shapes
,
mode
,
types2
);
}
}
TEST_F
(
ARM_COMMON
,
BENCHMARK_BINARY_MULTI_TYPE
)
{
using
Mode
=
ElemwiseMultiType
::
Param
::
Mode
;
for
(
auto
mode
:
{
Mode
::
QADD
,
Mode
::
QFUSE_ADD_RELU
,
Mode
::
QFUSE_ADD_H_SWISH
})
{
std
::
vector
<
DType
>
types
=
{
dtype
::
QuantizedS8
(
1.4
f
),
dtype
::
QuantizedS8
(
3.4
f
),
dtype
::
QuantizedS8
(
1.6
f
)};
TensorShapeArray
shapes
=
{{
10000
},
{
10000
}};
RUN_WITH_MODE
(
shapes
,
mode
,
types
);
std
::
vector
<
DType
>
types2
=
{
dtype
::
QuantizedS32
(
1.4
f
),
dtype
::
QuantizedS32
(
3.4
f
),
dtype
::
QuantizedS8
(
1.6
f
)};
RUN_WITH_MODE
(
shapes
,
mode
,
types2
);
}
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
dnn/test/fallback/elemwise_multi_type.cpp
浏览文件 @
ffbf8fad
...
...
@@ -26,6 +26,175 @@ TYPED_TEST(FALLBACK_ELEMWISE_MULTI_TYPE, run) {
elemwise_multi_type
::
run_test
<
TypeParam
>
(
this
->
handle
());
}
TEST_F
(
FALLBACK
,
ELEMWISE_QUANTIZED_MODE_UNARY
)
{
using
Mode
=
ElemwiseMultiType
::
Param
::
Mode
;
Checker
<
ElemwiseMultiType
>
checker
(
handle
());
std
::
unique_ptr
<
RNG
>
rng
;
for
(
auto
mode
:
{
Mode
::
QRELU
,
Mode
::
QABS
,
Mode
::
QSIGMOID
,
Mode
::
QEXP
,
Mode
::
QTANH
,
Mode
::
QFAST_TANH
,
Mode
::
QH_SWISH
})
{
checker
.
set_param
({
mode
});
for
(
DType
src_type
:
std
::
vector
<
DType
>
{
dtype
::
QuantizedS8
(
1.4
f
),
dtype
::
QuantizedS32
(
1.3
f
)})
{
checker
.
set_dtype
(
0
,
src_type
);
if
(
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)
{
rng
=
std
::
make_unique
<
UniformIntRNG
>
(
-
127
,
127
);
checker
.
set_dtype
(
1
,
dtype
::
QuantizedS8
(
1.7
f
));
}
else
{
rng
=
std
::
make_unique
<
UniformIntRNG
>
(
INT16_MIN
>>
1
,
INT16_MAX
>>
1
);
}
checker
.
set_rng
(
0
,
rng
.
get
());
auto
run
=
[
&
]()
{
checker
.
execs
({{
3
,
4
,
5
,
6
},
{}});
checker
.
execs
({{
3
},
{}});
checker
.
execs
({{
9
},
{}});
checker
.
execs
({{
17
},
{}});
};
if
(
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)
{
for
(
DType
dst_type
:
std
::
vector
<
DType
>
{
dtype
::
QuantizedS8
(
32718.6
f
)})
{
checker
.
set_dtype
(
1
,
dst_type
);
run
();
}
}
else
{
run
();
}
}
}
}
TEST_F
(
FALLBACK
,
ELEMWISE_QUANTIZED_MODE_BINARY
)
{
using
Mode
=
ElemwiseMultiType
::
Param
::
Mode
;
Checker
<
ElemwiseMultiType
>
checker
(
handle
());
auto
run
=
[
&
]()
{
//! nchw44
checker
.
execs
({{
1
,
3
,
2
,
2
,
4
},
{
1
,
3
,
1
,
1
,
4
},
{}});
checker
.
execs
({{
1
,
3
,
1
,
1
,
4
},
{
1
,
3
,
2
,
2
,
4
},
{}});
checker
.
execs
({{
3
,
8
,
5
,
3
,
4
},
{
1
,
8
,
1
,
1
,
4
},
{}});
checker
.
execs
({{
3
,
4
,
5
,
7
,
4
},
{
3
,
4
,
5
,
7
,
4
},
{}});
checker
.
execs
({{
1
,
2
,
5
,
7
,
4
},
{
1
,
2
,
1
,
1
,
4
},
{}});
checker
.
execs
({{
1
,
3
,
1
,
1
,
4
},
{
1
,
3
,
2
,
2
,
4
},
{}});
checker
.
execs
({{
1
,
8
,
1
,
1
,
4
},
{
3
,
8
,
5
,
3
,
4
},
{}});
checker
.
execs
({{
3
,
4
,
5
,
7
,
4
},
{
3
,
4
,
5
,
7
,
4
},
{}});
checker
.
execs
({{
1
,
2
,
1
,
1
,
4
},
{
1
,
2
,
5
,
7
,
4
},
{}});
//! VEC + SCALAR
checker
.
execs
({{
3
,
4
,
5
,
6
},
{
1
,
1
,
1
,
1
},
{}});
checker
.
execs
({{
1
,
1
,
1
,
1
},
{
3
,
4
,
5
,
6
},
{}});
checker
.
execs
({{
3
,
4
,
5
,
6
},
{
1
},
{}});
checker
.
execs
({{
1
},
{
3
,
4
,
5
,
6
},
{}});
//! VEC + 1C11
checker
.
execs
({{
3
,
4
,
5
,
6
},
{
1
,
4
,
1
,
1
},
{}});
checker
.
execs
({{
1
,
4
,
1
,
1
},
{
3
,
4
,
5
,
6
},
{}});
//! VEC + VEC
checker
.
execs
({{
3
},
{
3
},
{}});
checker
.
execs
({{
9
},
{
9
},
{}});
checker
.
execs
({{
17
},
{
17
},
{}});
};
// qint32 to qint8/quint8
for
(
auto
mode
:
{
Mode
::
QADD
,
Mode
::
QFUSE_ADD_RELU
,
Mode
::
QFUSE_ADD_H_SWISH
})
{
checker
.
set_param
({
mode
});
UniformIntRNG
rng
{
INT16_MIN
>>
1
,
INT16_MAX
>>
1
};
checker
.
set_rng
(
0
,
&
rng
)
.
set_rng
(
1
,
&
rng
)
.
set_dtype
(
0
,
dtype
::
QuantizedS32
(
1.3
f
))
.
set_dtype
(
1
,
dtype
::
QuantizedS32
(
1.2
f
));
for
(
DType
dst_type
:
std
::
vector
<
DType
>
{
dtype
::
QuantizedS8
(
32718.6
f
)})
{
checker
.
set_dtype
(
2
,
dst_type
);
run
();
}
}
for
(
auto
mode
:
{
Mode
::
QMUL
,
Mode
::
QADD
,
Mode
::
QMIN
,
Mode
::
QMAX
,
Mode
::
QSUB
,
Mode
::
QFUSE_ADD_RELU
,
Mode
::
QFUSE_ADD_SIGMOID
,
Mode
::
QFUSE_ADD_H_SWISH
})
{
checker
.
set_param
({
mode
});
// qint8 to qint8
UniformIntRNG
rng_int8
{
-
127
,
127
};
checker
.
set_rng
(
0
,
&
rng_int8
)
.
set_rng
(
1
,
&
rng_int8
)
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.35
f
))
.
set_dtype
(
1
,
dtype
::
QuantizedS8
(
1.15
f
))
.
set_dtype
(
2
,
dtype
::
QuantizedS8
(
1.75
f
));
run
();
}
//! TRUE_DIV : 0.0 / 0.0 will fail
checker
.
set_param
({
Mode
::
QTRUE_DIV
});
UniformIntRNG
rng_int8_1
{
-
127
,
127
};
UniformIntRNG
rng_int8_2
{
-
127
,
-
1
};
checker
.
set_rng
(
0
,
&
rng_int8_1
)
.
set_rng
(
1
,
&
rng_int8_2
)
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.4
f
))
.
set_dtype
(
1
,
dtype
::
QuantizedS8
(
1.1
f
))
.
set_dtype
(
2
,
dtype
::
QuantizedS8
(
1.7
f
));
run
();
//! TANH
checker
.
set_param
({
Mode
::
QFUSE_ADD_TANH
});
UniformIntRNG
rng_int8
{
-
5
,
5
};
checker
.
set_rng
(
0
,
&
rng_int8
)
.
set_rng
(
1
,
&
rng_int8
)
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.1
f
))
.
set_dtype
(
1
,
dtype
::
QuantizedS8
(
1.4
f
))
.
set_dtype
(
2
,
dtype
::
QuantizedS8
(
1.7
f
));
run
();
}
TEST_F
(
FALLBACK
,
ELEMWISE_QUANTIZED_MODE_TERNARY
)
{
using
Mode
=
ElemwiseMultiType
::
Param
::
Mode
;
Checker
<
ElemwiseMultiType
>
checker
(
handle
());
auto
run
=
[
&
]()
{
//! nchw44
checker
.
execs
({{
1
,
3
,
1
,
1
,
4
},
{
1
,
3
,
2
,
2
,
4
},
{
1
,
3
,
1
,
1
,
4
},
{}});
checker
.
execs
({{
1
,
3
,
1
,
1
,
4
},
{
2
,
3
,
2
,
2
,
4
},
{
1
,
3
,
1
,
1
,
4
},
{}});
checker
.
execs
({{
1
,
8
,
1
,
1
,
4
},
{
3
,
8
,
5
,
3
,
4
},
{
1
,
8
,
1
,
1
,
4
},
{}});
checker
.
execs
({{
3
,
4
,
5
,
7
,
4
},
{
3
,
4
,
5
,
7
,
4
},
{
3
,
4
,
5
,
7
,
4
},
{}});
checker
.
execs
({{
1
,
2
,
1
,
1
,
4
},
{
1
,
2
,
5
,
7
,
4
},
{
1
,
2
,
1
,
1
,
4
},
{}});
//! nchw44
checker
.
execs
({{
1
,
3
,
2
,
2
,
4
},
{
1
,
3
,
1
,
1
,
4
},
{
1
,
3
,
2
,
2
,
4
},
{}});
checker
.
execs
({{
2
,
3
,
2
,
2
,
4
},
{
1
,
3
,
1
,
1
,
4
},
{
2
,
3
,
2
,
2
,
4
},
{}});
checker
.
execs
({{
3
,
8
,
5
,
3
,
4
},
{
1
,
8
,
1
,
1
,
4
},
{
3
,
8
,
5
,
3
,
4
},
{}});
checker
.
execs
({{
3
,
4
,
5
,
7
,
4
},
{
3
,
4
,
5
,
7
,
4
},
{
3
,
4
,
5
,
7
,
4
},
{}});
checker
.
execs
({{
1
,
2
,
5
,
7
,
4
},
{
1
,
2
,
1
,
1
,
4
},
{
1
,
2
,
5
,
7
,
4
},
{}});
checker
.
execs
({{
3
,
4
,
5
,
6
},
{
3
,
4
,
5
,
6
},
{
1
,
1
,
1
,
1
},
{}});
checker
.
execs
({{
1
,
4
,
1
,
1
},
{
3
,
4
,
5
,
6
},
{
1
,
4
,
1
,
1
},
{}});
checker
.
execs
({{
3
},
{
3
},
{
3
},
{}});
checker
.
execs
({{
9
},
{
9
},
{
9
},
{}});
checker
.
execs
({{
17
},
{
17
},
{
17
},
{}});
checker
.
execs
({{
3
,
4
,
5
,
6
},
{
3
,
4
,
5
,
6
},
{
3
,
4
,
5
,
6
},
{}});
};
for
(
auto
mode
:
{
Mode
::
QFUSE_MUL_ADD3
})
{
checker
.
set_param
({
mode
});
// qint8 to qint8
UniformIntRNG
rng_int8
{
-
127
,
127
};
checker
.
set_rng
(
0
,
&
rng_int8
)
.
set_rng
(
1
,
&
rng_int8
)
.
set_rng
(
2
,
&
rng_int8
)
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.45
f
))
.
set_dtype
(
1
,
dtype
::
QuantizedS8
(
1.15
f
))
.
set_dtype
(
2
,
dtype
::
QuantizedS8
(
1.75
f
))
.
set_dtype
(
3
,
dtype
::
QuantizedS8
(
1.35
f
));
run
();
}
}
TEST_F
(
FALLBACK
,
ELEMWISE_MULTI_TYPE_RECORD_FMA3_INT16x32x32x32
)
{
TaskRecordChecker
<
ElemwiseMultiType
>
checker
{
1
};
checker
.
set_param
({
ElemwiseMultiType
::
Mode
::
FUSE_MUL_ADD3_INT16x32x32x32
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录