diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 69435793a75a203533806a567c718e0af4d2e20c..0b3b96e82e4c33e691ed7a2de417fc9265cc61e2 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -118,7 +118,14 @@ cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor) nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info) if(WITH_GPU) - nv_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor operator) + nv_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor operator) else() - cc_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor operator) + cc_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor operator) +endif() + +if(NOT APPLE AND NOT WIN32) + cc_library(device_code SRCS device_code.cc DEPS device_context) + if(WITH_GPU) + cc_test(device_code_test SRCS device_code_test.cc DEPS device_code lod_tensor) + endif() endif() diff --git a/paddle/fluid/platform/device_code.cc b/paddle/fluid/platform/device_code.cc new file mode 100644 index 0000000000000000000000000000000000000000..24421b5c3c99bd341c562f4c35df55ad749bdc50 --- /dev/null +++ b/paddle/fluid/platform/device_code.cc @@ -0,0 +1,123 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/device_code.h" +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { + +#ifdef PADDLE_WITH_CUDA +inline bool is_error(nvrtcResult stat) { return stat != NVRTC_SUCCESS; } + +inline void throw_on_error(nvrtcResult stat, const std::string& msg) { +#ifndef REPLACE_ENFORCE_GLOG + throw std::runtime_error(dynload::nvrtcGetErrorString(stat) + msg); +#else + LOG(FATAL) << dynload::nvrtcGetErrorString(stat) << msg; +#endif +} + +CUDADeviceCode::CUDADeviceCode(const Place& place, const std::string& name, + const std::string& kernel) { + if (!is_gpu_place(place)) { + PADDLE_THROW("CUDADeviceCode can only launch on GPU place."); + } + + place_ = place; + name_ = name; + kernel_ = kernel; +} + +void CUDADeviceCode::Compile() { + nvrtcProgram program; + PADDLE_ENFORCE_EQ(dynload::nvrtcCreateProgram(&program, + kernel_.c_str(), // buffer + name_.c_str(), // name + 0, // numHeaders + nullptr, // headers + nullptr), // includeNames + NVRTC_SUCCESS, + "nvrtcCreateProgram failed."); + + // Compile the program for specified compute_capability + auto* dev_ctx = reinterpret_cast( + DeviceContextPool::Instance().Get(place_)); + int compute_capability = dev_ctx->GetComputeCapability(); + std::string compute_flag = + "--gpu-architecture=compute_" + std::to_string(compute_capability); + const std::vector options = {"--std=c++11", + compute_flag.c_str()}; + nvrtcResult compile_result = + dynload::nvrtcCompileProgram(program, // program + options.size(), // numOptions + options.data()); // options + if (compile_result == NVRTC_ERROR_COMPILATION) { + // Obtain compilation log from the program + size_t log_size; + PADDLE_ENFORCE_EQ(dynload::nvrtcGetProgramLogSize(program, &log_size), + NVRTC_SUCCESS, "nvrtcGetProgramLogSize failed."); + std::vector log; + log.resize(log_size + 1); + PADDLE_ENFORCE_EQ(dynload::nvrtcGetProgramLog(program, log.data()), + NVRTC_SUCCESS, "nvrtcGetProgramLog failed."); + LOG(FATAL) << "JIT compiling of CUDA code failed:\n" << log.data(); + } + + // Obtain PTX from the program + size_t ptx_size; + PADDLE_ENFORCE_EQ(dynload::nvrtcGetPTXSize(program, &ptx_size), NVRTC_SUCCESS, + "nvrtcGetPTXSize failed."); + ptx_.resize(ptx_size + 1); + PADDLE_ENFORCE_EQ(dynload::nvrtcGetPTX(program, ptx_.data()), NVRTC_SUCCESS, + "nvrtcGetPTX failed."); + + PADDLE_ENFORCE_EQ(dynload::nvrtcDestroyProgram(&program), NVRTC_SUCCESS, + "nvrtcDestroyProgram failed."); + + PADDLE_ENFORCE_EQ( + dynload::cuModuleLoadData(&module_, ptx_.data()), CUDA_SUCCESS, + "Fail to load PTX of %s (in cuModuleLoadData.)", name_.c_str()); + PADDLE_ENFORCE_EQ( + dynload::cuModuleGetFunction(&function_, module_, name_.c_str()), + CUDA_SUCCESS, "Fail to get function of %s (in cuModuleGetFunction.)", + name_.c_str()); + + max_threads_ = dev_ctx->GetMaxPhysicalThreadCount(); +} + +void CUDADeviceCode::Launch(const size_t n, std::vector* args) const { + int max_blocks = std::max(max_threads_ / num_threads_, 1); + int workload_per_block = workload_per_thread_ * num_threads_; + int num_blocks = + std::min(max_blocks, (static_cast(n) + workload_per_block - 1) / + workload_per_block); + + auto* dev_ctx = reinterpret_cast( + DeviceContextPool::Instance().Get(place_)); + PADDLE_ENFORCE_EQ( + dynload::cuLaunchKernel(function_, num_blocks, 1, 1, // grid dim + num_threads_, 1, 1, // block dim + 0, // shared memory + dev_ctx->stream(), // stream + args->data(), // arguments + nullptr), + CUDA_SUCCESS, "Fail to launch kernel %s (in cuLaunchKernel.)", + name_.c_str()); +} +#endif + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device_code.h b/paddle/fluid/platform/device_code.h new file mode 100644 index 0000000000000000000000000000000000000000..19adb0707f1742e9a41c4eaec549f7ccd5101acb --- /dev/null +++ b/paddle/fluid/platform/device_code.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/platform/device_context.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/dynload/cuda_driver.h" +#include "paddle/fluid/platform/dynload/nvrtc.h" +#endif + +namespace paddle { +namespace platform { + +class DeviceCode { + public: + virtual ~DeviceCode() {} + virtual void Compile() = 0; + virtual void Launch(const size_t n, std::vector* args) const = 0; + + protected: + Place place_; + std::string name_; + std::string kernel_; +}; + +#ifdef PADDLE_WITH_CUDA +class CUDADeviceCode : public DeviceCode { + public: + explicit CUDADeviceCode(const Place& place, const std::string& name, + const std::string& kernel); + void Compile() override; + void Launch(const size_t n, std::vector* args) const override; + + void SetNumThreads(int num_threads) { num_threads_ = num_threads; } + void SetWorkloadPerThread(int workload_per_thread) { + workload_per_thread_ = workload_per_thread; + } + + private: + int max_threads_{0}; + int num_threads_{1024}; + int workload_per_thread_{1}; + std::vector ptx_; + CUmodule module_; + CUfunction function_; +}; +#endif + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device_code_test.cc b/paddle/fluid/platform/device_code_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b63ed4e369c7c9ccecf8a6b7e2272973a44e266 --- /dev/null +++ b/paddle/fluid/platform/device_code_test.cc @@ -0,0 +1,78 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/device_code.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/platform/init.h" + +constexpr auto saxpy_code = R"( +extern "C" __global__ +void saxpy_kernel(float a, float *x, float* y, float* z, size_t n) { + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < n; + tid += blockDim.x * gridDim.x) { + z[tid] = a * x[tid] + y[tid]; + } +} +)"; + +#ifdef PADDLE_WITH_CUDA +TEST(device_code, cuda) { + paddle::framework::InitDevices(false, {0}); + paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0); + paddle::platform::CUDADeviceCode code(place, "saxpy_kernel", saxpy_code); + + paddle::framework::Tensor cpu_x; + paddle::framework::Tensor cpu_y; + paddle::framework::Tensor cpu_z; + + float scale = 2; + auto dims = paddle::framework::make_ddim( + {static_cast(256), static_cast(1024)}); + cpu_x.mutable_data(dims, paddle::platform::CPUPlace()); + cpu_y.mutable_data(dims, paddle::platform::CPUPlace()); + + size_t n = cpu_x.numel(); + for (size_t i = 0; i < n; ++i) { + cpu_x.data()[i] = static_cast(i); + } + for (size_t i = 0; i < n; ++i) { + cpu_y.data()[i] = static_cast(0.5); + } + + paddle::framework::Tensor x; + paddle::framework::Tensor y; + paddle::framework::Tensor z; + + float* x_data = x.mutable_data(dims, place); + float* y_data = y.mutable_data(dims, place); + float* z_data = z.mutable_data(dims, place); + + TensorCopySync(cpu_x, place, &x); + TensorCopySync(cpu_y, place, &y); + + code.Compile(); + + std::vector args = {&scale, &x_data, &y_data, &z_data, &n}; + code.SetNumThreads(1024); + code.SetWorkloadPerThread(1); + code.Launch(n, &args); + + TensorCopySync(z, paddle::platform::CPUPlace(), &cpu_z); + for (size_t i = 0; i < n; i++) { + PADDLE_ENFORCE_EQ(cpu_z.data()[i], + static_cast(i) * scale + 0.5); + } +} +#endif diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index 07159d4a12ef4b628f7705ed206d3334be46dfc8..81312111aebac24d7c2854f0b192269860af0db1 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -3,8 +3,9 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce) list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc) # There is no macOS version of NCCL. +# Disable nvrtc and cuda_driver api on MacOS and Windows, and only do a early test on Linux. if (NOT APPLE AND NOT WIN32) - list(APPEND CUDA_SRCS nccl.cc) + list(APPEND CUDA_SRCS nccl.cc nvrtc.cc cuda_driver.cc) endif() if (TENSORRT_FOUND) diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h index ced789b90d067218c3b01d124cfd2c93dc94e528..ed9b9133c6a0d7597d73a7090c41b4dc56062e24 100644 --- a/paddle/fluid/platform/dynload/cublas.h +++ b/paddle/fluid/platform/dynload/cublas.h @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once diff --git a/paddle/fluid/platform/dynload/cuda_driver.cc b/paddle/fluid/platform/dynload/cuda_driver.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c2edb2ccef9720f0b31b3734c3a775337b5e1ce --- /dev/null +++ b/paddle/fluid/platform/dynload/cuda_driver.cc @@ -0,0 +1,30 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/dynload/cuda_driver.h" + +namespace paddle { +namespace platform { +namespace dynload { + +std::once_flag cuda_dso_flag; +void* cuda_dso_handle = nullptr; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +CUDA_ROUTINE_EACH(DEFINE_WRAP); + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/cuda_driver.h b/paddle/fluid/platform/dynload/cuda_driver.h new file mode 100644 index 0000000000000000000000000000000000000000..894797728bb1c3794082bc0ba3094a6748c5a0c4 --- /dev/null +++ b/paddle/fluid/platform/dynload/cuda_driver.h @@ -0,0 +1,79 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include // NOLINT +#include "paddle/fluid/platform/dynload/dynamic_loader.h" +#include "paddle/fluid/platform/port.h" + +namespace paddle { +namespace platform { +namespace dynload { + +extern std::once_flag cuda_dso_flag; +extern void* cuda_dso_handle; + +#ifdef PADDLE_USE_DSO + +#define DECLARE_DYNAMIC_LOAD_CUDA_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ + using cuda_func = decltype(&::__name); \ + std::call_once(cuda_dso_flag, []() { \ + cuda_dso_handle = paddle::platform::dynload::GetCUDADsoHandle(); \ + }); \ + static void* p_##__name = dlsym(cuda_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern struct DynLoad__##__name __name + +#else + +#define DECLARE_DYNAMIC_LOAD_CUDA_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + inline auto operator()(Args... args) { \ + return ::__name(args...); \ + } \ + }; \ + extern DynLoad__##__name __name + +#endif + +/** + * include all needed cuda driver functions + **/ +#define CUDA_ROUTINE_EACH(__macro) \ + __macro(cuGetErrorString); \ + __macro(cuModuleLoadData); \ + __macro(cuModuleGetFunction); \ + __macro(cuModuleUnload); \ + __macro(cuOccupancyMaxActiveBlocksPerMultiprocessor); \ + __macro(cuLaunchKernel); \ + __macro(cuCtxCreate); \ + __macro(cuCtxGetCurrent); \ + __macro(cuDeviceGet); \ + __macro(cuDevicePrimaryCtxGetState) + +CUDA_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDA_WRAP); + +#undef DECLARE_DYNAMIC_LOAD_CUDA_WRAP + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index c2d6abb346c7b906d769a9fed756b88441a76afc..4a1cd5a8db7fa9b8f8fdd9427c7a26e5c90cc95f 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -222,6 +222,22 @@ void* GetCurandDsoHandle() { #endif } +void* GetNVRTCDsoHandle() { +#if defined(__APPLE__) || defined(__OSX__) + return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvrtc.dylib"); +#else + return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvrtc.so"); +#endif +} + +void* GetCUDADsoHandle() { +#if defined(__APPLE__) || defined(__OSX__) + return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcuda.dylib"); +#else + return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcuda.so"); +#endif +} + void* GetWarpCTCDsoHandle() { std::string warpctc_dir = ""; if (!s_py_site_pkg_path.path.empty()) { diff --git a/paddle/fluid/platform/dynload/dynamic_loader.h b/paddle/fluid/platform/dynload/dynamic_loader.h index d8bc884ee0987c24c6c6b08229c40b950c546a29..df101474aa4e158198baf92ca389d23239ba6f47 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.h +++ b/paddle/fluid/platform/dynload/dynamic_loader.h @@ -29,6 +29,8 @@ void* GetCublasDsoHandle(); void* GetCUDNNDsoHandle(); void* GetCUPTIDsoHandle(); void* GetCurandDsoHandle(); +void* GetNVRTCDsoHandle(); +void* GetCUDADsoHandle(); void* GetWarpCTCDsoHandle(); void* GetNCCLDsoHandle(); void* GetTensorRtDsoHandle(); diff --git a/paddle/fluid/platform/dynload/nvrtc.cc b/paddle/fluid/platform/dynload/nvrtc.cc new file mode 100644 index 0000000000000000000000000000000000000000..793b5b8d149daa89d7a570e7d7519a3e9aebf584 --- /dev/null +++ b/paddle/fluid/platform/dynload/nvrtc.cc @@ -0,0 +1,30 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/dynload/nvrtc.h" + +namespace paddle { +namespace platform { +namespace dynload { + +std::once_flag nvrtc_dso_flag; +void* nvrtc_dso_handle = nullptr; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +NVRTC_ROUTINE_EACH(DEFINE_WRAP); + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/nvrtc.h b/paddle/fluid/platform/dynload/nvrtc.h new file mode 100644 index 0000000000000000000000000000000000000000..20647affabc807ed5a570f09daa241e4389007e4 --- /dev/null +++ b/paddle/fluid/platform/dynload/nvrtc.h @@ -0,0 +1,77 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include // NOLINT +#include "paddle/fluid/platform/dynload/dynamic_loader.h" +#include "paddle/fluid/platform/port.h" + +namespace paddle { +namespace platform { +namespace dynload { + +extern std::once_flag nvrtc_dso_flag; +extern void* nvrtc_dso_handle; + +#ifdef PADDLE_USE_DSO + +#define DECLARE_DYNAMIC_LOAD_NVRTC_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ + using nvrtc_func = decltype(&::__name); \ + std::call_once(nvrtc_dso_flag, []() { \ + nvrtc_dso_handle = paddle::platform::dynload::GetNVRTCDsoHandle(); \ + }); \ + static void* p_##__name = dlsym(nvrtc_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern struct DynLoad__##__name __name + +#else + +#define DECLARE_DYNAMIC_LOAD_NVRTC_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + inline auto operator()(Args... args) { \ + return ::__name(args...); \ + } \ + }; \ + extern DynLoad__##__name __name + +#endif + +/** + * include all needed nvrtc functions + **/ +#define NVRTC_ROUTINE_EACH(__macro) \ + __macro(nvrtcGetErrorString); \ + __macro(nvrtcCompileProgram); \ + __macro(nvrtcCreateProgram); \ + __macro(nvrtcDestroyProgram); \ + __macro(nvrtcGetPTX); \ + __macro(nvrtcGetPTXSize); \ + __macro(nvrtcGetProgramLog); \ + __macro(nvrtcGetProgramLogSize) + +NVRTC_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NVRTC_WRAP); + +#undef DECLARE_DYNAMIC_LOAD_NVRTC_WRAP + +} // namespace dynload +} // namespace platform +} // namespace paddle