gpu_info.cc 20.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
liaogang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/platform/gpu_info.h"
S
sneaxiy 已提交
16
#include <cstdlib>
L
liaogang 已提交
17

18
#include "gflags/gflags.h"
19
#include "paddle/fluid/platform/cuda_device_guard.h"
20 21 22
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/dynload/miopen.h"
#else
23
#include "paddle/fluid/platform/dynload/cudnn.h"
24
#endif
25
#include "paddle/fluid/memory/malloc.h"
Y
Yi Wang 已提交
26
#include "paddle/fluid/platform/enforce.h"
27 28
#include "paddle/fluid/platform/lock_guard_ptr.h"
#include "paddle/fluid/platform/macros.h"
H
hutuxian 已提交
29
#include "paddle/fluid/platform/monitor.h"
30
#include "paddle/fluid/platform/place.h"
31
#include "paddle/fluid/string/split.h"
L
liaogang 已提交
32

33 34 35 36 37
DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_bool(enable_cublas_tensor_op_math);
DECLARE_string(selected_gpus);
38
DECLARE_uint64(gpu_memory_limit_mb);
39

Z
zhhsplendid 已提交
40 41
constexpr static float fraction_reserve_gpu_memory = 0.05f;

H
hutuxian 已提交
42
USE_GPU_MEM_STAT;
L
liaogang 已提交
43 44 45
namespace paddle {
namespace platform {

46 47
int CudnnVersion() {
  if (!dynload::HasCUDNN()) return -1;
48

49 50 51 52 53 54
#ifdef PADDLE_WITH_HIP
  size_t version_major, version_minor, version_patch;
  PADDLE_ENFORCE_CUDA_SUCCESS(dynload::miopenGetVersion(
      &version_major, &version_minor, &version_patch));
  return version_major * 100 + version_minor * 10 + version_patch;
#else
55
  return dynload::cudnnGetVersion();
56
#endif
57
}
S
sneaxiy 已提交
58
static int GetCUDADeviceCountImpl() {
59
  int driverVersion = 0;
60 61 62
#ifdef PADDLE_WITH_HIP
  hipError_t status = hipDriverGetVersion(&driverVersion);
#else
63
  cudaError_t status = cudaDriverGetVersion(&driverVersion);
64
#endif
65

66
  if (!(status == gpuSuccess && driverVersion != 0)) {
67
    // No GPU driver
68
    VLOG(2) << "GPU Driver Version can't be detected. No GPU driver!";
69 70 71
    return 0;
  }

72 73 74
#ifdef PADDLE_WITH_HIP
  const auto *cuda_visible_devices = std::getenv("HIP_VISIBLE_DEVICES");
#else
S
sneaxiy 已提交
75
  const auto *cuda_visible_devices = std::getenv("CUDA_VISIBLE_DEVICES");
76
#endif
S
sneaxiy 已提交
77 78
  if (cuda_visible_devices != nullptr) {
    std::string cuda_visible_devices_str(cuda_visible_devices);
79 80 81 82 83 84 85 86 87 88
    if (!cuda_visible_devices_str.empty()) {
      cuda_visible_devices_str.erase(
          0, cuda_visible_devices_str.find_first_not_of('\''));
      cuda_visible_devices_str.erase(
          cuda_visible_devices_str.find_last_not_of('\'') + 1);
      cuda_visible_devices_str.erase(
          0, cuda_visible_devices_str.find_first_not_of('\"'));
      cuda_visible_devices_str.erase(
          cuda_visible_devices_str.find_last_not_of('\"') + 1);
    }
S
sneaxiy 已提交
89 90 91
    if (std::all_of(cuda_visible_devices_str.begin(),
                    cuda_visible_devices_str.end(),
                    [](char ch) { return ch == ' '; })) {
92 93
      VLOG(2) << "CUDA_VISIBLE_DEVICES or HIP_VISIBLE_DEVICES is set to be "
                 "empty. No GPU detected.";
S
sneaxiy 已提交
94 95 96
      return 0;
    }
  }
L
liaogang 已提交
97
  int count;
98 99 100
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipGetDeviceCount(&count));
#else
101
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetDeviceCount(&count));
102
#endif
L
liaogang 已提交
103 104 105
  return count;
}

