未验证 提交 90ae3533 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

gpu_context (#43661)

上级 1aafc31b
......@@ -6,4 +6,5 @@ elseif(WITH_ROCM)
hip_library(phi_gpu_info SRCS gpu_info.cc DEPS phi_rocm_info gflags glog enforce phi_dynload_cuda)
endif()
cc_library(gpu_context SRCS gpu_context.cc DEPS phi_device_context phi_gpu_info eigen3)
cc_library(gpu_resources SRCS gpu_resources.cc DEPS phi_device_context phi_gpu_info)
cc_library(gpu_context SRCS gpu_context.cc DEPS phi_device_context phi_gpu_info eigen3 gpu_resources)
......@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include <algorithm>
#include <array>
#include <functional>
......@@ -21,10 +22,11 @@ limitations under the License. */
#include <memory>
#include <mutex>
#include "glog/logging.h"
#include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_resources.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/allocator.h"
......@@ -202,27 +204,31 @@ struct GPUContext::Impl {
void Init() {
owned_ = true;
backends::gpu::GPUDeviceGuard guard(place_.device);
InitGpuProperties();
InitStream();
phi::InitGpuProperties(place_,
&compute_capability_,
&runtime_version_,
&driver_version_,
&multi_process_,
&max_threads_per_mp_,
&max_threads_per_block_,
&max_grid_dim_size_);
phi::InitStream(&stream_);
InitEigenDevice();
InitBlasHandle();
InitBlasLtHandle();
InitDNNHandle();
InitSolverHandle();
InitSparseHandle();
InitDnnWorkspace();
}
void PartialInitWithoutAllocator() {
owned_ = true;
backends::gpu::GPUDeviceGuard guard(place_.device);
InitGpuProperties();
InitStream();
InitBlasHandle();
InitBlasLtHandle();
InitDNNHandle();
InitSolverHandle();
InitSparseHandle();
phi::InitGpuProperties(place_,
&compute_capability_,
&runtime_version_,
&driver_version_,
&multi_process_,
&max_threads_per_mp_,
&max_threads_per_block_,
&max_grid_dim_size_);
phi::InitStream(&stream_);
}
void PartialInitWithAllocator() {
......@@ -238,19 +244,23 @@ struct GPUContext::Impl {
~Impl() {
backends::gpu::GPUDeviceGuard guard(place_.device);
DestoryInternalWorkspace();
DestoryInternalEigenDevice();
DestroyInternalSparseHandle();
DestroyInternalSolverHandle();
DestroyInternalDnnHandle();
if (owned_) {
DestoryInternalWorkspace();
DestoryInternalEigenDevice();
phi::DestroySparseHandle(sparse_handle_);
phi::DestroySolverHandle(solver_handle_);
phi::DestroyDnnHandle(dnn_handle_);
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (nccl_comm_) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
}
if (nccl_comm_) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
}
#endif
DestroyInternalBlasHandle();
DestroyInternalBlasLtHandle();
DestoryInternalStream();
phi::DestroyBlasHandle(blas_handle_);
phi::DestroyBlasHandle(blas_tensor_core_handle_);
phi::DestroyBlasHandle(blas_tf32_tensor_core_handle_);
phi::DestroyBlasLtHandle(blaslt_handle_);
phi::DestoryStream(stream_);
}
}
const Place& GetPlace() const { return place_; }
......@@ -259,73 +269,6 @@ struct GPUContext::Impl {
return blas_tensor_core_handle_ != nullptr;
}
void InitGpuProperties() {
backends::gpu::GPUDeviceGuard guard(place_.GetDeviceId());
compute_capability_ =
backends::gpu::GetGPUComputeCapability(place_.GetDeviceId());
multi_process_ = backends::gpu::GetGPUMultiProcessors(place_.GetDeviceId());
max_threads_per_mp_ =
backends::gpu::GetGPUMaxThreadsPerMultiProcessor(place_.GetDeviceId());
max_grid_dim_size_ =
backends::gpu::GetGpuMaxGridDimSize(place_.GetDeviceId());
max_threads_per_block_ =
backends::gpu::GetGPUMaxThreadsPerBlock(place_.GetDeviceId());
driver_version_ = backends::gpu::GetGPUDriverVersion(place_.GetDeviceId());
runtime_version_ =
backends::gpu::GetGPURuntimeVersion(place_.GetDeviceId());
// TODO(wilber): glog may be replaced in the future?
LOG_FIRST_N(WARNING, 1)
<< "Please NOTE: device: " << static_cast<int>(place_.device)
<< ", GPU Compute Capability: " << compute_capability_ / 10 << "."
<< compute_capability_ % 10
<< ", Driver API Version: " << driver_version_ / 1000 << "."
<< (driver_version_ % 100) / 10
<< ", Runtime API Version: " << runtime_version_ / 1000 << "."
<< (runtime_version_ % 100) / 10;
#ifdef PADDLE_WITH_HIP
size_t miopen_major, miopen_minor, miopen_patch;
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::miopenGetVersion(&miopen_major, &miopen_minor, &miopen_patch));
auto cudnn_dso_ver =
(miopen_major * 1000 + miopen_minor * 10 + miopen_patch) / 10;
auto compile_miopen_version = MIOPEN_VERSION / 10;
if (cudnn_dso_ver < static_cast<size_t>(compile_miopen_version)) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << static_cast<int>(place_.device)
<< ". The installed Paddle is compiled with MIOPEN "
<< compile_miopen_version / 100 << "." << compile_miopen_version % 100
<< ", but MIOPEN version in your machine is " << cudnn_dso_ver / 100
<< "." << cudnn_dso_ver % 100
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible MIOPEN "
"version.";
}
#else
size_t cudnn_dso_ver = dynload::cudnnGetVersion();
LOG_FIRST_N(WARNING, 1) << "device: " << static_cast<int>(place_.device)
<< ", cuDNN Version: " << cudnn_dso_ver / 1000
<< "." << (cudnn_dso_ver % 1000) / 100 << ".";
// Check CUDA/CUDNN version compatiblity
auto local_cuda_version =
(driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
auto compile_cuda_version =
(CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
if (local_cuda_version < compile_cuda_version) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << static_cast<int>(place_.device)
<< ". The installed Paddle is compiled with CUDA "
<< compile_cuda_version / 10 << "." << compile_cuda_version % 10
<< ", but CUDA runtime version in your machine is "
<< local_cuda_version / 10 << "." << local_cuda_version % 10
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible CUDA "
"version.";
}
#endif
}
void InitDnnWorkspace() {
PD_CHECK(allocator_ != nullptr,
"the device allocator for gpu context is nullptr.");
......@@ -350,27 +293,6 @@ struct GPUContext::Impl {
return DnnWorkspaceHandle(allocator_, stream_);
}
void InitStream() {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
hipStreamCreateWithPriority(&stream_, hipStreamDefault, 0));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
cudaStreamCreateWithPriority(&stream_, cudaStreamDefault, 0));
#endif
}
void DestoryInternalStream() {
if (owned_ && stream_ != nullptr) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream_));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream_));
#endif
}
stream_ = nullptr;
}
void SetStream(gpuStream_t stream) { stream_ = stream; }
gpuStream_t GetStream() const {
......@@ -400,129 +322,56 @@ struct GPUContext::Impl {
return eigen_device_;
}
void InitBlasHandle() {
#ifdef PADDLE_WITH_HIP
phi::dynload::rocblas_create_handle(&blas_handle_);
phi::dynload::rocblas_set_stream(blas_handle_, stream_);
#else // PADDLE_WITH_CUDA
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasCreate(&blas_handle_));
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasSetStream(blas_handle_, stream_));
blasHandle_t GetBlasHandle() {
std::call_once(flag_blas_, [=]() {
if (!blas_handle_) {
phi::InitBlasHandle(&blas_handle_, stream_);
}
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasCreate(&blas_tensor_core_handle_));
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasSetStream(blas_tensor_core_handle_, stream_));
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
if (!blas_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
#if CUDA_VERSION >= 11000
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasCreate(&blas_tf32_tensor_core_handle_));
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasSetStream(blas_tf32_tensor_core_handle_, stream_));
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
#endif // CUDA_VERSION >= 11000
#endif // CUDA_VERSION >= 9000
#endif // PADDLE_WITH_HIP
}
void DestroyInternalBlasHandle() {
#ifdef PADDLE_WITH_HIP
if (owned_ && blas_handle_ != nullptr) {
phi::dynload::rocblas_destroy_handle(blas_handle_);
blas_handle_ = nullptr;
}
#else
if (owned_ && blas_handle_ != nullptr) {
phi::dynload::cublasDestroy(blas_handle_);
blas_handle_ = nullptr;
}
if (owned_ && blas_tensor_core_handle_ != nullptr) {
phi::dynload::cublasDestroy(blas_tensor_core_handle_);
blas_tensor_core_handle_ = nullptr;
}
if (owned_ && blas_tf32_tensor_core_handle_ != nullptr) {
phi::dynload::cublasDestroy(blas_tf32_tensor_core_handle_);
blas_tf32_tensor_core_handle_ = nullptr;
}
#endif // PADDLE_WITH_HIP
}
blasHandle_t GetBlasHandle() const {
if (!blas_tf32_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
}
#endif
#endif
});
PD_CHECK(blas_handle_ != nullptr, "the gpu blas handle is nullptr.");
return blas_handle_;
}
void SetBlasHandle(blasHandle_t blas) { blas_handle_ = blas; }
void InitBlasLtHandle() {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
phi::dynload::cublasLtCreate(&blaslt_handle_);
#endif
void SetBlasTensorCoreHandle(blasHandle_t handle) {
blas_tensor_core_handle_ = handle;
}
void DestroyInternalBlasLtHandle() {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
phi::dynload::cublasLtDestroy(blaslt_handle_);
#endif
void SetBlasTF32Handle(blasHandle_t handle) {
blas_tf32_tensor_core_handle_ = handle;
}
void SetBlasLtHandle(blasLtHandle_t blaslt) { blaslt_handle_ = blaslt; }
blasLtHandle_t GetBlasLtHandle() const {
blasLtHandle_t GetBlasLtHandle() {
std::call_once(flag_blaslt_, [=]() {
if (!blaslt_handle_) phi::InitBlasLtHandle(&blaslt_handle_);
});
PD_CHECK(blaslt_handle_ != nullptr, "the gpu blasLt handle is nullptr.");
return blaslt_handle_;
}
void InitDNNHandle() {
if (phi::dynload::HasCUDNN()) {
#ifdef PADDLE_WITH_HIP
size_t miopen_major, miopen_minor, miopen_patch;
PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenGetVersion(
&miopen_major, &miopen_minor, &miopen_patch));
auto local_miopen_version =
(miopen_major * 1000 + miopen_minor * 10 + miopen_patch) / 10;
auto compile_miopen_version = MIOPEN_VERSION / 10;
if (local_miopen_version < static_cast<size_t>(compile_miopen_version)) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << place_.device
<< ". The installed Paddle is compiled with MIOPEN "
<< compile_miopen_version / 100 << "."
<< compile_miopen_version % 100
<< ", but MIOPEN version in your machine is "
<< local_miopen_version / 100 << "." << local_miopen_version % 100
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible MIOPEN "
"version.";
}
PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreate(&dnn_handle_));
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::miopenSetStream(dnn_handle_, stream_));
#else
auto local_cudnn_version = phi::dynload::cudnnGetVersion() / 100;
auto compile_cudnn_version = CUDNN_VERSION / 100;
if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << place_.device
<< ". The installed Paddle is compiled with CUDNN "
<< compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
<< ", but CUDNN version in your machine is "
<< local_cudnn_version / 10 << "." << local_cudnn_version % 10
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible CUDNN "
"version.";
}
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cudnnCreate(&dnn_handle_));
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cudnnSetStream(dnn_handle_, stream_));
#endif
} else {
dnn_handle_ = nullptr;
}
}
dnnHandle_t GetDnnHandle() {
std::call_once(flag_dnn_, [=]() {
if (!dnn_handle_) phi::InitDnnHandle(&dnn_handle_, stream_, place_);
});
PD_CHECK(dnn_handle_ != nullptr, "the gpu dnn handle is nullptr.");
return dnn_handle_;
}
......@@ -543,54 +392,16 @@ struct GPUContext::Impl {
void SetDnnHandle(dnnHandle_t handle) { dnn_handle_ = handle; }
void InitSolverHandle() {
#ifndef PADDLE_WITH_HIP
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cusolverDnCreate(&solver_handle_));
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cusolverDnSetStream(solver_handle_, stream_));
#endif
}
void DestroyInternalSolverHandle() {
#ifndef PADDLE_WITH_HIP
if (owned_ && solver_handle_ != nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cusolverDnDestroy(solver_handle_));
solver_handle_ = nullptr;
}
#endif
}
solverHandle_t GetSolverHandle() const {
solverHandle_t GetSolverHandle() {
std::call_once(flag_slover_, [=]() {
if (!solver_handle_) phi::InitSolverHandle(&solver_handle_, stream_);
});
PD_CHECK(solver_handle_ != nullptr, "the gpu solver handle is nullptr.");
return solver_handle_;
}
void SetSolverHandle(solverHandle_t handle) { solver_handle_ = handle; }
void InitSparseHandle() {
// ROCM is not yet supported
#if defined(PADDLE_WITH_CUDA)
// The generic APIs is supported from CUDA10.1
#if CUDA_VERSION >= 10010
PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseCreate(&sparse_handle_));
PADDLE_RETRY_CUDA_SUCCESS(
dynload::cusparseSetStream(sparse_handle_, stream_));
#endif
#endif
}
void DestroyInternalSparseHandle() {
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10010
if (owned_ && sparse_handle_ != nullptr) {
PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseDestroy(sparse_handle_));
sparse_handle_ = nullptr;
}
#endif
#endif
}
sparseHandle_t GetSparseHandle() const {
PD_CHECK(sparse_handle_ != nullptr, "the gpu sparse handle is nullptr.");
return sparse_handle_;
......@@ -646,8 +457,28 @@ struct GPUContext::Impl {
#endif
}
inline void CublasCall(
const std::function<void(blasHandle_t)>& callback) const {
inline void CublasCall(const std::function<void(blasHandle_t)>& callback) {
std::call_once(flag_cublas_, [=]() {
if (!blas_handle_) {
phi::InitBlasHandle(&blas_handle_, stream_);
}
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
if (!blas_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
#if CUDA_VERSION >= 11000
if (!blas_tf32_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
}
#endif
#endif
});
if (blas_tf32_tensor_core_handle_ != nullptr) {
std::lock_guard<std::mutex> guard(blas_tf32_mtx_);
callback(blas_tf32_tensor_core_handle_);
......@@ -658,7 +489,26 @@ struct GPUContext::Impl {
}
inline void TensorCoreCublasCallIfAvailable(
const std::function<void(blasHandle_t)>& callback) const {
const std::function<void(blasHandle_t)>& callback) {
std::call_once(flag_tensorcore_cublas_, [=]() {
if (!blas_handle_) phi::InitBlasHandle(&blas_handle_, stream_);
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
if (!blas_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
#if CUDA_VERSION >= 11000
if (!blas_tf32_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
}
#endif
#endif
});
if (blas_tensor_core_handle_ != nullptr) {
std::lock_guard<std::mutex> guard(blas_tensor_core_mtx_);
callback(blas_tensor_core_handle_);
......@@ -689,8 +539,7 @@ struct GPUContext::Impl {
void AddStreamCallback(const std::function<void()>& callback) const {
// NOTE(zhiqiu): better use threadpool here, otherwise "std::async" may
// launch too
// many threads and result in thread oversubscription.
// launch too many threads and result in thread oversubscription.
auto* callback_func = new std::function<void()>(std::move(callback));
auto* func = new std::function<void()>([this, callback_func] {
std::lock_guard<std::mutex> lock(stream_call_back_mtx_);
......@@ -749,6 +598,13 @@ struct GPUContext::Impl {
sparseHandle_t sparse_handle_{nullptr};
DnnWorkspaceHandle* workspace_{nullptr};
std::once_flag flag_blas_;
std::once_flag flag_blaslt_;
std::once_flag flag_dnn_;
std::once_flag flag_slover_;
std::once_flag flag_cublas_;
std::once_flag flag_tensorcore_cublas_;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
// NCCL communicator (single process version) for NCCL collective operations.
// NCCL collective operations provides fast collectives over multiple GPUs
......@@ -878,7 +734,10 @@ void GPUContext::Init() {
impl_->Init();
}
void GPUContext::SetStream(gpuStream_t stream) { impl_->SetStream(stream); }
void GPUContext::SetStream(gpuStream_t stream) {
impl_->allocator_ = const_cast<Allocator*>(&this->GetAllocator());
impl_->SetStream(stream);
}
void GPUContext::SetEigenDevice(Eigen::GpuDevice* device) {
impl_->SetEigenDevice(device);
......@@ -888,6 +747,14 @@ void GPUContext::SetBlasHandle(blasHandle_t blas) {
impl_->SetBlasHandle(blas);
}
void GPUContext::SetBlasTensorCoreHandle(blasHandle_t handle) {
impl_->SetBlasTensorCoreHandle(handle);
}
void GPUContext::SetBlasTF32Handle(blasHandle_t handle) {
impl_->SetBlasTF32Handle(handle);
}
void GPUContext::SetBlasLtHandle(blasLtHandle_t blaslt) {
impl_->SetBlasLtHandle(blaslt);
}
......
......@@ -199,6 +199,10 @@ class PADDLE_API GPUContext : public DeviceContext {
void SetBlasHandle(blasHandle_t);
void SetBlasTensorCoreHandle(blasHandle_t);
void SetBlasTF32Handle(blasHandle_t);
void SetBlasLtHandle(blasLtHandle_t);
void SetDnnHandle(dnnHandle_t);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_resources.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/allocator.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/phi/backends/dynload/cublas.h"
#include "paddle/phi/backends/dynload/cudnn.h"
#include "paddle/phi/backends/dynload/cusolver.h"
#include "paddle/phi/backends/dynload/cusparse.h"
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
#include "paddle/phi/backends/dynload/nccl.h"
#endif // !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
#endif // PADDLE_WITH_CUDA
#include "unsupported/Eigen/CXX11/Tensor"
// TODO(phi): remove fluid header.
#include "paddle/fluid/platform/enforce.h"
namespace phi {
void InitGpuProperties(Place place,
int* compute_capability,
int* runtime_version,
int* driver_version,
int* multi_process,
int* max_threads_per_mp,
int* max_threads_per_block,
std::array<int, 3>* max_grid_dim_size) {
backends::gpu::GPUDeviceGuard guard(place.GetDeviceId());
*compute_capability =
backends::gpu::GetGPUComputeCapability(place.GetDeviceId());
*multi_process = backends::gpu::GetGPUMultiProcessors(place.GetDeviceId());
*max_threads_per_mp =
backends::gpu::GetGPUMaxThreadsPerMultiProcessor(place.GetDeviceId());
*max_grid_dim_size = backends::gpu::GetGpuMaxGridDimSize(place.GetDeviceId());
*max_threads_per_block =
backends::gpu::GetGPUMaxThreadsPerBlock(place.GetDeviceId());
*driver_version = backends::gpu::GetGPUDriverVersion(place.GetDeviceId());
*runtime_version = backends::gpu::GetGPURuntimeVersion(place.GetDeviceId());
// TODO(wilber): glog may be replaced in the future?
LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: "
<< static_cast<int>(place.device)
<< ", GPU Compute Capability: "
<< *compute_capability / 10 << "."
<< *compute_capability % 10
<< ", Driver API Version: " << *driver_version / 1000
<< "." << (*driver_version % 100) / 10
<< ", Runtime API Version: "
<< *runtime_version / 1000 << "."
<< (*runtime_version % 100) / 10;
#ifdef PADDLE_WITH_HIP
size_t miopen_major, miopen_minor, miopen_patch;
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::miopenGetVersion(&miopen_major, &miopen_minor, &miopen_patch));
auto cudnn_dso_ver =
(miopen_major * 1000 + miopen_minor * 10 + miopen_patch) / 10;
auto compile_miopen_version = MIOPEN_VERSION / 10;
if (cudnn_dso_ver < static_cast<size_t>(compile_miopen_version)) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << static_cast<int>(place.device)
<< ". The installed Paddle is compiled with MIOPEN "
<< compile_miopen_version / 100 << "." << compile_miopen_version % 100
<< ", but MIOPEN version in your machine is " << cudnn_dso_ver / 100
<< "." << cudnn_dso_ver % 100
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible MIOPEN "
"version.";
}
#else
size_t cudnn_dso_ver = dynload::cudnnGetVersion();
LOG_FIRST_N(WARNING, 1) << "device: " << static_cast<int>(place.device)
<< ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
<< (cudnn_dso_ver % 1000) / 100 << ".";
// Check CUDA/CUDNN version compatiblity
auto local_cuda_version =
(*driver_version / 1000) * 10 + (*driver_version % 100) / 10;
auto compile_cuda_version =
(CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
if (local_cuda_version < compile_cuda_version) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << static_cast<int>(place.device)
<< ". The installed Paddle is compiled with CUDA "
<< compile_cuda_version / 10 << "." << compile_cuda_version % 10
<< ", but CUDA runtime version in your machine is "
<< local_cuda_version / 10 << "." << local_cuda_version % 10
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible CUDA "
"version.";
}
#endif
}
void InitStream(gpuStream_t* stream) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
hipStreamCreateWithPriority(stream, hipStreamDefault, 0));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
cudaStreamCreateWithPriority(stream, cudaStreamDefault, 0));
#endif
}
void DestoryStream(gpuStream_t stream) {
if (stream != nullptr) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream));
#endif
}
stream = nullptr;
}
void InitBlasHandle(blasHandle_t* blas_handle, gpuStream_t stream) {
#ifdef PADDLE_WITH_HIP
phi::dynload::rocblas_create_handle(blas_handle);
phi::dynload::rocblas_set_stream(*blas_handle, stream);
#else // PADDLE_WITH_CUDA
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasCreate(blas_handle));
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasSetStream(*blas_handle, stream));
#endif // PADDLE_WITH_HIP
}
void DestroyBlasHandle(blasHandle_t handle) {
#ifdef PADDLE_WITH_HIP
if (handle != nullptr) {
phi::dynload::rocblas_destroy_handle(handle);
handle = nullptr;
}
#else
if (handle != nullptr) {
phi::dynload::cublasDestroy(handle);
handle = nullptr;
}
#endif // PADDLE_WITH_HIP
}
void InitBlasLtHandle(blasLtHandle_t* blaslt_handle) {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
phi::dynload::cublasLtCreate(blaslt_handle);
#endif
}
void DestroyBlasLtHandle(blasLtHandle_t handle) {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
if (handle != nullptr) {
phi::dynload::cublasLtDestroy(handle);
handle = nullptr;
}
#endif
}
void InitDnnHandle(dnnHandle_t* handle, gpuStream_t stream, Place place) {
if (phi::dynload::HasCUDNN()) {
#ifdef PADDLE_WITH_HIP
size_t miopen_major, miopen_minor, miopen_patch;
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::miopenGetVersion(&miopen_major, &miopen_minor, &miopen_patch));
auto local_miopen_version =
(miopen_major * 1000 + miopen_minor * 10 + miopen_patch) / 10;
auto compile_miopen_version = MIOPEN_VERSION / 10;
if (local_miopen_version < static_cast<size_t>(compile_miopen_version)) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << place.device
<< ". The installed Paddle is compiled with MIOPEN "
<< compile_miopen_version / 100 << "." << compile_miopen_version % 100
<< ", but MIOPEN version in your machine is "
<< local_miopen_version / 100 << "." << local_miopen_version % 100
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible MIOPEN "
"version.";
}
PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreate(handle));
PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenSetStream(*handle, stream));
#else
auto local_cudnn_version = phi::dynload::cudnnGetVersion() / 100;
auto compile_cudnn_version = CUDNN_VERSION / 100;
if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << place.device
<< ". The installed Paddle is compiled with CUDNN "
<< compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
<< ", but CUDNN version in your machine is "
<< local_cudnn_version / 10 << "." << local_cudnn_version % 10
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible CUDNN "
"version.";
}
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cudnnCreate(handle));
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cudnnSetStream(*handle, stream));
#endif
} else {
*handle = nullptr;
}
}
void DestroyDnnHandle(dnnHandle_t handle) {
#ifdef PADDLE_WITH_HIP
if (handle != nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenDestroy(handle));
handle = nullptr;
}
#else
if (handle != nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnDestroy(handle));
handle = nullptr;
}
#endif // PADDLE_WITH_HIP
}
void InitSolverHandle(solverHandle_t* handle, gpuStream_t stream) {
#ifndef PADDLE_WITH_HIP
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cusolverDnCreate(handle));
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cusolverDnSetStream(*handle, stream));
#endif
}
void DestroySolverHandle(solverHandle_t solver_handle) {
#ifndef PADDLE_WITH_HIP
if (solver_handle != nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDestroy(solver_handle));
solver_handle = nullptr;
}
#endif
}
void InitSparseHandle(sparseHandle_t* handle, gpuStream_t stream) {
// ROCM is not yet supported
#if defined(PADDLE_WITH_CUDA)
// The generic APIs is supported from CUDA10.1
#if CUDA_VERSION >= 10010
PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseCreate(handle));
PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseSetStream(*handle, stream));
#endif
#endif
}
void DestroySparseHandle(sparseHandle_t handle) {
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10010
if (handle != nullptr) {
PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseDestroy(handle));
handle = nullptr;
}
#endif
#endif
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <array>
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/common/place.h"
namespace phi {
void InitGpuProperties(Place place,
int* compute_capability,
int* runtime_version,
int* driver_version,
int* multi_process,
int* max_threads_per_mp,
int* max_threads_per_block,
std::array<int, 3>* max_grid_dim_size);
void InitStream(gpuStream_t* stream);
void DestoryStream(gpuStream_t stream);
void InitBlasHandle(blasHandle_t* blas_handle, gpuStream_t stream);
void DestroyBlasHandle(blasHandle_t handle);
void InitBlasLtHandle(blasLtHandle_t* blaslt_handle);
void DestroyBlasLtHandle(blasLtHandle_t handle);
void InitDnnHandle(dnnHandle_t* handle, gpuStream_t stream, Place place);
void DestroyDnnHandle(dnnHandle_t handle);
void InitSolverHandle(solverHandle_t* handle, gpuStream_t stream);
void DestroySolverHandle(solverHandle_t solver_handle);
void InitSparseHandle(sparseHandle_t* handle, gpuStream_t stream);
void DestroySparseHandle(sparseHandle_t handle);
// void InitDnnWorkspace();
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册