提交 ee0b95e9 编写于 作者: M Megvii Engine Team

feat(dnn/elemwise/arm_common): support part of arm ternary elemwise multithread

BCAST111C_VEC_BCAST111C and BCAST101_VEC_BCAST101

GitOrigin-RevId: 0e26553c90563b6d5f752882749ef049cf9d1d31
上级 7ea104d7
...@@ -144,21 +144,35 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( ...@@ -144,21 +144,35 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec(
midout_iv(Mode::_mode), _type_midout_id) { \ midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \ thin_function<void( \
const _type*, const _type*, const _type*, _type*, DType, DType, \ const _type*, const _type*, const _type*, _type*, DType, DType, \
DType, DType, size_t, size_t, size_t)> \ DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \ run = OpCallerTernary< \
_op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \ _op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \
MEGDNN_DISPATCH_CPU_KERN( \ auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \
static_cast<naive::HandleImpl*>(kern_param.handle), \ binfo, dst, run](size_t task_id, size_t) { \
run(static_cast<const _type*>(src0.raw_ptr()), \ size_t offset = task_id * nr_channels_per_thread; \
static_cast<const _type*>(src1.raw_ptr()), \ size_t nr_channels_thread = \
static_cast<const _type*>(src2.raw_ptr()), \ std::min(nr_channels - offset, nr_channels_per_thread); \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ run(static_cast<const _type*>(src0.raw_ptr()) + offset, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ static_cast<const _type*>(src1.raw_ptr()) + offset * binfo.z, \
binfo.x, binfo.y, binfo.z)); \ static_cast<const _type*>(src2.raw_ptr()) + offset, \
static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \
binfo.y * binfo.z); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \
kernel); \
} \ } \
MIDOUT_END(); \ MIDOUT_END(); \
return return
size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle)
->megcore_dispatcher()
->nr_threads();
size_t nr_channels = binfo.y;
size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads;
auto&& dst = *(kern_param.m_dst); auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3Bcast101VecBcast101::exec"_hash); DISPATCH_TYPE("AlgoTernaryFma3Bcast101VecBcast101::exec"_hash);
#undef DISPATCH_TERNARY #undef DISPATCH_TERNARY
...@@ -181,23 +195,39 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec( ...@@ -181,23 +195,39 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec(
midout_iv(Mode::_mode), _type_midout_id) { \ midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \ thin_function<void( \
const _type*, const _type*, size_t, const _type*, _type*, DType, \ const _type*, const _type*, size_t, const _type*, _type*, DType, \
DType, DType, DType, size_t, size_t, size_t)> \ DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \ run = OpCallerTernary< \
_op<_type, _type>, \ _op<_type, _type>, \
BcastType::BCAST111C_VEC_BCAST111C>::run; \ BcastType::BCAST111C_VEC_BCAST111C>::run; \
MEGDNN_DISPATCH_CPU_KERN( \ auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \
static_cast<naive::HandleImpl*>(kern_param.handle), \ binfo, dst, run](size_t task_id, size_t) { \
size_t offset = task_id * nr_channels_per_thread; \
size_t nr_channels_thread = \
std::min(nr_channels - offset, nr_channels_per_thread); \
size_t src1_offset = \
is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z; \
run(static_cast<const _type*>(src0.raw_ptr()), \ run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \ static_cast<const _type*>(src1.raw_ptr()) + \
is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z, \ offset * (binfo.z + src1_offset), \
static_cast<const _type*>(src2.raw_ptr()), \ src1_offset, static_cast<const _type*>(src2.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
binfo.x, binfo.y, binfo.z)); \ dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \
binfo.y * binfo.z); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \
kernel); \
} \ } \
MIDOUT_END(); \ MIDOUT_END(); \
return return
size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle)
->megcore_dispatcher()
->nr_threads();
size_t nr_channels = binfo.y;
size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads;
auto&& dst = *(kern_param.m_dst); auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3Bcast111CVecBcast111C::exec"_hash); DISPATCH_TYPE("AlgoTernaryFma3Bcast111CVecBcast111C::exec"_hash);
#undef DISPATCH_TERNARY #undef DISPATCH_TERNARY
......
...@@ -772,13 +772,14 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( ...@@ -772,13 +772,14 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void( \ thin_function<void( \
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \ const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \
DType, DType, DType, DType, size_t, size_t, size_t)> \ DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \ run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \ _op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), src2.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \
src2.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \
binfo.y, binfo.z, binfo.y* binfo.z)); \
return; \ return; \
} }
......
...@@ -1060,7 +1060,8 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> { ...@@ -1060,7 +1060,8 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> {
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, const typename Op::src_ctype* src2, typename Op::dst_ctype* dst,
DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype,
size_t batch_size, size_t channel_size, size_t channel_stride) { size_t batch_size, size_t channel_size, size_t channel_stride,
size_t batch_offset) {
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
ParamElemVisitor<typename Op::src_ctype> vis1; ParamElemVisitor<typename Op::src_ctype> vis1;
ParamElemVisitorDup<typename Op::src_ctype> vis0; ParamElemVisitorDup<typename Op::src_ctype> vis0;
...@@ -1068,6 +1069,7 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> { ...@@ -1068,6 +1069,7 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> {
for (size_t batch = 0; batch < batch_size; batch++) { for (size_t batch = 0; batch < batch_size; batch++) {
auto src0_ptr = src0; auto src0_ptr = src0;
auto src2_ptr = src2; auto src2_ptr = src2;
auto b_offset = batch_offset;
for (size_t channel = 0; channel < channel_size; channel++) { for (size_t channel = 0; channel < channel_size; channel++) {
size_t i = 0; size_t i = 0;
auto src0_neon = vis0(src0_ptr); auto src0_neon = vis0(src0_ptr);
...@@ -1079,6 +1081,7 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> { ...@@ -1079,6 +1081,7 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> {
{{src2_neon, src2_neon}}, dst); {{src2_neon, src2_neon}}, dst);
src1 += Op::SIMD_WIDTH * 2; src1 += Op::SIMD_WIDTH * 2;
dst += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2;
b_offset -= Op::SIMD_WIDTH * 2;
} }
#if MEGDNN_FIX_AARCH32_BUG #if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize // FIXME: as llvm may cause cannot select error if enable vectorize
...@@ -1088,10 +1091,13 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> { ...@@ -1088,10 +1091,13 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> {
op(*src0_ptr, *src1, *src2_ptr, dst); op(*src0_ptr, *src1, *src2_ptr, dst);
src1++; src1++;
dst++; dst++;
b_offset--;
} }
src0_ptr++; src0_ptr++;
src2_ptr++; src2_ptr++;
} }
src1 += b_offset;
dst += b_offset;
} }
} }
}; };
...@@ -1104,10 +1110,11 @@ struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> { ...@@ -1104,10 +1110,11 @@ struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> {
size_t src1_offset, const typename Op::src_ctype* src2, size_t src1_offset, const typename Op::src_ctype* src2,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size, DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size,
size_t channel_stride) { size_t channel_stride, size_t batch_offset) {
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
ParamElemVisitor<typename Op::src_ctype> vis; ParamElemVisitor<typename Op::src_ctype> vis;
for (size_t batch = 0; batch < batch_size; batch++) { for (size_t batch = 0; batch < batch_size; batch++) {
auto b_offset = batch_offset;
for (size_t channel = 0; channel < channel_size; channel++) { for (size_t channel = 0; channel < channel_size; channel++) {
auto src0_ptr = src0; auto src0_ptr = src0;
auto src2_ptr = src2; auto src2_ptr = src2;
...@@ -1126,6 +1133,7 @@ struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> { ...@@ -1126,6 +1133,7 @@ struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> {
src1 += Op::SIMD_WIDTH * 2; src1 += Op::SIMD_WIDTH * 2;
src2_ptr += Op::SIMD_WIDTH * 2; src2_ptr += Op::SIMD_WIDTH * 2;
dst += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2;
b_offset -= Op::SIMD_WIDTH * 2;
} }
#if MEGDNN_FIX_AARCH32_BUG #if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize // FIXME: as llvm may cause cannot select error if enable vectorize
...@@ -1137,9 +1145,12 @@ struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> { ...@@ -1137,9 +1145,12 @@ struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> {
src1++; src1++;
src2_ptr++; src2_ptr++;
dst++; dst++;
b_offset--;
} }
src1 += src1_offset; src1 += src1_offset;
} }
src1 += b_offset;
dst += b_offset;
} }
} }
}; };
......
...@@ -300,7 +300,7 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) { ...@@ -300,7 +300,7 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) {
#endif #endif
} }
TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NHWC_FP32_BCAST) { TEST_F(ARM_COMMON_MULTI_THREADS, ELEMWISE_FORWARD_NHWC_FP32_BCAST) {
using Mode = ElemwiseForward::Param::Mode; using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle()); Checker<ElemwiseForward> checker(handle());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册