gpu_info.cc 23.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
liaogang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/platform/gpu_info.h"
S
sneaxiy 已提交
16
#include <cstdlib>
17 18
#include <mutex>
#include <vector>
L
liaogang 已提交
19

20
#include "gflags/gflags.h"
21
#include "paddle/fluid/platform/cuda_device_guard.h"
22 23 24
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/dynload/miopen.h"
#else
25
#include "paddle/fluid/platform/cuda_graph.h"
26
#include "paddle/fluid/platform/dynload/cudnn.h"
27
#endif
28
#include "paddle/fluid/memory/malloc.h"
29 30 31 32 33
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10020
#include "paddle/fluid/platform/dynload/cuda_driver.h"
#endif
#endif
Y
Yi Wang 已提交
34
#include "paddle/fluid/platform/enforce.h"
35 36
#include "paddle/fluid/platform/lock_guard_ptr.h"
#include "paddle/fluid/platform/macros.h"
H
hutuxian 已提交
37
#include "paddle/fluid/platform/monitor.h"
38
#include "paddle/fluid/platform/place.h"
39
#include "paddle/fluid/string/split.h"
L
liaogang 已提交
40

41 42 43 44 45
DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_bool(enable_cublas_tensor_op_math);
DECLARE_string(selected_gpus);
46
DECLARE_uint64(gpu_memory_limit_mb);
47

Z
zhhsplendid 已提交
48 49
constexpr static float fraction_reserve_gpu_memory = 0.05f;

50 51 52 53
static std::once_flag g_device_props_size_init_flag;
static std::vector<std::unique_ptr<std::once_flag>> g_device_props_init_flags;
static std::vector<paddle::gpuDeviceProp> g_device_props;

H
hutuxian 已提交
54
USE_GPU_MEM_STAT;
L
liaogang 已提交
55 56 57
namespace paddle {
namespace platform {

58 59
int CudnnVersion() {
  if (!dynload::HasCUDNN()) return -1;
60

61 62 63 64 65 66
#ifdef PADDLE_WITH_HIP
  size_t version_major, version_minor, version_patch;
  PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenGetVersion(
      &version_major, &version_minor, &version_patch));
  return version_major * 100 + version_minor * 10 + version_patch;
#else
67
  return dynload::cudnnGetVersion();
68
#endif
69
}
S
sneaxiy 已提交
70
static int GetCUDADeviceCountImpl() {
71
  int driverVersion = 0;
72 73 74
#ifdef PADDLE_WITH_HIP
  hipError_t status = hipDriverGetVersion(&driverVersion);
#else
75
  cudaError_t status = cudaDriverGetVersion(&driverVersion);
76
#endif
77

78
  if (!(status == gpuSuccess && driverVersion != 0)) {
79
    // No GPU driver
80
    VLOG(2) << "GPU Driver Version can't be detected. No GPU driver!";
81 82 83
    return 0;
  }

84 85 86
#ifdef PADDLE_WITH_HIP
  const auto *cuda_visible_devices = std::getenv("HIP_VISIBLE_DEVICES");
#else
S
sneaxiy 已提交
87
  const auto *cuda_visible_devices = std::getenv("CUDA_VISIBLE_DEVICES");
88
#endif
S
sneaxiy 已提交
89 90
  if (cuda_visible_devices != nullptr) {
    std::string cuda_visible_devices_str(cuda_visible_devices);
91 92 93 94 95 96 97 98 99 100
    if (!cuda_visible_devices_str.empty()) {
      cuda_visible_devices_str.erase(
          0, cuda_visible_devices_str.find_first_not_of('\''));
      cuda_visible_devices_str.erase(
          cuda_visible_devices_str.find_last_not_of('\'') + 1);
      cuda_visible_devices_str.erase(
          0, cuda_visible_devices_str.find_first_not_of('\"'));
      cuda_visible_devices_str.erase(
          cuda_visible_devices_str.find_last_not_of('\"') + 1);
    }
S
sneaxiy 已提交
101 102 103
    if (std::all_of(cuda_visible_devices_str.begin(),
                    cuda_visible_devices_str.end(),
                    [](char ch) { return ch == ' '; })) {
104 105
      VLOG(2) << "CUDA_VISIBLE_DEVICES or HIP_VISIBLE_DEVICES is set to be "
                 "empty. No GPU detected.";
S
sneaxiy 已提交
106 107 108
      return 0;
    }
  }
L
liaogang 已提交
109
  int count;
110 111 112
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipGetDeviceCount(&count));
#else
113
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetDeviceCount(&count));
114
#endif
L
liaogang 已提交
115 116 117
  return count;
}

