diff --git a/dnn/src/arm_common/elemwise_helper/elemwise_op.h b/dnn/src/arm_common/elemwise_helper/elemwise_op.h index 96b25e7698e26a1224487a011ba2e8efe4144900..2e81ed1b7af56b47369c3bc418120da055b47505 100644 --- a/dnn/src/arm_common/elemwise_helper/elemwise_op.h +++ b/dnn/src/arm_common/elemwise_helper/elemwise_op.h @@ -13,18 +13,6 @@ using BcastType = megdnn::elemwise::BcastType; ///////////////////////////////// ParamElemVistor /////////////////////////// #define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, _neon_type_v2) \ - template <> \ - struct ParamElemVisitor<_ctype> { \ - _neon_type operator()(const _ctype* src) const { \ - return vld1q_##_fun_suffix(reinterpret_cast(src)); \ - } \ - }; \ - template <> \ - struct ParamElemVisitorDup<_ctype> { \ - _neon_type operator()(const _ctype* src) const { \ - return vdupq_n_##_fun_suffix(*reinterpret_cast(src)); \ - } \ - }; \ template <> \ struct ParamElemVisitorV2<_ctype> { \ _neon_type_v2 operator()(const _ctype* src, const _ctype* src_1) const { \ @@ -53,16 +41,7 @@ cb(__fp16, __fp16, float16x8_t, f16, float16x8x2_t); cb(dt_int16, int16_t, int16x8_t, s16, int16x8x2_t); #undef cb -template -struct ParamElemVisitorBcast101x4; #define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, rel_suffix, _neon_type_v2) \ - template <> \ - struct ParamElemVisitorBcast101x4<_ctype> { \ - _neon_type operator()(const _ctype* src) const { \ - return vreinterpretq_##_fun_suffix##_##rel_suffix(vld1q_dup_##rel_suffix( \ - reinterpret_cast(src))); \ - } \ - }; \ template <> \ struct ParamElemVisitorBcast101x4V2<_ctype> { \ _neon_type_v2 operator()(const _ctype* src) const { \ @@ -83,16 +62,20 @@ cb(__fp16, uint64_t, float16x8_t, f16, u64, float16x8x2_t); #undef cb template -struct ParamElemVisitorBcast101x8; -#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ - template <> \ - struct ParamElemVisitorBcast101x8<_ctype> { \ - _neon_type operator()(const _ctype* src) const { \ - return vld1q_##_fun_suffix(reinterpret_cast(src)); \ - } \ +struct ParamElemVisitorBcast101x8V2; +#define cb(_ctype, _inner_ctype, _neon_type_v2, _fun_suffix) \ + template <> \ + struct ParamElemVisitorBcast101x8V2<_ctype> { \ + _neon_type_v2 operator()(const _ctype* src) const { \ + _neon_type_v2 ret; \ + ret.val[0] = \ + vld1q_##_fun_suffix(reinterpret_cast(src)); \ + ret.val[1] = ret.val[0]; \ + return ret; \ + } \ } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -cb(__fp16, __fp16, float16x8_t, f16); +cb(__fp16, __fp16, float16x8x2_t, f16); #endif #undef cb @@ -106,8 +89,8 @@ struct OpCallerBinaryBcast101xXVec<__fp16, 8> { const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, const Op& op, size_t batch, size_t nr_channel_blocks, size_t channel_stride) { - ParamElemVisitorBcast101x8 vis0; - ParamElemVisitor vis1; + ParamElemVisitorBcast101x8V2 vis0; + ParamElemVisitorV2 vis1; OpCallerBinaryBcast101xDVec::run( src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, channel_stride); @@ -122,8 +105,8 @@ struct OpCallerBinaryVecBcast101xX<__fp16, 8> { const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, const Op& op, size_t batch, size_t nr_channel_blocks, size_t channel_stride) { - ParamElemVisitor vis0; - ParamElemVisitorBcast101x8 vis1; + ParamElemVisitorV2 vis0; + ParamElemVisitorBcast101x8V2 vis1; OpCallerBinaryVecBcast101xD::run( src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, channel_stride); @@ -138,9 +121,9 @@ struct OpCallerTernaryBcast101xXVecBcast101xX<__fp16, 8> { const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, typename Op::dst_ctype* dst, const Op& op, size_t batch, size_t nr_channel_blocks, size_t channel_stride) { - ParamElemVisitorBcast101x8 vis0; - ParamElemVisitor vis1; - ParamElemVisitorBcast101x8 vis2; + ParamElemVisitorBcast101x8V2 vis0; + ParamElemVisitorV2 vis1; + ParamElemVisitorBcast101x8V2 vis2; OpCallerTernaryBcast101xDVecBcast101xD::run( src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, channel_stride); @@ -155,9 +138,9 @@ struct OpCallerTernaryVecBcast101xXVec<__fp16, 8> { const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, typename Op::dst_ctype* dst, const Op& op, size_t batch, size_t nr_channel_blocks, size_t channel_stride) { - ParamElemVisitor vis0; - ParamElemVisitorBcast101x8 vis1; - ParamElemVisitor vis2; + ParamElemVisitorV2 vis0; + ParamElemVisitorBcast101x8V2 vis1; + ParamElemVisitorV2 vis2; OpCallerTernaryVecBcast101xDVec::run( src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, channel_stride); diff --git a/dnn/src/fallback/elemwise_helper/op_common.h b/dnn/src/fallback/elemwise_helper/op_common.h index f0ca73eb2c1ef57d13e77cf418a2518ed30f9bae..ab79186c658d8a151d480ab7c5ee9b6ca4057f01 100644 --- a/dnn/src/fallback/elemwise_helper/op_common.h +++ b/dnn/src/fallback/elemwise_helper/op_common.h @@ -36,66 +36,6 @@ enum BcastType { UNKNOWN_BCAST_TYPE }; -///////////////////////////////// ParamElemVistor /////////////////////////// -template -struct ParamElemVisitor; - -//! visitor single elemwise, and dup to vector -template -struct ParamElemVisitorDup; - -template -struct ParamElemVisitorBcast101x4; - -#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ - template <> \ - struct ParamElemVisitor<_ctype> { \ - _simd_type operator()(const _ctype* src) const { \ - return GiLoad##_fun_suffix(src); \ - } \ - }; \ - template <> \ - struct ParamElemVisitorDup<_ctype> { \ - _simd_type operator()(const _ctype* src) const { \ - return GiBroadcast##_fun_suffix( \ - *reinterpret_cast(src)); \ - } \ - } -cb(dt_qint32, int32_t, GI_INT32_t, Int32); -cb(dt_qint8, int8_t, GI_INT8_t, Int8); - -cb(dt_float32, float, GI_FLOAT32_t, Float32); -cb(dt_int32, int32_t, GI_INT32_t, Int32); -cb(dt_int8, int8_t, GI_INT8_t, Int8); -#undef cb - -template -struct ParamElemVisitorBcast101x4; -#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \ - template <> \ - struct ParamElemVisitorBcast101x4<_ctype> { \ - _simd_type operator()(const _ctype* src) const { \ - return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \ - *reinterpret_cast(src))); \ - } \ - } - -cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32); -cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32); -#undef cb -#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ - template <> \ - struct ParamElemVisitorBcast101x4<_ctype> { \ - _simd_type operator()(const _ctype* src) const { \ - return GiLoad##_fun_suffix(src); \ - } \ - } - -cb(dt_qint32, int32_t, GI_INT32_t, Int32); -cb(dt_float32, float, GI_FLOAT32_t, Float32); -cb(dt_int32, int32_t, GI_INT32_t, Int32); -#undef cb - ///////////////////////////////// ParamElemVistor v2/////////////////////////// template struct ParamElemVisitorV2;