gpu_info.cc 20.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
liaogang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/platform/gpu_info.h"
S
sneaxiy 已提交
16
#include <cstdlib>
L
liaogang 已提交
17

18
#include "gflags/gflags.h"
19
#include "paddle/fluid/platform/cuda_device_guard.h"
20 21 22
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/dynload/miopen.h"
#else
23
#include "paddle/fluid/platform/dynload/cudnn.h"
24
#endif
Y
Yi Wang 已提交
25
#include "paddle/fluid/platform/enforce.h"
26 27
#include "paddle/fluid/platform/lock_guard_ptr.h"
#include "paddle/fluid/platform/macros.h"
H
hutuxian 已提交
28
#include "paddle/fluid/platform/monitor.h"
29
#include "paddle/fluid/string/split.h"
L
liaogang 已提交
30

31 32 33 34 35
DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_bool(enable_cublas_tensor_op_math);
DECLARE_string(selected_gpus);
36
DECLARE_uint64(gpu_memory_limit_mb);
37

Z
zhhsplendid 已提交
38 39
constexpr static float fraction_reserve_gpu_memory = 0.05f;

H
hutuxian 已提交
40
USE_GPU_MEM_STAT;
L
liaogang 已提交
41 42 43
namespace paddle {
namespace platform {

44 45
int CudnnVersion() {
  if (!dynload::HasCUDNN()) return -1;
46

47 48 49 50 51 52
#ifdef PADDLE_WITH_HIP
  size_t version_major, version_minor, version_patch;
  PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenGetVersion(
      &version_major, &version_minor, &version_patch));
  return version_major * 100 + version_minor * 10 + version_patch;
#else
53
  return dynload::cudnnGetVersion();
54
#endif
55
}
S
sneaxiy 已提交
56
static int GetCUDADeviceCountImpl() {
57
  int driverVersion = 0;
58 59 60
#ifdef PADDLE_WITH_HIP
  hipError_t status = hipDriverGetVersion(&driverVersion);
#else
61
  cudaError_t status = cudaDriverGetVersion(&driverVersion);
62
#endif
63

64
  if (!(status == gpuSuccess && driverVersion != 0)) {
65
    // No GPU driver
66
    VLOG(2) << "GPU Driver Version can't be detected. No GPU driver!";
67 68 69
    return 0;
  }

70 71 72
#ifdef PADDLE_WITH_HIP
  const auto *cuda_visible_devices = std::getenv("HIP_VISIBLE_DEVICES");
#else
S
sneaxiy 已提交
73
  const auto *cuda_visible_devices = std::getenv("CUDA_VISIBLE_DEVICES");
74
#endif
S
sneaxiy 已提交
75 76
  if (cuda_visible_devices != nullptr) {
    std::string cuda_visible_devices_str(cuda_visible_devices);
77 78 79 80 81 82 83 84 85 86
    if (!cuda_visible_devices_str.empty()) {
      cuda_visible_devices_str.erase(
          0, cuda_visible_devices_str.find_first_not_of('\''));
      cuda_visible_devices_str.erase(
          cuda_visible_devices_str.find_last_not_of('\'') + 1);
      cuda_visible_devices_str.erase(
          0, cuda_visible_devices_str.find_first_not_of('\"'));
      cuda_visible_devices_str.erase(
          cuda_visible_devices_str.find_last_not_of('\"') + 1);
    }
S
sneaxiy 已提交
87 88 89
    if (std::all_of(cuda_visible_devices_str.begin(),
                    cuda_visible_devices_str.end(),
                    [](char ch) { return ch == ' '; })) {
90 91
      VLOG(2) << "CUDA_VISIBLE_DEVICES or HIP_VISIBLE_DEVICES is set to be "
                 "empty. No GPU detected.";
S
sneaxiy 已提交
92 93 94
      return 0;
    }
  }
L
liaogang 已提交
95
  int count;
96 97 98
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipGetDeviceCount(&count));
#else
99
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetDeviceCount(&count));
100
#endif
L
liaogang 已提交
101 102 103
  return count;
}