S
sneaxiy 已提交
106
int GetCUDADeviceCount() {
107
  // cache the count
S
sneaxiy 已提交
108 109 110 111
  static auto dev_cnt = GetCUDADeviceCountImpl();
  return dev_cnt;
}

112 113 114 115
/* Here is a very simple CUDA “pro tip”: cudaDeviceGetAttribute() is a much
faster way to query device properties. You can see details in
https://devblogs.nvidia.com/cuda-pro-tip-the-fast-way-to-query-device-properties/
*/
116
int GetCUDAComputeCapability(int id) {
117 118 119 120 121
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
122 123
  int major, minor;

124 125 126 127 128 129
#ifdef PADDLE_WITH_HIP
  auto major_error_code = hipDeviceGetAttribute(
      &major, hipDeviceAttributeComputeCapabilityMajor, id);
  auto minor_error_code = hipDeviceGetAttribute(
      &minor, hipDeviceAttributeComputeCapabilityMinor, id);
#else
130 131 132 133
  auto major_error_code =
      cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id);
  auto minor_error_code =
      cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, id);
134
#endif
135 136
  PADDLE_ENFORCE_CUDA_SUCCESS(major_error_code);
  PADDLE_ENFORCE_CUDA_SUCCESS(minor_error_code);
137 138 139
#ifdef PADDLE_WITH_HIP
  return major * 100 + minor;
#else
140
  return major * 10 + minor;
141
#endif
142 143
}

144
dim3 GetGpuMaxGridDimSize(int id) {
145 146 147 148 149
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
150 151
  dim3 ret;
  int size;
152 153 154 155
#ifdef PADDLE_WITH_HIP
  auto error_code_x =
      hipDeviceGetAttribute(&size, hipDeviceAttributeMaxGridDimX, id);
#else
156
  auto error_code_x = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimX, id);
157
#endif
158
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_x);
159 160
  ret.x = size;

161 162 163 164
#ifdef PADDLE_WITH_HIP
  auto error_code_y =
      hipDeviceGetAttribute(&size, hipDeviceAttributeMaxGridDimY, id);
#else
165
  auto error_code_y = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimY, id);
166
#endif
167
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_y);
168 169
  ret.y = size;

170 171 172 173
#ifdef PADDLE_WITH_HIP
  auto error_code_z =
      hipDeviceGetAttribute(&size, hipDeviceAttributeMaxGridDimZ, id);
#else
174
  auto error_code_z = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimZ, id);
175
#endif
176
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_z);
177 178 179 180
  ret.z = size;
  return ret;
}

C
chengduo 已提交
181
int GetCUDARuntimeVersion(int id) {
182 183 184 185 186
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduo 已提交
187
  int runtime_version = 0;
188 189 190
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipRuntimeGetVersion(&runtime_version));
#else
191
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaRuntimeGetVersion(&runtime_version));
192
#endif
C
chengduo 已提交
193 194 195 196
  return runtime_version;
}

int GetCUDADriverVersion(int id) {
197 198 199 200 201
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduo 已提交
202
  int driver_version = 0;
203 204 205
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipDriverGetVersion(&driver_version));
#else
206
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaDriverGetVersion(&driver_version));
207
#endif
C
chengduo 已提交
208 209 210
  return driver_version;
}

211
bool TensorCoreAvailable() {
212
#if !defined(PADDLE_WITH_HIP) && CUDA_VERSION >= 9000
213 214 215 216 217 218 219 220
  int device = GetCurrentDeviceId();
  int driver_version = GetCUDAComputeCapability(device);
  return driver_version >= 70;
#else
  return false;
#endif
}

