未验证 提交 42b5bec6 编写于 作者: Y Yiqun Liu 提交者: GitHub

Integrate NVRTC to support compiling CUDA kernel at runtime (#19422)

* Add the dynamic load of nvrtc, and support runtime compiling of CUDA kernel using nvrtc.
test=develop

* Call CUDA driver api to launch the kernel compiled by nvrtc.
test=develop

* Disable for mac and windows.
test=develop

* Refine the codes to support manually specified num_threads and workload_per_thread.
test=develop

* Refine the CUDA kernel to support large dims.
test=develop
上级 3ae939e4
...@@ -118,7 +118,14 @@ cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor) ...@@ -118,7 +118,14 @@ cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor)
nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info) nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info)
if(WITH_GPU) if(WITH_GPU)
nv_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor operator) nv_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor operator)
else() else()
cc_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor operator) cc_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor operator)
endif()
if(NOT APPLE AND NOT WIN32)
cc_library(device_code SRCS device_code.cc DEPS device_context)
if(WITH_GPU)
cc_test(device_code_test SRCS device_code_test.cc DEPS device_code lod_tensor)
endif()
endif() endif()
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/device_code.h"
#include <algorithm>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
#ifdef PADDLE_WITH_CUDA
inline bool is_error(nvrtcResult stat) { return stat != NVRTC_SUCCESS; }
inline void throw_on_error(nvrtcResult stat, const std::string& msg) {
#ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(dynload::nvrtcGetErrorString(stat) + msg);
#else
LOG(FATAL) << dynload::nvrtcGetErrorString(stat) << msg;
#endif
}
CUDADeviceCode::CUDADeviceCode(const Place& place, const std::string& name,
const std::string& kernel) {
if (!is_gpu_place(place)) {
PADDLE_THROW("CUDADeviceCode can only launch on GPU place.");
}
place_ = place;
name_ = name;
kernel_ = kernel;
}
void CUDADeviceCode::Compile() {
nvrtcProgram program;
PADDLE_ENFORCE_EQ(dynload::nvrtcCreateProgram(&program,
kernel_.c_str(), // buffer
name_.c_str(), // name
0, // numHeaders
nullptr, // headers
nullptr), // includeNames
NVRTC_SUCCESS,
"nvrtcCreateProgram failed.");
// Compile the program for specified compute_capability
auto* dev_ctx = reinterpret_cast<CUDADeviceContext*>(
DeviceContextPool::Instance().Get(place_));
int compute_capability = dev_ctx->GetComputeCapability();
std::string compute_flag =
"--gpu-architecture=compute_" + std::to_string(compute_capability);
const std::vector<const char*> options = {"--std=c++11",
compute_flag.c_str()};
nvrtcResult compile_result =
dynload::nvrtcCompileProgram(program, // program
options.size(), // numOptions
options.data()); // options
if (compile_result == NVRTC_ERROR_COMPILATION) {
// Obtain compilation log from the program
size_t log_size;
PADDLE_ENFORCE_EQ(dynload::nvrtcGetProgramLogSize(program, &log_size),
NVRTC_SUCCESS, "nvrtcGetProgramLogSize failed.");
std::vector<char> log;
log.resize(log_size + 1);
PADDLE_ENFORCE_EQ(dynload::nvrtcGetProgramLog(program, log.data()),
NVRTC_SUCCESS, "nvrtcGetProgramLog failed.");
LOG(FATAL) << "JIT compiling of CUDA code failed:\n" << log.data();
}
// Obtain PTX from the program
size_t ptx_size;
PADDLE_ENFORCE_EQ(dynload::nvrtcGetPTXSize(program, &ptx_size), NVRTC_SUCCESS,
"nvrtcGetPTXSize failed.");
ptx_.resize(ptx_size + 1);
PADDLE_ENFORCE_EQ(dynload::nvrtcGetPTX(program, ptx_.data()), NVRTC_SUCCESS,
"nvrtcGetPTX failed.");
PADDLE_ENFORCE_EQ(dynload::nvrtcDestroyProgram(&program), NVRTC_SUCCESS,
"nvrtcDestroyProgram failed.");
PADDLE_ENFORCE_EQ(
dynload::cuModuleLoadData(&module_, ptx_.data()), CUDA_SUCCESS,
"Fail to load PTX of %s (in cuModuleLoadData.)", name_.c_str());
PADDLE_ENFORCE_EQ(
dynload::cuModuleGetFunction(&function_, module_, name_.c_str()),
CUDA_SUCCESS, "Fail to get function of %s (in cuModuleGetFunction.)",
name_.c_str());
max_threads_ = dev_ctx->GetMaxPhysicalThreadCount();
}
void CUDADeviceCode::Launch(const size_t n, std::vector<void*>* args) const {
int max_blocks = std::max(max_threads_ / num_threads_, 1);
int workload_per_block = workload_per_thread_ * num_threads_;
int num_blocks =
std::min(max_blocks, (static_cast<int>(n) + workload_per_block - 1) /
workload_per_block);
auto* dev_ctx = reinterpret_cast<CUDADeviceContext*>(
DeviceContextPool::Instance().Get(place_));
PADDLE_ENFORCE_EQ(
dynload::cuLaunchKernel(function_, num_blocks, 1, 1, // grid dim
num_threads_, 1, 1, // block dim
0, // shared memory
dev_ctx->stream(), // stream
args->data(), // arguments
nullptr),
CUDA_SUCCESS, "Fail to launch kernel %s (in cuLaunchKernel.)",
name_.c_str());
}
#endif
} // namespace platform
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cuda_driver.h"
#include "paddle/fluid/platform/dynload/nvrtc.h"
#endif
namespace paddle {
namespace platform {
class DeviceCode {
public:
virtual ~DeviceCode() {}
virtual void Compile() = 0;
virtual void Launch(const size_t n, std::vector<void*>* args) const = 0;
protected:
Place place_;
std::string name_;
std::string kernel_;
};
#ifdef PADDLE_WITH_CUDA
class CUDADeviceCode : public DeviceCode {
public:
explicit CUDADeviceCode(const Place& place, const std::string& name,
const std::string& kernel);
void Compile() override;
void Launch(const size_t n, std::vector<void*>* args) const override;
void SetNumThreads(int num_threads) { num_threads_ = num_threads; }
void SetWorkloadPerThread(int workload_per_thread) {
workload_per_thread_ = workload_per_thread;
}
private:
int max_threads_{0};
int num_threads_{1024};
int workload_per_thread_{1};
std::vector<char> ptx_;
CUmodule module_;
CUfunction function_;
};
#endif
} // namespace platform
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/device_code.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/init.h"
constexpr auto saxpy_code = R"(
extern "C" __global__
void saxpy_kernel(float a, float *x, float* y, float* z, size_t n) {
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < n;
tid += blockDim.x * gridDim.x) {
z[tid] = a * x[tid] + y[tid];
}
}
)";
#ifdef PADDLE_WITH_CUDA
TEST(device_code, cuda) {
paddle::framework::InitDevices(false, {0});
paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0);
paddle::platform::CUDADeviceCode code(place, "saxpy_kernel", saxpy_code);
paddle::framework::Tensor cpu_x;
paddle::framework::Tensor cpu_y;
paddle::framework::Tensor cpu_z;
float scale = 2;
auto dims = paddle::framework::make_ddim(
{static_cast<int64_t>(256), static_cast<int64_t>(1024)});
cpu_x.mutable_data<float>(dims, paddle::platform::CPUPlace());
cpu_y.mutable_data<float>(dims, paddle::platform::CPUPlace());
size_t n = cpu_x.numel();
for (size_t i = 0; i < n; ++i) {
cpu_x.data<float>()[i] = static_cast<float>(i);
}
for (size_t i = 0; i < n; ++i) {
cpu_y.data<float>()[i] = static_cast<float>(0.5);
}
paddle::framework::Tensor x;
paddle::framework::Tensor y;
paddle::framework::Tensor z;
float* x_data = x.mutable_data<float>(dims, place);
float* y_data = y.mutable_data<float>(dims, place);
float* z_data = z.mutable_data<float>(dims, place);
TensorCopySync(cpu_x, place, &x);
TensorCopySync(cpu_y, place, &y);
code.Compile();
std::vector<void*> args = {&scale, &x_data, &y_data, &z_data, &n};
code.SetNumThreads(1024);
code.SetWorkloadPerThread(1);
code.Launch(n, &args);
TensorCopySync(z, paddle::platform::CPUPlace(), &cpu_z);
for (size_t i = 0; i < n; i++) {
PADDLE_ENFORCE_EQ(cpu_z.data<float>()[i],
static_cast<float>(i) * scale + 0.5);
}
}
#endif
...@@ -3,8 +3,9 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce) ...@@ -3,8 +3,9 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce)
list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc) list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc)
# There is no macOS version of NCCL. # There is no macOS version of NCCL.
# Disable nvrtc and cuda_driver api on MacOS and Windows, and only do a early test on Linux.
if (NOT APPLE AND NOT WIN32) if (NOT APPLE AND NOT WIN32)
list(APPEND CUDA_SRCS nccl.cc) list(APPEND CUDA_SRCS nccl.cc nvrtc.cc cuda_driver.cc)
endif() endif()
if (TENSORRT_FOUND) if (TENSORRT_FOUND)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/cuda_driver.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag cuda_dso_flag;
void* cuda_dso_handle = nullptr;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
CUDA_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace paddle {
namespace platform {
namespace dynload {
extern std::once_flag cuda_dso_flag;
extern void* cuda_dso_handle;
#ifdef PADDLE_USE_DSO
#define DECLARE_DYNAMIC_LOAD_CUDA_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using cuda_func = decltype(&::__name); \
std::call_once(cuda_dso_flag, []() { \
cuda_dso_handle = paddle::platform::dynload::GetCUDADsoHandle(); \
}); \
static void* p_##__name = dlsym(cuda_dso_handle, #__name); \
return reinterpret_cast<cuda_func>(p_##__name)(args...); \
} \
}; \
extern struct DynLoad__##__name __name
#else
#define DECLARE_DYNAMIC_LOAD_CUDA_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
inline auto operator()(Args... args) { \
return ::__name(args...); \
} \
}; \
extern DynLoad__##__name __name
#endif
/**
* include all needed cuda driver functions
**/
#define CUDA_ROUTINE_EACH(__macro) \
__macro(cuGetErrorString); \
__macro(cuModuleLoadData); \
__macro(cuModuleGetFunction); \
__macro(cuModuleUnload); \
__macro(cuOccupancyMaxActiveBlocksPerMultiprocessor); \
__macro(cuLaunchKernel); \
__macro(cuCtxCreate); \
__macro(cuCtxGetCurrent); \
__macro(cuDeviceGet); \
__macro(cuDevicePrimaryCtxGetState)
CUDA_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDA_WRAP);
#undef DECLARE_DYNAMIC_LOAD_CUDA_WRAP
} // namespace dynload
} // namespace platform
} // namespace paddle
...@@ -222,6 +222,22 @@ void* GetCurandDsoHandle() { ...@@ -222,6 +222,22 @@ void* GetCurandDsoHandle() {
#endif #endif
} }
void* GetNVRTCDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvrtc.dylib");
#else
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvrtc.so");
#endif
}
void* GetCUDADsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcuda.dylib");
#else
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcuda.so");
#endif
}
void* GetWarpCTCDsoHandle() { void* GetWarpCTCDsoHandle() {
std::string warpctc_dir = ""; std::string warpctc_dir = "";
if (!s_py_site_pkg_path.path.empty()) { if (!s_py_site_pkg_path.path.empty()) {
......
...@@ -29,6 +29,8 @@ void* GetCublasDsoHandle(); ...@@ -29,6 +29,8 @@ void* GetCublasDsoHandle();
void* GetCUDNNDsoHandle(); void* GetCUDNNDsoHandle();
void* GetCUPTIDsoHandle(); void* GetCUPTIDsoHandle();
void* GetCurandDsoHandle(); void* GetCurandDsoHandle();
void* GetNVRTCDsoHandle();
void* GetCUDADsoHandle();
void* GetWarpCTCDsoHandle(); void* GetWarpCTCDsoHandle();
void* GetNCCLDsoHandle(); void* GetNCCLDsoHandle();
void* GetTensorRtDsoHandle(); void* GetTensorRtDsoHandle();
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/nvrtc.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag nvrtc_dso_flag;
void* nvrtc_dso_handle = nullptr;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
NVRTC_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <nvrtc.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace paddle {
namespace platform {
namespace dynload {
extern std::once_flag nvrtc_dso_flag;
extern void* nvrtc_dso_handle;
#ifdef PADDLE_USE_DSO
#define DECLARE_DYNAMIC_LOAD_NVRTC_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using nvrtc_func = decltype(&::__name); \
std::call_once(nvrtc_dso_flag, []() { \
nvrtc_dso_handle = paddle::platform::dynload::GetNVRTCDsoHandle(); \
}); \
static void* p_##__name = dlsym(nvrtc_dso_handle, #__name); \
return reinterpret_cast<nvrtc_func>(p_##__name)(args...); \
} \
}; \
extern struct DynLoad__##__name __name
#else
#define DECLARE_DYNAMIC_LOAD_NVRTC_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
inline auto operator()(Args... args) { \
return ::__name(args...); \
} \
}; \
extern DynLoad__##__name __name
#endif
/**
* include all needed nvrtc functions
**/
#define NVRTC_ROUTINE_EACH(__macro) \
__macro(nvrtcGetErrorString); \
__macro(nvrtcCompileProgram); \
__macro(nvrtcCreateProgram); \
__macro(nvrtcDestroyProgram); \
__macro(nvrtcGetPTX); \
__macro(nvrtcGetPTXSize); \
__macro(nvrtcGetProgramLog); \
__macro(nvrtcGetProgramLogSize)
NVRTC_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NVRTC_WRAP);
#undef DECLARE_DYNAMIC_LOAD_NVRTC_WRAP
} // namespace dynload
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册