提交 ffbf8fad 编写于 作者: M Megvii Engine Team

feat(fallback): add general intrinsic to elemwise multitype

GitOrigin-RevId: fe7b335545fd959f917b7df8ee48739ccb2a86ab
上级 484e1f11
......@@ -15,7 +15,7 @@
#include "src/arm_common/elemwise_helper/op_binary.h"
#include "src/arm_common/elemwise_helper/op_ternary.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"
#include "src/fallback/elemwise_helper/op_common.h"
namespace megdnn {
namespace elemwise {
......
......@@ -364,17 +364,9 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
}
#define DISPATCH() \
if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \
DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \
} else if ( \
param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \
if (param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \
DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, dtype::Quantized8Asymm) \
} else if ( \
param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \
DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \
} else if ( \
param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \
......@@ -467,16 +459,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
#define DISPATCH() \
if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \
DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \
} else if ( \
param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \
DISPATCH_MODE(dtype::QuantizedS32, dtype::Quantized8Asymm) \
} else if ( \
param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \
DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \
} else if ( \
param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \
......@@ -701,12 +685,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
}
#define DISPATCH() \
if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \
DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \
} else if ( \
param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \
if (param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \
dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \
DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, dtype::Quantized8Asymm) \
}
......
......@@ -12,61 +12,4 @@
#include "src/fallback/general_intrinsic/gi_float.h"
#include "src/fallback/general_intrinsic/gi_int.h"
namespace megdnn {
namespace elemwise {
///////////////////////////////// ParamElemVistor ///////////////////////////
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \
template <> \
struct ParamElemVisitor<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiLoad##_fun_suffix(src); \
} \
}; \
template <> \
struct ParamElemVisitorDup<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiBroadcast##_fun_suffix( \
*reinterpret_cast<const _inner_ctype*>(src)); \
} \
}
cb(dt_qint32, int32_t, GI_INT32_t, Int32);
cb(dt_qint8, int8_t, GI_INT8_t, Int8);
cb(dt_float32, float, GI_FLOAT32_t, Float32);
cb(dt_int32, int32_t, GI_INT32_t, Int32);
cb(dt_int8, int8_t, GI_INT8_t, Int8);
#undef cb
template <typename ctype>
struct ParamElemVisitorBcast101x4;
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \
*reinterpret_cast<const _inner_ctype*>(src))); \
} \
}
cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32);
cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32);
#undef cb
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiLoad##_fun_suffix(src); \
} \
}
cb(dt_qint32, int32_t, GI_INT32_t, Int32);
cb(dt_float32, float, GI_FLOAT32_t, Float32);
cb(dt_int32, int32_t, GI_INT32_t, Int32);
#undef cb
} // namespace elemwise
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -87,7 +87,7 @@ template <>
struct FuseAddHSwishOp<dt_qint32, dt_qint8> : FuseAddHSwishOpBase<dt_qint32, dt_qint8> {
using FuseAddHSwishOpBase::FuseAddHSwishOpBase;
using FuseAddHSwishOpBase::operator();
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t);
void operator()(
const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1,
dt_qint8* dst) const {
......
......@@ -41,7 +41,7 @@ struct UnaryOpBase : OpBase<src_ctype, dst_ctype> {
GiStoreLowInt8( \
reinterpret_cast<int8_t*>(dst + 8), \
operator()({{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}})); \
GI_INT16_t vsrct2 = GiMoveHighLongInt8(vsrc.val[1]); \
GI_INT16_t vsrct2 = GiMoveLowLongInt8(vsrc.val[1]); \
GiStoreLowInt8( \
reinterpret_cast<int8_t*>(dst + 16), \
operator()({{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \
......@@ -330,7 +330,7 @@ struct UnaryQuantizationOp;
template <typename Op>
struct UnaryQuantizationOp<dt_qint8, dt_qint8, Op> : UnaryOpBase<dt_qint8, dt_qint8> {
using UnaryOpBase<dt_qint8, dt_qint8>::UnaryOpBase;
constexpr static size_t SIMD_WIDTH = 16;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
Op op;
void operator()(const dt_qint8& src, dt_qint8* dst) const {
......@@ -354,7 +354,7 @@ struct UnaryQuantizationOp<dt_qint8, dt_qint8, Op> : UnaryOpBase<dt_qint8, dt_qi
auto val = this->op({{vitem0, vitem1}});
val.val[0] = GiMultiplyFloat32(val.val[0], this->vscale_dst);
val.val[1] = GiMultiplyFloat32(val.val[1], this->vscale_dst);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V4_t>(val);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>(val);
}
};
......@@ -364,7 +364,7 @@ struct BinaryQuantizationOp;
template <typename Op>
struct BinaryQuantizationOp<dt_qint8, dt_qint8, Op> : BinaryOpBase<dt_qint8, dt_qint8> {
using BinaryOpBase<dt_qint8, dt_qint8>::BinaryOpBase;
constexpr static size_t SIMD_WIDTH = 16;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
Op op;
void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const {
......@@ -403,7 +403,7 @@ template <typename Op>
struct TernaryQuantizationOp<dt_qint8, dt_qint8, Op>
: TernaryOpBase<dt_qint8, dt_qint8> {
using TernaryOpBase<dt_qint8, dt_qint8>::TernaryOpBase;
constexpr static size_t SIMD_WIDTH = 16;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
Op op;
void operator()(
......
......@@ -69,7 +69,7 @@ struct ReluOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> {
template <>
struct ReluOp<dt_qint8, dt_qint8> : ReluOpBase<dt_qint8, dt_qint8> {
using ReluOpBase::ReluOpBase;
constexpr static size_t SIMD_WIDTH = 16;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
using ReluOpBase::operator();
void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const {
......
......@@ -8,6 +8,7 @@
namespace megdnn {
namespace elemwise {
/*!
* \brief broadcast type
* BCAST_x[0]x[1]...: x[i] == !stride[i]
......@@ -49,6 +50,55 @@ struct ParamElemVisitorDup;
template <typename ctype>
struct ParamElemVisitorBcast101x4;
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \
template <> \
struct ParamElemVisitor<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiLoad##_fun_suffix(src); \
} \
}; \
template <> \
struct ParamElemVisitorDup<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiBroadcast##_fun_suffix( \
*reinterpret_cast<const _inner_ctype*>(src)); \
} \
}
cb(dt_qint32, int32_t, GI_INT32_t, Int32);
cb(dt_qint8, int8_t, GI_INT8_t, Int8);
cb(dt_float32, float, GI_FLOAT32_t, Float32);
cb(dt_int32, int32_t, GI_INT32_t, Int32);
cb(dt_int8, int8_t, GI_INT8_t, Int8);
#undef cb
template <typename ctype>
struct ParamElemVisitorBcast101x4;
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \
*reinterpret_cast<const _inner_ctype*>(src))); \
} \
}
cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32);
cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32);
#undef cb
#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \
template <> \
struct ParamElemVisitorBcast101x4<_ctype> { \
_simd_type operator()(const _ctype* src) const { \
return GiLoad##_fun_suffix(src); \
} \
}
cb(dt_qint32, int32_t, GI_INT32_t, Int32);
cb(dt_float32, float, GI_FLOAT32_t, Float32);
cb(dt_int32, int32_t, GI_INT32_t, Int32);
#undef cb
///////////////////////////////// OpCaller /////////////////////////////
template <typename Op, BcastType bcast_type>
struct OpCallerUnary;
......
......@@ -50,6 +50,18 @@ protected:
void on_fuse_mul_add3_uint8xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) override;
void on_quantized_mode(
const ElemwiseOpParamN<1>& param, const TensorND& dst,
Elemwise::Mode mode) override;
void on_quantized_mode(
const ElemwiseOpParamN<2>& param, const TensorND& dst,
Elemwise::Mode mode) override;
void on_quantized_mode(
const ElemwiseOpParamN<3>& param, const TensorND& dst,
Elemwise::Mode mode) override;
public:
using naive::ElemwiseMultiTypeImpl::ElemwiseMultiTypeImpl;
};
......
此差异已折叠。
......@@ -60,6 +60,7 @@
#define GI_NEON_INTRINSICS
#if defined(__aarch64__)
#define GI_NEON64_INTRINSICS
#define GI_NEON32_INTRINSICS
#else
#define GI_NEON32_INTRINSICS
#endif
......
......@@ -11,8 +11,10 @@
*/
#include "test/common/elemwise_multi_type.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs.h"
#include "test/arm_common/fixture.h"
#include "test/common/benchmarker.h"
#include "test/common/checker.h"
#include "test/common/task_record_check.h"
#include "test/common/timer.h"
......@@ -559,4 +561,95 @@ TEST_F(ARM_COMMON, ELEMWISE_FMA3_UINT8xF32xF32xF32_RECORD) {
.execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}});
}
#if MEGDNN_WITH_BENCHMARK
namespace {
void run_elemwise_benchmark(
const TensorShapeArray& shapes, ElemwiseMultiType::Param::Mode mode,
const char* mode_str, std::vector<DType> types, Handle* handle_bench) {
auto handle_fallback = create_cpu_handle(1);
Benchmarker<ElemwiseMultiType> benchmarker_bench(handle_bench);
Benchmarker<ElemwiseMultiType> benchmarker_fallback(handle_fallback.get());
float throughput = 0;
SmallVector<TensorLayout> layouts;
std::string src_strs;
for (size_t i = 0; i < shapes.size(); i++) {
layouts.emplace_back(shapes[i], types[i]);
throughput += layouts.back().span().dist_byte();
src_strs += layouts.back().to_string();
if (i != shapes.size() - 1) {
src_strs += ",";
}
}
constexpr size_t RUN = 50;
benchmarker_fallback.set_times(RUN).set_display(false);
benchmarker_bench.set_times(RUN).set_display(false);
benchmarker_fallback.set_param(mode);
benchmarker_bench.set_param(mode);
TensorLayout dst_layout;
dst_layout.dtype = types.back();
auto opr = handle_bench->create_operator<ElemwiseMultiType>();
opr->param() = mode;
opr->deduce_layout(layouts, dst_layout);
float computations =
dst_layout.total_nr_elems() * (std::max<size_t>(shapes.size(), 2) - 1);
throughput += dst_layout.span().dist_byte();
computations *= (1e3 / (1024.0 * 1024));
throughput *= (1e3 / (1024.0 * 1024));
layouts.emplace_back(dst_layout);
auto fallback_time = benchmarker_fallback.execl(layouts) / RUN;
auto bench_time = benchmarker_bench.execl(layouts) / RUN;
float fallback_flops = computations / fallback_time;
float bench_flops = computations / bench_time;
float fallback_thr = throughput / fallback_time;
float bench_thr = throughput / bench_time;
printf("%s = %s (mode: %s) cpu=%fMFLOPS %fMB/s, bench=%fMFLOPS "
"%fMB/s "
"computations: %fx, throughput: %fx\n",
src_strs.c_str(), dst_layout.to_string().c_str(), mode_str, fallback_flops,
fallback_thr, bench_flops, bench_thr, bench_flops / fallback_flops,
bench_thr / fallback_thr);
}
} // namespace
#define RUN_WITH_MODE(shape, mode, types) \
run_elemwise_benchmark(shape, mode, #mode, types, handle());
TEST_F(ARM_COMMON, BENCHMARK_UNARY_MULTI_TYPE) {
using Mode = ElemwiseMultiType::Param::Mode;
for (auto mode :
{Mode::QRELU, Mode::QABS, Mode::QSIGMOID, Mode::QEXP, Mode::QTANH,
Mode::QFAST_TANH, Mode::QH_SWISH}) {
std::vector<DType> types = {dtype::QuantizedS8(1.4f), dtype::QuantizedS8(3.4f)};
TensorShapeArray shapes = {{10000}};
RUN_WITH_MODE(shapes, mode, types);
std::vector<DType> types2 = {
dtype::QuantizedS32(1.4f), dtype::QuantizedS8(3.4f)};
RUN_WITH_MODE(shapes, mode, types2);
}
}
TEST_F(ARM_COMMON, BENCHMARK_BINARY_MULTI_TYPE) {
using Mode = ElemwiseMultiType::Param::Mode;
for (auto mode : {Mode::QADD, Mode::QFUSE_ADD_RELU, Mode::QFUSE_ADD_H_SWISH}) {
std::vector<DType> types = {
dtype::QuantizedS8(1.4f), dtype::QuantizedS8(3.4f),
dtype::QuantizedS8(1.6f)};
TensorShapeArray shapes = {{10000}, {10000}};
RUN_WITH_MODE(shapes, mode, types);
std::vector<DType> types2 = {
dtype::QuantizedS32(1.4f), dtype::QuantizedS32(3.4f),
dtype::QuantizedS8(1.6f)};
RUN_WITH_MODE(shapes, mode, types2);
}
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -26,6 +26,175 @@ TYPED_TEST(FALLBACK_ELEMWISE_MULTI_TYPE, run) {
elemwise_multi_type::run_test<TypeParam>(this->handle());
}
TEST_F(FALLBACK, ELEMWISE_QUANTIZED_MODE_UNARY) {
using Mode = ElemwiseMultiType::Param::Mode;
Checker<ElemwiseMultiType> checker(handle());
std::unique_ptr<RNG> rng;
for (auto mode :
{Mode::QRELU, Mode::QABS, Mode::QSIGMOID, Mode::QEXP, Mode::QTANH,
Mode::QFAST_TANH, Mode::QH_SWISH}) {
checker.set_param({mode});
for (DType src_type :
std::vector<DType>{dtype::QuantizedS8(1.4f), dtype::QuantizedS32(1.3f)}) {
checker.set_dtype(0, src_type);
if (src_type.enumv() == DTypeEnum::QuantizedS8) {
rng = std::make_unique<UniformIntRNG>(-127, 127);
checker.set_dtype(1, dtype::QuantizedS8(1.7f));
} else {
rng = std::make_unique<UniformIntRNG>(INT16_MIN >> 1, INT16_MAX >> 1);
}
checker.set_rng(0, rng.get());
auto run = [&]() {
checker.execs({{3, 4, 5, 6}, {}});
checker.execs({{3}, {}});
checker.execs({{9}, {}});
checker.execs({{17}, {}});
};
if (src_type.enumv() == DTypeEnum::QuantizedS32) {
for (DType dst_type :
std::vector<DType>{dtype::QuantizedS8(32718.6f)}) {
checker.set_dtype(1, dst_type);
run();
}
} else {
run();
}
}
}
}
TEST_F(FALLBACK, ELEMWISE_QUANTIZED_MODE_BINARY) {
using Mode = ElemwiseMultiType::Param::Mode;
Checker<ElemwiseMultiType> checker(handle());
auto run = [&]() {
//! nchw44
checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
//! VEC + SCALAR
checker.execs({{3, 4, 5, 6}, {1, 1, 1, 1}, {}});
checker.execs({{1, 1, 1, 1}, {3, 4, 5, 6}, {}});
checker.execs({{3, 4, 5, 6}, {1}, {}});
checker.execs({{1}, {3, 4, 5, 6}, {}});
//! VEC + 1C11
checker.execs({{3, 4, 5, 6}, {1, 4, 1, 1}, {}});
checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {}});
//! VEC + VEC
checker.execs({{3}, {3}, {}});
checker.execs({{9}, {9}, {}});
checker.execs({{17}, {17}, {}});
};
// qint32 to qint8/quint8
for (auto mode : {Mode::QADD, Mode::QFUSE_ADD_RELU, Mode::QFUSE_ADD_H_SWISH}) {
checker.set_param({mode});
UniformIntRNG rng{INT16_MIN >> 1, INT16_MAX >> 1};
checker.set_rng(0, &rng)
.set_rng(1, &rng)
.set_dtype(0, dtype::QuantizedS32(1.3f))
.set_dtype(1, dtype::QuantizedS32(1.2f));
for (DType dst_type : std::vector<DType>{dtype::QuantizedS8(32718.6f)}) {
checker.set_dtype(2, dst_type);
run();
}
}
for (auto mode :
{Mode::QMUL, Mode::QADD, Mode::QMIN, Mode::QMAX, Mode::QSUB,
Mode::QFUSE_ADD_RELU, Mode::QFUSE_ADD_SIGMOID, Mode::QFUSE_ADD_H_SWISH}) {
checker.set_param({mode});
// qint8 to qint8
UniformIntRNG rng_int8{-127, 127};
checker.set_rng(0, &rng_int8)
.set_rng(1, &rng_int8)
.set_dtype(0, dtype::QuantizedS8(1.35f))
.set_dtype(1, dtype::QuantizedS8(1.15f))
.set_dtype(2, dtype::QuantizedS8(1.75f));
run();
}
//! TRUE_DIV : 0.0 / 0.0 will fail
checker.set_param({Mode::QTRUE_DIV});
UniformIntRNG rng_int8_1{-127, 127};
UniformIntRNG rng_int8_2{-127, -1};
checker.set_rng(0, &rng_int8_1)
.set_rng(1, &rng_int8_2)
.set_dtype(0, dtype::QuantizedS8(1.4f))
.set_dtype(1, dtype::QuantizedS8(1.1f))
.set_dtype(2, dtype::QuantizedS8(1.7f));
run();
//! TANH
checker.set_param({Mode::QFUSE_ADD_TANH});
UniformIntRNG rng_int8{-5, 5};
checker.set_rng(0, &rng_int8)
.set_rng(1, &rng_int8)
.set_dtype(0, dtype::QuantizedS8(1.1f))
.set_dtype(1, dtype::QuantizedS8(1.4f))
.set_dtype(2, dtype::QuantizedS8(1.7f));
run();
}
TEST_F(FALLBACK, ELEMWISE_QUANTIZED_MODE_TERNARY) {
using Mode = ElemwiseMultiType::Param::Mode;
Checker<ElemwiseMultiType> checker(handle());
auto run = [&]() {
//! nchw44
checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
//! nchw44
checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {1, 1, 1, 1}, {}});
checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {1, 4, 1, 1}, {}});
checker.execs({{3}, {3}, {3}, {}});
checker.execs({{9}, {9}, {9}, {}});
checker.execs({{17}, {17}, {17}, {}});
checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}, {}});
};
for (auto mode : {Mode::QFUSE_MUL_ADD3}) {
checker.set_param({mode});
// qint8 to qint8
UniformIntRNG rng_int8{-127, 127};
checker.set_rng(0, &rng_int8)
.set_rng(1, &rng_int8)
.set_rng(2, &rng_int8)
.set_dtype(0, dtype::QuantizedS8(1.45f))
.set_dtype(1, dtype::QuantizedS8(1.15f))
.set_dtype(2, dtype::QuantizedS8(1.75f))
.set_dtype(3, dtype::QuantizedS8(1.35f));
run();
}
}
TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_RECORD_FMA3_INT16x32x32x32) {
TaskRecordChecker<ElemwiseMultiType> checker{1};
checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册