C
chengduoZH 已提交
221
int GetCUDAMultiProcessors(int id) {
222 223 224 225 226
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduoZH 已提交
227
  int count;
228 229 230 231
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(
      hipDeviceGetAttribute(&count, hipDeviceAttributeMultiprocessorCount, id));
#else
232 233
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id));
234
#endif
C
chengduoZH 已提交
235 236 237 238
  return count;
}

int GetCUDAMaxThreadsPerMultiProcessor(int id) {
239 240 241 242 243
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduoZH 已提交
244
  int count;
245 246 247 248
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipDeviceGetAttribute(
      &count, hipDeviceAttributeMaxThreadsPerMultiProcessor, id));
#else
249 250
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaDeviceGetAttribute(
      &count, cudaDevAttrMaxThreadsPerMultiProcessor, id));
251
#endif
C
chengduoZH 已提交
252 253 254
  return count;
}

255
int GetCUDAMaxThreadsPerBlock(int id) {
256 257 258 259 260
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
261
  int count;
262 263 264 265
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(
      hipDeviceGetAttribute(&count, hipDeviceAttributeMaxThreadsPerBlock, id));
#else
266 267
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaDeviceGetAttribute(&count, cudaDevAttrMaxThreadsPerBlock, id));
268
#endif
269 270 271
  return count;
}

L
liaogang 已提交
272 273
int GetCurrentDeviceId() {
  int device_id;
274 275 276
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipGetDevice(&device_id));
#else
277
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetDevice(&device_id));
278
#endif
L
liaogang 已提交
279 280 281
  return device_id;
}

282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
//! Get a list of device ids from environment variable or use all.
std::vector<int> GetSelectedDevices() {
  // use user specified GPUs in single-node multi-process mode.
  std::vector<int> devices;
  if (!FLAGS_selected_gpus.empty()) {
    auto devices_str = paddle::string::Split(FLAGS_selected_gpus, ',');
    for (auto id : devices_str) {
      devices.push_back(atoi(id.c_str()));
    }
  } else {
    int count = GetCUDADeviceCount();
    for (int i = 0; i < count; ++i) {
      devices.push_back(i);
    }
  }
  return devices;
}

L
liaogang 已提交
300
void SetDeviceId(int id) {
Q
qijun 已提交
301
  // TODO(qijun): find a better way to cache the cuda device count
302 303 304 305 306
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
307 308 309
#ifdef PADDLE_WITH_HIP
  PADDLE_RETRY_CUDA_SUCCESS(hipSetDevice(id));
#else
L
Leo Chen 已提交
310
  PADDLE_RETRY_CUDA_SUCCESS(cudaSetDevice(id));
311
#endif
L
liaogang 已提交
312 313
}

314
void GpuMemoryUsage(size_t *available, size_t *total) {
315 316 317
  size_t actual_available, actual_total;
  RecordedCudaMemGetInfo(available, total, &actual_available, &actual_total,
                         platform::GetCurrentDeviceId());
L
liaogang 已提交
318 319
}

320
size_t GpuAvailableMemToAlloc() {
L
liaogang 已提交
321 322
  size_t total = 0;
  size_t available = 0;
323
  GpuMemoryUsage(&available, &total);
324 325
  size_t reserving =
      static_cast<size_t>(fraction_reserve_gpu_memory * available);
326
  // If available size is less than minimum chunk size, no usable memory exists
327
  size_t available_to_alloc = available - reserving;
328
  size_t min_chunk_size = GpuMinChunkSize();
329 330 331
  if (available_to_alloc < min_chunk_size) {
    available_to_alloc = 0;
  }
332 333 334
  VLOG(10) << "GPU usage " << (available >> 20) << "M/" << (total >> 20)
           << "M, " << (available_to_alloc >> 20) << "M available to allocate";
  return available_to_alloc;
Z
zhhsplendid 已提交
335 336
}

337 338 339
size_t GpuMaxAllocSize() {
  return std::max(GpuInitAllocSize(), GpuReallocSize());
}
Z
zhhsplendid 已提交
340

