relu.h 8.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
/**
 * \file dnn/src/fallback/elemwise_helper/kimpl/relu.h
 */
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct ReluOpBase : UnaryOpBase<src_ctype, dst_ctype> {
    using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
    void operator()(const src_ctype& src, dst_ctype* dst) const {
        *dst = operator()(src);
    }
    dst_ctype operator()(const src_ctype& src) const { return src > 0 ? src : 0; }
};

template <typename src_ctype, typename dst_type = src_ctype>
struct ReluOp;

#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
    template <>                                                        \
    struct ReluOp<_ctype> : ReluOpBase<_ctype> {                       \
        using ReluOpBase::ReluOpBase;                                  \
        using ReluOpBase::operator();                                  \
        constexpr static size_t SIMD_WIDTH = _simd_width;              \
        void operator()(const _simd_type2& src, _ctype* dst) const {   \
            auto vitem = operator()(src);                              \
            GiStore##_func_suffix(dst, vitem.val[0]);                  \
            GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]);     \
        }                                                              \
        _simd_type2 operator()(const _simd_type2& src) const {         \
            auto vzero = GiBroadcast##_func_suffix(0);                 \
            auto vitem0 = GiMaximum##_func_suffix(src.val[0], vzero);  \
            auto vitem1 = GiMaximum##_func_suffix(src.val[1], vzero);  \
            return {{vitem0, vitem1}};                                 \
        }                                                              \
        void operator()(const _simd_type& src, _ctype* dst) const {    \
            auto vitem = operator()(src);                              \
            GiStore##_func_suffix(dst, vitem);                         \
        }                                                              \
        _simd_type operator()(const _simd_type& src) const {           \
            auto vzero = GiBroadcast##_func_suffix(0);                 \
            return GiMaximum##_func_suffix(src, vzero);                \
        }                                                              \
    };

OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t))
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t))
#undef OP

template <>
struct ReluOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> {
    using UnaryOpBase::UnaryOpBase;
    void operator()(const dt_qint8& src, dt_qint8* dst) const {
        *dst = operator()(src);
    }
    dt_qint8 operator()(const dt_qint8& src) const {
        float fsrc = src.as_int8() * this->scale;
        fsrc = std::max<float>(fsrc, 0.f);
        return QConverter::convert<dt_qint8, float>(fsrc);
    }
};

template <>
struct ReluOp<dt_qint8, dt_qint8> : ReluOpBase<dt_qint8, dt_qint8> {
    using ReluOpBase::ReluOpBase;
    constexpr static size_t SIMD_WIDTH = 16;
    using ReluOpBase::operator();

    void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const {
        OPERATOR_UNARY_QINT8_FALLBACK;
    }
    GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
        auto vzero = GiBroadcastFloat32(0.f);
        auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale);
        auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale);
        vitem0 = GiMaximumFloat32(vitem0, vzero);
        vitem1 = GiMaximumFloat32(vitem1, vzero);
        return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
    }
};

template <>
struct ReluOpBase<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> {
    using UnaryOpBase::UnaryOpBase;
    void operator()(const dt_qint32& src, dt_qint8* dst) const {
        *dst = operator()(src);
    }

    dt_qint8 operator()(const dt_qint32& src) const {
        float fsrc = src.as_int32() * this->scale;
        fsrc = std::max<float>(fsrc, 0.f);
        return QConverter::convert<dt_qint8, float>(fsrc);
    }
};

//! if old armv7, special define relu with fixup
#if defined(__ARM_ARCH) && __ARM_ARCH < 8
template <>
struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase {
    using ReluOpBase::operator();
    constexpr static size_t SIMD_WIDTH = 4;

    ReluOp(DType src_dtype, DType dst_dtype)
            : ReluOpBase(src_dtype, dst_dtype), FixupBase(scale) {}

    ReluOp(float src_scale, float dst_scale)
            : ReluOpBase(src_scale, dst_scale), FixupBase(scale) {}

    void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const {
        vst1_s8(reinterpret_cast<int8_t*>(dst), operator()(vsrc));
    }

    int8x8_t operator()(const int32x4x2_t& vsrc) const {
        int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier);
        int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier);
        vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero());
        vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero());
        return vqmovn_s16(vcombine_s16(
                vqmovn_s32(vrshlq_s32(vitem0, vshift)),
                vqmovn_s32(vrshlq_s32(vitem1, vshift))));
    }
    int8x8_t operator()(const float32x4_t& vsrc) const {
        int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier);
        vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero());
        vitem0 = vrshlq_s32(vitem0, vshift);
        int16x4_t vitem = vqmovn_s32(vitem0);
        return vqmovn_s16(vcombine_s16(vitem, vitem));
    }
    void operator()(const int32x4_t& src, dt_qint8* dst) const {
        auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale);
        vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero());
        auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0);
        vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0);
    }
    void operator()(const float32x4_t& src, dt_qint8* dst) const {
        auto vitem0 = vmulq_f32(src, this->vscale);
        vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero());
        auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0);
        vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0);
    }
};

#else
template <>
struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8> {
    using ReluOpBase::ReluOpBase;
    using ReluOpBase::operator();
    constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t);

    void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const {
        GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc));
    }
    void operator()(const GI_INT32_t& src, dt_qint8* dst) const {
        GiStoreLane0Int32(
                reinterpret_cast<int32_t*>(dst), (GI_INT32_t)(operator()(src)));
    }

    GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
        auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale);
        auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale);
        vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero());
        vitem1 = GiMaximumFloat32(vitem1, QConverterBase::vfzero());

        return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
    }
    GI_INT8_t operator()(const GI_INT32_t& src) const {
        auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(src), this->vscale);
        vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero());
        return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0);
    }
    GI_INT8_t operator()(const GI_FLOAT32_t& src) const {
        auto vitem0 = GiMultiplyFloat32(src, this->vscale);
        vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero());
        return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0);
    }
};

#endif

}  // namespace fallback
}  // namespace megdnn

// vim: syntax=cpp.doxygen