gpu_info.cc 7.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
liaogang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/platform/gpu_info.h"
L
liaogang 已提交
16

17
#include <algorithm>
L
liaogang 已提交
18

19
#include "gflags/gflags.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/platform/enforce.h"
L
liaogang 已提交
21

22
#ifndef _WIN32
P
peizhilin 已提交
23
constexpr static float fraction_of_gpu_memory_to_use = 0.92f;
24
#else
P
peizhilin 已提交
25 26 27
// fraction_of_gpu_memory_to_use cannot be too high on windows,
// since the win32 graphic sub-system can occupy some GPU memory
// which may lead to insufficient memory left for paddle
P
peizhilin 已提交
28
constexpr static float fraction_of_gpu_memory_to_use = 0.5f;
29 30 31
#endif

DEFINE_double(fraction_of_gpu_memory_to_use, fraction_of_gpu_memory_to_use,
X
Xin Pan 已提交
32 33 34 35 36
              "Allocate a trunk of gpu memory that is this fraction of the "
              "total gpu memory size. Future memory usage will be allocated "
              "from the trunk. If the trunk doesn't have enough gpu memory, "
              "additional trunks of the same size will be requested from gpu "
              "until the gpu has no memory left for another trunk.");
L
liaogang 已提交
37

38 39 40 41 42 43 44 45 46 47
DEFINE_bool(
    enable_cublas_tensor_op_math, false,
    "The enable_cublas_tensor_op_math indicate whether to use Tensor Core, "
    "but it may loss precision. Currently, There are two CUDA libraries that"
    " use Tensor Cores, cuBLAS and cuDNN. cuBLAS uses Tensor Cores to speed up"
    " GEMM computations(the matrices must be either half precision or single "
    "precision); cuDNN uses Tensor Cores to speed up both convolutions(the "
    "input and output must be half precision) and recurrent neural networks "
    "(RNNs).");

