gpu_info.cc 5.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
liaogang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/platform/gpu_info.h"
L
liaogang 已提交
16

17
#include <algorithm>
L
liaogang 已提交
18

19
#include "gflags/gflags.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/platform/enforce.h"
L
liaogang 已提交
21

22
DEFINE_double(fraction_of_gpu_memory_to_use, 0.92,
X
Xin Pan 已提交
23 24 25 26 27
              "Allocate a trunk of gpu memory that is this fraction of the "
              "total gpu memory size. Future memory usage will be allocated "
              "from the trunk. If the trunk doesn't have enough gpu memory, "
              "additional trunks of the same size will be requested from gpu "
              "until the gpu has no memory left for another trunk.");
L
liaogang 已提交
28 29 30 31

namespace paddle {
namespace platform {

32
int GetCUDADeviceCount() {
L
liaogang 已提交
33
  int count;
L
liaogang 已提交
34
  PADDLE_ENFORCE(
L
liaogang 已提交
35
      cudaGetDeviceCount(&count),
36
      "cudaGetDeviceCount failed in paddle::platform::GetCUDADeviceCount");
L
liaogang 已提交
37 38 39
  return count;
}

40 41 42 43 44 45 46 47 48
int GetCUDAComputeCapability(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  cudaDeviceProp device_prop;
  PADDLE_ENFORCE(cudaGetDeviceProperties(&device_prop, id),
                 "cudaGetDeviceProperties failed in "
                 "paddle::platform::GetCUDAComputeCapability");
  return device_prop.major * 10 + device_prop.minor;
}

C
chengduoZH 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
int GetCUDAMultiProcessors(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int count;
  PADDLE_ENFORCE(
      cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id),
      "cudaDeviceGetAttribute failed in "
      "paddle::platform::GetCUDAMultiProcessors");
  return count;
}

int GetCUDAMaxThreadsPerMultiProcessor(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int count;
  PADDLE_ENFORCE(cudaDeviceGetAttribute(
                     &count, cudaDevAttrMaxThreadsPerMultiProcessor, id),
                 "cudaDeviceGetAttribute failed in "
                 "paddle::platform::GetCUDAMaxThreadsPerMultiProcessor");
  return count;
}

L
liaogang 已提交
69 70
int GetCurrentDeviceId() {
  int device_id;
L
liaogang 已提交
71
  PADDLE_ENFORCE(
L
liaogang 已提交
72 73 74 75 76 77
      cudaGetDevice(&device_id),
      "cudaGetDevice failed in paddle::platform::GetCurrentDeviceId");
  return device_id;
}

void SetDeviceId(int id) {
Q
qijun 已提交
78
  // TODO(qijun): find a better way to cache the cuda device count
Y
Yang Yang 已提交
79
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
L
liaogang 已提交
80
  PADDLE_ENFORCE(cudaSetDevice(id),
L
liaogang 已提交
81 82 83
                 "cudaSetDevice failed in paddle::platform::SetDeviceId");
}

84 85
void GpuMemoryUsage(size_t *available, size_t *total) {
  PADDLE_ENFORCE(cudaMemGetInfo(available, total),
L
liaogang 已提交
86 87 88 89 90 91 92
                 "cudaMemGetInfo failed in paddle::platform::GetMemoryUsage");
}

size_t GpuMaxAllocSize() {
  size_t total = 0;
  size_t available = 0;

93
  GpuMemoryUsage(&available, &total);
L
liaogang 已提交
94

L
liaogang 已提交
95
  // Reserve the rest for page tables, etc.
L
liaogang 已提交
96
  return static_cast<size_t>(total * FLAGS_fraction_of_gpu_memory_to_use);
L
liaogang 已提交
97 98
}

L
liaogang 已提交
99 100 101 102 103 104 105
size_t GpuMinChunkSize() {
  // Allow to allocate the minimum chunk size is 256 bytes.
  return 1 << 8;
}

size_t GpuMaxChunkSize() {
  size_t total = 0;
C
chenweihang 已提交
106
  size_t available = 0;
L
liaogang 已提交
107

C
chenweihang 已提交
108 109
  GpuMemoryUsage(&available, &total);
  VLOG(10) << "GPU Usage " << available / 1024 / 1024 << "M/"
110 111
           << total / 1024 / 1024 << "M";
  size_t reserving = static_cast<size_t>(0.05 * total);
L
liaogang 已提交
112
  // If available less than minimum chunk size, no usable memory exists.
C
chenweihang 已提交
113 114 115
  available =
      std::min(std::max(available, GpuMinChunkSize()) - GpuMinChunkSize(),
               total - reserving);
116 117

  // Reserving the rest memory for page tables, etc.
L
liaogang 已提交
118

C
chenweihang 已提交
119 120
  size_t allocating = static_cast<size_t>(FLAGS_fraction_of_gpu_memory_to_use *
                                          (total - reserving));
L
liaogang 已提交
121

C
chenweihang 已提交
122 123
  PADDLE_ENFORCE_LE(allocating, available,
                    "Insufficient GPU memory to allocation.");
124

C
chenweihang 已提交
125
  return allocating;
L
liaogang 已提交
126 127
}

L
liaogang 已提交
128 129
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
                    enum cudaMemcpyKind kind, cudaStream_t stream) {
L
liaogang 已提交
130 131
  PADDLE_ENFORCE(cudaMemcpyAsync(dst, src, count, kind, stream),
                 "cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync");
L
liaogang 已提交
132 133
}

134 135 136 137 138 139 140 141
void GpuMemcpySync(void *dst, const void *src, size_t count,
                   enum cudaMemcpyKind kind) {
  PADDLE_ENFORCE(cudaMemcpy(dst, src, count, kind),
                 "cudaMemcpy failed in paddle::platform::GpuMemcpySync");
}

void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,
                        int src_device, size_t count, cudaStream_t stream) {
L
liaogang 已提交
142
  PADDLE_ENFORCE(
L
liaogang 已提交
143
      cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream),
144 145 146 147 148 149 150 151
      "cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeerAsync");
}

void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src,
                       int src_device, size_t count) {
  PADDLE_ENFORCE(
      cudaMemcpyPeer(dst, dst_device, src, src_device, count),
      "cudaMemcpyPeer failed in paddle::platform::GpuMemcpyPeerSync");
L
liaogang 已提交
152
}
D
dzhwinter 已提交
153 154 155 156 157

void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) {
  PADDLE_ENFORCE(cudaMemsetAsync(dst, value, count, stream),
                 "cudaMemsetAsync failed in paddle::platform::GpuMemsetAsync");
}
L
liaogang 已提交
158 159
}  // namespace platform
}  // namespace paddle