未验证 提交 cad139a7 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

call_once (#43206)

上级 8dab2690
...@@ -214,23 +214,6 @@ struct GPUContext::Impl { ...@@ -214,23 +214,6 @@ struct GPUContext::Impl {
&max_grid_dim_size_); &max_grid_dim_size_);
phi::InitStream(&stream_); phi::InitStream(&stream_);
InitEigenDevice(); InitEigenDevice();
phi::InitBlasHandle(&blas_handle_, stream_);
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
#endif
#if CUDA_VERSION >= 11000
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
#endif
#endif
phi::InitBlasLtHandle(&blaslt_handle_);
phi::InitDnnHandle(&dnn_handle_, stream_, place_);
phi::InitSolverHandle(&solver_handle_, stream_);
phi::InitSparseHandle(&sparse_handle_, stream_);
InitDnnWorkspace(); InitDnnWorkspace();
} }
...@@ -246,23 +229,6 @@ struct GPUContext::Impl { ...@@ -246,23 +229,6 @@ struct GPUContext::Impl {
&max_threads_per_block_, &max_threads_per_block_,
&max_grid_dim_size_); &max_grid_dim_size_);
phi::InitStream(&stream_); phi::InitStream(&stream_);
phi::InitBlasHandle(&blas_handle_, stream_);
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
#endif
#if CUDA_VERSION >= 11000
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
#endif
#endif
phi::InitBlasLtHandle(&blaslt_handle_);
phi::InitDnnHandle(&dnn_handle_, stream_, place_);
phi::InitSolverHandle(&solver_handle_, stream_);
phi::InitSparseHandle(&sparse_handle_, stream_);
} }
void PartialInitWithAllocator() { void PartialInitWithAllocator() {
...@@ -356,7 +322,28 @@ struct GPUContext::Impl { ...@@ -356,7 +322,28 @@ struct GPUContext::Impl {
return eigen_device_; return eigen_device_;
} }
blasHandle_t GetBlasHandle() const { blasHandle_t GetBlasHandle() {
std::call_once(flag_blas_, [=]() {
if (!blas_handle_) {
phi::InitBlasHandle(&blas_handle_, stream_);
}
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
if (!blas_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
#if CUDA_VERSION >= 11000
if (!blas_tf32_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
}
#endif
#endif
});
PD_CHECK(blas_handle_ != nullptr, "the gpu blas handle is nullptr."); PD_CHECK(blas_handle_ != nullptr, "the gpu blas handle is nullptr.");
return blas_handle_; return blas_handle_;
} }
...@@ -373,12 +360,18 @@ struct GPUContext::Impl { ...@@ -373,12 +360,18 @@ struct GPUContext::Impl {
void SetBlasLtHandle(blasLtHandle_t blaslt) { blaslt_handle_ = blaslt; } void SetBlasLtHandle(blasLtHandle_t blaslt) { blaslt_handle_ = blaslt; }
blasLtHandle_t GetBlasLtHandle() const { blasLtHandle_t GetBlasLtHandle() {
std::call_once(flag_blaslt_, [=]() {
if (!blaslt_handle_) phi::InitBlasLtHandle(&blaslt_handle_);
});
PD_CHECK(blaslt_handle_ != nullptr, "the gpu blasLt handle is nullptr."); PD_CHECK(blaslt_handle_ != nullptr, "the gpu blasLt handle is nullptr.");
return blaslt_handle_; return blaslt_handle_;
} }
dnnHandle_t GetDnnHandle() { dnnHandle_t GetDnnHandle() {
std::call_once(flag_dnn_, [=]() {
if (!dnn_handle_) phi::InitDnnHandle(&dnn_handle_, stream_, place_);
});
PD_CHECK(dnn_handle_ != nullptr, "the gpu dnn handle is nullptr."); PD_CHECK(dnn_handle_ != nullptr, "the gpu dnn handle is nullptr.");
return dnn_handle_; return dnn_handle_;
} }
...@@ -399,7 +392,10 @@ struct GPUContext::Impl { ...@@ -399,7 +392,10 @@ struct GPUContext::Impl {
void SetDnnHandle(dnnHandle_t handle) { dnn_handle_ = handle; } void SetDnnHandle(dnnHandle_t handle) { dnn_handle_ = handle; }
solverHandle_t GetSolverHandle() const { solverHandle_t GetSolverHandle() {
std::call_once(flag_slover_, [=]() {
if (!solver_handle_) phi::InitSolverHandle(&solver_handle_, stream_);
});
PD_CHECK(solver_handle_ != nullptr, "the gpu solver handle is nullptr."); PD_CHECK(solver_handle_ != nullptr, "the gpu solver handle is nullptr.");
return solver_handle_; return solver_handle_;
} }
...@@ -461,8 +457,28 @@ struct GPUContext::Impl { ...@@ -461,8 +457,28 @@ struct GPUContext::Impl {
#endif #endif
} }
inline void CublasCall( inline void CublasCall(const std::function<void(blasHandle_t)>& callback) {
const std::function<void(blasHandle_t)>& callback) const { std::call_once(flag_cublas_, [=]() {
if (!blas_handle_) {
phi::InitBlasHandle(&blas_handle_, stream_);
}
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
if (!blas_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
#if CUDA_VERSION >= 11000
if (!blas_tf32_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
}
#endif
#endif
});
if (blas_tf32_tensor_core_handle_ != nullptr) { if (blas_tf32_tensor_core_handle_ != nullptr) {
std::lock_guard<std::mutex> guard(blas_tf32_mtx_); std::lock_guard<std::mutex> guard(blas_tf32_mtx_);
callback(blas_tf32_tensor_core_handle_); callback(blas_tf32_tensor_core_handle_);
...@@ -473,7 +489,26 @@ struct GPUContext::Impl { ...@@ -473,7 +489,26 @@ struct GPUContext::Impl {
} }
inline void TensorCoreCublasCallIfAvailable( inline void TensorCoreCublasCallIfAvailable(
const std::function<void(blasHandle_t)>& callback) const { const std::function<void(blasHandle_t)>& callback) {
std::call_once(flag_tensorcore_cublas_, [=]() {
if (!blas_handle_) phi::InitBlasHandle(&blas_handle_, stream_);
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
if (!blas_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tensor_core_handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
#if CUDA_VERSION >= 11000
if (!blas_tf32_tensor_core_handle_) {
phi::InitBlasHandle(&blas_tf32_tensor_core_handle_, stream_);
PADDLE_RETRY_CUDA_SUCCESS(phi::dynload::cublasSetMathMode(
blas_tf32_tensor_core_handle_, CUBLAS_TF32_TENSOR_OP_MATH));
}
#endif
#endif
});
if (blas_tensor_core_handle_ != nullptr) { if (blas_tensor_core_handle_ != nullptr) {
std::lock_guard<std::mutex> guard(blas_tensor_core_mtx_); std::lock_guard<std::mutex> guard(blas_tensor_core_mtx_);
callback(blas_tensor_core_handle_); callback(blas_tensor_core_handle_);
...@@ -563,6 +598,13 @@ struct GPUContext::Impl { ...@@ -563,6 +598,13 @@ struct GPUContext::Impl {
sparseHandle_t sparse_handle_{nullptr}; sparseHandle_t sparse_handle_{nullptr};
DnnWorkspaceHandle* workspace_{nullptr}; DnnWorkspaceHandle* workspace_{nullptr};
std::once_flag flag_blas_;
std::once_flag flag_blaslt_;
std::once_flag flag_dnn_;
std::once_flag flag_slover_;
std::once_flag flag_cublas_;
std::once_flag flag_tensorcore_cublas_;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
// NCCL communicator (single process version) for NCCL collective operations. // NCCL communicator (single process version) for NCCL collective operations.
// NCCL collective operations provides fast collectives over multiple GPUs // NCCL collective operations provides fast collectives over multiple GPUs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册