gpu_info.cc 12.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
liaogang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/platform/gpu_info.h"
16
#include <algorithm>
S
sneaxiy 已提交
17 18
#include <cstdlib>
#include <string>
L
liaogang 已提交
19

20
#include "gflags/gflags.h"
Y
Yi Wang 已提交
21
#include "paddle/fluid/platform/enforce.h"
22
#include "paddle/fluid/string/split.h"
L
liaogang 已提交
23

24
#ifndef _WIN32
P
peizhilin 已提交
25
constexpr static float fraction_of_gpu_memory_to_use = 0.92f;
26
#else
P
peizhilin 已提交
27 28 29
// fraction_of_gpu_memory_to_use cannot be too high on windows,
// since the win32 graphic sub-system can occupy some GPU memory
// which may lead to insufficient memory left for paddle
P
peizhilin 已提交
30
constexpr static float fraction_of_gpu_memory_to_use = 0.5f;
31 32
#endif

Z
zhhsplendid 已提交
33 34
constexpr static float fraction_reserve_gpu_memory = 0.05f;

35
DEFINE_double(fraction_of_gpu_memory_to_use, fraction_of_gpu_memory_to_use,
X
Xin Pan 已提交
36 37 38 39 40
              "Allocate a trunk of gpu memory that is this fraction of the "
              "total gpu memory size. Future memory usage will be allocated "
              "from the trunk. If the trunk doesn't have enough gpu memory, "
              "additional trunks of the same size will be requested from gpu "
              "until the gpu has no memory left for another trunk.");
L
liaogang 已提交
41

42 43 44 45
DEFINE_uint64(
    initial_gpu_memory_in_mb, 0ul,
    "Allocate a trunk of gpu memory whose byte size is specified by "
    "the flag. Future memory usage will be allocated from the "
46
    "trunk. If the trunk doesn't have enough gpu memory, additional "
47 48 49 50 51 52 53 54 55
    "trunks of the gpu memory will be requested from gpu with size "
    "specified by FLAGS_reallocate_gpu_memory_in_mb until the gpu has "
    "no memory left for the additional trunk. Note: if you set this "
    "flag, the memory size set by "
    "FLAGS_fraction_of_gpu_memory_to_use will be overrided by this "
    "flag. If you don't set this flag, PaddlePaddle will use "
    "FLAGS_fraction_of_gpu_memory_to_use to allocate gpu memory");

DEFINE_uint64(reallocate_gpu_memory_in_mb, 0ul,
Z
zhhsplendid 已提交
56 57 58 59
              "If this flag is set, Paddle will reallocate the gpu memory with "
              "size specified by this flag. Else Paddle will reallocate by "
              "FLAGS_fraction_of_gpu_memory_to_use");

60 61 62 63 64 65 66 67 68 69
DEFINE_bool(
    enable_cublas_tensor_op_math, false,
    "The enable_cublas_tensor_op_math indicate whether to use Tensor Core, "
    "but it may loss precision. Currently, There are two CUDA libraries that"
    " use Tensor Cores, cuBLAS and cuDNN. cuBLAS uses Tensor Cores to speed up"
    " GEMM computations(the matrices must be either half precision or single "
    "precision); cuDNN uses Tensor Cores to speed up both convolutions(the "
    "input and output must be half precision) and recurrent neural networks "
    "(RNNs).");

70 71 72 73 74 75 76 77 78
DEFINE_string(selected_gpus, "",
              "A list of device ids separated by comma, like: 0,1,2,3. "
              "This option is useful when doing multi process training and "
              "each process have only one device (GPU). If you want to use "
              "all visible devices, set this to empty string. NOTE: the "
              "reason of doing this is that we want to use P2P communication"
              "between GPU devices, use CUDA_VISIBLE_DEVICES can only use"
              "share-memory only.");

