gpu_info.cc 16.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
liaogang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/platform/gpu_info.h"
16
#include <algorithm>
S
sneaxiy 已提交
17
#include <cstdlib>
18
#include <memory>
L
liaogang 已提交
19

20
#include "gflags/gflags.h"
21
#include "paddle/fluid/platform/cuda_device_guard.h"
22
#include "paddle/fluid/platform/dynload/cudnn.h"
Y
Yi Wang 已提交
23
#include "paddle/fluid/platform/enforce.h"
24 25
#include "paddle/fluid/platform/lock_guard_ptr.h"
#include "paddle/fluid/platform/macros.h"
H
hutuxian 已提交
26
#include "paddle/fluid/platform/monitor.h"
27
#include "paddle/fluid/string/split.h"
L
liaogang 已提交
28

29 30 31 32 33
DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_bool(enable_cublas_tensor_op_math);
DECLARE_string(selected_gpus);
34
DECLARE_uint64(gpu_memory_limit_mb);
35

Z
zhhsplendid 已提交
36 37
constexpr static float fraction_reserve_gpu_memory = 0.05f;

H
hutuxian 已提交
38
USE_GPU_MEM_STAT;
L
liaogang 已提交
39 40 41
namespace paddle {
namespace platform {

42 43
int CudnnVersion() {
  if (!dynload::HasCUDNN()) return -1;
44

45 46
  return dynload::cudnnGetVersion();
}
S
sneaxiy 已提交
47
static int GetCUDADeviceCountImpl() {
48 49 50 51 52
  int driverVersion = 0;
  cudaError_t status = cudaDriverGetVersion(&driverVersion);

  if (!(status == cudaSuccess && driverVersion != 0)) {
    // No GPU driver
53
    VLOG(2) << "GPU Driver Version can't be detected. No GPU driver!";
54 55 56
    return 0;
  }

S
sneaxiy 已提交
57 58 59 60 61 62
  const auto *cuda_visible_devices = std::getenv("CUDA_VISIBLE_DEVICES");
  if (cuda_visible_devices != nullptr) {
    std::string cuda_visible_devices_str(cuda_visible_devices);
    if (std::all_of(cuda_visible_devices_str.begin(),
                    cuda_visible_devices_str.end(),
                    [](char ch) { return ch == ' '; })) {
S
sneaxiy 已提交
63
      VLOG(2) << "CUDA_VISIBLE_DEVICES is set to be empty. No GPU detected.";
S
sneaxiy 已提交
64 65 66
      return 0;
    }
  }
L
liaogang 已提交
67
  int count;
68
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetDeviceCount(&count));
L
liaogang 已提交
69 70 71
  return count;
}

S
sneaxiy 已提交
72 73 74 75 76
int GetCUDADeviceCount() {
  static auto dev_cnt = GetCUDADeviceCountImpl();
  return dev_cnt;
}

77 78 79 80
/* Here is a very simple CUDA “pro tip”: cudaDeviceGetAttribute() is a much
faster way to query device properties. You can see details in
https://devblogs.nvidia.com/cuda-pro-tip-the-fast-way-to-query-device-properties/
*/
81
int GetCUDAComputeCapability(int id) {
82 83 84 85 86
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
87 88 89 90 91 92
  int major, minor;

  auto major_error_code =
      cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id);
  auto minor_error_code =
      cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, id);
93 94
  PADDLE_ENFORCE_CUDA_SUCCESS(major_error_code);
  PADDLE_ENFORCE_CUDA_SUCCESS(minor_error_code);
95
  return major * 10 + minor;
96 97
}

98
dim3 GetGpuMaxGridDimSize(int id) {
99 100 101 102 103
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
104 105 106
  dim3 ret;
  int size;
  auto error_code_x = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimX, id);
107
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_x);
108 109 110
  ret.x = size;

  auto error_code_y = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimY, id);
111
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_y);
112 113 114
  ret.y = size;

  auto error_code_z = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimZ, id);
115
  PADDLE_ENFORCE_CUDA_SUCCESS(error_code_z);
116 117 118 119
  ret.z = size;
  return ret;
}

C
chengduo 已提交
120
int GetCUDARuntimeVersion(int id) {
121 122 123 124 125
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduo 已提交
126
  int runtime_version = 0;
127
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaRuntimeGetVersion(&runtime_version));
C
chengduo 已提交
128 129 130 131
  return runtime_version;
}

