提交 f6d99094 编写于 作者: M Megvii Engine Team

feat(dnn): add elemwise multi type support i16xf32 and u8xf32

GitOrigin-RevId: 2fe469bb4ec9a0b7d20f88a54d2a87e7ad42385b
上级 d9a46ea4
......@@ -497,7 +497,16 @@ pdef('ElemwiseMultiType').add_enum(
Doc('QCOND_LEQ_MOV = 50', 'quantized cond_leq_mov'),
Doc('QH_SWISH = 51', 'quantized h_swish'),
Doc('QFUSE_ADD_H_SWISH = 52', 'quantized h_swish(x+y)'),
Doc('QH_SWISH_GRAD = 53', 'quantized h_swish_grad')
Doc('QH_SWISH_GRAD = 53', 'quantized h_swish_grad'),
Doc('FUSE_MUL_ADD3_INT16xF32xF32xF32 = 54',
'compute ``a * b + c`` requiring that ``a`` be int16 and ``b`` and '
'``c`` float32, and the result is float32.'),
Doc('MUL_INT16xF32xF32 = 55',
'compute ``a * b `` requiring that ``a`` be int16 and ``b`` float32, '
'and the result is float32.'),
Doc('FUSE_MUL_ADD3_UINT8xF32xF32xF32 = 56',
'compute ``a * b + c`` requiring that ``a`` be uint8 and ``b`` and '
'``c`` float32, and the result is float32.')
)
pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
......
/**
* \file dnn/src/arm_common/elemwise_multi_type/kernels.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "kernels.h"
#include "src/arm_common/simd_macro/marm_neon.h"
namespace megdnn {
namespace arm_common {
#if defined(__ARM_FEATURE_FMA)
#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m)
#else
#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m)
#endif
void neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast111c_bcast111c(
size_t batch_size, size_t channel_stride, size_t channel_size,
const int16_t* src0, const float* src1, const float* src2, float* dst) {
const int16_t* __restrict sptr0 = src0;
const float* __restrict sptr1 = src1;
const float* __restrict sptr2 = src2;
float* __restrict dst_ptr = dst;
for (size_t batch = 0; batch < batch_size; ++batch) {
for (size_t s = 0; s < channel_stride; ++s) {
size_t i = 0;
for (; i + 15 < channel_size; i += 16, sptr0 += 16, dst_ptr += 16) {
auto vec0_01 = vld1q_s16(sptr0);
auto vec0_23 = vld1q_s16(sptr0 + 8);
auto vec1_0 = vld1q_f32(sptr1 + i);
auto vec1_1 = vld1q_f32(sptr1 + i + 4);
auto vec1_2 = vld1q_f32(sptr1 + i + 8);
auto vec1_3 = vld1q_f32(sptr1 + i + 12);
auto vec2_0 = vld1q_f32(sptr2 + i);
auto vec2_1 = vld1q_f32(sptr2 + i + 4);
auto vec2_2 = vld1q_f32(sptr2 + i + 8);
auto vec2_3 = vld1q_f32(sptr2 + i + 12);
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23)));
auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23)));
auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0);
auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1);
auto dst_vec_2 = Vfmaq_f32(vec2_2, vec0_2, vec1_2);
auto dst_vec_3 = Vfmaq_f32(vec2_3, vec0_3, vec1_3);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
vst1q_f32(dst_ptr + 8, dst_vec_2);
vst1q_f32(dst_ptr + 12, dst_vec_3);
}
for (; i + 7 < channel_size; i += 8, sptr0 += 8, dst_ptr += 8) {
auto vec0_01 = vld1q_s16(sptr0);
auto vec1_0 = vld1q_f32(sptr1 + i);
auto vec1_1 = vld1q_f32(sptr1 + i + 4);
auto vec2_0 = vld1q_f32(sptr2 + i);
auto vec2_1 = vld1q_f32(sptr2 + i + 4);
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0);
auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
}
for (; i + 3 < channel_size; i += 4, sptr0 += 4, dst_ptr += 4) {
auto vec0_0 = vld1_s16(sptr0);
auto vec1_0 = vld1q_f32(sptr1 + i);
auto vec2_0 = vld1q_f32(sptr2 + i);
auto vec0_0_f32 = vcvtq_f32_s32(vmovl_s16(vec0_0));
auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0_f32, vec1_0);
vst1q_f32(dst_ptr, dst_vec_0);
}
for (; i < channel_size; ++i, ++sptr0, ++dst_ptr) {
*dst_ptr = (float)(*sptr0) * sptr1[i] + sptr2[i];
}
}
}
}
void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast111c_bcast111c(
size_t batch_size, size_t channel_stride, size_t channel_size,
const uint8_t* src0, const float* src1, const float* src2, float* dst) {
const uint8_t* __restrict sptr0 = src0;
const float* __restrict sptr1 = src1;
const float* __restrict sptr2 = src2;
float* __restrict dst_ptr = dst;
for (size_t batch = 0; batch < batch_size; ++batch) {
for (size_t s = 0; s < channel_stride; ++s) {
size_t i = 0;
for (; i + 15 < channel_size; i += 16, sptr0 += 16, dst_ptr += 16) {
auto vec0_0123_u8 = vld1q_u8(sptr0);
auto vec1_0 = vld1q_f32(sptr1 + i);
auto vec1_1 = vld1q_f32(sptr1 + i + 4);
auto vec1_2 = vld1q_f32(sptr1 + i + 8);
auto vec1_3 = vld1q_f32(sptr1 + i + 12);
auto vec2_0 = vld1q_f32(sptr2 + i);
auto vec2_1 = vld1q_f32(sptr2 + i + 4);
auto vec2_2 = vld1q_f32(sptr2 + i + 8);
auto vec2_3 = vld1q_f32(sptr2 + i + 12);
auto vec0_01 =
vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(vec0_0123_u8)));
auto vec0_23 =
vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(vec0_0123_u8)));
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23)));
auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23)));
auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0);
auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1);
auto dst_vec_2 = Vfmaq_f32(vec2_2, vec0_2, vec1_2);
auto dst_vec_3 = Vfmaq_f32(vec2_3, vec0_3, vec1_3);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
vst1q_f32(dst_ptr + 8, dst_vec_2);
vst1q_f32(dst_ptr + 12, dst_vec_3);
}
for (; i + 7 < channel_size; i += 8, sptr0 += 8, dst_ptr += 8) {
auto vec0_01_u8 = vld1_u8(sptr0);
auto vec1_0 = vld1q_f32(sptr1 + i);
auto vec1_1 = vld1q_f32(sptr1 + i + 4);
auto vec2_0 = vld1q_f32(sptr2 + i);
auto vec2_1 = vld1q_f32(sptr2 + i + 4);
auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vec0_01_u8));
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0);
auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
}
for (; i < channel_size; ++i, ++sptr0, ++dst_ptr) {
*dst_ptr = (float)(*sptr0) * sptr1[i] + sptr2[i];
}
}
}
}
void neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast101_bcast101(
size_t batch_size, size_t channel_size, size_t channel_stride,
const int16_t* src0, const float* src1, const float* src2, float* dst) {
const int16_t* __restrict sptr0 = src0;
const float* __restrict sptr1 = src1;
const float* __restrict sptr2 = src2;
float* __restrict dst_ptr = dst;
for (size_t batch = 0; batch < batch_size; ++batch) {
for (size_t chan = 0; chan < channel_size; ++chan) {
auto vec1 = vdupq_n_f32(sptr1[chan]);
auto vec2 = vdupq_n_f32(sptr2[chan]);
size_t i = 0;
for (; i + 15 < channel_stride; i += 16, sptr0 += 16, dst_ptr += 16) {
auto vec0_01 = vld1q_s16(sptr0);
auto vec0_23 = vld1q_s16(sptr0 + 8);
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23)));
auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23)));
auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1);
auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1);
auto dst_vec_2 = Vfmaq_f32(vec2, vec0_2, vec1);
auto dst_vec_3 = Vfmaq_f32(vec2, vec0_3, vec1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
vst1q_f32(dst_ptr + 8, dst_vec_2);
vst1q_f32(dst_ptr + 12, dst_vec_3);
}
for (; i + 7 < channel_stride; i += 8, sptr0 += 8, dst_ptr += 8) {
auto vec0_01 = vld1q_s16(sptr0);
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1);
auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
}
for (; i + 3 < channel_stride; i += 4, sptr0 += 4, dst_ptr += 4) {
auto vec0_0 = vld1_s16(sptr0);
auto vec0_0_f32 = vcvtq_f32_s32(vmovl_s16(vec0_0));
auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0_f32, vec1);
vst1q_f32(dst_ptr, dst_vec_0);
}
for (; i < channel_stride; ++i, ++sptr0, ++dst_ptr) {
*dst_ptr = (float)(*sptr0) * sptr1[chan] + sptr2[chan];
}
}
}
}
void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast101_bcast101(
size_t batch_size, size_t channel_size, size_t channel_stride,
const uint8_t* src0, const float* src1, const float* src2, float* dst) {
const uint8_t* __restrict sptr0 = src0;
const float* __restrict sptr1 = src1;
const float* __restrict sptr2 = src2;
float* __restrict dst_ptr = dst;
for (size_t batch = 0; batch < batch_size; ++batch) {
for (size_t chan = 0; chan < channel_size; ++chan) {
auto vec1 = vdupq_n_f32(sptr1[chan]);
auto vec2 = vdupq_n_f32(sptr2[chan]);
size_t i = 0;
for (; i + 15 < channel_stride; i += 16, sptr0 += 16, dst_ptr += 16) {
auto vec0_0123_u8 = vld1q_u8(sptr0);
auto vec0_01 =
vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(vec0_0123_u8)));
auto vec0_23 =
vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(vec0_0123_u8)));
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23)));
auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23)));
auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1);
auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1);
auto dst_vec_2 = Vfmaq_f32(vec2, vec0_2, vec1);
auto dst_vec_3 = Vfmaq_f32(vec2, vec0_3, vec1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
vst1q_f32(dst_ptr + 8, dst_vec_2);
vst1q_f32(dst_ptr + 12, dst_vec_3);
}
for (; i + 7 < channel_stride; i += 8, sptr0 += 8, dst_ptr += 8) {
auto vec0_01_u8 = vld1_u8(sptr0);
auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vec0_01_u8));
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1);
auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
}
for (; i < channel_stride; ++i, ++sptr0, ++dst_ptr) {
*dst_ptr = (float)(*sptr0) * sptr1[chan] + sptr2[chan];
}
}
}
}
void neon_fuse_mul_add3_int16xf32xf32xf32_vec_vec_vec(
size_t size, const int16_t* src0, const float* src1, const float* src2,
float* dst) {
const int16_t* __restrict sptr0 = src0;
const float* __restrict sptr1 = src1;
const float* __restrict sptr2 = src2;
float* __restrict dst_ptr = dst;
size_t i = 0;
for (; i + 15 < size;
i += 16, sptr0 += 16, sptr1 += 16, sptr2 += 16, dst_ptr += 16) {
auto vec0_01 = vld1q_s16(sptr0);
auto vec0_23 = vld1q_s16(sptr0 + 8);
auto vec1_0 = vld1q_f32(sptr1);
auto vec1_1 = vld1q_f32(sptr1 + 4);
auto vec1_2 = vld1q_f32(sptr1 + 8);
auto vec1_3 = vld1q_f32(sptr1 + 12);
auto vec2_0 = vld1q_f32(sptr2);
auto vec2_1 = vld1q_f32(sptr2 + 4);
auto vec2_2 = vld1q_f32(sptr2 + 8);
auto vec2_3 = vld1q_f32(sptr2 + 12);
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23)));
auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23)));
auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0);
auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1);
auto dst_vec_2 = Vfmaq_f32(vec2_2, vec0_2, vec1_2);
auto dst_vec_3 = Vfmaq_f32(vec2_3, vec0_3, vec1_3);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
vst1q_f32(dst_ptr + 8, dst_vec_2);
vst1q_f32(dst_ptr + 12, dst_vec_3);
}
for (; i + 7 < size; i += 8, sptr0 += 8, sptr1 += 8, sptr2 += 8, dst_ptr += 8) {
auto vec0_01 = vld1q_s16(sptr0);
auto vec1_0 = vld1q_f32(sptr1);
auto vec1_1 = vld1q_f32(sptr1 + 4);
auto vec2_0 = vld1q_f32(sptr2);
auto vec2_1 = vld1q_f32(sptr2 + 4);
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0);
auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
}
for (; i + 3 < size; i += 4, sptr0 += 4, sptr1 += 4, sptr2 += 4, dst_ptr += 4) {
auto vec0_0 = vld1_s16(sptr0);
auto vec1_0 = vld1q_f32(sptr1);
auto vec2_0 = vld1q_f32(sptr2);
auto vec0_0_f32 = vcvtq_f32_s32(vmovl_s16(vec0_0));
auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0_f32, vec1_0);
vst1q_f32(dst_ptr, dst_vec_0);
}
for (; i < size; ++i, ++sptr0, ++sptr1, ++sptr2, ++dst_ptr) {
*dst_ptr = (float)(*sptr0) * (*sptr1) + (*sptr2);
}
}
void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_vec_vec(
size_t size, const uint8_t* src0, const float* src1, const float* src2,
float* dst) {
const uint8_t* __restrict sptr0 = src0;
const float* __restrict sptr1 = src1;
const float* __restrict sptr2 = src2;
float* __restrict dst_ptr = dst;
size_t i = 0;
for (; i + 15 < size;
i += 16, sptr0 += 16, sptr1 += 16, sptr2 += 16, dst_ptr += 16) {
auto vec0_0123 = vld1q_u8(sptr0);
auto vec1_0 = vld1q_f32(sptr1);
auto vec1_1 = vld1q_f32(sptr1 + 4);
auto vec1_2 = vld1q_f32(sptr1 + 8);
auto vec1_3 = vld1q_f32(sptr1 + 12);
auto vec2_0 = vld1q_f32(sptr2);
auto vec2_1 = vld1q_f32(sptr2 + 4);
auto vec2_2 = vld1q_f32(sptr2 + 8);
auto vec2_3 = vld1q_f32(sptr2 + 12);
auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(vec0_0123)));
auto vec0_23 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(vec0_0123)));
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23)));
auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23)));
auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0);
auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1);
auto dst_vec_2 = Vfmaq_f32(vec2_2, vec0_2, vec1_2);
auto dst_vec_3 = Vfmaq_f32(vec2_3, vec0_3, vec1_3);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
vst1q_f32(dst_ptr + 8, dst_vec_2);
vst1q_f32(dst_ptr + 12, dst_vec_3);
}
for (; i + 7 < size; i += 8, sptr0 += 8, sptr1 += 8, sptr2 += 8, dst_ptr += 8) {
auto vec0_01_u8 = vld1_u8(sptr0);
auto vec1_0 = vld1q_f32(sptr1);
auto vec1_1 = vld1q_f32(sptr1 + 4);
auto vec2_0 = vld1q_f32(sptr2);
auto vec2_1 = vld1q_f32(sptr2 + 4);
auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vec0_01_u8));
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0);
auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
}
for (; i < size; ++i, ++sptr0, ++sptr1, ++sptr2, ++dst_ptr) {
*dst_ptr = (float)(*sptr0) * (*sptr1) + (*sptr2);
}
}
void neon_fuse_mul_add3_int16xf32xf32xf32_vec_scaler_scaler(
size_t size, const int16_t* src0, const float* src1, const float* src2,
float* dst) {
const int16_t* __restrict sptr0 = src0;
const float* __restrict sptr1 = src1;
const float* __restrict sptr2 = src2;
auto vec1 = vdupq_n_f32(sptr1[0]);
auto vec2 = vdupq_n_f32(sptr2[0]);
float* __restrict dst_ptr = dst;
size_t i = 0;
for (; i + 15 < size; i += 16, sptr0 += 16, dst_ptr += 16) {
auto vec0_01 = vld1q_s16(sptr0);
auto vec0_23 = vld1q_s16(sptr0 + 8);
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23)));
auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23)));
auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1);
auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1);
auto dst_vec_2 = Vfmaq_f32(vec2, vec0_2, vec1);
auto dst_vec_3 = Vfmaq_f32(vec2, vec0_3, vec1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
vst1q_f32(dst_ptr + 8, dst_vec_2);
vst1q_f32(dst_ptr + 12, dst_vec_3);
}
for (; i + 7 < size; i += 8, sptr0 += 8, sptr1 += 8, sptr2 += 8, dst_ptr += 8) {
auto vec0_01 = vld1q_s16(sptr0);
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1);
auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
}
for (; i + 3 < size; i += 4, sptr0 += 4, sptr1 += 4, sptr2 += 4, dst_ptr += 4) {
auto vec0_0 = vld1_s16(sptr0);
auto vec0_0_f32 = vcvtq_f32_s32(vmovl_s16(vec0_0));
auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0_f32, vec1);
vst1q_f32(dst_ptr, dst_vec_0);
}
for (; i < size; ++i, ++sptr0, ++dst_ptr) {
*dst_ptr = (float)(*sptr0) * (*sptr1) + (*sptr2);
}
}
void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_scaler_scaler(
size_t size, const uint8_t* src0, const float* src1, const float* src2,
float* dst) {
const uint8_t* __restrict sptr0 = src0;
const float* __restrict sptr1 = src1;
const float* __restrict sptr2 = src2;
auto vec1 = vdupq_n_f32(sptr1[0]);
auto vec2 = vdupq_n_f32(sptr2[0]);
float* __restrict dst_ptr = dst;
size_t i = 0;
for (; i + 15 < size; i += 16, sptr0 += 16, dst_ptr += 16) {
auto vec0_0123 = vld1q_u8(sptr0);
auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(vec0_0123)));
auto vec0_23 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(vec0_0123)));
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23)));
auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23)));
auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1);
auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1);
auto dst_vec_2 = Vfmaq_f32(vec2, vec0_2, vec1);
auto dst_vec_3 = Vfmaq_f32(vec2, vec0_3, vec1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
vst1q_f32(dst_ptr + 8, dst_vec_2);
vst1q_f32(dst_ptr + 12, dst_vec_3);
}
for (; i + 7 < size; i += 8, sptr0 += 8, dst_ptr += 8) {
auto vec0_01_u8 = vld1_u8(sptr0);
auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vec0_01_u8));
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1);
auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
}
for (; i < size; ++i, ++sptr0, ++dst_ptr) {
*dst_ptr = (float)(*sptr0) * (*sptr1) + (*sptr2);
}
}
void neon_mul_int16xf32xf32_vec_bcast111c(
size_t batch_size, size_t channel_stride, size_t channel_size,
const int16_t* src0, const float* src1, float* dst) {
const int16_t* __restrict sptr0 = src0;
const float* __restrict sptr1 = src1;
float* __restrict dst_ptr = dst;
for (size_t batch = 0; batch < batch_size; ++batch) {
for (size_t s = 0; s < channel_stride; ++s) {
size_t i = 0;
for (; i + 15 < channel_size; i += 16, sptr0 += 16, dst_ptr += 16) {
auto vec0_01 = vld1q_s16(sptr0);
auto vec0_23 = vld1q_s16(sptr0 + 8);
auto vec1_0 = vld1q_f32(sptr1 + i);
auto vec1_1 = vld1q_f32(sptr1 + i + 4);
auto vec1_2 = vld1q_f32(sptr1 + i + 8);
auto vec1_3 = vld1q_f32(sptr1 + i + 12);
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23)));
auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23)));
auto dst_vec_0 = vmulq_f32(vec0_0, vec1_0);
auto dst_vec_1 = vmulq_f32(vec0_1, vec1_1);
auto dst_vec_2 = vmulq_f32(vec0_2, vec1_2);
auto dst_vec_3 = vmulq_f32(vec0_3, vec1_3);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
vst1q_f32(dst_ptr + 8, dst_vec_2);
vst1q_f32(dst_ptr + 12, dst_vec_3);
}
for (; i + 7 < channel_size; i += 8, sptr0 += 8, dst_ptr += 8) {
auto vec0_01 = vld1q_s16(sptr0);
auto vec1_0 = vld1q_f32(sptr1 + i);
auto vec1_1 = vld1q_f32(sptr1 + i + 4);
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto dst_vec_0 = vmulq_f32(vec0_0, vec1_0);
auto dst_vec_1 = vmulq_f32(vec0_1, vec1_1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
}
for (; i < channel_size; ++i, ++sptr0, ++dst_ptr) {
*dst_ptr = (float)(*sptr0) * sptr1[i];
}
}
}
}
void neon_mul_int16xf32xf32_vec_bcast101(
size_t batch_size, size_t channel_size, size_t channel_stride,
const int16_t* src0, const float* src1, float* dst) {
const int16_t* __restrict sptr0 = src0;
const float* __restrict sptr1 = src1;
float* __restrict dst_ptr = dst;
for (size_t batch = 0; batch < batch_size; ++batch) {
for (size_t chan = 0; chan < channel_size; ++chan) {
auto vec1 = vdupq_n_f32(sptr1[chan]);
size_t i = 0;
for (; i + 15 < channel_stride; i += 16, sptr0 += 16, dst_ptr += 16) {
auto vec0_01 = vld1q_s16(sptr0);
auto vec0_23 = vld1q_s16(sptr0 + 8);
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23)));
auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23)));
auto dst_vec_0 = vmulq_f32(vec0_0, vec1);
auto dst_vec_1 = vmulq_f32(vec0_1, vec1);
auto dst_vec_2 = vmulq_f32(vec0_2, vec1);
auto dst_vec_3 = vmulq_f32(vec0_3, vec1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
vst1q_f32(dst_ptr + 8, dst_vec_2);
vst1q_f32(dst_ptr + 12, dst_vec_3);
}
for (; i + 7 < channel_stride; i += 8, sptr0 += 8, dst_ptr += 8) {
auto vec0_01 = vld1q_s16(sptr0);
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto dst_vec_0 = vmulq_f32(vec0_0, vec1);
auto dst_vec_1 = vmulq_f32(vec0_1, vec1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
}
for (; i < channel_stride; ++i, ++sptr0, ++dst_ptr) {
*dst_ptr = (float)(*sptr0) * sptr1[chan];
}
}
}
}
void neon_mul_int16xf32xf32_vec_vec(
size_t size, const int16_t* src0, const float* src1, float* dst) {
const int16_t* __restrict sptr0 = src0;
const float* __restrict sptr1 = src1;
float* __restrict dst_ptr = dst;
size_t i = 0;
for (; i + 15 < size; i += 16, sptr0 += 16, sptr1 += 16, dst_ptr += 16) {
auto vec0_01 = vld1q_s16(sptr0);
auto vec0_23 = vld1q_s16(sptr0 + 8);
auto vec1_0 = vld1q_f32(sptr1);
auto vec1_1 = vld1q_f32(sptr1 + 4);
auto vec1_2 = vld1q_f32(sptr1 + 8);
auto vec1_3 = vld1q_f32(sptr1 + 12);
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23)));
auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23)));
auto dst_vec_0 = vmulq_f32(vec0_0, vec1_0);
auto dst_vec_1 = vmulq_f32(vec0_1, vec1_1);
auto dst_vec_2 = vmulq_f32(vec0_2, vec1_2);
auto dst_vec_3 = vmulq_f32(vec0_3, vec1_3);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
vst1q_f32(dst_ptr + 8, dst_vec_2);
vst1q_f32(dst_ptr + 12, dst_vec_3);
}
for (; i + 7 < size; i += 8, sptr0 += 8, sptr1 += 8, dst_ptr += 8) {
auto vec0_01 = vld1q_s16(sptr0);
auto vec1_0 = vld1q_f32(sptr1);
auto vec1_1 = vld1q_f32(sptr1 + 4);
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto dst_vec_0 = vmulq_f32(vec0_0, vec1_0);
auto dst_vec_1 = vmulq_f32(vec0_1, vec1_1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
}
for (; i < size; ++i, ++sptr0, ++sptr1, ++dst_ptr) {
*dst_ptr = (float)(*sptr0) * (*sptr1);
}
}
void neon_mul_int16xf32xf32_vec_scaler(
size_t size, const int16_t* src0, const float* src1, float* dst) {
const int16_t* __restrict sptr0 = src0;
const float* __restrict sptr1 = src1;
auto vec1 = vdupq_n_f32(sptr1[0]);
float* __restrict dst_ptr = dst;
size_t i = 0;
for (; i + 15 < size; i += 16, sptr0 += 16, dst_ptr += 16) {
auto vec0_01 = vld1q_s16(sptr0);
auto vec0_23 = vld1q_s16(sptr0 + 8);
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23)));
auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23)));
auto dst_vec_0 = vmulq_f32(vec0_0, vec1);
auto dst_vec_1 = vmulq_f32(vec0_1, vec1);
auto dst_vec_2 = vmulq_f32(vec0_2, vec1);
auto dst_vec_3 = vmulq_f32(vec0_3, vec1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
vst1q_f32(dst_ptr + 8, dst_vec_2);
vst1q_f32(dst_ptr + 12, dst_vec_3);
}
for (; i + 7 < size; i += 8, sptr0 += 8, dst_ptr += 8) {
auto vec0_01 = vld1q_s16(sptr0);
auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01)));
auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01)));
auto dst_vec_0 = vmulq_f32(vec0_0, vec1);
auto dst_vec_1 = vmulq_f32(vec0_1, vec1);
vst1q_f32(dst_ptr, dst_vec_0);
vst1q_f32(dst_ptr + 4, dst_vec_1);
}
for (; i < size; ++i, ++sptr0, ++dst_ptr) {
*dst_ptr = (float)(*sptr0) * (*sptr1);
}
}
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/elemwise_multi_type/kernels.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "stddef.h"
#include "stdint.h"
namespace megdnn {
namespace arm_common {
void neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast111c_bcast111c(
size_t batch_size, size_t channel_stride, size_t channel_size,
const int16_t* src0, const float* src1, const float* src2, float* dst);
void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast111c_bcast111c(
size_t batch_size, size_t channel_stride, size_t channel_size,
const uint8_t* src0, const float* src1, const float* src2, float* dst);
void neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast101_bcast101(
size_t batch_size, size_t channel_size, size_t channel_stride,
const int16_t* src0, const float* src1, const float* src2, float* dst);
void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast101_bcast101(
size_t batch_size, size_t channel_size, size_t channel_stride,
const uint8_t* src0, const float* src1, const float* src2, float* dst);
void neon_fuse_mul_add3_int16xf32xf32xf32_vec_vec_vec(
size_t size, const int16_t* src0, const float* src1, const float* src2,
float* dst);
void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_vec_vec(
size_t size, const uint8_t* src0, const float* src1, const float* src2,
float* dst);
void neon_fuse_mul_add3_int16xf32xf32xf32_vec_b1x_b1x(
size_t size, size_t vec, const int16_t* src0, const float* src1,
const float* src2, float* dst);
void neon_fuse_mul_add3_int16xf32xf32xf32_vec_scaler_scaler(
size_t size, const int16_t* src0, const float* src1, const float* src2,
float* dst);
void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_scaler_scaler(
size_t size, const uint8_t* src0, const float* src1, const float* src2,
float* dst);
void neon_mul_int16xf32xf32_vec_bcast111c(
size_t batch_size, size_t channel_stride, size_t channel_size,
const int16_t* src0, const float* src1, float* dst);
void neon_mul_int16xf32xf32_vec_bcast101(
size_t batch_size, size_t channel_size, size_t channel_stride,
const int16_t* src0, const float* src1, float* dst);
void neon_mul_int16xf32xf32_vec_vec(
size_t size, const int16_t* src0, const float* src1, float* dst);
void neon_mul_int16xf32xf32_vec_scaler(
size_t size, const int16_t* src0, const float* src1, float* dst);
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -11,6 +11,7 @@
*/
#include "./opr_impl.h"
#include "kernels.h"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "src/naive/handle.h"
......@@ -851,6 +852,154 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
#undef DISPATCH_QUANTIZED_MODE
}
void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
auto &src0 = param[0], &src1 = param[1], &src2 = param[2];
BroadcastChannelInfo binfo;
if (is_vector(src0.layout) &&
is_NHWC_broadcasted_channel_like(src1.layout, binfo) &&
src1.layout.eq_layout(src2.layout)) {
// VEC_BCAST111C_BCAST111C
MEGDNN_DISPATCH_CPU_KERN_OPR(
neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast111c_bcast111c(
binfo.x, binfo.y, binfo.z,
static_cast<dt_int16*>(src0.raw_ptr()),
static_cast<dt_float32*>(src1.raw_ptr()),
static_cast<dt_float32*>(src2.raw_ptr()),
dst.ptr<dt_float32>()));
return;
} else if (
is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo) &&
src1.layout.eq_layout(src2.layout)) {
// VEC_BCAST101_BCAST101
MEGDNN_DISPATCH_CPU_KERN_OPR(
neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast101_bcast101(
binfo.x, binfo.y, binfo.z,
static_cast<dt_int16*>(src0.raw_ptr()),
static_cast<dt_float32*>(src1.raw_ptr()),
static_cast<dt_float32*>(src2.raw_ptr()),
dst.ptr<dt_float32>()));
return;
} else if (
is_vector(src0.layout) && is_vector(src1.layout) &&
is_vector(src2.layout)) {
// VEC_VEC_VEC
auto size = param.size;
MEGDNN_DISPATCH_CPU_KERN_OPR(neon_fuse_mul_add3_int16xf32xf32xf32_vec_vec_vec(
size, static_cast<dt_int16*>(src0.raw_ptr()),
static_cast<dt_float32*>(src1.raw_ptr()),
static_cast<dt_float32*>(src2.raw_ptr()), dst.ptr<dt_float32>()));
return;
} else if (
is_vector(src0.layout) && is_broadcasted_scalar(src1.layout) &&
is_broadcasted_scalar(src2.layout)) {
// VEC_SCALAR_SCALAR
auto size = param.size;
MEGDNN_DISPATCH_CPU_KERN_OPR(
neon_fuse_mul_add3_int16xf32xf32xf32_vec_scaler_scaler(
size, static_cast<dt_int16*>(src0.raw_ptr()),
static_cast<dt_float32*>(src1.raw_ptr()),
static_cast<dt_float32*>(src2.raw_ptr()),
dst.ptr<dt_float32>()));
return;
}
naive::ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32(param, dst);
}
void ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
auto &src0 = param[0], &src1 = param[1], &src2 = param[2];
BroadcastChannelInfo binfo;
if (is_vector(src0.layout) &&
is_NHWC_broadcasted_channel_like(src1.layout, binfo) &&
src1.layout.eq_layout(src2.layout)) {
// VEC_BCAST111C_BCAST111C
MEGDNN_DISPATCH_CPU_KERN_OPR(
neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast111c_bcast111c(
binfo.x, binfo.y, binfo.z,
static_cast<dt_uint8*>(src0.raw_ptr()),
static_cast<dt_float32*>(src1.raw_ptr()),
static_cast<dt_float32*>(src2.raw_ptr()),
dst.ptr<dt_float32>()));
return;
} else if (
is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo) &&
src1.layout.eq_layout(src2.layout)) {
// VEC_BCAST101_BCAST101
MEGDNN_DISPATCH_CPU_KERN_OPR(
neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast101_bcast101(
binfo.x, binfo.y, binfo.z,
static_cast<dt_uint8*>(src0.raw_ptr()),
static_cast<dt_float32*>(src1.raw_ptr()),
static_cast<dt_float32*>(src2.raw_ptr()),
dst.ptr<dt_float32>()));
return;
} else if (
is_vector(src0.layout) && is_vector(src1.layout) &&
is_vector(src2.layout)) {
// VEC_VEC_VEC
auto size = param.size;
MEGDNN_DISPATCH_CPU_KERN_OPR(neon_fuse_mul_add3_uint8xf32xf32xf32_vec_vec_vec(
size, static_cast<dt_uint8*>(src0.raw_ptr()),
static_cast<dt_float32*>(src1.raw_ptr()),
static_cast<dt_float32*>(src2.raw_ptr()), dst.ptr<dt_float32>()));
return;
} else if (
is_vector(src0.layout) && is_broadcasted_scalar(src1.layout) &&
is_broadcasted_scalar(src2.layout)) {
// VEC_SCALAR_SCALAR
auto size = param.size;
MEGDNN_DISPATCH_CPU_KERN_OPR(
neon_fuse_mul_add3_uint8xf32xf32xf32_vec_scaler_scaler(
size, static_cast<dt_uint8*>(src0.raw_ptr()),
static_cast<dt_float32*>(src1.raw_ptr()),
static_cast<dt_float32*>(src2.raw_ptr()),
dst.ptr<dt_float32>()));
return;
}
naive::ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32(param, dst);
}
void ElemwiseMultiTypeImpl::on_mul_int16xf32xf32(
const ElemwiseOpParamN<2>& param, const TensorND& dst) {
auto &src0 = param[0], &src1 = param[1];
BroadcastChannelInfo binfo;
if (is_vector(src0.layout) &&
is_NHWC_broadcasted_channel_like(src1.layout, binfo)) {
// VEC_BCAST111C
MEGDNN_DISPATCH_CPU_KERN_OPR(neon_mul_int16xf32xf32_vec_bcast111c(
binfo.x, binfo.y, binfo.z, static_cast<dt_int16*>(src0.raw_ptr()),
static_cast<dt_float32*>(src1.raw_ptr()), dst.ptr<dt_float32>()));
return;
} else if (
is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo)) {
// VEC_BCAST101
MEGDNN_DISPATCH_CPU_KERN_OPR(neon_mul_int16xf32xf32_vec_bcast101(
binfo.x, binfo.y, binfo.z, static_cast<dt_int16*>(src0.raw_ptr()),
static_cast<dt_float32*>(src1.raw_ptr()), dst.ptr<dt_float32>()));
return;
} else if (is_vector(src0.layout) && is_vector(src1.layout)) {
// VEC_VEC
auto size = param.size;
MEGDNN_DISPATCH_CPU_KERN_OPR(neon_mul_int16xf32xf32_vec_vec(
size, static_cast<dt_int16*>(src0.raw_ptr()),
static_cast<dt_float32*>(src1.raw_ptr()), dst.ptr<dt_float32>()));
return;
} else if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout)) {
auto size = param.size;
MEGDNN_DISPATCH_CPU_KERN_OPR(neon_mul_int16xf32xf32_vec_scaler(
size, static_cast<dt_int16*>(src0.raw_ptr()),
static_cast<dt_float32*>(src1.raw_ptr()), dst.ptr<dt_float32>()));
return;
}
naive::ElemwiseMultiTypeImpl::on_mul_int16xf32xf32(param, dst);
}
} // namespace arm_common
} // namespace megdnn
......
......@@ -48,6 +48,15 @@ protected:
const ElemwiseOpParamN<3>& param, const TensorND& dst,
Elemwise::Mode mode) override;
void on_fuse_mul_add3_int16xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) override;
void on_mul_int16xf32xf32(
const ElemwiseOpParamN<2>& param, const TensorND& dst) override;
void on_fuse_mul_add3_uint8xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) override;
public:
using fallback::ElemwiseMultiTypeImpl::ElemwiseMultiTypeImpl;
};
......
......@@ -155,6 +155,29 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
dst.name = name;
dst.need_specify_out_dtype = true;
};
auto init_fma3_int16xf32xf32xf32 = [&](ModeTrait& dst, const char* name) {
dst.arity = 3;
dst.check_inp[0] = make_check_dtype_func(dtype::Int16());
dst.check_inp[1] = make_check_dtype_func(dtype::Float32());
dst.check_inp[2] = make_check_dtype_func(dtype::Float32());
dst.check_out = make_out_dtype_func(dtype::Float32());
dst.name = name;
};
auto init_mul_int16xf32xf32 = [&](ModeTrait& dst, const char* name) {
dst.arity = 2;
dst.check_inp[0] = make_check_dtype_func(dtype::Int16());
dst.check_inp[1] = make_check_dtype_func(dtype::Float32());
dst.check_out = make_out_dtype_func(dtype::Float32());
dst.name = name;
};
auto init_fma3_uint8xf32xf32xf32 = [&](ModeTrait& dst, const char* name) {
dst.arity = 3;
dst.check_inp[0] = make_check_dtype_func(dtype::Uint8());
dst.check_inp[1] = make_check_dtype_func(dtype::Float32());
dst.check_inp[2] = make_check_dtype_func(dtype::Float32());
dst.check_out = make_out_dtype_func(dtype::Float32());
dst.name = name;
};
#define SET(f, m) \
MIDOUT_BEGIN(megdnn_common_elemwise_multi_type, midout_iv(Mode::m)) { \
......@@ -169,6 +192,9 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
SET(init_fuse_add_rmulh_rshr_int32x32x32x8,
FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8);
SET(init_rshrs_iXxi8xi16, ROUND_SHR_SATURATE_IXxI8xI16);
SET(init_fma3_int16xf32xf32xf32, FUSE_MUL_ADD3_INT16xF32xF32xF32);
SET(init_mul_int16xf32xf32, MUL_INT16xF32xF32);
SET(init_fma3_uint8xf32xf32xf32, FUSE_MUL_ADD3_UINT8xF32xF32xF32);
//! quantized opr, with specified dtype.
//! dispatch elemwise mode internally
......
......@@ -43,6 +43,17 @@ void ElemwiseMultiTypeImplHelper::exec(
case Mode::ROUND_SHR_SATURATE_IXxI8xI16:
on_round_shr_saturate_iXxi8xi16(make_elemwise_op_param<2>(src, dst), dst);
break;
case Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32:
on_fuse_mul_add3_int16xf32xf32xf32(
make_elemwise_op_param<3>(src, dst), dst);
break;
case Mode::MUL_INT16xF32xF32:
on_mul_int16xf32xf32(make_elemwise_op_param<2>(src, dst), dst);
break;
case Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32:
on_fuse_mul_add3_uint8xf32xf32xf32(
make_elemwise_op_param<3>(src, dst), dst);
break;
ON_QUANTIZED_MODE(RELU, 1);
ON_QUANTIZED_MODE(ABS, 1);
ON_QUANTIZED_MODE(ACOS, 1);
......
......@@ -50,6 +50,27 @@ protected:
virtual void on_round_shr_saturate_iXxi8xi16(
const ElemwiseOpParamN<2>& param, const TensorND& dst) = 0;
virtual void on_fuse_mul_add3_int16xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
MEGDNN_MARK_USED_VAR(param);
MEGDNN_MARK_USED_VAR(dst);
megdnn_throw("unsupported ElemwiseMultiType fma3 int16xf32xf32xf32.");
}
virtual void on_mul_int16xf32xf32(
const ElemwiseOpParamN<2>& param, const TensorND& dst) {
MEGDNN_MARK_USED_VAR(param);
MEGDNN_MARK_USED_VAR(dst);
megdnn_throw("unsupported ElemwiseMultiType fma3 int16xf32xf32.");
}
virtual void on_fuse_mul_add3_uint8xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
MEGDNN_MARK_USED_VAR(param);
MEGDNN_MARK_USED_VAR(dst);
megdnn_throw("unsupported ElemwiseMultiType fma3 uint8xf32xf32xf32.");
}
virtual void on_quantized_mode(
const ElemwiseOpParamN<1>& param, const TensorND& dst,
Elemwise::Mode mode) {
......
......@@ -56,6 +56,216 @@ void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32(
naive::ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32(param, dst);
}
void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
BroadcastChannelInfo binfo0, binfo1;
if (is_vector(param[0].layout) &&
is_NHWC_broadcasted_channel_like(param[1].layout, binfo0) &&
is_NHWC_broadcasted_channel_like(param[2].layout, binfo1) && binfo0 == binfo1) {
auto x = binfo0.x, y = binfo0.y, z = binfo0.z;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto work = [=]() {
const dt_int16* __restrict__ a = static_cast<dt_int16*>(src0.raw_ptr());
const dt_float32* __restrict__ b = static_cast<dt_float32*>(src1.raw_ptr());
const dt_float32* __restrict__ c = static_cast<dt_float32*>(src2.raw_ptr());
dt_float32* __restrict__ d = dst.ptr<dt_float32>();
for (size_t i = 0; i < x; ++i) {
for (size_t j = 0; j < y; ++j) {
auto off = i * (y * z) + j * z;
size_t k = 0;
for (; k + 4 <= z; k += 4) {
d[off + k + 0] = a[off + k + 0] * b[k + 0] + c[k + 0];
d[off + k + 1] = a[off + k + 1] * b[k + 1] + c[k + 1];
d[off + k + 2] = a[off + k + 2] * b[k + 2] + c[k + 2];
d[off + k + 3] = a[off + k + 3] * b[k + 3] + c[k + 3];
}
for (; k < z; ++k) {
d[off + k] = a[off + k] * b[k] + c[k];
}
}
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
return;
} else if (
is_vector(param[0].layout) &&
is_broadcasted_channel_like(param[1].layout, binfo0) &&
is_broadcasted_channel_like(param[2].layout, binfo1) && binfo0 == binfo1) {
auto x = binfo0.x, y = binfo0.y, z = binfo0.z;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto work = [=]() {
const dt_int16* __restrict__ a = static_cast<dt_int16*>(src0.raw_ptr());
const dt_float32* __restrict__ b = static_cast<dt_float32*>(src1.raw_ptr());
const dt_float32* __restrict__ c = static_cast<dt_float32*>(src2.raw_ptr());
dt_float32* __restrict__ d = dst.ptr<dt_float32>();
for (size_t j = 0; j < y; ++j) {
auto bv = b[j], cv = c[j];
for (size_t i = 0; i < x; ++i) {
auto off = i * (y * z) + j * z, offt = off + z;
for (; off + 4 <= offt; off += 4) {
d[off + 0] = a[off + 0] * bv + cv;
d[off + 1] = a[off + 1] * bv + cv;
d[off + 2] = a[off + 2] * bv + cv;
d[off + 3] = a[off + 3] * bv + cv;
}
for (; off < offt; ++off) {
d[off] = a[off] * bv + cv;
}
}
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
return;
}
naive::ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32(param, dst);
}
void ElemwiseMultiTypeImpl::on_mul_int16xf32xf32(
const ElemwiseOpParamN<2>& param, const TensorND& dst) {
BroadcastChannelInfo binfo;
if (is_vector(param[0].layout) &&
is_NHWC_broadcasted_channel_like(param[1].layout, binfo)) {
auto x = binfo.x, y = binfo.y, z = binfo.z;
auto src0 = param[0];
auto src1 = param[1];
auto work = [=]() {
const dt_int16* __restrict__ a = static_cast<dt_int16*>(src0.raw_ptr());
const dt_float32* __restrict__ b = static_cast<dt_float32*>(src1.raw_ptr());
dt_float32* __restrict__ d = dst.ptr<dt_float32>();
for (size_t i = 0; i < x; ++i) {
for (size_t j = 0; j < y; ++j) {
auto off = i * (y * z) + j * z;
size_t k = 0;
for (; k + 4 <= z; k += 4) {
d[off + k + 0] = a[off + k + 0] * b[k + 0];
d[off + k + 1] = a[off + k + 1] * b[k + 1];
d[off + k + 2] = a[off + k + 2] * b[k + 2];
d[off + k + 3] = a[off + k + 3] * b[k + 3];
}
for (; k < z; ++k) {
d[off + k] = a[off + k] * b[k];
}
}
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
return;
} else if (
is_vector(param[0].layout) &&
is_broadcasted_channel_like(param[1].layout, binfo)) {
auto x = binfo.x, y = binfo.y, z = binfo.z;
auto src0 = param[0];
auto src1 = param[1];
auto work = [=]() {
const dt_int16* __restrict__ a = static_cast<dt_int16*>(src0.raw_ptr());
const dt_float32* __restrict__ b = static_cast<dt_float32*>(src1.raw_ptr());
dt_float32* __restrict__ d = dst.ptr<dt_float32>();
for (size_t j = 0; j < y; ++j) {
auto bv = b[j];
for (size_t i = 0; i < x; ++i) {
auto off = i * (y * z) + j * z, offt = off + z;
for (; off + 4 <= offt; off += 4) {
d[off + 0] = a[off + 0] * bv;
d[off + 1] = a[off + 1] * bv;
d[off + 2] = a[off + 2] * bv;
d[off + 3] = a[off + 3] * bv;
}
for (; off < offt; ++off) {
d[off] = a[off] * bv;
}
}
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
return;
}
naive::ElemwiseMultiTypeImpl::on_mul_int16xf32xf32(param, dst);
}
void ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
BroadcastChannelInfo binfo0, binfo1;
if (is_vector(param[0].layout) &&
is_NHWC_broadcasted_channel_like(param[1].layout, binfo0) &&
is_NHWC_broadcasted_channel_like(param[2].layout, binfo1) && binfo0 == binfo1) {
auto x = binfo0.x, y = binfo0.y, z = binfo0.z;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto work = [=]() {
const dt_uint8* __restrict__ a = static_cast<dt_uint8*>(src0.raw_ptr());
const dt_float32* __restrict__ b = static_cast<dt_float32*>(src1.raw_ptr());
const dt_float32* __restrict__ c = static_cast<dt_float32*>(src2.raw_ptr());
dt_float32* __restrict__ d = dst.ptr<dt_float32>();
for (size_t i = 0; i < x; ++i) {
for (size_t j = 0; j < y; ++j) {
auto off = i * (y * z) + j * z;
size_t k = 0;
for (; k + 4 <= z; k += 4) {
d[off + k + 0] = a[off + k + 0] * b[k + 0] + c[k + 0];
d[off + k + 1] = a[off + k + 1] * b[k + 1] + c[k + 1];
d[off + k + 2] = a[off + k + 2] * b[k + 2] + c[k + 2];
d[off + k + 3] = a[off + k + 3] * b[k + 3] + c[k + 3];
}
for (; k < z; ++k) {
d[off + k] = a[off + k] * b[k] + c[k];
}
}
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
return;
} else if (
is_vector(param[0].layout) &&
is_broadcasted_channel_like(param[1].layout, binfo0) &&
is_broadcasted_channel_like(param[2].layout, binfo1) && binfo0 == binfo1) {
auto x = binfo0.x, y = binfo0.y, z = binfo0.z;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto work = [=]() {
const dt_uint8* __restrict__ a = static_cast<dt_uint8*>(src0.raw_ptr());
const dt_float32* __restrict__ b = static_cast<dt_float32*>(src1.raw_ptr());
const dt_float32* __restrict__ c = static_cast<dt_float32*>(src2.raw_ptr());
dt_float32* __restrict__ d = dst.ptr<dt_float32>();
for (size_t j = 0; j < y; ++j) {
auto bv = b[j], cv = c[j];
for (size_t i = 0; i < x; ++i) {
auto off = i * (y * z) + j * z, offt = off + z;
for (; off + 4 <= offt; off += 4) {
d[off + 0] = a[off + 0] * bv + cv;
d[off + 1] = a[off + 1] * bv + cv;
d[off + 2] = a[off + 2] * bv + cv;
d[off + 3] = a[off + 3] * bv + cv;
}
for (; off < offt; ++off) {
d[off] = a[off] * bv + cv;
}
}
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
return;
}
naive::ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32(param, dst);
}
template <typename ctype>
void ElemwiseMultiTypeImpl::dispatch_fma3_iXxf32xf32xi8_bcast_1x(
const ElemwiseOpParamN<3>& param, const Broadcast1xInfo& binfo,
......
......@@ -43,6 +43,12 @@ protected:
const ElemwiseOpParamN<6>& param, const TensorND& dst) override;
void on_round_shr_saturate_iXxi8xi16(
const ElemwiseOpParamN<2>& param, const TensorND& dst) override;
void on_fuse_mul_add3_int16xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) override;
void on_mul_int16xf32xf32(
const ElemwiseOpParamN<2>& param, const TensorND& dst) override;
void on_fuse_mul_add3_uint8xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) override;
public:
using naive::ElemwiseMultiTypeImpl::ElemwiseMultiTypeImpl;
......
......@@ -39,6 +39,66 @@ void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32(
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}
void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto work = [src0, src1, src2, size, dst]() {
auto i0 = tensor_iter_valonly<dt_int16>(src0).begin();
auto i1 = tensor_iter_valonly<dt_float32>(src1).begin();
auto i2 = tensor_iter_valonly<dt_float32>(src2).begin();
auto dst_ptr = dst.ptr<dt_float32>();
for (size_t i = 0; i < size; ++i) {
dst_ptr[i] = (*i0) * (*i1) + (*i2);
++i0;
++i1;
++i2;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}
void ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto work = [src0, src1, src2, size, dst]() {
auto i0 = tensor_iter_valonly<dt_uint8>(src0).begin();
auto i1 = tensor_iter_valonly<dt_float32>(src1).begin();
auto i2 = tensor_iter_valonly<dt_float32>(src2).begin();
auto dst_ptr = dst.ptr<dt_float32>();
for (size_t i = 0; i < size; ++i) {
dst_ptr[i] = (*i0) * (*i1) + (*i2);
++i0;
++i1;
++i2;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}
void ElemwiseMultiTypeImpl::on_mul_int16xf32xf32(
const ElemwiseOpParamN<2>& param, const TensorND& dst) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto work = [src0, src1, size, dst]() {
auto i0 = tensor_iter_valonly<dt_int16>(src0).begin();
auto i1 = tensor_iter_valonly<dt_float32>(src1).begin();
auto dst_ptr = dst.ptr<dt_float32>();
for (size_t i = 0; i < size; ++i) {
dst_ptr[i] = (*i0) * (*i1);
++i0;
++i1;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}
void ElemwiseMultiTypeImpl::on_fuse_mul_add3_iXxf32xf32xi8(
const ElemwiseOpParamN<3>& param, const TensorND& dst) {
switch (param[0].layout.dtype.enumv()) {
......
......@@ -60,6 +60,12 @@ protected:
const ElemwiseOpParamN<6>& param, const TensorND& dst) override;
void on_round_shr_saturate_iXxi8xi16(
const ElemwiseOpParamN<2>& param, const TensorND& dst) override;
void on_fuse_mul_add3_int16xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) override;
void on_mul_int16xf32xf32(
const ElemwiseOpParamN<2>& param, const TensorND& dst) override;
void on_fuse_mul_add3_uint8xf32xf32xf32(
const ElemwiseOpParamN<3>& param, const TensorND& dst) override;
void on_quantized_mode(
const ElemwiseOpParamN<1>& param, const TensorND& dst,
......
......@@ -456,4 +456,107 @@ TEST_F(ARM_COMMON, ELEMWISE_QUANTIZED_MODE_TERNARY_RECORD) {
}
}
TEST_F(ARM_COMMON, ELEMWISE_FMA3_INT16xF32xF32xF32) {
Checker<ElemwiseMultiType> checker(handle());
checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32});
checker.set_dtype(0, dtype::Int16());
checker.set_dtype(1, dtype::Float32());
checker.set_dtype(2, dtype::Float32());
UniformIntRNG rng{-100, 100};
checker.set_rng(0, &rng);
checker.set_rng(1, &rng);
checker.set_rng(2, &rng);
checker.execs({{5, 7, 16}, {1, 1, 16}, {1, 1, 16}, {}})
.execs({{2, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}})
.execs({{2, 700, 600}, {2, 700, 600}, {2, 700, 600}, {}})
.execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}})
.execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}});
}
TEST_F(ARM_COMMON, ELEMWISE_FMA3_INT16xF32xF32xF32_RECORD) {
TaskRecordChecker<ElemwiseMultiType> checker(0);
checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32});
checker.set_dtype(0, dtype::Int16());
checker.set_dtype(1, dtype::Float32());
checker.set_dtype(2, dtype::Float32());
UniformIntRNG rng{-100, 100};
checker.set_rng(0, &rng);
checker.set_rng(1, &rng);
checker.set_rng(2, &rng);
checker.execs({{5, 7, 16}, {1, 1, 16}, {1, 1, 16}, {}})
.execs({{2, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}})
.execs({{2, 700, 600}, {2, 700, 600}, {2, 700, 600}, {}})
.execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}})
.execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}})
.execs({{16, 128, 16, 18}, {1, 1, 1, 18}, {1, 1, 1, 18}, {}})
.execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}});
}
TEST_F(ARM_COMMON, ELEMWISE_MUL_INT16xF32xF32) {
Checker<ElemwiseMultiType> checker(handle());
checker.set_param({ElemwiseMultiType::Mode::MUL_INT16xF32xF32});
checker.set_dtype(0, dtype::Int16());
checker.set_dtype(1, dtype::Float32());
UniformIntRNG rng{-100, 100};
checker.set_rng(0, &rng);
checker.set_rng(1, &rng);
checker.execs({{5, 7, 16}, {1, 1, 16}, {}})
.execs({{2, 700, 600}, {1, 1, 600}, {}})
.execs({{2, 700, 600}, {2, 700, 600}, {}})
.execs({{16, 16, 128}, {16, 16, 128}, {}})
.execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {}});
}
TEST_F(ARM_COMMON, ELEMWISE_ELEMWISE_MUL_INT16xF32xF32_RECORD) {
TaskRecordChecker<ElemwiseMultiType> checker(0);
checker.set_param({ElemwiseMultiType::Mode::MUL_INT16xF32xF32});
checker.set_dtype(0, dtype::Int16());
checker.set_dtype(1, dtype::Float32());
UniformIntRNG rng{-100, 100};
checker.set_rng(0, &rng);
checker.set_rng(1, &rng);
checker.execs({{5, 7, 16}, {1, 1, 16}, {}})
.execs({{2, 700, 600}, {1, 1, 600}, {}})
.execs({{2, 700, 600}, {2, 700, 600}, {}})
.execs({{16, 16, 128}, {16, 16, 128}, {}})
.execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {}})
.execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {}});
}
TEST_F(ARM_COMMON, ELEMWISE_FMA3_UINT8xF32xF32xF32) {
Checker<ElemwiseMultiType> checker(handle());
checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32});
checker.set_dtype(0, dtype::Uint8());
checker.set_dtype(1, dtype::Float32());
checker.set_dtype(2, dtype::Float32());
UniformIntRNG rng{-100, 100};
checker.set_rng(0, &rng);
checker.set_rng(1, &rng);
checker.set_rng(2, &rng);
checker.execs({{5, 7, 16}, {1, 1, 16}, {1, 1, 16}, {}})
.execs({{2, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}})
.execs({{2, 700, 600}, {2, 700, 600}, {2, 700, 600}, {}})
.execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}})
.execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}})
.execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}});
}
TEST_F(ARM_COMMON, ELEMWISE_FMA3_UINT8xF32xF32xF32_RECORD) {
TaskRecordChecker<ElemwiseMultiType> checker(0);
checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32});
checker.set_dtype(0, dtype::Uint8());
checker.set_dtype(1, dtype::Float32());
checker.set_dtype(2, dtype::Float32());
UniformIntRNG rng{-100, 100};
checker.set_rng(0, &rng);
checker.set_rng(1, &rng);
checker.set_rng(2, &rng);
checker.execs({{5, 7, 16}, {1, 1, 16}, {1, 1, 16}, {}})
.execs({{2, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}})
.execs({{2, 700, 600}, {2, 700, 600}, {2, 700, 600}, {}})
.execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}})
.execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}})
.execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}});
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -79,6 +79,73 @@ DEF_TEST(fuse_mul_add3_int16x32x32x32) {
.execs({{102, 67, 71}, {1, 67, 1}, {1, 67, 1}, {}});
}
DEF_TEST(fuse_mul_add3_int16xf32xf32xf32) {
// This is not implemented on CUDA.
if (handle->type() == Handle::HandleType::CUDA) {
return;
}
Checker<ElemwiseMultiType> checker(handle);
checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32});
checker.set_dtype(0, dtype::Int16());
checker.set_dtype(1, dtype::Float32());
checker.set_dtype(2, dtype::Float32());
UniformIntRNG rng{-100, 100};
checker.set_rng(0, &rng);
checker.set_rng(1, &rng);
checker.set_rng(2, &rng);
checker.execs({{5, 7, 6}, {1, 1, 6}, {1, 1, 6}, {}})
.execs({{1, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}})
.execs({{1, 700, 600}, {1, 700, 600}, {1, 700, 600}, {}})
.execs({{102, 71, 67}, {1, 1, 67}, {1, 1, 67}, {}})
.execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}})
.execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}})
.execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}});
}
DEF_TEST(fuse_mul_add3_uint8xf32xf32xf32) {
// This is not implemented on CUDA.
if (handle->type() == Handle::HandleType::CUDA) {
return;
}
Checker<ElemwiseMultiType> checker(handle);
checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32});
checker.set_dtype(0, dtype::Uint8());
checker.set_dtype(1, dtype::Float32());
checker.set_dtype(2, dtype::Float32());
UniformIntRNG rng{-100, 100};
checker.set_rng(0, &rng);
checker.set_rng(1, &rng);
checker.set_rng(2, &rng);
checker.execs({{5, 7, 6}, {1, 1, 6}, {1, 1, 6}, {}})
.execs({{1, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}})
.execs({{1, 700, 600}, {1, 700, 600}, {1, 700, 600}, {}})
.execs({{102, 71, 67}, {1, 1, 67}, {1, 1, 67}, {}})
.execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}})
.execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}})
.execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}});
}
DEF_TEST(fuse_mul_add3_int16xf32xf32) {
// This is not implemented on CUDA.
if (handle->type() == Handle::HandleType::CUDA) {
return;
}
Checker<ElemwiseMultiType> checker(handle);
checker.set_param({ElemwiseMultiType::Mode::MUL_INT16xF32xF32});
checker.set_dtype(0, dtype::Int16());
checker.set_dtype(1, dtype::Float32());
UniformIntRNG rng{-100, 100};
checker.set_rng(0, &rng);
checker.set_rng(1, &rng);
checker.execs({{5, 7, 6}, {1, 1, 6}, {}})
.execs({{1, 700, 600}, {1, 1, 600}, {}})
.execs({{1, 700, 600}, {1, 700, 600}, {}})
.execs({{102, 71, 67}, {1, 1, 67}, {}})
.execs({{16, 16, 128}, {16, 16, 128}, {}})
.execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {}})
.execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {}});
}
DEF_TEST(fuse_mul_add3_iXxf32xf32xi8) {
Checker<ElemwiseMultiType> checker(handle);
checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_IXxF32xF32xI8});
......
......@@ -20,10 +20,13 @@ namespace test {
namespace elemwise_multi_type {
#define FIRST_ELEMWISE_MULTI_TYPE_CASE fuse_mul_add3_int16x32x32x32
#define FOREACH_ELEMWISE_MULTI_TYPE_NONFIRST_CASE(cb) \
cb(fuse_mul_add3_iXxf32xf32xi8) cb(round_shr_saturate_iXxi8xi8) \
cb(fuse_add_rmulh_round_shr_saturate_int16) \
cb(fuse_add_rmulh_round_shr_saturate_int32)
#define FOREACH_ELEMWISE_MULTI_TYPE_NONFIRST_CASE(cb) \
cb(fuse_mul_add3_iXxf32xf32xi8) cb(round_shr_saturate_iXxi8xi8) \
cb(fuse_add_rmulh_round_shr_saturate_int16) \
cb(fuse_add_rmulh_round_shr_saturate_int32) \
cb(fuse_mul_add3_int16xf32xf32xf32) \
cb(fuse_mul_add3_uint8xf32xf32xf32) \
cb(fuse_mul_add3_int16xf32xf32)
#define FOREACH_ELEMWISE_MULTI_TYPE_CASE(cb) \
cb(FIRST_ELEMWISE_MULTI_TYPE_CASE) FOREACH_ELEMWISE_MULTI_TYPE_NONFIRST_CASE(cb)
......
......@@ -40,6 +40,24 @@ TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_RECORD_FMA3_INT16x32x32x32) {
checker.execs({{A, B, C}, {1, B, 1}, {1, B, 1}, {}});
}
TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_RECORD_FMA3_INT16xF32xF32xF32) {
TaskRecordChecker<ElemwiseMultiType> checker{1};
checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32});
checker.set_dtype(0, dtype::Int16());
checker.set_dtype(1, dtype::Float32());
checker.set_dtype(2, dtype::Float32());
UniformIntRNG rng{-10, 10};
checker.set_rng(0, &rng);
checker.set_rng(1, &rng);
checker.set_rng(2, &rng);
checker.execs({{5, 7, 16}, {1, 1, 16}, {1, 1, 16}, {}})
.execs({{2, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}})
.execs({{2, 700, 600}, {2, 700, 600}, {2, 700, 600}, {}})
.execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}})
.execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}})
.execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}});
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_BENCHMARK_FMA3_INT16x32x32x32) {
Benchmarker<ElemwiseMultiType> bench{handle()};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册