From f6d9909460ed2400406de0fc7d73806ded5e19f5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 17 Dec 2021 11:50:48 +0800 Subject: [PATCH] feat(dnn): add elemwise multi type support i16xf32 and u8xf32 GitOrigin-RevId: 2fe469bb4ec9a0b7d20f88a54d2a87e7ad42385b --- dnn/scripts/opr_param_defs.py | 11 +- .../elemwise_multi_type/kernels.cpp | 707 ++++++++++++++++++ .../arm_common/elemwise_multi_type/kernels.h | 71 ++ .../elemwise_multi_type/opr_impl.cpp | 149 ++++ .../arm_common/elemwise_multi_type/opr_impl.h | 9 + .../common/elemwise_multi_type/opr_impl.cpp | 26 + .../elemwise_multi_type/opr_impl_helper.cpp | 11 + .../elemwise_multi_type/opr_impl_helper.h | 21 + .../fallback/elemwise_multi_type/opr_impl.cpp | 210 ++++++ .../fallback/elemwise_multi_type/opr_impl.h | 6 + .../naive/elemwise_multi_type/opr_impl.cpp | 60 ++ dnn/src/naive/elemwise_multi_type/opr_impl.h | 6 + dnn/test/arm_common/elemwise_multi_type.cpp | 103 +++ dnn/test/common/elemwise_multi_type.cpp | 67 ++ dnn/test/common/elemwise_multi_type.h | 11 +- dnn/test/fallback/elemwise_multi_type.cpp | 18 + 16 files changed, 1481 insertions(+), 5 deletions(-) create mode 100644 dnn/src/arm_common/elemwise_multi_type/kernels.cpp create mode 100644 dnn/src/arm_common/elemwise_multi_type/kernels.h diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 8eba91999..76220c99e 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -497,7 +497,16 @@ pdef('ElemwiseMultiType').add_enum( Doc('QCOND_LEQ_MOV = 50', 'quantized cond_leq_mov'), Doc('QH_SWISH = 51', 'quantized h_swish'), Doc('QFUSE_ADD_H_SWISH = 52', 'quantized h_swish(x+y)'), - Doc('QH_SWISH_GRAD = 53', 'quantized h_swish_grad') + Doc('QH_SWISH_GRAD = 53', 'quantized h_swish_grad'), + Doc('FUSE_MUL_ADD3_INT16xF32xF32xF32 = 54', + 'compute ``a * b + c`` requiring that ``a`` be int16 and ``b`` and ' + '``c`` float32, and the result is float32.'), + Doc('MUL_INT16xF32xF32 = 55', + 'compute ``a * b `` requiring that ``a`` be int16 and ``b`` float32, ' + 'and the result is float32.'), + Doc('FUSE_MUL_ADD3_UINT8xF32xF32xF32 = 56', + 'compute ``a * b + c`` requiring that ``a`` be uint8 and ``b`` and ' + '``c`` float32, and the result is float32.') ) pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) diff --git a/dnn/src/arm_common/elemwise_multi_type/kernels.cpp b/dnn/src/arm_common/elemwise_multi_type/kernels.cpp new file mode 100644 index 000000000..7dfa10511 --- /dev/null +++ b/dnn/src/arm_common/elemwise_multi_type/kernels.cpp @@ -0,0 +1,707 @@ +/** + * \file dnn/src/arm_common/elemwise_multi_type/kernels.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "kernels.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace arm_common { + +#if defined(__ARM_FEATURE_FMA) +#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m) +#else +#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m) +#endif + +void neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast111c_bcast111c( + size_t batch_size, size_t channel_stride, size_t channel_size, + const int16_t* src0, const float* src1, const float* src2, float* dst) { + const int16_t* __restrict sptr0 = src0; + const float* __restrict sptr1 = src1; + const float* __restrict sptr2 = src2; + float* __restrict dst_ptr = dst; + for (size_t batch = 0; batch < batch_size; ++batch) { + for (size_t s = 0; s < channel_stride; ++s) { + size_t i = 0; + for (; i + 15 < channel_size; i += 16, sptr0 += 16, dst_ptr += 16) { + auto vec0_01 = vld1q_s16(sptr0); + auto vec0_23 = vld1q_s16(sptr0 + 8); + auto vec1_0 = vld1q_f32(sptr1 + i); + auto vec1_1 = vld1q_f32(sptr1 + i + 4); + auto vec1_2 = vld1q_f32(sptr1 + i + 8); + auto vec1_3 = vld1q_f32(sptr1 + i + 12); + auto vec2_0 = vld1q_f32(sptr2 + i); + auto vec2_1 = vld1q_f32(sptr2 + i + 4); + auto vec2_2 = vld1q_f32(sptr2 + i + 8); + auto vec2_3 = vld1q_f32(sptr2 + i + 12); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); + auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); + + auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0); + auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1); + auto dst_vec_2 = Vfmaq_f32(vec2_2, vec0_2, vec1_2); + auto dst_vec_3 = Vfmaq_f32(vec2_3, vec0_3, vec1_3); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + vst1q_f32(dst_ptr + 8, dst_vec_2); + vst1q_f32(dst_ptr + 12, dst_vec_3); + } + for (; i + 7 < channel_size; i += 8, sptr0 += 8, dst_ptr += 8) { + auto vec0_01 = vld1q_s16(sptr0); + auto vec1_0 = vld1q_f32(sptr1 + i); + auto vec1_1 = vld1q_f32(sptr1 + i + 4); + auto vec2_0 = vld1q_f32(sptr2 + i); + auto vec2_1 = vld1q_f32(sptr2 + i + 4); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + + auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0); + auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + } + for (; i + 3 < channel_size; i += 4, sptr0 += 4, dst_ptr += 4) { + auto vec0_0 = vld1_s16(sptr0); + auto vec1_0 = vld1q_f32(sptr1 + i); + auto vec2_0 = vld1q_f32(sptr2 + i); + + auto vec0_0_f32 = vcvtq_f32_s32(vmovl_s16(vec0_0)); + + auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0_f32, vec1_0); + + vst1q_f32(dst_ptr, dst_vec_0); + } + for (; i < channel_size; ++i, ++sptr0, ++dst_ptr) { + *dst_ptr = (float)(*sptr0) * sptr1[i] + sptr2[i]; + } + } + } +} + +void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast111c_bcast111c( + size_t batch_size, size_t channel_stride, size_t channel_size, + const uint8_t* src0, const float* src1, const float* src2, float* dst) { + const uint8_t* __restrict sptr0 = src0; + const float* __restrict sptr1 = src1; + const float* __restrict sptr2 = src2; + float* __restrict dst_ptr = dst; + for (size_t batch = 0; batch < batch_size; ++batch) { + for (size_t s = 0; s < channel_stride; ++s) { + size_t i = 0; + for (; i + 15 < channel_size; i += 16, sptr0 += 16, dst_ptr += 16) { + auto vec0_0123_u8 = vld1q_u8(sptr0); + auto vec1_0 = vld1q_f32(sptr1 + i); + auto vec1_1 = vld1q_f32(sptr1 + i + 4); + auto vec1_2 = vld1q_f32(sptr1 + i + 8); + auto vec1_3 = vld1q_f32(sptr1 + i + 12); + auto vec2_0 = vld1q_f32(sptr2 + i); + auto vec2_1 = vld1q_f32(sptr2 + i + 4); + auto vec2_2 = vld1q_f32(sptr2 + i + 8); + auto vec2_3 = vld1q_f32(sptr2 + i + 12); + + auto vec0_01 = + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(vec0_0123_u8))); + auto vec0_23 = + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(vec0_0123_u8))); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); + auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); + + auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0); + auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1); + auto dst_vec_2 = Vfmaq_f32(vec2_2, vec0_2, vec1_2); + auto dst_vec_3 = Vfmaq_f32(vec2_3, vec0_3, vec1_3); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + vst1q_f32(dst_ptr + 8, dst_vec_2); + vst1q_f32(dst_ptr + 12, dst_vec_3); + } + for (; i + 7 < channel_size; i += 8, sptr0 += 8, dst_ptr += 8) { + auto vec0_01_u8 = vld1_u8(sptr0); + auto vec1_0 = vld1q_f32(sptr1 + i); + auto vec1_1 = vld1q_f32(sptr1 + i + 4); + auto vec2_0 = vld1q_f32(sptr2 + i); + auto vec2_1 = vld1q_f32(sptr2 + i + 4); + + auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vec0_01_u8)); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + + auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0); + auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + } + for (; i < channel_size; ++i, ++sptr0, ++dst_ptr) { + *dst_ptr = (float)(*sptr0) * sptr1[i] + sptr2[i]; + } + } + } +} + +void neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast101_bcast101( + size_t batch_size, size_t channel_size, size_t channel_stride, + const int16_t* src0, const float* src1, const float* src2, float* dst) { + const int16_t* __restrict sptr0 = src0; + const float* __restrict sptr1 = src1; + const float* __restrict sptr2 = src2; + float* __restrict dst_ptr = dst; + for (size_t batch = 0; batch < batch_size; ++batch) { + for (size_t chan = 0; chan < channel_size; ++chan) { + auto vec1 = vdupq_n_f32(sptr1[chan]); + auto vec2 = vdupq_n_f32(sptr2[chan]); + size_t i = 0; + for (; i + 15 < channel_stride; i += 16, sptr0 += 16, dst_ptr += 16) { + auto vec0_01 = vld1q_s16(sptr0); + auto vec0_23 = vld1q_s16(sptr0 + 8); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); + auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); + + auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1); + auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1); + auto dst_vec_2 = Vfmaq_f32(vec2, vec0_2, vec1); + auto dst_vec_3 = Vfmaq_f32(vec2, vec0_3, vec1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + vst1q_f32(dst_ptr + 8, dst_vec_2); + vst1q_f32(dst_ptr + 12, dst_vec_3); + } + for (; i + 7 < channel_stride; i += 8, sptr0 += 8, dst_ptr += 8) { + auto vec0_01 = vld1q_s16(sptr0); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + + auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1); + auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + } + for (; i + 3 < channel_stride; i += 4, sptr0 += 4, dst_ptr += 4) { + auto vec0_0 = vld1_s16(sptr0); + auto vec0_0_f32 = vcvtq_f32_s32(vmovl_s16(vec0_0)); + auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0_f32, vec1); + + vst1q_f32(dst_ptr, dst_vec_0); + } + for (; i < channel_stride; ++i, ++sptr0, ++dst_ptr) { + *dst_ptr = (float)(*sptr0) * sptr1[chan] + sptr2[chan]; + } + } + } +} + +void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast101_bcast101( + size_t batch_size, size_t channel_size, size_t channel_stride, + const uint8_t* src0, const float* src1, const float* src2, float* dst) { + const uint8_t* __restrict sptr0 = src0; + const float* __restrict sptr1 = src1; + const float* __restrict sptr2 = src2; + float* __restrict dst_ptr = dst; + for (size_t batch = 0; batch < batch_size; ++batch) { + for (size_t chan = 0; chan < channel_size; ++chan) { + auto vec1 = vdupq_n_f32(sptr1[chan]); + auto vec2 = vdupq_n_f32(sptr2[chan]); + size_t i = 0; + for (; i + 15 < channel_stride; i += 16, sptr0 += 16, dst_ptr += 16) { + auto vec0_0123_u8 = vld1q_u8(sptr0); + + auto vec0_01 = + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(vec0_0123_u8))); + auto vec0_23 = + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(vec0_0123_u8))); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); + auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); + + auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1); + auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1); + auto dst_vec_2 = Vfmaq_f32(vec2, vec0_2, vec1); + auto dst_vec_3 = Vfmaq_f32(vec2, vec0_3, vec1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + vst1q_f32(dst_ptr + 8, dst_vec_2); + vst1q_f32(dst_ptr + 12, dst_vec_3); + } + for (; i + 7 < channel_stride; i += 8, sptr0 += 8, dst_ptr += 8) { + auto vec0_01_u8 = vld1_u8(sptr0); + auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vec0_01_u8)); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + + auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1); + auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + } + for (; i < channel_stride; ++i, ++sptr0, ++dst_ptr) { + *dst_ptr = (float)(*sptr0) * sptr1[chan] + sptr2[chan]; + } + } + } +} + +void neon_fuse_mul_add3_int16xf32xf32xf32_vec_vec_vec( + size_t size, const int16_t* src0, const float* src1, const float* src2, + float* dst) { + const int16_t* __restrict sptr0 = src0; + const float* __restrict sptr1 = src1; + const float* __restrict sptr2 = src2; + float* __restrict dst_ptr = dst; + size_t i = 0; + for (; i + 15 < size; + i += 16, sptr0 += 16, sptr1 += 16, sptr2 += 16, dst_ptr += 16) { + auto vec0_01 = vld1q_s16(sptr0); + auto vec0_23 = vld1q_s16(sptr0 + 8); + auto vec1_0 = vld1q_f32(sptr1); + auto vec1_1 = vld1q_f32(sptr1 + 4); + auto vec1_2 = vld1q_f32(sptr1 + 8); + auto vec1_3 = vld1q_f32(sptr1 + 12); + auto vec2_0 = vld1q_f32(sptr2); + auto vec2_1 = vld1q_f32(sptr2 + 4); + auto vec2_2 = vld1q_f32(sptr2 + 8); + auto vec2_3 = vld1q_f32(sptr2 + 12); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); + auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); + + auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0); + auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1); + auto dst_vec_2 = Vfmaq_f32(vec2_2, vec0_2, vec1_2); + auto dst_vec_3 = Vfmaq_f32(vec2_3, vec0_3, vec1_3); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + vst1q_f32(dst_ptr + 8, dst_vec_2); + vst1q_f32(dst_ptr + 12, dst_vec_3); + } + for (; i + 7 < size; i += 8, sptr0 += 8, sptr1 += 8, sptr2 += 8, dst_ptr += 8) { + auto vec0_01 = vld1q_s16(sptr0); + auto vec1_0 = vld1q_f32(sptr1); + auto vec1_1 = vld1q_f32(sptr1 + 4); + auto vec2_0 = vld1q_f32(sptr2); + auto vec2_1 = vld1q_f32(sptr2 + 4); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + + auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0); + auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + } + for (; i + 3 < size; i += 4, sptr0 += 4, sptr1 += 4, sptr2 += 4, dst_ptr += 4) { + auto vec0_0 = vld1_s16(sptr0); + auto vec1_0 = vld1q_f32(sptr1); + auto vec2_0 = vld1q_f32(sptr2); + + auto vec0_0_f32 = vcvtq_f32_s32(vmovl_s16(vec0_0)); + + auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0_f32, vec1_0); + + vst1q_f32(dst_ptr, dst_vec_0); + } + for (; i < size; ++i, ++sptr0, ++sptr1, ++sptr2, ++dst_ptr) { + *dst_ptr = (float)(*sptr0) * (*sptr1) + (*sptr2); + } +} + +void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_vec_vec( + size_t size, const uint8_t* src0, const float* src1, const float* src2, + float* dst) { + const uint8_t* __restrict sptr0 = src0; + const float* __restrict sptr1 = src1; + const float* __restrict sptr2 = src2; + float* __restrict dst_ptr = dst; + size_t i = 0; + for (; i + 15 < size; + i += 16, sptr0 += 16, sptr1 += 16, sptr2 += 16, dst_ptr += 16) { + auto vec0_0123 = vld1q_u8(sptr0); + auto vec1_0 = vld1q_f32(sptr1); + auto vec1_1 = vld1q_f32(sptr1 + 4); + auto vec1_2 = vld1q_f32(sptr1 + 8); + auto vec1_3 = vld1q_f32(sptr1 + 12); + auto vec2_0 = vld1q_f32(sptr2); + auto vec2_1 = vld1q_f32(sptr2 + 4); + auto vec2_2 = vld1q_f32(sptr2 + 8); + auto vec2_3 = vld1q_f32(sptr2 + 12); + + auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(vec0_0123))); + auto vec0_23 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(vec0_0123))); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); + auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); + + auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0); + auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1); + auto dst_vec_2 = Vfmaq_f32(vec2_2, vec0_2, vec1_2); + auto dst_vec_3 = Vfmaq_f32(vec2_3, vec0_3, vec1_3); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + vst1q_f32(dst_ptr + 8, dst_vec_2); + vst1q_f32(dst_ptr + 12, dst_vec_3); + } + + for (; i + 7 < size; i += 8, sptr0 += 8, sptr1 += 8, sptr2 += 8, dst_ptr += 8) { + auto vec0_01_u8 = vld1_u8(sptr0); + auto vec1_0 = vld1q_f32(sptr1); + auto vec1_1 = vld1q_f32(sptr1 + 4); + auto vec2_0 = vld1q_f32(sptr2); + auto vec2_1 = vld1q_f32(sptr2 + 4); + + auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vec0_01_u8)); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + + auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0); + auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + } + for (; i < size; ++i, ++sptr0, ++sptr1, ++sptr2, ++dst_ptr) { + *dst_ptr = (float)(*sptr0) * (*sptr1) + (*sptr2); + } +} + +void neon_fuse_mul_add3_int16xf32xf32xf32_vec_scaler_scaler( + size_t size, const int16_t* src0, const float* src1, const float* src2, + float* dst) { + const int16_t* __restrict sptr0 = src0; + const float* __restrict sptr1 = src1; + const float* __restrict sptr2 = src2; + auto vec1 = vdupq_n_f32(sptr1[0]); + auto vec2 = vdupq_n_f32(sptr2[0]); + float* __restrict dst_ptr = dst; + size_t i = 0; + for (; i + 15 < size; i += 16, sptr0 += 16, dst_ptr += 16) { + auto vec0_01 = vld1q_s16(sptr0); + auto vec0_23 = vld1q_s16(sptr0 + 8); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); + auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); + + auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1); + auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1); + auto dst_vec_2 = Vfmaq_f32(vec2, vec0_2, vec1); + auto dst_vec_3 = Vfmaq_f32(vec2, vec0_3, vec1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + vst1q_f32(dst_ptr + 8, dst_vec_2); + vst1q_f32(dst_ptr + 12, dst_vec_3); + } + for (; i + 7 < size; i += 8, sptr0 += 8, sptr1 += 8, sptr2 += 8, dst_ptr += 8) { + auto vec0_01 = vld1q_s16(sptr0); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + + auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1); + auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + } + for (; i + 3 < size; i += 4, sptr0 += 4, sptr1 += 4, sptr2 += 4, dst_ptr += 4) { + auto vec0_0 = vld1_s16(sptr0); + + auto vec0_0_f32 = vcvtq_f32_s32(vmovl_s16(vec0_0)); + + auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0_f32, vec1); + + vst1q_f32(dst_ptr, dst_vec_0); + } + for (; i < size; ++i, ++sptr0, ++dst_ptr) { + *dst_ptr = (float)(*sptr0) * (*sptr1) + (*sptr2); + } +} + +void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_scaler_scaler( + size_t size, const uint8_t* src0, const float* src1, const float* src2, + float* dst) { + const uint8_t* __restrict sptr0 = src0; + const float* __restrict sptr1 = src1; + const float* __restrict sptr2 = src2; + auto vec1 = vdupq_n_f32(sptr1[0]); + auto vec2 = vdupq_n_f32(sptr2[0]); + float* __restrict dst_ptr = dst; + size_t i = 0; + for (; i + 15 < size; i += 16, sptr0 += 16, dst_ptr += 16) { + auto vec0_0123 = vld1q_u8(sptr0); + + auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(vec0_0123))); + auto vec0_23 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(vec0_0123))); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); + auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); + + auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1); + auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1); + auto dst_vec_2 = Vfmaq_f32(vec2, vec0_2, vec1); + auto dst_vec_3 = Vfmaq_f32(vec2, vec0_3, vec1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + vst1q_f32(dst_ptr + 8, dst_vec_2); + vst1q_f32(dst_ptr + 12, dst_vec_3); + } + + for (; i + 7 < size; i += 8, sptr0 += 8, dst_ptr += 8) { + auto vec0_01_u8 = vld1_u8(sptr0); + + auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vec0_01_u8)); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + + auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1); + auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + } + for (; i < size; ++i, ++sptr0, ++dst_ptr) { + *dst_ptr = (float)(*sptr0) * (*sptr1) + (*sptr2); + } +} + +void neon_mul_int16xf32xf32_vec_bcast111c( + size_t batch_size, size_t channel_stride, size_t channel_size, + const int16_t* src0, const float* src1, float* dst) { + const int16_t* __restrict sptr0 = src0; + const float* __restrict sptr1 = src1; + float* __restrict dst_ptr = dst; + for (size_t batch = 0; batch < batch_size; ++batch) { + for (size_t s = 0; s < channel_stride; ++s) { + size_t i = 0; + for (; i + 15 < channel_size; i += 16, sptr0 += 16, dst_ptr += 16) { + auto vec0_01 = vld1q_s16(sptr0); + auto vec0_23 = vld1q_s16(sptr0 + 8); + auto vec1_0 = vld1q_f32(sptr1 + i); + auto vec1_1 = vld1q_f32(sptr1 + i + 4); + auto vec1_2 = vld1q_f32(sptr1 + i + 8); + auto vec1_3 = vld1q_f32(sptr1 + i + 12); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); + auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); + + auto dst_vec_0 = vmulq_f32(vec0_0, vec1_0); + auto dst_vec_1 = vmulq_f32(vec0_1, vec1_1); + auto dst_vec_2 = vmulq_f32(vec0_2, vec1_2); + auto dst_vec_3 = vmulq_f32(vec0_3, vec1_3); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + vst1q_f32(dst_ptr + 8, dst_vec_2); + vst1q_f32(dst_ptr + 12, dst_vec_3); + } + for (; i + 7 < channel_size; i += 8, sptr0 += 8, dst_ptr += 8) { + auto vec0_01 = vld1q_s16(sptr0); + auto vec1_0 = vld1q_f32(sptr1 + i); + auto vec1_1 = vld1q_f32(sptr1 + i + 4); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + + auto dst_vec_0 = vmulq_f32(vec0_0, vec1_0); + auto dst_vec_1 = vmulq_f32(vec0_1, vec1_1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + } + for (; i < channel_size; ++i, ++sptr0, ++dst_ptr) { + *dst_ptr = (float)(*sptr0) * sptr1[i]; + } + } + } +} + +void neon_mul_int16xf32xf32_vec_bcast101( + size_t batch_size, size_t channel_size, size_t channel_stride, + const int16_t* src0, const float* src1, float* dst) { + const int16_t* __restrict sptr0 = src0; + const float* __restrict sptr1 = src1; + float* __restrict dst_ptr = dst; + for (size_t batch = 0; batch < batch_size; ++batch) { + for (size_t chan = 0; chan < channel_size; ++chan) { + auto vec1 = vdupq_n_f32(sptr1[chan]); + size_t i = 0; + for (; i + 15 < channel_stride; i += 16, sptr0 += 16, dst_ptr += 16) { + auto vec0_01 = vld1q_s16(sptr0); + auto vec0_23 = vld1q_s16(sptr0 + 8); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); + auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); + + auto dst_vec_0 = vmulq_f32(vec0_0, vec1); + auto dst_vec_1 = vmulq_f32(vec0_1, vec1); + auto dst_vec_2 = vmulq_f32(vec0_2, vec1); + auto dst_vec_3 = vmulq_f32(vec0_3, vec1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + vst1q_f32(dst_ptr + 8, dst_vec_2); + vst1q_f32(dst_ptr + 12, dst_vec_3); + } + for (; i + 7 < channel_stride; i += 8, sptr0 += 8, dst_ptr += 8) { + auto vec0_01 = vld1q_s16(sptr0); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + + auto dst_vec_0 = vmulq_f32(vec0_0, vec1); + auto dst_vec_1 = vmulq_f32(vec0_1, vec1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + } + for (; i < channel_stride; ++i, ++sptr0, ++dst_ptr) { + *dst_ptr = (float)(*sptr0) * sptr1[chan]; + } + } + } +} + +void neon_mul_int16xf32xf32_vec_vec( + size_t size, const int16_t* src0, const float* src1, float* dst) { + const int16_t* __restrict sptr0 = src0; + const float* __restrict sptr1 = src1; + float* __restrict dst_ptr = dst; + size_t i = 0; + for (; i + 15 < size; i += 16, sptr0 += 16, sptr1 += 16, dst_ptr += 16) { + auto vec0_01 = vld1q_s16(sptr0); + auto vec0_23 = vld1q_s16(sptr0 + 8); + auto vec1_0 = vld1q_f32(sptr1); + auto vec1_1 = vld1q_f32(sptr1 + 4); + auto vec1_2 = vld1q_f32(sptr1 + 8); + auto vec1_3 = vld1q_f32(sptr1 + 12); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); + auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); + + auto dst_vec_0 = vmulq_f32(vec0_0, vec1_0); + auto dst_vec_1 = vmulq_f32(vec0_1, vec1_1); + auto dst_vec_2 = vmulq_f32(vec0_2, vec1_2); + auto dst_vec_3 = vmulq_f32(vec0_3, vec1_3); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + vst1q_f32(dst_ptr + 8, dst_vec_2); + vst1q_f32(dst_ptr + 12, dst_vec_3); + } + + for (; i + 7 < size; i += 8, sptr0 += 8, sptr1 += 8, dst_ptr += 8) { + auto vec0_01 = vld1q_s16(sptr0); + auto vec1_0 = vld1q_f32(sptr1); + auto vec1_1 = vld1q_f32(sptr1 + 4); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + + auto dst_vec_0 = vmulq_f32(vec0_0, vec1_0); + auto dst_vec_1 = vmulq_f32(vec0_1, vec1_1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + } + for (; i < size; ++i, ++sptr0, ++sptr1, ++dst_ptr) { + *dst_ptr = (float)(*sptr0) * (*sptr1); + } +} + +void neon_mul_int16xf32xf32_vec_scaler( + size_t size, const int16_t* src0, const float* src1, float* dst) { + const int16_t* __restrict sptr0 = src0; + const float* __restrict sptr1 = src1; + auto vec1 = vdupq_n_f32(sptr1[0]); + float* __restrict dst_ptr = dst; + size_t i = 0; + for (; i + 15 < size; i += 16, sptr0 += 16, dst_ptr += 16) { + auto vec0_01 = vld1q_s16(sptr0); + auto vec0_23 = vld1q_s16(sptr0 + 8); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); + auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); + + auto dst_vec_0 = vmulq_f32(vec0_0, vec1); + auto dst_vec_1 = vmulq_f32(vec0_1, vec1); + auto dst_vec_2 = vmulq_f32(vec0_2, vec1); + auto dst_vec_3 = vmulq_f32(vec0_3, vec1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + vst1q_f32(dst_ptr + 8, dst_vec_2); + vst1q_f32(dst_ptr + 12, dst_vec_3); + } + + for (; i + 7 < size; i += 8, sptr0 += 8, dst_ptr += 8) { + auto vec0_01 = vld1q_s16(sptr0); + + auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); + auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); + + auto dst_vec_0 = vmulq_f32(vec0_0, vec1); + auto dst_vec_1 = vmulq_f32(vec0_1, vec1); + + vst1q_f32(dst_ptr, dst_vec_0); + vst1q_f32(dst_ptr + 4, dst_vec_1); + } + for (; i < size; ++i, ++sptr0, ++dst_ptr) { + *dst_ptr = (float)(*sptr0) * (*sptr1); + } +} + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_multi_type/kernels.h b/dnn/src/arm_common/elemwise_multi_type/kernels.h new file mode 100644 index 000000000..636c5c340 --- /dev/null +++ b/dnn/src/arm_common/elemwise_multi_type/kernels.h @@ -0,0 +1,71 @@ +/** + * \file dnn/src/arm_common/elemwise_multi_type/kernels.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "stddef.h" +#include "stdint.h" + +namespace megdnn { +namespace arm_common { +void neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast111c_bcast111c( + size_t batch_size, size_t channel_stride, size_t channel_size, + const int16_t* src0, const float* src1, const float* src2, float* dst); + +void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast111c_bcast111c( + size_t batch_size, size_t channel_stride, size_t channel_size, + const uint8_t* src0, const float* src1, const float* src2, float* dst); + +void neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast101_bcast101( + size_t batch_size, size_t channel_size, size_t channel_stride, + const int16_t* src0, const float* src1, const float* src2, float* dst); + +void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast101_bcast101( + size_t batch_size, size_t channel_size, size_t channel_stride, + const uint8_t* src0, const float* src1, const float* src2, float* dst); + +void neon_fuse_mul_add3_int16xf32xf32xf32_vec_vec_vec( + size_t size, const int16_t* src0, const float* src1, const float* src2, + float* dst); + +void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_vec_vec( + size_t size, const uint8_t* src0, const float* src1, const float* src2, + float* dst); + +void neon_fuse_mul_add3_int16xf32xf32xf32_vec_b1x_b1x( + size_t size, size_t vec, const int16_t* src0, const float* src1, + const float* src2, float* dst); + +void neon_fuse_mul_add3_int16xf32xf32xf32_vec_scaler_scaler( + size_t size, const int16_t* src0, const float* src1, const float* src2, + float* dst); + +void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_scaler_scaler( + size_t size, const uint8_t* src0, const float* src1, const float* src2, + float* dst); + +void neon_mul_int16xf32xf32_vec_bcast111c( + size_t batch_size, size_t channel_stride, size_t channel_size, + const int16_t* src0, const float* src1, float* dst); + +void neon_mul_int16xf32xf32_vec_bcast101( + size_t batch_size, size_t channel_size, size_t channel_stride, + const int16_t* src0, const float* src1, float* dst); + +void neon_mul_int16xf32xf32_vec_vec( + size_t size, const int16_t* src0, const float* src1, float* dst); + +void neon_mul_int16xf32xf32_vec_scaler( + size_t size, const int16_t* src0, const float* src1, float* dst); +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp b/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp index 810c20c9d..69a10cabc 100644 --- a/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp @@ -11,6 +11,7 @@ */ #include "./opr_impl.h" +#include "kernels.h" #include "src/common/elemwise_multi_type/kern_defs.cuh" #include "src/naive/handle.h" @@ -851,6 +852,154 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( #undef DISPATCH_QUANTIZED_MODE } +void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32( + const ElemwiseOpParamN<3>& param, const TensorND& dst) { + auto &src0 = param[0], &src1 = param[1], &src2 = param[2]; + BroadcastChannelInfo binfo; + + if (is_vector(src0.layout) && + is_NHWC_broadcasted_channel_like(src1.layout, binfo) && + src1.layout.eq_layout(src2.layout)) { + // VEC_BCAST111C_BCAST111C + MEGDNN_DISPATCH_CPU_KERN_OPR( + neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast111c_bcast111c( + binfo.x, binfo.y, binfo.z, + static_cast(src0.raw_ptr()), + static_cast(src1.raw_ptr()), + static_cast(src2.raw_ptr()), + dst.ptr())); + return; + } else if ( + is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo) && + src1.layout.eq_layout(src2.layout)) { + // VEC_BCAST101_BCAST101 + MEGDNN_DISPATCH_CPU_KERN_OPR( + neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast101_bcast101( + binfo.x, binfo.y, binfo.z, + static_cast(src0.raw_ptr()), + static_cast(src1.raw_ptr()), + static_cast(src2.raw_ptr()), + dst.ptr())); + return; + } else if ( + is_vector(src0.layout) && is_vector(src1.layout) && + is_vector(src2.layout)) { + // VEC_VEC_VEC + auto size = param.size; + MEGDNN_DISPATCH_CPU_KERN_OPR(neon_fuse_mul_add3_int16xf32xf32xf32_vec_vec_vec( + size, static_cast(src0.raw_ptr()), + static_cast(src1.raw_ptr()), + static_cast(src2.raw_ptr()), dst.ptr())); + return; + } else if ( + is_vector(src0.layout) && is_broadcasted_scalar(src1.layout) && + is_broadcasted_scalar(src2.layout)) { + // VEC_SCALAR_SCALAR + auto size = param.size; + MEGDNN_DISPATCH_CPU_KERN_OPR( + neon_fuse_mul_add3_int16xf32xf32xf32_vec_scaler_scaler( + size, static_cast(src0.raw_ptr()), + static_cast(src1.raw_ptr()), + static_cast(src2.raw_ptr()), + dst.ptr())); + return; + } + naive::ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32(param, dst); +} + +void ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32( + const ElemwiseOpParamN<3>& param, const TensorND& dst) { + auto &src0 = param[0], &src1 = param[1], &src2 = param[2]; + BroadcastChannelInfo binfo; + + if (is_vector(src0.layout) && + is_NHWC_broadcasted_channel_like(src1.layout, binfo) && + src1.layout.eq_layout(src2.layout)) { + // VEC_BCAST111C_BCAST111C + MEGDNN_DISPATCH_CPU_KERN_OPR( + neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast111c_bcast111c( + binfo.x, binfo.y, binfo.z, + static_cast(src0.raw_ptr()), + static_cast(src1.raw_ptr()), + static_cast(src2.raw_ptr()), + dst.ptr())); + return; + } else if ( + is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo) && + src1.layout.eq_layout(src2.layout)) { + // VEC_BCAST101_BCAST101 + MEGDNN_DISPATCH_CPU_KERN_OPR( + neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast101_bcast101( + binfo.x, binfo.y, binfo.z, + static_cast(src0.raw_ptr()), + static_cast(src1.raw_ptr()), + static_cast(src2.raw_ptr()), + dst.ptr())); + return; + } else if ( + is_vector(src0.layout) && is_vector(src1.layout) && + is_vector(src2.layout)) { + // VEC_VEC_VEC + auto size = param.size; + MEGDNN_DISPATCH_CPU_KERN_OPR(neon_fuse_mul_add3_uint8xf32xf32xf32_vec_vec_vec( + size, static_cast(src0.raw_ptr()), + static_cast(src1.raw_ptr()), + static_cast(src2.raw_ptr()), dst.ptr())); + return; + } else if ( + is_vector(src0.layout) && is_broadcasted_scalar(src1.layout) && + is_broadcasted_scalar(src2.layout)) { + // VEC_SCALAR_SCALAR + auto size = param.size; + MEGDNN_DISPATCH_CPU_KERN_OPR( + neon_fuse_mul_add3_uint8xf32xf32xf32_vec_scaler_scaler( + size, static_cast(src0.raw_ptr()), + static_cast(src1.raw_ptr()), + static_cast(src2.raw_ptr()), + dst.ptr())); + return; + } + + naive::ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32(param, dst); +} + +void ElemwiseMultiTypeImpl::on_mul_int16xf32xf32( + const ElemwiseOpParamN<2>& param, const TensorND& dst) { + auto &src0 = param[0], &src1 = param[1]; + BroadcastChannelInfo binfo; + + if (is_vector(src0.layout) && + is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { + // VEC_BCAST111C + MEGDNN_DISPATCH_CPU_KERN_OPR(neon_mul_int16xf32xf32_vec_bcast111c( + binfo.x, binfo.y, binfo.z, static_cast(src0.raw_ptr()), + static_cast(src1.raw_ptr()), dst.ptr())); + return; + } else if ( + is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo)) { + // VEC_BCAST101 + MEGDNN_DISPATCH_CPU_KERN_OPR(neon_mul_int16xf32xf32_vec_bcast101( + binfo.x, binfo.y, binfo.z, static_cast(src0.raw_ptr()), + static_cast(src1.raw_ptr()), dst.ptr())); + return; + } else if (is_vector(src0.layout) && is_vector(src1.layout)) { + // VEC_VEC + auto size = param.size; + MEGDNN_DISPATCH_CPU_KERN_OPR(neon_mul_int16xf32xf32_vec_vec( + size, static_cast(src0.raw_ptr()), + static_cast(src1.raw_ptr()), dst.ptr())); + return; + } else if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout)) { + auto size = param.size; + MEGDNN_DISPATCH_CPU_KERN_OPR(neon_mul_int16xf32xf32_vec_scaler( + size, static_cast(src0.raw_ptr()), + static_cast(src1.raw_ptr()), dst.ptr())); + return; + } + + naive::ElemwiseMultiTypeImpl::on_mul_int16xf32xf32(param, dst); +} + } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/elemwise_multi_type/opr_impl.h b/dnn/src/arm_common/elemwise_multi_type/opr_impl.h index deb348838..c32535291 100644 --- a/dnn/src/arm_common/elemwise_multi_type/opr_impl.h +++ b/dnn/src/arm_common/elemwise_multi_type/opr_impl.h @@ -48,6 +48,15 @@ protected: const ElemwiseOpParamN<3>& param, const TensorND& dst, Elemwise::Mode mode) override; + void on_fuse_mul_add3_int16xf32xf32xf32( + const ElemwiseOpParamN<3>& param, const TensorND& dst) override; + + void on_mul_int16xf32xf32( + const ElemwiseOpParamN<2>& param, const TensorND& dst) override; + + void on_fuse_mul_add3_uint8xf32xf32xf32( + const ElemwiseOpParamN<3>& param, const TensorND& dst) override; + public: using fallback::ElemwiseMultiTypeImpl::ElemwiseMultiTypeImpl; }; diff --git a/dnn/src/common/elemwise_multi_type/opr_impl.cpp b/dnn/src/common/elemwise_multi_type/opr_impl.cpp index 164aa5b14..ec284d7de 100644 --- a/dnn/src/common/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/common/elemwise_multi_type/opr_impl.cpp @@ -155,6 +155,29 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { dst.name = name; dst.need_specify_out_dtype = true; }; + auto init_fma3_int16xf32xf32xf32 = [&](ModeTrait& dst, const char* name) { + dst.arity = 3; + dst.check_inp[0] = make_check_dtype_func(dtype::Int16()); + dst.check_inp[1] = make_check_dtype_func(dtype::Float32()); + dst.check_inp[2] = make_check_dtype_func(dtype::Float32()); + dst.check_out = make_out_dtype_func(dtype::Float32()); + dst.name = name; + }; + auto init_mul_int16xf32xf32 = [&](ModeTrait& dst, const char* name) { + dst.arity = 2; + dst.check_inp[0] = make_check_dtype_func(dtype::Int16()); + dst.check_inp[1] = make_check_dtype_func(dtype::Float32()); + dst.check_out = make_out_dtype_func(dtype::Float32()); + dst.name = name; + }; + auto init_fma3_uint8xf32xf32xf32 = [&](ModeTrait& dst, const char* name) { + dst.arity = 3; + dst.check_inp[0] = make_check_dtype_func(dtype::Uint8()); + dst.check_inp[1] = make_check_dtype_func(dtype::Float32()); + dst.check_inp[2] = make_check_dtype_func(dtype::Float32()); + dst.check_out = make_out_dtype_func(dtype::Float32()); + dst.name = name; + }; #define SET(f, m) \ MIDOUT_BEGIN(megdnn_common_elemwise_multi_type, midout_iv(Mode::m)) { \ @@ -169,6 +192,9 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { SET(init_fuse_add_rmulh_rshr_int32x32x32x8, FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8); SET(init_rshrs_iXxi8xi16, ROUND_SHR_SATURATE_IXxI8xI16); + SET(init_fma3_int16xf32xf32xf32, FUSE_MUL_ADD3_INT16xF32xF32xF32); + SET(init_mul_int16xf32xf32, MUL_INT16xF32xF32); + SET(init_fma3_uint8xf32xf32xf32, FUSE_MUL_ADD3_UINT8xF32xF32xF32); //! quantized opr, with specified dtype. //! dispatch elemwise mode internally diff --git a/dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp b/dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp index c4a32d7d5..d938a79f4 100644 --- a/dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp +++ b/dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp @@ -43,6 +43,17 @@ void ElemwiseMultiTypeImplHelper::exec( case Mode::ROUND_SHR_SATURATE_IXxI8xI16: on_round_shr_saturate_iXxi8xi16(make_elemwise_op_param<2>(src, dst), dst); break; + case Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32: + on_fuse_mul_add3_int16xf32xf32xf32( + make_elemwise_op_param<3>(src, dst), dst); + break; + case Mode::MUL_INT16xF32xF32: + on_mul_int16xf32xf32(make_elemwise_op_param<2>(src, dst), dst); + break; + case Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32: + on_fuse_mul_add3_uint8xf32xf32xf32( + make_elemwise_op_param<3>(src, dst), dst); + break; ON_QUANTIZED_MODE(RELU, 1); ON_QUANTIZED_MODE(ABS, 1); ON_QUANTIZED_MODE(ACOS, 1); diff --git a/dnn/src/common/elemwise_multi_type/opr_impl_helper.h b/dnn/src/common/elemwise_multi_type/opr_impl_helper.h index 72e163b8f..d50010416 100644 --- a/dnn/src/common/elemwise_multi_type/opr_impl_helper.h +++ b/dnn/src/common/elemwise_multi_type/opr_impl_helper.h @@ -50,6 +50,27 @@ protected: virtual void on_round_shr_saturate_iXxi8xi16( const ElemwiseOpParamN<2>& param, const TensorND& dst) = 0; + virtual void on_fuse_mul_add3_int16xf32xf32xf32( + const ElemwiseOpParamN<3>& param, const TensorND& dst) { + MEGDNN_MARK_USED_VAR(param); + MEGDNN_MARK_USED_VAR(dst); + megdnn_throw("unsupported ElemwiseMultiType fma3 int16xf32xf32xf32."); + } + + virtual void on_mul_int16xf32xf32( + const ElemwiseOpParamN<2>& param, const TensorND& dst) { + MEGDNN_MARK_USED_VAR(param); + MEGDNN_MARK_USED_VAR(dst); + megdnn_throw("unsupported ElemwiseMultiType fma3 int16xf32xf32."); + } + + virtual void on_fuse_mul_add3_uint8xf32xf32xf32( + const ElemwiseOpParamN<3>& param, const TensorND& dst) { + MEGDNN_MARK_USED_VAR(param); + MEGDNN_MARK_USED_VAR(dst); + megdnn_throw("unsupported ElemwiseMultiType fma3 uint8xf32xf32xf32."); + } + virtual void on_quantized_mode( const ElemwiseOpParamN<1>& param, const TensorND& dst, Elemwise::Mode mode) { diff --git a/dnn/src/fallback/elemwise_multi_type/opr_impl.cpp b/dnn/src/fallback/elemwise_multi_type/opr_impl.cpp index 95dbbc223..80eacbcd0 100644 --- a/dnn/src/fallback/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/fallback/elemwise_multi_type/opr_impl.cpp @@ -56,6 +56,216 @@ void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32( naive::ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32(param, dst); } +void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32( + const ElemwiseOpParamN<3>& param, const TensorND& dst) { + BroadcastChannelInfo binfo0, binfo1; + + if (is_vector(param[0].layout) && + is_NHWC_broadcasted_channel_like(param[1].layout, binfo0) && + is_NHWC_broadcasted_channel_like(param[2].layout, binfo1) && binfo0 == binfo1) { + auto x = binfo0.x, y = binfo0.y, z = binfo0.z; + auto src0 = param[0]; + auto src1 = param[1]; + auto src2 = param[2]; + auto work = [=]() { + const dt_int16* __restrict__ a = static_cast(src0.raw_ptr()); + const dt_float32* __restrict__ b = static_cast(src1.raw_ptr()); + const dt_float32* __restrict__ c = static_cast(src2.raw_ptr()); + dt_float32* __restrict__ d = dst.ptr(); + for (size_t i = 0; i < x; ++i) { + for (size_t j = 0; j < y; ++j) { + auto off = i * (y * z) + j * z; + size_t k = 0; + for (; k + 4 <= z; k += 4) { + d[off + k + 0] = a[off + k + 0] * b[k + 0] + c[k + 0]; + d[off + k + 1] = a[off + k + 1] * b[k + 1] + c[k + 1]; + d[off + k + 2] = a[off + k + 2] * b[k + 2] + c[k + 2]; + d[off + k + 3] = a[off + k + 3] * b[k + 3] + c[k + 3]; + } + for (; k < z; ++k) { + d[off + k] = a[off + k] * b[k] + c[k]; + } + } + } + }; + + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); + return; + } else if ( + is_vector(param[0].layout) && + is_broadcasted_channel_like(param[1].layout, binfo0) && + is_broadcasted_channel_like(param[2].layout, binfo1) && binfo0 == binfo1) { + auto x = binfo0.x, y = binfo0.y, z = binfo0.z; + auto src0 = param[0]; + auto src1 = param[1]; + auto src2 = param[2]; + auto work = [=]() { + const dt_int16* __restrict__ a = static_cast(src0.raw_ptr()); + const dt_float32* __restrict__ b = static_cast(src1.raw_ptr()); + const dt_float32* __restrict__ c = static_cast(src2.raw_ptr()); + dt_float32* __restrict__ d = dst.ptr(); + for (size_t j = 0; j < y; ++j) { + auto bv = b[j], cv = c[j]; + for (size_t i = 0; i < x; ++i) { + auto off = i * (y * z) + j * z, offt = off + z; + for (; off + 4 <= offt; off += 4) { + d[off + 0] = a[off + 0] * bv + cv; + d[off + 1] = a[off + 1] * bv + cv; + d[off + 2] = a[off + 2] * bv + cv; + d[off + 3] = a[off + 3] * bv + cv; + } + for (; off < offt; ++off) { + d[off] = a[off] * bv + cv; + } + } + } + }; + + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); + return; + } + + naive::ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32(param, dst); +} + +void ElemwiseMultiTypeImpl::on_mul_int16xf32xf32( + const ElemwiseOpParamN<2>& param, const TensorND& dst) { + BroadcastChannelInfo binfo; + + if (is_vector(param[0].layout) && + is_NHWC_broadcasted_channel_like(param[1].layout, binfo)) { + auto x = binfo.x, y = binfo.y, z = binfo.z; + auto src0 = param[0]; + auto src1 = param[1]; + auto work = [=]() { + const dt_int16* __restrict__ a = static_cast(src0.raw_ptr()); + const dt_float32* __restrict__ b = static_cast(src1.raw_ptr()); + dt_float32* __restrict__ d = dst.ptr(); + for (size_t i = 0; i < x; ++i) { + for (size_t j = 0; j < y; ++j) { + auto off = i * (y * z) + j * z; + size_t k = 0; + for (; k + 4 <= z; k += 4) { + d[off + k + 0] = a[off + k + 0] * b[k + 0]; + d[off + k + 1] = a[off + k + 1] * b[k + 1]; + d[off + k + 2] = a[off + k + 2] * b[k + 2]; + d[off + k + 3] = a[off + k + 3] * b[k + 3]; + } + for (; k < z; ++k) { + d[off + k] = a[off + k] * b[k]; + } + } + } + }; + + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); + return; + } else if ( + is_vector(param[0].layout) && + is_broadcasted_channel_like(param[1].layout, binfo)) { + auto x = binfo.x, y = binfo.y, z = binfo.z; + auto src0 = param[0]; + auto src1 = param[1]; + auto work = [=]() { + const dt_int16* __restrict__ a = static_cast(src0.raw_ptr()); + const dt_float32* __restrict__ b = static_cast(src1.raw_ptr()); + dt_float32* __restrict__ d = dst.ptr(); + for (size_t j = 0; j < y; ++j) { + auto bv = b[j]; + for (size_t i = 0; i < x; ++i) { + auto off = i * (y * z) + j * z, offt = off + z; + for (; off + 4 <= offt; off += 4) { + d[off + 0] = a[off + 0] * bv; + d[off + 1] = a[off + 1] * bv; + d[off + 2] = a[off + 2] * bv; + d[off + 3] = a[off + 3] * bv; + } + for (; off < offt; ++off) { + d[off] = a[off] * bv; + } + } + } + }; + + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); + return; + } + + naive::ElemwiseMultiTypeImpl::on_mul_int16xf32xf32(param, dst); +} + +void ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32( + const ElemwiseOpParamN<3>& param, const TensorND& dst) { + BroadcastChannelInfo binfo0, binfo1; + + if (is_vector(param[0].layout) && + is_NHWC_broadcasted_channel_like(param[1].layout, binfo0) && + is_NHWC_broadcasted_channel_like(param[2].layout, binfo1) && binfo0 == binfo1) { + auto x = binfo0.x, y = binfo0.y, z = binfo0.z; + auto src0 = param[0]; + auto src1 = param[1]; + auto src2 = param[2]; + auto work = [=]() { + const dt_uint8* __restrict__ a = static_cast(src0.raw_ptr()); + const dt_float32* __restrict__ b = static_cast(src1.raw_ptr()); + const dt_float32* __restrict__ c = static_cast(src2.raw_ptr()); + dt_float32* __restrict__ d = dst.ptr(); + for (size_t i = 0; i < x; ++i) { + for (size_t j = 0; j < y; ++j) { + auto off = i * (y * z) + j * z; + size_t k = 0; + for (; k + 4 <= z; k += 4) { + d[off + k + 0] = a[off + k + 0] * b[k + 0] + c[k + 0]; + d[off + k + 1] = a[off + k + 1] * b[k + 1] + c[k + 1]; + d[off + k + 2] = a[off + k + 2] * b[k + 2] + c[k + 2]; + d[off + k + 3] = a[off + k + 3] * b[k + 3] + c[k + 3]; + } + for (; k < z; ++k) { + d[off + k] = a[off + k] * b[k] + c[k]; + } + } + } + }; + + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); + return; + } else if ( + is_vector(param[0].layout) && + is_broadcasted_channel_like(param[1].layout, binfo0) && + is_broadcasted_channel_like(param[2].layout, binfo1) && binfo0 == binfo1) { + auto x = binfo0.x, y = binfo0.y, z = binfo0.z; + auto src0 = param[0]; + auto src1 = param[1]; + auto src2 = param[2]; + auto work = [=]() { + const dt_uint8* __restrict__ a = static_cast(src0.raw_ptr()); + const dt_float32* __restrict__ b = static_cast(src1.raw_ptr()); + const dt_float32* __restrict__ c = static_cast(src2.raw_ptr()); + dt_float32* __restrict__ d = dst.ptr(); + for (size_t j = 0; j < y; ++j) { + auto bv = b[j], cv = c[j]; + for (size_t i = 0; i < x; ++i) { + auto off = i * (y * z) + j * z, offt = off + z; + for (; off + 4 <= offt; off += 4) { + d[off + 0] = a[off + 0] * bv + cv; + d[off + 1] = a[off + 1] * bv + cv; + d[off + 2] = a[off + 2] * bv + cv; + d[off + 3] = a[off + 3] * bv + cv; + } + for (; off < offt; ++off) { + d[off] = a[off] * bv + cv; + } + } + } + }; + + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); + return; + } + + naive::ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32(param, dst); +} + template void ElemwiseMultiTypeImpl::dispatch_fma3_iXxf32xf32xi8_bcast_1x( const ElemwiseOpParamN<3>& param, const Broadcast1xInfo& binfo, diff --git a/dnn/src/fallback/elemwise_multi_type/opr_impl.h b/dnn/src/fallback/elemwise_multi_type/opr_impl.h index 91df43fed..9d6035200 100644 --- a/dnn/src/fallback/elemwise_multi_type/opr_impl.h +++ b/dnn/src/fallback/elemwise_multi_type/opr_impl.h @@ -43,6 +43,12 @@ protected: const ElemwiseOpParamN<6>& param, const TensorND& dst) override; void on_round_shr_saturate_iXxi8xi16( const ElemwiseOpParamN<2>& param, const TensorND& dst) override; + void on_fuse_mul_add3_int16xf32xf32xf32( + const ElemwiseOpParamN<3>& param, const TensorND& dst) override; + void on_mul_int16xf32xf32( + const ElemwiseOpParamN<2>& param, const TensorND& dst) override; + void on_fuse_mul_add3_uint8xf32xf32xf32( + const ElemwiseOpParamN<3>& param, const TensorND& dst) override; public: using naive::ElemwiseMultiTypeImpl::ElemwiseMultiTypeImpl; diff --git a/dnn/src/naive/elemwise_multi_type/opr_impl.cpp b/dnn/src/naive/elemwise_multi_type/opr_impl.cpp index 18769e0a9..502cef74e 100644 --- a/dnn/src/naive/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/naive/elemwise_multi_type/opr_impl.cpp @@ -39,6 +39,66 @@ void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32( MEGDNN_DISPATCH_CPU_KERN_OPR(work()); } +void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32( + const ElemwiseOpParamN<3>& param, const TensorND& dst) { + auto size = param.size; + auto src0 = param[0]; + auto src1 = param[1]; + auto src2 = param[2]; + auto work = [src0, src1, src2, size, dst]() { + auto i0 = tensor_iter_valonly(src0).begin(); + auto i1 = tensor_iter_valonly(src1).begin(); + auto i2 = tensor_iter_valonly(src2).begin(); + auto dst_ptr = dst.ptr(); + for (size_t i = 0; i < size; ++i) { + dst_ptr[i] = (*i0) * (*i1) + (*i2); + ++i0; + ++i1; + ++i2; + } + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); +} + +void ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32( + const ElemwiseOpParamN<3>& param, const TensorND& dst) { + auto size = param.size; + auto src0 = param[0]; + auto src1 = param[1]; + auto src2 = param[2]; + auto work = [src0, src1, src2, size, dst]() { + auto i0 = tensor_iter_valonly(src0).begin(); + auto i1 = tensor_iter_valonly(src1).begin(); + auto i2 = tensor_iter_valonly(src2).begin(); + auto dst_ptr = dst.ptr(); + for (size_t i = 0; i < size; ++i) { + dst_ptr[i] = (*i0) * (*i1) + (*i2); + ++i0; + ++i1; + ++i2; + } + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); +} + +void ElemwiseMultiTypeImpl::on_mul_int16xf32xf32( + const ElemwiseOpParamN<2>& param, const TensorND& dst) { + auto size = param.size; + auto src0 = param[0]; + auto src1 = param[1]; + auto work = [src0, src1, size, dst]() { + auto i0 = tensor_iter_valonly(src0).begin(); + auto i1 = tensor_iter_valonly(src1).begin(); + auto dst_ptr = dst.ptr(); + for (size_t i = 0; i < size; ++i) { + dst_ptr[i] = (*i0) * (*i1); + ++i0; + ++i1; + } + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(work()); +} + void ElemwiseMultiTypeImpl::on_fuse_mul_add3_iXxf32xf32xi8( const ElemwiseOpParamN<3>& param, const TensorND& dst) { switch (param[0].layout.dtype.enumv()) { diff --git a/dnn/src/naive/elemwise_multi_type/opr_impl.h b/dnn/src/naive/elemwise_multi_type/opr_impl.h index 60dcaed19..ad81568a7 100644 --- a/dnn/src/naive/elemwise_multi_type/opr_impl.h +++ b/dnn/src/naive/elemwise_multi_type/opr_impl.h @@ -60,6 +60,12 @@ protected: const ElemwiseOpParamN<6>& param, const TensorND& dst) override; void on_round_shr_saturate_iXxi8xi16( const ElemwiseOpParamN<2>& param, const TensorND& dst) override; + void on_fuse_mul_add3_int16xf32xf32xf32( + const ElemwiseOpParamN<3>& param, const TensorND& dst) override; + void on_mul_int16xf32xf32( + const ElemwiseOpParamN<2>& param, const TensorND& dst) override; + void on_fuse_mul_add3_uint8xf32xf32xf32( + const ElemwiseOpParamN<3>& param, const TensorND& dst) override; void on_quantized_mode( const ElemwiseOpParamN<1>& param, const TensorND& dst, diff --git a/dnn/test/arm_common/elemwise_multi_type.cpp b/dnn/test/arm_common/elemwise_multi_type.cpp index b5c2a71c8..7e8e5b3a9 100644 --- a/dnn/test/arm_common/elemwise_multi_type.cpp +++ b/dnn/test/arm_common/elemwise_multi_type.cpp @@ -456,4 +456,107 @@ TEST_F(ARM_COMMON, ELEMWISE_QUANTIZED_MODE_TERNARY_RECORD) { } } +TEST_F(ARM_COMMON, ELEMWISE_FMA3_INT16xF32xF32xF32) { + Checker checker(handle()); + checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32}); + checker.set_dtype(0, dtype::Int16()); + checker.set_dtype(1, dtype::Float32()); + checker.set_dtype(2, dtype::Float32()); + UniformIntRNG rng{-100, 100}; + checker.set_rng(0, &rng); + checker.set_rng(1, &rng); + checker.set_rng(2, &rng); + checker.execs({{5, 7, 16}, {1, 1, 16}, {1, 1, 16}, {}}) + .execs({{2, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}}) + .execs({{2, 700, 600}, {2, 700, 600}, {2, 700, 600}, {}}) + .execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}}) + .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}}); +} + +TEST_F(ARM_COMMON, ELEMWISE_FMA3_INT16xF32xF32xF32_RECORD) { + TaskRecordChecker checker(0); + checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32}); + checker.set_dtype(0, dtype::Int16()); + checker.set_dtype(1, dtype::Float32()); + checker.set_dtype(2, dtype::Float32()); + UniformIntRNG rng{-100, 100}; + checker.set_rng(0, &rng); + checker.set_rng(1, &rng); + checker.set_rng(2, &rng); + checker.execs({{5, 7, 16}, {1, 1, 16}, {1, 1, 16}, {}}) + .execs({{2, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}}) + .execs({{2, 700, 600}, {2, 700, 600}, {2, 700, 600}, {}}) + .execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}}) + .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}}) + .execs({{16, 128, 16, 18}, {1, 1, 1, 18}, {1, 1, 1, 18}, {}}) + .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}}); +} + +TEST_F(ARM_COMMON, ELEMWISE_MUL_INT16xF32xF32) { + Checker checker(handle()); + checker.set_param({ElemwiseMultiType::Mode::MUL_INT16xF32xF32}); + checker.set_dtype(0, dtype::Int16()); + checker.set_dtype(1, dtype::Float32()); + UniformIntRNG rng{-100, 100}; + checker.set_rng(0, &rng); + checker.set_rng(1, &rng); + checker.execs({{5, 7, 16}, {1, 1, 16}, {}}) + .execs({{2, 700, 600}, {1, 1, 600}, {}}) + .execs({{2, 700, 600}, {2, 700, 600}, {}}) + .execs({{16, 16, 128}, {16, 16, 128}, {}}) + .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {}}); +} + +TEST_F(ARM_COMMON, ELEMWISE_ELEMWISE_MUL_INT16xF32xF32_RECORD) { + TaskRecordChecker checker(0); + checker.set_param({ElemwiseMultiType::Mode::MUL_INT16xF32xF32}); + checker.set_dtype(0, dtype::Int16()); + checker.set_dtype(1, dtype::Float32()); + UniformIntRNG rng{-100, 100}; + checker.set_rng(0, &rng); + checker.set_rng(1, &rng); + checker.execs({{5, 7, 16}, {1, 1, 16}, {}}) + .execs({{2, 700, 600}, {1, 1, 600}, {}}) + .execs({{2, 700, 600}, {2, 700, 600}, {}}) + .execs({{16, 16, 128}, {16, 16, 128}, {}}) + .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {}}) + .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {}}); +} + +TEST_F(ARM_COMMON, ELEMWISE_FMA3_UINT8xF32xF32xF32) { + Checker checker(handle()); + checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32}); + checker.set_dtype(0, dtype::Uint8()); + checker.set_dtype(1, dtype::Float32()); + checker.set_dtype(2, dtype::Float32()); + UniformIntRNG rng{-100, 100}; + checker.set_rng(0, &rng); + checker.set_rng(1, &rng); + checker.set_rng(2, &rng); + checker.execs({{5, 7, 16}, {1, 1, 16}, {1, 1, 16}, {}}) + .execs({{2, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}}) + .execs({{2, 700, 600}, {2, 700, 600}, {2, 700, 600}, {}}) + .execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}}) + .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}}) + .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}}); +} + +TEST_F(ARM_COMMON, ELEMWISE_FMA3_UINT8xF32xF32xF32_RECORD) { + TaskRecordChecker checker(0); + checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32}); + checker.set_dtype(0, dtype::Uint8()); + checker.set_dtype(1, dtype::Float32()); + checker.set_dtype(2, dtype::Float32()); + UniformIntRNG rng{-100, 100}; + checker.set_rng(0, &rng); + checker.set_rng(1, &rng); + checker.set_rng(2, &rng); + checker.execs({{5, 7, 16}, {1, 1, 16}, {1, 1, 16}, {}}) + .execs({{2, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}}) + .execs({{2, 700, 600}, {2, 700, 600}, {2, 700, 600}, {}}) + .execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}}) + .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}}) + .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}}); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/test/common/elemwise_multi_type.cpp b/dnn/test/common/elemwise_multi_type.cpp index 5263c535e..7c4844bd7 100644 --- a/dnn/test/common/elemwise_multi_type.cpp +++ b/dnn/test/common/elemwise_multi_type.cpp @@ -79,6 +79,73 @@ DEF_TEST(fuse_mul_add3_int16x32x32x32) { .execs({{102, 67, 71}, {1, 67, 1}, {1, 67, 1}, {}}); } +DEF_TEST(fuse_mul_add3_int16xf32xf32xf32) { + // This is not implemented on CUDA. + if (handle->type() == Handle::HandleType::CUDA) { + return; + } + Checker checker(handle); + checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32}); + checker.set_dtype(0, dtype::Int16()); + checker.set_dtype(1, dtype::Float32()); + checker.set_dtype(2, dtype::Float32()); + UniformIntRNG rng{-100, 100}; + checker.set_rng(0, &rng); + checker.set_rng(1, &rng); + checker.set_rng(2, &rng); + checker.execs({{5, 7, 6}, {1, 1, 6}, {1, 1, 6}, {}}) + .execs({{1, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}}) + .execs({{1, 700, 600}, {1, 700, 600}, {1, 700, 600}, {}}) + .execs({{102, 71, 67}, {1, 1, 67}, {1, 1, 67}, {}}) + .execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}}) + .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}}) + .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}}); +} + +DEF_TEST(fuse_mul_add3_uint8xf32xf32xf32) { + // This is not implemented on CUDA. + if (handle->type() == Handle::HandleType::CUDA) { + return; + } + Checker checker(handle); + checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32}); + checker.set_dtype(0, dtype::Uint8()); + checker.set_dtype(1, dtype::Float32()); + checker.set_dtype(2, dtype::Float32()); + UniformIntRNG rng{-100, 100}; + checker.set_rng(0, &rng); + checker.set_rng(1, &rng); + checker.set_rng(2, &rng); + checker.execs({{5, 7, 6}, {1, 1, 6}, {1, 1, 6}, {}}) + .execs({{1, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}}) + .execs({{1, 700, 600}, {1, 700, 600}, {1, 700, 600}, {}}) + .execs({{102, 71, 67}, {1, 1, 67}, {1, 1, 67}, {}}) + .execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}}) + .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}}) + .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}}); +} + +DEF_TEST(fuse_mul_add3_int16xf32xf32) { + // This is not implemented on CUDA. + if (handle->type() == Handle::HandleType::CUDA) { + return; + } + Checker checker(handle); + checker.set_param({ElemwiseMultiType::Mode::MUL_INT16xF32xF32}); + checker.set_dtype(0, dtype::Int16()); + checker.set_dtype(1, dtype::Float32()); + UniformIntRNG rng{-100, 100}; + checker.set_rng(0, &rng); + checker.set_rng(1, &rng); + checker.execs({{5, 7, 6}, {1, 1, 6}, {}}) + .execs({{1, 700, 600}, {1, 1, 600}, {}}) + .execs({{1, 700, 600}, {1, 700, 600}, {}}) + .execs({{102, 71, 67}, {1, 1, 67}, {}}) + .execs({{16, 16, 128}, {16, 16, 128}, {}}) + .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {}}) + .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {}}); +} + DEF_TEST(fuse_mul_add3_iXxf32xf32xi8) { Checker checker(handle); checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_IXxF32xF32xI8}); diff --git a/dnn/test/common/elemwise_multi_type.h b/dnn/test/common/elemwise_multi_type.h index f5374506c..632ce5fd0 100644 --- a/dnn/test/common/elemwise_multi_type.h +++ b/dnn/test/common/elemwise_multi_type.h @@ -20,10 +20,13 @@ namespace test { namespace elemwise_multi_type { #define FIRST_ELEMWISE_MULTI_TYPE_CASE fuse_mul_add3_int16x32x32x32 -#define FOREACH_ELEMWISE_MULTI_TYPE_NONFIRST_CASE(cb) \ - cb(fuse_mul_add3_iXxf32xf32xi8) cb(round_shr_saturate_iXxi8xi8) \ - cb(fuse_add_rmulh_round_shr_saturate_int16) \ - cb(fuse_add_rmulh_round_shr_saturate_int32) +#define FOREACH_ELEMWISE_MULTI_TYPE_NONFIRST_CASE(cb) \ + cb(fuse_mul_add3_iXxf32xf32xi8) cb(round_shr_saturate_iXxi8xi8) \ + cb(fuse_add_rmulh_round_shr_saturate_int16) \ + cb(fuse_add_rmulh_round_shr_saturate_int32) \ + cb(fuse_mul_add3_int16xf32xf32xf32) \ + cb(fuse_mul_add3_uint8xf32xf32xf32) \ + cb(fuse_mul_add3_int16xf32xf32) #define FOREACH_ELEMWISE_MULTI_TYPE_CASE(cb) \ cb(FIRST_ELEMWISE_MULTI_TYPE_CASE) FOREACH_ELEMWISE_MULTI_TYPE_NONFIRST_CASE(cb) diff --git a/dnn/test/fallback/elemwise_multi_type.cpp b/dnn/test/fallback/elemwise_multi_type.cpp index a51767ced..7e0566f6d 100644 --- a/dnn/test/fallback/elemwise_multi_type.cpp +++ b/dnn/test/fallback/elemwise_multi_type.cpp @@ -40,6 +40,24 @@ TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_RECORD_FMA3_INT16x32x32x32) { checker.execs({{A, B, C}, {1, B, 1}, {1, B, 1}, {}}); } +TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_RECORD_FMA3_INT16xF32xF32xF32) { + TaskRecordChecker checker{1}; + checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32}); + checker.set_dtype(0, dtype::Int16()); + checker.set_dtype(1, dtype::Float32()); + checker.set_dtype(2, dtype::Float32()); + UniformIntRNG rng{-10, 10}; + checker.set_rng(0, &rng); + checker.set_rng(1, &rng); + checker.set_rng(2, &rng); + checker.execs({{5, 7, 16}, {1, 1, 16}, {1, 1, 16}, {}}) + .execs({{2, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}}) + .execs({{2, 700, 600}, {2, 700, 600}, {2, 700, 600}, {}}) + .execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}}) + .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}}) + .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}}); +} + #if MEGDNN_WITH_BENCHMARK TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_BENCHMARK_FMA3_INT16x32x32x32) { Benchmarker bench{handle()}; -- GitLab