int GetCUDADriverVersion(int id) {
132 133 134 135 136
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduo 已提交
137
  int driver_version = 0;
138
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaDriverGetVersion(&driver_version));
C
chengduo 已提交
139 140 141
  return driver_version;
}

142 143 144 145 146 147 148 149 150 151
bool TensorCoreAvailable() {
#if CUDA_VERSION >= 9000
  int device = GetCurrentDeviceId();
  int driver_version = GetCUDAComputeCapability(device);
  return driver_version >= 70;
#else
  return false;
#endif
}

C
chengduoZH 已提交
152
int GetCUDAMultiProcessors(int id) {
153 154 155 156 157
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduoZH 已提交
158
  int count;
159 160
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id));
C
chengduoZH 已提交
161 162 163 164
  return count;
}

int GetCUDAMaxThreadsPerMultiProcessor(int id) {
165 166 167 168 169
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
C
chengduoZH 已提交
170
  int count;
171 172
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaDeviceGetAttribute(
      &count, cudaDevAttrMaxThreadsPerMultiProcessor, id));
C
chengduoZH 已提交
173 174 175
  return count;
}

176
int GetCUDAMaxThreadsPerBlock(int id) {
177 178 179 180 181
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
182
  int count;
183 184
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaDeviceGetAttribute(&count, cudaDevAttrMaxThreadsPerBlock, id));
185 186 187
  return count;
}

L
liaogang 已提交
188 189
int GetCurrentDeviceId() {
  int device_id;
190
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetDevice(&device_id));
L
liaogang 已提交
191 192 193
  return device_id;
}

194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
//! Get a list of device ids from environment variable or use all.
std::vector<int> GetSelectedDevices() {
  // use user specified GPUs in single-node multi-process mode.
  std::vector<int> devices;
  if (!FLAGS_selected_gpus.empty()) {
    auto devices_str = paddle::string::Split(FLAGS_selected_gpus, ',');
    for (auto id : devices_str) {
      devices.push_back(atoi(id.c_str()));
    }
  } else {
    int count = GetCUDADeviceCount();
    for (int i = 0; i < count; ++i) {
      devices.push_back(i);
    }
  }
  return devices;
}

L
liaogang 已提交
212
void SetDeviceId(int id) {
Q
qijun 已提交
213
  // TODO(qijun): find a better way to cache the cuda device count
214 215 216 217 218 219
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
                    platform::errors::InvalidArgument(
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
                        id, GetCUDADeviceCount()));
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(id));
L
liaogang 已提交
220 221
}

222
void GpuMemoryUsage(size_t *available, size_t *total) {
223 224 225
  size_t actual_available, actual_total;
  RecordedCudaMemGetInfo(available, total, &actual_available, &actual_total,
                         platform::GetCurrentDeviceId());
L
liaogang 已提交
226 227
}

228
size_t GpuAvailableMemToAlloc() {
L
liaogang 已提交
229 230
  size_t total = 0;
  size_t available = 0;
231
  GpuMemoryUsage(&available, &total);
232 233
  size_t reserving =
      static_cast<size_t>(fraction_reserve_gpu_memory * available);
234
  // If available size is less than minimum chunk size, no usable memory exists
235
  size_t available_to_alloc = available - reserving;
236
  size_t min_chunk_size = GpuMinChunkSize();
237 238 239
  if (available_to_alloc < min_chunk_size) {
    available_to_alloc = 0;
  }
240 241 242
  VLOG(10) << "GPU usage " << (available >> 20) << "M/" << (total >> 20)
           << "M, " << (available_to_alloc >> 20) << "M available to allocate";
  return available_to_alloc;
Z
zhhsplendid 已提交
243 244
}

245 246 247
size_t GpuMaxAllocSize() {
  return std::max(GpuInitAllocSize(), GpuReallocSize());
}
Z
zhhsplendid 已提交
248