341 342
static size_t GpuAllocSize(bool realloc) {
  size_t available_to_alloc = GpuAvailableMemToAlloc();
G
GaoWei8 已提交
343 344 345
  PADDLE_ENFORCE_GT(
      available_to_alloc, 0,
      platform::errors::ResourceExhausted("Not enough available GPU memory."));
346 347 348 349 350 351 352
  // If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be
  // allocated by fraction
  size_t flag_mb = realloc ? FLAGS_reallocate_gpu_memory_in_mb
                           : FLAGS_initial_gpu_memory_in_mb;
  size_t alloc_bytes =
      (flag_mb > 0ul ? flag_mb << 20 : available_to_alloc *
                                           FLAGS_fraction_of_gpu_memory_to_use);
G
GaoWei8 已提交
353 354 355
  PADDLE_ENFORCE_GE(
      available_to_alloc, alloc_bytes,
      platform::errors::ResourceExhausted("Not enough available GPU memory."));
356 357 358 359
  VLOG(10) << "Alloc size is " << (alloc_bytes >> 20)
           << " MiB, is it Re-alloc: " << realloc;
  return alloc_bytes;
}
Z
zhhsplendid 已提交
360

361
size_t GpuInitAllocSize() { return GpuAllocSize(/* realloc = */ false); }
Z
zhhsplendid 已提交
362

363
size_t GpuReallocSize() { return GpuAllocSize(/* realloc = */ true); }
L
liaogang 已提交
364

L
liaogang 已提交
365 366 367 368 369 370
size_t GpuMinChunkSize() {
  // Allow to allocate the minimum chunk size is 256 bytes.
  return 1 << 8;
}

size_t GpuMaxChunkSize() {
371 372 373
  size_t max_chunk_size = GpuMaxAllocSize();
  VLOG(10) << "Max chunk size " << (max_chunk_size >> 20) << "M";
  return max_chunk_size;
L
liaogang 已提交
374 375
}

376 377 378 379 380 381
#ifdef PADDLE_WITH_HIP
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
                    enum hipMemcpyKind kind, hipStream_t stream) {
  PADDLE_ENFORCE_CUDA_SUCCESS(hipMemcpyAsync(dst, src, count, kind, stream));
}
#else
L
liaogang 已提交
382 383
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
                    enum cudaMemcpyKind kind, cudaStream_t stream) {
384
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync(dst, src, count, kind, stream));
L
liaogang 已提交
385
}
386
#endif
L
liaogang 已提交
387

388 389 390 391 392 393
#ifdef PADDLE_WITH_HIP
void GpuMemcpySync(void *dst, const void *src, size_t count,
                   enum hipMemcpyKind kind) {
  PADDLE_ENFORCE_CUDA_SUCCESS(hipMemcpy(dst, src, count, kind));
}
#else
394 395
void GpuMemcpySync(void *dst, const void *src, size_t count,
                   enum cudaMemcpyKind kind) {
396
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpy(dst, src, count, kind));
397
}
398
#endif
399 400

void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,
401 402 403 404 405
                        int src_device, size_t count, gpuStream_t stream) {
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(
      hipMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream));
#else
406 407
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream));
408
#endif
409 410 411 412
}

void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src,
                       int src_device, size_t count) {
413 414 415 416
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(
      hipMemcpyPeer(dst, dst_device, src, src_device, count));
#else
417 418
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaMemcpyPeer(dst, dst_device, src, src_device, count));
419
#endif
L
liaogang 已提交
420
}
D
dzhwinter 已提交
421

422 423 424 425
void GpuMemsetAsync(void *dst, int value, size_t count, gpuStream_t stream) {
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipMemsetAsync(dst, value, count, stream));
#else
426
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync(dst, value, count, stream));
427
#endif
D
dzhwinter 已提交
428
}
429

430 431 432 433
void GpuStreamSync(gpuStream_t stream) {
#ifdef PADDLE_WITH_HIP
  PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
434
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
435
#endif
石晓伟 已提交
436 437
}

