提交 c48d58da 编写于 作者: M Megvii Engine Team

feat(dnn/arm_common): add N1HW like elemwise broadcast mode

GitOrigin-RevId: 28951358012c2d085f68260fd723797f943138ca
上级 669c3cda
...@@ -104,6 +104,21 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available( ...@@ -104,6 +104,21 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available(
return false; return false;
} }
bool ElemwiseImpl::AlgoBinaryVecBcastX0X::is_available(
const KernParam& kern_param) const {
if (!is_available_common(kern_param.mode) ||
((BcastType::VEC_BCASTX0X != kern_param.broad_cast_type) &&
(BcastType::BCASTX0X_VEC != kern_param.broad_cast_type)))
return false;
auto& elparam = kern_param.binary_elparam;
auto& src0 = elparam[0];
DISPATCH_TYPE("AlgoBinaryVecBcastX0X::is_available"_hash);
return false;
}
bool ElemwiseImpl::AlgoBinaryVecBcast111C::is_available( bool ElemwiseImpl::AlgoBinaryVecBcast111C::is_available(
const KernParam& kern_param) const { const KernParam& kern_param) const {
if (!is_available_common(kern_param.mode) || if (!is_available_common(kern_param.mode) ||
...@@ -348,6 +363,72 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) cons ...@@ -348,6 +363,72 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) cons
return; return;
} }
void ElemwiseImpl::AlgoBinaryVecBcastX0X::exec(const KernParam& kern_param) const {
auto& elparam = kern_param.binary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1];
auto&& dst = *(kern_param.m_dst);
BroadcastChannelInfo binfo;
// Case: BcastType::VEC + BCAST_X0X
if (BcastType::VEC_BCASTX0X == kern_param.broad_cast_type &&
is_broadcasted_3dim_like(src1.layout, binfo)) {
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::VEC_BCASTX0X>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
static_cast<const _type*>(src1.raw_ptr), \
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return
DISPATCH_TYPE("AlgoBinaryVecBcastX0X::exec_vec_b"_hash);
#undef DISPATCH_BINARY
}
// BCAST_X0X + BcastType::VEC
if (BcastType::BCASTX0X_VEC == kern_param.broad_cast_type &&
is_broadcasted_3dim_like(src0.layout, binfo)) {
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::BCASTX0X_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
static_cast<const _type*>(src1.raw_ptr), \
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return
DISPATCH_TYPE("AlgoBinaryVecBcastX0X::exec_b_vec"_hash);
#undef DISPATCH_BINARY
}
return;
}
void ElemwiseImpl::AlgoBinaryVecBcast111C::exec(const KernParam& kern_param) const { void ElemwiseImpl::AlgoBinaryVecBcast111C::exec(const KernParam& kern_param) const {
auto& elparam = kern_param.binary_elparam; auto& elparam = kern_param.binary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1]; auto &src0 = elparam[0], &src1 = elparam[1];
......
...@@ -33,6 +33,7 @@ namespace arm_common { ...@@ -33,6 +33,7 @@ namespace arm_common {
DECL_CB(VecVec); DECL_CB(VecVec);
DECL_CB(VecScalar); DECL_CB(VecScalar);
DECL_CB(VecBcast101); DECL_CB(VecBcast101);
DECL_CB(VecBcastX0X);
DECL_CB(VecBcast111C); DECL_CB(VecBcast111C);
DECL_CB(VecBcast101xX); DECL_CB(VecBcast101xX);
#undef DECL_CB #undef DECL_CB
......
...@@ -27,6 +27,7 @@ class ElemwiseImpl::AlgoPack { ...@@ -27,6 +27,7 @@ class ElemwiseImpl::AlgoPack {
AlgoBinaryVecVec algo_binary_vec_vec; AlgoBinaryVecVec algo_binary_vec_vec;
AlgoBinaryVecScalar algo_binary_vec_sca; AlgoBinaryVecScalar algo_binary_vec_sca;
AlgoBinaryVecBcast101 algo_binary_vec_bcast101; AlgoBinaryVecBcast101 algo_binary_vec_bcast101;
AlgoBinaryVecBcastX0X algo_binary_vec_bcastX0X;
AlgoBinaryVecBcast111C algo_binary_vec_bcast110; AlgoBinaryVecBcast111C algo_binary_vec_bcast110;
AlgoBinaryVecBcast101xX algo_binary_VEC_BCAST101xX; AlgoBinaryVecBcast101xX algo_binary_VEC_BCAST101xX;
AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec; AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec;
...@@ -46,6 +47,7 @@ public: ...@@ -46,6 +47,7 @@ public:
all_algos.emplace_back(&algo_binary_vec_vec); all_algos.emplace_back(&algo_binary_vec_vec);
all_algos.emplace_back(&algo_binary_vec_sca); all_algos.emplace_back(&algo_binary_vec_sca);
all_algos.emplace_back(&algo_binary_vec_bcast101); all_algos.emplace_back(&algo_binary_vec_bcast101);
all_algos.emplace_back(&algo_binary_vec_bcastX0X);
all_algos.emplace_back(&algo_binary_vec_bcast110); all_algos.emplace_back(&algo_binary_vec_bcast110);
all_algos.emplace_back(&algo_binary_VEC_BCAST101xX); all_algos.emplace_back(&algo_binary_VEC_BCAST101xX);
all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec); all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec);
...@@ -202,6 +204,16 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { ...@@ -202,6 +204,16 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
return kern_param; return kern_param;
} }
if (is_vector(src0.layout) && is_broadcasted_3dim_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCASTX0X;
return kern_param;
}
if (is_vector(src1.layout) && is_broadcasted_3dim_like(src0.layout, binfo)) {
kern_param.broad_cast_type = BcastType::BCASTX0X_VEC;
return kern_param;
}
if (is_legal_layout_for_nhwc(src1.layout) && if (is_legal_layout_for_nhwc(src1.layout) &&
is_NHWC_broadcasted_channel_like(src0.layout, binfo)) { is_NHWC_broadcasted_channel_like(src0.layout, binfo)) {
kern_param.broad_cast_type = BcastType::BCAST111C_VEC; kern_param.broad_cast_type = BcastType::BCAST111C_VEC;
......
...@@ -38,6 +38,7 @@ private: ...@@ -38,6 +38,7 @@ private:
class AlgoBinaryVecVec; class AlgoBinaryVecVec;
class AlgoBinaryVecScalar; class AlgoBinaryVecScalar;
class AlgoBinaryVecBcast101; class AlgoBinaryVecBcast101;
class AlgoBinaryVecBcastX0X;
class AlgoBinaryVecBcast111C; class AlgoBinaryVecBcast111C;
class AlgoBinaryVecBcast101xX; class AlgoBinaryVecBcast101xX;
class AlgoTernaryFma3VecVecVec; class AlgoTernaryFma3VecVecVec;
......
...@@ -107,11 +107,13 @@ enum BcastType { ...@@ -107,11 +107,13 @@ enum BcastType {
VEC, VEC,
VEC_VEC, VEC_VEC,
VEC_BCAST101, VEC_BCAST101,
VEC_BCASTX0X,
VEC_BCAST111C, VEC_BCAST111C,
VEC_BCAST101xX, VEC_BCAST101xX,
VEC_SCALAR, VEC_SCALAR,
SCALAR_VEC, SCALAR_VEC,
BCAST101_VEC, BCAST101_VEC,
BCASTX0X_VEC,
BCAST111C_VEC, BCAST111C_VEC,
BCAST101xX_VEC, BCAST101xX_VEC,
VEC_VEC_VEC, VEC_VEC_VEC,
...@@ -230,6 +232,34 @@ struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST101> { ...@@ -230,6 +232,34 @@ struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST101> {
} }
}; };
template <typename ctype>
struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCASTX0X> {
using Op = PowOp<ctype, ctype>;
static void run(
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
Op op(src0_dtype, src1_dtype, dst_dtype);
for (size_t b = 0; b < batch; b++) {
const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride;
for (size_t c = 0; c < channel; c++) {
size_t i = 0;
auto src1_ptr = src1_ptr_base;
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for (; i < channel_stride; i++) {
op(*src0, *src1_ptr, dst);
src0++;
src1_ptr++;
dst++;
}
}
}
}
};
template <typename ctype> template <typename ctype>
struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST111C> { struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST111C> {
using Op = PowOp<ctype, ctype>; using Op = PowOp<ctype, ctype>;
...@@ -332,6 +362,34 @@ struct OpCallerBinary<PowOp<ctype, ctype>, BCAST101_VEC> { ...@@ -332,6 +362,34 @@ struct OpCallerBinary<PowOp<ctype, ctype>, BCAST101_VEC> {
} }
}; };
template <typename ctype>
struct OpCallerBinary<PowOp<ctype, ctype>, BCASTX0X_VEC> {
using Op = PowOp<ctype, ctype>;
static void run(
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
Op op(src0_dtype, src1_dtype, dst_dtype);
for (size_t b = 0; b < batch; b++) {
auto src0_ptr_base = src0 + b * channel_stride;
for (size_t c = 0; c < channel; c++) {
size_t i = 0;
auto src0_ptr = src0_ptr_base;
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for (; i < channel_stride; i++) {
op(*src0_ptr, *src1, dst);
src0_ptr++;
src1++;
dst++;
}
}
}
}
};
template <typename Op> template <typename Op>
struct OpCallerBinary<Op, VEC_VEC> { struct OpCallerBinary<Op, VEC_VEC> {
static void run( static void run(
...@@ -398,6 +456,45 @@ struct OpCallerBinary<Op, VEC_BCAST101> { ...@@ -398,6 +456,45 @@ struct OpCallerBinary<Op, VEC_BCAST101> {
} }
}; };
template <typename Op>
struct OpCallerBinary<Op, VEC_BCASTX0X> {
static void run(
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
Op op(src0_dtype, src1_dtype, dst_dtype);
ParamElemVisitor<typename Op::src_ctype> vis;
for (size_t b = 0; b < batch; b++) {
const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride;
for (size_t c = 0; c < channel; c++) {
size_t i = 0;
auto src1_ptr = src1_ptr_base;
for (; i + Op::SIMD_WIDTH * 2 <= channel_stride;
i += Op::SIMD_WIDTH * 2) {
auto src0_neon0 = vis(src0);
auto src0_neon1 = vis(src0 + Op::SIMD_WIDTH);
auto src1_neon0 = vis(src1_ptr);
auto src1_neon1 = vis(src1_ptr + Op::SIMD_WIDTH);
op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst);
src0 += Op::SIMD_WIDTH * 2;
src1_ptr += Op::SIMD_WIDTH * 2;
dst += Op::SIMD_WIDTH * 2;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for (; i < channel_stride; i++) {
op(*src0, *src1_ptr, dst);
src0++;
src1_ptr++;
dst++;
}
}
}
}
};
template <typename Op> template <typename Op>
struct OpCallerBinary<Op, VEC_BCAST111C> { struct OpCallerBinary<Op, VEC_BCAST111C> {
static void run( static void run(
...@@ -844,6 +941,45 @@ struct OpCallerBinary<Op, BCAST101_VEC> { ...@@ -844,6 +941,45 @@ struct OpCallerBinary<Op, BCAST101_VEC> {
} }
}; };
template <typename Op>
struct OpCallerBinary<Op, BCASTX0X_VEC> {
static void run(
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
Op op(src0_dtype, src1_dtype, dst_dtype);
ParamElemVisitor<typename Op::src_ctype> vis;
for (size_t b = 0; b < batch; b++) {
auto src0_ptr_base = src0 + b * channel_stride;
for (size_t c = 0; c < channel; c++) {
auto src0_ptr = src0_ptr_base;
size_t i = 0;
for (; i + Op::SIMD_WIDTH * 2 <= channel_stride;
i += Op::SIMD_WIDTH * 2) {
auto src0_neon0 = vis(src0_ptr);
auto src0_neon1 = vis(src0_ptr + Op::SIMD_WIDTH);
auto src1_neon0 = vis(src1);
auto src1_neon1 = vis(src1 + Op::SIMD_WIDTH);
op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst);
src0_ptr += Op::SIMD_WIDTH * 2;
src1 += Op::SIMD_WIDTH * 2;
dst += Op::SIMD_WIDTH * 2;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for (; i < channel_stride; i++) {
op(*src0_ptr, *src1, dst);
src0_ptr++;
src1++;
dst++;
}
}
}
}
};
template <typename Op, BcastType bcast_type> template <typename Op, BcastType bcast_type>
struct OpCallerTernary; struct OpCallerTernary;
......
...@@ -150,6 +150,20 @@ bool ElemwiseLayoutHelper::is_broadcasted_channel_like( ...@@ -150,6 +150,20 @@ bool ElemwiseLayoutHelper::is_broadcasted_channel_like(
return false; return false;
} }
bool ElemwiseLayoutHelper::is_broadcasted_3dim_like(
const TensorLayout& layout, BroadcastChannelInfo& info) {
if (layout.format.type() == TensorFormat::Type::DEFAULT) {
if (layout.ndim == 3 && (layout.stride[0] - layout.shape[2]) == 0 &&
layout.stride[1] == 0 && layout.stride[2] == 1) {
info.x = layout.shape[0];
info.y = layout.shape[1];
info.z = layout.shape[2];
return true;
}
}
return false;
}
bool ElemwiseLayoutHelper::is_NHWC_broadcasted_channel_like( bool ElemwiseLayoutHelper::is_NHWC_broadcasted_channel_like(
const TensorLayout& layout, BroadcastChannelInfo& info) { const TensorLayout& layout, BroadcastChannelInfo& info) {
if (layout.format.type() == TensorFormat::Type::DEFAULT) { if (layout.format.type() == TensorFormat::Type::DEFAULT) {
......
...@@ -80,6 +80,14 @@ public: ...@@ -80,6 +80,14 @@ public:
static bool is_broadcasted_channel_like( static bool is_broadcasted_channel_like(
const TensorLayout& layout, BroadcastChannelInfo& info); const TensorLayout& layout, BroadcastChannelInfo& info);
/*!
* \brief check whether layout matches BroadcastChannelInfo like N1HW
*
* Note layout should be [N, 1, H*W] like
*/
static bool is_broadcasted_3dim_like(
const TensorLayout& layout, BroadcastChannelInfo& info);
/*! /*!
* \brief check whether layout matches BroadcastChannelInfo under NHWC * \brief check whether layout matches BroadcastChannelInfo under NHWC
* layout * layout
......
...@@ -356,6 +356,30 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NHWC_FP32_BCAST) { ...@@ -356,6 +356,30 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NHWC_FP32_BCAST) {
run_3d_incontig(Mode::FUSE_MUL_ADD3); run_3d_incontig(Mode::FUSE_MUL_ADD3);
} }
TEST_F(ARM_COMMON, ELEMWISE_FORWARD_N1HW_FP32_BCAST) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle());
UniformFloatRNG rng(1e-5, 7e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());
//! 2 dim
auto run = [&](Mode mode) {
// VEC_BCASTX0X
checker.set_param(mode).execs({{2, 8, 4, 4}, {2, 1, 4, 4}, {}});
checker.set_param(mode).execs({{4, 21, 78}, {4, 1, 78}, {}});
// BCASTX0X_VEC
checker.set_param(mode).execs({{2, 1, 4, 4}, {2, 8, 4, 4}, {}});
checker.set_param(mode).execs({{4, 1, 78}, {4, 21, 78}, {}});
};
run(Mode::ADD);
run(Mode::MUL);
run(Mode::SUB);
}
#if MEGDNN_WITH_BENCHMARK #if MEGDNN_WITH_BENCHMARK
namespace { namespace {
void run_elemwise_benchmark( void run_elemwise_benchmark(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册