mlu_info.h 5.6 KB
Newer Older
F
fwenguang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#ifdef PADDLE_WITH_MLU
#include <cn_api.h>
#include <cnnl.h>
F
fwenguang 已提交
20
#include <cnpapi.h>
C
cifar10 已提交
21
#include <cnpapi_cndrv_id.h>
F
fwenguang 已提交
22
#include <cnrt.h>
C
cifar10 已提交
23
#include <mlu_op.h>
24 25 26
#ifdef PADDLE_WITH_CNCL
#include <cncl.h>
#endif
F
fwenguang 已提交
27 28 29 30 31 32 33
#include <vector>

namespace paddle {

using cnStatus = CNresult;
using cnrtStatus = cnrtRet_t;
using cnnlStatus = cnnlStatus_t;
C
cifar10 已提交
34
using mluOpStatus = mluOpStatus_t;
35 36 37
#ifdef PADDLE_WITH_CNCL
using cnclStatus = cnclResult_t;
#endif
F
fwenguang 已提交
38 39
using mluStream = cnrtQueue_t;
using mluCnnlHandle = cnnlHandle_t;
C
cifar10 已提交
40
using mluOpHandle = mluOpHandle_t;
F
fwenguang 已提交
41
using mluEventHandle = cnrtNotifier_t;
F
fwenguang 已提交
42 43 44 45 46 47 48 49 50 51
using mluDeviceHandle = CNdev;

namespace platform {

//! Get the driver version of the ith MLU.
int GetMLUDriverVersion(int id);

//! Get the runtime version of the ith MLU.
int GetMLURuntimeVersion(int id);

Q
qipengh 已提交
52 53 54
//! Get the cnnl version of the ith MLU.
int GetMLUCnnlVersion(int id);

C
cifar10 已提交
55 56 57
//! Get the mluOp version of the ith MLU.
int GetMLUOpVersion(int id);

F
fwenguang 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
//! Get the total number of MLU devices in system.
int GetMLUDeviceCount();

//! Get a list of device ids from environment variable or use all.
std::vector<int> GetMLUSelectedDevices();

//! Get the current MLU device id in system.
int GetMLUCurrentDeviceId();

//! Set the MLU device id for next execution.
void SetMLUDeviceId(int device_id);

//! Get a handle of device ids.
void GetMLUDeviceHandle(int device_ordinal, mluDeviceHandle* device);

//! Get the compute capability of the ith MLU (format: major * 10 + minor)
int GetMLUComputeCapability(int id);

//! Get the memory usage of current MLU device.
void MLUMemoryUsage(size_t* available, size_t* total);

//! Get the available memory to allocate, which is the size of available mlu
//! minus reserving.
size_t MLUAvailableMemToAlloc();

//! Get the maximum allocation size of current MLU device.
size_t MLUMaxAllocSize();

//! Get the initial allocation size of current MLU device.
size_t MLUInitAllocSize();

//! Get the re-allocation size of current MLU device.
size_t MLUReallocSize();

//! Get the minimum chunk size for MLU buddy allocator.
size_t MLUMinChunkSize();

//! Get the maximum chunk size for MLU buddy allocator.
size_t MLUMaxChunkSize();

//! Copy memory from address device to host asynchronously.
99 100 101
void MLUMemcpyD2HAsync(void* dst,
                       const void* src,
                       size_t num,
F
fwenguang 已提交
102 103 104 105 106 107
                       mluStream stream);

//! Copy memory from address device to host synchronously.
void MLUMemcpyD2HSync(void* dst, const void* src, size_t num);

//! Copy memory from address host to device asynchronously.
108 109 110
void MLUMemcpyH2DAsync(void* dst,
                       const void* src,
                       size_t num,
F
fwenguang 已提交
111 112 113 114 115 116
                       mluStream stream);

//! Copy memory from address host to device synchronously.
void MLUMemcpyH2DSync(void* dst, const void* src, size_t num);

//! Copy memory from address device to device asynchronously in a single device.
117 118 119
void MLUMemcpyD2DAsync(void* dst,
                       const void* src,
                       size_t num,
F
fwenguang 已提交
120 121 122 123 124 125
                       mluStream stream);

//! Copy memory from address device to device synchronously in a single device.
void MLUMemcpyD2DSync(void* dst, const void* src, size_t num);

//! Copy memory from one device to another device asynchronously.
126 127 128 129 130 131
void MLUMemcpyPeerAsync(void* dst,
                        int dst_place,
                        const void* src,
                        int src_place,
                        size_t num,
                        mluStream stream);
F
fwenguang 已提交
132 133

//! Copy memory from one device to another device synchronously.
134 135
void MLUMemcpyPeerSync(
    void* dst, int dst_place, const void* src, int src_place, size_t num);
F
fwenguang 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148 149

//! Set memory dst with value count size asynchronously
void MLUMemsetAsync(void* dst, int value, size_t count, mluStream stream);

//! Blocks until stream has completed all operations.
void MLUStreamSync(mluStream stream);

//! MLUMalloc with recorded info
cnrtStatus RecordedMLUMalloc(void** ptr, size_t size, int dev_id);

//! MLUFree with recorded info
void RecordedMLUFree(void* p, size_t size, int dev_id);

//! Get available and total mlu memory with considering limitation
150 151 152 153 154
bool RecordedMLUMemGetInfo(size_t* avail,
                           size_t* total,
                           size_t* actual_avail,
                           size_t* actual_total,
                           int dev_id);
F
fwenguang 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190

//! Get recorded mluMalloc size. If record is disabled, return 0.
uint64_t RecordedMLUMallocSize(int dev_id);

bool IsMLUMallocRecorded(int dev_id);

//! Empty idle cached memory held by the allocator.
void EmptyCache(void);

class MLUDeviceGuard {
 public:
  explicit inline MLUDeviceGuard(int dev_id) {
    int prev_id = platform::GetMLUCurrentDeviceId();
    if (prev_id != dev_id) {
      prev_id_ = prev_id;
      platform::SetMLUDeviceId(dev_id);
    }
  }

  inline ~MLUDeviceGuard() {
    if (prev_id_ != -1) {
      platform::SetMLUDeviceId(prev_id_);
    }
  }

  MLUDeviceGuard(const MLUDeviceGuard& o) = delete;
  MLUDeviceGuard& operator=(const MLUDeviceGuard& o) = delete;

 private:
  int prev_id_{-1};
};

}  // namespace platform
}  // namespace paddle

#endif