249 250
static size_t GpuAllocSize(bool realloc) {
  size_t available_to_alloc = GpuAvailableMemToAlloc();
G
GaoWei8 已提交
251 252 253
  PADDLE_ENFORCE_GT(
      available_to_alloc, 0,
      platform::errors::ResourceExhausted("Not enough available GPU memory."));
254 255 256 257 258 259 260
  // If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be
  // allocated by fraction
  size_t flag_mb = realloc ? FLAGS_reallocate_gpu_memory_in_mb
                           : FLAGS_initial_gpu_memory_in_mb;
  size_t alloc_bytes =
      (flag_mb > 0ul ? flag_mb << 20 : available_to_alloc *
                                           FLAGS_fraction_of_gpu_memory_to_use);
G
GaoWei8 已提交
261 262 263
  PADDLE_ENFORCE_GE(
      available_to_alloc, alloc_bytes,
      platform::errors::ResourceExhausted("Not enough available GPU memory."));
264 265 266 267
  VLOG(10) << "Alloc size is " << (alloc_bytes >> 20)
           << " MiB, is it Re-alloc: " << realloc;
  return alloc_bytes;
}
Z
zhhsplendid 已提交
268

269
size_t GpuInitAllocSize() { return GpuAllocSize(/* realloc = */ false); }
Z
zhhsplendid 已提交
270

271
size_t GpuReallocSize() { return GpuAllocSize(/* realloc = */ true); }
L
liaogang 已提交
272

L
liaogang 已提交
273 274 275 276 277 278
size_t GpuMinChunkSize() {
  // Allow to allocate the minimum chunk size is 256 bytes.
  return 1 << 8;
}

size_t GpuMaxChunkSize() {
279 280 281
  size_t max_chunk_size = GpuMaxAllocSize();
  VLOG(10) << "Max chunk size " << (max_chunk_size >> 20) << "M";
  return max_chunk_size;
L
liaogang 已提交
282 283
}

L
liaogang 已提交
284 285
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
                    enum cudaMemcpyKind kind, cudaStream_t stream) {
286
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync(dst, src, count, kind, stream));
L
liaogang 已提交
287 288
}

289 290
void GpuMemcpySync(void *dst, const void *src, size_t count,
                   enum cudaMemcpyKind kind) {
291
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpy(dst, src, count, kind));
292 293 294 295
}

void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,
                        int src_device, size_t count, cudaStream_t stream) {
296 297
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream));
298 299 300 301
}

void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src,
                       int src_device, size_t count) {
302 303
  PADDLE_ENFORCE_CUDA_SUCCESS(
      cudaMemcpyPeer(dst, dst_device, src, src_device, count));
L
liaogang 已提交
304
}
D
dzhwinter 已提交
305 306

void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) {
307
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync(dst, value, count, stream));
D
dzhwinter 已提交
308
}
309

石晓伟 已提交
310
void GpuStreamSync(cudaStream_t stream) {
311
  PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
石晓伟 已提交
312 313
}

314
static void RaiseNonOutOfMemoryError(cudaError_t *status) {
315 316 317 318 319 320 321 322 323 324 325
  if (*status == cudaErrorMemoryAllocation) {
    *status = cudaSuccess;
  }
  PADDLE_ENFORCE_CUDA_SUCCESS(*status);

  *status = cudaGetLastError();
  if (*status == cudaErrorMemoryAllocation) {
    *status = cudaSuccess;
  }
  PADDLE_ENFORCE_CUDA_SUCCESS(*status);
}
石晓伟 已提交
326

327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
class RecordedCudaMallocHelper {
 private:
  explicit RecordedCudaMallocHelper(int dev_id, uint64_t limit_size = 0)
      : dev_id_(dev_id), limit_size_(limit_size) {
    if (NeedRecord()) {
      mtx_.reset(new std::mutex());
    }
  }

  DISABLE_COPY_AND_ASSIGN(RecordedCudaMallocHelper);

 public:
  static RecordedCudaMallocHelper *Instance(int dev_id) {
    std::call_once(once_flag_, [] {
      int dev_cnt = GetCUDADeviceCount();
      instances_.reserve(dev_cnt);
      for (int i = 0; i < dev_cnt; ++i) {
        instances_.emplace_back(
            new RecordedCudaMallocHelper(i, FLAGS_gpu_memory_limit_mb << 20));
      }
    });

    PADDLE_ENFORCE_GE(
        dev_id, 0,
        platform::errors::OutOfRange(
G
GaoWei8 已提交
352
            "Device id must be not less than 0, but got %d.", dev_id));
353 354
    PADDLE_ENFORCE_LT(
        dev_id, instances_.size(),
G
GaoWei8 已提交
355
        platform::errors::OutOfRange("Device id %d exceeds gpu card number %d.",
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
                                     dev_id, instances_.size()));
    return instances_[dev_id].get();
  }