S
sneaxiy 已提交
104
int GetCUDADeviceCount() {
105
  // cache the count
S
sneaxiy 已提交
106 107 108 109
  static auto dev_cnt = GetCUDADeviceCountImpl();
  return dev_cnt;
}

110 111 112 113
/* Here is a very simple CUDA “pro tip”: cudaDeviceGetAttribute() is a much
faster way to query device properties. You can see details in
https://devblogs.nvidia.com/cuda-pro-tip-the-fast-way-to-query-device-properties/
*/
114
int GetCUDAComputeCapability(int id) {
115 116 117 118 119
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
120 121
  int major, minor;

122 123 124 125 126 127
#ifdef PADDLE_WITH_HIP
  auto major_error_code = hipDeviceGetAttribute(
      &major, hipDeviceAttributeComputeCapabilityMajor, id);
  auto minor_error_code = hipDeviceGetAttribute(
      &minor, hipDeviceAttributeComputeCapabilityMinor, id);
#else
128 129 130 131
  auto major_error_code =
      cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id);
  auto minor_error_code =
      cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, id);
132
#endif
133 134
  PADDLE_ENFORCE_CUDA_SUCCESS(major_error_code);
  PADDLE_ENFORCE_CUDA_SUCCESS(minor_error_code);
135 136 137
#ifdef PADDLE_WITH_HIP
  return major * 100 + minor;
#else
138
  return major * 10 + minor;
139
#endif
140 141
}

142
dim3 GetGpuMaxGridDimSize(int id) {
143 144 145 146 147
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
148 149
  dim3 ret;
  int size;
150 151 152 153
#ifdef PADDLE_WITH_HIP
  auto error_code_x =
      hipDeviceGetAttribute(&size, hipDeviceAttributeMaxGridDimX, id);
#else
154
  auto error_code_x = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimX, id);
155
#endif
156
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_x);
157 158
  ret.x = size;

159 160 161 162
#ifdef PADDLE_WITH_HIP
  auto error_code_y =
      hipDeviceGetAttribute(&size, hipDeviceAttributeMaxGridDimY, id);
#else
163
  auto error_code_y = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimY, id);
164
#endif
165
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_y);
166 167
  ret.y = size;

168 169 170 171
#ifdef PADDLE_WITH_HIP
  auto error_code_z =
      hipDeviceGetAttribute(&size, hipDeviceAttributeMaxGridDimZ, id);
#else
172
  auto error_code_z = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimZ, id);
173
#endif
174
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_z);
175 176 177 178
  ret.z = size;
  return ret;
}

C
chengduo 已提交
179
int GetCUDARuntimeVersion(int id) {
180 181 182 183 184
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduo 已提交
185
  int runtime_version = 0;
186 187 188
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipRuntimeGetVersion(&runtime_version));
#else
189
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaRuntimeGetVersion(&runtime_version));
190
#endif
C
chengduo 已提交
191 192 193 194
  return runtime_version;
}

int GetCUDADriverVersion(int id) {
195 196 197 198 199
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduo 已提交
200
  int driver_version = 0;
201 202 203
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipDriverGetVersion(&driver_version));
#else
204
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaDriverGetVersion(&driver_version));
205
#endif
C
chengduo 已提交
206 207 208
  return driver_version;
}

209
bool TensorCoreAvailable() {
210
#if !defined(PADDLE_WITH_HIP) && CUDA_VERSION >= 9000
211 212 213 214 215 216 217 218
  int device = GetCurrentDeviceId();
  int driver_version = GetCUDAComputeCapability(device);
  return driver_version >= 70;
#else
  return false;
#endif
}

