From 3344b580a9bf84486799e300ba1c2509d7f03d07 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 25 Aug 2021 22:06:02 +0800 Subject: [PATCH] feat(dnn): add elemwise for nchw88+fp16 GitOrigin-RevId: 63587975f8746bd8cf2443e81d433bfc07122b38 --- .../arm_common/conv_bias/postprocess_helper.h | 92 ++-- dnn/src/arm_common/elemwise/binary/algo.cpp | 42 +- dnn/src/arm_common/elemwise/binary/algo.h | 2 +- dnn/src/arm_common/elemwise/opr_impl.cpp | 35 +- dnn/src/arm_common/elemwise/opr_impl.h | 8 +- dnn/src/arm_common/elemwise/ternary/algo.cpp | 24 +- dnn/src/arm_common/elemwise/ternary/algo.h | 4 +- .../elemwise_multi_type/opr_impl.cpp | 19 +- dnn/src/arm_common/elemwise_op.h | 445 +++++++++++++++--- dnn/test/arm_common/elemwise.cpp | 86 ++++ 10 files changed, 595 insertions(+), 162 deletions(-) diff --git a/dnn/src/arm_common/conv_bias/postprocess_helper.h b/dnn/src/arm_common/conv_bias/postprocess_helper.h index 0a9178a1d..11c9f60c7 100644 --- a/dnn/src/arm_common/conv_bias/postprocess_helper.h +++ b/dnn/src/arm_common/conv_bias/postprocess_helper.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -22,7 +23,6 @@ MIDOUT_DECL(arm_common_conv_bias_postprocess_helper) namespace { - #define CONCAT_OP(_name) megdnn::arm_common::_name #define CONCAT_NL(_name) megdnn::NonlineMode::_name @@ -57,9 +57,9 @@ namespace { reinterpret_cast(dst_ptr), bias_type, bias_type, \ dst_type, N, OC, OH* OW); -#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW44(_op) \ +#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ megdnn::arm_common::OpCallerBinary<_op, \ - megdnn::arm_common::VEC_BCAST101x4>:: \ + megdnn::arm_common::VEC_BCAST101xX>:: \ run(static_cast(conv_dst_ptr), \ reinterpret_cast(bias_ptr), \ reinterpret_cast(dst_ptr), bias_type, bias_type, \ @@ -86,9 +86,9 @@ namespace { if (pack_oc_size == 1) { \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ } else { \ - megdnn_assert(pack_oc_size == 4, \ - "Only support nchw44 in ARM"); \ - FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ + megdnn_assert(pack_oc_size == 4 || pack_oc_size == 8, \ + "Only support nchw44/nchw88 in ARM"); \ + FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX); \ } \ } \ MIDOUT_END(); \ @@ -100,7 +100,7 @@ namespace { MIDOUT_END(); \ break; \ default: \ - megdnn_throw("unknow biasmode"); \ + megdnn_throw("unknow biasmode"); \ break; \ } @@ -160,7 +160,7 @@ struct PostProcess { #undef FOR_NONLINEAR_UNARY #undef FOR_NONLINEAR_BINARY_BROADCAST -#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW44 +#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX #undef FOR_NONLINEAR_BINARY #undef FOR_NONLINEAR_NOBIAS #undef FOR_NONLINEAR @@ -183,16 +183,24 @@ struct PostProcess { reinterpret_cast(dst_ptr), bias_type, bias_type, \ dst_type, N, OC, OH* OW); -#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW44(_op) \ +#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ + megdnn::arm_common::OpCallerBinary<_op, \ + megdnn::arm_common::VEC_BCAST101xX>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, \ + dst_type, N, OC, OH* OW, pack_oc_size); + +#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ megdnn::arm_common::OpCallerBinary<_op, \ - megdnn::arm_common::VEC_BCAST101x4>:: \ + megdnn::arm_common::VEC_BCAST101xX>:: \ run(static_cast(conv_dst_ptr), \ reinterpret_cast(bias_ptr), \ reinterpret_cast(dst_ptr), bias_type, bias_type, \ dst_type, N, OC, OH* OW, pack_oc_size); -#define HANDLE_IDENTITY(_caller, _op) \ - case megdnn::NonlineMode::IDENTITY: \ +#define HANDLE_IDENTITY(_caller, _op) \ + case megdnn::NonlineMode::IDENTITY: \ _caller(_op) break; #define FOR_NONLINEAR(_caller) \ @@ -220,9 +228,9 @@ struct PostProcess { if (pack_oc_size == 1) { \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ } else { \ - megdnn_assert(pack_oc_size == 4, \ - "Only support nchw44 in ARM"); \ - FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ + megdnn_assert(pack_oc_size == 4 || pack_oc_size == 8, \ + "Only support nchw44/nchw88 in ARM"); \ + FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX); \ } \ break; \ default: \ @@ -230,9 +238,9 @@ struct PostProcess { if (pack_oc_size == 1) { \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ } else { \ - megdnn_assert(pack_oc_size == 4, \ - "Only support nchw44 in ARM"); \ - FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ + megdnn_assert(pack_oc_size == 4 || pack_oc_size == 8, \ + "Only support nchw44/nchw88 in ARM"); \ + FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX); \ } \ break; \ } \ @@ -254,7 +262,7 @@ struct PostProcess { #undef FOR_NONLINEAR_UNARY #undef FOR_NONLINEAR_BINARY_BROADCAST -#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW44 +#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX #undef FOR_NONLINEAR_BINARY #undef FOR_NONLINEAR_NOBIAS #undef FOR_NONLINEAR @@ -268,9 +276,9 @@ struct PostProcess { reinterpret_cast(dst_ptr), bias_type, bias_type, \ dst_type, N, OC, OH* OW); -#define FOR_BINARY_BROADCAST_NCHW44(_op) \ +#define FOR_BINARY_BROADCAST_NCHWXX(_op) \ megdnn::arm_common::OpCallerBinary<_op, \ - megdnn::arm_common::VEC_BCAST101x4>:: \ + megdnn::arm_common::VEC_BCAST101xX>:: \ run(static_cast(conv_dst_ptr), \ reinterpret_cast(bias_ptr), \ reinterpret_cast(dst_ptr), bias_type, bias_type, \ @@ -284,25 +292,25 @@ struct PostProcess { reinterpret_cast(dst_ptr), bias_type, bias_type, \ dst_type, N* OC* OH* OW* pack_oc_size); -#define FOR_BIAS(_bias_mode, OH, OW) \ - switch (_bias_mode) { \ - case megdnn::BiasMode::NO_BIAS: \ - break; \ - case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ - if (pack_oc_size == 1) { \ - FOR_BINARY_BROADCAST(CONCAT_OP(AddOp)); \ - } else { \ - megdnn_assert(pack_oc_size == 4, \ - "Only support nchw44 in ARM"); \ - FOR_BINARY_BROADCAST_NCHW44(CONCAT_OP(AddOp)); \ - } \ - break; \ - case megdnn::BiasMode::BIAS: \ - FOR_BINARY(CONCAT_OP(AddOp)); \ - break; \ - default: \ - megdnn_throw("unknow biasmode"); \ - break; \ +#define FOR_BIAS(_bias_mode, OH, OW) \ + switch (_bias_mode) { \ + case megdnn::BiasMode::NO_BIAS: \ + break; \ + case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ + if (pack_oc_size == 1) { \ + FOR_BINARY_BROADCAST(CONCAT_OP(AddOp)); \ + } else { \ + megdnn_assert(pack_oc_size == 4 || pack_oc_size == 8, \ + "Only support nchw44/nchw88 in ARM"); \ + FOR_BINARY_BROADCAST_NCHWXX(CONCAT_OP(AddOp)); \ + } \ + break; \ + case megdnn::BiasMode::BIAS: \ + FOR_BINARY(CONCAT_OP(AddOp)); \ + break; \ + default: \ + megdnn_throw("unknow biasmode"); \ + break; \ } template @@ -318,7 +326,7 @@ struct PostProcess { }; #undef FOR_BINARY_BROADCAST -#undef FOR_BINARY_BROADCAST_NCHW44 +#undef FOR_BINARY_BROADCAST_NCHWXX #undef FOR_BINARY #undef FOR_BIAS #undef CB diff --git a/dnn/src/arm_common/elemwise/binary/algo.cpp b/dnn/src/arm_common/elemwise/binary/algo.cpp index 833b7cd1e..81811136f 100644 --- a/dnn/src/arm_common/elemwise/binary/algo.cpp +++ b/dnn/src/arm_common/elemwise/binary/algo.cpp @@ -105,25 +105,21 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available( return false; } -bool ElemwiseImpl::AlgoBinaryVecBcast101x4::is_available( +bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available( const KernParam& kern_param) const { if (!is_available_common(kern_param.mode) || - ((BcastType::VEC_BCAST101x4 != kern_param.broad_cast_type) && - (BcastType::BCAST101x4_VEC != kern_param.broad_cast_type))) + ((BcastType::VEC_BCAST101xX != kern_param.broad_cast_type) && + (BcastType::BCAST101xX_VEC != kern_param.broad_cast_type))) return false; auto& elparam = kern_param.binary_elparam; auto& src0 = elparam[0]; -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - if (DNN_FLOAT16_SELECT(src0.layout.dtype == dtype::Float16{}, false)) { - return false; - } -#endif - DISPATCH_TYPE("AlgoBinaryVecBcast101x::is_available"_hash); + DISPATCH_TYPE("AlgoBinaryVecBcast101xX::is_available"_hash); return false; } + #undef DISPATCH_MODE_FLOAT #undef DISPATCH_MODE_INT @@ -334,16 +330,19 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec( return; } -void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec( +void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec( const KernParam& kern_param) const { auto& elparam = kern_param.binary_elparam; auto &src0 = elparam[0], &src1 = elparam[1]; auto&& dst = *(kern_param.m_dst); BroadcastChannelInfo binfo; - // BcastType::VEC + BCAST_101x - if (BcastType::VEC_BCAST101x4 == kern_param.broad_cast_type && - is_broadcastedx_channel_like<4>(src1.layout, binfo)) { + // BcastType::VEC + BCAST_101X + if (BcastType::VEC_BCAST101xX == kern_param.broad_cast_type) { + megdnn_assert( + is_broadcastedx_channel_like<4>(src1.layout, binfo) || + is_broadcastedx_channel_like<8>(src1.layout, binfo), + "only nchw44 and nchw88 supported"); #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ case Mode::_mode: \ MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ @@ -351,7 +350,7 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec( thin_function \ run = OpCallerBinary<_op<_type, _type>, \ - BcastType::VEC_BCAST101x4>::run; \ + BcastType::VEC_BCAST101xX>::run; \ MEGDNN_DISPATCH_CPU_KERN( \ static_cast(kern_param.handle), \ run(static_cast(src0.raw_ptr), \ @@ -362,17 +361,19 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec( } \ MIDOUT_END(); \ return - size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); - DISPATCH_TYPE("AlgoBinaryVecBcast101x::exec_vec_b"_hash); + DISPATCH_TYPE("AlgoBinaryVecBcast101xX::exec_vec_b"_hash); #undef DISPATCH_BINARY } // BCAST_101x + BcastType::VEC - if (BcastType::BCAST101x4_VEC == kern_param.broad_cast_type && - is_broadcastedx_channel_like<4>(src0.layout, binfo)) { + if (BcastType::BCAST101xX_VEC == kern_param.broad_cast_type) { + megdnn_assert( + is_broadcastedx_channel_like<4>(src0.layout, binfo) || + is_broadcastedx_channel_like<8>(src0.layout, binfo), + "only nchw44 and nchw88 supported"); #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ case Mode::_mode: \ MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ @@ -380,7 +381,7 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec( thin_function \ run = OpCallerBinary<_op<_type, _type>, \ - BcastType::BCAST101x4_VEC>::run; \ + BcastType::BCAST101xX_VEC>::run; \ MEGDNN_DISPATCH_CPU_KERN( \ static_cast(kern_param.handle), \ run(static_cast(src0.raw_ptr), \ @@ -394,12 +395,13 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec( size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); - DISPATCH_TYPE("AlgoBinaryVecBcast101x::exec_b_vec"_hash); + DISPATCH_TYPE("AlgoBinaryVecBcast101xX::exec_b_vec"_hash); #undef DISPATCH_BINARY } return; } + #undef DISPATCH_MODE_FLOAT #undef DISPATCH_MODE_INT diff --git a/dnn/src/arm_common/elemwise/binary/algo.h b/dnn/src/arm_common/elemwise/binary/algo.h index b98ee269a..42669b53b 100644 --- a/dnn/src/arm_common/elemwise/binary/algo.h +++ b/dnn/src/arm_common/elemwise/binary/algo.h @@ -34,7 +34,7 @@ namespace arm_common { DECL_CB(VecVec); DECL_CB(VecScalar); DECL_CB(VecBcast101); -DECL_CB(VecBcast101x4); +DECL_CB(VecBcast101xX); #undef DECL_CB } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/elemwise/opr_impl.cpp b/dnn/src/arm_common/elemwise/opr_impl.cpp index 109eda682..6e705205b 100644 --- a/dnn/src/arm_common/elemwise/opr_impl.cpp +++ b/dnn/src/arm_common/elemwise/opr_impl.cpp @@ -27,14 +27,14 @@ class ElemwiseImpl::AlgoPack { AlgoBinaryVecVec algo_binary_vec_vec; AlgoBinaryVecScalar algo_binary_vec_sca; AlgoBinaryVecBcast101 algo_binary_vec_bcast101; - AlgoBinaryVecBcast101x4 algo_binary_VEC_BCAST101x4; + AlgoBinaryVecBcast101xX algo_binary_VEC_BCAST101xX; AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec; AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca; AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101; - AlgoTernaryFma3Bcast101x4VecBcast101x4 - algo_ternaryfma3_bcast101x4_vec_bcast101x4; + AlgoTernaryFma3Bcast101xXVecBcast101xX + algo_ternaryfma3_bcast101xX_vec_bcast101xX; AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec; - AlgoTernaryFma3VecBcast101x4Vec algo_ternaryfma3_vec_bcast101x4_vec; + AlgoTernaryFma3VecBcast101xXVec algo_ternaryfma3_vec_bcast101xX_vec; AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec; AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca; @@ -44,13 +44,13 @@ public: all_algos.emplace_back(&algo_binary_vec_vec); all_algos.emplace_back(&algo_binary_vec_sca); all_algos.emplace_back(&algo_binary_vec_bcast101); - all_algos.emplace_back(&algo_binary_VEC_BCAST101x4); + all_algos.emplace_back(&algo_binary_VEC_BCAST101xX); all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec); all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca); all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101); - all_algos.emplace_back(&algo_ternaryfma3_bcast101x4_vec_bcast101x4); + all_algos.emplace_back(&algo_ternaryfma3_bcast101xX_vec_bcast101xX); all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec); - all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101x4_vec); + all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101xX_vec); all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec); all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca); } @@ -118,9 +118,10 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { } if (is_vector(src1.layout) && - is_broadcastedx_channel_like<4>(src0.layout, binfo) && + (is_broadcastedx_channel_like<4>(src0.layout, binfo) || + is_broadcastedx_channel_like<8>(src0.layout, binfo)) && src0.layout.eq_layout(src2.layout)) { - kern_param.broad_cast_type = BcastType::BCAST101x4_VEC_BCAST101x4; + kern_param.broad_cast_type = BcastType::BCAST101xX_VEC_BCAST101xX; return kern_param; } @@ -131,8 +132,9 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { } if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && - is_broadcastedx_channel_like<4>(src1.layout, binfo)) { - kern_param.broad_cast_type = BcastType::VEC_BCAST101x4_VEC; + (is_broadcastedx_channel_like<4>(src1.layout, binfo) || + is_broadcastedx_channel_like<8>(src1.layout, binfo))) { + kern_param.broad_cast_type = BcastType::VEC_BCAST101xX_VEC; return kern_param; } @@ -180,17 +182,18 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { } if (is_vector(src0.layout) && - is_broadcastedx_channel_like<4>(src1.layout, binfo)) { - kern_param.broad_cast_type = BcastType::VEC_BCAST101x4; + (is_broadcastedx_channel_like<4>(src1.layout, binfo) || + is_broadcastedx_channel_like<8>(src1.layout, binfo))) { + kern_param.broad_cast_type = BcastType::VEC_BCAST101xX; return kern_param; } if (is_vector(src1.layout) && - is_broadcastedx_channel_like<4>(src0.layout, binfo)) { - kern_param.broad_cast_type = BcastType::BCAST101x4_VEC; + (is_broadcastedx_channel_like<4>(src0.layout, binfo) || + is_broadcastedx_channel_like<8>(src0.layout, binfo))) { + kern_param.broad_cast_type = BcastType::BCAST101xX_VEC; return kern_param; } - } else if (opr->m_src->size() == 1) { kern_param.broad_cast_type = BcastType::VEC; kern_param.unary_elparam = opr->make_elemwise_op_param<1>(); diff --git a/dnn/src/arm_common/elemwise/opr_impl.h b/dnn/src/arm_common/elemwise/opr_impl.h index 568245359..17f5bda0a 100644 --- a/dnn/src/arm_common/elemwise/opr_impl.h +++ b/dnn/src/arm_common/elemwise/opr_impl.h @@ -10,7 +10,9 @@ * implied. */ #pragma once + #include "src/fallback/elemwise/opr_impl.h" + #include "src/arm_common/elemwise_op.h" namespace megdnn { @@ -37,13 +39,13 @@ private: class AlgoBinaryVecVec; class AlgoBinaryVecScalar; class AlgoBinaryVecBcast101; - class AlgoBinaryVecBcast101x4; + class AlgoBinaryVecBcast101xX; class AlgoTernaryFma3VecVecVec; class AlgoTernaryFma3VecVecScalar; class AlgoTernaryFma3Bcast101VecBcast101; - class AlgoTernaryFma3Bcast101x4VecBcast101x4; + class AlgoTernaryFma3Bcast101xXVecBcast101xX; class AlgoTernaryFma3VecBcast101Vec; - class AlgoTernaryFma3VecBcast101x4Vec; + class AlgoTernaryFma3VecBcast101xXVec; class AlgoTernaryFma3VecScalarVec; class AlgoTernaryFma3VecScalarScalar; class AlgoPack; diff --git a/dnn/src/arm_common/elemwise/ternary/algo.cpp b/dnn/src/arm_common/elemwise/ternary/algo.cpp index 8df76d1ae..8070f0f84 100644 --- a/dnn/src/arm_common/elemwise/ternary/algo.cpp +++ b/dnn/src/arm_common/elemwise/ternary/algo.cpp @@ -42,9 +42,9 @@ using namespace arm_common; DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC); DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR); DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101); -DECL_AVAILABLE(Bcast101x4VecBcast101x4, BcastType::BCAST101x4_VEC_BCAST101x4); +DECL_AVAILABLE(Bcast101xXVecBcast101xX, BcastType::BCAST101xX_VEC_BCAST101xX); DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC); -DECL_AVAILABLE(VecBcast101x4Vec, BcastType::VEC_BCAST101x4_VEC); +DECL_AVAILABLE(VecBcast101xXVec, BcastType::VEC_BCAST101xX_VEC); DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC); DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR); #undef DECL_CB @@ -161,13 +161,15 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( return; } -void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec( +void ElemwiseImpl::AlgoTernaryFma3Bcast101xXVecBcast101xX::exec( const KernParam& kern_param) const { auto& elparam = kern_param.ternary_elparam; auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; BroadcastChannelInfo binfo; - is_broadcastedx_channel_like<4>(src0.layout, binfo); + megdnn_assert(is_broadcastedx_channel_like<4>(src0.layout, binfo) || + is_broadcastedx_channel_like<8>(src0.layout, binfo), + "only nchw44 and nchw88 supported"); #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ case Mode::_mode: \ MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ @@ -177,7 +179,7 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec( size_t, size_t, size_t)> \ run = OpCallerTernary< \ _op<_type, _type>, \ - BcastType::BCAST101x4_VEC_BCAST101x4>::run; \ + BcastType::BCAST101xX_VEC_BCAST101xX>::run; \ MEGDNN_DISPATCH_CPU_KERN( \ static_cast(kern_param.handle), \ run(static_cast(src0.raw_ptr), \ @@ -193,19 +195,21 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec( size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); auto&& dst = *(kern_param.m_dst); - DISPATCH_TYPE("AlgoTernaryFma3Bcast101x4VecBcast101x4::exec"_hash); + DISPATCH_TYPE("AlgoTernaryFma3Bcast101xXVecBcast101xX::exec"_hash); #undef DISPATCH_TERNARY return; } -void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec( +void ElemwiseImpl::AlgoTernaryFma3VecBcast101xXVec::exec( const KernParam& kern_param) const { auto& elparam = kern_param.ternary_elparam; auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; BroadcastChannelInfo binfo; - is_broadcastedx_channel_like<4>(src1.layout, binfo); + megdnn_assert(is_broadcastedx_channel_like<4>(src1.layout, binfo) || + is_broadcastedx_channel_like<8>(src1.layout, binfo), + "only nchw44 and nchw88 supported"); #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ case Mode::_mode: \ MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ @@ -214,7 +218,7 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec( _type*, DType, DType, DType, DType, size_t, \ size_t, size_t, size_t)> \ run = OpCallerTernary<_op<_type, _type>, \ - BcastType::VEC_BCAST101x4_VEC>::run; \ + BcastType::VEC_BCAST101xX_VEC>::run; \ MEGDNN_DISPATCH_CPU_KERN( \ static_cast(kern_param.handle), \ run(static_cast(src0.raw_ptr), \ @@ -230,7 +234,7 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec( size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); auto&& dst = *(kern_param.m_dst); - DISPATCH_TYPE("AlgoTernaryFma3VecBcast101x4Vec::exec"_hash); + DISPATCH_TYPE("AlgoTernaryFma3VecBcast101xXVec::exec"_hash); #undef DISPATCH_TERNARY return; diff --git a/dnn/src/arm_common/elemwise/ternary/algo.h b/dnn/src/arm_common/elemwise/ternary/algo.h index d63b4ee81..2864d3522 100644 --- a/dnn/src/arm_common/elemwise/ternary/algo.h +++ b/dnn/src/arm_common/elemwise/ternary/algo.h @@ -34,9 +34,9 @@ namespace arm_common { DECL_CB(VecVecVec); DECL_CB(VecVecScalar); DECL_CB(Bcast101VecBcast101); -DECL_CB(Bcast101x4VecBcast101x4); +DECL_CB(Bcast101xXVecBcast101xX); DECL_CB(VecBcast101Vec); -DECL_CB(VecBcast101x4Vec); +DECL_CB(VecBcast101xXVec); DECL_CB(VecScalarVec); DECL_CB(VecScalarScalar); #undef DECL_CB diff --git a/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp b/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp index c975996c2..96526a642 100644 --- a/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp @@ -644,7 +644,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, { BroadcastChannelInfo binfo; if (is_vector(src0.layout) && - is_broadcastedx_channel_like<4>(src1.layout, binfo)) { + (is_broadcastedx_channel_like<4>(src1.layout, binfo) || + is_broadcastedx_channel_like<8>(src1.layout, binfo))) { #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ case _mode: { \ using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ @@ -653,14 +654,13 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, DType, DType, DType, size_t, size_t, size_t, \ size_t)> \ run = OpCallerBinary<_op, \ - VEC_BCAST101x4>::run; \ + VEC_BCAST101xX>::run; \ MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ src0.ptr(), src1.ptr(), \ dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ return; \ } - size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); DISPATCH() @@ -679,14 +679,13 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, DType, DType, DType, size_t, size_t, size_t, \ size_t)> \ run = OpCallerBinary<_op, \ - BCAST101x4_VEC>::run; \ + BCAST101xX_VEC>::run; \ MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ src0.ptr(), src1.ptr(), \ dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ return; \ } - size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); DISPATCH() @@ -818,7 +817,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, { BroadcastChannelInfo binfo; if (is_vector(src0.layout) && - is_broadcastedx_channel_like<4>(src1.layout, binfo) && + (is_broadcastedx_channel_like<4>(src1.layout, binfo) || + is_broadcastedx_channel_like<8>(src1.layout, binfo)) && src0.layout.eq_shape(src2.layout)) { #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ case _mode: { \ @@ -828,7 +828,7 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, const src_ctype*, dst_ctype*, DType, DType, DType, \ DType, size_t, size_t, size_t, size_t)> \ run = OpCallerTernary<_op, \ - VEC_BCAST101x4_VEC>::run; \ + VEC_BCAST101xX_VEC>::run; \ MEGDNN_DISPATCH_CPU_KERN_OPR( \ run(src0.ptr(), src1.ptr(), \ src2.ptr(), dst.ptr(), \ @@ -846,7 +846,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, //! BCAST101x + VEC +BCAST101x if (is_vector(src1.layout) && - is_broadcastedx_channel_like<4>(src0.layout, binfo) && + (is_broadcastedx_channel_like<4>(src0.layout, binfo) || + is_broadcastedx_channel_like<8>(src0.layout, binfo)) && src0.layout.eq_shape(src2.layout)) { #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ case _mode: { \ @@ -856,7 +857,7 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, const src_ctype*, dst_ctype*, DType, DType, DType, \ DType, size_t, size_t, size_t, size_t)> \ run = OpCallerTernary<_op, \ - BCAST101x4_VEC_BCAST101x4>::run; \ + BCAST101xX_VEC_BCAST101xX>::run; \ MEGDNN_DISPATCH_CPU_KERN_OPR( \ run(src0.ptr(), src1.ptr(), \ src2.ptr(), dst.ptr(), \ diff --git a/dnn/src/arm_common/elemwise_op.h b/dnn/src/arm_common/elemwise_op.h index 5f10fbb81..bc8d2373e 100644 --- a/dnn/src/arm_common/elemwise_op.h +++ b/dnn/src/arm_common/elemwise_op.h @@ -89,6 +89,21 @@ cb(dt_float32, float32_t, float32x4_t, f32); cb(dt_int32, int32_t, int32x4_t, s32); #undef cb +template +struct ParamElemVisitorBcast101x8; +#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ + template <> \ + struct ParamElemVisitorBcast101x8<_ctype> { \ + _neon_type operator()(const _ctype* src) const { \ + return vld1q_##_fun_suffix( \ + reinterpret_cast(src)); \ + } \ + } +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +cb(__fp16, __fp16, float16x8_t, f16); +#endif +#undef cb + /*! * \brief broadcast type * BCAST_x[0]x[1]...: x[i] == !stride[i] @@ -97,17 +112,17 @@ enum BcastType { VEC, VEC_VEC, VEC_BCAST101, - VEC_BCAST101x4, + VEC_BCAST101xX, VEC_SCALAR, SCALAR_VEC, BCAST101_VEC, - BCAST101x4_VEC, + BCAST101xX_VEC, VEC_VEC_VEC, VEC_VEC_SCALAR, BCAST101_VEC_BCAST101, - BCAST101x4_VEC_BCAST101x4, + BCAST101xX_VEC_BCAST101xX, VEC_BCAST101_VEC, - VEC_BCAST101x4_VEC, + VEC_BCAST101xX_VEC, VEC_SCALAR_VEC, VEC_SCALAR_SCALAR, UNKNOWN_BCAST_TYPE @@ -334,7 +349,7 @@ struct OpCallerBinary { }; template -struct OpCallerBinary, BCAST101x4_VEC> { +struct OpCallerBinary, BCAST101xX_VEC> { using Op = PowOp; static void run(const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, @@ -360,18 +375,37 @@ struct OpCallerBinary, BCAST101x4_VEC> { } }; -template -struct OpCallerBinary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType dst_dtype, size_t batch, - size_t nr_channel_blocks, size_t channel_stride, - size_t channel_block_dim) { - megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitorBcast101x4 vis0; - ParamElemVisitor vis1; +template +struct OpCallerBinaryBcast101xXVec { + template + static void run(const src_ctype* src0, const src_ctype* src1, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src0_block_ptr = src0_ptr + cb * channel_block_dim; + for (size_t img_index = 0; img_index < channel_stride; + img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; + c_iter++) { + op(*(src0_block_ptr + c_iter), *src1, dst); + src1++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinaryBcast101xDVec { + template + static void run(const src_ctype* src0, const src_ctype* src1, + typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, + const Vis1& vis1, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src0_ptr = src0; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { @@ -400,8 +434,63 @@ struct OpCallerBinary { } }; +template +struct OpCallerBinaryBcast101xXVec { + template + static void run(const src_ctype* src0, const src_ctype* src1, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + ParamElemVisitorBcast101x4 vis0; + ParamElemVisitor vis1; + OpCallerBinaryBcast101xDVec::run( + src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, + channel_stride); + } +}; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +struct OpCallerBinaryBcast101xXVec<__fp16, 8> { + using src_ctype = __fp16; + + template + static void run(const src_ctype* src0, const src_ctype* src1, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + ParamElemVisitorBcast101x8 vis0; + ParamElemVisitor vis1; + OpCallerBinaryBcast101xDVec::run( + src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, + channel_stride); + } +}; +#endif + +template +struct OpCallerBinary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType dst_dtype, size_t batch, + size_t nr_channel_blocks, size_t channel_stride, + size_t channel_block_dim) { + megdnn_assert(channel_block_dim == 4 || channel_block_dim == 8, + "only imp for nchw44/nchw88"); + Op op(src0_dtype, src1_dtype, dst_dtype); + if (channel_block_dim == 4) { + OpCallerBinaryBcast101xXVec::run( + src0, src1, dst, op, batch, nr_channel_blocks, + channel_stride); + } else { + OpCallerBinaryBcast101xXVec::run( + src0, src1, dst, op, batch, nr_channel_blocks, + channel_stride); + } + } +}; + template -struct OpCallerBinary, VEC_BCAST101x4> { +struct OpCallerBinary, VEC_BCAST101xX> { using Op = PowOp; static void run(const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, @@ -427,18 +516,37 @@ struct OpCallerBinary, VEC_BCAST101x4> { } }; -template -struct OpCallerBinary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType dst_dtype, size_t batch, - size_t nr_channel_blocks, size_t channel_stride, - size_t channel_block_dim) { - megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); - Op op(src0_dtype, src1_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitorBcast101x4 vis1; +template +struct OpCallerBinaryVecBcast101xX { + template + static void run(const src_ctype* src0, const src_ctype* src1, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src1_ptr = src1; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src1_block_ptr = src1_ptr + cb * channel_block_dim; + for (size_t img_index = 0; img_index < channel_stride; + img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; + c_iter++) { + op(*src0, *(src1_block_ptr + c_iter), dst); + src0++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinaryVecBcast101xD { + template + static void run(const src_ctype* src0, const src_ctype* src1, + typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, + const Vis1& vis1, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src1_ptr = src1; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { @@ -467,6 +575,60 @@ struct OpCallerBinary { } }; +template +struct OpCallerBinaryVecBcast101xX { + template + static void run(const src_ctype* src0, const src_ctype* src1, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + ParamElemVisitor vis0; + ParamElemVisitorBcast101x4 vis1; + OpCallerBinaryVecBcast101xD::run( + src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, + channel_stride); + } +}; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +struct OpCallerBinaryVecBcast101xX<__fp16, 8> { + using src_ctype = __fp16; + template + static void run(const src_ctype* src0, const src_ctype* src1, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + ParamElemVisitor vis0; + ParamElemVisitorBcast101x8 vis1; + OpCallerBinaryVecBcast101xD::run( + src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, + channel_stride); + } +}; +#endif + +template +struct OpCallerBinary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType dst_dtype, size_t batch, + size_t nr_channel_blocks, size_t channel_stride, + size_t channel_block_dim) { + megdnn_assert(channel_block_dim == 4 || channel_block_dim == 8, + "only imp for nchw44/nchw88"); + Op op(src0_dtype, src1_dtype, dst_dtype); + if (channel_block_dim == 4) { + OpCallerBinaryVecBcast101xX::run( + src0, src1, dst, op, batch, nr_channel_blocks, + channel_stride); + } else { + OpCallerBinaryVecBcast101xX::run( + src0, src1, dst, op, batch, nr_channel_blocks, + channel_stride); + } + } +}; + template struct OpCallerBinary { static void run(const typename Op::src_ctype* src0, @@ -683,21 +845,42 @@ struct OpCallerTernary { } }; -//! src0: CHW44, src1: vector, src2: CHW44 -template -struct OpCallerTernary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - const typename Op::src_ctype* src2, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t batch, size_t nr_channel_blocks, - size_t channel_stride, size_t channel_block_dim) { - megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitorBcast101x4 vis0; - ParamElemVisitor vis1; - ParamElemVisitorBcast101x4 vis2; +template +struct OpCallerTernaryBcast101xXVecBcast101xX { + template + static void run(const src_ctype* src0, const src_ctype* src1, + const src_ctype* src2, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + auto src2_ptr = src2; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src0_block_ptr = src0_ptr + cb * channel_block_dim; + auto src2_block_ptr = src2_ptr + cb * channel_block_dim; + for (size_t img_index = 0; img_index < channel_stride; + img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; + c_iter++) { + op(*(src0_block_ptr + c_iter), *src1, + *(src2_block_ptr + c_iter), dst); + src1++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerTernaryBcast101xDVecBcast101xD { + template + static void run(const src_ctype* src0, const src_ctype* src1, + const src_ctype* src2, typename Op::dst_ctype* dst, + const Op& op, const Vis0& vis0, const Vis1& vis1, + const Vis2& vis2, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src0_ptr = src0; auto src2_ptr = src2; @@ -731,6 +914,70 @@ struct OpCallerTernary { } }; +//! src0: CHW44, src1: vector, src2: CHW44 +template +struct OpCallerTernaryBcast101xXVecBcast101xX { + template + static void run(const src_ctype* src0, const src_ctype* src1, + const src_ctype* src2, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + ParamElemVisitorBcast101x4 vis0; + ParamElemVisitor vis1; + ParamElemVisitorBcast101x4 vis2; + OpCallerTernaryBcast101xDVecBcast101xD::run( + src0, src1, src2, dst, op, vis0, vis1, vis2, batch, + nr_channel_blocks, channel_stride); + } +}; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +struct OpCallerTernaryBcast101xXVecBcast101xX<__fp16, 8> { + using src_ctype = __fp16; + template + static void run(const src_ctype* src0, const src_ctype* src1, + const src_ctype* src2, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + ParamElemVisitorBcast101x8 vis0; + ParamElemVisitor vis1; + ParamElemVisitorBcast101x8 vis2; + OpCallerTernaryBcast101xDVecBcast101xD::run( + src0, src1, src2, dst, op, vis0, vis1, vis2, batch, + nr_channel_blocks, channel_stride); + } +}; +#endif + +template +struct OpCallerTernary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch, size_t nr_channel_blocks, + size_t channel_stride, size_t channel_block_dim) { + megdnn_assert(channel_block_dim == 4 || channel_block_dim == 8, + "only imp for nchw44/nchw88"); + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + if (channel_block_dim == 4) { + OpCallerTernaryBcast101xXVecBcast101xX::run(src0, src1, src2, + dst, op, batch, + nr_channel_blocks, + channel_stride); + } else { + OpCallerTernaryBcast101xXVecBcast101xX::run(src0, src1, src2, + dst, op, batch, + nr_channel_blocks, + channel_stride); + } + } +}; + //! src1: 1C11, src0 and src2 are contig template struct OpCallerTernary { @@ -775,21 +1022,41 @@ struct OpCallerTernary { } }; +template +struct OpCallerTernaryVecBcast101xXVec { + template + static void run(const src_ctype* src0, const src_ctype* src1, + const src_ctype* src2, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src1_ptr = src1; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src1_block_ptr = src1_ptr + cb * channel_block_dim; + for (size_t img_index = 0; img_index < channel_stride; + img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; + c_iter++) { + op(*src0, *(src1_block_ptr + c_iter), *src2, dst); + src0++; + src2++; + dst++; + } + } + } + } + } +}; + //! src1: CHW44, src0 and src2 are contig -template -struct OpCallerTernary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - const typename Op::src_ctype* src2, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t batch, size_t nr_channel_blocks, - size_t channel_stride, size_t channel_block_dim) { - megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); - Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); - ParamElemVisitor vis0; - ParamElemVisitorBcast101x4 vis1; - ParamElemVisitor vis2; +template +struct OpCallerTernaryVecBcast101xDVec { + template + static void run(const src_ctype* src0, const src_ctype* src1, + const src_ctype* src2, typename Op::dst_ctype* dst, + const Op& op, const Vis0& vis0, const Vis1& vis1, + const Vis2& vis2, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src1_ptr = src1; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { @@ -821,6 +1088,66 @@ struct OpCallerTernary { } }; +template +struct OpCallerTernaryVecBcast101xXVec { + template + static void run(const src_ctype* src0, const src_ctype* src1, + const src_ctype* src2, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + ParamElemVisitor vis0; + ParamElemVisitorBcast101x4 vis1; + ParamElemVisitor vis2; + OpCallerTernaryVecBcast101xDVec::run( + src0, src1, src2, dst, op, vis0, vis1, vis2, batch, + nr_channel_blocks, channel_stride); + } +}; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +struct OpCallerTernaryVecBcast101xXVec<__fp16, 8> { + using src_ctype = __fp16; + template + static void run(const src_ctype* src0, const src_ctype* src1, + const src_ctype* src2, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + ParamElemVisitor vis0; + ParamElemVisitorBcast101x8 vis1; + ParamElemVisitor vis2; + OpCallerTernaryVecBcast101xDVec::run( + src0, src1, src2, dst, op, vis0, vis1, vis2, batch, + nr_channel_blocks, channel_stride); + } +}; +#endif + +template +struct OpCallerTernary { + static void run(const typename Op::src_ctype* src0, + const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, + typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch, size_t nr_channel_blocks, + size_t channel_stride, size_t channel_block_dim) { + megdnn_assert(channel_block_dim == 4 || channel_block_dim == 8, + "only imp for nchw44/nchw88"); + + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + if (channel_block_dim == 4) { + OpCallerTernaryVecBcast101xXVec::run( + src0, src1, src2, dst, op, batch, nr_channel_blocks, + channel_stride); + } else { + OpCallerTernaryVecBcast101xXVec::run( + src0, src1, src2, dst, op, batch, nr_channel_blocks, + channel_stride); + } + } +}; + //! src1: scalar, src0 and src2 has the same shape template struct OpCallerTernary { diff --git a/dnn/test/arm_common/elemwise.cpp b/dnn/test/arm_common/elemwise.cpp index 6f31d389f..da415471d 100644 --- a/dnn/test/arm_common/elemwise.cpp +++ b/dnn/test/arm_common/elemwise.cpp @@ -53,6 +53,20 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) { checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); + //! nchw88 + checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); + checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); + checker.execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}}); + checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); + checker.execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}}); + + //! nchw88 + checker.execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}}); + checker.execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}}); + checker.execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}}); + checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); + checker.execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}}); + checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); checker.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); @@ -227,6 +241,78 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) { run(Mode::POW); } +TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle()); + + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}}); + + auto run = [&](Mode mode) { + // VEC_BCAST101x + checker.set_param(mode).execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); + checker.set_param(mode).execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); + checker.set_param(mode).execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}}); + checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); + checker.set_param(mode).execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}}); + // BCAST101x_VEC not powOp + checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}}); + checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}}); + checker.set_param(mode).execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}}); + checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); + checker.set_param(mode).execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}}); + }; + auto run_all = [&]() { + run(Mode::ADD); + run(Mode::FUSE_ADD_H_SWISH); + run(Mode::FUSE_ADD_RELU); + run(Mode::MAX); + run(Mode::MIN); + run(Mode::MUL); + run(Mode::SUB); + run(Mode::TRUE_DIV); + run(Mode::POW); + }; + + { + UniformFloatRNG rng(1e-5, 7e1); + checker.set_rng(0, &rng); + checker.set_epsilon(1e-5); + checker.set_dtype(0, dtype::Float32()); + checker.set_dtype(1, dtype::Float32()); + run_all(); + } + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + UniformFloatRNG rng(1, 2); + checker.set_rng(0, &rng); + checker.set_epsilon(3e-3); + checker.set_dtype(0, dtype::Float16()); + checker.set_dtype(1, dtype::Float16()); + run_all(); + } +#endif +} + #if MEGDNN_WITH_BENCHMARK namespace { void run_elemwise_benchmark(const TensorShapeArray& shapes, -- GitLab