gpu_info.cc 12.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
liaogang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/platform/gpu_info.h"
16
#include <algorithm>
S
sneaxiy 已提交
17 18
#include <cstdlib>
#include <string>
L
liaogang 已提交
19

20
#include "gflags/gflags.h"
Y
Yi Wang 已提交
21
#include "paddle/fluid/platform/enforce.h"
22
#include "paddle/fluid/string/split.h"
L
liaogang 已提交
23

24 25 26 27 28
DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_bool(enable_cublas_tensor_op_math);
DECLARE_string(selected_gpus);
29

Z
zhhsplendid 已提交
30 31
constexpr static float fraction_reserve_gpu_memory = 0.05f;

L
liaogang 已提交
32 33 34
namespace paddle {
namespace platform {

35 36 37 38 39
/* Here is a very simple CUDA “pro tip”: cudaDeviceGetAttribute() is a much
faster way to query device properties. You can see details in
https://devblogs.nvidia.com/cuda-pro-tip-the-fast-way-to-query-device-properties/
*/

40 41 42 43 44 45
inline std::string CudaErrorWebsite() {
  return "Please see detail in https://docs.nvidia.com/cuda/cuda-runtime-api"
         "/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3f51e3575c217824"
         "6db0a94a430e0038";
}

S
sneaxiy 已提交
46
static int GetCUDADeviceCountImpl() {
47 48 49 50 51 52 53 54
  int driverVersion = 0;
  cudaError_t status = cudaDriverGetVersion(&driverVersion);

  if (!(status == cudaSuccess && driverVersion != 0)) {
    // No GPU driver
    return 0;
  }

S
sneaxiy 已提交
55 56 57 58 59 60
  const auto *cuda_visible_devices = std::getenv("CUDA_VISIBLE_DEVICES");
  if (cuda_visible_devices != nullptr) {
    std::string cuda_visible_devices_str(cuda_visible_devices);
    if (std::all_of(cuda_visible_devices_str.begin(),
                    cuda_visible_devices_str.end(),
                    [](char ch) { return ch == ' '; })) {
S
sneaxiy 已提交
61
      VLOG(2) << "CUDA_VISIBLE_DEVICES is set to be empty. No GPU detected.";
S
sneaxiy 已提交
62 63 64 65
      return 0;
    }
  }

L
liaogang 已提交
66
  int count;
67
  auto error_code = cudaGetDeviceCount(&count);
L
liaogang 已提交
68
  PADDLE_ENFORCE(
69 70 71 72
      error_code,
      "cudaGetDeviceCount failed in "
      "paddle::platform::GetCUDADeviceCountImpl, error code : %d, %s",
      error_code, CudaErrorWebsite());
L
liaogang 已提交
73 74 75
  return count;
}

S
sneaxiy 已提交
76 77 78 79 80
int GetCUDADeviceCount() {
  static auto dev_cnt = GetCUDADeviceCountImpl();
  return dev_cnt;
}

81 82
int GetCUDAComputeCapability(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
83 84 85 86 87 88 89 90 91
  int major, minor;

  auto major_error_code =
      cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id);
  auto minor_error_code =
      cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, id);
  PADDLE_ENFORCE_EQ(
      major_error_code, 0,
      "cudaDevAttrComputeCapabilityMajor failed in "
92
      "paddle::platform::GetCUDAComputeCapability, error code : %d, %s",
93 94 95 96 97 98 99
      major_error_code, CudaErrorWebsite());
  PADDLE_ENFORCE_EQ(
      minor_error_code, 0,
      "cudaDevAttrComputeCapabilityMinor failed in "
      "paddle::platform::GetCUDAComputeCapability, error code : %d, %s",
      minor_error_code, CudaErrorWebsite());
  return major * 10 + minor;
100 101
}

102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
dim3 GetGpuMaxGridDimSize(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  dim3 ret;
  int size;
  auto error_code_x = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimX, id);
  PADDLE_ENFORCE_EQ(error_code_x, 0,
                    "cudaDevAttrMaxGridDimX failed in "
                    "paddle::platform::GpuMaxGridDimSize, error code : %d, %s",
                    error_code_x, CudaErrorWebsite());
  ret.x = size;

  auto error_code_y = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimY, id);
  PADDLE_ENFORCE_EQ(error_code_y, 0,
                    "cudaDevAttrMaxGridDimY failed in "
                    "paddle::platform::GpuMaxGridDimSize, error code : %d, %s",
                    error_code_y, CudaErrorWebsite());
  ret.y = size;

  auto error_code_z = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimZ, id);
  PADDLE_ENFORCE_EQ(error_code_z, 0,
                    "cudaDevAttrMaxGridDimZ failed in "
                    "paddle::platform::GpuMaxGridDimSize, error code : %d, %s",
                    error_code_z, CudaErrorWebsite());
  ret.z = size;
  return ret;
}