438 439 440 441 442 443
static void RaiseNonOutOfMemoryError(gpuError_t *status) {
#ifdef PADDLE_WITH_HIP
  if (*status == hipErrorOutOfMemory) {
    *status = hipSuccess;
  }
#else
444 445 446
  if (*status == cudaErrorMemoryAllocation) {
    *status = cudaSuccess;
  }
447
#endif
448 449
  PADDLE_ENFORCE_CUDA_SUCCESS(*status);

450 451 452 453 454 455
#ifdef PADDLE_WITH_HIP
  *status = hipGetLastError();
  if (*status == hipErrorOutOfMemory) {
    *status = hipSuccess;
  }
#else
456 457 458 459
  *status = cudaGetLastError();
  if (*status == cudaErrorMemoryAllocation) {
    *status = cudaSuccess;
  }
460
#endif
461 462
  PADDLE_ENFORCE_CUDA_SUCCESS(*status);
}
石晓伟 已提交
463

464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
class RecordedCudaMallocHelper {
 private:
  explicit RecordedCudaMallocHelper(int dev_id, uint64_t limit_size = 0)
      : dev_id_(dev_id), limit_size_(limit_size) {
    if (NeedRecord()) {
      mtx_.reset(new std::mutex());
    }
  }

  DISABLE_COPY_AND_ASSIGN(RecordedCudaMallocHelper);

 public:
  static RecordedCudaMallocHelper *Instance(int dev_id) {
    std::call_once(once_flag_, [] {
      int dev_cnt = GetCUDADeviceCount();
      instances_.reserve(dev_cnt);
      for (int i = 0; i < dev_cnt; ++i) {
        instances_.emplace_back(
            new RecordedCudaMallocHelper(i, FLAGS_gpu_memory_limit_mb << 20));
      }
    });

    PADDLE_ENFORCE_GE(
        dev_id, 0,
        platform::errors::OutOfRange(
G
GaoWei8 已提交
489
            "Device id must be not less than 0, but got %d.", dev_id));
490 491
    PADDLE_ENFORCE_LT(
        dev_id, instances_.size(),
G
GaoWei8 已提交
492
        platform::errors::OutOfRange("Device id %d exceeds gpu card number %d.",
493 494 495 496 497 498 499 500 501
                                     dev_id, instances_.size()));
    return instances_[dev_id].get();
  }

  /**
   * Try to allocate `size` gpu memory. Only cudaErrorMemoryAllocation
   * or cudaSuccess would be returned, and the cudaGetLastError() flag
   * would be clear.
   */
502
  gpuError_t Malloc(void **ptr, size_t size) {
503
    LockGuardPtr<std::mutex> lock(mtx_);
504
    if (UNLIKELY(NeedRecord() && cur_size_.load() + size > limit_size_)) {
505 506 507
#ifdef PADDLE_WITH_HIP
      return hipErrorOutOfMemory;
#else
508
      return cudaErrorMemoryAllocation;
509
#endif
510 511 512
    }

    CUDADeviceGuard guard(dev_id_);
513 514 515
#ifdef PADDLE_WITH_HIP
    auto result = hipMalloc(ptr, size);
#else
516
    auto result = cudaMalloc(ptr, size);
517 518
#endif
    if (result == gpuSuccess) {
519
      cur_size_.fetch_add(size);
H
hutuxian 已提交
520
      STAT_INT_ADD("STAT_gpu" + std::to_string(dev_id_) + "_mem_size", size);
521
      return gpuSuccess;
522 523
    } else {
      RaiseNonOutOfMemoryError(&result);
524 525 526 527 528 529
// Non out of memory error would be raised inside
// RaiseNonOutOfMemoryError. Therefore, we can
// return cudaErrorMemoryAllocation directly here.
#ifdef PADDLE_WITH_HIP
      return hipErrorOutOfMemory;
#else
530
      return cudaErrorMemoryAllocation;
531
#endif
532 533 534 535 536 537 538 539 540 541 542 543 544 545
    }
  }

