gpu_info.cc 16.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
liaogang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/platform/gpu_info.h"
16
#include <algorithm>
S
sneaxiy 已提交
17
#include <cstdlib>
18
#include <memory>
L
liaogang 已提交
19

20
#include "gflags/gflags.h"
21
#include "paddle/fluid/platform/cuda_device_guard.h"
22
#include "paddle/fluid/platform/dynload/cudnn.h"
Y
Yi Wang 已提交
23
#include "paddle/fluid/platform/enforce.h"
24 25
#include "paddle/fluid/platform/lock_guard_ptr.h"
#include "paddle/fluid/platform/macros.h"
H
hutuxian 已提交
26
#include "paddle/fluid/platform/monitor.h"
27
#include "paddle/fluid/string/split.h"
L
liaogang 已提交
28

29 30 31 32 33
DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_bool(enable_cublas_tensor_op_math);
DECLARE_string(selected_gpus);
34
DECLARE_uint64(gpu_memory_limit_mb);
35

Z
zhhsplendid 已提交
36 37
constexpr static float fraction_reserve_gpu_memory = 0.05f;

H
hutuxian 已提交
38
USE_GPU_MEM_STAT;
L
liaogang 已提交
39 40 41
namespace paddle {
namespace platform {

42 43
int CudnnVersion() {
  if (!dynload::HasCUDNN()) return -1;
44

45 46
  return dynload::cudnnGetVersion();
}
S
sneaxiy 已提交
47
static int GetCUDADeviceCountImpl() {
48 49 50 51 52
  int driverVersion = 0;
  cudaError_t status = cudaDriverGetVersion(&driverVersion);

  if (!(status == cudaSuccess && driverVersion != 0)) {
    // No GPU driver
53
    VLOG(2) << "GPU Driver Version can't be detected. No GPU driver!";
54 55 56
    return 0;
  }

S
sneaxiy 已提交
57 58 59
  const auto *cuda_visible_devices = std::getenv("CUDA_VISIBLE_DEVICES");
  if (cuda_visible_devices != nullptr) {
    std::string cuda_visible_devices_str(cuda_visible_devices);
60 61 62 63 64 65 66 67 68 69
    if (!cuda_visible_devices_str.empty()) {
      cuda_visible_devices_str.erase(
          0, cuda_visible_devices_str.find_first_not_of('\''));
      cuda_visible_devices_str.erase(
          cuda_visible_devices_str.find_last_not_of('\'') + 1);
      cuda_visible_devices_str.erase(
          0, cuda_visible_devices_str.find_first_not_of('\"'));
      cuda_visible_devices_str.erase(
          cuda_visible_devices_str.find_last_not_of('\"') + 1);
    }
S
sneaxiy 已提交
70 71 72
    if (std::all_of(cuda_visible_devices_str.begin(),
                    cuda_visible_devices_str.end(),
                    [](char ch) { return ch == ' '; })) {
S
sneaxiy 已提交
73
      VLOG(2) << "CUDA_VISIBLE_DEVICES is set to be empty. No GPU detected.";
S
sneaxiy 已提交
74 75 76
      return 0;
    }
  }
L
liaogang 已提交
77
  int count;
78
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetDeviceCount(&count));
L
liaogang 已提交
79 80 81
  return count;
}

S
sneaxiy 已提交
82 83 84 85 86
int GetCUDADeviceCount() {
  static auto dev_cnt = GetCUDADeviceCountImpl();
  return dev_cnt;
}

87 88 89 90
/* Here is a very simple CUDA “pro tip”: cudaDeviceGetAttribute() is a much
faster way to query device properties. You can see details in
https://devblogs.nvidia.com/cuda-pro-tip-the-fast-way-to-query-device-properties/
*/
91
int GetCUDAComputeCapability(int id) {
92 93 94 95 96
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
97 98 99 100 101 102
  int major, minor;

  auto major_error_code =
      cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id);
  auto minor_error_code =
      cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, id);
103 104
  PADDLE_ENFORCE_CUDA_SUCCESS(major_error_code);
  PADDLE_ENFORCE_CUDA_SUCCESS(minor_error_code);
105
  return major * 10 + minor;
106 107
}

108
dim3 GetGpuMaxGridDimSize(int id) {
109 110 111 112 113
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
114 115 116
  dim3 ret;
  int size;
  auto error_code_x = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimX, id);
117
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_x);
118 119 120
  ret.x = size;

  auto error_code_y = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimY, id);