C
chengduo 已提交
129 130 131
int GetCUDARuntimeVersion(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int runtime_version = 0;
132 133
  auto error_code = cudaRuntimeGetVersion(&runtime_version);
  PADDLE_ENFORCE(error_code,
C
chengduo 已提交
134
                 "cudaRuntimeGetVersion failed in "
135 136
                 "paddle::platform::GetCUDARuntimeVersion, error code : %d, %s",
                 error_code, CudaErrorWebsite());
C
chengduo 已提交
137 138 139 140 141 142
  return runtime_version;
}

int GetCUDADriverVersion(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int driver_version = 0;
143 144
  auto error_code = cudaDriverGetVersion(&driver_version);
  PADDLE_ENFORCE(error_code,
C
chengduo 已提交
145
                 "cudaDriverGetVersion failed in "
146 147
                 "paddle::platform::GetCUDADriverVersion, error code : %d, %s",
                 error_code, CudaErrorWebsite());
C
chengduo 已提交
148 149 150
  return driver_version;
}

151 152 153 154 155 156 157 158 159 160
bool TensorCoreAvailable() {
#if CUDA_VERSION >= 9000
  int device = GetCurrentDeviceId();
  int driver_version = GetCUDAComputeCapability(device);
  return driver_version >= 70;
#else
  return false;
#endif
}

C
chengduoZH 已提交
161 162 163
int GetCUDAMultiProcessors(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int count;
164 165 166 167 168 169
  auto error_code =
      cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id);
  PADDLE_ENFORCE(error_code,
                 "cudaDeviceGetAttribute failed in "
                 "paddle::platform::GetCUDAMultiProcess, error code : %d, %s",
                 error_code, CudaErrorWebsite());
C
chengduoZH 已提交
170 171 172 173 174 175
  return count;
}

int GetCUDAMaxThreadsPerMultiProcessor(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int count;
176 177 178 179 180 181 182
  auto error_code = cudaDeviceGetAttribute(
      &count, cudaDevAttrMaxThreadsPerMultiProcessor, id);
  PADDLE_ENFORCE(
      error_code,
      "cudaDeviceGetAttribute failed in paddle::"
      "platform::GetCUDAMaxThreadsPerMultiProcessor, error code : %d, %s",
      error_code, CudaErrorWebsite());
C
chengduoZH 已提交
183 184 185
  return count;
}

L
liaogang 已提交
186 187
int GetCurrentDeviceId() {
  int device_id;
188 189 190 191 192
  auto error_code = cudaGetDevice(&device_id);
  PADDLE_ENFORCE(error_code,
                 "cudaGetDevice failed in "
                 "paddle::platform::GetCurrentDeviceId, error code : %d, %s",
                 error_code, CudaErrorWebsite());
L
liaogang 已提交
193 194 195
  return device_id;
}

196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
//! Get a list of device ids from environment variable or use all.
std::vector<int> GetSelectedDevices() {
  // use user specified GPUs in single-node multi-process mode.
  std::vector<int> devices;
  if (!FLAGS_selected_gpus.empty()) {
    auto devices_str = paddle::string::Split(FLAGS_selected_gpus, ',');
    for (auto id : devices_str) {
      devices.push_back(atoi(id.c_str()));
    }
  } else {
    int count = GetCUDADeviceCount();
    for (int i = 0; i < count; ++i) {
      devices.push_back(i);
    }
  }
  return devices;
}

L
liaogang 已提交
214
void SetDeviceId(int id) {
Q
qijun 已提交
215
  // TODO(qijun): find a better way to cache the cuda device count
Y
Yang Yang 已提交
216
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
217 218 219 220 221
  auto error_code = cudaSetDevice(id);
  PADDLE_ENFORCE(error_code,
                 "cudaSetDevice failed in "
                 "paddle::platform::SetDeviced, error code : %d, %s",
                 error_code, CudaErrorWebsite());
L
liaogang 已提交
222 223
}

224
void GpuMemoryUsage(size_t *available, size_t *total) {
225 226 227 228 229
  auto error_code = cudaMemGetInfo(available, total);
  PADDLE_ENFORCE(error_code,
                 "cudaMemGetInfo failed in "
                 "paddle::platform::GetMemoryUsage, error code : %d, %s",
                 error_code, CudaErrorWebsite());
L
liaogang 已提交
230 231
}

232
size_t GpuAvailableMemToAlloc() {
L
liaogang 已提交
233 234
  size_t total = 0;
  size_t available = 0;
235
  GpuMemoryUsage(&available, &total);
236 237
  size_t reserving =
      static_cast<size_t>(fraction_reserve_gpu_memory * available);
238
  // If available size is less than minimum chunk size, no usable memory exists
239
  size_t available_to_alloc = available - reserving;
240
  size_t min_chunk_size = GpuMinChunkSize();
241 242 243
  if (available_to_alloc < min_chunk_size) {
    available_to_alloc = 0;
  }
244 245 246
  VLOG(10) << "GPU usage " << (available >> 20) << "M/" << (total >> 20)
           << "M, " << (available_to_alloc >> 20) << "M available to allocate";
  return available_to_alloc;
Z
zhhsplendid 已提交
247 248
}