  /**
   * Try to allocate `size` gpu memory. Only cudaErrorMemoryAllocation
   * or cudaSuccess would be returned, and the cudaGetLastError() flag
   * would be clear.
   */
  cudaError_t Malloc(void **ptr, size_t size) {
    LockGuardPtr<std::mutex> lock(mtx_);
    if (UNLIKELY(NeedRecord() && cur_size_ + size > limit_size_)) {
      return cudaErrorMemoryAllocation;
    }

    CUDADeviceGuard guard(dev_id_);
    auto result = cudaMalloc(ptr, size);
    if (result == cudaSuccess) {
      if (NeedRecord()) {
        cur_size_ += size;
      }
H
hutuxian 已提交
377
      STAT_INT_ADD("STAT_gpu" + std::to_string(dev_id_) + "_mem_size", size);
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400
      return cudaSuccess;
    } else {
      RaiseNonOutOfMemoryError(&result);
      // Non out of memory error would be raised inside
      // RaiseNonOutOfMemoryError. Therefore, we can
      // return cudaErrorMemoryAllocation directly here.
      return cudaErrorMemoryAllocation;
    }
  }

  /**
   * Free gpu memory. Usually, free is not allowed to raise error.
   * If it does raise error, the process should be crashed.
   */
  void Free(void *ptr, size_t size) {
    // Purposefully allow cudaErrorCudartUnloading, because
    // that is returned if you ever call cudaFree after the
    // driver has already shutdown. This happens only if the
    // process is terminating, in which case we don't care if
    // cudaFree succeeds.
    CUDADeviceGuard guard(dev_id_);
    auto err = cudaFree(ptr);
    if (err != cudaErrorCudartUnloading) {
401
      PADDLE_ENFORCE_CUDA_SUCCESS(err);
402 403 404 405
      if (NeedRecord()) {
        std::lock_guard<std::mutex> guard(*mtx_);
        cur_size_ -= size;
      }
H
hutuxian 已提交
406
      STAT_INT_SUB("STAT_gpu" + std::to_string(dev_id_) + "_mem_size", size);
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480
    } else {
      cudaGetLastError();  // clear the error flag when cudaErrorCudartUnloading
    }
  }

  bool GetMemInfo(size_t *avail, size_t *total, size_t *actual_avail,
                  size_t *actual_total) {
    {
      CUDADeviceGuard guard(dev_id_);
      auto result = cudaMemGetInfo(actual_avail, actual_total);
      if (result != cudaSuccess) {
        *actual_avail = 0;
      }
      RaiseNonOutOfMemoryError(&result);
    }

    if (NeedRecord()) {
      std::lock_guard<std::mutex> guard(*mtx_);
      *avail = std::min(*actual_avail, limit_size_ - cur_size_);
      *total = std::min(*actual_total, limit_size_);
      return *total < *actual_total;
    } else {
      *avail = *actual_avail;
      *total = *actual_total;
      return false;
    }
  }

  inline bool NeedRecord() const { return limit_size_ != 0; }

  uint64_t RecordedSize() const {
    LockGuardPtr<std::mutex> lock(mtx_);
    return NeedRecord() ? cur_size_ : 0;
  }

  uint64_t LimitSize() const { return limit_size_; }

 private:
  const int dev_id_;
  const uint64_t limit_size_;
  uint64_t cur_size_{0};

  mutable std::unique_ptr<std::mutex> mtx_;

  static std::once_flag once_flag_;
  static std::vector<std::unique_ptr<RecordedCudaMallocHelper>> instances_;
};

std::once_flag RecordedCudaMallocHelper::once_flag_;
std::vector<std::unique_ptr<RecordedCudaMallocHelper>>
    RecordedCudaMallocHelper::instances_;

cudaError_t RecordedCudaMalloc(void **ptr, size_t size, int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->Malloc(ptr, size);
}

void RecordedCudaFree(void *p, size_t size, int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->Free(p, size);
}

bool RecordedCudaMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail,
                            size_t *actual_total, int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->GetMemInfo(
      avail, total, actual_avail, actual_total);
}

uint64_t RecordedCudaMallocSize(int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->RecordedSize();
}

bool IsCudaMallocRecorded(int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->NeedRecord();
}

L
liaogang 已提交
481 482
}  // namespace platform
}  // namespace paddle