121
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_y);
122 123 124
  ret.y = size;

  auto error_code_z = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimZ, id);
125
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_z);
126 127 128 129
  ret.z = size;
  return ret;
}

C
chengduo 已提交
130
int GetCUDARuntimeVersion(int id) {
131 132 133 134 135
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduo 已提交
136
  int runtime_version = 0;
137
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaRuntimeGetVersion(&runtime_version));
C
chengduo 已提交
138 139 140 141
  return runtime_version;
}

int GetCUDADriverVersion(int id) {
142 143 144 145 146
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduo 已提交
147
  int driver_version = 0;
148
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaDriverGetVersion(&driver_version));
C
chengduo 已提交
149 150 151
  return driver_version;
}

152 153 154 155 156 157 158 159 160 161
bool TensorCoreAvailable() {
#if CUDA_VERSION >= 9000
  int device = GetCurrentDeviceId();
  int driver_version = GetCUDAComputeCapability(device);
  return driver_version >= 70;
#else
  return false;
#endif
}

C
chengduoZH 已提交
162
int GetCUDAMultiProcessors(int id) {
163 164 165 166 167
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduoZH 已提交
168
  int count;
169 170
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id));
C
chengduoZH 已提交
171 172 173 174
  return count;
}

int GetCUDAMaxThreadsPerMultiProcessor(int id) {
175 176 177 178 179
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduoZH 已提交
180
  int count;
181 182
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaDeviceGetAttribute(
      &count, cudaDevAttrMaxThreadsPerMultiProcessor, id));
C
chengduoZH 已提交
183 184 185
  return count;
}

186
int GetCUDAMaxThreadsPerBlock(int id) {
187 188 189 190 191
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
192
  int count;
193 194
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaDeviceGetAttribute(&count, cudaDevAttrMaxThreadsPerBlock, id));
195 196 197
  return count;
}

L
liaogang 已提交
198 199
int GetCurrentDeviceId() {
  int device_id;
200
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetDevice(&device_id));
L
liaogang 已提交
201 202 203
  return device_id;
}

204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
//! Get a list of device ids from environment variable or use all.
std::vector<int> GetSelectedDevices() {
  // use user specified GPUs in single-node multi-process mode.
  std::vector<int> devices;
  if (!FLAGS_selected_gpus.empty()) {
    auto devices_str = paddle::string::Split(FLAGS_selected_gpus, ',');
    for (auto id : devices_str) {
      devices.push_back(atoi(id.c_str()));
    }
  } else {
    int count = GetCUDADeviceCount();
    for (int i = 0; i < count; ++i) {
      devices.push_back(i);
    }
  }
  return devices;
}

L
liaogang 已提交
222
void SetDeviceId(int id) {
Q
qijun 已提交
223
  // TODO(qijun): find a better way to cache the cuda device count
224 225 226 227 228 229
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(id));
L
liaogang 已提交
230 231
}

232
void GpuMemoryUsage(size_t *available, size_t *total) {
233 234 235
  size_t actual_available, actual_total;
  RecordedCudaMemGetInfo(available, total, &actual_available, &actual_total,
                         platform::GetCurrentDeviceId());
L
liaogang 已提交
236 237
}

238
size_t GpuAvailableMemToAlloc() {
L
liaogang 已提交
239 240
  size_t total = 0;
  size_t available = 0;
241
  GpuMemoryUsage(&available, &total);
242 243
  size_t reserving =
      static_cast<size_t>(fraction_reserve_gpu_memory * available);
244
  // If available size is less than minimum chunk size, no usable memory exists
245
  size_t available_to_alloc = available - reserving;
246
  size_t min_chunk_size = GpuMinChunkSize();
247 248 249
  if (available_to_alloc < min_chunk_size) {
    available_to_alloc = 0;
  }
250 251 252
  VLOG(10) << "GPU usage " << (available >> 20) << "M/" << (total >> 20)
           << "M, " << (available_to_alloc >> 20) << "M available to allocate";
  return available_to_alloc;
Z
zhhsplendid 已提交
253 254
}

255 256 257
size_t GpuMaxAllocSize() {
  return std::max(GpuInitAllocSize(), GpuReallocSize());
}
Z
zhhsplendid 已提交
258