S
sneaxiy 已提交
118
int GetCUDADeviceCount() {
119
  // cache the count
S
sneaxiy 已提交
120 121 122 123
  static auto dev_cnt = GetCUDADeviceCountImpl();
  return dev_cnt;
}

124 125 126 127
/* Here is a very simple CUDA “pro tip”: cudaDeviceGetAttribute() is a much
faster way to query device properties. You can see details in
https://devblogs.nvidia.com/cuda-pro-tip-the-fast-way-to-query-device-properties/
*/
128
int GetCUDAComputeCapability(int id) {
129 130 131 132 133
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
134 135
  int major, minor;

136 137 138 139 140 141
#ifdef PADDLE_WITH_HIP
  auto major_error_code = hipDeviceGetAttribute(
      &major, hipDeviceAttributeComputeCapabilityMajor, id);
  auto minor_error_code = hipDeviceGetAttribute(
      &minor, hipDeviceAttributeComputeCapabilityMinor, id);
#else
142 143 144 145
  auto major_error_code =
      cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id);
  auto minor_error_code =
      cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, id);
146
#endif
147 148
  PADDLE_ENFORCE_CUDA_SUCCESS(major_error_code);
  PADDLE_ENFORCE_CUDA_SUCCESS(minor_error_code);
149 150 151
#ifdef PADDLE_WITH_HIP
  return major * 100 + minor;
#else
152
  return major * 10 + minor;
153
#endif
154 155
}

156
dim3 GetGpuMaxGridDimSize(int id) {
157 158 159 160 161
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
162 163
  dim3 ret;
  int size;
164 165 166 167
#ifdef PADDLE_WITH_HIP
  auto error_code_x =
      hipDeviceGetAttribute(&size, hipDeviceAttributeMaxGridDimX, id);
#else
168
  auto error_code_x = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimX, id);
169
#endif
170
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_x);
171 172
  ret.x = size;

173 174 175 176
#ifdef PADDLE_WITH_HIP
  auto error_code_y =
      hipDeviceGetAttribute(&size, hipDeviceAttributeMaxGridDimY, id);
#else
177
  auto error_code_y = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimY, id);
178
#endif
179
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_y);
180 181
  ret.y = size;

182 183 184 185
#ifdef PADDLE_WITH_HIP
  auto error_code_z =
      hipDeviceGetAttribute(&size, hipDeviceAttributeMaxGridDimZ, id);
#else
186
  auto error_code_z = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimZ, id);
187
#endif
188
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_z);
189 190 191 192
  ret.z = size;
  return ret;
}

C
chengduo 已提交
193
int GetCUDARuntimeVersion(int id) {
194 195 196 197 198
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduo 已提交
199
  int runtime_version = 0;
200 201 202
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipRuntimeGetVersion(&runtime_version));
#else
203
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaRuntimeGetVersion(&runtime_version));
204
#endif
C
chengduo 已提交
205 206 207 208
  return runtime_version;
}

int GetCUDADriverVersion(int id) {
209 210 211 212 213
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduo 已提交
214
  int driver_version = 0;
215 216 217
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipDriverGetVersion(&driver_version));
#else
218
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaDriverGetVersion(&driver_version));
219
#endif
C
chengduo 已提交
220 221 222
  return driver_version;
}

