From f19646b51fcba82eb9f2230d8fcd7e5006097275 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 8 May 2020 15:48:48 +0800 Subject: [PATCH] feat(dnn/arm_common/elemwise): elemwise ternary support chw44 GitOrigin-RevId: ef19a636ba4e47712585b0d627ef5c2c7d19d3b3 --- dnn/src/arm_common/elemwise/opr_impl.cpp | 18 ++++ dnn/src/arm_common/elemwise/opr_impl.h | 2 + dnn/src/arm_common/elemwise/ternary/algo.cpp | 79 +++++++++++++++ dnn/src/arm_common/elemwise/ternary/algo.h | 2 + .../elemwise_multi_type/opr_impl.cpp | 59 ++++++++++++ dnn/src/arm_common/elemwise_op.h | 96 +++++++++++++++++++ dnn/test/arm_common/elemwise.cpp | 69 +++++++------ dnn/test/arm_common/elemwise_multi_type.cpp | 42 ++++---- 8 files changed, 316 insertions(+), 51 deletions(-) diff --git a/dnn/src/arm_common/elemwise/opr_impl.cpp b/dnn/src/arm_common/elemwise/opr_impl.cpp index f3b5be56..c049cf5b 100644 --- a/dnn/src/arm_common/elemwise/opr_impl.cpp +++ b/dnn/src/arm_common/elemwise/opr_impl.cpp @@ -31,7 +31,10 @@ class ElemwiseImpl::AlgoPack { AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec; AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca; AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101; + AlgoTernaryFma3Bcast101x4VecBcast101x4 + algo_ternaryfma3_bcast101x4_vec_bcast101x4; AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec; + AlgoTernaryFma3VecBcast101x4Vec algo_ternaryfma3_vec_bcast101x4_vec; AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec; AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca; @@ -45,7 +48,9 @@ public: all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec); all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca); all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101); + all_algos.emplace_back(&algo_ternaryfma3_bcast101x4_vec_bcast101x4); all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec); + all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101x4_vec); all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec); all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca); } @@ -112,12 +117,25 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { return kern_param; } + if (is_vector(src1.layout) && + is_broadcastedx_channel_like<4>(src0.layout, binfo) && + src0.layout.eq_layout(src2.layout)) { + kern_param.broad_cast_type = BcastType::BCAST101x4_VEC_BCAST101x4; + return kern_param; + } + if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && is_broadcasted_channel_like(src1.layout, binfo)) { kern_param.broad_cast_type = BcastType::VEC_BCAST101_VEC; return kern_param; } + if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && + is_broadcastedx_channel_like<4>(src1.layout, binfo)) { + kern_param.broad_cast_type = BcastType::VEC_BCAST101x4_VEC; + return kern_param; + } + if (is_vector(src0.layout) && is_vector(src2.layout) && is_broadcasted_scalar(src1.layout)) { kern_param.broad_cast_type = BcastType::VEC_SCALAR_VEC; diff --git a/dnn/src/arm_common/elemwise/opr_impl.h b/dnn/src/arm_common/elemwise/opr_impl.h index a0d7743f..eab11604 100644 --- a/dnn/src/arm_common/elemwise/opr_impl.h +++ b/dnn/src/arm_common/elemwise/opr_impl.h @@ -41,7 +41,9 @@ private: class AlgoTernaryFma3VecVecVec; class AlgoTernaryFma3VecVecScalar; class AlgoTernaryFma3Bcast101VecBcast101; + class AlgoTernaryFma3Bcast101x4VecBcast101x4; class AlgoTernaryFma3VecBcast101Vec; + class AlgoTernaryFma3VecBcast101x4Vec; class AlgoTernaryFma3VecScalarVec; class AlgoTernaryFma3VecScalarScalar; class AlgoPack; diff --git a/dnn/src/arm_common/elemwise/ternary/algo.cpp b/dnn/src/arm_common/elemwise/ternary/algo.cpp index 624099a1..ffae4970 100644 --- a/dnn/src/arm_common/elemwise/ternary/algo.cpp +++ b/dnn/src/arm_common/elemwise/ternary/algo.cpp @@ -42,7 +42,9 @@ using namespace arm_common; DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC); DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR); DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101); +DECL_AVAILABLE(Bcast101x4VecBcast101x4, BcastType::BCAST101x4_VEC_BCAST101x4); DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC); +DECL_AVAILABLE(VecBcast101x4Vec, BcastType::VEC_BCAST101x4_VEC); DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC); DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR); #undef DECL_CB @@ -158,6 +160,82 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( return; } + +void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + BroadcastChannelInfo binfo; + is_broadcastedx_channel_like<4>(src0.layout, binfo); +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, \ + BcastType::BCAST101x4_VEC_BCAST101x4>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast(src2.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, \ + dst.layout.dtype, batch_size, binfo.x, binfo.y, \ + binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE("AlgoTernaryFma3Bcast101x4VecBcast101x4::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} + +void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + BroadcastChannelInfo binfo; + is_broadcastedx_channel_like<4>(src1.layout, binfo); +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary<_op<_type, _type>, \ + BcastType::VEC_BCAST101x4_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast(src2.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, \ + dst.layout.dtype, batch_size, binfo.x, binfo.y, \ + binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE("AlgoTernaryFma3VecBcast101x4Vec::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} + void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec( const KernParam& kern_param) const { auto& elparam = kern_param.ternary_elparam; @@ -193,6 +271,7 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec( return; } + void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec( const KernParam& kern_param) const { auto& elparam = kern_param.ternary_elparam; diff --git a/dnn/src/arm_common/elemwise/ternary/algo.h b/dnn/src/arm_common/elemwise/ternary/algo.h index 975af62a..62a02d7e 100644 --- a/dnn/src/arm_common/elemwise/ternary/algo.h +++ b/dnn/src/arm_common/elemwise/ternary/algo.h @@ -33,7 +33,9 @@ namespace arm_common { DECL_CB(VecVecVec); DECL_CB(VecVecScalar); DECL_CB(Bcast101VecBcast101); +DECL_CB(Bcast101x4VecBcast101x4); DECL_CB(VecBcast101Vec); +DECL_CB(VecBcast101x4Vec); DECL_CB(VecScalarVec); DECL_CB(VecScalarScalar); #undef DECL_CB diff --git a/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp b/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp index bcebe699..855984a1 100644 --- a/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp @@ -810,6 +810,65 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, DISPATCH() +#undef DISPATCH_SINGLE_MODE + } + } + + //! VEC + BCAST101x4 + VEC + { + BroadcastChannelInfo binfo; + if (is_vector(src0.layout) && + is_broadcastedx_channel_like<4>(src1.layout, binfo) && + src0.layout.eq_shape(src2.layout)) { +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerTernary<_op, \ + VEC_BCAST101x4_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + src2.ptr(), dst.ptr(), \ + src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ + dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ + return; \ + } + + size_t batch_size = + src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + DISPATCH() + +#undef DISPATCH_SINGLE_MODE + } + + //! BCAST101x + VEC +BCAST101x + if (is_vector(src1.layout) && + is_broadcastedx_channel_like<4>(src0.layout, binfo) && + src0.layout.eq_shape(src2.layout)) { +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerTernary<_op, \ + BCAST101x4_VEC_BCAST101x4>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + src2.ptr(), dst.ptr(), \ + src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ + dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ + return; \ + } + + size_t batch_size = + src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + DISPATCH() + #undef DISPATCH_SINGLE_MODE } } diff --git a/dnn/src/arm_common/elemwise_op.h b/dnn/src/arm_common/elemwise_op.h index 5cb7312e..54795ebc 100644 --- a/dnn/src/arm_common/elemwise_op.h +++ b/dnn/src/arm_common/elemwise_op.h @@ -105,7 +105,9 @@ enum BcastType { VEC_VEC_VEC, VEC_VEC_SCALAR, BCAST101_VEC_BCAST101, + BCAST101x4_VEC_BCAST101x4, VEC_BCAST101_VEC, + VEC_BCAST101x4_VEC, VEC_SCALAR_VEC, VEC_SCALAR_SCALAR, UNKNOWN_BCAST_TYPE @@ -681,6 +683,54 @@ struct OpCallerTernary { } }; +//! src0: CHW44, src1: vector, src2: CHW44 +template +struct OpCallerTernary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch, size_t nr_channel_blocks, + size_t channel_stride, size_t channel_block_dim) { + megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitorBcast101x4 vis0; + ParamElemVisitor vis1; + ParamElemVisitorBcast101x4 vis2; + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + auto src2_ptr = src2; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src0_block_ptr = src0_ptr + cb * channel_block_dim; + auto src2_block_ptr = src2_ptr + cb * channel_block_dim; + auto channel_block_vec0 = vis0(src0_block_ptr); + auto channel_block_vec2 = vis2(src2_block_ptr); + size_t img_index = 0; + auto src1_offset = Op::SIMD_WIDTH / channel_block_dim; + for (; img_index + 2 * src1_offset <= channel_stride; + img_index += 2 * src1_offset) { + op({{channel_block_vec0, channel_block_vec0}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, + {{channel_block_vec2, channel_block_vec2}}, dst); + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } + // TODO:all elemwise_multi_type op imp one simd mode + for (; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; + c_iter++) { + op(*(src0_block_ptr + c_iter), *src1, + *(src2_block_ptr + c_iter), dst); + src1++; + dst++; + } + } + } + } + } +}; + //! src1: 1C11, src0 and src2 are contig template struct OpCallerTernary { @@ -725,6 +775,52 @@ struct OpCallerTernary { } }; +//! src1: CHW44, src0 and src2 are contig +template +struct OpCallerTernary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch, size_t nr_channel_blocks, + size_t channel_stride, size_t channel_block_dim) { + megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorBcast101x4 vis1; + ParamElemVisitor vis2; + for (size_t b = 0; b < batch; b++) { + auto src1_ptr = src1; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src1_block_ptr = src1_ptr + cb * channel_block_dim; + auto channel_block_vec = vis1(src1_block_ptr); + size_t img_index = 0; + auto offset = Op::SIMD_WIDTH / channel_block_dim; + for (; img_index + 2 * offset <= channel_stride; + img_index += 2 * offset) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{channel_block_vec, channel_block_vec}}, + {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src2 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } + // TODO:all elemwise_multi_type op imp one simd mode + for (; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; + c_iter++) { + op(*src0, *(src1_block_ptr + c_iter), *src2, dst); + src0++; + src2++; + dst++; + } + } + } + } + } +}; + //! src1: scalar, src0 and src2 has the same shape template struct OpCallerTernary { diff --git a/dnn/test/arm_common/elemwise.cpp b/dnn/test/arm_common/elemwise.cpp index e46a3efc..653d3fe6 100644 --- a/dnn/test/arm_common/elemwise.cpp +++ b/dnn/test/arm_common/elemwise.cpp @@ -26,50 +26,53 @@ TYPED_TEST(ARM_ELEMWISE, run) { elemwise::run_test(this->handle()); } -#define TERNARY_COMPLATE_TEST_CASE(_optr) \ - printf("Check binary optr %s by all cases.\n", #_optr); \ - checker.set_param(Mode::_optr) \ - .execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \ - checker.set_param(Mode::_optr) \ - .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \ - checker.set_param(Mode::_optr) \ - .execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \ - checker.set_param(Mode::_optr) \ - .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \ - checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \ - checker.set_param(Mode::_optr) \ - .execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \ - checker.set_param(Mode::_optr) \ - .execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \ - checker.set_param(Mode::_optr) \ - .execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}}); \ - checker.set_param(Mode::_optr).execs({{3, 4, 5}, {1}, {1}, {}}); \ - checker.set_param(Mode::_optr).execs({{1}, {3, 4, 5}, {1}, {}}); - -#define BUILD_TERNARY_COMPLATE_TEST_CASE \ - TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3) - TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) { using Mode = ElemwiseForward::Param::Mode; Checker checker(handle()); + checker.set_param(Mode::FUSE_MUL_ADD3); + + auto run = [&] { + //! nchw44 + checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); + checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); + + //! nchw44 + checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); + checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); + checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); + checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); + + checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); + checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); + checker.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); + checker.execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); + checker.execs({{1, 7}, {1, 7}, {1, 7}, {}}); + checker.execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); + checker.execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); + checker.execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}}); + checker.execs({{3, 4, 5}, {1}, {1}, {}}); + checker.execs({{1}, {3, 4, 5}, {1}, {}}); + }; + // case int checker.set_dtype(0, dtype::Int8()); checker.set_dtype(1, dtype::Int8()); checker.set_dtype(2, dtype::Int8()); - // BUILD_TERNARY_TEST_CASE - BUILD_TERNARY_COMPLATE_TEST_CASE + run(); checker.set_dtype(0, dtype::Int16()); checker.set_dtype(1, dtype::Int16()); checker.set_dtype(2, dtype::Int16()); - // BUILD_TERNARY_TEST_CASE - BUILD_TERNARY_COMPLATE_TEST_CASE + run(); checker.set_dtype(0, dtype::Int32()); checker.set_dtype(1, dtype::Int32()); checker.set_dtype(2, dtype::Int32()); - // BUILD_TERNARY_TEST_CASE - BUILD_TERNARY_COMPLATE_TEST_CASE + run(); // case float UniformFloatRNG rng(1e-5, 7e1); @@ -78,9 +81,7 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) { checker.set_dtype(0, dtype::Float32()); checker.set_dtype(1, dtype::Float32()); checker.set_dtype(2, dtype::Float32()); - - // BUILD_TERNARY_TEST_CASE - BUILD_TERNARY_COMPLATE_TEST_CASE + run(); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC // case half @@ -90,9 +91,7 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) { checker.set_dtype(0, dtype::Float16()); checker.set_dtype(1, dtype::Float16()); checker.set_dtype(2, dtype::Float16()); - - // BUILD_TERNARY_TEST_CASE - BUILD_TERNARY_COMPLATE_TEST_CASE + run(); #endif } diff --git a/dnn/test/arm_common/elemwise_multi_type.cpp b/dnn/test/arm_common/elemwise_multi_type.cpp index 0a14b2f3..0725b0c8 100644 --- a/dnn/test/arm_common/elemwise_multi_type.cpp +++ b/dnn/test/arm_common/elemwise_multi_type.cpp @@ -214,6 +214,30 @@ TEST_F(ARM_COMMON, ELEMWISE_QUANTIZED_MODE_TERNARY) { using Mode = ElemwiseMultiType::Param::Mode; Checker checker(handle()); + auto run = [&]() { + //! nchw44 + checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); + checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); + + //! nchw44 + checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); + checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); + checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); + checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); + + checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {1, 1, 1, 1}, {}}); + checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {1, 4, 1, 1}, {}}); + + checker.execs({{3}, {3}, {3}, {}}); + checker.execs({{9}, {9}, {9}, {}}); + checker.execs({{17}, {17}, {17}, {}}); + checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}, {}}); + }; + for (auto mode : {Mode::QFUSE_MUL_ADD3}) { checker.set_param({mode}); @@ -226,14 +250,7 @@ TEST_F(ARM_COMMON, ELEMWISE_QUANTIZED_MODE_TERNARY) { .set_dtype(1, dtype::QuantizedS8(1.15f)) .set_dtype(2, dtype::QuantizedS8(1.75f)) .set_dtype(3, dtype::QuantizedS8(1.35f)); - - checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {1, 1, 1, 1}, {}}); - checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {1, 4, 1, 1}, {}}); - - checker.execs({{3}, {3}, {3}, {}}); - checker.execs({{9}, {9}, {9}, {}}); - checker.execs({{17}, {17}, {17}, {}}); - checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}, {}}); + run(); // quint8 to quint8 UniformIntRNG rng_uint8{0, 225}; @@ -248,14 +265,7 @@ TEST_F(ARM_COMMON, ELEMWISE_QUANTIZED_MODE_TERNARY) { static_cast(128))) .set_dtype(3, dtype::Quantized8Asymm( 1.45f, static_cast(128))); - - checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {1, 1, 1, 1}, {}}); - checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {1, 4, 1, 1}, {}}); - - checker.execs({{3}, {3}, {3}, {}}); - checker.execs({{9}, {9}, {9}, {}}); - checker.execs({{17}, {17}, {17}, {}}); - checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}, {}}); + run(); } } -- GitLab