提交 5885b137 编写于 作者: M Megvii Engine Team

feat(dnn/arm): support layout like NHWC channel like broadcast on arm

GitOrigin-RevId: fb4300004c4e1920d3cd1be40ca33bde822e4c72
上级 565466c2
......@@ -104,6 +104,21 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available(
return false;
}
bool ElemwiseImpl::AlgoBinaryVecBcast111C::is_available(
const KernParam& kern_param) const {
if (!is_available_common(kern_param.mode) ||
((BcastType::VEC_BCAST111C != kern_param.broad_cast_type) &&
(BcastType::BCAST111C_VEC != kern_param.broad_cast_type)))
return false;
auto& elparam = kern_param.binary_elparam;
auto& src0 = elparam[0];
DISPATCH_TYPE("AlgoBinaryVecBcast111C::is_available"_hash);
return false;
}
bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available(
const KernParam& kern_param) const {
if (!is_available_common(kern_param.mode) ||
......@@ -333,6 +348,72 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) cons
return;
}
void ElemwiseImpl::AlgoBinaryVecBcast111C::exec(const KernParam& kern_param) const {
auto& elparam = kern_param.binary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1];
auto&& dst = *(kern_param.m_dst);
BroadcastChannelInfo binfo;
// Case extra: BcastType::VEC + BCAST_111C
if (BcastType::VEC_BCAST111C == kern_param.broad_cast_type &&
is_NHWC_broadcasted_channel_like(src1.layout, binfo)) {
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::VEC_BCAST111C>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
static_cast<const _type*>(src1.raw_ptr), \
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return
DISPATCH_TYPE("AlgoBinaryVecBcast111C::exec_vec_b"_hash);
#undef DISPATCH_BINARY
}
// BCAST_111C + BcastType::VEC
if (BcastType::BCAST111C_VEC == kern_param.broad_cast_type &&
is_NHWC_broadcasted_channel_like(src0.layout, binfo)) {
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::BCAST111C_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
static_cast<const _type*>(src1.raw_ptr), \
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return
DISPATCH_TYPE("AlgoBinaryVecBcast111C::exec_b_vec"_hash);
#undef DISPATCH_BINARY
}
return;
}
void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec(const KernParam& kern_param) const {
auto& elparam = kern_param.binary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1];
......
......@@ -33,6 +33,7 @@ namespace arm_common {
DECL_CB(VecVec);
DECL_CB(VecScalar);
DECL_CB(VecBcast101);
DECL_CB(VecBcast111C);
DECL_CB(VecBcast101xX);
#undef DECL_CB
} // namespace arm_common
......
......@@ -27,12 +27,15 @@ class ElemwiseImpl::AlgoPack {
AlgoBinaryVecVec algo_binary_vec_vec;
AlgoBinaryVecScalar algo_binary_vec_sca;
AlgoBinaryVecBcast101 algo_binary_vec_bcast101;
AlgoBinaryVecBcast111C algo_binary_vec_bcast110;
AlgoBinaryVecBcast101xX algo_binary_VEC_BCAST101xX;
AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec;
AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca;
AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101;
AlgoTernaryFma3Bcast111CVecBcast111C algo_ternaryfma3_bcast110_vec_bcast110;
AlgoTernaryFma3Bcast101xXVecBcast101xX algo_ternaryfma3_bcast101xX_vec_bcast101xX;
AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec;
AlgoTernaryFma3VecBcast111CVec algo_ternaryfma3_vec_bcast110_vec;
AlgoTernaryFma3VecBcast101xXVec algo_ternaryfma3_vec_bcast101xX_vec;
AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec;
AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca;
......@@ -43,12 +46,15 @@ public:
all_algos.emplace_back(&algo_binary_vec_vec);
all_algos.emplace_back(&algo_binary_vec_sca);
all_algos.emplace_back(&algo_binary_vec_bcast101);
all_algos.emplace_back(&algo_binary_vec_bcast110);
all_algos.emplace_back(&algo_binary_VEC_BCAST101xX);
all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca);
all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101);
all_algos.emplace_back(&algo_ternaryfma3_bcast110_vec_bcast110);
all_algos.emplace_back(&algo_ternaryfma3_bcast101xX_vec_bcast101xX);
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast110_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101xX_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca);
......@@ -87,6 +93,14 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
kern_param.mode = opr->param().mode;
kern_param.handle = opr->handle();
auto is_legal_layout_for_nhwc = [](const TensorLayout& l) {
if (is_vector(l))
return true;
if (l.ndim == 2 && l.stride[1] == 1)
return true;
return false;
};
if ((opr->m_src->size() == 3) && (opr->param().mode == Mode::FUSE_MUL_ADD3)) {
kern_param.ternary_elparam = opr->make_elemwise_op_param<3>();
bool c_is_scalar;
......@@ -127,6 +141,20 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
return kern_param;
}
if (is_legal_layout_for_nhwc(src1.layout) &&
is_NHWC_broadcasted_channel_like(src0.layout, binfo) &&
src0.layout.eq_layout(src2.layout)) {
kern_param.broad_cast_type = BcastType::BCAST111C_VEC_BCAST111C;
return kern_param;
}
if (is_legal_layout_for_nhwc(src0.layout) &&
src2.layout.eq_layout(src0.layout) &&
is_NHWC_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST111C_VEC;
return kern_param;
}
if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) &&
(is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo))) {
......@@ -174,6 +202,18 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
return kern_param;
}
if (is_legal_layout_for_nhwc(src1.layout) &&
is_NHWC_broadcasted_channel_like(src0.layout, binfo)) {
kern_param.broad_cast_type = BcastType::BCAST111C_VEC;
return kern_param;
}
if (is_legal_layout_for_nhwc(src0.layout) &&
is_NHWC_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST111C;
return kern_param;
}
if (is_vector(src0.layout) &&
(is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo))) {
......
......@@ -38,12 +38,15 @@ private:
class AlgoBinaryVecVec;
class AlgoBinaryVecScalar;
class AlgoBinaryVecBcast101;
class AlgoBinaryVecBcast111C;
class AlgoBinaryVecBcast101xX;
class AlgoTernaryFma3VecVecVec;
class AlgoTernaryFma3VecVecScalar;
class AlgoTernaryFma3Bcast101VecBcast101;
class AlgoTernaryFma3Bcast111CVecBcast111C;
class AlgoTernaryFma3Bcast101xXVecBcast101xX;
class AlgoTernaryFma3VecBcast101Vec;
class AlgoTernaryFma3VecBcast111CVec;
class AlgoTernaryFma3VecBcast101xXVec;
class AlgoTernaryFma3VecScalarVec;
class AlgoTernaryFma3VecScalarScalar;
......
......@@ -42,8 +42,10 @@ using namespace arm_common;
DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC);
DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR);
DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101);
DECL_AVAILABLE(Bcast111CVecBcast111C, BcastType::BCAST111C_VEC_BCAST111C);
DECL_AVAILABLE(Bcast101xXVecBcast101xX, BcastType::BCAST101xX_VEC_BCAST101xX);
DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC);
DECL_AVAILABLE(VecBcast111CVec, BcastType::VEC_BCAST111C_VEC);
DECL_AVAILABLE(VecBcast101xXVec, BcastType::VEC_BCAST101xX_VEC);
DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC);
DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR);
......@@ -164,6 +166,45 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec(
return;
}
void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];
// Case 3: shape of src0 and src2 is {1, 1, 1, C}
BroadcastChannelInfo binfo;
is_NHWC_broadcasted_channel_like(src0.layout, binfo);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, size_t, const _type*, _type*, DType, \
DType, DType, DType, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, \
BcastType::BCAST111C_VEC_BCAST111C>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
static_cast<const _type*>(src1.raw_ptr), \
is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z, \
static_cast<const _type*>(src2.raw_ptr), \
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
binfo.x, binfo.y, binfo.z)); \
} \
MIDOUT_END(); \
return
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3Bcast111CVecBcast111C::exec"_hash);
#undef DISPATCH_TERNARY
return;
}
void ElemwiseImpl::AlgoTernaryFma3Bcast101xXVecBcast101xX::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
......@@ -282,6 +323,45 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec(
return;
}
void ElemwiseImpl::AlgoTernaryFma3VecBcast111CVec::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];
// Case 4: shape of src1 is {1, 1, 1, C}, and src0 and src2 are contig
BroadcastChannelInfo binfo;
is_NHWC_broadcasted_channel_like(src1.layout, binfo);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, size_t, const _type*, const _type*, size_t, _type*, \
DType, DType, DType, DType, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::VEC_BCAST111C_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
is_vector(src0.layout) ? 0 : src0.layout.stride[0] - binfo.z, \
static_cast<const _type*>(src1.raw_ptr), \
static_cast<const _type*>(src2.raw_ptr), \
is_vector(src2.layout) ? 0 : src2.layout.stride[0] - binfo.z, \
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
binfo.x, binfo.y, binfo.z)); \
} \
MIDOUT_END(); \
return
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3VecBcast111CVec::exec"_hash);
#undef DISPATCH_TERNARY
return;
}
void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
......
......@@ -33,8 +33,10 @@ namespace arm_common {
DECL_CB(VecVecVec);
DECL_CB(VecVecScalar);
DECL_CB(Bcast101VecBcast101);
DECL_CB(Bcast111CVecBcast111C);
DECL_CB(Bcast101xXVecBcast101xX);
DECL_CB(VecBcast101Vec);
DECL_CB(VecBcast111CVec);
DECL_CB(VecBcast101xXVec);
DECL_CB(VecScalarVec);
DECL_CB(VecScalarScalar);
......
......@@ -107,16 +107,20 @@ enum BcastType {
VEC,
VEC_VEC,
VEC_BCAST101,
VEC_BCAST111C,
VEC_BCAST101xX,
VEC_SCALAR,
SCALAR_VEC,
BCAST101_VEC,
BCAST111C_VEC,
BCAST101xX_VEC,
VEC_VEC_VEC,
VEC_VEC_SCALAR,
BCAST101_VEC_BCAST101,
BCAST111C_VEC_BCAST111C,
BCAST101xX_VEC_BCAST101xX,
VEC_BCAST101_VEC,
VEC_BCAST111C_VEC,
VEC_BCAST101xX_VEC,
VEC_SCALAR_VEC,
VEC_SCALAR_SCALAR,
......@@ -226,6 +230,60 @@ struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST101> {
}
};
template <typename ctype>
struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST111C> {
using Op = PowOp<ctype, ctype>;
static void run(
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
Op op(src0_dtype, src1_dtype, dst_dtype);
for (size_t b = 0; b < batch; b++) {
for (size_t c = 0; c < channel; c++) {
size_t i = 0;
const typename Op::src_ctype* src1_ptr = src1;
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for (; i < channel_stride; i++) {
op(*src0, *src1_ptr, dst);
src0++;
src1_ptr++;
dst++;
}
}
}
}
};
template <typename ctype>
struct OpCallerBinary<PowOp<ctype, ctype>, BCAST111C_VEC> {
using Op = PowOp<ctype, ctype>;
static void run(
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
Op op(src0_dtype, src1_dtype, dst_dtype);
for (size_t b = 0; b < batch; b++) {
for (size_t c = 0; c < channel; c++) {
size_t i = 0;
const typename Op::src_ctype* src0_ptr = src0;
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for (; i < channel_stride; i++) {
op(*src0_ptr, *src1, dst);
src0_ptr++;
src1++;
dst++;
}
}
}
}
};
template <typename ctype>
struct OpCallerBinary<PowOp<ctype, ctype>, SCALAR_VEC> {
using Op = PowOp<ctype, ctype>;
......@@ -340,6 +398,84 @@ struct OpCallerBinary<Op, VEC_BCAST101> {
}
};
template <typename Op>
struct OpCallerBinary<Op, VEC_BCAST111C> {
static void run(
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
Op op(src0_dtype, src1_dtype, dst_dtype);
ParamElemVisitor<typename Op::src_ctype> vis;
for (size_t b = 0; b < batch; b++) {
for (size_t c = 0; c < channel; c++) {
size_t rest = channel_stride;
const typename Op::src_ctype* src1_ptr = src1;
while (rest >= Op::SIMD_WIDTH * 2) {
auto src0_neon0 = vis(src0);
auto src0_neon1 = vis(src0 + Op::SIMD_WIDTH);
auto src1_neon0 = vis(src1_ptr);
auto src1_neon1 = vis(src1_ptr + Op::SIMD_WIDTH);
src0 += Op::SIMD_WIDTH * 2;
src1_ptr += Op::SIMD_WIDTH * 2;
op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst);
dst += Op::SIMD_WIDTH * 2;
rest -= Op::SIMD_WIDTH * 2;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
while (rest > 0) {
op(*src0, *src1_ptr, dst);
dst++;
src0++;
src1_ptr++;
rest--;
}
}
}
}
};
template <typename Op>
struct OpCallerBinary<Op, BCAST111C_VEC> {
static void run(
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
Op op(src0_dtype, src1_dtype, dst_dtype);
ParamElemVisitor<typename Op::src_ctype> vis;
for (size_t b = 0; b < batch; b++) {
for (size_t c = 0; c < channel; c++) {
size_t rest = channel_stride;
const typename Op::src_ctype* src0_ptr = src0;
while (rest >= Op::SIMD_WIDTH * 2) {
auto src0_neon0 = vis(src0_ptr);
auto src0_neon1 = vis(src0_ptr + Op::SIMD_WIDTH);
auto src1_neon0 = vis(src1);
auto src1_neon1 = vis(src1 + Op::SIMD_WIDTH);
src0_ptr += Op::SIMD_WIDTH * 2;
src1 += Op::SIMD_WIDTH * 2;
op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst);
dst += Op::SIMD_WIDTH * 2;
rest -= Op::SIMD_WIDTH * 2;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
while (rest > 0) {
op(*src0_ptr, *src1, dst);
dst++;
src0_ptr++;
src1++;
rest--;
}
}
}
}
};
template <typename ctype>
struct OpCallerBinary<PowOp<ctype, ctype>, BCAST101xX_VEC> {
using Op = PowOp<ctype, ctype>;
......@@ -824,6 +960,54 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> {
}
};
//! src0: 111C, src1: vector, src2: 111C, src1 may not be contig
template <typename Op>
struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> {
static void run(
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
size_t src1_offset, const typename Op::src_ctype* src2,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size,
size_t channel_stride) {
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
ParamElemVisitor<typename Op::src_ctype> vis;
for (size_t batch = 0; batch < batch_size; batch++) {
for (size_t channel = 0; channel < channel_size; channel++) {
auto src0_ptr = src0;
auto src2_ptr = src2;
size_t i = 0;
for (; i + Op::SIMD_WIDTH * 2 <= channel_stride;
i += Op::SIMD_WIDTH * 2) {
auto src0_neon0 = vis(src0_ptr);
auto src0_neon1 = vis(src0_ptr + Op::SIMD_WIDTH);
auto src1_neon0 = vis(src1);
auto src1_neon1 = vis(src1 + Op::SIMD_WIDTH);
auto src2_neon0 = vis(src2_ptr);
auto src2_neon1 = vis(src2_ptr + Op::SIMD_WIDTH);
op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}},
{{src2_neon0, src2_neon1}}, dst);
src0_ptr += Op::SIMD_WIDTH * 2;
src1 += Op::SIMD_WIDTH * 2;
src2_ptr += Op::SIMD_WIDTH * 2;
dst += Op::SIMD_WIDTH * 2;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for (; i < channel_stride; i++) {
op(*src0_ptr, *src1, *src2_ptr, dst);
src0_ptr++;
src1++;
src2_ptr++;
dst++;
}
src1 += src1_offset;
}
}
}
};
template <typename src_ctype, size_t channel_block_dim>
struct OpCallerTernaryBcast101xXVecBcast101xX {
template <typename Op>
......@@ -992,6 +1176,51 @@ struct OpCallerTernary<Op, VEC_BCAST101_VEC> {
}
};
//! src1: 111C, src0 and src2 may not be contig
template <typename Op>
struct OpCallerTernary<Op, VEC_BCAST111C_VEC> {
static void run(
const typename Op::src_ctype* src0, size_t src0_offset,
const typename Op::src_ctype* src1, const typename Op::src_ctype* src2,
size_t src2_offset, typename Op::dst_ctype* dst, DType src0_dtype,
DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t batch_size,
size_t channel_size, size_t channel_stride) {
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
ParamElemVisitor<typename Op::src_ctype> vis0;
ParamElemVisitor<typename Op::src_ctype> vis1;
ParamElemVisitor<typename Op::src_ctype> vis2;
for (size_t batch = 0; batch < batch_size; batch++) {
for (size_t channel = 0; channel < channel_size; channel++) {
auto src1_ptr = src1;
size_t i = 0;
for (; i + Op::SIMD_WIDTH * 2 <= channel_stride;
i += Op::SIMD_WIDTH * 2) {
op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}},
{{vis1(src1_ptr), vis1(src1_ptr + Op::SIMD_WIDTH)}},
{{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst);
src0 += Op::SIMD_WIDTH * 2;
src1_ptr += Op::SIMD_WIDTH * 2;
src2 += Op::SIMD_WIDTH * 2;
dst += Op::SIMD_WIDTH * 2;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for (; i < channel_stride; i++) {
op(*src0, *src1_ptr, *src2, dst);
src0++;
src1_ptr++;
src2++;
dst++;
}
src0 += src0_offset;
src2 += src2_offset;
}
}
}
};
template <typename src_ctype, size_t channel_block_dim>
struct OpCallerTernaryVecBcast101xXVec {
template <typename Op>
......
......@@ -50,6 +50,20 @@ inline dt_qint32 QConverter::convert(const float& src) {
saturate<int32_t, float>(std::round(src), -2147483648, 2147483647));
}
template <>
inline float32x4x2_t QConverter::convert(const int16x8_t& vsrc) {
int32x4_t vhi = vmovl_s16(vget_high_s16(vsrc));
int32x4_t vlo = vmovl_s16(vget_low_s16(vsrc));
return {{vcvtq_f32_s32(vlo), vcvtq_f32_s32(vhi)}};
}
template <>
inline float32x4x2_t QConverter::convert(const uint16x8_t& vsrc) {
uint32x4_t vhi = vmovl_u16(vget_high_u16(vsrc));
uint32x4_t vlo = vmovl_u16(vget_low_u16(vsrc));
return {{vcvtq_f32_u32(vlo), vcvtq_f32_u32(vhi)}};
}
#if __ARM_ARCH >= 8
template <>
inline int8x8_t QConverter::convert(const float32x4x2_t& vsrc) {
......
......@@ -17,6 +17,7 @@
#include "src/common/utils.h"
#include "src/naive/handle.h"
MIDOUT_DECL(megdnn_arm_typecvt_fix2float)
MIDOUT_DECL(megdnn_arm_typecvt_quantized)
MIDOUT_DECL(megdnn_arm_typecvt_float)
......@@ -325,6 +326,48 @@ struct FloatTypeCvter<float, __fp16> {
};
#endif
template <typename ctype, typename dtype>
struct Fix2FloatTypeCvter;
template <>
struct Fix2FloatTypeCvter<int16_t, float> {
using stype = int16_t;
using dst_type = float;
static constexpr size_t SIMD_WIDTH = 8;
Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) {
MEGDNN_MARK_USED_VAR(src_dtype);
MEGDNN_MARK_USED_VAR(dst_dtype);
}
void cvt(const int16_t* src, float* dst) {
int16x8_t vitem = vld1q_s16(src);
auto vres = QConverter::convert<float32x4x2_t, int16x8_t>(vitem);
vst1q_f32_x2(dst, vres);
}
void cvt_remain(const int16_t* src, float* dst) { *dst = *src; }
};
template <>
struct Fix2FloatTypeCvter<uint16_t, float> {
using stype = uint16_t;
using dst_type = float;
static constexpr size_t SIMD_WIDTH = 8;
Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) {
MEGDNN_MARK_USED_VAR(src_dtype);
MEGDNN_MARK_USED_VAR(dst_dtype);
}
void cvt(const uint16_t* src, float* dst) {
uint16x8_t vitem = vld1q_u16(src);
auto vres = QConverter::convert<float32x4x2_t, uint16x8_t>(vitem);
vst1q_f32_x2(dst, vres);
}
void cvt_remain(const uint16_t* src, float* dst) { *dst = *src; }
};
template <typename TypeCvter>
void do_typecvt(
const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst,
......@@ -347,6 +390,43 @@ void do_typecvt(
}
}
template <typename TypeCvter>
void do_typecvt(
const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst,
DType src_dtype, DType dst_dtype, const TensorLayout& src_layout) {
TypeCvter typecvt(src_dtype, dst_dtype);
size_t calc_num = 1;
size_t nr_elems = src_layout.total_nr_elems();
size_t src_stride = nr_elems;
//! adjust calc_num nr_elems and src_stride according to src_collapse_layout
auto src_collapse_layout = src_layout.collapse_contiguous();
if (src_collapse_layout.ndim == 2) {
calc_num = src_collapse_layout.shape[0];
nr_elems = src_collapse_layout.shape[1];
src_stride = src_collapse_layout.stride[0];
}
for (size_t c = 0; c < calc_num; ++c) {
size_t i = 0;
for (; i + TypeCvter::SIMD_WIDTH <= nr_elems; i += TypeCvter::SIMD_WIDTH) {
typecvt.cvt(src, dst);
src += TypeCvter::SIMD_WIDTH;
dst += TypeCvter::SIMD_WIDTH;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for (; i < nr_elems; i++) {
typecvt.cvt_remain(src, dst);
src++;
dst++;
}
src += src_stride - nr_elems;
}
}
} // anonymous namespace
void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
......@@ -354,7 +434,30 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
DType dst_dtype = dst.layout.dtype;
size_t nr_elems = src.layout.total_nr_elems();
bool execed = false;
if (src.layout.is_contiguous()) {
auto src_collapse_layout = src.layout.collapse_contiguous();
bool has_int16_special_impl =
(src.layout.dtype.enumv() == DTypeEnum::Int16 ||
src.layout.dtype.enumv() == DTypeEnum::Uint16) &&
(src.layout.is_contiguous() || src_collapse_layout.ndim == 2) &&
dst.layout.is_contiguous();
if (has_int16_special_impl) {
using namespace dtype;
#define DISPATCH_FIX2FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \
dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \
MIDOUT_BEGIN(megdnn_arm_typecvt_fix2float, midout_iv(_midout_iv)) { \
using _TypeCvter = Fix2FloatTypeCvter<_stype, _dtype>; \
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \
src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \
src_dtype, dst_dtype, src.layout)); \
execed = true; \
} \
MIDOUT_END(); \
}
DISPATCH_FIX2FLOAT(Int16, int16_t, Float32, float, 0);
DISPATCH_FIX2FLOAT(Uint16, uint16_t, Float32, float, 1);
#undef DISPATCH_FIX2FLOAT
} else if (src.layout.is_contiguous()) {
using namespace dtype;
#define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \
......@@ -377,6 +480,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS32, int32_t, 5);
DISPATCH_QUANTIZED(float, float, QuantizedS8, int8_t, 6);
DISPATCH_QUANTIZED(float, float, Quantized8Asymm, uint8_t, 7);
#undef DISPATCH_QUANTIZED
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#define DISPATCH_FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \
......@@ -394,6 +498,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
}
DISPATCH_FLOAT(dt_float16, __fp16, float, float, 0);
DISPATCH_FLOAT(float, float, dt_float16, __fp16, 1);
#undef DISPATCH_FLOAT
#endif
}
if (!execed) {
......
......@@ -150,6 +150,19 @@ bool ElemwiseLayoutHelper::is_broadcasted_channel_like(
return false;
}
bool ElemwiseLayoutHelper::is_NHWC_broadcasted_channel_like(
const TensorLayout& layout, BroadcastChannelInfo& info) {
if (layout.format.type() == TensorFormat::Type::DEFAULT) {
if (layout.ndim == 2 && layout.stride[1] == 1 && layout.stride[0] == 0) {
info.x = 1;
info.y = layout.shape[0];
info.z = layout.shape[1];
return true;
}
}
return false;
}
bool ElemwiseLayoutHelper::is_broadcasted_1x(
const TensorLayout& layout, Broadcast1xInfo& binfo) {
if (layout.ndim == 2 && layout.stride[0] == 0 && layout.stride[1] == 1) {
......
......@@ -80,6 +80,16 @@ public:
static bool is_broadcasted_channel_like(
const TensorLayout& layout, BroadcastChannelInfo& info);
/*!
* \brief check whether layout matches BroadcastChannelInfo under NHWC
* layout
*
* Note that Input must be 2-dimensional, and must be [1, y] broadacsted
* into [z, y] and x would be set to 1.
*/
static bool is_NHWC_broadcasted_channel_like(
const TensorLayout& layout, BroadcastChannelInfo& info);
/*!
* \brief check whether layout matches BroadcastChannelInfo
*
......
......@@ -309,7 +309,8 @@ void on_dest_ctype(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
break; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool) case DTypeEnum::QuantizedS8
cb(::megdnn::dtype::Bool)
cb(::megdnn::dtype::Uint16) case DTypeEnum::QuantizedS8
: MIDOUT_BEGIN(
megdnn_fb_typecvt_src_dtype,
midout_iv(DTypeEnum::QuantizedS8)) {
......@@ -467,7 +468,8 @@ void run_contiguous(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool) case DTypeEnum::QuantizedS8
cb(::megdnn::dtype::Bool)
cb(::megdnn::dtype::Uint16) case DTypeEnum::QuantizedS8
: MIDOUT_BEGIN(
megdnn_fb_typecvt_dst_dtype,
midout_iv(DTypeEnum::QuantizedS8)) {
......
......@@ -78,7 +78,7 @@ void on_dest_ctype(HandleImpl* handle, const TensorND& dest, const TensorND& src
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
cb(::megdnn::dtype::Bool)
cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16)
#undef cb
default : megdnn_throw("bad dtype");
}
......@@ -99,7 +99,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
cb(::megdnn::dtype::Bool)
cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16)
#undef cb
default : megdnn_throw("bad dtype");
}
......
......@@ -14,6 +14,7 @@
#include "test/common/benchmarker.h"
#include "test/common/checker.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/general.h"
using namespace megdnn;
......@@ -298,6 +299,63 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) {
#endif
}
TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NHWC_FP32_BCAST) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle());
UniformFloatRNG rng(1e-5, 7e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());
//! 2 dim
auto run = [&](Mode mode) {
// VEC_BCAST111C
checker.set_param(mode).execs({{1, 2, 2, 12}, {1, 1, 1, 12}, {}});
checker.set_param(mode).execs({{2, 5, 3, 28}, {1, 1, 1, 28}, {}});
checker.set_param(mode).execs({{3, 5, 8, 32}, {1, 1, 1, 32}, {}});
// BCAST111C_VEC
checker.set_param(mode).execs({{1, 1, 1, 12}, {1, 2, 2, 12}, {}});
checker.set_param(mode).execs({{1, 1, 1, 28}, {2, 5, 3, 28}, {}});
checker.set_param(mode).execs({{1, 1, 1, 32}, {3, 5, 8, 32}, {}});
};
run(Mode::ADD);
run(Mode::MUL);
run(Mode::SUB);
//! 3 dim contig
auto run_3d_contig = [&](Mode mode) {
// BCAST111C_VEC_BCAST111C
checker.set_param(mode).execs(
{{1, 1, 1, 12}, {1, 2, 2, 12}, {1, 1, 1, 12}, {}});
checker.set_param(mode).execs(
{{1, 1, 1, 28}, {2, 5, 3, 28}, {1, 1, 1, 28}, {}});
checker.set_param(mode).execs(
{{1, 1, 1, 32}, {3, 5, 8, 32}, {1, 1, 1, 32}, {}});
// VEC_BCAST111C_VEC
checker.set_param(mode).execs(
{{1, 2, 2, 12}, {1, 1, 1, 12}, {1, 2, 2, 12}, {}});
checker.set_param(mode).execs(
{{2, 5, 3, 28}, {1, 1, 1, 28}, {2, 5, 3, 28}, {}});
checker.set_param(mode).execs(
{{3, 5, 8, 32}, {1, 1, 1, 32}, {3, 5, 8, 32}, {}});
};
run_3d_contig(Mode::FUSE_MUL_ADD3);
//! 3 dim incontig
auto run_3d_incontig = [&](Mode mode) {
megdnn::TensorLayout src0({1, 1, 1, 12}, dtype::Float32());
megdnn::TensorLayout src1({1, 2, 2, 12}, {80, 40, 20, 1}, dtype::Float32());
// BCAST111C_VEC_BCAST111C
checker.set_param(mode).execl({src0, src1, src0, {}});
// VEC_BCAST111C_VEC
checker.set_param(mode).execl({src1, src0, src1, {}});
};
run_3d_incontig(Mode::FUSE_MUL_ADD3);
}
#if MEGDNN_WITH_BENCHMARK
namespace {
void run_elemwise_benchmark(
......@@ -354,6 +412,39 @@ void run_elemwise_benchmark(
}
} // namespace
TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NHWC) {
Benchmarker<Elemwise> benchmarker(handle());
constexpr size_t RUN = 50;
benchmarker.set_times(RUN).set_display(false);
auto run = [&](size_t N, size_t C, size_t H, size_t W, param::Elemwise::Mode mode,
const char* mode_name) {
megdnn::param::Elemwise param;
param.mode = mode;
benchmarker.set_param(param);
megdnn::TensorShape nhwc_src0{N, H, W, C};
megdnn::TensorShape nhwc_src1{1, 1, 1, C};
megdnn::TensorShape nchw_src0{N, C, H, W};
megdnn::TensorShape nchw_src1{1, C, 1, 1};
float computations = N * C * H * W;
auto nhwc_time = benchmarker.execs({nhwc_src1, nhwc_src0, {}}) / RUN;
auto nchw_time = benchmarker.execs({nchw_src1, nchw_src0, {}}) / RUN;
auto perf_nhwc = computations / nhwc_time / 1e6;
auto perf_nchw = computations / nchw_time / 1e6;
printf("Elemwise Mode : %s\nNHWC : %fms %fGflops\nNCHW : %fms "
"%fGflops\n",
mode_name, nhwc_time, perf_nhwc, nchw_time, perf_nchw);
};
run(1, 120, 16, 24, param::Elemwise::Mode::ADD, "ADD");
run(1, 120, 16, 24, param::Elemwise::Mode::MUL, "MUL");
run(1, 120, 32, 48, param::Elemwise::Mode::ADD, "ADD");
run(1, 120, 32, 48, param::Elemwise::Mode::MUL, "MUL");
run(1, 120, 64, 96, param::Elemwise::Mode::ADD, "ADD");
run(1, 120, 64, 96, param::Elemwise::Mode::MUL, "MUL");
}
#define INT_RUN(shape, mode) \
run_elemwise_benchmark(shape, mode, #mode, dtype::Int8{}, handle()); \
run_elemwise_benchmark(shape, mode, #mode, dtype::Int16{}, handle()); \
......
......@@ -88,6 +88,26 @@ TEST_F(ARM_COMMON, TYPE_CVT) {
.execs({{1, 32, 24, 128}, {1, 32, 24, 128}});
}
TEST_F(ARM_COMMON, TYPE_CVT_16_F32) {
Checker<TypeCvt> checker(handle());
UniformIntRNG rng{INT16_MIN >> 1, INT16_MAX >> 1};
for (size_t size : {3, 7, 15, 33, 10000}) {
checker.set_rng(0, &rng);
checker.set_dtype(0, dtype::Int16()).execs({{size}, {size}});
checker.set_dtype(0, dtype::Uint16()).execs({{size}, {size}});
}
TensorLayout src_int16{
{1, 96, 64, 120}, {128 * 64 * 96, 128 * 64, 128, 1}, dtype::Int16()};
TensorLayout dst_int16{{1, 96, 64, 120}, dtype::Float32()};
checker.execl({src_int16, dst_int16});
TensorLayout src_uint16{
{1, 96, 64, 120}, {128 * 64 * 96, 128 * 64, 128, 1}, dtype::Uint16()};
TensorLayout dst_uint16{{1, 96, 64, 120}, dtype::Float32()};
checker.execl({src_uint16, dst_uint16});
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(ARM_COMMON, BENCHMARK_TYPE_CVT) {
auto run = [&](const TensorShapeArray& shapes) {
......
......@@ -158,6 +158,7 @@ void copy_tensors(
//! In order to avoid an unnecessary increase in binary size, we just
//! use QuantizedS16 dtype in winograd_filter_preprocess now.
cb(::megdnn::dtype::QuantizedS16) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
cb(::megdnn::dtype::Uint16)
#undef cb
default : megdnn_trap();
}
......
......@@ -202,6 +202,9 @@ void IIDRNG::gen(const TensorND& tensor) {
memset(tensor.raw_ptr, 0, tensor.layout.access_bytes());
return;
}
if (tensor.layout.dtype.enumv() == DTypeEnum::Uint16) {
return;
}
megdnn_assert(
0, "IIDRNG does not know how to generate value for DType %s",
tensor.layout.dtype.name());
......
......@@ -25,6 +25,11 @@ TEST_F(CUDA, TYPE_CVT) {
TensorLayout src({10, 10}, sdtype), dst({10, 10}, ddtype);
Checker<TypeCvt> checker(handle_cuda());
checker.set_rng(0, &init).exec(TensorLayoutArray{src, dst});
TensorLayout non_contig_src(
{1, 96, 64, 120}, {96 * 64 * 128, 64 * 128, 128, 1}, sdtype);
TensorLayout non_contig_dst({1, 96, 64, 120}, ddtype);
checker.exec(TensorLayoutArray{non_contig_src, non_contig_dst});
}
}
......
......@@ -37,8 +37,22 @@ TEST_F(X86, TYPE_CVT) {
for (auto ddtype : dtypes) {
checker.set_dtype(0, sdtype).set_dtype(1, ddtype).execs(
{{size}, {size}});
TensorLayout non_contig_src(
{1, 10, 10, 12}, {10 * 10 * 18, 10 * 18, 18, 1}, sdtype);
TensorLayout non_contig_dst({1, 10, 10, 12}, ddtype);
checker.exec(TensorLayoutArray{non_contig_src, non_contig_dst});
}
}
for (size_t size : {1, 7, 15, 33}) {
checker.set_dtype(0, dtype::Uint16())
.set_dtype(1, dtype::Float32())
.execs({{size}, {size}});
}
TensorLayout non_contig_src(
{1, 10, 10, 12}, {10 * 10 * 18, 10 * 18, 18, 1}, dtype::Uint16());
TensorLayout non_contig_dst({1, 10, 10, 12}, dtype::Float32());
checker.exec(TensorLayoutArray{non_contig_src, non_contig_dst});
}
TEST_F(X86, TYPE_CVT_NO_CONTIGUOUS) {
......
......@@ -772,8 +772,10 @@ void TypeCvt::perform(
}
void TypeCvt::add_input_layout_constraint() {
//! Because the implementation of typecvt on arm/x86/cuda/opencl support
//! non-contiguous memory. So we change constraint of typecvt to monotone
for (auto i : input()) {
i->add_layout_constraint_contiguous();
i->add_layout_constraint_monotone();
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册