223
bool TensorCoreAvailable() {
224
#if !defined(PADDLE_WITH_HIP) && CUDA_VERSION >= 9000
225 226 227 228 229 230 231 232
  int device = GetCurrentDeviceId();
  int driver_version = GetCUDAComputeCapability(device);
  return driver_version >= 70;
#else
  return false;
#endif
}

C
chengduoZH 已提交
233
int GetCUDAMultiProcessors(int id) {
234 235 236 237 238
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduoZH 已提交
239
  int count;
240 241 242 243
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(
      hipDeviceGetAttribute(&count, hipDeviceAttributeMultiprocessorCount, id));
#else
244 245
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id));
246
#endif
C
chengduoZH 已提交
247 248 249 250
  return count;
}

int GetCUDAMaxThreadsPerMultiProcessor(int id) {
251 252 253 254 255
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduoZH 已提交
256
  int count;
257 258 259 260
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipDeviceGetAttribute(
      &count, hipDeviceAttributeMaxThreadsPerMultiProcessor, id));
#else
261 262
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaDeviceGetAttribute(
      &count, cudaDevAttrMaxThreadsPerMultiProcessor, id));
263
#endif
C
chengduoZH 已提交
264 265 266
  return count;
}

267
int GetCUDAMaxThreadsPerBlock(int id) {
268 269 270 271 272
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
273
  int count;
274 275 276 277
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(
      hipDeviceGetAttribute(&count, hipDeviceAttributeMaxThreadsPerBlock, id));
#else
278 279
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaDeviceGetAttribute(&count, cudaDevAttrMaxThreadsPerBlock, id));
280
#endif
281 282 283
  return count;
}

L
liaogang 已提交
284 285
int GetCurrentDeviceId() {
  int device_id;
286 287 288
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipGetDevice(&device_id));
#else
289
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetDevice(&device_id));
290
#endif
L
liaogang 已提交
291 292 293
  return device_id;
}

294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
//! Get a list of device ids from environment variable or use all.
std::vector<int> GetSelectedDevices() {
  // use user specified GPUs in single-node multi-process mode.
  std::vector<int> devices;
  if (!FLAGS_selected_gpus.empty()) {
    auto devices_str = paddle::string::Split(FLAGS_selected_gpus, ',');
    for (auto id : devices_str) {
      devices.push_back(atoi(id.c_str()));
    }
  } else {
    int count = GetCUDADeviceCount();
    for (int i = 0; i < count; ++i) {
      devices.push_back(i);
    }
  }
  return devices;
}

312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
const gpuDeviceProp &GetDeviceProperties(int id) {
  std::call_once(g_device_props_size_init_flag, [&] {
    int gpu_num = 0;
    gpu_num = platform::GetCUDADeviceCount();
    g_device_props_init_flags.resize(gpu_num);
    g_device_props.resize(gpu_num);
    for (int i = 0; i < gpu_num; ++i) {
      g_device_props_init_flags[i] = std::make_unique<std::once_flag>();
    }
  });

  if (id == -1) {
    id = platform::GetCurrentDeviceId();
  }

  if (id < 0 || id >= static_cast<int>(g_device_props.size())) {
    PADDLE_THROW(platform::errors::OutOfRange(
        "The device id %d is out of range [0, %d), where %d is the number of "
        "devices on this machine. Because the device id should be greater than "
        "or equal to zero and smaller than the number of gpus. Please input "
        "appropriate device again!",
        id, static_cast<int>(g_device_props.size()),
        static_cast<int>(g_device_props.size())));
  }

  std::call_once(*(g_device_props_init_flags[id]), [&] {
#ifdef PADDLE_WITH_CUDA
    PADDLE_ENFORCE_CUDA_SUCCESS(
        cudaGetDeviceProperties(&g_device_props[id], id));
#else
    PADDLE_ENFORCE_CUDA_SUCCESS(
      hipGetDeviceProperties(&g_device_props[id], id));
#endif
  });

  return g_device_props[id];
}

