/** * \file dnn/src/fallback/elemwise_helper/op_common.h */ #pragma once #include "megdnn/dtype.h" #include "src/fallback/elemwise_helper/kimpl/pow.h" namespace megdnn { namespace elemwise { /*! * \brief broadcast type * BCAST_x[0]x[1]...: x[i] == !stride[i] */ enum BcastType { VEC, VEC_VEC, VEC_BCAST101, VEC_BCASTX0X, VEC_BCAST111C, VEC_BCAST101xX, VEC_SCALAR, SCALAR_VEC, BCAST101_VEC, BCASTX0X_VEC, BCAST111C_VEC, BCAST101xX_VEC, VEC_VEC_VEC, VEC_VEC_SCALAR, BCAST101_VEC_BCAST101, BCAST111C_VEC_BCAST111C, BCAST101xX_VEC_BCAST101xX, VEC_BCAST101_VEC, VEC_BCAST111C_VEC, VEC_BCAST101xX_VEC, VEC_SCALAR_VEC, VEC_SCALAR_SCALAR, UNKNOWN_BCAST_TYPE }; ///////////////////////////////// ParamElemVistor /////////////////////////// template struct ParamElemVisitor; //! visitor single elemwise, and dup to vector template struct ParamElemVisitorDup; template struct ParamElemVisitorBcast101x4; #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ template <> \ struct ParamElemVisitor<_ctype> { \ _simd_type operator()(const _ctype* src) const { \ return GiLoad##_fun_suffix(src); \ } \ }; \ template <> \ struct ParamElemVisitorDup<_ctype> { \ _simd_type operator()(const _ctype* src) const { \ return GiBroadcast##_fun_suffix( \ *reinterpret_cast(src)); \ } \ } cb(dt_qint32, int32_t, GI_INT32_t, Int32); cb(dt_qint8, int8_t, GI_INT8_t, Int8); cb(dt_float32, float, GI_FLOAT32_t, Float32); cb(dt_int32, int32_t, GI_INT32_t, Int32); cb(dt_int8, int8_t, GI_INT8_t, Int8); #undef cb template struct ParamElemVisitorBcast101x4; #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \ template <> \ struct ParamElemVisitorBcast101x4<_ctype> { \ _simd_type operator()(const _ctype* src) const { \ return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \ *reinterpret_cast(src))); \ } \ } cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32); cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32); #undef cb #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ template <> \ struct ParamElemVisitorBcast101x4<_ctype> { \ _simd_type operator()(const _ctype* src) const { \ return GiLoad##_fun_suffix(src); \ } \ } cb(dt_qint32, int32_t, GI_INT32_t, Int32); cb(dt_float32, float, GI_FLOAT32_t, Float32); cb(dt_int32, int32_t, GI_INT32_t, Int32); #undef cb ///////////////////////////////// OpCaller ///////////////////////////// template struct OpCallerUnary; template struct OpCallerUnary { static void run( const typename Op::src_ctype* src, typename Op::dst_ctype* dst, DType src_dtype, DType dst_dtype, size_t nr_elems) { Op op(src_dtype, dst_dtype); ParamElemVisitor vis; size_t i = 0; for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { op({{vis(src), vis(src + Op::SIMD_WIDTH)}}, dst); src += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < nr_elems; i++) { op(*src, dst); src++; dst++; } } }; template struct OpCallerBinary; ///////////////////////// Pow //////////////////////////////// template struct OpCallerBinary, VEC_VEC> { using Op = fallback::PowOp; static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t nr_elems) { Op op(src0_dtype, src1_dtype, dst_dtype); size_t i = 0; #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < nr_elems; i++) { op(*src0, *src1, dst); src0++; src1++; dst++; } } }; template struct OpCallerBinary, VEC_SCALAR> { using Op = fallback::PowOp; static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t nr_elems) { Op op(src0_dtype, src1_dtype, dst_dtype); size_t i = 0; #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < nr_elems; i++) { op(*src0, src1, dst); src0++; dst++; } } }; template struct OpCallerBinary, VEC_BCAST101> { using Op = fallback::PowOp; static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { Op op(src0_dtype, src1_dtype, dst_dtype); for (size_t b = 0; b < batch; b++) { const typename Op::src_ctype* src1_ptr = src1; for (size_t c = 0; c < channel; c++) { size_t i = 0; #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < channel_stride; i++) { op(*src0, *src1_ptr, dst); src0++; dst++; } src1_ptr++; } } } }; template struct OpCallerBinary, VEC_BCASTX0X> { using Op = fallback::PowOp; static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { Op op(src0_dtype, src1_dtype, dst_dtype); for (size_t b = 0; b < batch; b++) { const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride; for (size_t c = 0; c < channel; c++) { size_t i = 0; auto src1_ptr = src1_ptr_base; #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < channel_stride; i++) { op(*src0, *src1_ptr, dst); src0++; src1_ptr++; dst++; } } } } }; template struct OpCallerBinary, VEC_BCAST111C> { using Op = fallback::PowOp; static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { Op op(src0_dtype, src1_dtype, dst_dtype); for (size_t b = 0; b < batch; b++) { for (size_t c = 0; c < channel; c++) { size_t i = 0; const typename Op::src_ctype* src1_ptr = src1; #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < channel_stride; i++) { op(*src0, *src1_ptr, dst); src0++; src1_ptr++; dst++; } } } } }; template struct OpCallerBinary, BCAST111C_VEC> { using Op = fallback::PowOp; static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { Op op(src0_dtype, src1_dtype, dst_dtype); for (size_t b = 0; b < batch; b++) { for (size_t c = 0; c < channel; c++) { size_t i = 0; const typename Op::src_ctype* src0_ptr = src0; #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < channel_stride; i++) { op(*src0_ptr, *src1, dst); src0_ptr++; src1++; dst++; } } } } }; template struct OpCallerBinary, SCALAR_VEC> { using Op = fallback::PowOp; static void run( const typename Op::src_ctype src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t nr_elems) { Op op(src0_dtype, src1_dtype, dst_dtype); size_t i = 0; #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < nr_elems; i++) { op(src0, *src1, dst); src1++; dst++; } } }; template struct OpCallerBinary, BCAST101_VEC> { using Op = fallback::PowOp; static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { Op op(src0_dtype, src1_dtype, dst_dtype); for (size_t b = 0; b < batch; b++) { auto src0_ptr = src0; for (size_t c = 0; c < channel; c++) { size_t i = 0; #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < channel_stride; i++) { op(*src0_ptr, *src1, dst); src1++; dst++; } src0_ptr++; } } } }; template struct OpCallerBinary, BCASTX0X_VEC> { using Op = fallback::PowOp; static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { Op op(src0_dtype, src1_dtype, dst_dtype); for (size_t b = 0; b < batch; b++) { auto src0_ptr_base = src0 + b * channel_stride; for (size_t c = 0; c < channel; c++) { size_t i = 0; auto src0_ptr = src0_ptr_base; #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < channel_stride; i++) { op(*src0_ptr, *src1, dst); src0_ptr++; src1++; dst++; } } } } }; template struct OpCallerBinary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t nr_elems) { Op op(src0_dtype, src1_dtype, dst_dtype); ParamElemVisitor vis0; ParamElemVisitor vis1; size_t i = 0; for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); src0 += Op::SIMD_WIDTH * 2; src1 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < nr_elems; i++) { op(*src0, *src1, dst); src0++; src1++; dst++; } } }; template struct OpCallerBinary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { Op op(src0_dtype, src1_dtype, dst_dtype); ParamElemVisitor vis0; ParamElemVisitorDup vis1; for (size_t b = 0; b < batch; b++) { const typename Op::src_ctype* src1_ptr = src1; for (size_t c = 0; c < channel; c++) { size_t i = 0; auto src1_simd = vis1(src1_ptr); for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; i += Op::SIMD_WIDTH * 2) { op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{src1_simd, src1_simd}}, dst); src0 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < channel_stride; i++) { op(*src0, *src1_ptr, dst); src0++; dst++; } src1_ptr++; } } } }; template struct OpCallerBinary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { Op op(src0_dtype, src1_dtype, dst_dtype); ParamElemVisitor vis; for (size_t b = 0; b < batch; b++) { const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride; for (size_t c = 0; c < channel; c++) { size_t i = 0; auto src1_ptr = src1_ptr_base; for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; i += Op::SIMD_WIDTH * 2) { auto src0_simd0 = vis(src0); auto src0_simd1 = vis(src0 + Op::SIMD_WIDTH); auto src1_simd0 = vis(src1_ptr); auto src1_simd1 = vis(src1_ptr + Op::SIMD_WIDTH); op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, dst); src0 += Op::SIMD_WIDTH * 2; src1_ptr += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < channel_stride; i++) { op(*src0, *src1_ptr, dst); src0++; src1_ptr++; dst++; } } } } }; template struct OpCallerBinary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { Op op(src0_dtype, src1_dtype, dst_dtype); ParamElemVisitor vis; for (size_t b = 0; b < batch; b++) { for (size_t c = 0; c < channel; c++) { size_t rest = channel_stride; const typename Op::src_ctype* src1_ptr = src1; while (rest >= Op::SIMD_WIDTH * 2) { auto src0_simd0 = vis(src0); auto src0_simd1 = vis(src0 + Op::SIMD_WIDTH); auto src1_simd0 = vis(src1_ptr); auto src1_simd1 = vis(src1_ptr + Op::SIMD_WIDTH); src0 += Op::SIMD_WIDTH * 2; src1_ptr += Op::SIMD_WIDTH * 2; op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, dst); dst += Op::SIMD_WIDTH * 2; rest -= Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif while (rest > 0) { op(*src0, *src1_ptr, dst); dst++; src0++; src1_ptr++; rest--; } } } } }; template struct OpCallerBinary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { Op op(src0_dtype, src1_dtype, dst_dtype); ParamElemVisitor vis; for (size_t b = 0; b < batch; b++) { for (size_t c = 0; c < channel; c++) { size_t rest = channel_stride; const typename Op::src_ctype* src0_ptr = src0; while (rest >= Op::SIMD_WIDTH * 2) { auto src0_simd0 = vis(src0_ptr); auto src0_simd1 = vis(src0_ptr + Op::SIMD_WIDTH); auto src1_simd0 = vis(src1); auto src1_simd1 = vis(src1 + Op::SIMD_WIDTH); src0_ptr += Op::SIMD_WIDTH * 2; src1 += Op::SIMD_WIDTH * 2; op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, dst); dst += Op::SIMD_WIDTH * 2; rest -= Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif while (rest > 0) { op(*src0_ptr, *src1, dst); dst++; src0_ptr++; src1++; rest--; } } } } }; template struct OpCallerBinary, BCAST101xX_VEC> { using Op = fallback::PowOp; static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, size_t nr_channel_blocks, size_t channel_stride, size_t channel_block_dim) { Op op(src0_dtype, src1_dtype, dst_dtype); for (size_t b = 0; b < batch; b++) { auto src0_ptr = src0; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src0_block_ptr = src0_ptr + cb * channel_block_dim; for (size_t i = 0; i < channel_stride; i++) { for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*(src0_block_ptr + c_iter), *src1, dst); src1++; dst++; } } } } } }; template struct OpCallerBinaryBcast101xXVec { template static void run( const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, const Op& op, size_t batch, size_t nr_channel_blocks, size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src0_ptr = src0; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src0_block_ptr = src0_ptr + cb * channel_block_dim; for (size_t img_index = 0; img_index < channel_stride; img_index++) { for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*(src0_block_ptr + c_iter), *src1, dst); src1++; dst++; } } } } } }; template struct OpCallerBinaryBcast101xDVec { template static void run( const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, const Vis1& vis1, size_t batch, size_t nr_channel_blocks, size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src0_ptr = src0; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src0_block_ptr = src0_ptr + cb * channel_block_dim; auto channel_block_vec = vis0(src0_block_ptr); size_t img_index = 0; auto src1_offset = Op::SIMD_WIDTH / channel_block_dim; for (; img_index + 2 * src1_offset <= channel_stride; img_index += 2 * src1_offset) { op({{channel_block_vec, channel_block_vec}}, {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); src1 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } // TODO:all elemwise_multi_type op imp one simd mode for (; img_index < channel_stride; img_index++) { for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*(src0_block_ptr + c_iter), *src1, dst); src1++; dst++; } } } } } }; template struct OpCallerBinaryBcast101xXVec { template static void run( const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, const Op& op, size_t batch, size_t nr_channel_blocks, size_t channel_stride) { ParamElemVisitorBcast101x4 vis0; ParamElemVisitor vis1; OpCallerBinaryBcast101xDVec::run( src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, channel_stride); } }; template struct OpCallerBinary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, size_t nr_channel_blocks, size_t channel_stride, size_t channel_block_dim) { megdnn_assert( channel_block_dim == 4 || channel_block_dim == 8, "only imp for nchw44/nchw88"); Op op(src0_dtype, src1_dtype, dst_dtype); if (channel_block_dim == 4) { OpCallerBinaryBcast101xXVec::run( src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); } else { OpCallerBinaryBcast101xXVec::run( src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); } } }; template struct OpCallerBinary, VEC_BCAST101xX> { using Op = fallback::PowOp; static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, size_t nr_channel_blocks, size_t channel_stride, size_t channel_block_dim) { Op op(src0_dtype, src1_dtype, dst_dtype); for (size_t b = 0; b < batch; b++) { auto src1_ptr = src1; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src1_block_ptr = src1_ptr + cb * channel_block_dim; for (size_t i = 0; i < channel_stride; i++) { for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*(src0), *(src1_block_ptr + c_iter), dst); src0++; dst++; } } } } } }; template struct OpCallerBinaryVecBcast101xX { template static void run( const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, const Op& op, size_t batch, size_t nr_channel_blocks, size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src1_ptr = src1; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src1_block_ptr = src1_ptr + cb * channel_block_dim; for (size_t img_index = 0; img_index < channel_stride; img_index++) { for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*src0, *(src1_block_ptr + c_iter), dst); src0++; dst++; } } } } } }; template struct OpCallerBinaryVecBcast101xD { template static void run( const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, const Vis1& vis1, size_t batch, size_t nr_channel_blocks, size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src1_ptr = src1; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src1_block_ptr = src1_ptr + cb * channel_block_dim; auto channel_block_vec = vis1(src1_block_ptr); size_t img_index = 0; auto src0_offset = Op::SIMD_WIDTH / channel_block_dim; for (; img_index + 2 * src0_offset <= channel_stride; img_index += 2 * src0_offset) { op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{channel_block_vec, channel_block_vec}}, dst); src0 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } // TODO:all elemwise_multi_type op imp one simd mode for (; img_index < channel_stride; img_index++) { for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*src0, *(src1_block_ptr + c_iter), dst); src0++; dst++; } } } } } }; template struct OpCallerBinaryVecBcast101xX { template static void run( const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, const Op& op, size_t batch, size_t nr_channel_blocks, size_t channel_stride) { ParamElemVisitor vis0; ParamElemVisitorBcast101x4 vis1; OpCallerBinaryVecBcast101xD::run( src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, channel_stride); } }; template struct OpCallerBinary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, size_t nr_channel_blocks, size_t channel_stride, size_t channel_block_dim) { megdnn_assert( channel_block_dim == 4 || channel_block_dim == 8, "only imp for nchw44/nchw88"); Op op(src0_dtype, src1_dtype, dst_dtype); if (channel_block_dim == 4) { OpCallerBinaryVecBcast101xX::run( src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); } else { OpCallerBinaryVecBcast101xX::run( src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); } } }; template struct OpCallerBinary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t nr_elems) { Op op(src0_dtype, src1_dtype, dst_dtype); ParamElemVisitor vis0; ParamElemVisitorDup vis1; auto vis1_simd = vis1(&src1); size_t i = 0; for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_simd, vis1_simd}}, dst); src0 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < nr_elems; i++) { op(*src0, src1, dst); src0++; dst++; } } }; //! this only for nonswap op, like SUB and DIV template struct OpCallerBinary { static void run( const typename Op::src_ctype src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t nr_elems) { Op op(src0_dtype, src1_dtype, dst_dtype); ParamElemVisitorDup vis0; ParamElemVisitor vis1; auto vis0_simd = vis0(&src0); size_t i = 0; for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { op({{vis0_simd, vis0_simd}}, {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); src1 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < nr_elems; i++) { op(src0, *src1, dst); src1++; dst++; } } }; template struct OpCallerBinary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { Op op(src0_dtype, src1_dtype, dst_dtype); ParamElemVisitorDup vis0; ParamElemVisitor vis1; for (size_t b = 0; b < batch; b++) { auto src0_ptr = src0; for (size_t c = 0; c < channel; c++) { auto vis0_simd = vis0(src0_ptr); size_t i = 0; for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; i += Op::SIMD_WIDTH * 2) { op({{vis0_simd, vis0_simd}}, {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); src1 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < channel_stride; i++) { op(*src0_ptr, *src1, dst); src1++; dst++; } src0_ptr++; } } } }; template struct OpCallerBinary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { Op op(src0_dtype, src1_dtype, dst_dtype); ParamElemVisitor vis; for (size_t b = 0; b < batch; b++) { auto src0_ptr_base = src0 + b * channel_stride; for (size_t c = 0; c < channel; c++) { auto src0_ptr = src0_ptr_base; size_t i = 0; for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; i += Op::SIMD_WIDTH * 2) { auto src0_simd0 = vis(src0_ptr); auto src0_simd1 = vis(src0_ptr + Op::SIMD_WIDTH); auto src1_simd0 = vis(src1); auto src1_simd1 = vis(src1 + Op::SIMD_WIDTH); op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, dst); src0_ptr += Op::SIMD_WIDTH * 2; src1 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < channel_stride; i++) { op(*src0_ptr, *src1, dst); src0_ptr++; src1++; dst++; } } } } }; template struct OpCallerTernary; template struct OpCallerTernary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t nr_elems) { Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); ParamElemVisitor vis0; ParamElemVisitor vis1; ParamElemVisitor vis2; size_t i = 0; for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); src0 += Op::SIMD_WIDTH * 2; src1 += Op::SIMD_WIDTH * 2; src2 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < nr_elems; i++) { op(*src0, *src1, *src2, dst); src0++; src1++; src2++; dst++; } } }; //! src0: vector, src1: vector, src2: scalar template struct OpCallerTernary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, const typename Op::src_ctype src2, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t nr_elems) { Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); ParamElemVisitor vis0; ParamElemVisitor vis1; ParamElemVisitorDup vis2; auto vis2_simd = vis2(&src2); size_t i = 0; for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, {{vis2_simd, vis2_simd}}, dst); src0 += Op::SIMD_WIDTH * 2; src1 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < nr_elems; i++) { op(*src0, *src1, src2, dst); src0++; src1++; dst++; } } }; //! src0: 1C11, src1: vector, src2: 1C11 template struct OpCallerTernary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size, size_t channel_stride, size_t batch_offset) { Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); ParamElemVisitor vis1; ParamElemVisitorDup vis0; ParamElemVisitorDup vis2; for (size_t batch = 0; batch < batch_size; batch++) { auto src0_ptr = src0; auto src2_ptr = src2; auto b_offset = batch_offset; for (size_t channel = 0; channel < channel_size; channel++) { size_t i = 0; auto src0_simd = vis0(src0_ptr); auto src2_simd = vis2(src2_ptr); for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; i += Op::SIMD_WIDTH * 2) { op({{src0_simd, src0_simd}}, {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, {{src2_simd, src2_simd}}, dst); src1 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; b_offset -= Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < channel_stride; i++) { op(*src0_ptr, *src1, *src2_ptr, dst); src1++; dst++; b_offset--; } src0_ptr++; src2_ptr++; } src1 += b_offset; dst += b_offset; } } }; //! src0: 111C, src1: vector, src2: 111C, src1 may not be contig template struct OpCallerTernary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, size_t src1_offset, const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size, size_t channel_stride, size_t batch_offset) { Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); ParamElemVisitor vis; for (size_t batch = 0; batch < batch_size; batch++) { auto b_offset = batch_offset; for (size_t channel = 0; channel < channel_size; channel++) { auto src0_ptr = src0; auto src2_ptr = src2; size_t i = 0; for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; i += Op::SIMD_WIDTH * 2) { auto src0_simd0 = vis(src0_ptr); auto src0_simd1 = vis(src0_ptr + Op::SIMD_WIDTH); auto src1_simd0 = vis(src1); auto src1_simd1 = vis(src1 + Op::SIMD_WIDTH); auto src2_simd0 = vis(src2_ptr); auto src2_simd1 = vis(src2_ptr + Op::SIMD_WIDTH); op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, {{src2_simd0, src2_simd1}}, dst); src0_ptr += Op::SIMD_WIDTH * 2; src1 += Op::SIMD_WIDTH * 2; src2_ptr += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; b_offset -= Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < channel_stride; i++) { op(*src0_ptr, *src1, *src2_ptr, dst); src0_ptr++; src1++; src2_ptr++; dst++; b_offset--; } src1 += src1_offset; } src1 += b_offset; dst += b_offset; } } }; template struct OpCallerTernaryBcast101xXVecBcast101xX { template static void run( const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, typename Op::dst_ctype* dst, const Op& op, size_t batch, size_t nr_channel_blocks, size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src0_ptr = src0; auto src2_ptr = src2; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src0_block_ptr = src0_ptr + cb * channel_block_dim; auto src2_block_ptr = src2_ptr + cb * channel_block_dim; for (size_t img_index = 0; img_index < channel_stride; img_index++) { for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*(src0_block_ptr + c_iter), *src1, *(src2_block_ptr + c_iter), dst); src1++; dst++; } } } } } }; template struct OpCallerTernaryBcast101xDVecBcast101xD { template static void run( const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, const Vis1& vis1, const Vis2& vis2, size_t batch, size_t nr_channel_blocks, size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src0_ptr = src0; auto src2_ptr = src2; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src0_block_ptr = src0_ptr + cb * channel_block_dim; auto src2_block_ptr = src2_ptr + cb * channel_block_dim; auto channel_block_vec0 = vis0(src0_block_ptr); auto channel_block_vec2 = vis2(src2_block_ptr); size_t img_index = 0; auto src1_offset = Op::SIMD_WIDTH / channel_block_dim; for (; img_index + 2 * src1_offset <= channel_stride; img_index += 2 * src1_offset) { op({{channel_block_vec0, channel_block_vec0}}, {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, {{channel_block_vec2, channel_block_vec2}}, dst); src1 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } // TODO:all elemwise_multi_type op imp one simd mode for (; img_index < channel_stride; img_index++) { for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*(src0_block_ptr + c_iter), *src1, *(src2_block_ptr + c_iter), dst); src1++; dst++; } } } } } }; //! src0: CHW44, src1: vector, src2: CHW44 template struct OpCallerTernaryBcast101xXVecBcast101xX { template static void run( const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, typename Op::dst_ctype* dst, const Op& op, size_t batch, size_t nr_channel_blocks, size_t channel_stride) { ParamElemVisitorBcast101x4 vis0; ParamElemVisitor vis1; ParamElemVisitorBcast101x4 vis2; OpCallerTernaryBcast101xDVecBcast101xD::run( src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, channel_stride); } }; template struct OpCallerTernary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t batch, size_t nr_channel_blocks, size_t channel_stride, size_t channel_block_dim) { megdnn_assert( channel_block_dim == 4 || channel_block_dim == 8, "only imp for nchw44/nchw88"); Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); if (channel_block_dim == 4) { OpCallerTernaryBcast101xXVecBcast101xX::run( src0, src1, src2, dst, op, batch, nr_channel_blocks, channel_stride); } else { OpCallerTernaryBcast101xXVecBcast101xX::run( src0, src1, src2, dst, op, batch, nr_channel_blocks, channel_stride); } } }; //! src1: 1C11, src0 and src2 are contig template struct OpCallerTernary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size, size_t channel_stride) { Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); ParamElemVisitor vis0; ParamElemVisitorDup vis1; ParamElemVisitor vis2; for (size_t batch = 0; batch < batch_size; batch++) { auto src1_ptr = src1; for (size_t channel = 0; channel < channel_size; channel++) { size_t i = 0; auto src1_simd = vis1(src1_ptr); for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; i += Op::SIMD_WIDTH * 2) { op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{src1_simd, src1_simd}}, {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); src0 += Op::SIMD_WIDTH * 2; src2 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < channel_stride; i++) { op(*src0, *src1_ptr, *src2, dst); src0++; src2++; dst++; } src1_ptr++; } } } }; //! src1: 111C, src0 and src2 may not be contig template struct OpCallerTernary { static void run( const typename Op::src_ctype* src0, size_t src0_offset, const typename Op::src_ctype* src1, const typename Op::src_ctype* src2, size_t src2_offset, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size, size_t channel_stride) { Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); ParamElemVisitor vis0; ParamElemVisitor vis1; ParamElemVisitor vis2; for (size_t batch = 0; batch < batch_size; batch++) { for (size_t channel = 0; channel < channel_size; channel++) { auto src1_ptr = src1; size_t i = 0; for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; i += Op::SIMD_WIDTH * 2) { op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1(src1_ptr), vis1(src1_ptr + Op::SIMD_WIDTH)}}, {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); src0 += Op::SIMD_WIDTH * 2; src1_ptr += Op::SIMD_WIDTH * 2; src2 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < channel_stride; i++) { op(*src0, *src1_ptr, *src2, dst); src0++; src1_ptr++; src2++; dst++; } src0 += src0_offset; src2 += src2_offset; } } } }; template struct OpCallerTernaryVecBcast101xXVec { template static void run( const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, typename Op::dst_ctype* dst, const Op& op, size_t batch, size_t nr_channel_blocks, size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src1_ptr = src1; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src1_block_ptr = src1_ptr + cb * channel_block_dim; for (size_t img_index = 0; img_index < channel_stride; img_index++) { for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*src0, *(src1_block_ptr + c_iter), *src2, dst); src0++; src2++; dst++; } } } } } }; //! src1: CHW44, src0 and src2 are contig template struct OpCallerTernaryVecBcast101xDVec { template static void run( const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, const Vis1& vis1, const Vis2& vis2, size_t batch, size_t nr_channel_blocks, size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src1_ptr = src1; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src1_block_ptr = src1_ptr + cb * channel_block_dim; auto channel_block_vec = vis1(src1_block_ptr); size_t img_index = 0; auto offset = Op::SIMD_WIDTH / channel_block_dim; for (; img_index + 2 * offset <= channel_stride; img_index += 2 * offset) { op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{channel_block_vec, channel_block_vec}}, {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); src0 += Op::SIMD_WIDTH * 2; src2 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } // TODO:all elemwise_multi_type op imp one simd mode for (; img_index < channel_stride; img_index++) { for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*src0, *(src1_block_ptr + c_iter), *src2, dst); src0++; src2++; dst++; } } } } } }; template struct OpCallerTernaryVecBcast101xXVec { template static void run( const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, typename Op::dst_ctype* dst, const Op& op, size_t batch, size_t nr_channel_blocks, size_t channel_stride) { ParamElemVisitor vis0; ParamElemVisitorBcast101x4 vis1; ParamElemVisitor vis2; OpCallerTernaryVecBcast101xDVec::run( src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, channel_stride); } }; template struct OpCallerTernary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t batch, size_t nr_channel_blocks, size_t channel_stride, size_t channel_block_dim) { megdnn_assert( channel_block_dim == 4 || channel_block_dim == 8, "only imp for nchw44/nchw88"); Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); if (channel_block_dim == 4) { OpCallerTernaryVecBcast101xXVec::run( src0, src1, src2, dst, op, batch, nr_channel_blocks, channel_stride); } else { OpCallerTernaryVecBcast101xXVec::run( src0, src1, src2, dst, op, batch, nr_channel_blocks, channel_stride); } } }; //! src1: scalar, src0 and src2 has the same shape template struct OpCallerTernary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype src1, const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t nr_elems) { Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); ParamElemVisitor vis0; ParamElemVisitorDup vis1; ParamElemVisitor vis2; auto vis1_simd = vis1(&src1); size_t i = 0; for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_simd, vis1_simd}}, {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); src0 += Op::SIMD_WIDTH * 2; src2 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < nr_elems; i++) { op(*src0, src1, *src2, dst); src0++; src2++; dst++; } } }; //! src1, src2: scalar, src0 is vector template struct OpCallerTernary { static void run( const typename Op::src_ctype* src0, const typename Op::src_ctype src1, const typename Op::src_ctype src2, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t nr_elems) { Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); ParamElemVisitor vis0; ParamElemVisitorDup vis1; ParamElemVisitorDup vis2; auto vis1_simd = vis1(&src1); auto vis2_simd = vis2(&src2); size_t i = 0; for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_simd, vis1_simd}}, {{vis2_simd, vis2_simd}}, dst); src0 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < nr_elems; i++) { op(*src0, src1, src2, dst); src0++; dst++; } } }; } // namespace elemwise } // namespace megdnn // vim: syntax=cpp.doxygen