From ee0b95e935513b9ce0fff0d1c80dc278d13b7361 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 14 Jan 2022 14:29:40 +0800 Subject: [PATCH] feat(dnn/elemwise/arm_common): support part of arm ternary elemwise multithread BCAST111C_VEC_BCAST111C and BCAST101_VEC_BCAST101 GitOrigin-RevId: 0e26553c90563b6d5f752882749ef049cf9d1d31 --- dnn/src/arm_common/elemwise/ternary/algo.cpp | 68 +++++++++++++------ .../elemwise_multi_type/opr_impl.cpp | 11 +-- dnn/src/arm_common/elemwise_op.h | 15 +++- dnn/test/arm_common/elemwise.cpp | 2 +- 4 files changed, 69 insertions(+), 27 deletions(-) diff --git a/dnn/src/arm_common/elemwise/ternary/algo.cpp b/dnn/src/arm_common/elemwise/ternary/algo.cpp index 1016f2e2f..db658be10 100644 --- a/dnn/src/arm_common/elemwise/ternary/algo.cpp +++ b/dnn/src/arm_common/elemwise/ternary/algo.cpp @@ -144,21 +144,35 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( midout_iv(Mode::_mode), _type_midout_id) { \ thin_function \ + DType, DType, size_t, size_t, size_t, size_t)> \ run = OpCallerTernary< \ _op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(kern_param.handle), \ - run(static_cast(src0.raw_ptr()), \ - static_cast(src1.raw_ptr()), \ - static_cast(src2.raw_ptr()), \ - static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ - src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ - binfo.x, binfo.y, binfo.z)); \ + auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \ + binfo, dst, run](size_t task_id, size_t) { \ + size_t offset = task_id * nr_channels_per_thread; \ + size_t nr_channels_thread = \ + std::min(nr_channels - offset, nr_channels_per_thread); \ + run(static_cast(src0.raw_ptr()) + offset, \ + static_cast(src1.raw_ptr()) + offset * binfo.z, \ + static_cast(src2.raw_ptr()) + offset, \ + static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \ + src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ + dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \ + binfo.y * binfo.z); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(kern_param.handle), nr_threads, \ + kernel); \ } \ MIDOUT_END(); \ return + size_t nr_threads = static_cast(kern_param.handle) + ->megcore_dispatcher() + ->nr_threads(); + + size_t nr_channels = binfo.y; + size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads; auto&& dst = *(kern_param.m_dst); DISPATCH_TYPE("AlgoTernaryFma3Bcast101VecBcast101::exec"_hash); #undef DISPATCH_TERNARY @@ -181,23 +195,39 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec( midout_iv(Mode::_mode), _type_midout_id) { \ thin_function \ + DType, DType, DType, size_t, size_t, size_t, size_t)> \ run = OpCallerTernary< \ _op<_type, _type>, \ BcastType::BCAST111C_VEC_BCAST111C>::run; \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(kern_param.handle), \ - run(static_cast(src0.raw_ptr()), \ - static_cast(src1.raw_ptr()), \ - is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z, \ - static_cast(src2.raw_ptr()), \ - static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ - src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ - binfo.x, binfo.y, binfo.z)); \ + auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \ + binfo, dst, run](size_t task_id, size_t) { \ + size_t offset = task_id * nr_channels_per_thread; \ + size_t nr_channels_thread = \ + std::min(nr_channels - offset, nr_channels_per_thread); \ + size_t src1_offset = \ + is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z; \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr()) + \ + offset * (binfo.z + src1_offset), \ + src1_offset, static_cast(src2.raw_ptr()), \ + static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \ + src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ + dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \ + binfo.y * binfo.z); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(kern_param.handle), nr_threads, \ + kernel); \ } \ MIDOUT_END(); \ return + size_t nr_threads = static_cast(kern_param.handle) + ->megcore_dispatcher() + ->nr_threads(); + + size_t nr_channels = binfo.y; + size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads; auto&& dst = *(kern_param.m_dst); DISPATCH_TYPE("AlgoTernaryFma3Bcast111CVecBcast111C::exec"_hash); #undef DISPATCH_TERNARY diff --git a/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp b/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp index 69a10cabc..df6cbd76a 100644 --- a/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp @@ -772,13 +772,14 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ thin_function \ + DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \ run = OpCallerTernary< \ _op, BCAST101_VEC_BCAST101>::run; \ - MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ - src0.ptr(), src1.ptr(), src2.ptr(), \ - dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ - src2.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + src2.ptr(), dst.ptr(), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \ + binfo.y, binfo.z, binfo.y* binfo.z)); \ return; \ } diff --git a/dnn/src/arm_common/elemwise_op.h b/dnn/src/arm_common/elemwise_op.h index 38f7731bc..4a3333cbc 100644 --- a/dnn/src/arm_common/elemwise_op.h +++ b/dnn/src/arm_common/elemwise_op.h @@ -1060,7 +1060,8 @@ struct OpCallerTernary { const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t batch_size, size_t channel_size, size_t channel_stride) { + size_t batch_size, size_t channel_size, size_t channel_stride, + size_t batch_offset) { Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); ParamElemVisitor vis1; ParamElemVisitorDup vis0; @@ -1068,6 +1069,7 @@ struct OpCallerTernary { for (size_t batch = 0; batch < batch_size; batch++) { auto src0_ptr = src0; auto src2_ptr = src2; + auto b_offset = batch_offset; for (size_t channel = 0; channel < channel_size; channel++) { size_t i = 0; auto src0_neon = vis0(src0_ptr); @@ -1079,6 +1081,7 @@ struct OpCallerTernary { {{src2_neon, src2_neon}}, dst); src1 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; + b_offset -= Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize @@ -1088,10 +1091,13 @@ struct OpCallerTernary { op(*src0_ptr, *src1, *src2_ptr, dst); src1++; dst++; + b_offset--; } src0_ptr++; src2_ptr++; } + src1 += b_offset; + dst += b_offset; } } }; @@ -1104,10 +1110,11 @@ struct OpCallerTernary { size_t src1_offset, const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size, - size_t channel_stride) { + size_t channel_stride, size_t batch_offset) { Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); ParamElemVisitor vis; for (size_t batch = 0; batch < batch_size; batch++) { + auto b_offset = batch_offset; for (size_t channel = 0; channel < channel_size; channel++) { auto src0_ptr = src0; auto src2_ptr = src2; @@ -1126,6 +1133,7 @@ struct OpCallerTernary { src1 += Op::SIMD_WIDTH * 2; src2_ptr += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; + b_offset -= Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize @@ -1137,9 +1145,12 @@ struct OpCallerTernary { src1++; src2_ptr++; dst++; + b_offset--; } src1 += src1_offset; } + src1 += b_offset; + dst += b_offset; } } }; diff --git a/dnn/test/arm_common/elemwise.cpp b/dnn/test/arm_common/elemwise.cpp index 95dba5079..fa222aec6 100644 --- a/dnn/test/arm_common/elemwise.cpp +++ b/dnn/test/arm_common/elemwise.cpp @@ -300,7 +300,7 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) { #endif } -TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NHWC_FP32_BCAST) { +TEST_F(ARM_COMMON_MULTI_THREADS, ELEMWISE_FORWARD_NHWC_FP32_BCAST) { using Mode = ElemwiseForward::Param::Mode; Checker checker(handle()); -- GitLab