提交 ee0b95e9 编写于 作者: M Megvii Engine Team

feat(dnn/elemwise/arm_common): support part of arm ternary elemwise multithread

BCAST111C_VEC_BCAST111C and BCAST101_VEC_BCAST101

GitOrigin-RevId: 0e26553c90563b6d5f752882749ef049cf9d1d31
上级 7ea104d7
......@@ -144,21 +144,35 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec(
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, const _type*, _type*, DType, DType, \
DType, DType, size_t, size_t, size_t)> \
DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<const _type*>(src2.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
binfo.x, binfo.y, binfo.z)); \
auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \
binfo, dst, run](size_t task_id, size_t) { \
size_t offset = task_id * nr_channels_per_thread; \
size_t nr_channels_thread = \
std::min(nr_channels - offset, nr_channels_per_thread); \
run(static_cast<const _type*>(src0.raw_ptr()) + offset, \
static_cast<const _type*>(src1.raw_ptr()) + offset * binfo.z, \
static_cast<const _type*>(src2.raw_ptr()) + offset, \
static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \
binfo.y * binfo.z); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \
kernel); \
} \
MIDOUT_END(); \
return
size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle)
->megcore_dispatcher()
->nr_threads();
size_t nr_channels = binfo.y;
size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads;
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3Bcast101VecBcast101::exec"_hash);
#undef DISPATCH_TERNARY
......@@ -181,23 +195,39 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec(
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, size_t, const _type*, _type*, DType, \
DType, DType, DType, size_t, size_t, size_t)> \
DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, \
BcastType::BCAST111C_VEC_BCAST111C>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \
binfo, dst, run](size_t task_id, size_t) { \
size_t offset = task_id * nr_channels_per_thread; \
size_t nr_channels_thread = \
std::min(nr_channels - offset, nr_channels_per_thread); \
size_t src1_offset = \
is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z; \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z, \
static_cast<const _type*>(src2.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
binfo.x, binfo.y, binfo.z)); \
static_cast<const _type*>(src1.raw_ptr()) + \
offset * (binfo.z + src1_offset), \
src1_offset, static_cast<const _type*>(src2.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \
binfo.y * binfo.z); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \
kernel); \
} \
MIDOUT_END(); \
return
size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle)
->megcore_dispatcher()
->nr_threads();
size_t nr_channels = binfo.y;
size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads;
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3Bcast111CVecBcast111C::exec"_hash);
#undef DISPATCH_TERNARY
......
......@@ -772,13 +772,14 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \
thin_function<void( \
const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \
DType, DType, DType, DType, size_t, size_t, size_t)> \
DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), src2.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
src2.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \
binfo.y, binfo.z, binfo.y* binfo.z)); \
return; \
}
......
......@@ -1060,7 +1060,8 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> {
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
const typename Op::src_ctype* src2, typename Op::dst_ctype* dst,
DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype,
size_t batch_size, size_t channel_size, size_t channel_stride) {
size_t batch_size, size_t channel_size, size_t channel_stride,
size_t batch_offset) {
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
ParamElemVisitor<typename Op::src_ctype> vis1;
ParamElemVisitorDup<typename Op::src_ctype> vis0;
......@@ -1068,6 +1069,7 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> {
for (size_t batch = 0; batch < batch_size; batch++) {
auto src0_ptr = src0;
auto src2_ptr = src2;
auto b_offset = batch_offset;
for (size_t channel = 0; channel < channel_size; channel++) {
size_t i = 0;
auto src0_neon = vis0(src0_ptr);
......@@ -1079,6 +1081,7 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> {
{{src2_neon, src2_neon}}, dst);
src1 += Op::SIMD_WIDTH * 2;
dst += Op::SIMD_WIDTH * 2;
b_offset -= Op::SIMD_WIDTH * 2;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
......@@ -1088,10 +1091,13 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> {
op(*src0_ptr, *src1, *src2_ptr, dst);
src1++;
dst++;
b_offset--;
}
src0_ptr++;
src2_ptr++;
}
src1 += b_offset;
dst += b_offset;
}
}
};
......@@ -1104,10 +1110,11 @@ struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> {
size_t src1_offset, const typename Op::src_ctype* src2,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size,
size_t channel_stride) {
size_t channel_stride, size_t batch_offset) {
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
ParamElemVisitor<typename Op::src_ctype> vis;
for (size_t batch = 0; batch < batch_size; batch++) {
auto b_offset = batch_offset;
for (size_t channel = 0; channel < channel_size; channel++) {
auto src0_ptr = src0;
auto src2_ptr = src2;
......@@ -1126,6 +1133,7 @@ struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> {
src1 += Op::SIMD_WIDTH * 2;
src2_ptr += Op::SIMD_WIDTH * 2;
dst += Op::SIMD_WIDTH * 2;
b_offset -= Op::SIMD_WIDTH * 2;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
......@@ -1137,9 +1145,12 @@ struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> {
src1++;
src2_ptr++;
dst++;
b_offset--;
}
src1 += src1_offset;
}
src1 += b_offset;
dst += b_offset;
}
}
};
......
......@@ -300,7 +300,7 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) {
#endif
}
TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NHWC_FP32_BCAST) {
TEST_F(ARM_COMMON_MULTI_THREADS, ELEMWISE_FORWARD_NHWC_FP32_BCAST) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册