未验证 提交 1892a441 编写于 作者: W Wilber 提交者: GitHub

inference multi stream support handle lazy init. (#44563)

* multi stream support handle lazy init.

* support eigen lazy init

* update

* fix ci problem
上级 65ad58b6
...@@ -415,14 +415,16 @@ void AnalysisPredictor::InitDeviceContexts() { ...@@ -415,14 +415,16 @@ void AnalysisPredictor::InitDeviceContexts() {
gpu_context->SetHostGenerator(framework::DefaultCPUGenerator().get()); gpu_context->SetHostGenerator(framework::DefaultCPUGenerator().get());
gpu_context->SetStream(gpu_resource->GetStream()); gpu_context->SetStream(gpu_resource->GetStream());
gpu_context->SetBlasHandle(gpu_resource->GetBlasHandle()); gpu_context->SetBlasHandle(gpu_resource->GetBlasHandleCreator());
gpu_context->SetBlasTensorCoreHandle( gpu_context->SetBlasTensorCoreHandle(
gpu_resource->GetBlasTensorCoreHandle()); gpu_resource->GetBlasTensorCoreHandleCreator());
gpu_context->SetBlasTF32Handle(gpu_resource->GetBlasTF32Handle()); gpu_context->SetBlasTF32Handle(
gpu_context->SetDnnHandle(gpu_resource->GetDnnHandle()); gpu_resource->GetBlasTF32TensorCoreHandleCreator());
gpu_context->SetSolverHandle(gpu_resource->GetSolverDnHandle()); gpu_context->SetDnnHandle(gpu_resource->GetDnnHandleCreator());
gpu_context->SetSparseHandle(gpu_resource->GetSparseHandle()); gpu_context->SetSolverHandle(
gpu_context->SetEigenDevice(gpu_resource->GetGpuEigenDevice()); gpu_resource->GetSolverDnHandleCreator());
gpu_context->SetSparseHandle(gpu_resource->GetSparseHandleCreator());
gpu_context->SetEigenDevice(gpu_resource->GetGpuEigenDeviceCreator());
gpu_context->SetComputeCapability( gpu_context->SetComputeCapability(
gpu_resource->GetGpuComputeCapability()); gpu_resource->GetGpuComputeCapability());
gpu_context->SetMaxThreadsPerBlock( gpu_context->SetMaxThreadsPerBlock(
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/api/resource_manager.h" #include "paddle/fluid/inference/api/resource_manager.h"
#include <functional>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <unordered_map> #include <unordered_map>
...@@ -150,12 +151,6 @@ void GPUContextResource::InitGPUResource(void* stream) { ...@@ -150,12 +151,6 @@ void GPUContextResource::InitGPUResource(void* stream) {
} }
InitGpuProperties(); InitGpuProperties();
InitGpuEigenDevice();
InitDnnHanlde();
InitBlasHandle();
InitBlasLtHandle();
InitSolverHandle();
InitSparseHandle();
} }
void GPUContextResource::DestroyGPUResource() { void GPUContextResource::DestroyGPUResource() {
...@@ -203,22 +198,6 @@ void GPUContextResource::DestroyDnnHandle() { ...@@ -203,22 +198,6 @@ void GPUContextResource::DestroyDnnHandle() {
phi::DestroyDnnHandle(dnn_handle_); phi::DestroyDnnHandle(dnn_handle_);
} }
void GPUContextResource::InitBlasHandle() {
phi::InitBlasHandle(&blas_handle_, stream_);
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
#endif
#if CUDA_VERSION >= 11000
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
#endif
#endif
}
void GPUContextResource::DestroyBlasHandle() { void GPUContextResource::DestroyBlasHandle() {
phi::DestroyBlasHandle(blas_handle_); phi::DestroyBlasHandle(blas_handle_);
phi::DestroyBlasHandle(blas_tensor_core_handle_); phi::DestroyBlasHandle(blas_tensor_core_handle_);
...@@ -255,32 +234,106 @@ gpuStream_t GPUContextResource::GetStream() const { return stream_; } ...@@ -255,32 +234,106 @@ gpuStream_t GPUContextResource::GetStream() const { return stream_; }
dnnHandle_t GPUContextResource::GetDnnHandle() const { return dnn_handle_; } dnnHandle_t GPUContextResource::GetDnnHandle() const { return dnn_handle_; }
std::function<phi::dnnHandle_t()> GPUContextResource::GetDnnHandleCreator() {
return [&]() -> phi::dnnHandle_t {
InitDnnHanlde();
return dnn_handle_;
};
}
blasHandle_t GPUContextResource::GetBlasHandle() const { return blas_handle_; } blasHandle_t GPUContextResource::GetBlasHandle() const { return blas_handle_; }
std::function<phi::blasHandle_t()> GPUContextResource::GetBlasHandleCreator() {
return [&]() -> phi::blasHandle_t {
phi::InitBlasHandle(&blas_handle_, stream_);
return blas_handle_;
};
}
blasHandle_t GPUContextResource::GetBlasTensorCoreHandle() const { blasHandle_t GPUContextResource::GetBlasTensorCoreHandle() const {
return blas_tensor_core_handle_; return blas_tensor_core_handle_;
} }
std::function<phi::blasHandle_t()>
GPUContextResource::GetBlasTensorCoreHandleCreator() {
return [&]() -> phi::blasHandle_t {
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
#endif
#endif
return blas_tensor_core_handle_;
};
}
blasHandle_t GPUContextResource::GetBlasTF32Handle() const { blasHandle_t GPUContextResource::GetBlasTF32Handle() const {
return blas_tf32_tensor_core_handle_; return blas_tf32_tensor_core_handle_;
} }
std::function<phi::blasHandle_t()>
GPUContextResource::GetBlasTF32TensorCoreHandleCreator() {
return [&]() -> phi::blasHandle_t {
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 11000
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
#endif
#endif
return blas_tf32_tensor_core_handle_;
};
}
blasLtHandle_t GPUContextResource::GetBlasLtHandle() const { blasLtHandle_t GPUContextResource::GetBlasLtHandle() const {
return blaslt_handle_; return blaslt_handle_;
} }
std::function<phi::blasLtHandle_t()>
GPUContextResource::GetBlasLtHandleCreator() {
return [&]() {
InitBlasLtHandle();
return blaslt_handle_;
};
}
phi::solverHandle_t GPUContextResource::GetSolverDnHandle() const { phi::solverHandle_t GPUContextResource::GetSolverDnHandle() const {
return solver_handle_; return solver_handle_;
} }
std::function<phi::solverHandle_t()>
GPUContextResource::GetSolverDnHandleCreator() {
return [&]() {
InitSolverHandle();
return solver_handle_;
};
}
phi::sparseHandle_t GPUContextResource::GetSparseHandle() const { phi::sparseHandle_t GPUContextResource::GetSparseHandle() const {
return sparse_handle_; return sparse_handle_;
} }
std::function<phi::sparseHandle_t()>
GPUContextResource::GetSparseHandleCreator() {
return [&]() {
InitSparseHandle();
return sparse_handle_;
};
}
Eigen::GpuDevice* GPUContextResource::GetGpuEigenDevice() const { Eigen::GpuDevice* GPUContextResource::GetGpuEigenDevice() const {
return gpu_eigen_device_.get(); return gpu_eigen_device_.get();
} }
std::function<Eigen::GpuDevice*()>
GPUContextResource::GetGpuEigenDeviceCreator() {
return [&]() {
InitGpuEigenDevice();
return gpu_eigen_device_.get();
};
}
int GPUContextResource::GetGpuComputeCapability() const { int GPUContextResource::GetGpuComputeCapability() const {
return compute_capability_; return compute_capability_;
} }
...@@ -311,67 +364,82 @@ void GPUContextResource::ReBindStream(gpuStream_t stream) { ...@@ -311,67 +364,82 @@ void GPUContextResource::ReBindStream(gpuStream_t stream) {
} }
void GPUContextResource::ReBindDnnHandle(gpuStream_t stream) const { void GPUContextResource::ReBindDnnHandle(gpuStream_t stream) const {
if (dnn_handle_) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::miopenSetStream(dnn_handle_, stream)); phi::dynload::miopenSetStream(dnn_handle_, stream));
#else #else
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cudnnSetStream(dnn_handle_, stream)); PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cudnnSetStream(dnn_handle_, stream));
#endif #endif
}
} }
void GPUContextResource::ReBindBlasHandle(gpuStream_t stream) const { void GPUContextResource::ReBindBlasHandle(gpuStream_t stream) const {
if (blas_handle_) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::rocblas_set_stream(blas_handle_, stream)); phi::dynload::rocblas_set_stream(blas_handle_, stream));
#else #else
PADDLE_RETRY_CUDA_SUCCESS( PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasSetStream(blas_handle_, stream)); phi::dynload::cublasSetStream(blas_handle_, stream));
#endif #endif
}
} }
void GPUContextResource::ReBindBlasTensorCoreHandle(gpuStream_t stream) const { void GPUContextResource::ReBindBlasTensorCoreHandle(gpuStream_t stream) const {
if (blas_tensor_core_handle_) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::rocblas_set_stream(blas_tensor_core_handle_, stream)); phi::dynload::rocblas_set_stream(blas_tensor_core_handle_, stream));
#else #else
PADDLE_RETRY_CUDA_SUCCESS( PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasSetStream(blas_tensor_core_handle_, stream)); phi::dynload::cublasSetStream(blas_tensor_core_handle_, stream));
#endif #endif
}
} }
void GPUContextResource::ReBindBlasTF32Handle(gpuStream_t stream) const { void GPUContextResource::ReBindBlasTF32Handle(gpuStream_t stream) const {
if (blas_tf32_tensor_core_handle_) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::rocblas_set_stream(
phi::dynload::rocblas_set_stream(blas_tf32_tensor_core_handle_, stream)); blas_tf32_tensor_core_handle_, stream));
#else #else
PADDLE_RETRY_CUDA_SUCCESS( PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasSetStream(blas_tf32_tensor_core_handle_, stream)); phi::dynload::cublasSetStream(blas_tf32_tensor_core_handle_, stream));
#endif #endif
}
} }
void GPUContextResource::ReBindSolverDnHandle(gpuStream_t stream) const { void GPUContextResource::ReBindSolverDnHandle(gpuStream_t stream) const {
if (solver_handle_) {
#ifndef PADDLE_WITH_HIP #ifndef PADDLE_WITH_HIP
PADDLE_RETRY_CUDA_SUCCESS( PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cusolverDnSetStream(solver_handle_, stream)); phi::dynload::cusolverDnSetStream(solver_handle_, stream));
#endif #endif
}
} }
void GPUContextResource::ReBindSparseHandle(gpuStream_t stream) const { void GPUContextResource::ReBindSparseHandle(gpuStream_t stream) const {
if (sparse_handle_) {
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
// The generic APIs is supported from CUDA10.1 // The generic APIs is supported from CUDA10.1
#if CUDA_VERSION >= 11000 #if CUDA_VERSION >= 11000
PADDLE_RETRY_CUDA_SUCCESS( PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cusparseSetStream(sparse_handle_, stream)); phi::dynload::cusparseSetStream(sparse_handle_, stream));
#endif #endif
#endif #endif
}
} }
void GPUContextResource::ReBindEigenDevice(gpuStream_t stream, void GPUContextResource::ReBindEigenDevice(gpuStream_t stream,
GPUPlace place) const { GPUPlace place) const {
auto* allocator = paddle::memory::allocation::AllocatorFacade::Instance() if (eigen_stream_) {
.GetAllocator(place_) auto* allocator = paddle::memory::allocation::AllocatorFacade::Instance()
.get(); .GetAllocator(place_)
eigen_stream_->Reinitialize(stream, allocator, place); .get();
eigen_stream_->Reinitialize(stream, allocator, place);
}
} }
#endif #endif
......
...@@ -55,6 +55,15 @@ class GPUContextResource { ...@@ -55,6 +55,15 @@ class GPUContextResource {
~GPUContextResource(); ~GPUContextResource();
phi::Place Place() const; phi::Place Place() const;
std::function<phi::dnnHandle_t()> GetDnnHandleCreator();
std::function<phi::blasHandle_t()> GetBlasHandleCreator();
std::function<phi::blasHandle_t()> GetBlasTensorCoreHandleCreator();
std::function<phi::blasHandle_t()> GetBlasTF32TensorCoreHandleCreator();
std::function<phi::blasLtHandle_t()> GetBlasLtHandleCreator();
std::function<phi::solverHandle_t()> GetSolverDnHandleCreator();
std::function<phi::sparseHandle_t()> GetSparseHandleCreator();
std::function<Eigen::GpuDevice*()> GetGpuEigenDeviceCreator();
gpuStream_t GetStream() const; gpuStream_t GetStream() const;
dnnHandle_t GetDnnHandle() const; dnnHandle_t GetDnnHandle() const;
blasHandle_t GetBlasHandle() const; blasHandle_t GetBlasHandle() const;
...@@ -89,7 +98,6 @@ class GPUContextResource { ...@@ -89,7 +98,6 @@ class GPUContextResource {
void InitGpuEigenDevice(); void InitGpuEigenDevice();
void InitDnnHanlde(); void InitDnnHanlde();
void DestroyDnnHandle(); void DestroyDnnHandle();
void InitBlasHandle();
void DestroyBlasHandle(); void DestroyBlasHandle();
void InitBlasLtHandle(); void InitBlasLtHandle();
void DestroyBlasLtHandle(); void DestroyBlasLtHandle();
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/api/all.h" #include "paddle/phi/core/ddim.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -213,7 +213,6 @@ struct GPUContext::Impl { ...@@ -213,7 +213,6 @@ struct GPUContext::Impl {
&max_threads_per_block_, &max_threads_per_block_,
&max_grid_dim_size_); &max_grid_dim_size_);
phi::InitStream(&stream_); phi::InitStream(&stream_);
InitEigenDevice();
InitDnnWorkspace(); InitDnnWorkspace();
} }
...@@ -234,7 +233,6 @@ struct GPUContext::Impl { ...@@ -234,7 +233,6 @@ struct GPUContext::Impl {
void PartialInitWithAllocator() { void PartialInitWithAllocator() {
owned_ = true; owned_ = true;
backends::gpu::GPUDeviceGuard guard(place_.device); backends::gpu::GPUDeviceGuard guard(place_.device);
InitEigenDevice();
InitDnnWorkspace(); InitDnnWorkspace();
} }
...@@ -317,27 +315,49 @@ struct GPUContext::Impl { ...@@ -317,27 +315,49 @@ struct GPUContext::Impl {
void SetEigenDevice(Eigen::GpuDevice* device) { eigen_device_ = device; } void SetEigenDevice(Eigen::GpuDevice* device) { eigen_device_ = device; }
Eigen::GpuDevice* eigen_device() const { void SetEigenDevice(std::function<Eigen::GpuDevice*()>&& creator) {
eigen_device_creator_ = std::move(creator);
}
Eigen::GpuDevice* eigen_device() {
std::call_once(flag_eigen_device_, [&]() {
if (!eigen_device_) {
if (!eigen_device_creator_)
InitEigenDevice();
else
eigen_device_ = eigen_device_creator_();
}
});
PD_CHECK(eigen_device_ != nullptr, "the gpu eigen_device is nullptr."); PD_CHECK(eigen_device_ != nullptr, "the gpu eigen_device is nullptr.");
return eigen_device_; return eigen_device_;
} }
blasHandle_t GetBlasHandle() { blasHandle_t GetBlasHandle() {
std::call_once(flag_blas_, [=]() { std::call_once(flag_blas_, [&]() {
if (!blas_handle_) { if (!blas_handle_) {
phi::InitBlasHandle(&blas_handle_, stream_); if (!blas_handle_creator_)
phi::InitBlasHandle(&blas_handle_, stream_);
else
blas_handle_ = blas_handle_creator_();
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
if (!blas_tensor_core_handle_) { if (!blas_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_); if (!blas_tensor_core_handle_creator_)
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
else
blas_tensor_core_handle_ = blas_tensor_core_handle_creator_();
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH)); blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
} }
#endif #endif
#if CUDA_VERSION >= 11000 #if CUDA_VERSION >= 11000
if (!blas_tf32_tensor_core_handle_) { if (!blas_tf32_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_); if (!blas_tf32_tensor_core_handle_creator_)
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
else
blas_tf32_tensor_core_handle_ =
blas_tf32_tensor_core_handle_creator_();
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH)); blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
} }
...@@ -350,27 +370,53 @@ struct GPUContext::Impl { ...@@ -350,27 +370,53 @@ struct GPUContext::Impl {
void SetBlasHandle(blasHandle_t blas) { blas_handle_ = blas; } void SetBlasHandle(blasHandle_t blas) { blas_handle_ = blas; }
void SetBlasHandle(std::function<blasHandle_t()>&& handle_creator) {
blas_handle_creator_ = std::move(handle_creator);
}
void SetBlasTensorCoreHandle(blasHandle_t handle) { void SetBlasTensorCoreHandle(blasHandle_t handle) {
blas_tensor_core_handle_ = handle; blas_tensor_core_handle_ = handle;
} }
void SetBlasTensorCoreHandle(std::function<blasHandle_t()>&& handle_creator) {
blas_tensor_core_handle_creator_ = std::move(handle_creator);
}
void SetBlasTF32Handle(blasHandle_t handle) { void SetBlasTF32Handle(blasHandle_t handle) {
blas_tf32_tensor_core_handle_ = handle; blas_tf32_tensor_core_handle_ = handle;
} }
void SetBlasTF32Handle(std::function<blasHandle_t()>&& handle_creator) {
blas_tf32_tensor_core_handle_creator_ = std::move(handle_creator);
}
void SetBlasLtHandle(blasLtHandle_t blaslt) { blaslt_handle_ = blaslt; } void SetBlasLtHandle(blasLtHandle_t blaslt) { blaslt_handle_ = blaslt; }
void SetBlasLtHandle(std::function<blasLtHandle_t()>&& handle_creator) {
blaslt_handle_creator_ = std::move(handle_creator);
}
blasLtHandle_t GetBlasLtHandle() { blasLtHandle_t GetBlasLtHandle() {
std::call_once(flag_blaslt_, [=]() { std::call_once(flag_blaslt_, [&]() {
if (!blaslt_handle_) phi::InitBlasLtHandle(&blaslt_handle_); if (!blaslt_handle_) {
if (!blaslt_handle_creator_)
phi::InitBlasLtHandle(&blaslt_handle_);
else
blaslt_handle_ = blaslt_handle_creator_();
}
}); });
PD_CHECK(blaslt_handle_ != nullptr, "the gpu blasLt handle is nullptr."); PD_CHECK(blaslt_handle_ != nullptr, "the gpu blasLt handle is nullptr.");
return blaslt_handle_; return blaslt_handle_;
} }
dnnHandle_t GetDnnHandle() { dnnHandle_t GetDnnHandle() {
std::call_once(flag_dnn_, [=]() { std::call_once(flag_dnn_, [&]() {
if (!dnn_handle_) phi::InitDnnHandle(&dnn_handle_, stream_, place_); if (!dnn_handle_) {
if (!dnn_handle_creator_)
phi::InitDnnHandle(&dnn_handle_, stream_, place_);
else
dnn_handle_ = dnn_handle_creator_();
}
}); });
PD_CHECK(dnn_handle_ != nullptr, "the gpu dnn handle is nullptr."); PD_CHECK(dnn_handle_ != nullptr, "the gpu dnn handle is nullptr.");
return dnn_handle_; return dnn_handle_;
...@@ -392,9 +438,18 @@ struct GPUContext::Impl { ...@@ -392,9 +438,18 @@ struct GPUContext::Impl {
void SetDnnHandle(dnnHandle_t handle) { dnn_handle_ = handle; } void SetDnnHandle(dnnHandle_t handle) { dnn_handle_ = handle; }
void SetDnnHandle(std::function<dnnHandle_t()>&& handle_creator) {
dnn_handle_creator_ = std::move(handle_creator);
}
solverHandle_t GetSolverHandle() { solverHandle_t GetSolverHandle() {
std::call_once(flag_slover_, [=]() { std::call_once(flag_slover_, [&]() {
if (!solver_handle_) phi::InitSolverHandle(&solver_handle_, stream_); if (!solver_handle_) {
if (!solver_handle_creator_)
phi::InitSolverHandle(&solver_handle_, stream_);
else
solver_handle_ = solver_handle_creator_();
}
}); });
PD_CHECK(solver_handle_ != nullptr, "the gpu solver handle is nullptr."); PD_CHECK(solver_handle_ != nullptr, "the gpu solver handle is nullptr.");
return solver_handle_; return solver_handle_;
...@@ -402,9 +457,18 @@ struct GPUContext::Impl { ...@@ -402,9 +457,18 @@ struct GPUContext::Impl {
void SetSolverHandle(solverHandle_t handle) { solver_handle_ = handle; } void SetSolverHandle(solverHandle_t handle) { solver_handle_ = handle; }
void SetSolverHandle(std::function<solverHandle_t()>&& handle_creator) {
solver_handle_creator_ = std::move(handle_creator);
}
sparseHandle_t GetSparseHandle() { sparseHandle_t GetSparseHandle() {
std::call_once(flag_sparse_, [=]() { std::call_once(flag_sparse_, [&]() {
if (!sparse_handle_) phi::InitSparseHandle(&sparse_handle_, stream_); if (!sparse_handle_) {
if (!sparse_handle_creator_)
phi::InitSparseHandle(&sparse_handle_, stream_);
else
sparse_handle_ = sparse_handle_creator_();
}
}); });
PD_CHECK(sparse_handle_ != nullptr, "the gpu sparse handle is nullptr."); PD_CHECK(sparse_handle_ != nullptr, "the gpu sparse handle is nullptr.");
return sparse_handle_; return sparse_handle_;
...@@ -412,6 +476,10 @@ struct GPUContext::Impl { ...@@ -412,6 +476,10 @@ struct GPUContext::Impl {
void SetSparseHandle(sparseHandle_t handle) { sparse_handle_ = handle; } void SetSparseHandle(sparseHandle_t handle) { sparse_handle_ = handle; }
void SetSparseHandle(std::function<sparseHandle_t()>&& handle_creator) {
sparse_handle_creator_ = std::move(handle_creator);
}
void Wait() const { void Wait() const {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
hipError_t e_sync = hipSuccess; hipError_t e_sync = hipSuccess;
...@@ -461,21 +529,31 @@ struct GPUContext::Impl { ...@@ -461,21 +529,31 @@ struct GPUContext::Impl {
} }
inline void CublasCall(const std::function<void(blasHandle_t)>& callback) { inline void CublasCall(const std::function<void(blasHandle_t)>& callback) {
std::call_once(flag_cublas_, [=]() { std::call_once(flag_cublas_, [&]() {
if (!blas_handle_) { if (!blas_handle_) {
phi::InitBlasHandle(&blas_handle_, stream_); if (!blas_handle_creator_)
phi::InitBlasHandle(&blas_handle_, stream_);
else
blas_handle_ = blas_handle_creator_();
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
if (!blas_tensor_core_handle_) { if (!blas_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_); if (!blas_tensor_core_handle_creator_)
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
else
blas_tensor_core_handle_ = blas_tensor_core_handle_creator_();
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH)); blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
} }
#endif #endif
#if CUDA_VERSION >= 11000 #if CUDA_VERSION >= 11000
if (!blas_tf32_tensor_core_handle_) { if (!blas_tf32_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_); if (!blas_tf32_tensor_core_handle_creator_)
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
else
blas_tf32_tensor_core_handle_ =
blas_tf32_tensor_core_handle_creator_();
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH)); blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
} }
...@@ -493,19 +571,31 @@ struct GPUContext::Impl { ...@@ -493,19 +571,31 @@ struct GPUContext::Impl {
inline void TensorCoreCublasCallIfAvailable( inline void TensorCoreCublasCallIfAvailable(
const std::function<void(blasHandle_t)>& callback) { const std::function<void(blasHandle_t)>& callback) {
std::call_once(flag_tensorcore_cublas_, [=]() { std::call_once(flag_tensorcore_cublas_, [&]() {
if (!blas_handle_) phi::InitBlasHandle(&blas_handle_, stream_); if (!blas_handle_) {
if (!blas_handle_creator_)
phi::InitBlasHandle(&blas_handle_, stream_);
else
blas_handle_ = blas_handle_creator_();
}
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
if (!blas_tensor_core_handle_) { if (!blas_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_); if (!blas_tensor_core_handle_creator_)
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
else
blas_tensor_core_handle_ = blas_tensor_core_handle_creator_();
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH)); blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
} }
#endif #endif
#if CUDA_VERSION >= 11000 #if CUDA_VERSION >= 11000
if (!blas_tf32_tensor_core_handle_) { if (!blas_tf32_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_); if (!blas_tf32_tensor_core_handle_creator_)
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
else
blas_tf32_tensor_core_handle_ =
blas_tf32_tensor_core_handle_creator_();
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode( PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH)); blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
} }
...@@ -523,9 +613,12 @@ struct GPUContext::Impl { ...@@ -523,9 +613,12 @@ struct GPUContext::Impl {
inline void CusparseCall( inline void CusparseCall(
const std::function<void(sparseHandle_t)>& callback) { const std::function<void(sparseHandle_t)>& callback) {
std::call_once(flag_sparse_, [=]() { std::call_once(flag_sparse_, [&]() {
if (!sparse_handle_) { if (!sparse_handle_) {
phi::InitSparseHandle(&sparse_handle_, stream_); if (!sparse_handle_creator_)
phi::InitSparseHandle(&sparse_handle_, stream_);
else
sparse_handle_ = sparse_handle_creator_();
} }
}); });
std::lock_guard<std::mutex> guard(sparse_mtx_); std::lock_guard<std::mutex> guard(sparse_mtx_);
...@@ -597,13 +690,21 @@ struct GPUContext::Impl { ...@@ -597,13 +690,21 @@ struct GPUContext::Impl {
gpuStream_t stream_{nullptr}; gpuStream_t stream_{nullptr};
Eigen::GpuDevice* eigen_device_{nullptr}; Eigen::GpuDevice* eigen_device_{nullptr};
std::function<Eigen::GpuDevice*()> eigen_device_creator_{nullptr};
blasHandle_t blas_handle_{nullptr}; blasHandle_t blas_handle_{nullptr};
std::function<blasHandle_t()> blas_handle_creator_{nullptr};
blasHandle_t blas_tensor_core_handle_{nullptr}; blasHandle_t blas_tensor_core_handle_{nullptr};
std::function<blasHandle_t()> blas_tensor_core_handle_creator_{nullptr};
blasHandle_t blas_tf32_tensor_core_handle_{nullptr}; blasHandle_t blas_tf32_tensor_core_handle_{nullptr};
std::function<blasHandle_t()> blas_tf32_tensor_core_handle_creator_{nullptr};
blasLtHandle_t blaslt_handle_{nullptr}; blasLtHandle_t blaslt_handle_{nullptr};
std::function<blasLtHandle_t()> blaslt_handle_creator_{nullptr};
dnnHandle_t dnn_handle_{nullptr}; dnnHandle_t dnn_handle_{nullptr};
std::function<dnnHandle_t()> dnn_handle_creator_{nullptr};
solverHandle_t solver_handle_{nullptr}; solverHandle_t solver_handle_{nullptr};
std::function<solverHandle_t()> solver_handle_creator_{nullptr};
sparseHandle_t sparse_handle_{nullptr}; sparseHandle_t sparse_handle_{nullptr};
std::function<sparseHandle_t()> sparse_handle_creator_{nullptr};
DnnWorkspaceHandle* workspace_{nullptr}; DnnWorkspaceHandle* workspace_{nullptr};
std::once_flag flag_sparse_; std::once_flag flag_sparse_;
...@@ -613,6 +714,7 @@ struct GPUContext::Impl { ...@@ -613,6 +714,7 @@ struct GPUContext::Impl {
std::once_flag flag_slover_; std::once_flag flag_slover_;
std::once_flag flag_cublas_; std::once_flag flag_cublas_;
std::once_flag flag_tensorcore_cublas_; std::once_flag flag_tensorcore_cublas_;
std::once_flag flag_eigen_device_;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
// NCCL communicator (single process version) for NCCL collective operations. // NCCL communicator (single process version) for NCCL collective operations.
...@@ -752,34 +854,66 @@ void GPUContext::SetEigenDevice(Eigen::GpuDevice* device) { ...@@ -752,34 +854,66 @@ void GPUContext::SetEigenDevice(Eigen::GpuDevice* device) {
impl_->SetEigenDevice(device); impl_->SetEigenDevice(device);
} }
void GPUContext::SetEigenDevice(std::function<Eigen::GpuDevice*()>&& creator) {
impl_->SetEigenDevice(std::move(creator));
}
void GPUContext::SetBlasHandle(blasHandle_t blas) { void GPUContext::SetBlasHandle(blasHandle_t blas) {
impl_->SetBlasHandle(blas); impl_->SetBlasHandle(blas);
} }
void GPUContext::SetBlasHandle(std::function<blasHandle_t()>&& func) {
impl_->SetBlasHandle(std::move(func));
}
void GPUContext::SetBlasTensorCoreHandle(blasHandle_t handle) { void GPUContext::SetBlasTensorCoreHandle(blasHandle_t handle) {
impl_->SetBlasTensorCoreHandle(handle); impl_->SetBlasTensorCoreHandle(handle);
} }
void GPUContext::SetBlasTensorCoreHandle(std::function<blasHandle_t()>&& func) {
impl_->SetBlasTensorCoreHandle(std::move(func));
}
void GPUContext::SetBlasTF32Handle(blasHandle_t handle) { void GPUContext::SetBlasTF32Handle(blasHandle_t handle) {
impl_->SetBlasTF32Handle(handle); impl_->SetBlasTF32Handle(handle);
} }
void GPUContext::SetBlasTF32Handle(std::function<blasHandle_t()>&& func) {
impl_->SetBlasTF32Handle(std::move(func));
}
void GPUContext::SetBlasLtHandle(blasLtHandle_t blaslt) { void GPUContext::SetBlasLtHandle(blasLtHandle_t blaslt) {
impl_->SetBlasLtHandle(blaslt); impl_->SetBlasLtHandle(blaslt);
} }
void GPUContext::SetBlasLtHandle(std::function<blasLtHandle_t()>&& func) {
impl_->SetBlasLtHandle(std::move(func));
}
void GPUContext::SetDnnHandle(dnnHandle_t handle) { void GPUContext::SetDnnHandle(dnnHandle_t handle) {
impl_->SetDnnHandle(handle); impl_->SetDnnHandle(handle);
} }
void GPUContext::SetDnnHandle(std::function<dnnHandle_t()>&& func) {
impl_->SetDnnHandle(std::move(func));
}
void GPUContext::SetSolverHandle(solverHandle_t handle) { void GPUContext::SetSolverHandle(solverHandle_t handle) {
impl_->SetSolverHandle(handle); impl_->SetSolverHandle(handle);
} }
void GPUContext::SetSolverHandle(std::function<solverHandle_t()>&& func) {
impl_->SetSolverHandle(std::move(func));
}
void GPUContext::SetSparseHandle(sparseHandle_t handle) { void GPUContext::SetSparseHandle(sparseHandle_t handle) {
impl_->SetSparseHandle(handle); impl_->SetSparseHandle(handle);
} }
void GPUContext::SetSparseHandle(std::function<sparseHandle_t()>&& func) {
impl_->SetSparseHandle(std::move(func));
}
void GPUContext::SetDnnWorkspaceHandle(DnnWorkspaceHandle* handle) { void GPUContext::SetDnnWorkspaceHandle(DnnWorkspaceHandle* handle) {
impl_->workspace_ = handle; impl_->workspace_ = handle;
} }
......
...@@ -197,20 +197,28 @@ class PADDLE_API GPUContext : public DeviceContext { ...@@ -197,20 +197,28 @@ class PADDLE_API GPUContext : public DeviceContext {
void SetStream(gpuStream_t); void SetStream(gpuStream_t);
void SetEigenDevice(Eigen::GpuDevice*); void SetEigenDevice(Eigen::GpuDevice*);
void SetEigenDevice(std::function<Eigen::GpuDevice*()>&&);
void SetBlasHandle(blasHandle_t); void SetBlasHandle(blasHandle_t);
void SetBlasHandle(std::function<blasHandle_t()>&&);
void SetBlasTensorCoreHandle(blasHandle_t); void SetBlasTensorCoreHandle(blasHandle_t);
void SetBlasTensorCoreHandle(std::function<blasHandle_t()>&&);
void SetBlasTF32Handle(blasHandle_t); void SetBlasTF32Handle(blasHandle_t);
void SetBlasTF32Handle(std::function<blasHandle_t()>&&);
void SetBlasLtHandle(blasLtHandle_t); void SetBlasLtHandle(blasLtHandle_t);
void SetBlasLtHandle(std::function<blasLtHandle_t()>&&);
void SetDnnHandle(dnnHandle_t); void SetDnnHandle(dnnHandle_t);
void SetDnnHandle(std::function<dnnHandle_t()>&&);
void SetSolverHandle(solverHandle_t); void SetSolverHandle(solverHandle_t);
void SetSolverHandle(std::function<solverHandle_t()>&&);
void SetSparseHandle(sparseHandle_t); void SetSparseHandle(sparseHandle_t);
void SetSparseHandle(std::function<sparseHandle_t()>&&);
void SetDnnWorkspaceHandle(DnnWorkspaceHandle*); void SetDnnWorkspaceHandle(DnnWorkspaceHandle*);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册