gpu_info.cc 20.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
liaogang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/platform/gpu_info.h"
S
sneaxiy 已提交
16
#include <cstdlib>
L
liaogang 已提交
17

18
#include "gflags/gflags.h"
19
#include "paddle/fluid/platform/cuda_device_guard.h"
20 21 22
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/dynload/miopen.h"
#else
23
#include "paddle/fluid/platform/dynload/cudnn.h"
24
#endif
Y
Yi Wang 已提交
25
#include "paddle/fluid/platform/enforce.h"
26 27
#include "paddle/fluid/platform/lock_guard_ptr.h"
#include "paddle/fluid/platform/macros.h"
H
hutuxian 已提交
28
#include "paddle/fluid/platform/monitor.h"
29
#include "paddle/fluid/string/split.h"
L
liaogang 已提交
30

31 32 33 34 35
DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_bool(enable_cublas_tensor_op_math);
DECLARE_string(selected_gpus);
36
DECLARE_uint64(gpu_memory_limit_mb);
37

Z
zhhsplendid 已提交
38 39
constexpr static float fraction_reserve_gpu_memory = 0.05f;

H
hutuxian 已提交
40
USE_GPU_MEM_STAT;
L
liaogang 已提交
41 42 43
namespace paddle {
namespace platform {

44 45
int CudnnVersion() {
  if (!dynload::HasCUDNN()) return -1;
46

47 48 49 50 51 52
#ifdef PADDLE_WITH_HIP
  size_t version_major, version_minor, version_patch;
  PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenGetVersion(
      &version_major, &version_minor, &version_patch));
  return version_major * 100 + version_minor * 10 + version_patch;
#else
53
  return dynload::cudnnGetVersion();
54
#endif
55
}
S
sneaxiy 已提交
56
static int GetCUDADeviceCountImpl() {
57
  int driverVersion = 0;
58 59 60
#ifdef PADDLE_WITH_HIP
  hipError_t status = hipDriverGetVersion(&driverVersion);
#else
61
  cudaError_t status = cudaDriverGetVersion(&driverVersion);
62
#endif
63

64
  if (!(status == gpuSuccess && driverVersion != 0)) {
65
    // No GPU driver
66
    VLOG(2) << "GPU Driver Version can't be detected. No GPU driver!";
67 68 69
    return 0;
  }

70 71 72
#ifdef PADDLE_WITH_HIP
  const auto *cuda_visible_devices = std::getenv("HIP_VISIBLE_DEVICES");
#else
S
sneaxiy 已提交
73
  const auto *cuda_visible_devices = std::getenv("CUDA_VISIBLE_DEVICES");
74
#endif
S
sneaxiy 已提交
75 76
  if (cuda_visible_devices != nullptr) {
    std::string cuda_visible_devices_str(cuda_visible_devices);
77 78 79 80 81 82 83 84 85 86
    if (!cuda_visible_devices_str.empty()) {
      cuda_visible_devices_str.erase(
          0, cuda_visible_devices_str.find_first_not_of('\''));
      cuda_visible_devices_str.erase(
          cuda_visible_devices_str.find_last_not_of('\'') + 1);
      cuda_visible_devices_str.erase(
          0, cuda_visible_devices_str.find_first_not_of('\"'));
      cuda_visible_devices_str.erase(
          cuda_visible_devices_str.find_last_not_of('\"') + 1);
    }
S
sneaxiy 已提交
87 88 89
    if (std::all_of(cuda_visible_devices_str.begin(),
                    cuda_visible_devices_str.end(),
                    [](char ch) { return ch == ' '; })) {
90 91
      VLOG(2) << "CUDA_VISIBLE_DEVICES or HIP_VISIBLE_DEVICES is set to be "
                 "empty. No GPU detected.";
S
sneaxiy 已提交
92 93 94
      return 0;
    }
  }
L
liaogang 已提交
95
  int count;
96 97 98
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipGetDeviceCount(&count));
#else
99
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetDeviceCount(&count));
100
#endif
L
liaogang 已提交
101 102 103
  return count;
}

