#pragma once #include "megdnn/dtype.h" namespace megdnn { namespace rounding { template struct RoundingConverter; template <> struct RoundingConverter { MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE float operator()(float x) const { return x; } }; #if !MEGDNN_DISABLE_FLOAT16 template <> struct RoundingConverter { MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE half_float::half operator()( float x) const { return static_cast(x); } }; template <> struct RoundingConverter { MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE half_bfloat16::bfloat16 operator()( float x) const { return static_cast(x); } }; #endif // #if !MEGDNN_DISABLE_FLOAT16 template <> struct RoundingConverter { MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE int8_t operator()(float x) const { #if MEGDNN_CC_HOST using std::round; #endif return static_cast(round(x)); } }; template <> struct RoundingConverter { MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE uint8_t operator()(float x) const { #if MEGDNN_CC_HOST using std::max; using std::min; using std::round; #endif x = min(255.0f, max(0.0f, x)); //! FIXME!!! check other places return static_cast(round(x)); } }; template <> struct RoundingConverter { MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_qint4 operator()(float x) const { #if MEGDNN_CC_HOST using std::round; #endif return static_cast(round(x)); } }; template <> struct RoundingConverter { MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_quint4 operator()(float x) const { #if MEGDNN_CC_HOST using std::round; #endif return static_cast(round(x)); } }; } // namespace rounding } // namespace megdnn /* vim: set ft=cpp: */