From d5f0ed4b2a0d870c7aa71baa639483ca8cd68e64 Mon Sep 17 00:00:00 2001 From: Wilber Date: Tue, 19 Jul 2022 11:31:23 +0800 Subject: [PATCH] update (#44418) --- .../fluid/inference/api/analysis_predictor.cc | 17 +- paddle/fluid/platform/device_context.cc | 185 ---------- paddle/fluid/platform/device_context.h | 349 ------------------ 3 files changed, 4 insertions(+), 547 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 541c53c8da..86224fcc7e 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -302,11 +302,8 @@ void AnalysisPredictor::InitPlace() { place_ = paddle::platform::CUDAPlace(config_.gpu_device_id()); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (config_.thread_local_stream_enabled()) { - auto *ctx = static_cast( - platform::DeviceContextPool::Instance().Get(place_)); - VLOG(3) << "The prediction process will be completed using a separate " - "normal-priority stream on each thread."; - ctx->ResetThreadContext(platform::stream::Priority::kNormal); + LOG_FIRST_N(WARNING, 1) << "We will remove this interface in the future. " + "Please use config.SetExecStream instead."; } #endif } else if (config_.use_xpu()) { @@ -1621,14 +1618,8 @@ bool AnalysisPredictor::ZeroCopyRun() { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) { - if (stream != nullptr) { - paddle::platform::DeviceContextPool &pool = - paddle::platform::DeviceContextPool::Instance(); - auto gpu_place = place_; - auto *dev_ctx = reinterpret_cast( - pool.Get(gpu_place)); - dev_ctx->SetThreadLocalStream(stream); - } + LOG_FIRST_N(WARNING, 1) << "We will remove this interface in the future. " + "Please use config.SetExecStream instead."; return ZeroCopyRun(); } #endif diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 1e978f078d..df705d4a10 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -534,74 +534,6 @@ void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) { allocation_ = memory::Alloc(device_context_, required_workspace_bytes); } -thread_local std::unordered_map> - CUDADeviceContext::thread_ctx_; -thread_local std::mutex CUDADeviceContext::ctx_mtx_; - -void CUDAContext::InitEigenContext() { - eigen_stream_.reset(new EigenCudaStreamDevice()); - eigen_stream_->Reinitialize(&RawStream(), place_); - eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); -} - -CUDAContext::CUDAContext(const CUDAPlace& place, - const stream::Priority& priority, - const stream::StreamFlag& flag) { - place_ = place; - CUDADeviceGuard guard(place_.device); - stream_.reset(new stream::CUDAStream(place, priority, flag)); - InitEigenContext(); - InitCuBlasContext(); - InitCuDNNContext(); -#ifndef PADDLE_WITH_HIP -#if CUDA_VERSION >= 11060 - InitCuBlasLtContext(); -#endif - InitCuSparseContext(); - InitCuSolverContext(); -#endif -} - -void CUDAContext::SetStream(gpuStream_t stream) { - if (stream_->raw_stream() != stream) { - CUDADeviceGuard guard(place_.device); - DestoryCuDNNContext(); - DestoryCuBlasContext(); -#ifndef PADDLE_WITH_HIP -#if CUDA_VERSION >= 11060 - DestoryCuBlasLtContext(); -#endif - DestoryCuSolverContext(); -#endif - - stream_->SetStream(stream); - - InitEigenContext(); - InitCuBlasContext(); - InitCuDNNContext(); -#ifndef PADDLE_WITH_HIP -#if CUDA_VERSION >= 11060 - InitCuBlasLtContext(); -#endif - InitCuSolverContext(); -#endif - } -} - -CUDAContext::~CUDAContext() { - CUDADeviceGuard guard(place_.device); - DestoryCuDNNContext(); - DestoryCuBlasContext(); -#ifndef PADDLE_WITH_HIP -#if CUDA_VERSION >= 11060 - InitCuBlasLtContext(); -#endif - DestoryCuSparseContext(); - DestoryCuSolverContext(); -#endif -} - CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : phi::GPUContext(place) { phi::GPUContext::PartialInitWithoutAllocator(); cuda_stream_.reset(new stream::CUDAStream(phi::GPUContext::stream(), place)); @@ -609,123 +541,6 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : phi::GPUContext(place) { CUDADeviceContext::~CUDADeviceContext() = default; -Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { - if (thread_ctx_.count(this)) { - return context()->EigenDevice().get(); - } - return phi::GPUContext::eigen_device(); -} - -void CUDADeviceContext::Wait() const { - VLOG(4) << "CUDA context(" << this << ") Wait"; - if (thread_ctx_.count(this)) { - context()->Stream()->Wait(); - return; - } - phi::GPUContext::Wait(); -} - -#ifdef PADDLE_WITH_HIP -miopenHandle_t CUDADeviceContext::cudnn_handle() const { -#else -cudnnHandle_t CUDADeviceContext::cudnn_handle() const { -#endif - if (thread_ctx_.count(this)) { - return context()->CudnnHandle(); - } - return phi::GPUContext::cudnn_handle(); -} - -#ifdef PADDLE_WITH_HIP -rocblas_handle CUDADeviceContext::cublas_handle() const { - if (thread_ctx_.count(this)) { - return context()->CublasHandle()->GetCublasHandle(); - } - return phi::GPUContext::cublas_handle(); -} -#else -cublasHandle_t CUDADeviceContext::cublas_handle() const { - if (thread_ctx_.count(this)) { - return context()->CublasHandle()->GetCublasHandle(); - } - return phi::GPUContext::cublas_handle(); -} -#if CUDA_VERSION >= 11060 -cublasLtHandle_t CUDADeviceContext::cublaslt_handle() const { - if (thread_ctx_.count(this)) { - return context()->CublasLtHandle()->GetCublasLtHandle(); - } - return phi::GPUContext::cublaslt_handle(); -} -#endif -cusparseHandle_t CUDADeviceContext::cusparse_handle() const { - if (thread_ctx_.count(this)) { - return context()->CusparseHandle()->GetCusparseHandle(); - } - return phi::GPUContext::cusparse_handle(); -} -cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const { - if (thread_ctx_.count(this)) { - return context()->CusolverDnHandle(); - } - return phi::GPUContext::cusolver_dn_handle(); -} -#endif - -void CUDADeviceContext::RecordEvent( - gpuEvent_t ev, const std::function& callback) const { - if (thread_ctx_.count(this)) { - context()->Stream()->RecordEvent(ev, callback); - return; - } - phi::GPUContext::RecordEvent(ev, callback); -} - -void CUDADeviceContext::AddStreamCallback( - const std::function& callback) const { - if (thread_ctx_.count(this)) { - context()->Stream()->AddCallback(callback); - return; - } - phi::GPUContext::AddStreamCallback(callback); -} - -void CUDADeviceContext::WaitStreamCallback() const { - if (thread_ctx_.count(this)) { - context()->Stream()->WaitCallback(); - return; - } - phi::GPUContext::WaitStreamCallback(); -} - -phi::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { - if (thread_ctx_.count(this)) { - // return workspace_.get(); - return phi::DnnWorkspaceHandle( - memory::allocation::AllocatorFacade::Instance() - .GetAllocator(GetPlace()) - .get(), - stream()); - } - return phi::GPUContext::cudnn_workspace_handle(); -} - -gpuStream_t CUDADeviceContext::stream() const { - if (thread_ctx_.count(this)) { - return context()->RawStream(); - } - return phi::GPUContext::stream(); -} - -std::shared_ptr CUDADeviceContext::context() const { - if (!thread_ctx_.count(this)) { - PADDLE_THROW(platform::errors::PermissionDenied( - "CUDADeviceContext call context() failed, make sure in the " - "thread_local semantic.")); - } - return thread_ctx_.at(this); -} - stream::CUDAStream* CUDADeviceContext::GetCudaStream() const { return cuda_stream_.get(); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index d0443e30cf..6838cd9509 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -274,366 +274,17 @@ struct DefaultDeviceContextType { class CudnnWorkspaceHandle; class EigenCudaStreamDevice; -class CUDAContext { - public: - CUDAContext() = default; - explicit CUDAContext( - const CUDAPlace& place, - const stream::Priority& priority = stream::Priority::kNormal, - const stream::StreamFlag& flag = stream::StreamFlag::kDefaultFlag); - - ~CUDAContext(); - - const CUDAPlace& Place() const { return place_; } - - const std::unique_ptr& EigenDevice() const { - return eigen_device_; - } - - const std::unique_ptr& EigenStream() const { - return eigen_stream_; - } - - const std::unique_ptr& Stream() const { return stream_; } - - stream::CUDAStream* SetStream(stream::CUDAStream* new_stream_ptr) { - auto* old_stream_ptr = stream_.release(); - stream_.reset(new_stream_ptr); - return old_stream_ptr; - } - - void SetStream(gpuStream_t stream); - - const gpuStream_t& RawStream() { return stream_->raw_stream(); } - -#ifdef PADDLE_WITH_HIP - const miopenHandle_t& CudnnHandle() const { return cudnn_handle_; } -#else - const cudnnHandle_t& CudnnHandle() const { return cudnn_handle_; } -#endif - -#ifndef PADDLE_WITH_HIP - const cusolverDnHandle_t& CusolverDnHandle() const { - return cusolver_dn_handle_; - } -#endif - - const std::unique_ptr& CublasHandle() const { - return cublas_handle_; - } - - const std::unique_ptr& CublasTensorCoreHandle() const { - return cublas_tensor_core_handle_; - } - -#ifndef PADDLE_WITH_HIP -#if CUDA_VERSION >= 11060 - const std::unique_ptr& CublasLtHandle() const { - return cublaslt_handle_; - } -#endif - - const std::unique_ptr& CusparseHandle() const { - return cusparse_handle_; - } -#endif - - /*! \brief Call cublas function safely. */ - inline void CublasCall( - const std::function& callback) const { - if (cublas_tf32_tensor_core_handle_) { - cublas_tf32_tensor_core_handle_->Call(callback); - } else { - cublas_handle_->Call(callback); - } - } - -#ifndef PADDLE_WITH_HIP -#if CUDA_VERSION >= 11060 - /*! \brief Call cublasLt function safely. */ - inline void CublasLtCall( - const std::function& callback) const { - cublaslt_handle_->Call(callback); - } -#endif - - /*! \brief Call cusparse function safely. */ - inline void CusparseCall( - const std::function& callback) const { - cusparse_handle_->Call(callback); - } -#endif - - /*! \brief Check whether tensor core is supported */ - bool tensor_core_available() const; - - /*! \brief Call cublas function with Tensor Core safely. If - Tensor Core is not available, use DEFAULT_MATH instead. */ - inline void TensorCoreCublasCallIfAvailable( - const std::function& callback) const { - if (cublas_tensor_core_handle_) { - cublas_tensor_core_handle_->Call(callback); - } else { - cublas_handle_->Call(callback); - } - } - - private: - void InitEigenContext(); - -#ifdef PADDLE_WITH_HIP - void InitCuBlasContext() { - cublas_handle_.reset(new CublasHandleHolder(RawStream())); - } -#else - void InitCuBlasContext() { - cublas_handle_.reset( - new CublasHandleHolder(RawStream(), CUBLAS_DEFAULT_MATH)); - if (TensorCoreAvailable()) { -#if CUDA_VERSION >= 9000 - cublas_tensor_core_handle_.reset( - new CublasHandleHolder(RawStream(), CUBLAS_TENSOR_OP_MATH)); -#if CUDA_VERSION >= 11000 - cublas_tf32_tensor_core_handle_.reset( - new CublasHandleHolder(RawStream(), CUBLAS_TF32_TENSOR_OP_MATH)); -#endif // CUDA_VERSION >= 11000 -#endif // CUDA_VERSION >= 9000 - } - } -#endif - -#ifndef PADDLE_WITH_HIP -#if CUDA_VERSION >= 11060 - void InitCuBlasLtContext() { - cublaslt_handle_.reset(new CublasLtHandleHolder()); - } -#endif - - void InitCuSparseContext() { - cusparse_handle_.reset(new CusparseHandleHolder(RawStream())); - } -#endif - - void InitCuDNNContext() { - if (dynload::HasCUDNN()) { -#ifdef PADDLE_WITH_HIP - size_t miopen_major, miopen_minor, miopen_patch; - PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenGetVersion( - &miopen_major, &miopen_minor, &miopen_patch)); - auto local_miopen_version = - (miopen_major * 1000 + miopen_minor * 10 + miopen_patch) / 10; - auto compile_miopen_version = MIOPEN_VERSION / 10; - if (local_miopen_version < static_cast(compile_miopen_version)) { - LOG_FIRST_N(WARNING, 1) - << "WARNING: device: " << place_.device - << ". The installed Paddle is compiled with MIOPEN " - << compile_miopen_version / 100 << "." - << compile_miopen_version % 100 - << ", but MIOPEN version in your machine is " - << local_miopen_version / 100 << "." << local_miopen_version % 100 - << ", which may cause serious incompatible bug. " - << "Please recompile or reinstall Paddle with compatible MIOPEN " - "version."; - } - PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreate(&cudnn_handle_)); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::miopenSetStream(cudnn_handle_, RawStream())); -#else - auto local_cudnn_version = dynload::cudnnGetVersion() / 100; - auto compile_cudnn_version = CUDNN_VERSION / 100; - if (local_cudnn_version < static_cast(compile_cudnn_version)) { - LOG_FIRST_N(WARNING, 1) - << "WARNING: device: " << place_.device - << ". The installed Paddle is compiled with CUDNN " - << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10 - << ", but CUDNN version in your machine is " - << local_cudnn_version / 10 << "." << local_cudnn_version % 10 - << ", which may cause serious incompatible bug. " - << "Please recompile or reinstall Paddle with compatible CUDNN " - "version."; - } - PADDLE_RETRY_CUDA_SUCCESS(dynload::cudnnCreate(&cudnn_handle_)); - PADDLE_RETRY_CUDA_SUCCESS( - dynload::cudnnSetStream(cudnn_handle_, RawStream())); -#endif - } else { - cudnn_handle_ = nullptr; - } - } - -#ifndef PADDLE_WITH_HIP - void InitCuSolverContext() { - PADDLE_RETRY_CUDA_SUCCESS(dynload::cusolverDnCreate(&cusolver_dn_handle_)); - PADDLE_RETRY_CUDA_SUCCESS( - dynload::cusolverDnSetStream(cusolver_dn_handle_, RawStream())); - } -#endif - - void DestoryCuDNNContext() { - if (cudnn_handle_) { -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenDestroy(cudnn_handle_)); -#else - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cudnnDestroy(cudnn_handle_)); -#endif - } - cudnn_handle_ = nullptr; - } - - void DestoryCuBlasContext() { - cublas_handle_.reset(); - cublas_tensor_core_handle_.reset(); - cublas_tf32_tensor_core_handle_.reset(); - } - -#ifndef PADDLE_WITH_HIP -#if CUDA_VERSION >= 11060 - void DestoryCuBlasLtContext() { cublaslt_handle_.reset(); } -#endif - - void DestoryCuSparseContext() { cusparse_handle_.reset(); } -#endif - -#ifndef PADDLE_WITH_HIP - void DestoryCuSolverContext() { - if (cusolver_dn_handle_) { - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cusolverDnDestroy(cusolver_dn_handle_)); - } - } -#endif - - CUDAPlace place_; - std::unique_ptr eigen_device_; - std::unique_ptr eigen_stream_; - std::unique_ptr stream_; -#ifdef PADDLE_WITH_HIP - miopenHandle_t cudnn_handle_; -#else - cudnnHandle_t cudnn_handle_; -#endif - std::unique_ptr cublas_handle_; - std::unique_ptr cublas_tensor_core_handle_; - std::unique_ptr cublas_tf32_tensor_core_handle_; -#ifndef PADDLE_WITH_HIP -#if CUDA_VERSION >= 11060 - std::unique_ptr cublaslt_handle_; -#endif - cusolverDnHandle_t cusolver_dn_handle_; - std::unique_ptr cusparse_handle_; -#endif - DISABLE_COPY_AND_ASSIGN(CUDAContext); -}; - class CUDADeviceContext : public phi::GPUContext { public: explicit CUDADeviceContext(CUDAPlace place); virtual ~CUDADeviceContext(); - /*! \brief Wait for all operations completion in the stream. */ - void Wait() const override; - - /*! \brief Return eigen device in the device context. */ - Eigen::GpuDevice* eigen_device() const; - - /*! \brief Call cublas function safely. */ - inline void CublasCall( - const std::function& callback) const { - if (!thread_ctx_.count(this)) { - phi::GPUContext::CublasCall(callback); - return; - } - return context()->CublasCall(callback); - } - -#ifndef PADDLE_WITH_HIP - /*! \brief Call cusparse function safely. */ - inline void CusparseCall( - const std::function& callback) const { - if (!thread_ctx_.count(this)) { - phi::GPUContext::CusparseCall(callback); - return; - } - context()->CusparseCall(callback); - } -#endif - - /*! \brief Call cublas function with Tensor Core safely. If - Tensor Core is not available, use DEFAULT_MATH instead. */ - inline void TensorCoreCublasCallIfAvailable( - const std::function& callback) const { - if (!thread_ctx_.count(this)) { - phi::GPUContext::TensorCoreCublasCallIfAvailable(callback); - return; - } - context()->TensorCoreCublasCallIfAvailable(callback); - } - -/*! \brief Return cudnn handle in the device context. */ -#ifdef PADDLE_WITH_HIP - miopenHandle_t cudnn_handle() const; -#else - cudnnHandle_t cudnn_handle() const; -#endif - -/*! \brief Return cublas handle in the device context. */ -#ifdef PADDLE_WITH_HIP - rocblas_handle cublas_handle() const; -#else - cublasHandle_t cublas_handle() const; - cublasLtHandle_t cublaslt_handle() const; - cusparseHandle_t cusparse_handle() const; -#endif - -#ifndef PADDLE_WITH_HIP - cusolverDnHandle_t cusolver_dn_handle() const; -#endif - - /*! \brief Return a cudnn workspace handle to call multiple cudnn - * functions without interrupting by other threads. - * Once the first cudnn function is called by the handle, a lock - * would be acquired to prevent other threads from accessing the - * workspace. Once the handle is destructed, the lock would be released. - * CudnnWorkspaceHandle is an RAII object to implement thread-safe - * sequential cudnn function calls. */ - phi::DnnWorkspaceHandle cudnn_workspace_handle() const; - - /*! \brief Return cuda stream in the device context. */ - gpuStream_t stream() const; - - void RecordEvent(gpuEvent_t ev, const std::function& callback) const; - - void AddStreamCallback(const std::function& callback) const; - - void WaitStreamCallback() const; - - void ResetThreadContext(const stream::Priority& priority) { - std::lock_guard guard(ctx_mtx_); - thread_ctx_[this].reset(new CUDAContext(this->GetPlace(), priority)); - } - - std::shared_ptr context() const; - - // Note: Can only be used under thread_local semantics. - void SetThreadLocalStream(const gpuStream_t stream) { - thread_ctx_.at(this)->SetStream(stream); - } - // NOTE: Just for compatibility with the past, please delete if there is an // elegant way. stream::CUDAStream* GetCudaStream() const; stream::CUDAStream* SetCudaStream(stream::CUDAStream*); private: - // The thread_local static variable will be released before the - // global static variable, so avoid using it in dtor. - static thread_local std::unordered_map> - thread_ctx_; - static thread_local std::mutex ctx_mtx_; - - mutable std::mutex cudnn_handle_mtx_; - // NOTE: Just for compatibility with the past, please delete if there is an // elegant way. std::unique_ptr cuda_stream_; -- GitLab