S
sneaxiy 已提交
104 105 106 107 108
int GetCUDADeviceCount() {
  static auto dev_cnt = GetCUDADeviceCountImpl();
  return dev_cnt;
}

109 110 111 112
/* Here is a very simple CUDA “pro tip”: cudaDeviceGetAttribute() is a much
faster way to query device properties. You can see details in
https://devblogs.nvidia.com/cuda-pro-tip-the-fast-way-to-query-device-properties/
*/
113
int GetCUDAComputeCapability(int id) {
114 115 116 117 118
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
119 120
  int major, minor;

121 122 123 124 125 126
#ifdef PADDLE_WITH_HIP
  auto major_error_code = hipDeviceGetAttribute(
      &major, hipDeviceAttributeComputeCapabilityMajor, id);
  auto minor_error_code = hipDeviceGetAttribute(
      &minor, hipDeviceAttributeComputeCapabilityMinor, id);
#else
127 128 129 130
  auto major_error_code =
      cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id);
  auto minor_error_code =
      cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, id);
131
#endif
132 133
  PADDLE_ENFORCE_CUDA_SUCCESS(major_error_code);
  PADDLE_ENFORCE_CUDA_SUCCESS(minor_error_code);
134 135 136
#ifdef PADDLE_WITH_HIP
  return major * 100 + minor;
#else
137
  return major * 10 + minor;
138
#endif
139 140
}

141
dim3 GetGpuMaxGridDimSize(int id) {
142 143 144 145 146
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
147 148
  dim3 ret;
  int size;
149 150 151 152
#ifdef PADDLE_WITH_HIP
  auto error_code_x =
      hipDeviceGetAttribute(&size, hipDeviceAttributeMaxGridDimX, id);
#else
153
  auto error_code_x = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimX, id);
154
#endif
155
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_x);
156 157
  ret.x = size;

158 159 160 161
#ifdef PADDLE_WITH_HIP
  auto error_code_y =
      hipDeviceGetAttribute(&size, hipDeviceAttributeMaxGridDimY, id);
#else
162
  auto error_code_y = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimY, id);
163
#endif
164
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_y);
165 166
  ret.y = size;

167 168 169 170
#ifdef PADDLE_WITH_HIP
  auto error_code_z =
      hipDeviceGetAttribute(&size, hipDeviceAttributeMaxGridDimZ, id);
#else
171
  auto error_code_z = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimZ, id);
172
#endif
173
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_z);
174 175 176 177
  ret.z = size;
  return ret;
}

C
chengduo 已提交
178
int GetCUDARuntimeVersion(int id) {
179 180 181 182 183
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduo 已提交
184
  int runtime_version = 0;
185 186 187
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipRuntimeGetVersion(&runtime_version));
#else
188
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaRuntimeGetVersion(&runtime_version));
189
#endif
C
chengduo 已提交
190 191 192 193
  return runtime_version;
}

int GetCUDADriverVersion(int id) {
194 195 196 197 198
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduo 已提交
199
  int driver_version = 0;
200 201 202
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipDriverGetVersion(&driver_version));
#else
203
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaDriverGetVersion(&driver_version));
204
#endif
C
chengduo 已提交
205 206 207
  return driver_version;
}

208
bool TensorCoreAvailable() {
209
#if !defined(PADDLE_WITH_HIP) && CUDA_VERSION >= 9000
210 211 212 213 214 215 216 217
  int device = GetCurrentDeviceId();
  int driver_version = GetCUDAComputeCapability(device);
  return driver_version >= 70;
#else
  return false;
#endif
}

C
chengduoZH 已提交
218
int GetCUDAMultiProcessors(int id) {
219 220 221 222 223
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduoZH 已提交
224
  int count;
225 226 227 228
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(
      hipDeviceGetAttribute(&count, hipDeviceAttributeMultiprocessorCount, id));
#else
229 230
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id));
231
#endif
C
chengduoZH 已提交
232 233 234 235
  return count;
}