C
chengduoZH 已提交
219
int GetCUDAMultiProcessors(int id) {
220 221 222 223 224
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduoZH 已提交
225
  int count;
226 227 228 229
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(
      hipDeviceGetAttribute(&count, hipDeviceAttributeMultiprocessorCount, id));
#else
230 231
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id));
232
#endif
C
chengduoZH 已提交
233 234 235 236
  return count;
}

int GetCUDAMaxThreadsPerMultiProcessor(int id) {
237 238 239 240 241
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduoZH 已提交
242
  int count;
243 244 245 246
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipDeviceGetAttribute(
      &count, hipDeviceAttributeMaxThreadsPerMultiProcessor, id));
#else
247 248
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaDeviceGetAttribute(
      &count, cudaDevAttrMaxThreadsPerMultiProcessor, id));
249
#endif
C
chengduoZH 已提交
250 251 252
  return count;
}

253
int GetCUDAMaxThreadsPerBlock(int id) {
254 255 256 257 258
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
259
  int count;
260 261 262 263
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(
      hipDeviceGetAttribute(&count, hipDeviceAttributeMaxThreadsPerBlock, id));
#else
264 265
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaDeviceGetAttribute(&count, cudaDevAttrMaxThreadsPerBlock, id));
266
#endif
267 268 269
  return count;
}

L
liaogang 已提交
270 271
int GetCurrentDeviceId() {
  int device_id;
272 273 274
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipGetDevice(&device_id));
#else
275
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetDevice(&device_id));
276
#endif
L
liaogang 已提交
277 278 279
  return device_id;
}

280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
//! Get a list of device ids from environment variable or use all.
std::vector<int> GetSelectedDevices() {
  // use user specified GPUs in single-node multi-process mode.
  std::vector<int> devices;
  if (!FLAGS_selected_gpus.empty()) {
    auto devices_str = paddle::string::Split(FLAGS_selected_gpus, ',');
    for (auto id : devices_str) {
      devices.push_back(atoi(id.c_str()));
    }
  } else {
    int count = GetCUDADeviceCount();
    for (int i = 0; i < count; ++i) {
      devices.push_back(i);
    }
  }
  return devices;
}

L
liaogang 已提交
298
void SetDeviceId(int id) {
Q
qijun 已提交
299
  // TODO(qijun): find a better way to cache the cuda device count
300 301 302 303 304
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
305 306 307
#ifdef PADDLE_WITH_HIP
  PADDLE_RETRY_CUDA_SUCCESS(hipSetDevice(id));
#else
L
Leo Chen 已提交
308
  PADDLE_RETRY_CUDA_SUCCESS(cudaSetDevice(id));
309
#endif
L
liaogang 已提交
310 311
}

312
void GpuMemoryUsage(size_t *available, size_t *total) {
313 314 315
  size_t actual_available, actual_total;
  RecordedCudaMemGetInfo(available, total, &actual_available, &actual_total,
                         platform::GetCurrentDeviceId());
L
liaogang 已提交
316 317
}

318
size_t GpuAvailableMemToAlloc() {
L
liaogang 已提交
319 320
  size_t total = 0;
  size_t available = 0;
321
  GpuMemoryUsage(&available, &total);
322 323
  size_t reserving =
      static_cast<size_t>(fraction_reserve_gpu_memory * available);
324
  // If available size is less than minimum chunk size, no usable memory exists
325
  size_t available_to_alloc = available - reserving;
326
  size_t min_chunk_size = GpuMinChunkSize();
327 328 329
  if (available_to_alloc < min_chunk_size) {
    available_to_alloc = 0;
  }
330 331 332
  VLOG(10) << "GPU usage " << (available >> 20) << "M/" << (total >> 20)
           << "M, " << (available_to_alloc >> 20) << "M available to allocate";
  return available_to_alloc;
Z
zhhsplendid 已提交
333 334
}

335 336 337
size_t GpuMaxAllocSize() {
  return std::max(GpuInitAllocSize(), GpuReallocSize());
}
Z
zhhsplendid 已提交
338

