relu.h 7.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
/**
 * \file dnn/src/fallback/elemwise_helper/kimpl/relu.h
 */
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct ReluOpBase : UnaryOpBase<src_ctype, dst_ctype> {
    using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
    void operator()(const src_ctype& src, dst_ctype* dst) const {
        *dst = operator()(src);
    }
    dst_ctype operator()(const src_ctype& src) const { return src > 0 ? src : 0; }
};

template <typename src_ctype, typename dst_type = src_ctype>
struct ReluOp;

23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width, zero) \
    template <>                                                              \
    struct ReluOp<_ctype> : ReluOpBase<_ctype> {                             \
        using ReluOpBase::ReluOpBase;                                        \
        using ReluOpBase::operator();                                        \
        constexpr static size_t SIMD_WIDTH = _simd_width;                    \
        void operator()(const _simd_type2& src, _ctype* dst) const {         \
            auto vitem = operator()(src);                                    \
            GiStore##_func_suffix(dst, vitem.val[0]);                        \
            GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]);           \
        }                                                                    \
        _simd_type2 operator()(const _simd_type2& src) const {               \
            auto vitem0 = GiMaximum##_func_suffix(src.val[0], zero);         \
            auto vitem1 = GiMaximum##_func_suffix(src.val[1], zero);         \
            return {{vitem0, vitem1}};                                       \
        }                                                                    \
        void operator()(const _simd_type& src, _ctype* dst) const {          \
            auto vitem = operator()(src);                                    \
            GiStore##_func_suffix(dst, vitem);                               \
        }                                                                    \
        _simd_type operator()(const _simd_type& src) const {                 \
            return GiMaximum##_func_suffix(src, zero);                       \
        }                                                                    \
46 47
    };

48 49 50 51 52 53
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float),
   vfzero)
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t),
   vzero)
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t),
   vzero_int8)
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
#undef OP

template <>
struct ReluOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> {
    using UnaryOpBase::UnaryOpBase;
    void operator()(const dt_qint8& src, dt_qint8* dst) const {
        *dst = operator()(src);
    }
    dt_qint8 operator()(const dt_qint8& src) const {
        float fsrc = src.as_int8() * this->scale;
        fsrc = std::max<float>(fsrc, 0.f);
        return QConverter::convert<dt_qint8, float>(fsrc);
    }
};

template <>
struct ReluOp<dt_qint8, dt_qint8> : ReluOpBase<dt_qint8, dt_qint8> {
    using ReluOpBase::ReluOpBase;
72
    constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
73 74 75 76 77 78 79 80
    using ReluOpBase::operator();

    void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const {
        OPERATOR_UNARY_QINT8_FALLBACK;
    }
    GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
        auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale);
        auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale);
81 82
        vitem0 = GiMaximumFloat32(vitem0, vfzero);
        vitem1 = GiMaximumFloat32(vitem1, vfzero);
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
        return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
    }
};

template <>
struct ReluOpBase<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> {
    using UnaryOpBase::UnaryOpBase;
    void operator()(const dt_qint32& src, dt_qint8* dst) const {
        *dst = operator()(src);
    }

    dt_qint8 operator()(const dt_qint32& src) const {
        float fsrc = src.as_int32() * this->scale;
        fsrc = std::max<float>(fsrc, 0.f);
        return QConverter::convert<dt_qint8, float>(fsrc);
    }
};

//! if old armv7, special define relu with fixup
#if defined(__ARM_ARCH) && __ARM_ARCH < 8
template <>
struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase {
    using ReluOpBase::operator();
    constexpr static size_t SIMD_WIDTH = 4;

    ReluOp(DType src_dtype, DType dst_dtype)
            : ReluOpBase(src_dtype, dst_dtype), FixupBase(scale) {}

    ReluOp(float src_scale, float dst_scale)
            : ReluOpBase(src_scale, dst_scale), FixupBase(scale) {}

    void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const {
115
        vst1_s8(reinterpret_cast<int8_t*>(dst), vget_low_s8(operator()(vsrc)));
116
    }
117
    int8x16_t operator()(const int32x4x2_t& vsrc) const {
118 119
        int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier);
        int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier);
120 121
        vitem0 = vmaxq_s32(vitem0, vzero);
        vitem1 = vmaxq_s32(vitem1, vzero);
122
        auto tmp = vqmovn_s16(vcombine_s16(
123 124
                vqmovn_s32(vrshlq_s32(vitem0, vshift)),
                vqmovn_s32(vrshlq_s32(vitem1, vshift))));
125
        return vcombine_s8(tmp, tmp);
126
    }
127
    int8x16_t operator()(const float32x4_t& vsrc) const {
128
        int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier);
129
        vitem0 = vmaxq_s32(vitem0, vzero);
130 131
        vitem0 = vrshlq_s32(vitem0, vshift);
        int16x4_t vitem = vqmovn_s32(vitem0);
132 133
        auto tmp = vqmovn_s16(vcombine_s16(vitem, vitem));
        return vcombine_s8(tmp, tmp);
134 135 136
    }
    void operator()(const int32x4_t& src, dt_qint8* dst) const {
        auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale);
137
        vitem0 = vmaxq_f32(vitem0, vfzero);
138 139
        auto result = QConverter::convert<int8x16_t, float32x4_t>(vitem0);
        vst1q_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x4_t)result, 0);
140 141 142
    }
    void operator()(const float32x4_t& src, dt_qint8* dst) const {
        auto vitem0 = vmulq_f32(src, this->vscale);
143
        vitem0 = vmaxq_f32(vitem0, vfzero);
144 145
        auto result = QConverter::convert<int8x16_t, float32x4_t>(vitem0);
        vst1q_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x4_t)result, 0);
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
    }
};

#else
template <>
struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8> {
    using ReluOpBase::ReluOpBase;
    using ReluOpBase::operator();
    constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t);

    void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const {
        GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc));
    }
    void operator()(const GI_INT32_t& src, dt_qint8* dst) const {
        GiStoreLane0Int32(
                reinterpret_cast<int32_t*>(dst), (GI_INT32_t)(operator()(src)));
    }

    GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
        auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale);
        auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale);
167 168
        vitem0 = GiMaximumFloat32(vitem0, vfzero);
        vitem1 = GiMaximumFloat32(vitem1, vfzero);
169 170 171 172 173

        return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
    }
    GI_INT8_t operator()(const GI_INT32_t& src) const {
        auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(src), this->vscale);
174
        vitem0 = GiMaximumFloat32(vitem0, vfzero);
175 176 177 178
        return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0);
    }
    GI_INT8_t operator()(const GI_FLOAT32_t& src) const {
        auto vitem0 = GiMultiplyFloat32(src, this->vscale);
179
        vitem0 = GiMaximumFloat32(vitem0, vfzero);
180 181 182 183 184 185 186 187 188 189
        return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0);
    }
};

#endif

}  // namespace fallback
}  // namespace megdnn

// vim: syntax=cpp.doxygen