cuda_info.cc 9.9 KB
Newer Older
W
Wilber 已提交
1
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

W
Wilber 已提交
15 16 17
#include "paddle/pten/backends/gpu/gpu_info.h"

// TODO(pten): remove fluid headers.
18 19 20 21
#include "paddle/fluid/platform/enforce.h"

static std::once_flag g_device_props_size_init_flag;
static std::vector<std::unique_ptr<std::once_flag>> g_device_props_init_flags;
W
Wilber 已提交
22 23 24 25 26
static std::vector<pten::gpuDeviceProp> g_device_props;

namespace pten {
namespace backends {
namespace gpu {
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76

int DnnVersion() {
  if (!dynload::HasCUDNN()) return -1;
  return dynload::cudnnGetVersion();
}

static int GetGPUDeviceCountImpl() {
  int driverVersion = 0;
  cudaError_t status = cudaDriverGetVersion(&driverVersion);

  if (!(status == gpuSuccess && driverVersion != 0)) {
    // No GPU driver
    VLOG(2) << "GPU Driver Version can't be detected. No GPU driver!";
    return 0;
  }

  const auto *cuda_visible_devices = std::getenv("CUDA_VISIBLE_DEVICES");

  if (cuda_visible_devices != nullptr) {
    std::string cuda_visible_devices_str(cuda_visible_devices);
    if (!cuda_visible_devices_str.empty()) {
      cuda_visible_devices_str.erase(
          0, cuda_visible_devices_str.find_first_not_of('\''));
      cuda_visible_devices_str.erase(
          cuda_visible_devices_str.find_last_not_of('\'') + 1);
      cuda_visible_devices_str.erase(
          0, cuda_visible_devices_str.find_first_not_of('\"'));
      cuda_visible_devices_str.erase(
          cuda_visible_devices_str.find_last_not_of('\"') + 1);
    }
    if (std::all_of(cuda_visible_devices_str.begin(),
                    cuda_visible_devices_str.end(),
                    [](char ch) { return ch == ' '; })) {
      VLOG(2) << "CUDA_VISIBLE_DEVICES is set to be "
                 "empty. No GPU detected.";
      return 0;
    }
  }
  int count;
  PADDLE_ENFORCE_GPU_SUCCESS(cudaGetDeviceCount(&count));
  return count;
}

int GetGPUDeviceCount() {
  // cache the count
  static auto dev_cnt = GetGPUDeviceCountImpl();
  return dev_cnt;
}

int GetGPUComputeCapability(int id) {
W
Wilber 已提交
77 78 79
  PADDLE_ENFORCE_LT(id,
                    GetGPUDeviceCount(),
                    paddle::platform::errors::InvalidArgument(
80 81
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
W
Wilber 已提交
82 83
                        id,
                        GetGPUDeviceCount()));
84 85 86 87 88 89 90 91 92 93 94 95
  int major, minor;
  auto major_error_code =
      cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id);
  auto minor_error_code =
      cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, id);

  PADDLE_ENFORCE_GPU_SUCCESS(major_error_code);
  PADDLE_ENFORCE_GPU_SUCCESS(minor_error_code);
  return major * 10 + minor;
}

int GetGPURuntimeVersion(int id) {
W
Wilber 已提交
96 97 98
  PADDLE_ENFORCE_LT(id,
                    GetGPUDeviceCount(),
                    paddle::platform::errors::InvalidArgument(
99 100
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
W
Wilber 已提交
101 102
                        id,
                        GetGPUDeviceCount()));
103 104 105 106 107 108
  int runtime_version = 0;
  PADDLE_ENFORCE_GPU_SUCCESS(cudaRuntimeGetVersion(&runtime_version));
  return runtime_version;
}

int GetGPUDriverVersion(int id) {
W
Wilber 已提交
109 110 111
  PADDLE_ENFORCE_LT(id,
                    GetGPUDeviceCount(),
                    paddle::platform::errors::InvalidArgument(
112 113
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
W
Wilber 已提交
114 115
                        id,
                        GetGPUDeviceCount()));
116 117 118 119 120 121 122 123 124 125 126 127
  int driver_version = 0;
  PADDLE_ENFORCE_GPU_SUCCESS(cudaDriverGetVersion(&driver_version));
  return driver_version;
}

bool TensorCoreAvailable() {
  int device = GetCurrentDeviceId();
  int driver_version = GetGPUComputeCapability(device);
  return driver_version >= 70;
}

int GetGPUMultiProcessors(int id) {
W
Wilber 已提交
128 129 130
  PADDLE_ENFORCE_LT(id,
                    GetGPUDeviceCount(),
                    paddle::platform::errors::InvalidArgument(
131 132
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
W
Wilber 已提交
133 134
                        id,
                        GetGPUDeviceCount()));
135 136 137 138 139 140 141
  int count;
  PADDLE_ENFORCE_GPU_SUCCESS(
      cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id));
  return count;
}

int GetGPUMaxThreadsPerMultiProcessor(int id) {
W
Wilber 已提交
142 143 144
  PADDLE_ENFORCE_LT(id,
                    GetGPUDeviceCount(),
                    paddle::platform::errors::InvalidArgument(
145 146
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
W
Wilber 已提交
147 148
                        id,
                        GetGPUDeviceCount()));
149 150 151 152 153 154 155 156
  int count;
  PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute(
      &count, cudaDevAttrMaxThreadsPerMultiProcessor, id));

  return count;
}