249 250 251
size_t GpuMaxAllocSize() {
  return std::max(GpuInitAllocSize(), GpuReallocSize());
}
Z
zhhsplendid 已提交
252

253 254 255 256 257 258 259 260 261 262
static size_t GpuAllocSize(bool realloc) {
  size_t available_to_alloc = GpuAvailableMemToAlloc();
  PADDLE_ENFORCE_GT(available_to_alloc, 0, "No enough available GPU memory");
  // If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be
  // allocated by fraction
  size_t flag_mb = realloc ? FLAGS_reallocate_gpu_memory_in_mb
                           : FLAGS_initial_gpu_memory_in_mb;
  size_t alloc_bytes =
      (flag_mb > 0ul ? flag_mb << 20 : available_to_alloc *
                                           FLAGS_fraction_of_gpu_memory_to_use);
263
  PADDLE_ENFORCE_GE(available_to_alloc, alloc_bytes,
264 265 266 267 268
                    "No enough available GPU memory");
  VLOG(10) << "Alloc size is " << (alloc_bytes >> 20)
           << " MiB, is it Re-alloc: " << realloc;
  return alloc_bytes;
}
Z
zhhsplendid 已提交
269

270
size_t GpuInitAllocSize() { return GpuAllocSize(/* realloc = */ false); }
Z
zhhsplendid 已提交
271

272
size_t GpuReallocSize() { return GpuAllocSize(/* realloc = */ true); }
L
liaogang 已提交
273

L
liaogang 已提交
274 275 276 277 278 279
size_t GpuMinChunkSize() {
  // Allow to allocate the minimum chunk size is 256 bytes.
  return 1 << 8;
}

size_t GpuMaxChunkSize() {
280 281 282
  size_t max_chunk_size = GpuMaxAllocSize();
  VLOG(10) << "Max chunk size " << (max_chunk_size >> 20) << "M";
  return max_chunk_size;
L
liaogang 已提交
283 284
}

L
liaogang 已提交
285 286
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
                    enum cudaMemcpyKind kind, cudaStream_t stream) {
287 288
  auto error_code = cudaMemcpyAsync(dst, src, count, kind, stream);
  PADDLE_ENFORCE(error_code,
289
                 "cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync "
290 291 292
                 "(%p -> %p, length: %d) error code : %d, %s",
                 src, dst, static_cast<int>(count), error_code,
                 CudaErrorWebsite());
L
liaogang 已提交
293 294
}

295 296
void GpuMemcpySync(void *dst, const void *src, size_t count,
                   enum cudaMemcpyKind kind) {
297 298 299 300 301 302
  auto error_code = cudaMemcpy(dst, src, count, kind);
  PADDLE_ENFORCE(error_code,
                 "cudaMemcpy failed in paddle::platform::GpuMemcpySync "
                 "(%p -> %p, length: %d) error code : %d, %s",
                 src, dst, static_cast<int>(count), error_code,
                 CudaErrorWebsite());
303 304 305 306
}

void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,
                        int src_device, size_t count, cudaStream_t stream) {
307 308
  auto error_code =
      cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream);
L
liaogang 已提交
309
  PADDLE_ENFORCE(
310 311 312 313
      error_code,
      "cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeerAsync "
      "error code : %d, %s",
      error_code, CudaErrorWebsite());
314 315 316 317
}

void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src,
                       int src_device, size_t count) {
318 319 320 321 322
  auto error_code = cudaMemcpyPeer(dst, dst_device, src, src_device, count);
  PADDLE_ENFORCE(error_code,
                 "cudaMemcpyPeer failed in paddle::platform::GpuMemcpyPeerSync "
                 "error code : %d, %s",
                 error_code, CudaErrorWebsite());
L
liaogang 已提交
323
}
D
dzhwinter 已提交
324 325

void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) {
326 327 328 329 330
  auto error_code = cudaMemsetAsync(dst, value, count, stream);
  PADDLE_ENFORCE(error_code,
                 "cudaMemsetAsync failed in paddle::platform::GpuMemsetAsync "
                 "error code : %d, %s",
                 error_code, CudaErrorWebsite());
D
dzhwinter 已提交
331
}
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346

void RaiseNonOutOfMemoryError(cudaError_t *status) {
  if (*status == cudaErrorMemoryAllocation) {
    *status = cudaSuccess;
  }

  PADDLE_ENFORCE_CUDA_SUCCESS(*status);

  *status = cudaGetLastError();
  if (*status == cudaErrorMemoryAllocation) {
    *status = cudaSuccess;
  }

  PADDLE_ENFORCE_CUDA_SUCCESS(*status);
}
L
liaogang 已提交
347 348
}  // namespace platform
}  // namespace paddle