339 340
static size_t GpuAllocSize(bool realloc) {
  size_t available_to_alloc = GpuAvailableMemToAlloc();
G
GaoWei8 已提交
341 342 343
  PADDLE_ENFORCE_GT(
      available_to_alloc, 0,
      platform::errors::ResourceExhausted("Not enough available GPU memory."));
344 345 346 347 348 349 350
  // If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be
  // allocated by fraction
  size_t flag_mb = realloc ? FLAGS_reallocate_gpu_memory_in_mb
                           : FLAGS_initial_gpu_memory_in_mb;
  size_t alloc_bytes =
      (flag_mb > 0ul ? flag_mb << 20 : available_to_alloc *
                                           FLAGS_fraction_of_gpu_memory_to_use);
G
GaoWei8 已提交
351 352 353
  PADDLE_ENFORCE_GE(
      available_to_alloc, alloc_bytes,
      platform::errors::ResourceExhausted("Not enough available GPU memory."));
354 355 356 357
  VLOG(10) << "Alloc size is " << (alloc_bytes >> 20)
           << " MiB, is it Re-alloc: " << realloc;
  return alloc_bytes;
}
Z
zhhsplendid 已提交
358

359
size_t GpuInitAllocSize() { return GpuAllocSize(/* realloc = */ false); }
Z
zhhsplendid 已提交
360

361
size_t GpuReallocSize() { return GpuAllocSize(/* realloc = */ true); }
L
liaogang 已提交
362

L
liaogang 已提交
363 364 365 366 367 368
size_t GpuMinChunkSize() {
  // Allow to allocate the minimum chunk size is 256 bytes.
  return 1 << 8;
}

size_t GpuMaxChunkSize() {
369 370 371
  size_t max_chunk_size = GpuMaxAllocSize();
  VLOG(10) << "Max chunk size " << (max_chunk_size >> 20) << "M";
  return max_chunk_size;
L
liaogang 已提交
372 373
}

374 375 376 377 378 379
#ifdef PADDLE_WITH_HIP
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
                    enum hipMemcpyKind kind, hipStream_t stream) {
  PADDLE_ENFORCE_CUDA_SUCCESS(hipMemcpyAsync(dst, src, count, kind, stream));
}
#else
L
liaogang 已提交
380 381
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
                    enum cudaMemcpyKind kind, cudaStream_t stream) {
382
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync(dst, src, count, kind, stream));
L
liaogang 已提交
383
}
384
#endif
L
liaogang 已提交
385

386 387 388 389 390 391
#ifdef PADDLE_WITH_HIP
void GpuMemcpySync(void *dst, const void *src, size_t count,
                   enum hipMemcpyKind kind) {
  PADDLE_ENFORCE_CUDA_SUCCESS(hipMemcpy(dst, src, count, kind));
}
#else
392 393
void GpuMemcpySync(void *dst, const void *src, size_t count,
                   enum cudaMemcpyKind kind) {
394
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpy(dst, src, count, kind));
395
}
396
#endif
397 398

void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,
399 400 401 402 403
                        int src_device, size_t count, gpuStream_t stream) {
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(
      hipMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream));
#else
404 405
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream));
406
#endif
407 408 409 410
}

void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src,
                       int src_device, size_t count) {
411 412 413 414
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(
      hipMemcpyPeer(dst, dst_device, src, src_device, count));
#else
415 416
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaMemcpyPeer(dst, dst_device, src, src_device, count));
417
#endif
L
liaogang 已提交
418
}
D
dzhwinter 已提交
419

420 421 422 423
void GpuMemsetAsync(void *dst, int value, size_t count, gpuStream_t stream) {
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipMemsetAsync(dst, value, count, stream));
#else
424
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync(dst, value, count, stream));
425
#endif
D
dzhwinter 已提交
426
}
427

428 429 430 431
void GpuStreamSync(gpuStream_t stream) {
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
432
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
433
#endif
石晓伟 已提交
434 435
}