int GetGPUMaxThreadsPerBlock(int id) {
W
Wilber 已提交
157 158 159
  PADDLE_ENFORCE_LT(id,
                    GetGPUDeviceCount(),
                    paddle::platform::errors::InvalidArgument(
160 161
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
W
Wilber 已提交
162 163
                        id,
                        GetGPUDeviceCount()));
164 165 166 167 168 169 170 171 172 173 174 175
  int count;
  PADDLE_ENFORCE_GPU_SUCCESS(
      cudaDeviceGetAttribute(&count, cudaDevAttrMaxThreadsPerBlock, id));
  return count;
}

int GetCurrentDeviceId() {
  int device_id;
  PADDLE_ENFORCE_GPU_SUCCESS(cudaGetDevice(&device_id));
  return device_id;
}

W
Wilber 已提交
176 177 178 179
std::array<int, 3> GetGpuMaxGridDimSize(int id) {
  PADDLE_ENFORCE_LT(id,
                    GetGPUDeviceCount(),
                    paddle::platform::errors::InvalidArgument(
180 181
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
W
Wilber 已提交
182 183 184
                        id,
                        GetGPUDeviceCount()));
  std::array<int, 3> ret;
185 186 187
  int size;
  auto error_code_x = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimX, id);
  PADDLE_ENFORCE_GPU_SUCCESS(error_code_x);
W
Wilber 已提交
188
  ret[0] = size;
189 190 191

  auto error_code_y = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimY, id);
  PADDLE_ENFORCE_GPU_SUCCESS(error_code_y);
W
Wilber 已提交
192
  ret[1] = size;
193 194 195

  auto error_code_z = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimZ, id);
  PADDLE_ENFORCE_GPU_SUCCESS(error_code_z);
W
Wilber 已提交
196
  ret[2] = size;
197 198 199 200 201 202
  return ret;
}

const gpuDeviceProp &GetDeviceProperties(int id) {
  std::call_once(g_device_props_size_init_flag, [&] {
    int gpu_num = 0;
W
Wilber 已提交
203
    gpu_num = GetGPUDeviceCount();
204 205 206 207 208 209 210 211
    g_device_props_init_flags.resize(gpu_num);
    g_device_props.resize(gpu_num);
    for (int i = 0; i < gpu_num; ++i) {
      g_device_props_init_flags[i] = std::make_unique<std::once_flag>();
    }
  });

  if (id == -1) {
W
Wilber 已提交
212
    id = GetCurrentDeviceId();
213 214 215
  }

  if (id < 0 || id >= static_cast<int>(g_device_props.size())) {
W
Wilber 已提交
216
    PADDLE_THROW(paddle::platform::errors::OutOfRange(
217 218 219 220
        "The device id %d is out of range [0, %d), where %d is the number of "
        "devices on this machine. Because the device id should be greater than "
        "or equal to zero and smaller than the number of gpus. Please input "
        "appropriate device again!",
W
Wilber 已提交
221 222
        id,
        static_cast<int>(g_device_props.size()),
223 224 225 226 227 228 229 230 231 232 233 234 235
        static_cast<int>(g_device_props.size())));
  }

  std::call_once(*(g_device_props_init_flags[id]), [&] {
    PADDLE_ENFORCE_GPU_SUCCESS(
        cudaGetDeviceProperties(&g_device_props[id], id));
  });

  return g_device_props[id];
}

void SetDeviceId(int id) {
  // TODO(qijun): find a better way to cache the cuda device count
W
Wilber 已提交
236 237 238
  PADDLE_ENFORCE_LT(id,
                    GetGPUDeviceCount(),
                    paddle::platform::errors::InvalidArgument(
239 240
                        "Device id must be less than GPU count, "
                        "but received id is: %d. GPU count is: %d.",
W
Wilber 已提交
241 242
                        id,
                        GetGPUDeviceCount()));
243 244 245
  PADDLE_RETRY_CUDA_SUCCESS(cudaSetDevice(id));
}

W
Wilber 已提交
246 247 248 249 250
void GpuMemcpyAsync(void *dst,
                    const void *src,
                    size_t count,
                    gpuMemcpyKind kind,
                    gpuStream_t stream) {
251 252 253
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(dst, src, count, kind, stream));
}

W
Wilber 已提交
254 255 256
void GpuMemcpySync(void *dst,
                   const void *src,
                   size_t count,
257 258 259 260
                   gpuMemcpyKind kind) {
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(dst, src, count, kind));
}

W
Wilber 已提交
261 262 263 264 265 266
void GpuMemcpyPeerAsync(void *dst,
                        int dst_device,
                        const void *src,
                        int src_device,
                        size_t count,
                        gpuStream_t stream) {
267 268 269 270
  PADDLE_ENFORCE_GPU_SUCCESS(
      cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream));
}

W
Wilber 已提交
271 272
void GpuMemcpyPeerSync(
    void *dst, int dst_device, const void *src, int src_device, size_t count) {
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
  PADDLE_ENFORCE_GPU_SUCCESS(
      cudaMemcpyPeer(dst, dst_device, src, src_device, count));
}

void GpuMemsetAsync(void *dst, int value, size_t count, gpuStream_t stream) {
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(dst, value, count, stream));
}

void GpuStreamSync(gpuStream_t stream) {
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
}

void GpuDestroyStream(gpuStream_t stream) {
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream));
}

void GpuDeviceSync() { PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); }

gpuError_t GpuGetLastError() { return cudaGetLastError(); }
W
Wilber 已提交
292 293 294 295

}  // namespace gpu
}  // namespace backends
}  // namespace pten