From 1280f2947d7920676267df3f3ed5354d05edfdae Mon Sep 17 00:00:00 2001 From: Wilber Date: Fri, 13 May 2022 17:05:20 +0800 Subject: [PATCH] add gpu resources. (#42723) --- paddle/fluid/inference/api/CMakeLists.txt | 4 +- paddle/fluid/inference/api/infer_context.cc | 17 + paddle/fluid/inference/api/infer_context.h | 46 +++ .../fluid/inference/api/resource_manager.cc | 290 +++++++++++++++ paddle/fluid/inference/api/resource_manager.h | 109 ++++++ paddle/phi/backends/gpu/CMakeLists.txt | 3 +- paddle/phi/backends/gpu/gpu_context.cc | 343 +++++------------- paddle/phi/backends/gpu/gpu_context.h | 4 + paddle/phi/backends/gpu/gpu_resources.cc | 271 ++++++++++++++ paddle/phi/backends/gpu/gpu_resources.h | 51 +++ 10 files changed, 876 insertions(+), 262 deletions(-) create mode 100644 paddle/fluid/inference/api/infer_context.cc create mode 100644 paddle/fluid/inference/api/infer_context.h create mode 100644 paddle/fluid/inference/api/resource_manager.cc create mode 100644 paddle/fluid/inference/api/resource_manager.h create mode 100644 paddle/phi/backends/gpu/gpu_resources.cc create mode 100644 paddle/phi/backends/gpu/gpu_resources.h diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index edec1b1c7d0..56cc4aa755b 100755 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -50,10 +50,10 @@ if(WITH_GPU AND TENSORRT_FOUND) endif() if (WITH_ONNXRUNTIME) - cc_library(analysis_predictor SRCS analysis_predictor.cc onnxruntime_predictor.cc ${mkldnn_quantizer_src} DEPS ${inference_deps} + cc_library(analysis_predictor SRCS analysis_predictor.cc onnxruntime_predictor.cc resource_manager.cc infer_context.cc ${mkldnn_quantizer_src} DEPS ${inference_deps} zero_copy_tensor ir_pass_manager op_compatible_info infer_io_utils onnxruntime paddle2onnx) else (WITH_ONNXRUNTIME) - cc_library(analysis_predictor SRCS analysis_predictor.cc ${mkldnn_quantizer_src} DEPS ${inference_deps} + cc_library(analysis_predictor SRCS analysis_predictor.cc resource_manager.cc infer_context.cc ${mkldnn_quantizer_src} DEPS ${inference_deps} zero_copy_tensor ir_pass_manager op_compatible_info infer_io_utils) endif (WITH_ONNXRUNTIME) diff --git a/paddle/fluid/inference/api/infer_context.cc b/paddle/fluid/inference/api/infer_context.cc new file mode 100644 index 00000000000..7706f2d0824 --- /dev/null +++ b/paddle/fluid/inference/api/infer_context.cc @@ -0,0 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/api/infer_context.h" + +namespace paddle {} // namespace paddle diff --git a/paddle/fluid/inference/api/infer_context.h b/paddle/fluid/inference/api/infer_context.h new file mode 100644 index 00000000000..b7a8bf637d8 --- /dev/null +++ b/paddle/fluid/inference/api/infer_context.h @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/backends/all_context.h" + +namespace paddle { + +class InferCPUContext : public phi::CPUContext { + public: + using phi::CPUContext::SetEigenDevice; +}; + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +class InferGPUContext : public phi::GPUContext { + public: + using phi::GPUContext::SetStream; + using phi::GPUContext::SetEigenDevice; + using phi::GPUContext::SetBlasHandle; + using phi::GPUContext::SetBlasTensorCoreHandle; + using phi::GPUContext::SetBlasTF32Handle; + using phi::GPUContext::SetDnnHandle; + using phi::GPUContext::SetSolverHandle; + using phi::GPUContext::SetSparseHandle; + // using phi::GPUContext::SetDnnWorkspaceHandle; + using phi::GPUContext::SetComputeCapability; + using phi::GPUContext::SetMaxThreadsPerMultiProcessor; + using phi::GPUContext::SetMultiProcessors; + using phi::GPUContext::SetMaxThreadsPerBlock; + using phi::GPUContext::SetMaxGridDimSize; + using phi::GPUContext::SetDriverVersion; + using phi::GPUContext::SetRuntimeVersion; +}; +#endif +} // namespace paddle diff --git a/paddle/fluid/inference/api/resource_manager.cc b/paddle/fluid/inference/api/resource_manager.cc new file mode 100644 index 00000000000..d88f282ce7a --- /dev/null +++ b/paddle/fluid/inference/api/resource_manager.cc @@ -0,0 +1,290 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/api/resource_manager.h" + +#include + +#include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/phi/backends/gpu/forwards.h" +#include "paddle/phi/backends/gpu/gpu_decls.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/backends/gpu/gpu_resources.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/allocator.h" +#include "paddle/phi/core/generator.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace internal { + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +class EigenGpuStreamDevice : public Eigen::StreamInterface { + public: + EigenGpuStreamDevice() : scratch_(nullptr), semaphore_(nullptr) { + Eigen::initializeDeviceProp(); + } + ~EigenGpuStreamDevice() override {} + + void Reinitialize(gpuStream_t cuda_stream, phi::Allocator* allocator, + GPUPlace place) { + stream_ = cuda_stream; + allocator_ = allocator; + device_prop_ = &Eigen::m_deviceProperties[place.device]; + } + + const gpuStream_t& stream() const override { return stream_; } + + const gpuDeviceProp& deviceProperties() const override { + return *device_prop_; + } + + void* allocate(size_t num_bytes) const override { + if (UNLIKELY(num_bytes == 0)) { + return nullptr; + } + auto buf = allocator_->Allocate(num_bytes); + VLOG(4) << "Eigen allocated at " << buf->ptr() << " requested " + << num_bytes; + void* retv = buf->ptr(); + { + std::lock_guard lock(mtx_); + allocations_.emplace(retv, std::move(buf)); + } + return retv; + } + + void deallocate(void* buffer) const override { + if (LIKELY(buffer)) { + std::lock_guard lock(mtx_); + allocations_.erase(buffer); + } + } + + void* scratchpad() const override { + if (scratch_ == NULL) { + scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int)); + } + return scratch_; + } + + unsigned int* semaphore() const override { + if (semaphore_ == NULL) { + char* scratch = static_cast(scratchpad()) + Eigen::kGpuScratchSize; + semaphore_ = reinterpret_cast(scratch); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS( + hipMemsetAsync(semaphore_, 0, sizeof(unsigned int), stream_)); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), stream_)); +#endif + } + return semaphore_; + } + + private: + gpuStream_t stream_; // not owned; + phi::Allocator* allocator_; // not owned; + const gpuDeviceProp* device_prop_; // not owned; + mutable void* scratch_; + mutable unsigned int* semaphore_; + mutable std::mutex mtx_; // to protect allocations_ + mutable std::unordered_map allocations_; +}; +#endif +} // namespace internal + +ResourceManager::ResourceManager(const phi::Place& place, void* stream) + : place_(place) { + InitCPUResource(); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + InitGPUResource(stream); +#endif +} + +ResourceManager::~ResourceManager() { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + DestroyGPUResource(); +#endif +} + +void ResourceManager::InitCPUResource() { + cpu_eigen_device_.reset(new Eigen::DefaultDevice()); +} + +Eigen::DefaultDevice* ResourceManager::GetCpuEigenDevice() { + return cpu_eigen_device_.get(); +} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +void ResourceManager::InitGPUResource(void* stream) { + if (stream == nullptr) { + owned_stream_ = true; + phi::InitStream(&stream_); + } else { + owned_stream_ = false; + stream_ = reinterpret_cast(stream); + } + + InitGpuProperties(); + InitGpuEigenDevice(); + InitDnnHanlde(); + InitBlasHandle(); + InitBlasLtHandle(); + InitSolverHandle(); + InitSparseHandle(); +} + +void ResourceManager::DestroyGPUResource() { + if (owned_stream_) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream_)); +#else + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream_)); +#endif + stream_ = nullptr; + } + + DestroyDnnHandle(); + DestroyBlasHandle(); + DestroyBlasLtHandle(); + DestroySolverHandle(); + DestroySparseHandle(); +} + +void ResourceManager::InitGpuProperties() { + phi::backends::gpu::GPUDeviceGuard guard(place_.device); + phi::InitGpuProperties(place_, &compute_capability_, &runtime_version_, + &driver_version_, &multi_process_, + &max_threads_per_mp_, &max_threads_per_block_, + &max_grid_dim_size_); +} + +void ResourceManager::InitGpuEigenDevice() { + auto* allocator = paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(place_) + .get(); + eigen_stream_.reset(new internal::EigenGpuStreamDevice()); + eigen_stream_->Reinitialize(stream_, allocator, place_); + gpu_eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); +} + +void ResourceManager::InitDnnHanlde() { + phi::InitDnnHandle(&dnn_handle_, stream_, place_); +} + +void ResourceManager::DestroyDnnHandle() { phi::DestroyDnnHandle(dnn_handle_); } + +void ResourceManager::InitBlasHandle() { + phi::InitBlasHandle(&blas_handle_, stream_); +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 9000 + phi::InitBlasHandle(&blas_tensor_core_handle_, stream_); + PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( + blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH)); +#endif +#if CUDA_VERSION >= 11000 + phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_); + PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( + blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH)); +#endif +#endif +} + +void ResourceManager::DestroyBlasHandle() { + phi::DestroyBlasHandle(blas_handle_); + phi::DestroyBlasHandle(blas_tensor_core_handle_); + phi::DestroyBlasHandle(blas_tf32_tensor_core_handle_); +} + +void ResourceManager::InitBlasLtHandle() { + phi::InitBlasLtHandle(&blaslt_handle_); +} + +void ResourceManager::DestroyBlasLtHandle() { + phi::DestroyBlasLtHandle(blaslt_handle_); +} + +void ResourceManager::InitSolverHandle() { + phi::InitSolverHandle(&solver_handle_, stream_); +} + +void ResourceManager::DestroySolverHandle() { + phi::DestroySolverHandle(solver_handle_); +} + +void ResourceManager::InitSparseHandle() { + phi::InitSparseHandle(&sparse_handle_, stream_); +} + +void ResourceManager::DestroySparseHandle() { + phi::DestroySparseHandle(sparse_handle_); +} + +gpuStream_t ResourceManager::GetStream() const { return stream_; } + +dnnHandle_t ResourceManager::GetDnnHandle() const { return dnn_handle_; } + +blasHandle_t ResourceManager::GetBlasHandle() const { return blas_handle_; } + +blasHandle_t ResourceManager::GetBlasTensorCoreHandle() const { + return blas_tensor_core_handle_; +} + +blasHandle_t ResourceManager::GetBlasTF32Handle() const { + return blas_tf32_tensor_core_handle_; +} + +blasLtHandle_t ResourceManager::GetBlasLtHandle() const { + return blaslt_handle_; +} + +phi::solverHandle_t ResourceManager::GetSolverDnHandle() const { + return solver_handle_; +} + +phi::sparseHandle_t ResourceManager::GetSparseHandle() const { + return sparse_handle_; +} + +Eigen::GpuDevice* ResourceManager::GetGpuEigenDevice() const { + return gpu_eigen_device_.get(); +} + +int ResourceManager::GetGpuComputeCapability() const { + return compute_capability_; +} + +int ResourceManager::GetGpuRuntimeVersion() const { return runtime_version_; } + +int ResourceManager::GetGpuDriverVersion() const { return driver_version_; } + +int ResourceManager::GetGPUMultiProcessors() const { return multi_process_; } + +int ResourceManager::GetGpuMaxThreadsPerMp() const { + return max_threads_per_mp_; +} + +int ResourceManager::GetGpuMaxThreadsPerBlock() const { + return max_threads_per_block_; +} + +std::array ResourceManager::GetGpuMaxGridDimSize() const { + return max_grid_dim_size_; +} + +#endif +} // namespace paddle diff --git a/paddle/fluid/inference/api/resource_manager.h b/paddle/fluid/inference/api/resource_manager.h new file mode 100644 index 00000000000..c41968dc585 --- /dev/null +++ b/paddle/fluid/inference/api/resource_manager.h @@ -0,0 +1,109 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/backends/cpu/forwards.h" + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include "paddle/fluid/platform/device/gpu/gpu_types.h" +#include "paddle/phi/backends/gpu/forwards.h" +#include "paddle/phi/backends/gpu/gpu_decls.h" +#include "paddle/phi/backends/gpu/gpu_resources.h" +#endif + +namespace paddle { +namespace internal { +class EigenGpuStreamDevice; +} // namespace internal + +class ResourceManager { + public: + explicit ResourceManager(const phi::Place& place, void* stream); + ~ResourceManager(); + + public: + Eigen::DefaultDevice* GetCpuEigenDevice(); + + private: + void InitCPUResource(); + + private: + phi::Place place_; + std::unique_ptr cpu_eigen_device_; + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + + public: + gpuStream_t GetStream() const; + dnnHandle_t GetDnnHandle() const; + blasHandle_t GetBlasHandle() const; + blasHandle_t GetBlasTensorCoreHandle() const; + blasHandle_t GetBlasTF32Handle() const; + blasLtHandle_t GetBlasLtHandle() const; + phi::solverHandle_t GetSolverDnHandle() const; + phi::sparseHandle_t GetSparseHandle() const; + Eigen::GpuDevice* GetGpuEigenDevice() const; + int GetGpuComputeCapability() const; + int GetGpuRuntimeVersion() const; + int GetGpuDriverVersion() const; + int GetGPUMultiProcessors() const; + int GetGpuMaxThreadsPerMp() const; + int GetGpuMaxThreadsPerBlock() const; + std::array GetGpuMaxGridDimSize() const; + + private: + void InitGPUResource(void* stream); + void DestroyGPUResource(); + void InitGpuProperties(); + void InitGpuEigenDevice(); + void InitDnnHanlde(); + void DestroyDnnHandle(); + void InitBlasHandle(); + void DestroyBlasHandle(); + void InitBlasLtHandle(); + void DestroyBlasLtHandle(); + void InitSolverHandle(); + void DestroySolverHandle(); + void InitSparseHandle(); + void DestroySparseHandle(); + + private: + int compute_capability_; + int runtime_version_; + int driver_version_; + int multi_process_; + int max_threads_per_mp_; + int max_threads_per_block_; + std::array max_grid_dim_size_; + + bool owned_stream_{true}; + gpuStream_t stream_; + std::unique_ptr gpu_eigen_device_; + std::unique_ptr eigen_stream_; + + blasHandle_t blas_handle_{nullptr}; + blasHandle_t blas_tensor_core_handle_{nullptr}; + blasHandle_t blas_tf32_tensor_core_handle_{nullptr}; + blasLtHandle_t blaslt_handle_{nullptr}; + dnnHandle_t dnn_handle_{nullptr}; + phi::solverHandle_t solver_handle_{nullptr}; + phi::sparseHandle_t sparse_handle_{nullptr}; +// DnnWorkspaceHandle +#endif +}; + +} // namespace paddle diff --git a/paddle/phi/backends/gpu/CMakeLists.txt b/paddle/phi/backends/gpu/CMakeLists.txt index d14e94024f9..ebe8f1ca4c1 100644 --- a/paddle/phi/backends/gpu/CMakeLists.txt +++ b/paddle/phi/backends/gpu/CMakeLists.txt @@ -6,4 +6,5 @@ elseif(WITH_ROCM) hip_library(phi_gpu_info SRCS gpu_info.cc DEPS phi_rocm_info gflags glog enforce phi_dynload_cuda) endif() -cc_library(gpu_context SRCS gpu_context.cc DEPS phi_device_context phi_gpu_info eigen3) +cc_library(gpu_resources SRCS gpu_resources.cc DEPS phi_device_context phi_gpu_info) +cc_library(gpu_context SRCS gpu_context.cc DEPS phi_device_context phi_gpu_info eigen3 gpu_resources) diff --git a/paddle/phi/backends/gpu/gpu_context.cc b/paddle/phi/backends/gpu/gpu_context.cc index ff238b79978..e5d34376834 100644 --- a/paddle/phi/backends/gpu/gpu_context.cc +++ b/paddle/phi/backends/gpu/gpu_context.cc @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_decls.h" #include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/backends/gpu/gpu_resources.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/allocator.h" @@ -202,27 +203,65 @@ struct GPUContext::Impl { void Init() { owned_ = true; backends::gpu::GPUDeviceGuard guard(place_.device); - InitGpuProperties(); - InitStream(); + phi::InitGpuProperties(place_, + &compute_capability_, + &runtime_version_, + &driver_version_, + &multi_process_, + &max_threads_per_mp_, + &max_threads_per_block_, + &max_grid_dim_size_); + phi::InitStream(&stream_); InitEigenDevice(); - InitBlasHandle(); - InitBlasLtHandle(); - InitDNNHandle(); - InitSolverHandle(); - InitSparseHandle(); + phi::InitBlasHandle(&blas_handle_, stream_); +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 9000 + phi::InitBlasHandle(&blas_tensor_core_handle_, stream_); + PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( + blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH)); +#endif +#if CUDA_VERSION >= 11000 + phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_); + PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( + blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH)); +#endif +#endif + phi::InitBlasLtHandle(&blaslt_handle_); + phi::InitDnnHandle(&dnn_handle_, stream_, place_); + phi::InitSolverHandle(&solver_handle_, stream_); + phi::InitSparseHandle(&sparse_handle_, stream_); InitDnnWorkspace(); } void PartialInitWithoutAllocator() { owned_ = true; backends::gpu::GPUDeviceGuard guard(place_.device); - InitGpuProperties(); - InitStream(); - InitBlasHandle(); - InitBlasLtHandle(); - InitDNNHandle(); - InitSolverHandle(); - InitSparseHandle(); + phi::InitGpuProperties(place_, + &compute_capability_, + &runtime_version_, + &driver_version_, + &multi_process_, + &max_threads_per_mp_, + &max_threads_per_block_, + &max_grid_dim_size_); + phi::InitStream(&stream_); + phi::InitBlasHandle(&blas_handle_, stream_); +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 9000 + phi::InitBlasHandle(&blas_tensor_core_handle_, stream_); + PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( + blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH)); +#endif +#if CUDA_VERSION >= 11000 + phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_); + PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( + blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH)); +#endif +#endif + phi::InitBlasLtHandle(&blaslt_handle_); + phi::InitDnnHandle(&dnn_handle_, stream_, place_); + phi::InitSolverHandle(&solver_handle_, stream_); + phi::InitSparseHandle(&sparse_handle_, stream_); } void PartialInitWithAllocator() { @@ -238,19 +277,23 @@ struct GPUContext::Impl { ~Impl() { backends::gpu::GPUDeviceGuard guard(place_.device); - DestoryInternalWorkspace(); - DestoryInternalEigenDevice(); - DestroyInternalSparseHandle(); - DestroyInternalSolverHandle(); - DestroyInternalDnnHandle(); + if (owned_) { + DestoryInternalWorkspace(); + DestoryInternalEigenDevice(); + phi::DestroySparseHandle(sparse_handle_); + phi::DestroySolverHandle(solver_handle_); + phi::DestroyDnnHandle(dnn_handle_); #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - if (nccl_comm_) { - PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclCommDestroy(nccl_comm_)); - } + if (nccl_comm_) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclCommDestroy(nccl_comm_)); + } #endif - DestroyInternalBlasHandle(); - DestroyInternalBlasLtHandle(); - DestoryInternalStream(); + phi::DestroyBlasHandle(blas_handle_); + phi::DestroyBlasHandle(blas_tensor_core_handle_); + phi::DestroyBlasHandle(blas_tf32_tensor_core_handle_); + phi::DestroyBlasLtHandle(blaslt_handle_); + phi::DestoryStream(stream_); + } } const Place& GetPlace() const { return place_; } @@ -259,73 +302,6 @@ struct GPUContext::Impl { return blas_tensor_core_handle_ != nullptr; } - void InitGpuProperties() { - backends::gpu::GPUDeviceGuard guard(place_.GetDeviceId()); - compute_capability_ = - backends::gpu::GetGPUComputeCapability(place_.GetDeviceId()); - multi_process_ = backends::gpu::GetGPUMultiProcessors(place_.GetDeviceId()); - max_threads_per_mp_ = - backends::gpu::GetGPUMaxThreadsPerMultiProcessor(place_.GetDeviceId()); - max_grid_dim_size_ = - backends::gpu::GetGpuMaxGridDimSize(place_.GetDeviceId()); - max_threads_per_block_ = - backends::gpu::GetGPUMaxThreadsPerBlock(place_.GetDeviceId()); - driver_version_ = backends::gpu::GetGPUDriverVersion(place_.GetDeviceId()); - runtime_version_ = - backends::gpu::GetGPURuntimeVersion(place_.GetDeviceId()); - - // TODO(wilber): glog may be replaced in the future? - LOG_FIRST_N(WARNING, 1) - << "Please NOTE: device: " << static_cast(place_.device) - << ", GPU Compute Capability: " << compute_capability_ / 10 << "." - << compute_capability_ % 10 - << ", Driver API Version: " << driver_version_ / 1000 << "." - << (driver_version_ % 100) / 10 - << ", Runtime API Version: " << runtime_version_ / 1000 << "." - << (runtime_version_ % 100) / 10; -#ifdef PADDLE_WITH_HIP - size_t miopen_major, miopen_minor, miopen_patch; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::miopenGetVersion(&miopen_major, &miopen_minor, &miopen_patch)); - auto cudnn_dso_ver = - (miopen_major * 1000 + miopen_minor * 10 + miopen_patch) / 10; - auto compile_miopen_version = MIOPEN_VERSION / 10; - if (cudnn_dso_ver < static_cast(compile_miopen_version)) { - LOG_FIRST_N(WARNING, 1) - << "WARNING: device: " << static_cast(place_.device) - << ". The installed Paddle is compiled with MIOPEN " - << compile_miopen_version / 100 << "." << compile_miopen_version % 100 - << ", but MIOPEN version in your machine is " << cudnn_dso_ver / 100 - << "." << cudnn_dso_ver % 100 - << ", which may cause serious incompatible bug. " - << "Please recompile or reinstall Paddle with compatible MIOPEN " - "version."; - } -#else - size_t cudnn_dso_ver = dynload::cudnnGetVersion(); - LOG_FIRST_N(WARNING, 1) << "device: " << static_cast(place_.device) - << ", cuDNN Version: " << cudnn_dso_ver / 1000 - << "." << (cudnn_dso_ver % 1000) / 100 << "."; - - // Check CUDA/CUDNN version compatiblity - auto local_cuda_version = - (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10; - auto compile_cuda_version = - (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10; - if (local_cuda_version < compile_cuda_version) { - LOG_FIRST_N(WARNING, 1) - << "WARNING: device: " << static_cast(place_.device) - << ". The installed Paddle is compiled with CUDA " - << compile_cuda_version / 10 << "." << compile_cuda_version % 10 - << ", but CUDA runtime version in your machine is " - << local_cuda_version / 10 << "." << local_cuda_version % 10 - << ", which may cause serious incompatible bug. " - << "Please recompile or reinstall Paddle with compatible CUDA " - "version."; - } -#endif - } - void InitDnnWorkspace() { PD_CHECK(allocator_ != nullptr, "the device allocator for gpu context is nullptr."); @@ -350,27 +326,6 @@ struct GPUContext::Impl { return DnnWorkspaceHandle(allocator_, stream_); } - void InitStream() { -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS( - hipStreamCreateWithPriority(&stream_, hipStreamDefault, 0)); -#else - PADDLE_ENFORCE_GPU_SUCCESS( - cudaStreamCreateWithPriority(&stream_, cudaStreamDefault, 0)); -#endif - } - - void DestoryInternalStream() { - if (owned_ && stream_ != nullptr) { -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream_)); -#else - PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream_)); -#endif - } - stream_ = nullptr; - } - void SetStream(gpuStream_t stream) { stream_ = stream; } gpuStream_t GetStream() const { @@ -400,55 +355,6 @@ struct GPUContext::Impl { return eigen_device_; } - void InitBlasHandle() { -#ifdef PADDLE_WITH_HIP - phi::dynload::rocblas_create_handle(&blas_handle_); - phi::dynload::rocblas_set_stream(blas_handle_, stream_); -#else // PADDLE_WITH_CUDA - PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasCreate(&blas_handle_)); - PADDLE_RETRY_CUDA_SUCCESS( - phi::dynload::cublasSetStream(blas_handle_, stream_)); -#if CUDA_VERSION >= 9000 - PADDLE_RETRY_CUDA_SUCCESS( - phi::dynload::cublasCreate(&blas_tensor_core_handle_)); - PADDLE_RETRY_CUDA_SUCCESS( - phi::dynload::cublasSetStream(blas_tensor_core_handle_, stream_)); - PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( - blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH)); -#if CUDA_VERSION >= 11000 - PADDLE_RETRY_CUDA_SUCCESS( - phi::dynload::cublasCreate(&blas_tf32_tensor_core_handle_)); - PADDLE_RETRY_CUDA_SUCCESS( - phi::dynload::cublasSetStream(blas_tf32_tensor_core_handle_, stream_)); - PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( - blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH)); -#endif // CUDA_VERSION >= 11000 -#endif // CUDA_VERSION >= 9000 -#endif // PADDLE_WITH_HIP - } - - void DestroyInternalBlasHandle() { -#ifdef PADDLE_WITH_HIP - if (owned_ && blas_handle_ != nullptr) { - phi::dynload::rocblas_destroy_handle(blas_handle_); - blas_handle_ = nullptr; - } -#else - if (owned_ && blas_handle_ != nullptr) { - phi::dynload::cublasDestroy(blas_handle_); - blas_handle_ = nullptr; - } - if (owned_ && blas_tensor_core_handle_ != nullptr) { - phi::dynload::cublasDestroy(blas_tensor_core_handle_); - blas_tensor_core_handle_ = nullptr; - } - if (owned_ && blas_tf32_tensor_core_handle_ != nullptr) { - phi::dynload::cublasDestroy(blas_tf32_tensor_core_handle_); - blas_tf32_tensor_core_handle_ = nullptr; - } -#endif // PADDLE_WITH_HIP - } - blasHandle_t GetBlasHandle() const { PD_CHECK(blas_handle_ != nullptr, "the gpu blas handle is nullptr."); return blas_handle_; @@ -456,16 +362,12 @@ struct GPUContext::Impl { void SetBlasHandle(blasHandle_t blas) { blas_handle_ = blas; } - void InitBlasLtHandle() { -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 - phi::dynload::cublasLtCreate(&blaslt_handle_); -#endif + void SetBlasTensorCoreHandle(blasHandle_t handle) { + blas_tensor_core_handle_ = handle; } - void DestroyInternalBlasLtHandle() { -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 - phi::dynload::cublasLtDestroy(blaslt_handle_); -#endif + void SetBlasTF32Handle(blasHandle_t handle) { + blas_tf32_tensor_core_handle_ = handle; } void SetBlasLtHandle(blasLtHandle_t blaslt) { blaslt_handle_ = blaslt; } @@ -475,53 +377,6 @@ struct GPUContext::Impl { return blaslt_handle_; } - void InitDNNHandle() { - if (phi::dynload::HasCUDNN()) { -#ifdef PADDLE_WITH_HIP - size_t miopen_major, miopen_minor, miopen_patch; - PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenGetVersion( - &miopen_major, &miopen_minor, &miopen_patch)); - auto local_miopen_version = - (miopen_major * 1000 + miopen_minor * 10 + miopen_patch) / 10; - auto compile_miopen_version = MIOPEN_VERSION / 10; - if (local_miopen_version < static_cast(compile_miopen_version)) { - LOG_FIRST_N(WARNING, 1) - << "WARNING: device: " << place_.device - << ". The installed Paddle is compiled with MIOPEN " - << compile_miopen_version / 100 << "." - << compile_miopen_version % 100 - << ", but MIOPEN version in your machine is " - << local_miopen_version / 100 << "." << local_miopen_version % 100 - << ", which may cause serious incompatible bug. " - << "Please recompile or reinstall Paddle with compatible MIOPEN " - "version."; - } - PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreate(&dnn_handle_)); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::miopenSetStream(dnn_handle_, stream_)); -#else - auto local_cudnn_version = phi::dynload::cudnnGetVersion() / 100; - auto compile_cudnn_version = CUDNN_VERSION / 100; - if (local_cudnn_version < static_cast(compile_cudnn_version)) { - LOG_FIRST_N(WARNING, 1) - << "WARNING: device: " << place_.device - << ". The installed Paddle is compiled with CUDNN " - << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10 - << ", but CUDNN version in your machine is " - << local_cudnn_version / 10 << "." << local_cudnn_version % 10 - << ", which may cause serious incompatible bug. " - << "Please recompile or reinstall Paddle with compatible CUDNN " - "version."; - } - PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cudnnCreate(&dnn_handle_)); - PADDLE_RETRY_CUDA_SUCCESS( - phi::dynload::cudnnSetStream(dnn_handle_, stream_)); -#endif - } else { - dnn_handle_ = nullptr; - } - } - dnnHandle_t GetDnnHandle() { PD_CHECK(dnn_handle_ != nullptr, "the gpu dnn handle is nullptr."); return dnn_handle_; @@ -543,24 +398,6 @@ struct GPUContext::Impl { void SetDnnHandle(dnnHandle_t handle) { dnn_handle_ = handle; } - void InitSolverHandle() { -#ifndef PADDLE_WITH_HIP - PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cusolverDnCreate(&solver_handle_)); - PADDLE_RETRY_CUDA_SUCCESS( - phi::dynload::cusolverDnSetStream(solver_handle_, stream_)); -#endif - } - - void DestroyInternalSolverHandle() { -#ifndef PADDLE_WITH_HIP - if (owned_ && solver_handle_ != nullptr) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cusolverDnDestroy(solver_handle_)); - solver_handle_ = nullptr; - } -#endif - } - solverHandle_t GetSolverHandle() const { PD_CHECK(solver_handle_ != nullptr, "the gpu solver handle is nullptr."); return solver_handle_; @@ -568,29 +405,6 @@ struct GPUContext::Impl { void SetSolverHandle(solverHandle_t handle) { solver_handle_ = handle; } - void InitSparseHandle() { -// ROCM is not yet supported -#if defined(PADDLE_WITH_CUDA) -// The generic APIs is supported from CUDA10.1 -#if CUDA_VERSION >= 10010 - PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseCreate(&sparse_handle_)); - PADDLE_RETRY_CUDA_SUCCESS( - dynload::cusparseSetStream(sparse_handle_, stream_)); -#endif -#endif - } - - void DestroyInternalSparseHandle() { -#ifdef PADDLE_WITH_CUDA -#if CUDA_VERSION >= 10010 - if (owned_ && sparse_handle_ != nullptr) { - PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseDestroy(sparse_handle_)); - sparse_handle_ = nullptr; - } -#endif -#endif - } - sparseHandle_t GetSparseHandle() const { PD_CHECK(sparse_handle_ != nullptr, "the gpu sparse handle is nullptr."); return sparse_handle_; @@ -878,7 +692,10 @@ void GPUContext::Init() { impl_->Init(); } -void GPUContext::SetStream(gpuStream_t stream) { impl_->SetStream(stream); } +void GPUContext::SetStream(gpuStream_t stream) { + impl_->allocator_ = const_cast(&this->GetAllocator()); + impl_->SetStream(stream); +} void GPUContext::SetEigenDevice(Eigen::GpuDevice* device) { impl_->SetEigenDevice(device); @@ -888,6 +705,14 @@ void GPUContext::SetBlasHandle(blasHandle_t blas) { impl_->SetBlasHandle(blas); } +void GPUContext::SetBlasTensorCoreHandle(blasHandle_t handle) { + impl_->SetBlasTensorCoreHandle(handle); +} + +void GPUContext::SetBlasTF32Handle(blasHandle_t handle) { + impl_->SetBlasTF32Handle(handle); +} + void GPUContext::SetBlasLtHandle(blasLtHandle_t blaslt) { impl_->SetBlasLtHandle(blaslt); } diff --git a/paddle/phi/backends/gpu/gpu_context.h b/paddle/phi/backends/gpu/gpu_context.h index 8d44acaa4a0..db9f287041d 100644 --- a/paddle/phi/backends/gpu/gpu_context.h +++ b/paddle/phi/backends/gpu/gpu_context.h @@ -199,6 +199,10 @@ class PADDLE_API GPUContext : public DeviceContext { void SetBlasHandle(blasHandle_t); + void SetBlasTensorCoreHandle(blasHandle_t); + + void SetBlasTF32Handle(blasHandle_t); + void SetBlasLtHandle(blasLtHandle_t); void SetDnnHandle(dnnHandle_t); diff --git a/paddle/phi/backends/gpu/gpu_resources.cc b/paddle/phi/backends/gpu/gpu_resources.cc new file mode 100644 index 00000000000..268024eb259 --- /dev/null +++ b/paddle/phi/backends/gpu/gpu_resources.cc @@ -0,0 +1,271 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_resources.h" + +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/backends/gpu/gpu_decls.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/allocator.h" + +#ifdef PADDLE_WITH_CUDA +#include "paddle/phi/backends/dynload/cublas.h" +#include "paddle/phi/backends/dynload/cudnn.h" +#include "paddle/phi/backends/dynload/cusolver.h" +#include "paddle/phi/backends/dynload/cusparse.h" +#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) +#include "paddle/phi/backends/dynload/nccl.h" +#endif // !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) +#endif // PADDLE_WITH_CUDA + +#include "unsupported/Eigen/CXX11/Tensor" + +// TODO(phi): remove fluid header. +#include "paddle/fluid/platform/enforce.h" + +namespace phi { + +void InitGpuProperties(Place place, + int* compute_capability, + int* runtime_version, + int* driver_version, + int* multi_process, + int* max_threads_per_mp, + int* max_threads_per_block, + std::array* max_grid_dim_size) { + backends::gpu::GPUDeviceGuard guard(place.GetDeviceId()); + *compute_capability = + backends::gpu::GetGPUComputeCapability(place.GetDeviceId()); + *multi_process = backends::gpu::GetGPUMultiProcessors(place.GetDeviceId()); + *max_threads_per_mp = + backends::gpu::GetGPUMaxThreadsPerMultiProcessor(place.GetDeviceId()); + *max_grid_dim_size = backends::gpu::GetGpuMaxGridDimSize(place.GetDeviceId()); + *max_threads_per_block = + backends::gpu::GetGPUMaxThreadsPerBlock(place.GetDeviceId()); + *driver_version = backends::gpu::GetGPUDriverVersion(place.GetDeviceId()); + *runtime_version = backends::gpu::GetGPURuntimeVersion(place.GetDeviceId()); + + // TODO(wilber): glog may be replaced in the future? + LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " + << static_cast(place.device) + << ", GPU Compute Capability: " + << *compute_capability / 10 << "." + << *compute_capability % 10 + << ", Driver API Version: " << *driver_version / 1000 + << "." << (*driver_version % 100) / 10 + << ", Runtime API Version: " + << *runtime_version / 1000 << "." + << (*runtime_version % 100) / 10; +#ifdef PADDLE_WITH_HIP + size_t miopen_major, miopen_minor, miopen_patch; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::miopenGetVersion(&miopen_major, &miopen_minor, &miopen_patch)); + auto cudnn_dso_ver = + (miopen_major * 1000 + miopen_minor * 10 + miopen_patch) / 10; + auto compile_miopen_version = MIOPEN_VERSION / 10; + if (cudnn_dso_ver < static_cast(compile_miopen_version)) { + LOG_FIRST_N(WARNING, 1) + << "WARNING: device: " << static_cast(place.device) + << ". The installed Paddle is compiled with MIOPEN " + << compile_miopen_version / 100 << "." << compile_miopen_version % 100 + << ", but MIOPEN version in your machine is " << cudnn_dso_ver / 100 + << "." << cudnn_dso_ver % 100 + << ", which may cause serious incompatible bug. " + << "Please recompile or reinstall Paddle with compatible MIOPEN " + "version."; + } +#else + size_t cudnn_dso_ver = dynload::cudnnGetVersion(); + LOG_FIRST_N(WARNING, 1) << "device: " << static_cast(place.device) + << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "." + << (cudnn_dso_ver % 1000) / 100 << "."; + + // Check CUDA/CUDNN version compatiblity + auto local_cuda_version = + (*driver_version / 1000) * 10 + (*driver_version % 100) / 10; + auto compile_cuda_version = + (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10; + if (local_cuda_version < compile_cuda_version) { + LOG_FIRST_N(WARNING, 1) + << "WARNING: device: " << static_cast(place.device) + << ". The installed Paddle is compiled with CUDA " + << compile_cuda_version / 10 << "." << compile_cuda_version % 10 + << ", but CUDA runtime version in your machine is " + << local_cuda_version / 10 << "." << local_cuda_version % 10 + << ", which may cause serious incompatible bug. " + << "Please recompile or reinstall Paddle with compatible CUDA " + "version."; + } +#endif +} + +void InitStream(gpuStream_t* stream) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS( + hipStreamCreateWithPriority(stream, hipStreamDefault, 0)); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + cudaStreamCreateWithPriority(stream, cudaStreamDefault, 0)); +#endif +} + +void DestoryStream(gpuStream_t stream) { + if (stream != nullptr) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream)); +#else + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream)); +#endif + } + stream = nullptr; +} + +void InitBlasHandle(blasHandle_t* blas_handle, gpuStream_t stream) { +#ifdef PADDLE_WITH_HIP + phi::dynload::rocblas_create_handle(blas_handle); + phi::dynload::rocblas_set_stream(*blas_handle, stream); +#else // PADDLE_WITH_CUDA + PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasCreate(blas_handle)); + PADDLE_RETRY_CUDA_SUCCESS( + phi::dynload::cublasSetStream(*blas_handle, stream)); +#endif // PADDLE_WITH_HIP +} + +void DestroyBlasHandle(blasHandle_t handle) { +#ifdef PADDLE_WITH_HIP + if (handle != nullptr) { + phi::dynload::rocblas_destroy_handle(handle); + handle = nullptr; + } +#else + if (handle != nullptr) { + phi::dynload::cublasDestroy(handle); + handle = nullptr; + } +#endif // PADDLE_WITH_HIP +} + +void InitBlasLtHandle(blasLtHandle_t* blaslt_handle) { +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 + phi::dynload::cublasLtCreate(blaslt_handle); +#endif +} + +void DestroyBlasLtHandle(blasLtHandle_t handle) { +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 + if (handle != nullptr) { + phi::dynload::cublasLtDestroy(handle); + handle = nullptr; + } +#endif +} + +void InitDnnHandle(dnnHandle_t* handle, gpuStream_t stream, Place place) { + if (phi::dynload::HasCUDNN()) { +#ifdef PADDLE_WITH_HIP + size_t miopen_major, miopen_minor, miopen_patch; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::miopenGetVersion(&miopen_major, &miopen_minor, &miopen_patch)); + auto local_miopen_version = + (miopen_major * 1000 + miopen_minor * 10 + miopen_patch) / 10; + auto compile_miopen_version = MIOPEN_VERSION / 10; + if (local_miopen_version < static_cast(compile_miopen_version)) { + LOG_FIRST_N(WARNING, 1) + << "WARNING: device: " << place.device + << ". The installed Paddle is compiled with MIOPEN " + << compile_miopen_version / 100 << "." << compile_miopen_version % 100 + << ", but MIOPEN version in your machine is " + << local_miopen_version / 100 << "." << local_miopen_version % 100 + << ", which may cause serious incompatible bug. " + << "Please recompile or reinstall Paddle with compatible MIOPEN " + "version."; + } + PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreate(handle)); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenSetStream(*handle, stream)); +#else + auto local_cudnn_version = phi::dynload::cudnnGetVersion() / 100; + auto compile_cudnn_version = CUDNN_VERSION / 100; + if (local_cudnn_version < static_cast(compile_cudnn_version)) { + LOG_FIRST_N(WARNING, 1) + << "WARNING: device: " << place.device + << ". The installed Paddle is compiled with CUDNN " + << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10 + << ", but CUDNN version in your machine is " + << local_cudnn_version / 10 << "." << local_cudnn_version % 10 + << ", which may cause serious incompatible bug. " + << "Please recompile or reinstall Paddle with compatible CUDNN " + "version."; + } + PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cudnnCreate(handle)); + PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cudnnSetStream(*handle, stream)); +#endif + } else { + *handle = nullptr; + } +} + +void DestroyDnnHandle(dnnHandle_t handle) { +#ifdef PADDLE_WITH_HIP + if (handle != nullptr) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenDestroy(handle)); + handle = nullptr; + } +#else + if (handle != nullptr) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnDestroy(handle)); + handle = nullptr; + } +#endif // PADDLE_WITH_HIP +} + +void InitSolverHandle(solverHandle_t* handle, gpuStream_t stream) { +#ifndef PADDLE_WITH_HIP + PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cusolverDnCreate(handle)); + PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cusolverDnSetStream(*handle, stream)); +#endif +} + +void DestroySolverHandle(solverHandle_t solver_handle) { +#ifndef PADDLE_WITH_HIP + if (solver_handle != nullptr) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDestroy(solver_handle)); + solver_handle = nullptr; + } +#endif +} + +void InitSparseHandle(sparseHandle_t* handle, gpuStream_t stream) { +// ROCM is not yet supported +#if defined(PADDLE_WITH_CUDA) +// The generic APIs is supported from CUDA10.1 +#if CUDA_VERSION >= 10010 + PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseCreate(handle)); + PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseSetStream(*handle, stream)); +#endif +#endif +} + +void DestroySparseHandle(sparseHandle_t handle) { +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 10010 + if (handle != nullptr) { + PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseDestroy(handle)); + handle = nullptr; + } +#endif +#endif +} + +} // namespace phi diff --git a/paddle/phi/backends/gpu/gpu_resources.h b/paddle/phi/backends/gpu/gpu_resources.h new file mode 100644 index 00000000000..07ccb621540 --- /dev/null +++ b/paddle/phi/backends/gpu/gpu_resources.h @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "paddle/phi/backends/gpu/gpu_decls.h" +#include "paddle/phi/common/place.h" + +namespace phi { + +void InitGpuProperties(Place place, + int* compute_capability, + int* runtime_version, + int* driver_version, + int* multi_process, + int* max_threads_per_mp, + int* max_threads_per_block, + std::array* max_grid_dim_size); + +void InitStream(gpuStream_t* stream); +void DestoryStream(gpuStream_t stream); + +void InitBlasHandle(blasHandle_t* blas_handle, gpuStream_t stream); +void DestroyBlasHandle(blasHandle_t handle); + +void InitBlasLtHandle(blasLtHandle_t* blaslt_handle); +void DestroyBlasLtHandle(blasLtHandle_t handle); + +void InitDnnHandle(dnnHandle_t* handle, gpuStream_t stream, Place place); +void DestroyDnnHandle(dnnHandle_t handle); + +void InitSolverHandle(solverHandle_t* handle, gpuStream_t stream); +void DestroySolverHandle(solverHandle_t solver_handle); + +void InitSparseHandle(sparseHandle_t* handle, gpuStream_t stream); +void DestroySparseHandle(sparseHandle_t handle); + +// void InitDnnWorkspace(); + +} // namespace phi -- GitLab