diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index cc169c930768dcd9762b19328bf42b7d1a8803b2..39c83d24a1522c1fafdc6036b9273981c9dc7c02 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -415,14 +415,16 @@ void AnalysisPredictor::InitDeviceContexts() { gpu_context->SetHostGenerator(framework::DefaultCPUGenerator().get()); gpu_context->SetStream(gpu_resource->GetStream()); - gpu_context->SetBlasHandle(gpu_resource->GetBlasHandle()); + gpu_context->SetBlasHandle(gpu_resource->GetBlasHandleCreator()); gpu_context->SetBlasTensorCoreHandle( - gpu_resource->GetBlasTensorCoreHandle()); - gpu_context->SetBlasTF32Handle(gpu_resource->GetBlasTF32Handle()); - gpu_context->SetDnnHandle(gpu_resource->GetDnnHandle()); - gpu_context->SetSolverHandle(gpu_resource->GetSolverDnHandle()); - gpu_context->SetSparseHandle(gpu_resource->GetSparseHandle()); - gpu_context->SetEigenDevice(gpu_resource->GetGpuEigenDevice()); + gpu_resource->GetBlasTensorCoreHandleCreator()); + gpu_context->SetBlasTF32Handle( + gpu_resource->GetBlasTF32TensorCoreHandleCreator()); + gpu_context->SetDnnHandle(gpu_resource->GetDnnHandleCreator()); + gpu_context->SetSolverHandle( + gpu_resource->GetSolverDnHandleCreator()); + gpu_context->SetSparseHandle(gpu_resource->GetSparseHandleCreator()); + gpu_context->SetEigenDevice(gpu_resource->GetGpuEigenDeviceCreator()); gpu_context->SetComputeCapability( gpu_resource->GetGpuComputeCapability()); gpu_context->SetMaxThreadsPerBlock( diff --git a/paddle/fluid/inference/api/resource_manager.cc b/paddle/fluid/inference/api/resource_manager.cc index 6b3be72749d7ec0d7f81fe55ccd05f9922b936df..49d33d6750abe8f61cd44f175a71da0d74bf30ea 100644 --- a/paddle/fluid/inference/api/resource_manager.cc +++ b/paddle/fluid/inference/api/resource_manager.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/inference/api/resource_manager.h" +#include #include #include #include @@ -150,12 +151,6 @@ void GPUContextResource::InitGPUResource(void* stream) { } InitGpuProperties(); - InitGpuEigenDevice(); - InitDnnHanlde(); - InitBlasHandle(); - InitBlasLtHandle(); - InitSolverHandle(); - InitSparseHandle(); } void GPUContextResource::DestroyGPUResource() { @@ -203,22 +198,6 @@ void GPUContextResource::DestroyDnnHandle() { phi::DestroyDnnHandle(dnn_handle_); } -void GPUContextResource::InitBlasHandle() { - phi::InitBlasHandle(&blas_handle_, stream_); -#ifdef PADDLE_WITH_CUDA -#if CUDA_VERSION >= 9000 - phi::InitBlasHandle(&blas_tensor_core_handle_, stream_); - PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( - blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH)); -#endif -#if CUDA_VERSION >= 11000 - phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_); - PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( - blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH)); -#endif -#endif -} - void GPUContextResource::DestroyBlasHandle() { phi::DestroyBlasHandle(blas_handle_); phi::DestroyBlasHandle(blas_tensor_core_handle_); @@ -255,32 +234,106 @@ gpuStream_t GPUContextResource::GetStream() const { return stream_; } dnnHandle_t GPUContextResource::GetDnnHandle() const { return dnn_handle_; } +std::function GPUContextResource::GetDnnHandleCreator() { + return [&]() -> phi::dnnHandle_t { + InitDnnHanlde(); + return dnn_handle_; + }; +} + blasHandle_t GPUContextResource::GetBlasHandle() const { return blas_handle_; } +std::function GPUContextResource::GetBlasHandleCreator() { + return [&]() -> phi::blasHandle_t { + phi::InitBlasHandle(&blas_handle_, stream_); + return blas_handle_; + }; +} + blasHandle_t GPUContextResource::GetBlasTensorCoreHandle() const { return blas_tensor_core_handle_; } +std::function +GPUContextResource::GetBlasTensorCoreHandleCreator() { + return [&]() -> phi::blasHandle_t { +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 9000 + phi::InitBlasHandle(&blas_tensor_core_handle_, stream_); + PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( + blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH)); +#endif +#endif + return blas_tensor_core_handle_; + }; +} + blasHandle_t GPUContextResource::GetBlasTF32Handle() const { return blas_tf32_tensor_core_handle_; } +std::function +GPUContextResource::GetBlasTF32TensorCoreHandleCreator() { + return [&]() -> phi::blasHandle_t { +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 11000 + phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_); + PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( + blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH)); +#endif +#endif + return blas_tf32_tensor_core_handle_; + }; +} + blasLtHandle_t GPUContextResource::GetBlasLtHandle() const { return blaslt_handle_; } +std::function +GPUContextResource::GetBlasLtHandleCreator() { + return [&]() { + InitBlasLtHandle(); + return blaslt_handle_; + }; +} + phi::solverHandle_t GPUContextResource::GetSolverDnHandle() const { return solver_handle_; } +std::function +GPUContextResource::GetSolverDnHandleCreator() { + return [&]() { + InitSolverHandle(); + return solver_handle_; + }; +} + phi::sparseHandle_t GPUContextResource::GetSparseHandle() const { return sparse_handle_; } +std::function +GPUContextResource::GetSparseHandleCreator() { + return [&]() { + InitSparseHandle(); + return sparse_handle_; + }; +} + Eigen::GpuDevice* GPUContextResource::GetGpuEigenDevice() const { return gpu_eigen_device_.get(); } +std::function +GPUContextResource::GetGpuEigenDeviceCreator() { + return [&]() { + InitGpuEigenDevice(); + return gpu_eigen_device_.get(); + }; +} + int GPUContextResource::GetGpuComputeCapability() const { return compute_capability_; } @@ -311,67 +364,82 @@ void GPUContextResource::ReBindStream(gpuStream_t stream) { } void GPUContextResource::ReBindDnnHandle(gpuStream_t stream) const { + if (dnn_handle_) { #ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::miopenSetStream(dnn_handle_, stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::miopenSetStream(dnn_handle_, stream)); #else - PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cudnnSetStream(dnn_handle_, stream)); + PADDLE_RETRY_CUDA_SUCCESS( + phi::dynload::cudnnSetStream(dnn_handle_, stream)); #endif + } } void GPUContextResource::ReBindBlasHandle(gpuStream_t stream) const { + if (blas_handle_) { #ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::rocblas_set_stream(blas_handle_, stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::rocblas_set_stream(blas_handle_, stream)); #else - PADDLE_RETRY_CUDA_SUCCESS( - phi::dynload::cublasSetStream(blas_handle_, stream)); + PADDLE_RETRY_CUDA_SUCCESS( + phi::dynload::cublasSetStream(blas_handle_, stream)); #endif + } } void GPUContextResource::ReBindBlasTensorCoreHandle(gpuStream_t stream) const { + if (blas_tensor_core_handle_) { #ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::rocblas_set_stream(blas_tensor_core_handle_, stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::rocblas_set_stream(blas_tensor_core_handle_, stream)); #else - PADDLE_RETRY_CUDA_SUCCESS( - phi::dynload::cublasSetStream(blas_tensor_core_handle_, stream)); + PADDLE_RETRY_CUDA_SUCCESS( + phi::dynload::cublasSetStream(blas_tensor_core_handle_, stream)); #endif + } } void GPUContextResource::ReBindBlasTF32Handle(gpuStream_t stream) const { + if (blas_tf32_tensor_core_handle_) { #ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::rocblas_set_stream(blas_tf32_tensor_core_handle_, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::rocblas_set_stream( + blas_tf32_tensor_core_handle_, stream)); #else - PADDLE_RETRY_CUDA_SUCCESS( - phi::dynload::cublasSetStream(blas_tf32_tensor_core_handle_, stream)); + PADDLE_RETRY_CUDA_SUCCESS( + phi::dynload::cublasSetStream(blas_tf32_tensor_core_handle_, stream)); #endif + } } void GPUContextResource::ReBindSolverDnHandle(gpuStream_t stream) const { + if (solver_handle_) { #ifndef PADDLE_WITH_HIP - PADDLE_RETRY_CUDA_SUCCESS( - phi::dynload::cusolverDnSetStream(solver_handle_, stream)); + PADDLE_RETRY_CUDA_SUCCESS( + phi::dynload::cusolverDnSetStream(solver_handle_, stream)); #endif + } } void GPUContextResource::ReBindSparseHandle(gpuStream_t stream) const { + if (sparse_handle_) { #if defined(PADDLE_WITH_CUDA) // The generic APIs is supported from CUDA10.1 #if CUDA_VERSION >= 11000 - PADDLE_RETRY_CUDA_SUCCESS( - phi::dynload::cusparseSetStream(sparse_handle_, stream)); + PADDLE_RETRY_CUDA_SUCCESS( + phi::dynload::cusparseSetStream(sparse_handle_, stream)); #endif #endif + } } void GPUContextResource::ReBindEigenDevice(gpuStream_t stream, GPUPlace place) const { - auto* allocator = paddle::memory::allocation::AllocatorFacade::Instance() - .GetAllocator(place_) - .get(); - eigen_stream_->Reinitialize(stream, allocator, place); + if (eigen_stream_) { + auto* allocator = paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(place_) + .get(); + eigen_stream_->Reinitialize(stream, allocator, place); + } } #endif diff --git a/paddle/fluid/inference/api/resource_manager.h b/paddle/fluid/inference/api/resource_manager.h index 359b8f8973281e512b75d764623f4e669b3d3fbe..1b375efaf3b5f976d8b7c8323585df0965346e6f 100644 --- a/paddle/fluid/inference/api/resource_manager.h +++ b/paddle/fluid/inference/api/resource_manager.h @@ -55,6 +55,15 @@ class GPUContextResource { ~GPUContextResource(); phi::Place Place() const; + std::function GetDnnHandleCreator(); + std::function GetBlasHandleCreator(); + std::function GetBlasTensorCoreHandleCreator(); + std::function GetBlasTF32TensorCoreHandleCreator(); + std::function GetBlasLtHandleCreator(); + std::function GetSolverDnHandleCreator(); + std::function GetSparseHandleCreator(); + std::function GetGpuEigenDeviceCreator(); + gpuStream_t GetStream() const; dnnHandle_t GetDnnHandle() const; blasHandle_t GetBlasHandle() const; @@ -89,7 +98,6 @@ class GPUContextResource { void InitGpuEigenDevice(); void InitDnnHanlde(); void DestroyDnnHandle(); - void InitBlasHandle(); void DestroyBlasHandle(); void InitBlasLtHandle(); void DestroyBlasLtHandle(); diff --git a/paddle/fluid/operators/fused/resnet_basic_block_op.cc b/paddle/fluid/operators/fused/resnet_basic_block_op.cc index 5990db8147be42f3588dfd76bebc5e8e53274591..af5b76911692d6244d151cc5fa9a1d582b32325e 100644 --- a/paddle/fluid/operators/fused/resnet_basic_block_op.cc +++ b/paddle/fluid/operators/fused/resnet_basic_block_op.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/api/all.h" +#include "paddle/phi/core/ddim.h" namespace paddle { namespace operators { diff --git a/paddle/phi/backends/gpu/gpu_context.cc b/paddle/phi/backends/gpu/gpu_context.cc index 92c1fedae44af4ef776e13f9b24a8efb859e9559..12b32107eca10b654f44aeaea08ba8de953e5d02 100644 --- a/paddle/phi/backends/gpu/gpu_context.cc +++ b/paddle/phi/backends/gpu/gpu_context.cc @@ -213,7 +213,6 @@ struct GPUContext::Impl { &max_threads_per_block_, &max_grid_dim_size_); phi::InitStream(&stream_); - InitEigenDevice(); InitDnnWorkspace(); } @@ -234,7 +233,6 @@ struct GPUContext::Impl { void PartialInitWithAllocator() { owned_ = true; backends::gpu::GPUDeviceGuard guard(place_.device); - InitEigenDevice(); InitDnnWorkspace(); } @@ -317,27 +315,49 @@ struct GPUContext::Impl { void SetEigenDevice(Eigen::GpuDevice* device) { eigen_device_ = device; } - Eigen::GpuDevice* eigen_device() const { + void SetEigenDevice(std::function&& creator) { + eigen_device_creator_ = std::move(creator); + } + + Eigen::GpuDevice* eigen_device() { + std::call_once(flag_eigen_device_, [&]() { + if (!eigen_device_) { + if (!eigen_device_creator_) + InitEigenDevice(); + else + eigen_device_ = eigen_device_creator_(); + } + }); PD_CHECK(eigen_device_ != nullptr, "the gpu eigen_device is nullptr."); return eigen_device_; } blasHandle_t GetBlasHandle() { - std::call_once(flag_blas_, [=]() { + std::call_once(flag_blas_, [&]() { if (!blas_handle_) { - phi::InitBlasHandle(&blas_handle_, stream_); + if (!blas_handle_creator_) + phi::InitBlasHandle(&blas_handle_, stream_); + else + blas_handle_ = blas_handle_creator_(); } #ifdef PADDLE_WITH_CUDA #if CUDA_VERSION >= 9000 if (!blas_tensor_core_handle_) { - phi::InitBlasHandle(&blas_tensor_core_handle_, stream_); + if (!blas_tensor_core_handle_creator_) + phi::InitBlasHandle(&blas_tensor_core_handle_, stream_); + else + blas_tensor_core_handle_ = blas_tensor_core_handle_creator_(); PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH)); } #endif #if CUDA_VERSION >= 11000 if (!blas_tf32_tensor_core_handle_) { - phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_); + if (!blas_tf32_tensor_core_handle_creator_) + phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_); + else + blas_tf32_tensor_core_handle_ = + blas_tf32_tensor_core_handle_creator_(); PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH)); } @@ -350,27 +370,53 @@ struct GPUContext::Impl { void SetBlasHandle(blasHandle_t blas) { blas_handle_ = blas; } + void SetBlasHandle(std::function&& handle_creator) { + blas_handle_creator_ = std::move(handle_creator); + } + void SetBlasTensorCoreHandle(blasHandle_t handle) { blas_tensor_core_handle_ = handle; } + void SetBlasTensorCoreHandle(std::function&& handle_creator) { + blas_tensor_core_handle_creator_ = std::move(handle_creator); + } + void SetBlasTF32Handle(blasHandle_t handle) { blas_tf32_tensor_core_handle_ = handle; } + void SetBlasTF32Handle(std::function&& handle_creator) { + blas_tf32_tensor_core_handle_creator_ = std::move(handle_creator); + } + void SetBlasLtHandle(blasLtHandle_t blaslt) { blaslt_handle_ = blaslt; } + void SetBlasLtHandle(std::function&& handle_creator) { + blaslt_handle_creator_ = std::move(handle_creator); + } + blasLtHandle_t GetBlasLtHandle() { - std::call_once(flag_blaslt_, [=]() { - if (!blaslt_handle_) phi::InitBlasLtHandle(&blaslt_handle_); + std::call_once(flag_blaslt_, [&]() { + if (!blaslt_handle_) { + if (!blaslt_handle_creator_) + phi::InitBlasLtHandle(&blaslt_handle_); + else + blaslt_handle_ = blaslt_handle_creator_(); + } }); PD_CHECK(blaslt_handle_ != nullptr, "the gpu blasLt handle is nullptr."); return blaslt_handle_; } dnnHandle_t GetDnnHandle() { - std::call_once(flag_dnn_, [=]() { - if (!dnn_handle_) phi::InitDnnHandle(&dnn_handle_, stream_, place_); + std::call_once(flag_dnn_, [&]() { + if (!dnn_handle_) { + if (!dnn_handle_creator_) + phi::InitDnnHandle(&dnn_handle_, stream_, place_); + else + dnn_handle_ = dnn_handle_creator_(); + } }); PD_CHECK(dnn_handle_ != nullptr, "the gpu dnn handle is nullptr."); return dnn_handle_; @@ -392,9 +438,18 @@ struct GPUContext::Impl { void SetDnnHandle(dnnHandle_t handle) { dnn_handle_ = handle; } + void SetDnnHandle(std::function&& handle_creator) { + dnn_handle_creator_ = std::move(handle_creator); + } + solverHandle_t GetSolverHandle() { - std::call_once(flag_slover_, [=]() { - if (!solver_handle_) phi::InitSolverHandle(&solver_handle_, stream_); + std::call_once(flag_slover_, [&]() { + if (!solver_handle_) { + if (!solver_handle_creator_) + phi::InitSolverHandle(&solver_handle_, stream_); + else + solver_handle_ = solver_handle_creator_(); + } }); PD_CHECK(solver_handle_ != nullptr, "the gpu solver handle is nullptr."); return solver_handle_; @@ -402,9 +457,18 @@ struct GPUContext::Impl { void SetSolverHandle(solverHandle_t handle) { solver_handle_ = handle; } + void SetSolverHandle(std::function&& handle_creator) { + solver_handle_creator_ = std::move(handle_creator); + } + sparseHandle_t GetSparseHandle() { - std::call_once(flag_sparse_, [=]() { - if (!sparse_handle_) phi::InitSparseHandle(&sparse_handle_, stream_); + std::call_once(flag_sparse_, [&]() { + if (!sparse_handle_) { + if (!sparse_handle_creator_) + phi::InitSparseHandle(&sparse_handle_, stream_); + else + sparse_handle_ = sparse_handle_creator_(); + } }); PD_CHECK(sparse_handle_ != nullptr, "the gpu sparse handle is nullptr."); return sparse_handle_; @@ -412,6 +476,10 @@ struct GPUContext::Impl { void SetSparseHandle(sparseHandle_t handle) { sparse_handle_ = handle; } + void SetSparseHandle(std::function&& handle_creator) { + sparse_handle_creator_ = std::move(handle_creator); + } + void Wait() const { #ifdef PADDLE_WITH_HIP hipError_t e_sync = hipSuccess; @@ -461,21 +529,31 @@ struct GPUContext::Impl { } inline void CublasCall(const std::function& callback) { - std::call_once(flag_cublas_, [=]() { + std::call_once(flag_cublas_, [&]() { if (!blas_handle_) { - phi::InitBlasHandle(&blas_handle_, stream_); + if (!blas_handle_creator_) + phi::InitBlasHandle(&blas_handle_, stream_); + else + blas_handle_ = blas_handle_creator_(); } #ifdef PADDLE_WITH_CUDA #if CUDA_VERSION >= 9000 if (!blas_tensor_core_handle_) { - phi::InitBlasHandle(&blas_tensor_core_handle_, stream_); + if (!blas_tensor_core_handle_creator_) + phi::InitBlasHandle(&blas_tensor_core_handle_, stream_); + else + blas_tensor_core_handle_ = blas_tensor_core_handle_creator_(); PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH)); } #endif #if CUDA_VERSION >= 11000 if (!blas_tf32_tensor_core_handle_) { - phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_); + if (!blas_tf32_tensor_core_handle_creator_) + phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_); + else + blas_tf32_tensor_core_handle_ = + blas_tf32_tensor_core_handle_creator_(); PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH)); } @@ -493,19 +571,31 @@ struct GPUContext::Impl { inline void TensorCoreCublasCallIfAvailable( const std::function& callback) { - std::call_once(flag_tensorcore_cublas_, [=]() { - if (!blas_handle_) phi::InitBlasHandle(&blas_handle_, stream_); + std::call_once(flag_tensorcore_cublas_, [&]() { + if (!blas_handle_) { + if (!blas_handle_creator_) + phi::InitBlasHandle(&blas_handle_, stream_); + else + blas_handle_ = blas_handle_creator_(); + } #ifdef PADDLE_WITH_CUDA #if CUDA_VERSION >= 9000 if (!blas_tensor_core_handle_) { - phi::InitBlasHandle(&blas_tensor_core_handle_, stream_); + if (!blas_tensor_core_handle_creator_) + phi::InitBlasHandle(&blas_tensor_core_handle_, stream_); + else + blas_tensor_core_handle_ = blas_tensor_core_handle_creator_(); PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH)); } #endif #if CUDA_VERSION >= 11000 if (!blas_tf32_tensor_core_handle_) { - phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_); + if (!blas_tf32_tensor_core_handle_creator_) + phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_); + else + blas_tf32_tensor_core_handle_ = + blas_tf32_tensor_core_handle_creator_(); PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH)); } @@ -523,9 +613,12 @@ struct GPUContext::Impl { inline void CusparseCall( const std::function& callback) { - std::call_once(flag_sparse_, [=]() { + std::call_once(flag_sparse_, [&]() { if (!sparse_handle_) { - phi::InitSparseHandle(&sparse_handle_, stream_); + if (!sparse_handle_creator_) + phi::InitSparseHandle(&sparse_handle_, stream_); + else + sparse_handle_ = sparse_handle_creator_(); } }); std::lock_guard guard(sparse_mtx_); @@ -597,13 +690,21 @@ struct GPUContext::Impl { gpuStream_t stream_{nullptr}; Eigen::GpuDevice* eigen_device_{nullptr}; + std::function eigen_device_creator_{nullptr}; blasHandle_t blas_handle_{nullptr}; + std::function blas_handle_creator_{nullptr}; blasHandle_t blas_tensor_core_handle_{nullptr}; + std::function blas_tensor_core_handle_creator_{nullptr}; blasHandle_t blas_tf32_tensor_core_handle_{nullptr}; + std::function blas_tf32_tensor_core_handle_creator_{nullptr}; blasLtHandle_t blaslt_handle_{nullptr}; + std::function blaslt_handle_creator_{nullptr}; dnnHandle_t dnn_handle_{nullptr}; + std::function dnn_handle_creator_{nullptr}; solverHandle_t solver_handle_{nullptr}; + std::function solver_handle_creator_{nullptr}; sparseHandle_t sparse_handle_{nullptr}; + std::function sparse_handle_creator_{nullptr}; DnnWorkspaceHandle* workspace_{nullptr}; std::once_flag flag_sparse_; @@ -613,6 +714,7 @@ struct GPUContext::Impl { std::once_flag flag_slover_; std::once_flag flag_cublas_; std::once_flag flag_tensorcore_cublas_; + std::once_flag flag_eigen_device_; #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) // NCCL communicator (single process version) for NCCL collective operations. @@ -752,34 +854,66 @@ void GPUContext::SetEigenDevice(Eigen::GpuDevice* device) { impl_->SetEigenDevice(device); } +void GPUContext::SetEigenDevice(std::function&& creator) { + impl_->SetEigenDevice(std::move(creator)); +} + void GPUContext::SetBlasHandle(blasHandle_t blas) { impl_->SetBlasHandle(blas); } +void GPUContext::SetBlasHandle(std::function&& func) { + impl_->SetBlasHandle(std::move(func)); +} + void GPUContext::SetBlasTensorCoreHandle(blasHandle_t handle) { impl_->SetBlasTensorCoreHandle(handle); } +void GPUContext::SetBlasTensorCoreHandle(std::function&& func) { + impl_->SetBlasTensorCoreHandle(std::move(func)); +} + void GPUContext::SetBlasTF32Handle(blasHandle_t handle) { impl_->SetBlasTF32Handle(handle); } +void GPUContext::SetBlasTF32Handle(std::function&& func) { + impl_->SetBlasTF32Handle(std::move(func)); +} + void GPUContext::SetBlasLtHandle(blasLtHandle_t blaslt) { impl_->SetBlasLtHandle(blaslt); } +void GPUContext::SetBlasLtHandle(std::function&& func) { + impl_->SetBlasLtHandle(std::move(func)); +} + void GPUContext::SetDnnHandle(dnnHandle_t handle) { impl_->SetDnnHandle(handle); } +void GPUContext::SetDnnHandle(std::function&& func) { + impl_->SetDnnHandle(std::move(func)); +} + void GPUContext::SetSolverHandle(solverHandle_t handle) { impl_->SetSolverHandle(handle); } +void GPUContext::SetSolverHandle(std::function&& func) { + impl_->SetSolverHandle(std::move(func)); +} + void GPUContext::SetSparseHandle(sparseHandle_t handle) { impl_->SetSparseHandle(handle); } +void GPUContext::SetSparseHandle(std::function&& func) { + impl_->SetSparseHandle(std::move(func)); +} + void GPUContext::SetDnnWorkspaceHandle(DnnWorkspaceHandle* handle) { impl_->workspace_ = handle; } diff --git a/paddle/phi/backends/gpu/gpu_context.h b/paddle/phi/backends/gpu/gpu_context.h index 5246155131dbeace56d9eff26fb51e0139e2732a..a23ab611101c50021e99430b65c8184bdca5f3e0 100644 --- a/paddle/phi/backends/gpu/gpu_context.h +++ b/paddle/phi/backends/gpu/gpu_context.h @@ -197,20 +197,28 @@ class PADDLE_API GPUContext : public DeviceContext { void SetStream(gpuStream_t); void SetEigenDevice(Eigen::GpuDevice*); + void SetEigenDevice(std::function&&); void SetBlasHandle(blasHandle_t); + void SetBlasHandle(std::function&&); void SetBlasTensorCoreHandle(blasHandle_t); + void SetBlasTensorCoreHandle(std::function&&); void SetBlasTF32Handle(blasHandle_t); + void SetBlasTF32Handle(std::function&&); void SetBlasLtHandle(blasLtHandle_t); + void SetBlasLtHandle(std::function&&); void SetDnnHandle(dnnHandle_t); + void SetDnnHandle(std::function&&); void SetSolverHandle(solverHandle_t); + void SetSolverHandle(std::function&&); void SetSparseHandle(sparseHandle_t); + void SetSparseHandle(std::function&&); void SetDnnWorkspaceHandle(DnnWorkspaceHandle*);