L
liaogang 已提交
350
void SetDeviceId(int id) {
Q
qijun 已提交
351
  // TODO(qijun): find a better way to cache the cuda device count
352 353 354 355 356
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
357 358 359
#ifdef PADDLE_WITH_HIP
  PADDLE_RETRY_CUDA_SUCCESS(hipSetDevice(id));
#else
L
Leo Chen 已提交
360
  PADDLE_RETRY_CUDA_SUCCESS(cudaSetDevice(id));
361
#endif
L
liaogang 已提交
362 363
}

364
void GpuMemoryUsage(size_t *available, size_t *total) {
365 366 367
  size_t actual_available, actual_total;
  RecordedCudaMemGetInfo(available, total, &actual_available, &actual_total,
                         platform::GetCurrentDeviceId());
L
liaogang 已提交
368 369
}

370
size_t GpuAvailableMemToAlloc() {
L
liaogang 已提交
371 372
  size_t total = 0;
  size_t available = 0;
373
  GpuMemoryUsage(&available, &total);
374 375
  size_t reserving =
      static_cast<size_t>(fraction_reserve_gpu_memory * available);
376
  // If available size is less than minimum chunk size, no usable memory exists
377
  size_t available_to_alloc = available - reserving;
378
  size_t min_chunk_size = GpuMinChunkSize();
379 380 381
  if (available_to_alloc < min_chunk_size) {
    available_to_alloc = 0;
  }
382 383 384
  VLOG(10) << "GPU usage " << (available >> 20) << "M/" << (total >> 20)
           << "M, " << (available_to_alloc >> 20) << "M available to allocate";
  return available_to_alloc;
Z
zhhsplendid 已提交
385 386
}

387 388 389
size_t GpuMaxAllocSize() {
  return std::max(GpuInitAllocSize(), GpuReallocSize());
}
Z
zhhsplendid 已提交
390

391 392
static size_t GpuAllocSize(bool realloc) {
  size_t available_to_alloc = GpuAvailableMemToAlloc();
G
GaoWei8 已提交
393 394 395
  PADDLE_ENFORCE_GT(
      available_to_alloc, 0,
      platform::errors::ResourceExhausted("Not enough available GPU memory."));
396 397 398 399 400 401 402
  // If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be
  // allocated by fraction
  size_t flag_mb = realloc ? FLAGS_reallocate_gpu_memory_in_mb
                           : FLAGS_initial_gpu_memory_in_mb;
  size_t alloc_bytes =
      (flag_mb > 0ul ? flag_mb << 20 : available_to_alloc *
                                           FLAGS_fraction_of_gpu_memory_to_use);
G
GaoWei8 已提交
403 404 405
  PADDLE_ENFORCE_GE(
      available_to_alloc, alloc_bytes,
      platform::errors::ResourceExhausted("Not enough available GPU memory."));
406 407 408 409
  VLOG(10) << "Alloc size is " << (alloc_bytes >> 20)
           << " MiB, is it Re-alloc: " << realloc;
  return alloc_bytes;
}
Z
zhhsplendid 已提交
410

411
size_t GpuInitAllocSize() { return GpuAllocSize(/* realloc = */ false); }
Z
zhhsplendid 已提交
412

413
size_t GpuReallocSize() { return GpuAllocSize(/* realloc = */ true); }
L
liaogang 已提交
414

L
liaogang 已提交
415 416 417 418 419 420
size_t GpuMinChunkSize() {
  // Allow to allocate the minimum chunk size is 256 bytes.
  return 1 << 8;
}