436 437 438 439 440 441
static void RaiseNonOutOfMemoryError(gpuError_t *status) {
#ifdef PADDLE_WITH_HIP
  if (*status == hipErrorOutOfMemory) {
    *status = hipSuccess;
  }
#else
442 443 444
  if (*status == cudaErrorMemoryAllocation) {
    *status = cudaSuccess;
  }
445
#endif
446 447
  PADDLE_ENFORCE_CUDA_SUCCESS(*status);

448 449 450 451 452 453
#ifdef PADDLE_WITH_HIP
  *status = hipGetLastError();
  if (*status == hipErrorOutOfMemory) {
    *status = hipSuccess;
  }
#else
454 455 456 457
  *status = cudaGetLastError();
  if (*status == cudaErrorMemoryAllocation) {
    *status = cudaSuccess;
  }
458
#endif
459 460
  PADDLE_ENFORCE_CUDA_SUCCESS(*status);
}
石晓伟 已提交
461

462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486
class RecordedCudaMallocHelper {
 private:
  explicit RecordedCudaMallocHelper(int dev_id, uint64_t limit_size = 0)
      : dev_id_(dev_id), limit_size_(limit_size) {
    if (NeedRecord()) {
      mtx_.reset(new std::mutex());
    }
  }

  DISABLE_COPY_AND_ASSIGN(RecordedCudaMallocHelper);

 public:
  static RecordedCudaMallocHelper *Instance(int dev_id) {
    std::call_once(once_flag_, [] {
      int dev_cnt = GetCUDADeviceCount();
      instances_.reserve(dev_cnt);
      for (int i = 0; i < dev_cnt; ++i) {
        instances_.emplace_back(
            new RecordedCudaMallocHelper(i, FLAGS_gpu_memory_limit_mb << 20));
      }
    });

    PADDLE_ENFORCE_GE(
        dev_id, 0,
        platform::errors::OutOfRange(
G
GaoWei8 已提交
487
            "Device id must be not less than 0, but got %d.", dev_id));
488 489
    PADDLE_ENFORCE_LT(
        dev_id, instances_.size(),
G
GaoWei8 已提交
490
        platform::errors::OutOfRange("Device id %d exceeds gpu card number %d.",
491 492 493 494 495 496 497 498 499
                                     dev_id, instances_.size()));
    return instances_[dev_id].get();
  }

  /**
   * Try to allocate `size` gpu memory. Only cudaErrorMemoryAllocation
   * or cudaSuccess would be returned, and the cudaGetLastError() flag
   * would be clear.
   */
500
  gpuError_t Malloc(void **ptr, size_t size) {
501
    LockGuardPtr<std::mutex> lock(mtx_);
502
    if (UNLIKELY(NeedRecord() && cur_size_.load() + size > limit_size_)) {
503 504 505
#ifdef PADDLE_WITH_HIP
      return hipErrorOutOfMemory;
#else
506
      return cudaErrorMemoryAllocation;
507
#endif
508 509 510
    }

    CUDADeviceGuard guard(dev_id_);
511 512 513
#ifdef PADDLE_WITH_HIP
    auto result = hipMalloc(ptr, size);
#else
514
    auto result = cudaMalloc(ptr, size);
515 516
#endif
    if (result == gpuSuccess) {
517
      cur_size_.fetch_add(size);
H
hutuxian 已提交
518
      STAT_INT_ADD("STAT_gpu" + std::to_string(dev_id_) + "_mem_size", size);
519
      return gpuSuccess;
520 521
    } else {
      RaiseNonOutOfMemoryError(&result);
522 523 524 525 526 527
// Non out of memory error would be raised inside
// RaiseNonOutOfMemoryError. Therefore, we can
// return cudaErrorMemoryAllocation directly here.
#ifdef PADDLE_WITH_HIP
      return hipErrorOutOfMemory;
#else
528
      return cudaErrorMemoryAllocation;
529
#endif
530 531 532 533 534 535 536 537 538 539 540 541 542 543
    }
  }