L
liaogang 已提交
48 49 50
namespace paddle {
namespace platform {

51
int GetCUDADeviceCount() {
L
liaogang 已提交
52
  int count;
L
liaogang 已提交
53
  PADDLE_ENFORCE(
L
liaogang 已提交
54
      cudaGetDeviceCount(&count),
55
      "cudaGetDeviceCount failed in paddle::platform::GetCUDADeviceCount");
L
liaogang 已提交
56 57 58
  return count;
}

59 60 61 62 63 64 65 66 67
int GetCUDAComputeCapability(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  cudaDeviceProp device_prop;
  PADDLE_ENFORCE(cudaGetDeviceProperties(&device_prop, id),
                 "cudaGetDeviceProperties failed in "
                 "paddle::platform::GetCUDAComputeCapability");
  return device_prop.major * 10 + device_prop.minor;
}

C
chengduo 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
int GetCUDARuntimeVersion(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int runtime_version = 0;
  PADDLE_ENFORCE(cudaRuntimeGetVersion(&runtime_version),
                 "cudaRuntimeGetVersion failed in "
                 "paddle::platform::cudaRuntimeGetVersion");
  return runtime_version;
}

int GetCUDADriverVersion(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int driver_version = 0;
  PADDLE_ENFORCE(cudaDriverGetVersion(&driver_version),
                 "cudaDriverGetVersion failed in "
                 "paddle::platform::GetCUDADriverVersion");
  return driver_version;
}

86 87 88 89 90 91 92 93 94 95
bool TensorCoreAvailable() {
#if CUDA_VERSION >= 9000
  int device = GetCurrentDeviceId();
  int driver_version = GetCUDAComputeCapability(device);
  return driver_version >= 70;
#else
  return false;
#endif
}

C
chengduoZH 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
int GetCUDAMultiProcessors(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int count;
  PADDLE_ENFORCE(
      cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id),
      "cudaDeviceGetAttribute failed in "
      "paddle::platform::GetCUDAMultiProcessors");
  return count;
}

int GetCUDAMaxThreadsPerMultiProcessor(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int count;
  PADDLE_ENFORCE(cudaDeviceGetAttribute(
                     &count, cudaDevAttrMaxThreadsPerMultiProcessor, id),
                 "cudaDeviceGetAttribute failed in "
                 "paddle::platform::GetCUDAMaxThreadsPerMultiProcessor");
  return count;
}

L
liaogang 已提交
116 117
int GetCurrentDeviceId() {
  int device_id;
L
liaogang 已提交
118
  PADDLE_ENFORCE(
L
liaogang 已提交
119 120 121 122 123 124
      cudaGetDevice(&device_id),
      "cudaGetDevice failed in paddle::platform::GetCurrentDeviceId");
  return device_id;
}

void SetDeviceId(int id) {
Q
qijun 已提交
125
  // TODO(qijun): find a better way to cache the cuda device count
Y
Yang Yang 已提交
126
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
L
liaogang 已提交
127
  PADDLE_ENFORCE(cudaSetDevice(id),
L
liaogang 已提交
128 129 130
                 "cudaSetDevice failed in paddle::platform::SetDeviceId");
}

131 132
void GpuMemoryUsage(size_t *available, size_t *total) {
  PADDLE_ENFORCE(cudaMemGetInfo(available, total),
L
liaogang 已提交
133 134 135 136 137 138 139
                 "cudaMemGetInfo failed in paddle::platform::GetMemoryUsage");
}

size_t GpuMaxAllocSize() {
  size_t total = 0;
  size_t available = 0;

140
  GpuMemoryUsage(&available, &total);
L
liaogang 已提交
141

L
liaogang 已提交
142
  // Reserve the rest for page tables, etc.
L
liaogang 已提交
143
  return static_cast<size_t>(total * FLAGS_fraction_of_gpu_memory_to_use);
L
liaogang 已提交
144 145
}

L
liaogang 已提交
146 147 148 149 150 151 152
size_t GpuMinChunkSize() {
  // Allow to allocate the minimum chunk size is 256 bytes.
  return 1 << 8;
}

size_t GpuMaxChunkSize() {
  size_t total = 0;
C
chenweihang 已提交
153
  size_t available = 0;
L
liaogang 已提交
154

C
chenweihang 已提交
155
  GpuMemoryUsage(&available, &total);
M
minqiyang 已提交
156 157
  VLOG(10) << "GPU Usage " << available / 1024 / 1024 << "M/"
           << total / 1024 / 1024 << "M";
158
  size_t reserving = static_cast<size_t>(0.05 * total);
L
liaogang 已提交
159
  // If available less than minimum chunk size, no usable memory exists.
C
chenweihang 已提交
160 161 162
  available =
      std::min(std::max(available, GpuMinChunkSize()) - GpuMinChunkSize(),
               total - reserving);
163 164

  // Reserving the rest memory for page tables, etc.
L
liaogang 已提交
165

C
chenweihang 已提交
166 167
  size_t allocating = static_cast<size_t>(FLAGS_fraction_of_gpu_memory_to_use *
                                          (total - reserving));
L
liaogang 已提交
168

C
chenweihang 已提交
169 170
  PADDLE_ENFORCE_LE(allocating, available,
                    "Insufficient GPU memory to allocation.");
171

C
chenweihang 已提交
172
  return allocating;
L
liaogang 已提交
173 174
}

L
liaogang 已提交
175 176
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
                    enum cudaMemcpyKind kind, cudaStream_t stream) {
L
liaogang 已提交
177 178
  PADDLE_ENFORCE(cudaMemcpyAsync(dst, src, count, kind, stream),
                 "cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync");
L
liaogang 已提交
179 180
}

181 182 183 184 185 186 187 188
void GpuMemcpySync(void *dst, const void *src, size_t count,
                   enum cudaMemcpyKind kind) {
  PADDLE_ENFORCE(cudaMemcpy(dst, src, count, kind),
                 "cudaMemcpy failed in paddle::platform::GpuMemcpySync");
}

void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,
                        int src_device, size_t count, cudaStream_t stream) {
L
liaogang 已提交
189
  PADDLE_ENFORCE(
L
liaogang 已提交
190
      cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream),
191 192 193 194 195 196 197 198
      "cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeerAsync");
}

void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src,
                       int src_device, size_t count) {
  PADDLE_ENFORCE(
      cudaMemcpyPeer(dst, dst_device, src, src_device, count),
      "cudaMemcpyPeer failed in paddle::platform::GpuMemcpyPeerSync");
L
liaogang 已提交
199
}
D
dzhwinter 已提交
200 201 202 203 204

void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) {
  PADDLE_ENFORCE(cudaMemsetAsync(dst, value, count, stream),
                 "cudaMemsetAsync failed in paddle::platform::GpuMemsetAsync");
}
L
liaogang 已提交
205 206
}  // namespace platform
}  // namespace paddle