size_t GpuMaxChunkSize() {
421 422 423
  size_t max_chunk_size = GpuMaxAllocSize();
  VLOG(10) << "Max chunk size " << (max_chunk_size >> 20) << "M";
  return max_chunk_size;
L
liaogang 已提交
424 425
}

426 427 428 429 430 431
#ifdef PADDLE_WITH_HIP
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
                    enum hipMemcpyKind kind, hipStream_t stream) {
  PADDLE_ENFORCE_CUDA_SUCCESS(hipMemcpyAsync(dst, src, count, kind, stream));
}
#else
L
liaogang 已提交
432 433
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
                    enum cudaMemcpyKind kind, cudaStream_t stream) {
434
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync(dst, src, count, kind, stream));
L
liaogang 已提交
435
}
436
#endif
L
liaogang 已提交
437

438 439 440 441 442 443
#ifdef PADDLE_WITH_HIP
void GpuMemcpySync(void *dst, const void *src, size_t count,
                   enum hipMemcpyKind kind) {
  PADDLE_ENFORCE_CUDA_SUCCESS(hipMemcpy(dst, src, count, kind));
}
#else
444 445
void GpuMemcpySync(void *dst, const void *src, size_t count,
                   enum cudaMemcpyKind kind) {
446
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpy(dst, src, count, kind));
447
}
448
#endif
449 450

void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,
451 452 453 454 455
                        int src_device, size_t count, gpuStream_t stream) {
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(
      hipMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream));
#else
456 457
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream));
458
#endif
459 460 461 462
}

void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src,
                       int src_device, size_t count) {
463 464 465 466
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(
      hipMemcpyPeer(dst, dst_device, src, src_device, count));
#else
467 468
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaMemcpyPeer(dst, dst_device, src, src_device, count));
469
#endif
L
liaogang 已提交
470
}
D
dzhwinter 已提交
471

472 473 474 475
void GpuMemsetAsync(void *dst, int value, size_t count, gpuStream_t stream) {
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipMemsetAsync(dst, value, count, stream));
#else
476
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync(dst, value, count, stream));
477
#endif
D
dzhwinter 已提交
478
}
479

480 481 482 483
void GpuStreamSync(gpuStream_t stream) {
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
484
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
485
#endif
石晓伟 已提交
486 487
}

488 489 490 491 492 493
static void RaiseNonOutOfMemoryError(gpuError_t *status) {
#ifdef PADDLE_WITH_HIP
  if (*status == hipErrorOutOfMemory) {
    *status = hipSuccess;
  }
#else
494 495 496
  if (*status == cudaErrorMemoryAllocation) {
    *status = cudaSuccess;
  }
497
#endif
498 499
  PADDLE_ENFORCE_CUDA_SUCCESS(*status);

500 501 502 503 504 505
#ifdef PADDLE_WITH_HIP
  *status = hipGetLastError();
  if (*status == hipErrorOutOfMemory) {
    *status = hipSuccess;
  }
#else
506 507 508 509
  *status = cudaGetLastError();
  if (*status == cudaErrorMemoryAllocation) {
    *status = cudaSuccess;
  }
510
#endif
511 512
  PADDLE_ENFORCE_CUDA_SUCCESS(*status);
}
石晓伟 已提交
513

514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538
class RecordedCudaMallocHelper {
 private:
  explicit RecordedCudaMallocHelper(int dev_id, uint64_t limit_size = 0)
      : dev_id_(dev_id), limit_size_(limit_size) {
    if (NeedRecord()) {
      mtx_.reset(new std::mutex());
    }
  }

  DISABLE_COPY_AND_ASSIGN(RecordedCudaMallocHelper);