259 260
static size_t GpuAllocSize(bool realloc) {
  size_t available_to_alloc = GpuAvailableMemToAlloc();
G
GaoWei8 已提交
261 262 263
  PADDLE_ENFORCE_GT(
      available_to_alloc, 0,
      platform::errors::ResourceExhausted("Not enough available GPU memory."));
264 265 266 267 268 269 270
  // If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be
  // allocated by fraction
  size_t flag_mb = realloc ? FLAGS_reallocate_gpu_memory_in_mb
                           : FLAGS_initial_gpu_memory_in_mb;
  size_t alloc_bytes =
      (flag_mb > 0ul ? flag_mb << 20 : available_to_alloc *
                                           FLAGS_fraction_of_gpu_memory_to_use);
G
GaoWei8 已提交
271 272 273
  PADDLE_ENFORCE_GE(
      available_to_alloc, alloc_bytes,
      platform::errors::ResourceExhausted("Not enough available GPU memory."));
274 275 276 277
  VLOG(10) << "Alloc size is " << (alloc_bytes >> 20)
           << " MiB, is it Re-alloc: " << realloc;
  return alloc_bytes;
}
Z
zhhsplendid 已提交
278

279
size_t GpuInitAllocSize() { return GpuAllocSize(/* realloc = */ false); }
Z
zhhsplendid 已提交
280

281
size_t GpuReallocSize() { return GpuAllocSize(/* realloc = */ true); }
L
liaogang 已提交
282

L
liaogang 已提交
283 284 285 286 287 288
size_t GpuMinChunkSize() {
  // Allow to allocate the minimum chunk size is 256 bytes.
  return 1 << 8;
}

size_t GpuMaxChunkSize() {
289 290 291
  size_t max_chunk_size = GpuMaxAllocSize();
  VLOG(10) << "Max chunk size " << (max_chunk_size >> 20) << "M";
  return max_chunk_size;
L
liaogang 已提交
292 293
}

L
liaogang 已提交
294 295
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
                    enum cudaMemcpyKind kind, cudaStream_t stream) {
296
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync(dst, src, count, kind, stream));
L
liaogang 已提交
297 298
}

299 300
void GpuMemcpySync(void *dst, const void *src, size_t count,
                   enum cudaMemcpyKind kind) {
301
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpy(dst, src, count, kind));
302 303 304 305
}

void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,
                        int src_device, size_t count, cudaStream_t stream) {
306 307
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream));
308 309 310 311
}

void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src,
                       int src_device, size_t count) {
312 313
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaMemcpyPeer(dst, dst_device, src, src_device, count));
L
liaogang 已提交
314
}
D
dzhwinter 已提交
315 316

void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) {
317
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync(dst, value, count, stream));
D
dzhwinter 已提交
318
}
319

石晓伟 已提交
320
void GpuStreamSync(cudaStream_t stream) {
321
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
石晓伟 已提交
322 323
}

324
static void RaiseNonOutOfMemoryError(cudaError_t *status) {
325 326 327 328 329 330 331 332 333 334 335
  if (*status == cudaErrorMemoryAllocation) {
    *status = cudaSuccess;
  }
  PADDLE_ENFORCE_CUDA_SUCCESS(*status);

  *status = cudaGetLastError();
  if (*status == cudaErrorMemoryAllocation) {
    *status = cudaSuccess;
  }
  PADDLE_ENFORCE_CUDA_SUCCESS(*status);
}
石晓伟 已提交
336

337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
class RecordedCudaMallocHelper {
 private:
  explicit RecordedCudaMallocHelper(int dev_id, uint64_t limit_size = 0)
      : dev_id_(dev_id), limit_size_(limit_size) {
    if (NeedRecord()) {
      mtx_.reset(new std::mutex());
    }
  }

  DISABLE_COPY_AND_ASSIGN(RecordedCudaMallocHelper);

 public:
  static RecordedCudaMallocHelper *Instance(int dev_id) {
    std::call_once(once_flag_, [] {
      int dev_cnt = GetCUDADeviceCount();
      instances_.reserve(dev_cnt);
      for (int i = 0; i < dev_cnt; ++i) {
        instances_.emplace_back(
            new RecordedCudaMallocHelper(i, FLAGS_gpu_memory_limit_mb << 20));
      }
    });

    PADDLE_ENFORCE_GE(
        dev_id, 0,
        platform::errors::OutOfRange(
G
GaoWei8 已提交
362
            "Device id must be not less than 0, but got %d.", dev_id));
363 364
    PADDLE_ENFORCE_LT(
        dev_id, instances_.size(),
G
GaoWei8 已提交
365
        platform::errors::OutOfRange("Device id %d exceeds gpu card number %d.",
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386
                                     dev_id, instances_.size()));
    return instances_[dev_id].get();
  }