int GetCUDAMaxThreadsPerMultiProcessor(int id) {
236 237 238 239 240
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduoZH 已提交
241
  int count;
242 243 244 245
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipDeviceGetAttribute(
      &count, hipDeviceAttributeMaxThreadsPerMultiProcessor, id));
#else
246 247
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaDeviceGetAttribute(
      &count, cudaDevAttrMaxThreadsPerMultiProcessor, id));
248
#endif
C
chengduoZH 已提交
249 250 251
  return count;
}

252
int GetCUDAMaxThreadsPerBlock(int id) {
253 254 255 256 257
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
258
  int count;
259 260 261 262
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(
      hipDeviceGetAttribute(&count, hipDeviceAttributeMaxThreadsPerBlock, id));
#else
263 264
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaDeviceGetAttribute(&count, cudaDevAttrMaxThreadsPerBlock, id));
265
#endif
266 267 268
  return count;
}

L
liaogang 已提交
269 270
int GetCurrentDeviceId() {
  int device_id;
271 272 273
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipGetDevice(&device_id));
#else
274
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetDevice(&device_id));
275
#endif
L
liaogang 已提交
276 277 278
  return device_id;
}

279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
//! Get a list of device ids from environment variable or use all.
std::vector<int> GetSelectedDevices() {
  // use user specified GPUs in single-node multi-process mode.
  std::vector<int> devices;
  if (!FLAGS_selected_gpus.empty()) {
    auto devices_str = paddle::string::Split(FLAGS_selected_gpus, ',');
    for (auto id : devices_str) {
      devices.push_back(atoi(id.c_str()));
    }
  } else {
    int count = GetCUDADeviceCount();
    for (int i = 0; i < count; ++i) {
      devices.push_back(i);
    }
  }
  return devices;
}

L
liaogang 已提交
297
void SetDeviceId(int id) {
Q
qijun 已提交
298
  // TODO(qijun): find a better way to cache the cuda device count
299 300 301 302 303
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
304 305 306
#ifdef PADDLE_WITH_HIP
  PADDLE_RETRY_CUDA_SUCCESS(hipSetDevice(id));
#else
L
Leo Chen 已提交
307
  PADDLE_RETRY_CUDA_SUCCESS(cudaSetDevice(id));
308
#endif
L
liaogang 已提交
309 310
}

311
void GpuMemoryUsage(size_t *available, size_t *total) {
312 313 314
  size_t actual_available, actual_total;
  RecordedCudaMemGetInfo(available, total, &actual_available, &actual_total,
                         platform::GetCurrentDeviceId());
L
liaogang 已提交
315 316
}

317
size_t GpuAvailableMemToAlloc() {
L
liaogang 已提交
318 319
  size_t total = 0;
  size_t available = 0;
320
  GpuMemoryUsage(&available, &total);
321 322
  size_t reserving =
      static_cast<size_t>(fraction_reserve_gpu_memory * available);
323
  // If available size is less than minimum chunk size, no usable memory exists
324
  size_t available_to_alloc = available - reserving;
325
  size_t min_chunk_size = GpuMinChunkSize();
326 327 328
  if (available_to_alloc < min_chunk_size) {
    available_to_alloc = 0;
  }
329 330 331
  VLOG(10) << "GPU usage " << (available >> 20) << "M/" << (total >> 20)
           << "M, " << (available_to_alloc >> 20) << "M available to allocate";
  return available_to_alloc;
Z
zhhsplendid 已提交
332 333
}

334 335 336
size_t GpuMaxAllocSize() {
  return std::max(GpuInitAllocSize(), GpuReallocSize());
}
Z
zhhsplendid 已提交
337

