提交 f19646b5 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(dnn/arm_common/elemwise): elemwise ternary support chw44

GitOrigin-RevId: ef19a636ba4e47712585b0d627ef5c2c7d19d3b3
上级 3d9d4b9b
......@@ -31,7 +31,10 @@ class ElemwiseImpl::AlgoPack {
AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec;
AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca;
AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101;
AlgoTernaryFma3Bcast101x4VecBcast101x4
algo_ternaryfma3_bcast101x4_vec_bcast101x4;
AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec;
AlgoTernaryFma3VecBcast101x4Vec algo_ternaryfma3_vec_bcast101x4_vec;
AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec;
AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca;
......@@ -45,7 +48,9 @@ public:
all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca);
all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101);
all_algos.emplace_back(&algo_ternaryfma3_bcast101x4_vec_bcast101x4);
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101x4_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca);
}
......@@ -112,12 +117,25 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
return kern_param;
}
if (is_vector(src1.layout) &&
is_broadcastedx_channel_like<4>(src0.layout, binfo) &&
src0.layout.eq_layout(src2.layout)) {
kern_param.broad_cast_type = BcastType::BCAST101x4_VEC_BCAST101x4;
return kern_param;
}
if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) &&
is_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101_VEC;
return kern_param;
}
if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) &&
is_broadcastedx_channel_like<4>(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101x4_VEC;
return kern_param;
}
if (is_vector(src0.layout) && is_vector(src2.layout) &&
is_broadcasted_scalar(src1.layout)) {
kern_param.broad_cast_type = BcastType::VEC_SCALAR_VEC;
......
......@@ -41,7 +41,9 @@ private:
class AlgoTernaryFma3VecVecVec;
class AlgoTernaryFma3VecVecScalar;
class AlgoTernaryFma3Bcast101VecBcast101;
class AlgoTernaryFma3Bcast101x4VecBcast101x4;
class AlgoTernaryFma3VecBcast101Vec;
class AlgoTernaryFma3VecBcast101x4Vec;
class AlgoTernaryFma3VecScalarVec;
class AlgoTernaryFma3VecScalarScalar;
class AlgoPack;
......
......@@ -42,7 +42,9 @@ using namespace arm_common;
DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC);
DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR);
DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101);
DECL_AVAILABLE(Bcast101x4VecBcast101x4, BcastType::BCAST101x4_VEC_BCAST101x4);
DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC);
DECL_AVAILABLE(VecBcast101x4Vec, BcastType::VEC_BCAST101x4_VEC);
DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC);
DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR);
#undef DECL_CB
......@@ -158,6 +160,82 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec(
return;
}
void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];
BroadcastChannelInfo binfo;
is_broadcastedx_channel_like<4>(src0.layout, binfo);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void(const _type*, const _type*, const _type*, \
_type*, DType, DType, DType, DType, size_t, \
size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, \
BcastType::BCAST101x4_VEC_BCAST101x4>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
static_cast<const _type*>(src1.raw_ptr), \
static_cast<const _type*>(src2.raw_ptr), \
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return
size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3Bcast101x4VecBcast101x4::exec"_hash);
#undef DISPATCH_TERNARY
return;
}
void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];
BroadcastChannelInfo binfo;
is_broadcastedx_channel_like<4>(src1.layout, binfo);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void(const _type*, const _type*, const _type*, \
_type*, DType, DType, DType, DType, size_t, \
size_t, size_t, size_t)> \
run = OpCallerTernary<_op<_type, _type>, \
BcastType::VEC_BCAST101x4_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
static_cast<const _type*>(src1.raw_ptr), \
static_cast<const _type*>(src2.raw_ptr), \
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return
size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3VecBcast101x4Vec::exec"_hash);
#undef DISPATCH_TERNARY
return;
}
void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
......@@ -193,6 +271,7 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec(
return;
}
void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
......
......@@ -33,7 +33,9 @@ namespace arm_common {
DECL_CB(VecVecVec);
DECL_CB(VecVecScalar);
DECL_CB(Bcast101VecBcast101);
DECL_CB(Bcast101x4VecBcast101x4);
DECL_CB(VecBcast101Vec);
DECL_CB(VecBcast101x4Vec);
DECL_CB(VecScalarVec);
DECL_CB(VecScalarScalar);
#undef DECL_CB
......
......@@ -810,6 +810,65 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param,
DISPATCH()
#undef DISPATCH_SINGLE_MODE
}
}
//! VEC + BCAST101x4 + VEC
{
BroadcastChannelInfo binfo;
if (is_vector(src0.layout) &&
is_broadcastedx_channel_like<4>(src1.layout, binfo) &&
src0.layout.eq_shape(src2.layout)) {
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
case _mode: { \
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void(const src_ctype*, const src_ctype*, \
const src_ctype*, dst_ctype*, DType, DType, DType, \
DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, \
VEC_BCAST101x4_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
}
size_t batch_size =
src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
DISPATCH()
#undef DISPATCH_SINGLE_MODE
}
//! BCAST101x + VEC +BCAST101x
if (is_vector(src1.layout) &&
is_broadcastedx_channel_like<4>(src0.layout, binfo) &&
src0.layout.eq_shape(src2.layout)) {
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
case _mode: { \
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void(const src_ctype*, const src_ctype*, \
const src_ctype*, dst_ctype*, DType, DType, DType, \
DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, \
BCAST101x4_VEC_BCAST101x4>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
}
size_t batch_size =
src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
DISPATCH()
#undef DISPATCH_SINGLE_MODE
}
}
......
......@@ -105,7 +105,9 @@ enum BcastType {
VEC_VEC_VEC,
VEC_VEC_SCALAR,
BCAST101_VEC_BCAST101,
BCAST101x4_VEC_BCAST101x4,
VEC_BCAST101_VEC,
VEC_BCAST101x4_VEC,
VEC_SCALAR_VEC,
VEC_SCALAR_SCALAR,
UNKNOWN_BCAST_TYPE
......@@ -681,6 +683,54 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> {
}
};
//! src0: CHW44, src1: vector, src2: CHW44
template <typename Op>
struct OpCallerTernary<Op, BCAST101x4_VEC_BCAST101x4> {
static void run(const typename Op::src_ctype* src0,
const typename Op::src_ctype* src1,
const typename Op::src_ctype* src2,
typename Op::dst_ctype* dst, DType src0_dtype,
DType src1_dtype, DType src2_dtype, DType dst_dtype,
size_t batch, size_t nr_channel_blocks,
size_t channel_stride, size_t channel_block_dim) {
megdnn_assert(channel_block_dim == 4, "only imp for nchw44");
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
ParamElemVisitorBcast101x4<typename Op::src_ctype> vis0;
ParamElemVisitor<typename Op::src_ctype> vis1;
ParamElemVisitorBcast101x4<typename Op::src_ctype> vis2;
for (size_t b = 0; b < batch; b++) {
auto src0_ptr = src0;
auto src2_ptr = src2;
for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
auto src0_block_ptr = src0_ptr + cb * channel_block_dim;
auto src2_block_ptr = src2_ptr + cb * channel_block_dim;
auto channel_block_vec0 = vis0(src0_block_ptr);
auto channel_block_vec2 = vis2(src2_block_ptr);
size_t img_index = 0;
auto src1_offset = Op::SIMD_WIDTH / channel_block_dim;
for (; img_index + 2 * src1_offset <= channel_stride;
img_index += 2 * src1_offset) {
op({{channel_block_vec0, channel_block_vec0}},
{{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}},
{{channel_block_vec2, channel_block_vec2}}, dst);
src1 += Op::SIMD_WIDTH * 2;
dst += Op::SIMD_WIDTH * 2;
}
// TODO:all elemwise_multi_type op imp one simd mode
for (; img_index < channel_stride; img_index++) {
for (size_t c_iter = 0; c_iter < channel_block_dim;
c_iter++) {
op(*(src0_block_ptr + c_iter), *src1,
*(src2_block_ptr + c_iter), dst);
src1++;
dst++;
}
}
}
}
}
};
//! src1: 1C11, src0 and src2 are contig
template <typename Op>
struct OpCallerTernary<Op, VEC_BCAST101_VEC> {
......@@ -725,6 +775,52 @@ struct OpCallerTernary<Op, VEC_BCAST101_VEC> {
}
};
//! src1: CHW44, src0 and src2 are contig
template <typename Op>
struct OpCallerTernary<Op, VEC_BCAST101x4_VEC> {
static void run(const typename Op::src_ctype* src0,
const typename Op::src_ctype* src1,
const typename Op::src_ctype* src2,
typename Op::dst_ctype* dst, DType src0_dtype,
DType src1_dtype, DType src2_dtype, DType dst_dtype,
size_t batch, size_t nr_channel_blocks,
size_t channel_stride, size_t channel_block_dim) {
megdnn_assert(channel_block_dim == 4, "only imp for nchw44");
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
ParamElemVisitor<typename Op::src_ctype> vis0;
ParamElemVisitorBcast101x4<typename Op::src_ctype> vis1;
ParamElemVisitor<typename Op::src_ctype> vis2;
for (size_t b = 0; b < batch; b++) {
auto src1_ptr = src1;
for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
auto src1_block_ptr = src1_ptr + cb * channel_block_dim;
auto channel_block_vec = vis1(src1_block_ptr);
size_t img_index = 0;
auto offset = Op::SIMD_WIDTH / channel_block_dim;
for (; img_index + 2 * offset <= channel_stride;
img_index += 2 * offset) {
op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}},
{{channel_block_vec, channel_block_vec}},
{{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst);
src0 += Op::SIMD_WIDTH * 2;
src2 += Op::SIMD_WIDTH * 2;
dst += Op::SIMD_WIDTH * 2;
}
// TODO:all elemwise_multi_type op imp one simd mode
for (; img_index < channel_stride; img_index++) {
for (size_t c_iter = 0; c_iter < channel_block_dim;
c_iter++) {
op(*src0, *(src1_block_ptr + c_iter), *src2, dst);
src0++;
src2++;
dst++;
}
}
}
}
}
};
//! src1: scalar, src0 and src2 has the same shape
template <typename Op>
struct OpCallerTernary<Op, VEC_SCALAR_VEC> {
......
......@@ -26,50 +26,53 @@ TYPED_TEST(ARM_ELEMWISE, run) {
elemwise::run_test<TypeParam>(this->handle());
}
#define TERNARY_COMPLATE_TEST_CASE(_optr) \
printf("Check binary optr %s by all cases.\n", #_optr); \
checker.set_param(Mode::_optr) \
.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \
checker.set_param(Mode::_optr) \
.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \
checker.set_param(Mode::_optr) \
.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \
checker.set_param(Mode::_optr) \
.execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \
checker.set_param(Mode::_optr) \
.execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \
checker.set_param(Mode::_optr) \
.execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \
checker.set_param(Mode::_optr) \
.execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}}); \
checker.set_param(Mode::_optr).execs({{3, 4, 5}, {1}, {1}, {}}); \
checker.set_param(Mode::_optr).execs({{1}, {3, 4, 5}, {1}, {}});
#define BUILD_TERNARY_COMPLATE_TEST_CASE \
TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3)
TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle());
checker.set_param(Mode::FUSE_MUL_ADD3);
auto run = [&] {
//! nchw44
checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
//! nchw44
checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}});
checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}});
checker.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}});
checker.execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}});
checker.execs({{1, 7}, {1, 7}, {1, 7}, {}});
checker.execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}});
checker.execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}});
checker.execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
checker.execs({{3, 4, 5}, {1}, {1}, {}});
checker.execs({{1}, {3, 4, 5}, {1}, {}});
};
// case int
checker.set_dtype(0, dtype::Int8());
checker.set_dtype(1, dtype::Int8());
checker.set_dtype(2, dtype::Int8());
// BUILD_TERNARY_TEST_CASE
BUILD_TERNARY_COMPLATE_TEST_CASE
run();
checker.set_dtype(0, dtype::Int16());
checker.set_dtype(1, dtype::Int16());
checker.set_dtype(2, dtype::Int16());
// BUILD_TERNARY_TEST_CASE
BUILD_TERNARY_COMPLATE_TEST_CASE
run();
checker.set_dtype(0, dtype::Int32());
checker.set_dtype(1, dtype::Int32());
checker.set_dtype(2, dtype::Int32());
// BUILD_TERNARY_TEST_CASE
BUILD_TERNARY_COMPLATE_TEST_CASE
run();
// case float
UniformFloatRNG rng(1e-5, 7e1);
......@@ -78,9 +81,7 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) {
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());
checker.set_dtype(2, dtype::Float32());
// BUILD_TERNARY_TEST_CASE
BUILD_TERNARY_COMPLATE_TEST_CASE
run();
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
// case half
......@@ -90,9 +91,7 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) {
checker.set_dtype(0, dtype::Float16());
checker.set_dtype(1, dtype::Float16());
checker.set_dtype(2, dtype::Float16());
// BUILD_TERNARY_TEST_CASE
BUILD_TERNARY_COMPLATE_TEST_CASE
run();
#endif
}
......
......@@ -214,6 +214,30 @@ TEST_F(ARM_COMMON, ELEMWISE_QUANTIZED_MODE_TERNARY) {
using Mode = ElemwiseMultiType::Param::Mode;
Checker<ElemwiseMultiType> checker(handle());
auto run = [&]() {
//! nchw44
checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
//! nchw44
checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {1, 1, 1, 1}, {}});
checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {1, 4, 1, 1}, {}});
checker.execs({{3}, {3}, {3}, {}});
checker.execs({{9}, {9}, {9}, {}});
checker.execs({{17}, {17}, {17}, {}});
checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}, {}});
};
for (auto mode : {Mode::QFUSE_MUL_ADD3}) {
checker.set_param({mode});
......@@ -226,14 +250,7 @@ TEST_F(ARM_COMMON, ELEMWISE_QUANTIZED_MODE_TERNARY) {
.set_dtype(1, dtype::QuantizedS8(1.15f))
.set_dtype(2, dtype::QuantizedS8(1.75f))
.set_dtype(3, dtype::QuantizedS8(1.35f));
checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {1, 1, 1, 1}, {}});
checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {1, 4, 1, 1}, {}});
checker.execs({{3}, {3}, {3}, {}});
checker.execs({{9}, {9}, {9}, {}});
checker.execs({{17}, {17}, {17}, {}});
checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}, {}});
run();
// quint8 to quint8
UniformIntRNG rng_uint8{0, 225};
......@@ -248,14 +265,7 @@ TEST_F(ARM_COMMON, ELEMWISE_QUANTIZED_MODE_TERNARY) {
static_cast<uint8_t>(128)))
.set_dtype(3, dtype::Quantized8Asymm(
1.45f, static_cast<uint8_t>(128)));
checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {1, 1, 1, 1}, {}});
checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {1, 4, 1, 1}, {}});
checker.execs({{3}, {3}, {3}, {}});
checker.execs({{9}, {9}, {9}, {}});
checker.execs({{17}, {17}, {17}, {}});
checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}, {}});
run();
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册