  /**
   * Try to allocate `size` gpu memory. Only cudaErrorMemoryAllocation
   * or cudaSuccess would be returned, and the cudaGetLastError() flag
   * would be clear.
   */
  cudaError_t Malloc(void **ptr, size_t size) {
    LockGuardPtr<std::mutex> lock(mtx_);
    if (UNLIKELY(NeedRecord() && cur_size_ + size > limit_size_)) {
      return cudaErrorMemoryAllocation;
    }

    CUDADeviceGuard guard(dev_id_);
    auto result = cudaMalloc(ptr, size);
    if (result == cudaSuccess) {
      if (NeedRecord()) {
        cur_size_ += size;
      }
H
hutuxian 已提交
387
      STAT_INT_ADD("STAT_gpu" + std::to_string(dev_id_) + "_mem_size", size);
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
      return cudaSuccess;
    } else {
      RaiseNonOutOfMemoryError(&result);
      // Non out of memory error would be raised inside
      // RaiseNonOutOfMemoryError. Therefore, we can
      // return cudaErrorMemoryAllocation directly here.
      return cudaErrorMemoryAllocation;
    }
  }

  /**
   * Free gpu memory. Usually, free is not allowed to raise error.
   * If it does raise error, the process should be crashed.
   */
  void Free(void *ptr, size_t size) {
    // Purposefully allow cudaErrorCudartUnloading, because
    // that is returned if you ever call cudaFree after the
    // driver has already shutdown. This happens only if the
    // process is terminating, in which case we don't care if
    // cudaFree succeeds.
    CUDADeviceGuard guard(dev_id_);
    auto err = cudaFree(ptr);
    if (err != cudaErrorCudartUnloading) {
411
      PADDLE_ENFORCE_CUDA_SUCCESS(err);
412 413 414 415
      if (NeedRecord()) {
        std::lock_guard<std::mutex> guard(*mtx_);
        cur_size_ -= size;
      }
H
hutuxian 已提交
416
      STAT_INT_SUB("STAT_gpu" + std::to_string(dev_id_) + "_mem_size", size);
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490
    } else {
      cudaGetLastError();  // clear the error flag when cudaErrorCudartUnloading
    }
  }

  bool GetMemInfo(size_t *avail, size_t *total, size_t *actual_avail,
                  size_t *actual_total) {
    {
      CUDADeviceGuard guard(dev_id_);
      auto result = cudaMemGetInfo(actual_avail, actual_total);
      if (result != cudaSuccess) {
        *actual_avail = 0;
      }
      RaiseNonOutOfMemoryError(&result);
    }

    if (NeedRecord()) {
      std::lock_guard<std::mutex> guard(*mtx_);
      *avail = std::min(*actual_avail, limit_size_ - cur_size_);
      *total = std::min(*actual_total, limit_size_);
      return *total < *actual_total;
    } else {
      *avail = *actual_avail;
      *total = *actual_total;
      return false;
    }
  }

  inline bool NeedRecord() const { return limit_size_ != 0; }

  uint64_t RecordedSize() const {
    LockGuardPtr<std::mutex> lock(mtx_);
    return NeedRecord() ? cur_size_ : 0;
  }

  uint64_t LimitSize() const { return limit_size_; }

 private:
  const int dev_id_;
  const uint64_t limit_size_;
  uint64_t cur_size_{0};

  mutable std::unique_ptr<std::mutex> mtx_;

  static std::once_flag once_flag_;
  static std::vector<std::unique_ptr<RecordedCudaMallocHelper>> instances_;
};

std::once_flag RecordedCudaMallocHelper::once_flag_;
std::vector<std::unique_ptr<RecordedCudaMallocHelper>>
    RecordedCudaMallocHelper::instances_;

cudaError_t RecordedCudaMalloc(void **ptr, size_t size, int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->Malloc(ptr, size);
}

void RecordedCudaFree(void *p, size_t size, int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->Free(p, size);
}

bool RecordedCudaMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail,
                            size_t *actual_total, int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->GetMemInfo(
      avail, total, actual_avail, actual_total);
}

uint64_t RecordedCudaMallocSize(int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->RecordedSize();
}

bool IsCudaMallocRecorded(int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->NeedRecord();
}

L
liaogang 已提交
491 492
}  // namespace platform
}  // namespace paddle