rounding_converter.cuh 1.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
#pragma once
#include "megdnn/dtype.h"

namespace megdnn {
namespace rounding {

template <typename T>
struct RoundingConverter;

template <>
struct RoundingConverter<float> {
M
Megvii Engine Team 已提交
12
    MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE float operator()(float x) const {
13 14 15 16
        return x;
    }
};

17
#if !MEGDNN_DISABLE_FLOAT16
18 19 20

template <>
struct RoundingConverter<half_float::half> {
21
    MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE half_float::half operator()(
22 23 24 25 26
            float x) const {
        return static_cast<half_float::half>(x);
    }
};

27 28
template <>
struct RoundingConverter<half_bfloat16::bfloat16> {
M
Megvii Engine Team 已提交
29 30
    MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE half_bfloat16::bfloat16 operator()(
            float x) const {
31 32 33 34
        return static_cast<half_bfloat16::bfloat16>(x);
    }
};

35
#endif  // #if !MEGDNN_DISABLE_FLOAT16
36 37 38

template <>
struct RoundingConverter<int8_t> {
M
Megvii Engine Team 已提交
39
    MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE int8_t operator()(float x) const {
40 41 42 43 44 45 46 47 48
#if MEGDNN_CC_HOST
        using std::round;
#endif
        return static_cast<int8_t>(round(x));
    }
};

template <>
struct RoundingConverter<uint8_t> {
M
Megvii Engine Team 已提交
49
    MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE uint8_t operator()(float x) const {
50
#if MEGDNN_CC_HOST
51 52
        using std::max;
        using std::min;
53
        using std::round;
54
#endif
55
        x = min(255.0f, max(0.0f, x));  //! FIXME!!! check other places
56 57 58 59
        return static_cast<uint8_t>(round(x));
    }
};

60 61
template <>
struct RoundingConverter<dt_qint4> {
M
Megvii Engine Team 已提交
62
    MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_qint4 operator()(float x) const {
63 64 65 66 67 68 69
#if MEGDNN_CC_HOST
        using std::round;
#endif
        return static_cast<dt_qint4>(round(x));
    }
};

70 71
template <>
struct RoundingConverter<dt_quint4> {
M
Megvii Engine Team 已提交
72
    MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_quint4 operator()(float x) const {
73 74 75 76 77 78 79
#if MEGDNN_CC_HOST
        using std::round;
#endif
        return static_cast<dt_quint4>(round(x));
    }
};

80 81 82 83
}  // namespace rounding
}  // namespace megdnn

/* vim: set ft=cpp: */