  /**
   * Free gpu memory. Usually, free is not allowed to raise error.
   * If it does raise error, the process should be crashed.
   */
  void Free(void *ptr, size_t size) {
    // Purposefully allow cudaErrorCudartUnloading, because
    // that is returned if you ever call cudaFree after the
    // driver has already shutdown. This happens only if the
    // process is terminating, in which case we don't care if
    // cudaFree succeeds.
    CUDADeviceGuard guard(dev_id_);
546 547 548 549
#ifdef PADDLE_WITH_HIP
    auto err = hipFree(ptr);
    if (err != hipErrorDeinitialized) {
#else
550 551
    auto err = cudaFree(ptr);
    if (err != cudaErrorCudartUnloading) {
552
#endif
553
      PADDLE_ENFORCE_CUDA_SUCCESS(err);
554
      cur_size_.fetch_sub(size);
H
hutuxian 已提交
555
      STAT_INT_SUB("STAT_gpu" + std::to_string(dev_id_) + "_mem_size", size);
556
    } else {
557 558 559
#ifdef PADDLE_WITH_HIP
      hipGetLastError();  // clear the error flag when hipErrorDeinitialized
#else
560
      cudaGetLastError();  // clear the error flag when cudaErrorCudartUnloading
561
#endif
562 563 564 565 566 567 568
    }
  }

  bool GetMemInfo(size_t *avail, size_t *total, size_t *actual_avail,
                  size_t *actual_total) {
    {
      CUDADeviceGuard guard(dev_id_);
569 570 571
#ifdef PADDLE_WITH_HIP
      auto result = hipMemGetInfo(actual_avail, actual_total);
#else
572
      auto result = cudaMemGetInfo(actual_avail, actual_total);
573 574
#endif
      if (result != gpuSuccess) {
575 576 577 578 579 580 581
        *actual_avail = 0;
      }
      RaiseNonOutOfMemoryError(&result);
    }

    if (NeedRecord()) {
      std::lock_guard<std::mutex> guard(*mtx_);
582
      *avail = std::min(*actual_avail, limit_size_ - cur_size_.load());
583 584 585 586 587 588 589 590 591 592 593
      *total = std::min(*actual_total, limit_size_);
      return *total < *actual_total;
    } else {
      *avail = *actual_avail;
      *total = *actual_total;
      return false;
    }
  }

  inline bool NeedRecord() const { return limit_size_ != 0; }

594
  uint64_t RecordedSize() const { return cur_size_.load(); }
595 596 597 598 599 600

  uint64_t LimitSize() const { return limit_size_; }

 private:
  const int dev_id_;
  const uint64_t limit_size_;
601
  std::atomic<uint64_t> cur_size_{0};
602 603 604 605 606

  mutable std::unique_ptr<std::mutex> mtx_;

  static std::once_flag once_flag_;
  static std::vector<std::unique_ptr<RecordedCudaMallocHelper>> instances_;
607
};  // NOLINT
608 609 610 611 612

std::once_flag RecordedCudaMallocHelper::once_flag_;
std::vector<std::unique_ptr<RecordedCudaMallocHelper>>
    RecordedCudaMallocHelper::instances_;

613
gpuError_t RecordedCudaMalloc(void **ptr, size_t size, int dev_id) {
614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634
  return RecordedCudaMallocHelper::Instance(dev_id)->Malloc(ptr, size);
}

void RecordedCudaFree(void *p, size_t size, int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->Free(p, size);
}

bool RecordedCudaMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail,
                            size_t *actual_total, int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->GetMemInfo(
      avail, total, actual_avail, actual_total);
}

uint64_t RecordedCudaMallocSize(int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->RecordedSize();
}

bool IsCudaMallocRecorded(int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->NeedRecord();
}

635 636 637 638 639 640 641
void EmptyCache(void) {
  std::vector<int> devices = GetSelectedDevices();
  for (auto device : devices) {
    memory::Release(CUDAPlace(device));
  }
}

L
liaogang 已提交
642 643
}  // namespace platform
}  // namespace paddle