338 339
static size_t GpuAllocSize(bool realloc) {
  size_t available_to_alloc = GpuAvailableMemToAlloc();
G
GaoWei8 已提交
340 341 342
  PADDLE_ENFORCE_GT(
      available_to_alloc, 0,
      platform::errors::ResourceExhausted("Not enough available GPU memory."));
343 344 345 346 347 348 349
  // If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be
  // allocated by fraction
  size_t flag_mb = realloc ? FLAGS_reallocate_gpu_memory_in_mb
                           : FLAGS_initial_gpu_memory_in_mb;
  size_t alloc_bytes =
      (flag_mb > 0ul ? flag_mb << 20 : available_to_alloc *
                                           FLAGS_fraction_of_gpu_memory_to_use);
G
GaoWei8 已提交
350 351 352
  PADDLE_ENFORCE_GE(
      available_to_alloc, alloc_bytes,
      platform::errors::ResourceExhausted("Not enough available GPU memory."));
353 354 355 356
  VLOG(10) << "Alloc size is " << (alloc_bytes >> 20)
           << " MiB, is it Re-alloc: " << realloc;
  return alloc_bytes;
}
Z
zhhsplendid 已提交
357

358
size_t GpuInitAllocSize() { return GpuAllocSize(/* realloc = */ false); }
Z
zhhsplendid 已提交
359

360
size_t GpuReallocSize() { return GpuAllocSize(/* realloc = */ true); }
L
liaogang 已提交
361

L
liaogang 已提交
362 363 364 365 366 367
size_t GpuMinChunkSize() {
  // Allow to allocate the minimum chunk size is 256 bytes.
  return 1 << 8;
}

size_t GpuMaxChunkSize() {
368 369 370
  size_t max_chunk_size = GpuMaxAllocSize();
  VLOG(10) << "Max chunk size " << (max_chunk_size >> 20) << "M";
  return max_chunk_size;
L
liaogang 已提交
371 372
}

373 374 375 376 377 378
#ifdef PADDLE_WITH_HIP
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
                    enum hipMemcpyKind kind, hipStream_t stream) {
  PADDLE_ENFORCE_CUDA_SUCCESS(hipMemcpyAsync(dst, src, count, kind, stream));
}
#else
L
liaogang 已提交
379 380
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
                    enum cudaMemcpyKind kind, cudaStream_t stream) {
381
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync(dst, src, count, kind, stream));
L
liaogang 已提交
382
}
383
#endif
L
liaogang 已提交
384

385 386 387 388 389 390
#ifdef PADDLE_WITH_HIP
void GpuMemcpySync(void *dst, const void *src, size_t count,
                   enum hipMemcpyKind kind) {
  PADDLE_ENFORCE_CUDA_SUCCESS(hipMemcpy(dst, src, count, kind));
}
#else
391 392
void GpuMemcpySync(void *dst, const void *src, size_t count,
                   enum cudaMemcpyKind kind) {
393
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpy(dst, src, count, kind));
394
}
395
#endif
396 397

void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,
398 399 400 401 402
                        int src_device, size_t count, gpuStream_t stream) {
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(
      hipMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream));
#else
403 404
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream));
405
#endif
406 407 408 409
}

void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src,
                       int src_device, size_t count) {
410 411 412 413
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(
      hipMemcpyPeer(dst, dst_device, src, src_device, count));
#else
414 415
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaMemcpyPeer(dst, dst_device, src, src_device, count));
416
#endif
L
liaogang 已提交
417
}
D
dzhwinter 已提交
418

419 420 421 422
void GpuMemsetAsync(void *dst, int value, size_t count, gpuStream_t stream) {
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipMemsetAsync(dst, value, count, stream));
#else
423
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync(dst, value, count, stream));
424
#endif
D
dzhwinter 已提交
425
}
426

427 428 429 430
void GpuStreamSync(gpuStream_t stream) {
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
431
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
432
#endif
石晓伟 已提交
433 434
}

