rounding_converter.cuh 2.0 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/common/rounding_converter.cuh
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once
#include "megdnn/dtype.h"

#if MEGDNN_CC_HOST && !defined(__host__)
#define MEGDNN_HOST_DEVICE_SELF_DEFINE
#define __host__
#define __device__
#if __GNUC__ || __has_attribute(always_inline)
#define __forceinline__ inline __attribute__((always_inline))
#else
#define __forceinline__ inline
#endif
#endif

namespace megdnn {
namespace rounding {

template <typename T>
struct RoundingConverter;

template <>
struct RoundingConverter<float> {
    __host__ __device__ __forceinline__ float operator()(float x) const {
        return x;
    }
};

#ifndef MEGDNN_DISABLE_FLOAT16

template <>
struct RoundingConverter<half_float::half> {
    __host__ __device__ __forceinline__ half_float::half operator()(
            float x) const {
        return static_cast<half_float::half>(x);
    }
};

49 50 51 52 53 54 55 56
template <>
struct RoundingConverter<half_bfloat16::bfloat16> {
    __host__ __device__ __forceinline__ half_bfloat16::bfloat16 operator()(
            float x) const {
        return static_cast<half_bfloat16::bfloat16>(x);
    }
};

57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
#endif  // #ifdef MEGDNN_DISABLE_FLOAT16

template <>
struct RoundingConverter<int8_t> {
    __host__ __device__ __forceinline__ int8_t operator()(float x) const {
#if MEGDNN_CC_HOST
        using std::round;
#endif
        return static_cast<int8_t>(round(x));
    }
};

template <>
struct RoundingConverter<uint8_t> {
    __host__ __device__ __forceinline__ uint8_t operator()(float x) const {
#if MEGDNN_CC_HOST
        using std::round;
#endif
        return static_cast<uint8_t>(round(x));
    }
};

}  // namespace rounding
}  // namespace megdnn

/* vim: set ft=cpp: */