 public:
  static RecordedCudaMallocHelper *Instance(int dev_id) {
    std::call_once(once_flag_, [] {
      int dev_cnt = GetCUDADeviceCount();
      instances_.reserve(dev_cnt);
      for (int i = 0; i < dev_cnt; ++i) {
        instances_.emplace_back(
            new RecordedCudaMallocHelper(i, FLAGS_gpu_memory_limit_mb << 20));
      }
    });

    PADDLE_ENFORCE_GE(
        dev_id, 0,
        platform::errors::OutOfRange(
G
GaoWei8 已提交
539
            "Device id must be not less than 0, but got %d.", dev_id));
540 541
    PADDLE_ENFORCE_LT(
        dev_id, instances_.size(),
G
GaoWei8 已提交
542
        platform::errors::OutOfRange("Device id %d exceeds gpu card number %d.",
543 544 545 546 547 548 549 550 551
                                     dev_id, instances_.size()));
    return instances_[dev_id].get();
  }

  /**
   * Try to allocate `size` gpu memory. Only cudaErrorMemoryAllocation
   * or cudaSuccess would be returned, and the cudaGetLastError() flag
   * would be clear.
   */
552
  gpuError_t Malloc(void **ptr, size_t size) {
553
    LockGuardPtr<std::mutex> lock(mtx_);
554
    if (UNLIKELY(NeedRecord() && cur_size_.load() + size > limit_size_)) {
555 556 557
#ifdef PADDLE_WITH_HIP
      return hipErrorOutOfMemory;
#else
558
      return cudaErrorMemoryAllocation;
559
#endif
560 561 562
    }

    CUDADeviceGuard guard(dev_id_);
563 564 565
#ifdef PADDLE_WITH_HIP
    auto result = hipMalloc(ptr, size);
#else
566
    CUDAGraphCaptureModeGuard capture_mode_guard;
567
    auto result = cudaMalloc(ptr, size);
568 569
#endif
    if (result == gpuSuccess) {
570
      cur_size_.fetch_add(size);
H
hutuxian 已提交
571
      STAT_INT_ADD("STAT_gpu" + std::to_string(dev_id_) + "_mem_size", size);
572
      return gpuSuccess;
573 574
    } else {
      RaiseNonOutOfMemoryError(&result);
575 576 577 578 579 580
// Non out of memory error would be raised inside
// RaiseNonOutOfMemoryError. Therefore, we can
// return cudaErrorMemoryAllocation directly here.
#ifdef PADDLE_WITH_HIP
      return hipErrorOutOfMemory;
#else
581
      return cudaErrorMemoryAllocation;
582
#endif
583 584 585 586 587 588 589 590 591 592 593 594 595 596
    }
  }