  /**
   * Free gpu memory. Usually, free is not allowed to raise error.
   * If it does raise error, the process should be crashed.
   */
  void Free(void *ptr, size_t size) {
    // Purposefully allow cudaErrorCudartUnloading, because
    // that is returned if you ever call cudaFree after the
    // driver has already shutdown. This happens only if the
    // process is terminating, in which case we don't care if
    // cudaFree succeeds.
    CUDADeviceGuard guard(dev_id_);
544 545 546 547
#ifdef PADDLE_WITH_HIP
    auto err = hipFree(ptr);
    if (err != hipErrorDeinitialized) {
#else
548 549
    auto err = cudaFree(ptr);
    if (err != cudaErrorCudartUnloading) {
550
#endif
551
      PADDLE_ENFORCE_CUDA_SUCCESS(err);
552
      cur_size_.fetch_sub(size);
H
hutuxian 已提交
553
      STAT_INT_SUB("STAT_gpu" + std::to_string(dev_id_) + "_mem_size", size);
554
    } else {
555 556 557
#ifdef PADDLE_WITH_HIP
      hipGetLastError();  // clear the error flag when hipErrorDeinitialized
#else
558
      cudaGetLastError();  // clear the error flag when cudaErrorCudartUnloading
559
#endif
560 561 562 563 564 565 566
    }
  }

  bool GetMemInfo(size_t *avail, size_t *total, size_t *actual_avail,
                  size_t *actual_total) {
    {
      CUDADeviceGuard guard(dev_id_);
567 568 569
#ifdef PADDLE_WITH_HIP
      auto result = hipMemGetInfo(actual_avail, actual_total);
#else
570
      auto result = cudaMemGetInfo(actual_avail, actual_total);
571 572
#endif
      if (result != gpuSuccess) {
573 574 575 576 577 578 579
        *actual_avail = 0;
      }
      RaiseNonOutOfMemoryError(&result);
    }

    if (NeedRecord()) {
      std::lock_guard<std::mutex> guard(*mtx_);
580
      *avail = std::min(*actual_avail, limit_size_ - cur_size_.load());
581 582 583 584 585 586 587 588 589 590 591
      *total = std::min(*actual_total, limit_size_);
      return *total < *actual_total;
    } else {
      *avail = *actual_avail;
      *total = *actual_total;
      return false;
    }
  }

  inline bool NeedRecord() const { return limit_size_ != 0; }

592
  uint64_t RecordedSize() const { return cur_size_.load(); }
593 594 595 596 597 598

  uint64_t LimitSize() const { return limit_size_; }

 private:
  const int dev_id_;
  const uint64_t limit_size_;
599
  std::atomic<uint64_t> cur_size_{0};
600 601 602 603 604

  mutable std::unique_ptr<std::mutex> mtx_;

  static std::once_flag once_flag_;
  static std::vector<std::unique_ptr<RecordedCudaMallocHelper>> instances_;
605
};  // NOLINT
606 607 608 609 610

std::once_flag RecordedCudaMallocHelper::once_flag_;
std::vector<std::unique_ptr<RecordedCudaMallocHelper>>
    RecordedCudaMallocHelper::instances_;

611
gpuError_t RecordedCudaMalloc(void **ptr, size_t size, int dev_id) {
612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
  return RecordedCudaMallocHelper::Instance(dev_id)->Malloc(ptr, size);
}

void RecordedCudaFree(void *p, size_t size, int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->Free(p, size);
}

bool RecordedCudaMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail,
                            size_t *actual_total, int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->GetMemInfo(
      avail, total, actual_avail, actual_total);
}

uint64_t RecordedCudaMallocSize(int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->RecordedSize();
}

bool IsCudaMallocRecorded(int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->NeedRecord();
}

L
liaogang 已提交
633 634
}  // namespace platform
}  // namespace paddle