device_code.cc 10.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/platform/device_code.h"
16
#include <sys/stat.h>
17
#include <algorithm>
18 19
#include <set>
#include <utility>
20 21
#include "paddle/fluid/platform/enforce.h"

22 23
DECLARE_string(cuda_dir);

24 25 26
namespace paddle {
namespace platform {

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
DeviceCodePool* DeviceCodePool::pool = nullptr;

void DeviceCodePool::Set(std::unique_ptr<DeviceCode>&& code) {
  Place place = code->GetPlace();
  std::string name = code->GetName();

  auto iter = device_codes_.find(place);
  if (iter == device_codes_.end()) {
    PADDLE_THROW(platform::errors::NotFound(
        "Place %s is not supported for runtime compiling.", place));
  }

  auto& codes_map = iter->second;
  codes_map.emplace(name, std::move(code));
}

platform::DeviceCode* DeviceCodePool::Get(const platform::Place& place,
                                          const std::string& name) {
  auto iter = device_codes_.find(place);
  if (iter == device_codes_.end()) {
    PADDLE_THROW(platform::errors::NotFound(
        "Place %s is not supported for runtime compiling.", place));
  }

  auto& codes_map = iter->second;
  auto code_iter = codes_map.find(name);
  if (code_iter == codes_map.end()) {
    PADDLE_THROW(platform::errors::NotFound(
        "Device code named %s for place %s does not exist.", name.c_str(),
        place));
  }

  return code_iter->second.get();
}
61

62
DeviceCodePool::DeviceCodePool(const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
63 64 65 66
  PADDLE_ENFORCE_GT(places.size(), 0,
                    errors::InvalidArgument(
                        "Expected the number of places >= 1. But received %d.",
                        places.size()));
67 68 69 70 71 72 73 74 75
  // Remove the duplicated places
  std::set<Place> set;
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (is_gpu_place(p)) {
#ifdef PADDLE_WITH_CUDA
      device_codes_.emplace(p, DeviceCodeMap());
76
#else
77 78
      PADDLE_THROW(platform::errors::PreconditionNotMet(
          "CUDAPlace is not supported, please re-compile with WITH_GPU=ON."));
79
#endif
80 81
    }
  }
82 83 84 85

#ifdef PADDLE_WITH_CUDA
  CUDADeviceCode::CheckAvailableStatus();
#endif
86 87
}

88
#ifdef PADDLE_WITH_CUDA
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
static bool CheckCUDADriverResult(CUresult result, std::string caller,
                                  std::string kernel_name = "") {
  if (result != CUDA_SUCCESS) {
    const char* error = nullptr;
    dynload::cuGetErrorString(result, &error);
    LOG_FIRST_N(WARNING, 1) << "Call " << caller << " for < " << kernel_name
                            << " > failed: " << error << " (" << result << ")";
    return false;
  }
  return true;
}

bool CUDADeviceCode::available_ = false;
void CUDADeviceCode::CheckAvailableStatus() {
  available_ = false;
  if (!dynload::HasNVRTC() || !dynload::HasCUDADriver()) {
    LOG_FIRST_N(WARNING, 1)
        << "NVRTC and CUDA driver are need for JIT compiling of CUDA code.";
    return;
  }

  int nvrtc_major = 0;
  int nvrtc_minor = 0;
  nvrtcResult nvrtc_result = dynload::nvrtcVersion(&nvrtc_major, &nvrtc_minor);

  int driver_version = 0;
  int dirver_major = 0;
  int driver_minor = 0;
  CUresult driver_result = dynload::cuDriverGetVersion(&driver_version);
  if (driver_result == CUDA_SUCCESS) {
    dirver_major = driver_version / 1000;
    driver_minor = (driver_version % 1000) / 10;
  }

  LOG_FIRST_N(INFO, 1) << "CUDA Driver Version: " << dirver_major << "."
                       << driver_minor << "; NVRTC Version: " << nvrtc_major
                       << "." << nvrtc_minor;
  if (nvrtc_result != NVRTC_SUCCESS || driver_result != CUDA_SUCCESS) {
    return;
  }

  int count = 0;
  if (CheckCUDADriverResult(dynload::cuDeviceGetCount(&count),
                            "cuDeviceGetCount")) {
    available_ = true;
  }
}

137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
static std::string FindCUDAIncludePath() {
  auto EndWith = [](std::string str, std::string substr) -> bool {
    size_t pos = str.rfind(substr);
    return pos != std::string::npos && pos == (str.length() - substr.length());
  };

  struct stat st;
  std::string cuda_include_path;
  if (!FLAGS_cuda_dir.empty()) {
    cuda_include_path = FLAGS_cuda_dir;
    if (EndWith(cuda_include_path, "/")) {
      cuda_include_path.erase(cuda_include_path.end() - 1);
    }
    for (std::string suffix : {"/lib", "/lib64"}) {
      if (EndWith(FLAGS_cuda_dir, suffix)) {
        cuda_include_path.erase(cuda_include_path.end() - suffix.length());
        break;
      }
    }

    if (!EndWith(cuda_include_path, "include")) {
      cuda_include_path += "/include";
    }
    // Whether the cuda_include_path exists on the file system.
    if (stat(cuda_include_path.c_str(), &st) == 0) {
      return cuda_include_path;
    }
  }

  cuda_include_path = "/usr/local/cuda/include";
  if (stat(cuda_include_path.c_str(), &st) == 0) {
    return cuda_include_path;
  }
  LOG(WARNING) << "Cannot find CUDA include path."
               << "Please check whether CUDA is installed in the default "
                  "installation path, or specify it by export "
                  "FLAGS_cuda_dir=xxx.";
  return "";
}

177 178 179
CUDADeviceCode::CUDADeviceCode(const Place& place, const std::string& name,
                               const std::string& kernel) {
  if (!is_gpu_place(place)) {
180 181
    PADDLE_THROW(platform::errors::PermissionDenied(
        "CUDADeviceCode can only launch on GPU place."));
182 183 184 185 186 187 188
  }

  place_ = place;
  name_ = name;
  kernel_ = kernel;
}

189
bool CUDADeviceCode::Compile(bool include_path) {
190 191
  is_compiled_ = false;
  if (!dynload::HasNVRTC() || !dynload::HasCUDADriver()) {
192
    LOG_FIRST_N(WARNING, 1)
193 194 195 196
        << "NVRTC and CUDA driver are need for JIT compiling of CUDA code.";
    return false;
  }

197
  nvrtcProgram program;
198 199 200 201 202 203 204 205 206
  if (!CheckNVRTCResult(dynload::nvrtcCreateProgram(&program,
                                                    kernel_.c_str(),  // buffer
                                                    name_.c_str(),    // name
                                                    0,         // numHeaders
                                                    nullptr,   // headers
                                                    nullptr),  // includeNames
                        "nvrtcCreateProgram")) {
    return false;
  }
207 208 209 210 211 212 213

  // Compile the program for specified compute_capability
  auto* dev_ctx = reinterpret_cast<CUDADeviceContext*>(
      DeviceContextPool::Instance().Get(place_));
  int compute_capability = dev_ctx->GetComputeCapability();
  std::string compute_flag =
      "--gpu-architecture=compute_" + std::to_string(compute_capability);
214
  std::vector<const char*> options = {"--std=c++11", compute_flag.c_str()};
215
  std::string include_option;
216 217 218
  if (include_path) {
    std::string cuda_include_path = FindCUDAIncludePath();
    if (!cuda_include_path.empty()) {
219
      include_option = "--include-path=" + cuda_include_path;
220 221 222
      options.push_back(include_option.c_str());
    }
  }
223 224 225 226 227 228 229
  nvrtcResult compile_result =
      dynload::nvrtcCompileProgram(program,          // program
                                   options.size(),   // numOptions
                                   options.data());  // options
  if (compile_result == NVRTC_ERROR_COMPILATION) {
    // Obtain compilation log from the program
    size_t log_size;
230 231 232 233
    if (!CheckNVRTCResult(dynload::nvrtcGetProgramLogSize(program, &log_size),
                          "nvrtcGetProgramLogSize")) {
      return false;
    }
234 235
    std::vector<char> log;
    log.resize(log_size + 1);
236 237 238 239 240 241 242 243 244
    if (!CheckNVRTCResult(dynload::nvrtcGetProgramLog(program, log.data()),
                          "nvrtcGetProgramLog")) {
      return false;
    }
    LOG(WARNING) << "JIT compiling of CUDA code failed:"
                 << "\n  Kernel name: " << name_ << "\n  Kernel body:\n"
                 << kernel_ << "\n  Compiling log: " << log.data();

    return false;
245 246 247 248
  }

  // Obtain PTX from the program
  size_t ptx_size;
249 250 251 252
  if (!CheckNVRTCResult(dynload::nvrtcGetPTXSize(program, &ptx_size),
                        "nvrtcGetPTXSize")) {
    return false;
  }
253
  ptx_.resize(ptx_size + 1);
254 255 256 257
  if (!CheckNVRTCResult(dynload::nvrtcGetPTX(program, ptx_.data()),
                        "nvrtcGetPTX")) {
    return false;
  }
258

259 260 261 262
  if (!CheckNVRTCResult(dynload::nvrtcDestroyProgram(&program),
                        "nvrtcDestroyProgram")) {
    return false;
  }
263

264
  if (!CheckCUDADriverResult(dynload::cuModuleLoadData(&module_, ptx_.data()),
265
                             "cuModuleLoadData", name_)) {
266 267 268 269 270
    return false;
  }

  if (!CheckCUDADriverResult(
          dynload::cuModuleGetFunction(&function_, module_, name_.c_str()),
271
          "cuModuleGetFunction", name_)) {
272 273
    return false;
  }
274 275

  max_threads_ = dev_ctx->GetMaxPhysicalThreadCount();
276 277
  is_compiled_ = true;
  return true;
278 279 280
}

void CUDADeviceCode::Launch(const size_t n, std::vector<void*>* args) const {
281 282 283 284 285
  PADDLE_ENFORCE_EQ(
      is_compiled_, true,
      errors::PreconditionNotMet(
          "Please compile the code before launching the kernel."));

286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
  int max_blocks = std::max(max_threads_ / num_threads_, 1);
  int workload_per_block = workload_per_thread_ * num_threads_;
  int num_blocks =
      std::min(max_blocks, (static_cast<int>(n) + workload_per_block - 1) /
                               workload_per_block);

  auto* dev_ctx = reinterpret_cast<CUDADeviceContext*>(
      DeviceContextPool::Instance().Get(place_));
  PADDLE_ENFORCE_EQ(
      dynload::cuLaunchKernel(function_, num_blocks, 1, 1,  // grid dim
                              num_threads_, 1, 1,           // block dim
                              0,                            // shared memory
                              dev_ctx->stream(),            // stream
                              args->data(),                 // arguments
                              nullptr),
301 302 303 304 305 306 307 308
      CUDA_SUCCESS,
      errors::External("Fail to launch kernel %s (in cuLaunchKernel.)",
                       name_.c_str()));
}

bool CUDADeviceCode::CheckNVRTCResult(nvrtcResult result,
                                      std::string function) {
  if (result != NVRTC_SUCCESS) {
309 310 311
    LOG_FIRST_N(WARNING, 1)
        << "Call " << function << " for < " << name_
        << " > failed: " << dynload::nvrtcGetErrorString(result);
312 313 314
    return false;
  }
  return true;
315 316 317 318 319
}
#endif

}  // namespace platform
}  // namespace paddle