  /**
   * Free gpu memory. Usually, free is not allowed to raise error.
   * If it does raise error, the process should be crashed.
   */
  void Free(void *ptr, size_t size) {
    // Purposefully allow cudaErrorCudartUnloading, because
    // that is returned if you ever call cudaFree after the
    // driver has already shutdown. This happens only if the
    // process is terminating, in which case we don't care if
    // cudaFree succeeds.
    CUDADeviceGuard guard(dev_id_);
597 598 599 600
#ifdef PADDLE_WITH_HIP
    auto err = hipFree(ptr);
    if (err != hipErrorDeinitialized) {
#else
601 602
    auto err = cudaFree(ptr);
    if (err != cudaErrorCudartUnloading) {
603
#endif
604
      PADDLE_ENFORCE_CUDA_SUCCESS(err);
605
      cur_size_.fetch_sub(size);
H
hutuxian 已提交
606
      STAT_INT_SUB("STAT_gpu" + std::to_string(dev_id_) + "_mem_size", size);
607
    } else {
608 609 610
#ifdef PADDLE_WITH_HIP
      hipGetLastError();  // clear the error flag when hipErrorDeinitialized
#else
611
      cudaGetLastError();  // clear the error flag when cudaErrorCudartUnloading
612
#endif
613 614 615 616 617 618 619
    }
  }

  bool GetMemInfo(size_t *avail, size_t *total, size_t *actual_avail,
                  size_t *actual_total) {
    {
      CUDADeviceGuard guard(dev_id_);
620 621 622
#ifdef PADDLE_WITH_HIP
      auto result = hipMemGetInfo(actual_avail, actual_total);
#else
623
      auto result = cudaMemGetInfo(actual_avail, actual_total);
624 625
#endif
      if (result != gpuSuccess) {
626 627 628 629 630 631 632
        *actual_avail = 0;
      }
      RaiseNonOutOfMemoryError(&result);
    }

    if (NeedRecord()) {
      std::lock_guard<std::mutex> guard(*mtx_);
633
      *avail = std::min(*actual_avail, limit_size_ - cur_size_.load());
634 635 636 637 638 639 640 641 642 643 644
      *total = std::min(*actual_total, limit_size_);
      return *total < *actual_total;
    } else {
      *avail = *actual_avail;
      *total = *actual_total;
      return false;
    }
  }

  inline bool NeedRecord() const { return limit_size_ != 0; }

645
  uint64_t RecordedSize() const { return cur_size_.load(); }
646 647 648

  uint64_t LimitSize() const { return limit_size_; }

649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10020
  CUresult MemCreate(CUmemGenericAllocationHandle *handle, size_t size,
                     const CUmemAllocationProp *prop,
                     unsigned long long flags) {  // NOLINT
    auto result =
        paddle::platform::dynload::cuMemCreate(handle, size, prop, flags);
    if (result == CUDA_SUCCESS) {
      cur_size_.fetch_add(size);
    }
    return result;
  }

  CUresult MemRelease(CUmemGenericAllocationHandle handle, size_t size) {
    auto result = paddle::platform::dynload::cuMemRelease(handle);
    if (result == CUDA_SUCCESS) {
      cur_size_.fetch_sub(size);
    }
    return result;
  }

#endif
#endif

673 674 675
 private:
  const int dev_id_;
  const uint64_t limit_size_;
676
  std::atomic<uint64_t> cur_size_{0};
677 678 679 680 681

  mutable std::unique_ptr<std::mutex> mtx_;

  static std::once_flag once_flag_;
  static std::vector<std::unique_ptr<RecordedCudaMallocHelper>> instances_;
682
};  // NOLINT
683 684 685 686 687

std::once_flag RecordedCudaMallocHelper::once_flag_;
std::vector<std::unique_ptr<RecordedCudaMallocHelper>>
    RecordedCudaMallocHelper::instances_;

688
gpuError_t RecordedCudaMalloc(void **ptr, size_t size, int dev_id) {
689 690 691 692 693 694 695
  return RecordedCudaMallocHelper::Instance(dev_id)->Malloc(ptr, size);
}

void RecordedCudaFree(void *p, size_t size, int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->Free(p, size);
}

696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10020
CUresult RecordedCuMemCreate(CUmemGenericAllocationHandle *handle, size_t size,
                             const CUmemAllocationProp *prop,
                             unsigned long long flags, int dev_id) {  // NOLINT
  return RecordedCudaMallocHelper::Instance(dev_id)->MemCreate(handle, size,
                                                               prop, flags);
}

CUresult RecordedCuMemRelease(CUmemGenericAllocationHandle handle, size_t size,
                              int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->MemRelease(handle, size);
}
#endif
#endif

712 713 714 715 716 717 718 719 720 721 722 723 724 725
bool RecordedCudaMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail,
                            size_t *actual_total, int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->GetMemInfo(
      avail, total, actual_avail, actual_total);
}

uint64_t RecordedCudaMallocSize(int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->RecordedSize();
}

bool IsCudaMallocRecorded(int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->NeedRecord();
}

726 727 728 729 730 731 732
void EmptyCache(void) {
  std::vector<int> devices = GetSelectedDevices();
  for (auto device : devices) {
    memory::Release(CUDAPlace(device));
  }
}

L
liaogang 已提交
733 734
}  // namespace platform
}  // namespace paddle