From 5885b137fa8ed76b6a14d72767ae8b8efee3a472 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 8 Oct 2021 10:47:57 +0800 Subject: [PATCH] feat(dnn/arm): support layout like NHWC channel like broadcast on arm GitOrigin-RevId: fb4300004c4e1920d3cd1be40ca33bde822e4c72 --- dnn/src/arm_common/elemwise/binary/algo.cpp | 81 +++++++ dnn/src/arm_common/elemwise/binary/algo.h | 1 + dnn/src/arm_common/elemwise/opr_impl.cpp | 40 ++++ dnn/src/arm_common/elemwise/opr_impl.h | 3 + dnn/src/arm_common/elemwise/ternary/algo.cpp | 80 +++++++ dnn/src/arm_common/elemwise/ternary/algo.h | 2 + dnn/src/arm_common/elemwise_op.h | 229 +++++++++++++++++++ dnn/src/arm_common/quantized_converter.h | 14 ++ dnn/src/arm_common/type_cvt/opr_impl.cpp | 107 ++++++++- dnn/src/common/elemwise/opr_impl_helper.cpp | 13 ++ dnn/src/common/elemwise/opr_impl_helper.h | 10 + dnn/src/fallback/type_cvt/opr_impl.cpp | 6 +- dnn/src/naive/type_cvt/opr_impl.cpp | 4 +- dnn/test/arm_common/elemwise.cpp | 91 ++++++++ dnn/test/arm_common/type_cvt.cpp | 20 ++ dnn/test/common/checker.cpp | 3 +- dnn/test/common/rng.cpp | 3 + dnn/test/cuda/type_cvt.cpp | 5 + dnn/test/x86/type_cvt.cpp | 14 ++ src/opr/impl/basic_arith.cpp | 4 +- 20 files changed, 723 insertions(+), 7 deletions(-) diff --git a/dnn/src/arm_common/elemwise/binary/algo.cpp b/dnn/src/arm_common/elemwise/binary/algo.cpp index 5cad7e433..cf7e086b4 100644 --- a/dnn/src/arm_common/elemwise/binary/algo.cpp +++ b/dnn/src/arm_common/elemwise/binary/algo.cpp @@ -104,6 +104,21 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available( return false; } +bool ElemwiseImpl::AlgoBinaryVecBcast111C::is_available( + const KernParam& kern_param) const { + if (!is_available_common(kern_param.mode) || + ((BcastType::VEC_BCAST111C != kern_param.broad_cast_type) && + (BcastType::BCAST111C_VEC != kern_param.broad_cast_type))) + return false; + + auto& elparam = kern_param.binary_elparam; + auto& src0 = elparam[0]; + + DISPATCH_TYPE("AlgoBinaryVecBcast111C::is_available"_hash); + + return false; +} + bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available( const KernParam& kern_param) const { if (!is_available_common(kern_param.mode) || @@ -333,6 +348,72 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) cons return; } +void ElemwiseImpl::AlgoBinaryVecBcast111C::exec(const KernParam& kern_param) const { + auto& elparam = kern_param.binary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1]; + auto&& dst = *(kern_param.m_dst); + BroadcastChannelInfo binfo; + + // Case extra: BcastType::VEC + BCAST_111C + if (BcastType::VEC_BCAST111C == kern_param.broad_cast_type && + is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::VEC_BCAST111C>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ + binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + DISPATCH_TYPE("AlgoBinaryVecBcast111C::exec_vec_b"_hash); + +#undef DISPATCH_BINARY + } + + // BCAST_111C + BcastType::VEC + if (BcastType::BCAST111C_VEC == kern_param.broad_cast_type && + is_NHWC_broadcasted_channel_like(src0.layout, binfo)) { +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::BCAST111C_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ + binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + DISPATCH_TYPE("AlgoBinaryVecBcast111C::exec_b_vec"_hash); + +#undef DISPATCH_BINARY + } + return; +} + void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec(const KernParam& kern_param) const { auto& elparam = kern_param.binary_elparam; auto &src0 = elparam[0], &src1 = elparam[1]; diff --git a/dnn/src/arm_common/elemwise/binary/algo.h b/dnn/src/arm_common/elemwise/binary/algo.h index 05ac6937d..f44621833 100644 --- a/dnn/src/arm_common/elemwise/binary/algo.h +++ b/dnn/src/arm_common/elemwise/binary/algo.h @@ -33,6 +33,7 @@ namespace arm_common { DECL_CB(VecVec); DECL_CB(VecScalar); DECL_CB(VecBcast101); +DECL_CB(VecBcast111C); DECL_CB(VecBcast101xX); #undef DECL_CB } // namespace arm_common diff --git a/dnn/src/arm_common/elemwise/opr_impl.cpp b/dnn/src/arm_common/elemwise/opr_impl.cpp index 5e8c2f3eb..75c94358b 100644 --- a/dnn/src/arm_common/elemwise/opr_impl.cpp +++ b/dnn/src/arm_common/elemwise/opr_impl.cpp @@ -27,12 +27,15 @@ class ElemwiseImpl::AlgoPack { AlgoBinaryVecVec algo_binary_vec_vec; AlgoBinaryVecScalar algo_binary_vec_sca; AlgoBinaryVecBcast101 algo_binary_vec_bcast101; + AlgoBinaryVecBcast111C algo_binary_vec_bcast110; AlgoBinaryVecBcast101xX algo_binary_VEC_BCAST101xX; AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec; AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca; AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101; + AlgoTernaryFma3Bcast111CVecBcast111C algo_ternaryfma3_bcast110_vec_bcast110; AlgoTernaryFma3Bcast101xXVecBcast101xX algo_ternaryfma3_bcast101xX_vec_bcast101xX; AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec; + AlgoTernaryFma3VecBcast111CVec algo_ternaryfma3_vec_bcast110_vec; AlgoTernaryFma3VecBcast101xXVec algo_ternaryfma3_vec_bcast101xX_vec; AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec; AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca; @@ -43,12 +46,15 @@ public: all_algos.emplace_back(&algo_binary_vec_vec); all_algos.emplace_back(&algo_binary_vec_sca); all_algos.emplace_back(&algo_binary_vec_bcast101); + all_algos.emplace_back(&algo_binary_vec_bcast110); all_algos.emplace_back(&algo_binary_VEC_BCAST101xX); all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec); all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca); all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101); + all_algos.emplace_back(&algo_ternaryfma3_bcast110_vec_bcast110); all_algos.emplace_back(&algo_ternaryfma3_bcast101xX_vec_bcast101xX); all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec); + all_algos.emplace_back(&algo_ternaryfma3_vec_bcast110_vec); all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101xX_vec); all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec); all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca); @@ -87,6 +93,14 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { kern_param.mode = opr->param().mode; kern_param.handle = opr->handle(); + auto is_legal_layout_for_nhwc = [](const TensorLayout& l) { + if (is_vector(l)) + return true; + if (l.ndim == 2 && l.stride[1] == 1) + return true; + return false; + }; + if ((opr->m_src->size() == 3) && (opr->param().mode == Mode::FUSE_MUL_ADD3)) { kern_param.ternary_elparam = opr->make_elemwise_op_param<3>(); bool c_is_scalar; @@ -127,6 +141,20 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { return kern_param; } + if (is_legal_layout_for_nhwc(src1.layout) && + is_NHWC_broadcasted_channel_like(src0.layout, binfo) && + src0.layout.eq_layout(src2.layout)) { + kern_param.broad_cast_type = BcastType::BCAST111C_VEC_BCAST111C; + return kern_param; + } + + if (is_legal_layout_for_nhwc(src0.layout) && + src2.layout.eq_layout(src0.layout) && + is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { + kern_param.broad_cast_type = BcastType::VEC_BCAST111C_VEC; + return kern_param; + } + if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && (is_broadcastedx_channel_like<4>(src1.layout, binfo) || is_broadcastedx_channel_like<8>(src1.layout, binfo))) { @@ -174,6 +202,18 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { return kern_param; } + if (is_legal_layout_for_nhwc(src1.layout) && + is_NHWC_broadcasted_channel_like(src0.layout, binfo)) { + kern_param.broad_cast_type = BcastType::BCAST111C_VEC; + return kern_param; + } + + if (is_legal_layout_for_nhwc(src0.layout) && + is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { + kern_param.broad_cast_type = BcastType::VEC_BCAST111C; + return kern_param; + } + if (is_vector(src0.layout) && (is_broadcastedx_channel_like<4>(src1.layout, binfo) || is_broadcastedx_channel_like<8>(src1.layout, binfo))) { diff --git a/dnn/src/arm_common/elemwise/opr_impl.h b/dnn/src/arm_common/elemwise/opr_impl.h index 769907137..33fa1c55d 100644 --- a/dnn/src/arm_common/elemwise/opr_impl.h +++ b/dnn/src/arm_common/elemwise/opr_impl.h @@ -38,12 +38,15 @@ private: class AlgoBinaryVecVec; class AlgoBinaryVecScalar; class AlgoBinaryVecBcast101; + class AlgoBinaryVecBcast111C; class AlgoBinaryVecBcast101xX; class AlgoTernaryFma3VecVecVec; class AlgoTernaryFma3VecVecScalar; class AlgoTernaryFma3Bcast101VecBcast101; + class AlgoTernaryFma3Bcast111CVecBcast111C; class AlgoTernaryFma3Bcast101xXVecBcast101xX; class AlgoTernaryFma3VecBcast101Vec; + class AlgoTernaryFma3VecBcast111CVec; class AlgoTernaryFma3VecBcast101xXVec; class AlgoTernaryFma3VecScalarVec; class AlgoTernaryFma3VecScalarScalar; diff --git a/dnn/src/arm_common/elemwise/ternary/algo.cpp b/dnn/src/arm_common/elemwise/ternary/algo.cpp index 0372114de..d5bb93595 100644 --- a/dnn/src/arm_common/elemwise/ternary/algo.cpp +++ b/dnn/src/arm_common/elemwise/ternary/algo.cpp @@ -42,8 +42,10 @@ using namespace arm_common; DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC); DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR); DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101); +DECL_AVAILABLE(Bcast111CVecBcast111C, BcastType::BCAST111C_VEC_BCAST111C); DECL_AVAILABLE(Bcast101xXVecBcast101xX, BcastType::BCAST101xX_VEC_BCAST101xX); DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC); +DECL_AVAILABLE(VecBcast111CVec, BcastType::VEC_BCAST111C_VEC); DECL_AVAILABLE(VecBcast101xXVec, BcastType::VEC_BCAST101xX_VEC); DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC); DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR); @@ -164,6 +166,45 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( return; } +void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + // Case 3: shape of src0 and src2 is {1, 1, 1, C} + BroadcastChannelInfo binfo; + is_NHWC_broadcasted_channel_like(src0.layout, binfo); +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, \ + BcastType::BCAST111C_VEC_BCAST111C>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z, \ + static_cast(src2.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + binfo.x, binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE("AlgoTernaryFma3Bcast111CVecBcast111C::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} + void ElemwiseImpl::AlgoTernaryFma3Bcast101xXVecBcast101xX::exec( const KernParam& kern_param) const { auto& elparam = kern_param.ternary_elparam; @@ -282,6 +323,45 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec( return; } +void ElemwiseImpl::AlgoTernaryFma3VecBcast111CVec::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + // Case 4: shape of src1 is {1, 1, 1, C}, and src0 and src2 are contig + BroadcastChannelInfo binfo; + is_NHWC_broadcasted_channel_like(src1.layout, binfo); +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, BcastType::VEC_BCAST111C_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + is_vector(src0.layout) ? 0 : src0.layout.stride[0] - binfo.z, \ + static_cast(src1.raw_ptr), \ + static_cast(src2.raw_ptr), \ + is_vector(src2.layout) ? 0 : src2.layout.stride[0] - binfo.z, \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + binfo.x, binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE("AlgoTernaryFma3VecBcast111CVec::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} + void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec( const KernParam& kern_param) const { auto& elparam = kern_param.ternary_elparam; diff --git a/dnn/src/arm_common/elemwise/ternary/algo.h b/dnn/src/arm_common/elemwise/ternary/algo.h index 211ba451f..c587688d3 100644 --- a/dnn/src/arm_common/elemwise/ternary/algo.h +++ b/dnn/src/arm_common/elemwise/ternary/algo.h @@ -33,8 +33,10 @@ namespace arm_common { DECL_CB(VecVecVec); DECL_CB(VecVecScalar); DECL_CB(Bcast101VecBcast101); +DECL_CB(Bcast111CVecBcast111C); DECL_CB(Bcast101xXVecBcast101xX); DECL_CB(VecBcast101Vec); +DECL_CB(VecBcast111CVec); DECL_CB(VecBcast101xXVec); DECL_CB(VecScalarVec); DECL_CB(VecScalarScalar); diff --git a/dnn/src/arm_common/elemwise_op.h b/dnn/src/arm_common/elemwise_op.h index a69ad7a6f..db18f422f 100644 --- a/dnn/src/arm_common/elemwise_op.h +++ b/dnn/src/arm_common/elemwise_op.h @@ -107,16 +107,20 @@ enum BcastType { VEC, VEC_VEC, VEC_BCAST101, + VEC_BCAST111C, VEC_BCAST101xX, VEC_SCALAR, SCALAR_VEC, BCAST101_VEC, + BCAST111C_VEC, BCAST101xX_VEC, VEC_VEC_VEC, VEC_VEC_SCALAR, BCAST101_VEC_BCAST101, + BCAST111C_VEC_BCAST111C, BCAST101xX_VEC_BCAST101xX, VEC_BCAST101_VEC, + VEC_BCAST111C_VEC, VEC_BCAST101xX_VEC, VEC_SCALAR_VEC, VEC_SCALAR_SCALAR, @@ -226,6 +230,60 @@ struct OpCallerBinary, VEC_BCAST101> { } }; +template +struct OpCallerBinary, VEC_BCAST111C> { + using Op = PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + const typename Op::src_ctype* src1_ptr = src1; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, dst); + src0++; + src1_ptr++; + dst++; + } + } + } + } +}; + +template +struct OpCallerBinary, BCAST111C_VEC> { + using Op = PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + const typename Op::src_ctype* src0_ptr = src0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, dst); + src0_ptr++; + src1++; + dst++; + } + } + } + } +}; + template struct OpCallerBinary, SCALAR_VEC> { using Op = PowOp; @@ -340,6 +398,84 @@ struct OpCallerBinary { } }; +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis; + for (size_t b = 0; b < batch; b++) { + for (size_t c = 0; c < channel; c++) { + size_t rest = channel_stride; + const typename Op::src_ctype* src1_ptr = src1; + while (rest >= Op::SIMD_WIDTH * 2) { + auto src0_neon0 = vis(src0); + auto src0_neon1 = vis(src0 + Op::SIMD_WIDTH); + auto src1_neon0 = vis(src1_ptr); + auto src1_neon1 = vis(src1_ptr + Op::SIMD_WIDTH); + src0 += Op::SIMD_WIDTH * 2; + src1_ptr += Op::SIMD_WIDTH * 2; + op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst); + dst += Op::SIMD_WIDTH * 2; + rest -= Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + while (rest > 0) { + op(*src0, *src1_ptr, dst); + dst++; + src0++; + src1_ptr++; + rest--; + } + } + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis; + for (size_t b = 0; b < batch; b++) { + for (size_t c = 0; c < channel; c++) { + size_t rest = channel_stride; + const typename Op::src_ctype* src0_ptr = src0; + while (rest >= Op::SIMD_WIDTH * 2) { + auto src0_neon0 = vis(src0_ptr); + auto src0_neon1 = vis(src0_ptr + Op::SIMD_WIDTH); + auto src1_neon0 = vis(src1); + auto src1_neon1 = vis(src1 + Op::SIMD_WIDTH); + src0_ptr += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst); + dst += Op::SIMD_WIDTH * 2; + rest -= Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + while (rest > 0) { + op(*src0_ptr, *src1, dst); + dst++; + src0_ptr++; + src1++; + rest--; + } + } + } + } +}; + template struct OpCallerBinary, BCAST101xX_VEC> { using Op = PowOp; @@ -824,6 +960,54 @@ struct OpCallerTernary { } }; +//! src0: 111C, src1: vector, src2: 111C, src1 may not be contig +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + size_t src1_offset, const typename Op::src_ctype* src2, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size, + size_t channel_stride) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis; + for (size_t batch = 0; batch < batch_size; batch++) { + for (size_t channel = 0; channel < channel_size; channel++) { + auto src0_ptr = src0; + auto src2_ptr = src2; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + auto src0_neon0 = vis(src0_ptr); + auto src0_neon1 = vis(src0_ptr + Op::SIMD_WIDTH); + auto src1_neon0 = vis(src1); + auto src1_neon1 = vis(src1 + Op::SIMD_WIDTH); + auto src2_neon0 = vis(src2_ptr); + auto src2_neon1 = vis(src2_ptr + Op::SIMD_WIDTH); + op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, + {{src2_neon0, src2_neon1}}, dst); + src0_ptr += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + src2_ptr += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, *src2_ptr, dst); + src0_ptr++; + src1++; + src2_ptr++; + dst++; + } + src1 += src1_offset; + } + } + } +}; + template struct OpCallerTernaryBcast101xXVecBcast101xX { template @@ -992,6 +1176,51 @@ struct OpCallerTernary { } }; +//! src1: 111C, src0 and src2 may not be contig +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, size_t src0_offset, + const typename Op::src_ctype* src1, const typename Op::src_ctype* src2, + size_t src2_offset, typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t batch_size, + size_t channel_size, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitor vis1; + ParamElemVisitor vis2; + for (size_t batch = 0; batch < batch_size; batch++) { + for (size_t channel = 0; channel < channel_size; channel++) { + auto src1_ptr = src1; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{vis1(src1_ptr), vis1(src1_ptr + Op::SIMD_WIDTH)}}, + {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src1_ptr += Op::SIMD_WIDTH * 2; + src2 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, *src2, dst); + src0++; + src1_ptr++; + src2++; + dst++; + } + src0 += src0_offset; + src2 += src2_offset; + } + } + } +}; + template struct OpCallerTernaryVecBcast101xXVec { template diff --git a/dnn/src/arm_common/quantized_converter.h b/dnn/src/arm_common/quantized_converter.h index c752dbe67..84bde87a9 100644 --- a/dnn/src/arm_common/quantized_converter.h +++ b/dnn/src/arm_common/quantized_converter.h @@ -50,6 +50,20 @@ inline dt_qint32 QConverter::convert(const float& src) { saturate(std::round(src), -2147483648, 2147483647)); } +template <> +inline float32x4x2_t QConverter::convert(const int16x8_t& vsrc) { + int32x4_t vhi = vmovl_s16(vget_high_s16(vsrc)); + int32x4_t vlo = vmovl_s16(vget_low_s16(vsrc)); + return {{vcvtq_f32_s32(vlo), vcvtq_f32_s32(vhi)}}; +} + +template <> +inline float32x4x2_t QConverter::convert(const uint16x8_t& vsrc) { + uint32x4_t vhi = vmovl_u16(vget_high_u16(vsrc)); + uint32x4_t vlo = vmovl_u16(vget_low_u16(vsrc)); + return {{vcvtq_f32_u32(vlo), vcvtq_f32_u32(vhi)}}; +} + #if __ARM_ARCH >= 8 template <> inline int8x8_t QConverter::convert(const float32x4x2_t& vsrc) { diff --git a/dnn/src/arm_common/type_cvt/opr_impl.cpp b/dnn/src/arm_common/type_cvt/opr_impl.cpp index 900e16c23..b6eb5f27a 100644 --- a/dnn/src/arm_common/type_cvt/opr_impl.cpp +++ b/dnn/src/arm_common/type_cvt/opr_impl.cpp @@ -17,6 +17,7 @@ #include "src/common/utils.h" #include "src/naive/handle.h" +MIDOUT_DECL(megdnn_arm_typecvt_fix2float) MIDOUT_DECL(megdnn_arm_typecvt_quantized) MIDOUT_DECL(megdnn_arm_typecvt_float) @@ -325,6 +326,48 @@ struct FloatTypeCvter { }; #endif +template +struct Fix2FloatTypeCvter; +template <> +struct Fix2FloatTypeCvter { + using stype = int16_t; + using dst_type = float; + static constexpr size_t SIMD_WIDTH = 8; + + Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) { + MEGDNN_MARK_USED_VAR(src_dtype); + MEGDNN_MARK_USED_VAR(dst_dtype); + } + + void cvt(const int16_t* src, float* dst) { + int16x8_t vitem = vld1q_s16(src); + auto vres = QConverter::convert(vitem); + vst1q_f32_x2(dst, vres); + } + + void cvt_remain(const int16_t* src, float* dst) { *dst = *src; } +}; + +template <> +struct Fix2FloatTypeCvter { + using stype = uint16_t; + using dst_type = float; + static constexpr size_t SIMD_WIDTH = 8; + + Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) { + MEGDNN_MARK_USED_VAR(src_dtype); + MEGDNN_MARK_USED_VAR(dst_dtype); + } + + void cvt(const uint16_t* src, float* dst) { + uint16x8_t vitem = vld1q_u16(src); + auto vres = QConverter::convert(vitem); + vst1q_f32_x2(dst, vres); + } + + void cvt_remain(const uint16_t* src, float* dst) { *dst = *src; } +}; + template void do_typecvt( const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst, @@ -347,6 +390,43 @@ void do_typecvt( } } +template +void do_typecvt( + const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst, + DType src_dtype, DType dst_dtype, const TensorLayout& src_layout) { + TypeCvter typecvt(src_dtype, dst_dtype); + size_t calc_num = 1; + size_t nr_elems = src_layout.total_nr_elems(); + size_t src_stride = nr_elems; + + //! adjust calc_num nr_elems and src_stride according to src_collapse_layout + auto src_collapse_layout = src_layout.collapse_contiguous(); + if (src_collapse_layout.ndim == 2) { + calc_num = src_collapse_layout.shape[0]; + nr_elems = src_collapse_layout.shape[1]; + src_stride = src_collapse_layout.stride[0]; + } + + for (size_t c = 0; c < calc_num; ++c) { + size_t i = 0; + for (; i + TypeCvter::SIMD_WIDTH <= nr_elems; i += TypeCvter::SIMD_WIDTH) { + typecvt.cvt(src, dst); + src += TypeCvter::SIMD_WIDTH; + dst += TypeCvter::SIMD_WIDTH; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + typecvt.cvt_remain(src, dst); + src++; + dst++; + } + src += src_stride - nr_elems; + } +} + } // anonymous namespace void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { @@ -354,7 +434,30 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { DType dst_dtype = dst.layout.dtype; size_t nr_elems = src.layout.total_nr_elems(); bool execed = false; - if (src.layout.is_contiguous()) { + auto src_collapse_layout = src.layout.collapse_contiguous(); + bool has_int16_special_impl = + (src.layout.dtype.enumv() == DTypeEnum::Int16 || + src.layout.dtype.enumv() == DTypeEnum::Uint16) && + (src.layout.is_contiguous() || src_collapse_layout.ndim == 2) && + dst.layout.is_contiguous(); + if (has_int16_special_impl) { + using namespace dtype; +#define DISPATCH_FIX2FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ + if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ + dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ + MIDOUT_BEGIN(megdnn_arm_typecvt_fix2float, midout_iv(_midout_iv)) { \ + using _TypeCvter = Fix2FloatTypeCvter<_stype, _dtype>; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ + src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \ + src_dtype, dst_dtype, src.layout)); \ + execed = true; \ + } \ + MIDOUT_END(); \ + } + DISPATCH_FIX2FLOAT(Int16, int16_t, Float32, float, 0); + DISPATCH_FIX2FLOAT(Uint16, uint16_t, Float32, float, 1); +#undef DISPATCH_FIX2FLOAT + } else if (src.layout.is_contiguous()) { using namespace dtype; #define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ @@ -377,6 +480,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS32, int32_t, 5); DISPATCH_QUANTIZED(float, float, QuantizedS8, int8_t, 6); DISPATCH_QUANTIZED(float, float, Quantized8Asymm, uint8_t, 7); +#undef DISPATCH_QUANTIZED #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #define DISPATCH_FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ @@ -394,6 +498,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { } DISPATCH_FLOAT(dt_float16, __fp16, float, float, 0); DISPATCH_FLOAT(float, float, dt_float16, __fp16, 1); +#undef DISPATCH_FLOAT #endif } if (!execed) { diff --git a/dnn/src/common/elemwise/opr_impl_helper.cpp b/dnn/src/common/elemwise/opr_impl_helper.cpp index 30bf96a99..bee1cce23 100644 --- a/dnn/src/common/elemwise/opr_impl_helper.cpp +++ b/dnn/src/common/elemwise/opr_impl_helper.cpp @@ -150,6 +150,19 @@ bool ElemwiseLayoutHelper::is_broadcasted_channel_like( return false; } +bool ElemwiseLayoutHelper::is_NHWC_broadcasted_channel_like( + const TensorLayout& layout, BroadcastChannelInfo& info) { + if (layout.format.type() == TensorFormat::Type::DEFAULT) { + if (layout.ndim == 2 && layout.stride[1] == 1 && layout.stride[0] == 0) { + info.x = 1; + info.y = layout.shape[0]; + info.z = layout.shape[1]; + return true; + } + } + return false; +} + bool ElemwiseLayoutHelper::is_broadcasted_1x( const TensorLayout& layout, Broadcast1xInfo& binfo) { if (layout.ndim == 2 && layout.stride[0] == 0 && layout.stride[1] == 1) { diff --git a/dnn/src/common/elemwise/opr_impl_helper.h b/dnn/src/common/elemwise/opr_impl_helper.h index 2aabbd85b..fde396afe 100644 --- a/dnn/src/common/elemwise/opr_impl_helper.h +++ b/dnn/src/common/elemwise/opr_impl_helper.h @@ -80,6 +80,16 @@ public: static bool is_broadcasted_channel_like( const TensorLayout& layout, BroadcastChannelInfo& info); + /*! + * \brief check whether layout matches BroadcastChannelInfo under NHWC + * layout + * + * Note that Input must be 2-dimensional, and must be [1, y] broadacsted + * into [z, y] and x would be set to 1. + */ + static bool is_NHWC_broadcasted_channel_like( + const TensorLayout& layout, BroadcastChannelInfo& info); + /*! * \brief check whether layout matches BroadcastChannelInfo * diff --git a/dnn/src/fallback/type_cvt/opr_impl.cpp b/dnn/src/fallback/type_cvt/opr_impl.cpp index 9aa2102f9..5cae3fbfb 100644 --- a/dnn/src/fallback/type_cvt/opr_impl.cpp +++ b/dnn/src/fallback/type_cvt/opr_impl.cpp @@ -309,7 +309,8 @@ void on_dest_ctype(_megdnn_tensor_in src, _megdnn_tensor_out dst) { break; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) - cb(::megdnn::dtype::Bool) case DTypeEnum::QuantizedS8 + cb(::megdnn::dtype::Bool) + cb(::megdnn::dtype::Uint16) case DTypeEnum::QuantizedS8 : MIDOUT_BEGIN( megdnn_fb_typecvt_src_dtype, midout_iv(DTypeEnum::QuantizedS8)) { @@ -467,7 +468,8 @@ void run_contiguous(_megdnn_tensor_in src, _megdnn_tensor_out dst) { } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) - cb(::megdnn::dtype::Bool) case DTypeEnum::QuantizedS8 + cb(::megdnn::dtype::Bool) + cb(::megdnn::dtype::Uint16) case DTypeEnum::QuantizedS8 : MIDOUT_BEGIN( megdnn_fb_typecvt_dst_dtype, midout_iv(DTypeEnum::QuantizedS8)) { diff --git a/dnn/src/naive/type_cvt/opr_impl.cpp b/dnn/src/naive/type_cvt/opr_impl.cpp index b36661fe1..4ab170d04 100644 --- a/dnn/src/naive/type_cvt/opr_impl.cpp +++ b/dnn/src/naive/type_cvt/opr_impl.cpp @@ -78,7 +78,7 @@ void on_dest_ctype(HandleImpl* handle, const TensorND& dest, const TensorND& src MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) - cb(::megdnn::dtype::Bool) + cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) #undef cb default : megdnn_throw("bad dtype"); } @@ -99,7 +99,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) - cb(::megdnn::dtype::Bool) + cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) #undef cb default : megdnn_throw("bad dtype"); } diff --git a/dnn/test/arm_common/elemwise.cpp b/dnn/test/arm_common/elemwise.cpp index 72c8b9210..f144a46f7 100644 --- a/dnn/test/arm_common/elemwise.cpp +++ b/dnn/test/arm_common/elemwise.cpp @@ -14,6 +14,7 @@ #include "test/common/benchmarker.h" #include "test/common/checker.h" +#include "megdnn/opr_param_defs.h" #include "megdnn/oprs/general.h" using namespace megdnn; @@ -298,6 +299,63 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) { #endif } +TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NHWC_FP32_BCAST) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle()); + + UniformFloatRNG rng(1e-5, 7e1); + checker.set_rng(0, &rng); + checker.set_epsilon(1e-5); + checker.set_dtype(0, dtype::Float32()); + checker.set_dtype(1, dtype::Float32()); + + //! 2 dim + auto run = [&](Mode mode) { + // VEC_BCAST111C + checker.set_param(mode).execs({{1, 2, 2, 12}, {1, 1, 1, 12}, {}}); + checker.set_param(mode).execs({{2, 5, 3, 28}, {1, 1, 1, 28}, {}}); + checker.set_param(mode).execs({{3, 5, 8, 32}, {1, 1, 1, 32}, {}}); + // BCAST111C_VEC + checker.set_param(mode).execs({{1, 1, 1, 12}, {1, 2, 2, 12}, {}}); + checker.set_param(mode).execs({{1, 1, 1, 28}, {2, 5, 3, 28}, {}}); + checker.set_param(mode).execs({{1, 1, 1, 32}, {3, 5, 8, 32}, {}}); + }; + run(Mode::ADD); + run(Mode::MUL); + run(Mode::SUB); + + //! 3 dim contig + auto run_3d_contig = [&](Mode mode) { + // BCAST111C_VEC_BCAST111C + checker.set_param(mode).execs( + {{1, 1, 1, 12}, {1, 2, 2, 12}, {1, 1, 1, 12}, {}}); + checker.set_param(mode).execs( + {{1, 1, 1, 28}, {2, 5, 3, 28}, {1, 1, 1, 28}, {}}); + checker.set_param(mode).execs( + {{1, 1, 1, 32}, {3, 5, 8, 32}, {1, 1, 1, 32}, {}}); + // VEC_BCAST111C_VEC + checker.set_param(mode).execs( + {{1, 2, 2, 12}, {1, 1, 1, 12}, {1, 2, 2, 12}, {}}); + checker.set_param(mode).execs( + {{2, 5, 3, 28}, {1, 1, 1, 28}, {2, 5, 3, 28}, {}}); + checker.set_param(mode).execs( + {{3, 5, 8, 32}, {1, 1, 1, 32}, {3, 5, 8, 32}, {}}); + }; + run_3d_contig(Mode::FUSE_MUL_ADD3); + + //! 3 dim incontig + auto run_3d_incontig = [&](Mode mode) { + megdnn::TensorLayout src0({1, 1, 1, 12}, dtype::Float32()); + megdnn::TensorLayout src1({1, 2, 2, 12}, {80, 40, 20, 1}, dtype::Float32()); + + // BCAST111C_VEC_BCAST111C + checker.set_param(mode).execl({src0, src1, src0, {}}); + // VEC_BCAST111C_VEC + checker.set_param(mode).execl({src1, src0, src1, {}}); + }; + run_3d_incontig(Mode::FUSE_MUL_ADD3); +} + #if MEGDNN_WITH_BENCHMARK namespace { void run_elemwise_benchmark( @@ -354,6 +412,39 @@ void run_elemwise_benchmark( } } // namespace +TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NHWC) { + Benchmarker benchmarker(handle()); + constexpr size_t RUN = 50; + benchmarker.set_times(RUN).set_display(false); + + auto run = [&](size_t N, size_t C, size_t H, size_t W, param::Elemwise::Mode mode, + const char* mode_name) { + megdnn::param::Elemwise param; + param.mode = mode; + benchmarker.set_param(param); + megdnn::TensorShape nhwc_src0{N, H, W, C}; + megdnn::TensorShape nhwc_src1{1, 1, 1, C}; + + megdnn::TensorShape nchw_src0{N, C, H, W}; + megdnn::TensorShape nchw_src1{1, C, 1, 1}; + + float computations = N * C * H * W; + auto nhwc_time = benchmarker.execs({nhwc_src1, nhwc_src0, {}}) / RUN; + auto nchw_time = benchmarker.execs({nchw_src1, nchw_src0, {}}) / RUN; + auto perf_nhwc = computations / nhwc_time / 1e6; + auto perf_nchw = computations / nchw_time / 1e6; + printf("Elemwise Mode : %s\nNHWC : %fms %fGflops\nNCHW : %fms " + "%fGflops\n", + mode_name, nhwc_time, perf_nhwc, nchw_time, perf_nchw); + }; + run(1, 120, 16, 24, param::Elemwise::Mode::ADD, "ADD"); + run(1, 120, 16, 24, param::Elemwise::Mode::MUL, "MUL"); + run(1, 120, 32, 48, param::Elemwise::Mode::ADD, "ADD"); + run(1, 120, 32, 48, param::Elemwise::Mode::MUL, "MUL"); + run(1, 120, 64, 96, param::Elemwise::Mode::ADD, "ADD"); + run(1, 120, 64, 96, param::Elemwise::Mode::MUL, "MUL"); +} + #define INT_RUN(shape, mode) \ run_elemwise_benchmark(shape, mode, #mode, dtype::Int8{}, handle()); \ run_elemwise_benchmark(shape, mode, #mode, dtype::Int16{}, handle()); \ diff --git a/dnn/test/arm_common/type_cvt.cpp b/dnn/test/arm_common/type_cvt.cpp index 71368aeca..4bdf9e583 100644 --- a/dnn/test/arm_common/type_cvt.cpp +++ b/dnn/test/arm_common/type_cvt.cpp @@ -88,6 +88,26 @@ TEST_F(ARM_COMMON, TYPE_CVT) { .execs({{1, 32, 24, 128}, {1, 32, 24, 128}}); } +TEST_F(ARM_COMMON, TYPE_CVT_16_F32) { + Checker checker(handle()); + UniformIntRNG rng{INT16_MIN >> 1, INT16_MAX >> 1}; + + for (size_t size : {3, 7, 15, 33, 10000}) { + checker.set_rng(0, &rng); + checker.set_dtype(0, dtype::Int16()).execs({{size}, {size}}); + checker.set_dtype(0, dtype::Uint16()).execs({{size}, {size}}); + } + TensorLayout src_int16{ + {1, 96, 64, 120}, {128 * 64 * 96, 128 * 64, 128, 1}, dtype::Int16()}; + TensorLayout dst_int16{{1, 96, 64, 120}, dtype::Float32()}; + checker.execl({src_int16, dst_int16}); + + TensorLayout src_uint16{ + {1, 96, 64, 120}, {128 * 64 * 96, 128 * 64, 128, 1}, dtype::Uint16()}; + TensorLayout dst_uint16{{1, 96, 64, 120}, dtype::Float32()}; + checker.execl({src_uint16, dst_uint16}); +} + #if MEGDNN_WITH_BENCHMARK TEST_F(ARM_COMMON, BENCHMARK_TYPE_CVT) { auto run = [&](const TensorShapeArray& shapes) { diff --git a/dnn/test/common/checker.cpp b/dnn/test/common/checker.cpp index 1dfad8482..db33bb539 100644 --- a/dnn/test/common/checker.cpp +++ b/dnn/test/common/checker.cpp @@ -158,8 +158,9 @@ void copy_tensors( //! In order to avoid an unnecessary increase in binary size, we just //! use QuantizedS16 dtype in winograd_filter_preprocess now. cb(::megdnn::dtype::QuantizedS16) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) + cb(::megdnn::dtype::Uint16) #undef cb - default : megdnn_trap(); + default : megdnn_trap(); } } diff --git a/dnn/test/common/rng.cpp b/dnn/test/common/rng.cpp index 372ba46e5..19362df93 100644 --- a/dnn/test/common/rng.cpp +++ b/dnn/test/common/rng.cpp @@ -202,6 +202,9 @@ void IIDRNG::gen(const TensorND& tensor) { memset(tensor.raw_ptr, 0, tensor.layout.access_bytes()); return; } + if (tensor.layout.dtype.enumv() == DTypeEnum::Uint16) { + return; + } megdnn_assert( 0, "IIDRNG does not know how to generate value for DType %s", tensor.layout.dtype.name()); diff --git a/dnn/test/cuda/type_cvt.cpp b/dnn/test/cuda/type_cvt.cpp index ad9c404a6..11c07e232 100644 --- a/dnn/test/cuda/type_cvt.cpp +++ b/dnn/test/cuda/type_cvt.cpp @@ -25,6 +25,11 @@ TEST_F(CUDA, TYPE_CVT) { TensorLayout src({10, 10}, sdtype), dst({10, 10}, ddtype); Checker checker(handle_cuda()); checker.set_rng(0, &init).exec(TensorLayoutArray{src, dst}); + + TensorLayout non_contig_src( + {1, 96, 64, 120}, {96 * 64 * 128, 64 * 128, 128, 1}, sdtype); + TensorLayout non_contig_dst({1, 96, 64, 120}, ddtype); + checker.exec(TensorLayoutArray{non_contig_src, non_contig_dst}); } } diff --git a/dnn/test/x86/type_cvt.cpp b/dnn/test/x86/type_cvt.cpp index f0d29815f..bb3c56804 100644 --- a/dnn/test/x86/type_cvt.cpp +++ b/dnn/test/x86/type_cvt.cpp @@ -37,8 +37,22 @@ TEST_F(X86, TYPE_CVT) { for (auto ddtype : dtypes) { checker.set_dtype(0, sdtype).set_dtype(1, ddtype).execs( {{size}, {size}}); + TensorLayout non_contig_src( + {1, 10, 10, 12}, {10 * 10 * 18, 10 * 18, 18, 1}, sdtype); + TensorLayout non_contig_dst({1, 10, 10, 12}, ddtype); + checker.exec(TensorLayoutArray{non_contig_src, non_contig_dst}); } } + + for (size_t size : {1, 7, 15, 33}) { + checker.set_dtype(0, dtype::Uint16()) + .set_dtype(1, dtype::Float32()) + .execs({{size}, {size}}); + } + TensorLayout non_contig_src( + {1, 10, 10, 12}, {10 * 10 * 18, 10 * 18, 18, 1}, dtype::Uint16()); + TensorLayout non_contig_dst({1, 10, 10, 12}, dtype::Float32()); + checker.exec(TensorLayoutArray{non_contig_src, non_contig_dst}); } TEST_F(X86, TYPE_CVT_NO_CONTIGUOUS) { diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index 00acc27fb..77ff47724 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -772,8 +772,10 @@ void TypeCvt::perform( } void TypeCvt::add_input_layout_constraint() { + //! Because the implementation of typecvt on arm/x86/cuda/opencl support + //! non-contiguous memory. So we change constraint of typecvt to monotone for (auto i : input()) { - i->add_layout_constraint_contiguous(); + i->add_layout_constraint_monotone(); } } -- GitLab