L
liaogang 已提交
79 80 81
namespace paddle {
namespace platform {

82 83 84 85 86 87
inline std::string CudaErrorWebsite() {
  return "Please see detail in https://docs.nvidia.com/cuda/cuda-runtime-api"
         "/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3f51e3575c217824"
         "6db0a94a430e0038";
}

S
sneaxiy 已提交
88 89 90 91 92 93 94
static int GetCUDADeviceCountImpl() {
  const auto *cuda_visible_devices = std::getenv("CUDA_VISIBLE_DEVICES");
  if (cuda_visible_devices != nullptr) {
    std::string cuda_visible_devices_str(cuda_visible_devices);
    if (std::all_of(cuda_visible_devices_str.begin(),
                    cuda_visible_devices_str.end(),
                    [](char ch) { return ch == ' '; })) {
S
sneaxiy 已提交
95
      VLOG(2) << "CUDA_VISIBLE_DEVICES is set to be empty. No GPU detected.";
S
sneaxiy 已提交
96 97 98 99
      return 0;
    }
  }

L
liaogang 已提交
100
  int count;
101
  auto error_code = cudaGetDeviceCount(&count);
L
liaogang 已提交
102
  PADDLE_ENFORCE(
103 104 105 106
      error_code,
      "cudaGetDeviceCount failed in "
      "paddle::platform::GetCUDADeviceCountImpl, error code : %d, %s",
      error_code, CudaErrorWebsite());
L
liaogang 已提交
107 108 109
  return count;
}

S
sneaxiy 已提交
110 111 112 113 114
int GetCUDADeviceCount() {
  static auto dev_cnt = GetCUDADeviceCountImpl();
  return dev_cnt;
}

115 116 117
int GetCUDAComputeCapability(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  cudaDeviceProp device_prop;
118
  auto error_code = cudaGetDeviceProperties(&device_prop, id);
119 120 121 122 123
  PADDLE_ENFORCE(
      error_code,
      "cudaGetDeviceProperties failed in "
      "paddle::platform::GetCUDAComputeCapability, error code : %d, %s",
      error_code, CudaErrorWebsite());
124 125 126
  return device_prop.major * 10 + device_prop.minor;
}

C
chengduo 已提交
127 128 129
int GetCUDARuntimeVersion(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int runtime_version = 0;
130 131
  auto error_code = cudaRuntimeGetVersion(&runtime_version);
  PADDLE_ENFORCE(error_code,
C
chengduo 已提交
132
                 "cudaRuntimeGetVersion failed in "
133 134
                 "paddle::platform::GetCUDARuntimeVersion, error code : %d, %s",
                 error_code, CudaErrorWebsite());
C
chengduo 已提交
135 136 137 138 139 140
  return runtime_version;
}

int GetCUDADriverVersion(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int driver_version = 0;
141 142
  auto error_code = cudaDriverGetVersion(&driver_version);
  PADDLE_ENFORCE(error_code,
C
chengduo 已提交
143
                 "cudaDriverGetVersion failed in "
144 145
                 "paddle::platform::GetCUDADriverVersion, error code : %d, %s",
                 error_code, CudaErrorWebsite());
C
chengduo 已提交
146 147 148
  return driver_version;
}

149 150 151 152 153 154 155 156 157 158
bool TensorCoreAvailable() {
#if CUDA_VERSION >= 9000
  int device = GetCurrentDeviceId();
  int driver_version = GetCUDAComputeCapability(device);
  return driver_version >= 70;
#else
  return false;
#endif
}

C
chengduoZH 已提交
159 160 161
int GetCUDAMultiProcessors(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int count;
162 163 164 165 166 167
  auto error_code =
      cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id);
  PADDLE_ENFORCE(error_code,
                 "cudaDeviceGetAttribute failed in "
                 "paddle::platform::GetCUDAMultiProcess, error code : %d, %s",
                 error_code, CudaErrorWebsite());
C
chengduoZH 已提交
168 169 170 171 172 173
  return count;
}

int GetCUDAMaxThreadsPerMultiProcessor(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int count;
174 175 176 177 178 179 180
  auto error_code = cudaDeviceGetAttribute(
      &count, cudaDevAttrMaxThreadsPerMultiProcessor, id);
  PADDLE_ENFORCE(
      error_code,
      "cudaDeviceGetAttribute failed in paddle::"
      "platform::GetCUDAMaxThreadsPerMultiProcessor, error code : %d, %s",
      error_code, CudaErrorWebsite());
C
chengduoZH 已提交
181 182 183
  return count;
}

L
liaogang 已提交
184 185
int GetCurrentDeviceId() {
  int device_id;
186 187 188 189 190
  auto error_code = cudaGetDevice(&device_id);
  PADDLE_ENFORCE(error_code,
                 "cudaGetDevice failed in "
                 "paddle::platform::GetCurrentDeviceId, error code : %d, %s",
                 error_code, CudaErrorWebsite());
L
liaogang 已提交
191 192 193
  return device_id;
}

194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
//! Get a list of device ids from environment variable or use all.
std::vector<int> GetSelectedDevices() {
  // use user specified GPUs in single-node multi-process mode.
  std::vector<int> devices;
  if (!FLAGS_selected_gpus.empty()) {
    auto devices_str = paddle::string::Split(FLAGS_selected_gpus, ',');
    for (auto id : devices_str) {
      devices.push_back(atoi(id.c_str()));
    }
  } else {
    int count = GetCUDADeviceCount();
    for (int i = 0; i < count; ++i) {
      devices.push_back(i);
    }
  }
  return devices;
}

L
liaogang 已提交
212
void SetDeviceId(int id) {
Q
qijun 已提交
213
  // TODO(qijun): find a better way to cache the cuda device count
Y
Yang Yang 已提交
214
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
215 216 217 218 219
  auto error_code = cudaSetDevice(id);
  PADDLE_ENFORCE(error_code,
                 "cudaSetDevice failed in "
                 "paddle::platform::SetDeviced, error code : %d, %s",
                 error_code, CudaErrorWebsite());
L
liaogang 已提交
220 221
}

222
void GpuMemoryUsage(size_t *available, size_t *total) {
223 224 225 226 227
  auto error_code = cudaMemGetInfo(available, total);
  PADDLE_ENFORCE(error_code,
                 "cudaMemGetInfo failed in "
                 "paddle::platform::GetMemoryUsage, error code : %d, %s",
                 error_code, CudaErrorWebsite());
L
liaogang 已提交
228 229 230
}

size_t GpuMaxAllocSize() {
Z
zhhsplendid 已提交
231 232 233 234
  return std::max(GpuInitAllocSize(), GpuReallocSize());
}

size_t GpuInitAllocSize() {
235 236 237
  if (FLAGS_initial_gpu_memory_in_mb > 0ul) {
    // Initial memory will be allocated by FLAGS_initial_gpu_memory_in_mb
    return static_cast<size_t>(FLAGS_initial_gpu_memory_in_mb << 20);
Z
zhhsplendid 已提交
238 239
  }

240
  // FLAGS_initial_gpu_memory_in_mb is 0, initial memory will be allocated by
Z
zhhsplendid 已提交
241
  // fraction
L
liaogang 已提交
242 243 244
  size_t total = 0;
  size_t available = 0;

245
  GpuMemoryUsage(&available, &total);
Z
zhhsplendid 已提交
246
  size_t reserving = static_cast<size_t>(fraction_reserve_gpu_memory * total);
L
liaogang 已提交
247

Z
zhhsplendid 已提交
248 249 250 251 252
  return static_cast<size_t>((total - reserving) *
                             FLAGS_fraction_of_gpu_memory_to_use);
}

size_t GpuReallocSize() {
253
  if (FLAGS_reallocate_gpu_memory_in_mb > 0ul) {
254 255
    // Additional memory will be allocated by
    // FLAGS_reallocate_gpu_memory_in_mb
256
    return static_cast<size_t>(FLAGS_reallocate_gpu_memory_in_mb << 20);
Z
zhhsplendid 已提交
257 258
  }

259 260
  // FLAGS_reallocate_gpu_memory_in_mb is 0, additional memory will be
  // allocated
Z
zhhsplendid 已提交
261 262 263 264 265 266 267 268 269
  // by fraction
  size_t total = 0;
  size_t available = 0;

  GpuMemoryUsage(&available, &total);
  size_t reserving = static_cast<size_t>(fraction_reserve_gpu_memory * total);

  return static_cast<size_t>((total - reserving) *
                             FLAGS_fraction_of_gpu_memory_to_use);
L
liaogang 已提交
270 271
}

L
liaogang 已提交
272 273 274 275 276 277 278
size_t GpuMinChunkSize() {
  // Allow to allocate the minimum chunk size is 256 bytes.
  return 1 << 8;
}

size_t GpuMaxChunkSize() {
  size_t total = 0;
C
chenweihang 已提交
279
  size_t available = 0;
L
liaogang 已提交
280

C
chenweihang 已提交
281
  GpuMemoryUsage(&available, &total);
M
minqiyang 已提交
282 283
  VLOG(10) << "GPU Usage " << available / 1024 / 1024 << "M/"
           << total / 1024 / 1024 << "M";
Z
zhhsplendid 已提交
284
  size_t reserving = static_cast<size_t>(fraction_reserve_gpu_memory * total);
L
liaogang 已提交
285
  // If available less than minimum chunk size, no usable memory exists.
C
chenweihang 已提交
286 287 288
  available =
      std::min(std::max(available, GpuMinChunkSize()) - GpuMinChunkSize(),
               total - reserving);
289

Z
zhhsplendid 已提交
290
  size_t allocating = GpuMaxAllocSize();
L
liaogang 已提交
291

C
chenweihang 已提交
292 293
  PADDLE_ENFORCE_LE(allocating, available,
                    "Insufficient GPU memory to allocation.");
294

C
chenweihang 已提交
295
  return allocating;
L
liaogang 已提交
296 297
}

L
liaogang 已提交
298 299
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
                    enum cudaMemcpyKind kind, cudaStream_t stream) {
300 301
  auto error_code = cudaMemcpyAsync(dst, src, count, kind, stream);
  PADDLE_ENFORCE(error_code,
302
                 "cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync "
303 304 305
                 "(%p -> %p, length: %d) error code : %d, %s",
                 src, dst, static_cast<int>(count), error_code,
                 CudaErrorWebsite());
L
liaogang 已提交
306 307
}

