handle.h 5.5 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/handle.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once
#include "megcore_cuda.h"
#include "megdnn/basic_types.h"
#include "megdnn/handle.h"
#include "megdnn/oprs/general.h"

#include "src/common/utils.h"
#include "src/common/handle_impl.h"
#include "src/cuda/cudnn_with_check.h"

#include <atomic>
#include <mutex>
#include <cuda_runtime_api.h>
#include <cublas_v2.h>
#include <cusolverDn.h>

#include <cuda.h>
#if CUDA_VERSION >= 10010
#include <cublasLt.h>
#endif

namespace megdnn {
namespace cuda {

class HandleImpl: public HandleImplHelper {
    public:
        HandleImpl(megcoreComputingHandle_t computing_handle);
        ~HandleImpl() noexcept;

        size_t alignment_requirement() const override;

        bool check_cross_dev_copy_constraint(const TensorLayout &src) override;

        const cudaDeviceProp& device_prop() const {
45
            return *m_device_prop;
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
        }

        template <typename Opr>
        std::unique_ptr<Opr> create_operator();

        const megcore::CudaContext& megcore_context() const {
            return m_megcore_context;
        }

        int device_id() const { return m_device_id; }

        cudaStream_t stream() const {
            return megcore_context().stream;
        }
        cudnnHandle_t cudnn_handle() {
            return m_cudnn_handle;
        }
        cublasHandle_t cublas_handle() {
            return m_cublas_handle;
        }
#if CUDA_VERSION >= 10010
        cublasLtHandle_t cublasLt_handle() {
            return m_cublasLt_handle;
        }
#endif
        cusolverDnHandle_t cusolver_handle() {
            std::call_once(m_cusolver_initialized,
                           [this] { initialize_cusolver(); });
            return m_cusolver_handle;
        }
        dt_float32 *zero_device() {
            return &m_const_scalars->f32[0];
        }
        dt_float32 *one_device() {
            return &m_const_scalars->f32[1];
        }
        __half* zero_device_h() {
            return &m_const_scalars->f16[0].cuda_x;
        }
        __half* one_device_h() {
            return &m_const_scalars->f16[1].cuda_x;
        }
        dt_int32 *zero_device_i32() {
            return &m_const_scalars->i32[0];
        }
        dt_int32 *one_device_i32() {
            return &m_const_scalars->i32[1];
        }

        bool is_tegra_k1() const {
            return m_is_tegra_k1;
        }

        //! global matmul opr
        MatrixMul* matmul_opr() override final {
            return get_helper_opr<MatrixMul, 0>(this);
        }

        //! global matmul opr with first operand transposed
        MatrixMul* matmul_aT_opr() override final {
            return get_helper_opr<MatrixMul, 1>(this, {true, false});
        }

        //! global matmul opr with second operand transposed
        MatrixMul* matmul_bT_opr() override final {
            return get_helper_opr<MatrixMul, 2>(this, {false, true});
        }

        //! global relayout opr
        Relayout* relayout_opr() override final {
            return get_helper_opr<Relayout, 3>(this);
        }

        BatchedMatrixMulForward* batched_matrix_mul() {
            return get_helper_opr<BatchedMatrixMulForward, 4>(this);
        }

        TypeCvt* typecvt_opr() { return get_helper_opr<TypeCvt, 0>(this); }

        size_t image2d_pitch_alignment() const override;
126
        HandleVendorType vendor_type() const override;
127 128 129 130

        class CUDNN;

        CUDNN& cudnn();
131 132 133 134 135 136 137 138 139 140 141 142 143 144
    private:
        bool m_is_tegra_k1;
        int m_device_id;
        //! MegDNN handle does not manage the lifetime of CUDA stream.
        megcore::CudaContext m_megcore_context;

        cudnnHandle_t m_cudnn_handle;
        cublasHandle_t m_cublas_handle;
#if CUDA_VERSION >= 10010
        cublasLtHandle_t m_cublasLt_handle;
#endif
        cusolverDnHandle_t m_cusolver_handle;
        std::once_flag m_cusolver_initialized;

145
        const cudaDeviceProp* m_device_prop;
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162

        struct ConstScalars {
            union FP16 {
                __half cuda_x;
                dt_float16 megdnn_x;
                FP16() {}
            };
            static_assert(sizeof(FP16) == 2, "bad FP16 size");
            FP16 f16[2];
            dt_float32 f32[2];
            dt_int32 i32[2];
            void init();
        };

        //! device ptr to const scalars
        ConstScalars* m_const_scalars;

163 164
        std::unique_ptr<CUDNN> m_cudnn_api_cache;

165 166 167
        void initialize_cusolver();
};

168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
class HandleImpl::CUDNN {
    cudnnHandle_t m_handle;
public:
    CUDNN(cudnnHandle_t handle);
#define WRAP_CUDNN_API(NAME) thin_function<decltype(cudnn##NAME)> NAME;
    WRAP_CUDNN_API(GetConvolutionForwardWorkspaceSize);
#if CUDNN_MAJOR >= 7
    WRAP_CUDNN_API(GetConvolutionForwardAlgorithm_v7);
    WRAP_CUDNN_API(GetConvolutionForwardAlgorithmMaxCount);
#endif
#if CUDNN_MAJOR >= 7
    WRAP_CUDNN_API(GetConvolutionBackwardDataAlgorithm_v7);
    WRAP_CUDNN_API(GetConvolutionBackwardDataAlgorithmMaxCount);
#endif
    WRAP_CUDNN_API(GetConvolutionBackwardDataWorkspaceSize);
#if CUDNN_MAJOR >= 7
    WRAP_CUDNN_API(GetConvolutionBackwardFilterAlgorithmMaxCount);
    WRAP_CUDNN_API(GetConvolutionBackwardFilterAlgorithm_v7);
#endif
    WRAP_CUDNN_API(GetConvolutionBackwardFilterWorkspaceSize);
#undef WRAP_CUDNN_API
};

191 192 193 194
} // namespace cuda
} // namespace megdnn

// vim: syntax=cpp.doxygen