未验证 提交 2d01cc85 编写于 作者: 石晓伟 提交者: GitHub

DeviceContext Split, test=develop (#23737)

* supports thread-binding stream, test=develop

* avoid using thread_local variables in dtor, test=develop

* modify the stream priority enum, test=develop
上级 8af85922
......@@ -44,6 +44,7 @@ cc_library(place SRCS place.cc DEPS enforce boost)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
add_subdirectory(dynload)
add_subdirectory(stream)
cc_library(cpu_helper SRCS cpu_helper.cc DEPS cblas enforce)
cc_test(cpu_helper_test SRCS cpu_helper_test.cc DEPS cpu_helper)
......@@ -54,7 +55,7 @@ IF(WITH_DGC)
ENDIF()
IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader cuda_stream)
ENDIF()
IF(WITH_MKLDNN)
......
......@@ -211,6 +211,33 @@ void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}
thread_local std::unordered_map<const CUDADeviceContext*,
std::shared_ptr<CUDAContext>>
CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;
void CUDAContext::InitEigenContext() {
eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&RawStream(), place_);
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}
CUDAContext::CUDAContext(const CUDAPlace& place,
const enum stream::Priority& priority) {
place_ = place;
CUDADeviceGuard guard(place_.device);
stream_.reset(new stream::CUDAStream(place, priority));
InitEigenContext();
InitCuBlasContext();
InitCuDNNContext();
}
CUDAContext::~CUDAContext() {
CUDADeviceGuard guard(place_.device);
DestoryCuDNNContext();
DestoryCuBlasContext();
}
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
CUDADeviceGuard guard(place_.device);
compute_capability_ = GetCUDAComputeCapability(place_.device);
......@@ -218,18 +245,6 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
max_threads_per_block_ = GetCUDAMaxThreadsPerBlock(place_.device);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream_));
eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&stream_, place);
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
cublas_handle_.reset(new CublasHandleHolder(stream_, CUBLAS_DEFAULT_MATH));
if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_.reset(
new CublasHandleHolder(stream_, CUBLAS_TENSOR_OP_MATH));
#endif
}
driver_version_ = GetCUDADriverVersion(place_.device);
runtime_version_ = GetCUDARuntimeVersion(place_.device);
......@@ -263,48 +278,12 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
<< "Please recompile or reinstall Paddle with compatible CUDA "
"version.";
}
if (dynload::HasCUDNN()) {
auto local_cudnn_version = cudnn_dso_ver / 100;
auto compile_cudnn_version = CUDNN_VERSION / 100;
if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << place_.device
<< ". The installed Paddle is compiled with CUDNN "
<< compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
<< ", but CUDNN version in your machine is "
<< local_cudnn_version / 10 << "." << local_cudnn_version % 10
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible CUDNN "
"version.";
}
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnCreate(&cudnn_handle_),
"Failed to create Cudnn handle in DeviceContext");
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnSetStream(cudnn_handle_, stream_),
"Failed to set stream for Cudnn handle in DeviceContext");
} else {
cudnn_handle_ = nullptr;
}
}
callback_manager_.reset(new StreamCallbackManager(stream_));
default_ctx_.reset(new CUDAContext(place_));
}
CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device);
Wait();
WaitStreamCallback();
cublas_handle_.reset();
cublas_tensor_core_handle_.reset();
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(stream_));
if (cudnn_handle_) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroy(cudnn_handle_),
"Failed to destory Cudnn handle");
}
#if defined(PADDLE_WITH_NCCL)
if (nccl_comm_) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
......@@ -314,22 +293,7 @@ CUDADeviceContext::~CUDADeviceContext() {
Place CUDADeviceContext::GetPlace() const { return place_; }
void CUDADeviceContext::Wait() const {
cudaError_t e_sync = cudaSuccess;
#if !defined(_WIN32)
e_sync = cudaStreamSynchronize(stream_);
#else
while (e_sync = cudaStreamQuery(stream_)) {
if (e_sync == cudaErrorNotReady) continue;
break;
}
#endif
PADDLE_ENFORCE_CUDA_SUCCESS(
e_sync, platform::errors::Fatal(
"cudaStreamSynchronize raises error: %s, errono: %d",
cudaGetErrorString(e_sync), static_cast<int>(e_sync)));
}
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
int CUDADeviceContext::GetComputeCapability() const {
return compute_capability_;
......@@ -346,24 +310,28 @@ int CUDADeviceContext::GetMaxThreadsPerBlock() const {
}
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get();
return context()->EigenDevice().get();
}
bool CUDADeviceContext::tensor_core_available() const {
return cublas_tensor_core_handle_ != nullptr;
return context()->CublasTensorCoreHandle() != nullptr;
}
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
return max_grid_dim_size_;
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return context()->CudnnHandle();
}
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
}
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
cudaStream_t CUDADeviceContext::stream() const {
return context()->RawStream();
}
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
......
......@@ -38,7 +38,7 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/stream_callback_manager.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
#endif
#include "unsupported/Eigen/CXX11/Tensor"
......@@ -80,6 +80,125 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
class EigenCudaStreamDevice;
class CudnnWorkspaceHandle;
class CUDAContext {
public:
CUDAContext() = default;
explicit CUDAContext(
const CUDAPlace& place,
const enum stream::Priority& priority = stream::Priority::kNormal);
~CUDAContext();
const CUDAPlace& Place() const { return place_; }
const std::unique_ptr<Eigen::GpuDevice>& EigenDevice() const {
return eigen_device_;
}
const std::unique_ptr<EigenCudaStreamDevice>& EigenStream() const {
return eigen_stream_;
}
const std::unique_ptr<stream::CUDAStream>& Stream() const { return stream_; }
const cudaStream_t& RawStream() { return stream_->raw_stream(); }
const cudnnHandle_t& CudnnHandle() const { return cudnn_handle_; }
const std::unique_ptr<CublasHandleHolder>& CublasHandle() const {
return cublas_handle_;
}
const std::unique_ptr<CublasHandleHolder>& CublasTensorCoreHandle() const {
return cublas_tensor_core_handle_;
}
/*! \brief Call cublas function safely. */
template <typename Callback>
inline void CublasCall(Callback&& callback) const {
cublas_handle_->Call(std::forward<Callback>(callback));
}
/*! \brief Check whether tensor core is supported */
bool tensor_core_available() const;
/*! \brief Call cublas function with Tensor Core safely. If
Tensor Core is not available, use DEFAULT_MATH instead. */
template <typename Callback>
inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
if (cublas_tensor_core_handle_) {
cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
} else {
cublas_handle_->Call(std::forward<Callback>(callback));
}
}
private:
void InitEigenContext();
void InitCuBlasContext() {
cublas_handle_.reset(
new CublasHandleHolder(RawStream(), CUBLAS_DEFAULT_MATH));
if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_.reset(
new CublasHandleHolder(RawStream(), CUBLAS_TENSOR_OP_MATH));
#endif
}
}
void InitCuDNNContext() {
if (dynload::HasCUDNN()) {
auto local_cudnn_version = dynload::cudnnGetVersion() / 100;
auto compile_cudnn_version = CUDNN_VERSION / 100;
if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << place_.device
<< ". The installed Paddle is compiled with CUDNN "
<< compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
<< ", but CUDNN version in your machine is "
<< local_cudnn_version / 10 << "." << local_cudnn_version % 10
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible CUDNN "
"version.";
}
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnCreate(&cudnn_handle_),
platform::errors::Fatal(
"Failed to create Cudnn handle in DeviceContext"));
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnSetStream(cudnn_handle_, RawStream()),
platform::errors::Fatal(
"Failed to set stream for Cudnn handle in DeviceContext"));
} else {
cudnn_handle_ = nullptr;
}
}
void DestoryCuDNNContext() {
if (cudnn_handle_) {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnDestroy(cudnn_handle_),
platform::errors::Fatal("Failed to destory Cudnn handle"));
}
cudnn_handle_ = nullptr;
}
void DestoryCuBlasContext() {
cublas_handle_.reset();
cublas_tensor_core_handle_.reset();
}
CUDAPlace place_;
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
std::unique_ptr<stream::CUDAStream> stream_;
cudnnHandle_t cudnn_handle_;
std::unique_ptr<CublasHandleHolder> cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
DISABLE_COPY_AND_ASSIGN(CUDAContext);
};
class CUDADeviceContext : public DeviceContext {
public:
explicit CUDADeviceContext(CUDAPlace place);
......@@ -112,7 +231,7 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Call cublas function safely. */
template <typename Callback>
inline void CublasCall(Callback&& callback) const {
cublas_handle_->Call(std::forward<Callback>(callback));
return context()->CublasCall(callback);
}
/*! \brief Check whether tensor core is supported */
......@@ -122,11 +241,7 @@ class CUDADeviceContext : public DeviceContext {
Tensor Core is not available, use DEFAULT_MATH instead. */
template <typename Callback>
inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
if (cublas_tensor_core_handle_) {
cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
} else {
cublas_handle_->Call(std::forward<Callback>(callback));
}
return context()->TensorCoreCublasCallIfAvailable(callback);
}
/*! \brief Return cudnn handle in the device context. */
......@@ -153,33 +268,48 @@ class CUDADeviceContext : public DeviceContext {
#endif
template <typename Callback>
void RecordEvent(cudaEvent_t ev, Callback callback) {
callback();
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(ev, stream_));
void RecordEvent(cudaEvent_t ev, Callback callback) const {
return context()->Stream()->RecordEvent(ev, callback);
}
template <typename Callback>
void AddStreamCallback(Callback&& callback) const {
callback_manager_->AddCallback(callback);
return context()->Stream()->AddCallback(callback);
}
void WaitStreamCallback() const {
return context()->Stream()->WaitCallback();
}
void WaitStreamCallback() const { callback_manager_->Wait(); }
void ResetDefaultContext(const enum stream::Priority& priority) {
default_ctx_.reset(new CUDAContext(place_, priority));
}
void ResetThreadContext(const enum stream::Priority& priority) {
std::lock_guard<std::mutex> guard(ctx_mtx_);
thread_ctx_[this].reset(new CUDAContext(place_, priority));
}
std::shared_ptr<CUDAContext> context() const {
if (!thread_ctx_.count(this)) {
return default_ctx_;
}
return thread_ctx_.at(this);
}
private:
CUDAPlace place_;
std::shared_ptr<CUDAContext> default_ctx_;
mutable std::once_flag init_cudnn_;
// The thread_local static variable will be released before the
// global static variable, so avoid using it in dtor.
static thread_local std::unordered_map<const CUDADeviceContext*,
std::shared_ptr<CUDAContext>>
thread_ctx_;
static thread_local std::mutex ctx_mtx_;
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
cudaStream_t stream_;
cudnnHandle_t cudnn_handle_;
mutable std::mutex cudnn_handle_mtx_;
std::unique_ptr<CublasHandleHolder> cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
#if defined(PADDLE_WITH_NCCL)
// NCCL communicator (single process version) for NCCL collective operations.
// NCCL collective operations provides fast collectives over multiple GPUs
......@@ -197,9 +327,6 @@ class CUDADeviceContext : public DeviceContext {
int max_threads_per_block_;
dim3 max_grid_dim_size_;
// StreamCallbackManager is thread-safe
std::unique_ptr<StreamCallbackManager> callback_manager_;
DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
};
......
IF(WITH_GPU)
cc_library(cuda_stream SRCS cuda_stream.cc DEPS enforce)
ENDIF()
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
namespace stream {
constexpr unsigned int kDefaultFlag = cudaStreamDefault;
bool CUDAStream::Init(const Place& place, const enum Priority& priority) {
PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
platform::errors::InvalidArgument(
"Cuda stream must be created using cuda place."));
place_ = place;
CUDADeviceGuard guard(boost::get<CUDAPlace>(place_).device);
if (priority == Priority::kHigh) {
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithPriority(&stream_, kDefaultFlag, -1),
platform::errors::Fatal("High priority cuda stream creation failed."));
} else if (priority == Priority::kNormal) {
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithPriority(&stream_, kDefaultFlag, 0),
platform::errors::Fatal(
"Normal priority cuda stream creation failed."));
}
callback_manager_.reset(new StreamCallbackManager(stream_));
VLOG(3) << "CUDAStream Init stream: " << stream_
<< ", priority: " << static_cast<int>(priority);
return true;
}
void CUDAStream::Destroy() {
CUDADeviceGuard guard(boost::get<CUDAPlace>(place_).device);
Wait();
WaitCallback();
if (stream_) {
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamDestroy(stream_),
platform::errors::Fatal("Cuda stream destruction failed."));
}
stream_ = nullptr;
}
void CUDAStream::Wait() const {
cudaError_t e_sync = cudaSuccess;
#if !defined(_WIN32)
e_sync = cudaStreamSynchronize(stream_);
#else
while (e_sync = cudaStreamQuery(stream_)) {
if (e_sync == cudaErrorNotReady) continue;
break;
}
#endif
PADDLE_ENFORCE_CUDA_SUCCESS(
e_sync, platform::errors::Fatal(
"cudaStreamSynchronize raises error: %s, errono: %d",
cudaGetErrorString(e_sync), static_cast<int>(e_sync)));
}
} // namespace stream
} // namespace platform
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstdint>
#include <memory>
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream_callback_manager.h"
namespace paddle {
namespace platform {
namespace stream {
#ifdef PADDLE_WITH_CUDA
enum class Priority : uint8_t {
kNull = 0x0,
kHigh = 0x1,
kNormal = 0x2,
};
class CUDAStream final {
public:
CUDAStream() = default;
CUDAStream(const Place& place,
const enum Priority& priority = Priority::kNormal) {
Init(place, priority);
}
virtual ~CUDAStream() { Destroy(); }
bool Init(const Place& place,
const enum Priority& priority = Priority::kNormal);
template <typename Callback>
void AddCallback(Callback&& callback) const {
callback_manager_->AddCallback(callback);
}
template <typename Callback>
void RecordEvent(cudaEvent_t ev, Callback callback) const {
callback();
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventRecord(ev, stream_),
platform::errors::Fatal("CUDA event recording failed."));
}
void RecordEvent(cudaEvent_t ev) const {
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventRecord(ev, stream_),
platform::errors::Fatal("CUDA event recording failed."));
}
void WaitEvent(cudaEvent_t ev) const {
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamWaitEvent(stream_, ev, 0),
platform::errors::Fatal("Failed to wait event."));
}
void Wait() const;
void WaitCallback() const { callback_manager_->Wait(); }
const cudaStream_t& raw_stream() const { return stream_; }
void Destroy();
private:
Place place_;
cudaStream_t stream_{nullptr};
Priority priority_{Priority::kNormal};
std::unique_ptr<StreamCallbackManager> callback_manager_;
DISABLE_COPY_AND_ASSIGN(CUDAStream);
};
#endif
} // namespace stream
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册