308 309
void GpuMemcpySync(void *dst, const void *src, size_t count,
                   enum cudaMemcpyKind kind) {
310 311 312 313 314 315
  auto error_code = cudaMemcpy(dst, src, count, kind);
  PADDLE_ENFORCE(error_code,
                 "cudaMemcpy failed in paddle::platform::GpuMemcpySync "
                 "(%p -> %p, length: %d) error code : %d, %s",
                 src, dst, static_cast<int>(count), error_code,
                 CudaErrorWebsite());
316 317 318 319
}

void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,
                        int src_device, size_t count, cudaStream_t stream) {
320 321
  auto error_code =
      cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream);
L
liaogang 已提交
322
  PADDLE_ENFORCE(
323 324 325 326
      error_code,
      "cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeerAsync "
      "error code : %d, %s",
      error_code, CudaErrorWebsite());
327 328 329 330
}

void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src,
                       int src_device, size_t count) {
331 332 333 334 335
  auto error_code = cudaMemcpyPeer(dst, dst_device, src, src_device, count);
  PADDLE_ENFORCE(error_code,
                 "cudaMemcpyPeer failed in paddle::platform::GpuMemcpyPeerSync "
                 "error code : %d, %s",
                 error_code, CudaErrorWebsite());
L
liaogang 已提交
336
}
D
dzhwinter 已提交
337 338

void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) {
339 340 341 342 343
  auto error_code = cudaMemsetAsync(dst, value, count, stream);
  PADDLE_ENFORCE(error_code,
                 "cudaMemsetAsync failed in paddle::platform::GpuMemsetAsync "
                 "error code : %d, %s",
                 error_code, CudaErrorWebsite());
D
dzhwinter 已提交
344
}
L
liaogang 已提交
345 346
}  // namespace platform
}  // namespace paddle