435 436 437 438 439 440
static void RaiseNonOutOfMemoryError(gpuError_t *status) {
#ifdef PADDLE_WITH_HIP
  if (*status == hipErrorOutOfMemory) {
    *status = hipSuccess;
  }
#else
441 442 443
  if (*status == cudaErrorMemoryAllocation) {
    *status = cudaSuccess;
  }
444
#endif
445 446
  PADDLE_ENFORCE_CUDA_SUCCESS(*status);

447 448 449 450 451 452
#ifdef PADDLE_WITH_HIP
  *status = hipGetLastError();
  if (*status == hipErrorOutOfMemory) {
    *status = hipSuccess;
  }
#else
453 454 455 456
  *status = cudaGetLastError();
  if (*status == cudaErrorMemoryAllocation) {
    *status = cudaSuccess;
  }
457
#endif
458 459
  PADDLE_ENFORCE_CUDA_SUCCESS(*status);
}
石晓伟 已提交
460

461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
class RecordedCudaMallocHelper {
 private:
  explicit RecordedCudaMallocHelper(int dev_id, uint64_t limit_size = 0)
      : dev_id_(dev_id), limit_size_(limit_size) {
    if (NeedRecord()) {
      mtx_.reset(new std::mutex());
    }
  }

  DISABLE_COPY_AND_ASSIGN(RecordedCudaMallocHelper);

 public:
  static RecordedCudaMallocHelper *Instance(int dev_id) {
    std::call_once(once_flag_, [] {
      int dev_cnt = GetCUDADeviceCount();
      instances_.reserve(dev_cnt);
      for (int i = 0; i < dev_cnt; ++i) {
        instances_.emplace_back(
            new RecordedCudaMallocHelper(i, FLAGS_gpu_memory_limit_mb << 20));
      }
    });

    PADDLE_ENFORCE_GE(
        dev_id, 0,
        platform::errors::OutOfRange(
G
GaoWei8 已提交
486
            "Device id must be not less than 0, but got %d.", dev_id));
487 488
    PADDLE_ENFORCE_LT(
        dev_id, instances_.size(),
G
GaoWei8 已提交
489
        platform::errors::OutOfRange("Device id %d exceeds gpu card number %d.",
490 491 492 493 494 495 496 497 498
                                     dev_id, instances_.size()));
    return instances_[dev_id].get();
  }

  /**
   * Try to allocate `size` gpu memory. Only cudaErrorMemoryAllocation
   * or cudaSuccess would be returned, and the cudaGetLastError() flag
   * would be clear.
   */
499
  gpuError_t Malloc(void **ptr, size_t size) {
500 501
    LockGuardPtr<std::mutex> lock(mtx_);
    if (UNLIKELY(NeedRecord() && cur_size_ + size > limit_size_)) {
502 503 504
#ifdef PADDLE_WITH_HIP
      return hipErrorOutOfMemory;
#else
505
      return cudaErrorMemoryAllocation;
506
#endif
507 508 509
    }

    CUDADeviceGuard guard(dev_id_);
510 511 512
#ifdef PADDLE_WITH_HIP
    auto result = hipMalloc(ptr, size);
#else
513
    auto result = cudaMalloc(ptr, size);
514 515
#endif
    if (result == gpuSuccess) {
516 517 518
      if (NeedRecord()) {
        cur_size_ += size;
      }
H
hutuxian 已提交
519
      STAT_INT_ADD("STAT_gpu" + std::to_string(dev_id_) + "_mem_size", size);
520
      return gpuSuccess;
521 522
    } else {
      RaiseNonOutOfMemoryError(&result);
523 524 525 526 527 528
// Non out of memory error would be raised inside
// RaiseNonOutOfMemoryError. Therefore, we can
// return cudaErrorMemoryAllocation directly here.
#ifdef PADDLE_WITH_HIP
      return hipErrorOutOfMemory;
#else
529
      return cudaErrorMemoryAllocation;
530
#endif
531 532 533 534 535 536 537 538 539 540 541 542 543 544
    }
  }

