fp16_help.cuh 1.9 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/fp16_help.cuh
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once

#include <cuda_runtime_api.h>
#include "cuda.h"
#include "cuda_fp16.h"

namespace megdnn {
namespace cuda {

M
Megvii Engine Team 已提交
20
__device__ __forceinline__ float fma(const float a, const float b, const float c) {
21 22 23
    return a * b + c;
}

M
Megvii Engine Team 已提交
24
__device__ __forceinline__ float2 fma2(const float2 a, const float2 b, const float2 c) {
25 26 27 28 29
    return {a.x * b.x + c.x, a.y * b.y + c.y};
}

#if CUDA_VERSION >= 9000

M
Megvii Engine Team 已提交
30
__device__ __forceinline__ __half fma(const __half a, const __half b, const __half c) {
31 32 33 34 35 36 37
#if __CUDA_ARCH__ >= 530
    return __hfma(a, b, c);
#else
    return __float2half(__half2float(a) * __half2float(b) + __half2float(c));
#endif
}

M
Megvii Engine Team 已提交
38 39
__device__ __forceinline__ __half2
fma2(const __half2 a, const __half2 b, const __half2 c) {
40 41 42
#if __CUDA_ARCH__ >= 530
    return __hfma2(a, b, c);
#else
M
Megvii Engine Team 已提交
43 44
    return {__float2half(__half2float(a.x) * __half2float(b.x) + __half2float(c.x)),
            __float2half(__half2float(a.y) * __half2float(b.y) + __half2float(c.y))};
45 46 47
#endif
}

48 49 50 51 52 53 54 55 56
__device__ __forceinline__ __half2 hadd2(const __half2 a, const __half2 b) {
#if __CUDA_ARCH__ >= 530
    return __hadd2(a, b);
#else
    return {__float2half(__half2float(a.x) + __half2float(b.x)),
            __float2half(__half2float(a.y) + __half2float(b.y))};
#endif
}

57 58 59 60 61 62
__device__ __forceinline__ float2
fma2(const __half2 a, const __half2 b, const float2 c) {
    return {__half2float(a.x) * __half2float(b.x) + c.x,
            __half2float(a.y) * __half2float(b.y) + c.y};
}

63 64
#endif  // CUDA_VERSION >= 9000

M
Megvii Engine Team 已提交
65 66
}  // namespace cuda
}  // namespace megdnn
67 68

// vim: syntax=cpp.doxygen