  /**
   * Free gpu memory. Usually, free is not allowed to raise error.
   * If it does raise error, the process should be crashed.
   */
  void Free(void *ptr, size_t size) {
    // Purposefully allow cudaErrorCudartUnloading, because
    // that is returned if you ever call cudaFree after the
    // driver has already shutdown. This happens only if the
    // process is terminating, in which case we don't care if
    // cudaFree succeeds.
    CUDADeviceGuard guard(dev_id_);
545 546 547 548
#ifdef PADDLE_WITH_HIP
    auto err = hipFree(ptr);
    if (err != hipErrorDeinitialized) {
#else
549 550
    auto err = cudaFree(ptr);
    if (err != cudaErrorCudartUnloading) {
551
#endif
552
      PADDLE_ENFORCE_CUDA_SUCCESS(err);
553 554 555 556
      if (NeedRecord()) {
        std::lock_guard<std::mutex> guard(*mtx_);
        cur_size_ -= size;
      }
H
hutuxian 已提交
557
      STAT_INT_SUB("STAT_gpu" + std::to_string(dev_id_) + "_mem_size", size);
558
    } else {
559 560 561
#ifdef PADDLE_WITH_HIP
      hipGetLastError();  // clear the error flag when hipErrorDeinitialized
#else
562
      cudaGetLastError();  // clear the error flag when cudaErrorCudartUnloading
563
#endif
564 565 566 567 568 569 570
    }
  }

  bool GetMemInfo(size_t *avail, size_t *total, size_t *actual_avail,
                  size_t *actual_total) {
    {
      CUDADeviceGuard guard(dev_id_);
571 572 573
#ifdef PADDLE_WITH_HIP
      auto result = hipMemGetInfo(actual_avail, actual_total);
#else
574
      auto result = cudaMemGetInfo(actual_avail, actual_total);
575 576
#endif
      if (result != gpuSuccess) {
577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611
        *actual_avail = 0;
      }
      RaiseNonOutOfMemoryError(&result);
    }

    if (NeedRecord()) {
      std::lock_guard<std::mutex> guard(*mtx_);
      *avail = std::min(*actual_avail, limit_size_ - cur_size_);
      *total = std::min(*actual_total, limit_size_);
      return *total < *actual_total;
    } else {
      *avail = *actual_avail;
      *total = *actual_total;
      return false;
    }
  }

  inline bool NeedRecord() const { return limit_size_ != 0; }

  uint64_t RecordedSize() const {
    LockGuardPtr<std::mutex> lock(mtx_);
    return NeedRecord() ? cur_size_ : 0;
  }

  uint64_t LimitSize() const { return limit_size_; }

 private:
  const int dev_id_;
  const uint64_t limit_size_;
  uint64_t cur_size_{0};

  mutable std::unique_ptr<std::mutex> mtx_;

  static std::once_flag once_flag_;
  static std::vector<std::unique_ptr<RecordedCudaMallocHelper>> instances_;
612
};  // NOLINT
613 614 615 616 617

std::once_flag RecordedCudaMallocHelper::once_flag_;
std::vector<std::unique_ptr<RecordedCudaMallocHelper>>
    RecordedCudaMallocHelper::instances_;

618
gpuError_t RecordedCudaMalloc(void **ptr, size_t size, int dev_id) {
619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639
  return RecordedCudaMallocHelper::Instance(dev_id)->Malloc(ptr, size);
}

void RecordedCudaFree(void *p, size_t size, int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->Free(p, size);
}

bool RecordedCudaMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail,
                            size_t *actual_total, int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->GetMemInfo(
      avail, total, actual_avail, actual_total);
}

uint64_t RecordedCudaMallocSize(int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->RecordedSize();
}

bool IsCudaMallocRecorded(int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->NeedRecord();
}

L
liaogang 已提交
640 641
}  // namespace platform
}  // namespace paddle