From 597cc0589b1ab7a6f5d4d465313d5401e802bc6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E6=99=93=E4=BC=9F?= <39303645+Shixiaowei02@users.noreply.github.com> Date: Thu, 23 Apr 2020 10:54:58 +0800 Subject: [PATCH] [cherry-pick] Thread-local Allocator, test=release/2.0 (#24061) * cherry-pick of DeviceContext Split, test=develop (#23737) * New feature: thread local allocator, test=develop (#23989) * add the thread_local_allocator, test=develop * refactor the thread_local_allocator, test=develop * provides option setting strategy, test=develop * add boost dependency to cuda_stream, test=develop * declare the stream::Priority as enum class, test=develop * deal with PADDLE_ENFORCE_CUDA_SUCCESS macro in pr #23816 --- paddle/fluid/memory/allocation/CMakeLists.txt | 4 +- .../memory/allocation/allocator_facade.cc | 17 ++ .../memory/allocation/allocator_strategy.cc | 4 + .../memory/allocation/allocator_strategy.h | 2 +- .../allocation/thread_local_allocator.cc | 76 ++++++++ .../allocation/thread_local_allocator.h | 100 +++++++++++ .../allocation/thread_local_allocator_test.cc | 93 ++++++++++ paddle/fluid/platform/CMakeLists.txt | 3 +- paddle/fluid/platform/device_context.cc | 99 ++++------- paddle/fluid/platform/device_context.h | 168 +++++++++++++++--- paddle/fluid/platform/flags.cc | 3 +- paddle/fluid/platform/stream/CMakeLists.txt | 3 + paddle/fluid/platform/stream/cuda_stream.cc | 70 ++++++++ paddle/fluid/platform/stream/cuda_stream.h | 85 +++++++++ 14 files changed, 637 insertions(+), 90 deletions(-) create mode 100644 paddle/fluid/memory/allocation/thread_local_allocator.cc create mode 100644 paddle/fluid/memory/allocation/thread_local_allocator.h create mode 100644 paddle/fluid/memory/allocation/thread_local_allocator_test.cc create mode 100644 paddle/fluid/platform/stream/CMakeLists.txt create mode 100644 paddle/fluid/platform/stream/cuda_stream.cc create mode 100644 paddle/fluid/platform/stream/cuda_stream.h diff --git a/paddle/fluid/memory/allocation/CMakeLists.txt b/paddle/fluid/memory/allocation/CMakeLists.txt index dc26c19cbc8..fdd6923a67b 100644 --- a/paddle/fluid/memory/allocation/CMakeLists.txt +++ b/paddle/fluid/memory/allocation/CMakeLists.txt @@ -14,13 +14,15 @@ endif() if (WITH_GPU) nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard) + nv_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator) + cc_test(thread_local_allocator_test SRCS thread_local_allocator_test.cc DEPS thread_local_allocator) endif() cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator) nv_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator) if (WITH_GPU) - set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard) + set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard thread_local_allocator) else () set(AllocatorFacadeDeps) endif() diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 63763acb64c..c851f1b10c9 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -32,6 +32,7 @@ #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/memory/allocation/cuda_allocator.h" #include "paddle/fluid/memory/allocation/pinned_allocator.h" +#include "paddle/fluid/memory/allocation/thread_local_allocator.h" #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/gpu_info.h" #endif @@ -80,6 +81,18 @@ class AllocatorFacadePrivate { break; } + case AllocatorStrategy::kThreadLocal: { + InitNaiveBestFitCPUAllocator(); +#ifdef PADDLE_WITH_CUDA + for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); + ++dev_id) { + InitThreadLocalCUDAAllocator(platform::CUDAPlace(dev_id)); + } + InitNaiveBestFitCUDAPinnedAllocator(); +#endif + break; + } + default: { PADDLE_THROW("Unsupported allocator strategy: %d", static_cast(strategy)); @@ -136,6 +149,10 @@ class AllocatorFacadePrivate { allocators_[p] = std::make_shared(p); } + void InitThreadLocalCUDAAllocator(platform::CUDAPlace p) { + allocators_[p] = std::make_shared(p); + } + void InitAutoGrowthCUDAAllocator(platform::CUDAPlace p) { auto cuda_allocator = std::make_shared(p); allocators_[p] = std::make_shared( diff --git a/paddle/fluid/memory/allocation/allocator_strategy.cc b/paddle/fluid/memory/allocation/allocator_strategy.cc index 19b1380612b..74757439fd4 100644 --- a/paddle/fluid/memory/allocation/allocator_strategy.cc +++ b/paddle/fluid/memory/allocation/allocator_strategy.cc @@ -32,6 +32,10 @@ static AllocatorStrategy GetStrategyFromFlag() { return AllocatorStrategy::kAutoGrowth; } + if (FLAGS_allocator_strategy == "thread_local") { + return AllocatorStrategy::kThreadLocal; + } + PADDLE_THROW("Unsupported allocator strategy: %s", FLAGS_allocator_strategy); } diff --git a/paddle/fluid/memory/allocation/allocator_strategy.h b/paddle/fluid/memory/allocation/allocator_strategy.h index ff6e7839ff7..0db9d93e3e6 100644 --- a/paddle/fluid/memory/allocation/allocator_strategy.h +++ b/paddle/fluid/memory/allocation/allocator_strategy.h @@ -18,7 +18,7 @@ namespace paddle { namespace memory { namespace allocation { -enum class AllocatorStrategy { kNaiveBestFit, kAutoGrowth }; +enum class AllocatorStrategy { kNaiveBestFit, kAutoGrowth, kThreadLocal }; extern AllocatorStrategy GetAllocatorStrategy(); diff --git a/paddle/fluid/memory/allocation/thread_local_allocator.cc b/paddle/fluid/memory/allocation/thread_local_allocator.cc new file mode 100644 index 00000000000..96f22530135 --- /dev/null +++ b/paddle/fluid/memory/allocation/thread_local_allocator.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/memory/allocation/thread_local_allocator.h" + +namespace paddle { +namespace memory { +namespace allocation { + +ThreadLocalAllocatorImpl::ThreadLocalAllocatorImpl(const platform::Place& p) + : place_(p) { + if (platform::is_gpu_place(place_)) { + buddy_allocator_.reset(new memory::detail::BuddyAllocator( + std::unique_ptr( + new memory::detail::GPUAllocator( + boost::get(place_).device)), + platform::GpuMinChunkSize(), platform::GpuMaxChunkSize())); + } else { + LOG(FATAL) << "Thread local allocator only supports CUDAPlace now."; + } +} + +std::shared_ptr ThreadLocalCUDAAllocatorPool::Get( + int gpu_id) { + auto pos = std::distance(devices_.begin(), + std::find(devices_.begin(), devices_.end(), gpu_id)); + PADDLE_ENFORCE_LT( + pos, devices_.size(), + platform::errors::InvalidArgument( + "The position of device should be less than the size of devices.")); + std::call_once(*init_flags_[pos], [this, pos, gpu_id] { + platform::SetDeviceId(devices_[pos]); + allocators_[pos].reset( + new ThreadLocalAllocatorImpl(platform::CUDAPlace(gpu_id))); + }); + return allocators_[pos]; +} + +ThreadLocalCUDAAllocatorPool::ThreadLocalCUDAAllocatorPool() + : devices_(platform::GetSelectedDevices()) { + auto gpu_num = devices_.size(); + allocators_.resize(gpu_num); + init_flags_.reserve(gpu_num); + for (size_t i = 0; i < gpu_num; ++i) { + init_flags_.emplace_back(new std::once_flag()); + } +} + +ThreadLocalAllocation* ThreadLocalAllocatorImpl::AllocateImpl(size_t size) { + VLOG(10) << "ThreadLocalAllocatorImpl::AllocateImpl " << size; + void* ptr = buddy_allocator_->Alloc(size); + auto* tl_allocation = new ThreadLocalAllocation(ptr, size, place_); + tl_allocation->SetThreadLocalAllocatorImpl(shared_from_this()); + return tl_allocation; +} + +void ThreadLocalAllocatorImpl::FreeImpl(ThreadLocalAllocation* allocation) { + VLOG(10) << "ThreadLocalAllocatorImpl::FreeImpl " << allocation; + buddy_allocator_->Free(allocation->ptr()); + delete allocation; +} + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/allocation/thread_local_allocator.h b/paddle/fluid/memory/allocation/thread_local_allocator.h new file mode 100644 index 00000000000..bc07ad0c4dc --- /dev/null +++ b/paddle/fluid/memory/allocation/thread_local_allocator.h @@ -0,0 +1,100 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/memory/detail/buddy_allocator.h" +#include "paddle/fluid/memory/detail/system_allocator.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace memory { +namespace allocation { + +class ThreadLocalAllocatorImpl; + +class ThreadLocalAllocation : public Allocation { + public: + ThreadLocalAllocation(void* ptr, size_t size, platform::Place place) + : Allocation(ptr, size, place) {} + + void SetThreadLocalAllocatorImpl( + std::shared_ptr allocator) { + allocator_ = allocator; + } + + std::shared_ptr GetAllocator() { + return allocator_; + } + + private: + std::shared_ptr allocator_; +}; + +class ThreadLocalAllocatorImpl + : public std::enable_shared_from_this { + public: + explicit ThreadLocalAllocatorImpl(const platform::Place& p); + ThreadLocalAllocation* AllocateImpl(size_t size); + void FreeImpl(ThreadLocalAllocation* allocation); + + private: + std::unique_ptr buddy_allocator_; + platform::Place place_; +}; + +class ThreadLocalCUDAAllocatorPool { + public: + static ThreadLocalCUDAAllocatorPool& Instance() { + static thread_local ThreadLocalCUDAAllocatorPool pool; + return pool; + } + + std::shared_ptr Get(int gpu_id); + + private: + ThreadLocalCUDAAllocatorPool(); + std::vector devices_; + std::vector> init_flags_; + std::vector> allocators_; +}; + +class ThreadLocalCUDAAllocator : public Allocator { + public: + explicit ThreadLocalCUDAAllocator(const platform::CUDAPlace& p) + : gpu_id_(p.device) {} + + bool IsAllocThreadSafe() const override { return true; } + + protected: + Allocation* AllocateImpl(size_t size) override { + return ThreadLocalCUDAAllocatorPool::Instance().Get(gpu_id_)->AllocateImpl( + size); + } + void FreeImpl(Allocation* allocation) override { + auto* tl_allocation = static_cast(allocation); + auto allocator_impl = tl_allocation->GetAllocator(); + allocator_impl->FreeImpl(tl_allocation); + } + + private: + int gpu_id_; +}; + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/allocation/thread_local_allocator_test.cc b/paddle/fluid/memory/allocation/thread_local_allocator_test.cc new file mode 100644 index 00000000000..f9e2ea8c27a --- /dev/null +++ b/paddle/fluid/memory/allocation/thread_local_allocator_test.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/memory/allocation/thread_local_allocator.h" +#include +#include // NOLINT +#include +#include +#include // NOLINT +#include +#include "gtest/gtest.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/platform/gpu_info.h" + +DECLARE_double(fraction_of_gpu_memory_to_use); +DECLARE_string(allocator_strategy); + +namespace paddle { +namespace memory { +namespace allocation { + +TEST(ThreadLocalAllocator, cross_scope_release) { + FLAGS_fraction_of_gpu_memory_to_use = 0.1; + FLAGS_allocator_strategy = "thread_local"; + + const size_t thread_num = 5; + const std::vector devices = platform::GetSelectedDevices(); + + std::vector> allocator_addresses(devices.size()); + std::vector> thread_allocations(devices.size()); + + for (size_t i = 0; i < devices.size(); ++i) { + allocator_addresses[i].resize(thread_num); + thread_allocations[i].resize(thread_num); + } + + std::vector threads(thread_num); + std::mutex mutex; + std::condition_variable cv; + bool flag = false; + + for (size_t i = 0; i < threads.size(); ++i) { + threads[i] = std::thread([&, i]() { + { + std::unique_lock lock(mutex); + cv.wait(lock, [&] { return flag; }); + } + for (size_t j = 0; j < devices.size(); ++j) { + thread_allocations[j][i] = + memory::Alloc(platform::CUDAPlace(devices[j]), 10); + auto tl_allocator_impl = + ThreadLocalCUDAAllocatorPool::Instance().Get(devices[j]); + allocator_addresses[j][i] = tl_allocator_impl.get(); + } + }); + } + + { + std::lock_guard lock(mutex); + flag = true; + cv.notify_all(); + } + + for (auto &th : threads) { + th.join(); + } + + for (auto &addresses : allocator_addresses) { + std::sort(addresses.begin(), addresses.end()); + ASSERT_EQ(std::adjacent_find(addresses.begin(), addresses.end(), + std::equal_to()), + addresses.end()); + } + + ::testing::FLAGS_gtest_death_test_style = "threadsafe"; + ASSERT_EXIT(([&]() { thread_allocations.clear(); }(), exit(0)), + ::testing::ExitedWithCode(0), ".*"); +} + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index acc5a0f172d..d0d74f6ea87 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -44,6 +44,7 @@ cc_library(place SRCS place.cc DEPS enforce boost) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) add_subdirectory(dynload) +add_subdirectory(stream) cc_library(cpu_helper SRCS cpu_helper.cc DEPS cblas enforce) cc_test(cpu_helper_test SRCS cpu_helper_test.cc DEPS cpu_helper) @@ -54,7 +55,7 @@ IF(WITH_DGC) ENDIF() IF(WITH_GPU) - set(GPU_CTX_DEPS dynload_cuda dynamic_loader) + set(GPU_CTX_DEPS dynload_cuda dynamic_loader cuda_stream) ENDIF() IF(WITH_MKLDNN) diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 3bffa72bb92..d996ab55f1b 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -211,6 +211,33 @@ void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) { allocation_ = memory::Alloc(device_context_, required_workspace_bytes); } +thread_local std::unordered_map> + CUDADeviceContext::thread_ctx_; +thread_local std::mutex CUDADeviceContext::ctx_mtx_; + +void CUDAContext::InitEigenContext() { + eigen_stream_.reset(new EigenCudaStreamDevice()); + eigen_stream_->Reinitialize(&RawStream(), place_); + eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); +} + +CUDAContext::CUDAContext(const CUDAPlace& place, + const stream::Priority& priority) { + place_ = place; + CUDADeviceGuard guard(place_.device); + stream_.reset(new stream::CUDAStream(place, priority)); + InitEigenContext(); + InitCuBlasContext(); + InitCuDNNContext(); +} + +CUDAContext::~CUDAContext() { + CUDADeviceGuard guard(place_.device); + DestoryCuDNNContext(); + DestoryCuBlasContext(); +} + CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { CUDADeviceGuard guard(place_.device); compute_capability_ = GetCUDAComputeCapability(place_.device); @@ -218,18 +245,6 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device); max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device); max_threads_per_block_ = GetCUDAMaxThreadsPerBlock(place_.device); - PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream_)); - eigen_stream_.reset(new EigenCudaStreamDevice()); - eigen_stream_->Reinitialize(&stream_, place); - eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); - cublas_handle_.reset(new CublasHandleHolder(stream_, CUBLAS_DEFAULT_MATH)); - - if (TensorCoreAvailable()) { -#if CUDA_VERSION >= 9000 - cublas_tensor_core_handle_.reset( - new CublasHandleHolder(stream_, CUBLAS_TENSOR_OP_MATH)); -#endif - } driver_version_ = GetCUDADriverVersion(place_.device); runtime_version_ = GetCUDARuntimeVersion(place_.device); @@ -263,44 +278,12 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { << "Please recompile or reinstall Paddle with compatible CUDA " "version."; } - - if (dynload::HasCUDNN()) { - auto local_cudnn_version = cudnn_dso_ver / 100; - auto compile_cudnn_version = CUDNN_VERSION / 100; - if (local_cudnn_version < static_cast(compile_cudnn_version)) { - LOG_FIRST_N(WARNING, 1) - << "WARNING: device: " << place_.device - << ". The installed Paddle is compiled with CUDNN " - << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10 - << ", but CUDNN version in your machine is " - << local_cudnn_version / 10 << "." << local_cudnn_version % 10 - << ", which may cause serious incompatible bug. " - << "Please recompile or reinstall Paddle with compatible CUDNN " - "version."; - } - PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreate(&cudnn_handle_)); - PADDLE_ENFORCE_CUDA_SUCCESS( - dynload::cudnnSetStream(cudnn_handle_, stream_)); - } else { - cudnn_handle_ = nullptr; - } } - - callback_manager_.reset(new StreamCallbackManager(stream_)); + default_ctx_.reset(new CUDAContext(place_)); } CUDADeviceContext::~CUDADeviceContext() { SetDeviceId(place_.device); - Wait(); - WaitStreamCallback(); - cublas_handle_.reset(); - cublas_tensor_core_handle_.reset(); - eigen_stream_.reset(); - eigen_device_.reset(); - PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(stream_)); - if (cudnn_handle_) { - PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroy(cudnn_handle_)); - } #if defined(PADDLE_WITH_NCCL) if (nccl_comm_) { PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclCommDestroy(nccl_comm_)); @@ -310,19 +293,7 @@ CUDADeviceContext::~CUDADeviceContext() { Place CUDADeviceContext::GetPlace() const { return place_; } -void CUDADeviceContext::Wait() const { - cudaError_t e_sync = cudaSuccess; -#if !defined(_WIN32) - e_sync = cudaStreamSynchronize(stream_); -#else - while (e_sync = cudaStreamQuery(stream_)) { - if (e_sync == cudaErrorNotReady) continue; - break; - } -#endif - - PADDLE_ENFORCE_CUDA_SUCCESS(e_sync); -} +void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); } int CUDADeviceContext::GetComputeCapability() const { return compute_capability_; @@ -339,24 +310,28 @@ int CUDADeviceContext::GetMaxThreadsPerBlock() const { } Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { - return eigen_device_.get(); + return context()->EigenDevice().get(); } bool CUDADeviceContext::tensor_core_available() const { - return cublas_tensor_core_handle_ != nullptr; + return context()->CublasTensorCoreHandle() != nullptr; } dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const { return max_grid_dim_size_; } -cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } +cudnnHandle_t CUDADeviceContext::cudnn_handle() const { + return context()->CudnnHandle(); +} CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_); } -cudaStream_t CUDADeviceContext::stream() const { return stream_; } +cudaStream_t CUDADeviceContext::stream() const { + return context()->RawStream(); +} CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index d6b8cda94f2..7615a0b7ea0 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -38,7 +38,7 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" #ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/platform/stream_callback_manager.h" +#include "paddle/fluid/platform/stream/cuda_stream.h" #endif #include "unsupported/Eigen/CXX11/Tensor" @@ -80,6 +80,118 @@ struct DefaultDeviceContextType { class EigenCudaStreamDevice; class CudnnWorkspaceHandle; +class CUDAContext { + public: + CUDAContext() = default; + explicit CUDAContext( + const CUDAPlace& place, + const stream::Priority& priority = stream::Priority::kNormal); + + ~CUDAContext(); + + const CUDAPlace& Place() const { return place_; } + + const std::unique_ptr& EigenDevice() const { + return eigen_device_; + } + + const std::unique_ptr& EigenStream() const { + return eigen_stream_; + } + + const std::unique_ptr& Stream() const { return stream_; } + + const cudaStream_t& RawStream() { return stream_->raw_stream(); } + + const cudnnHandle_t& CudnnHandle() const { return cudnn_handle_; } + + const std::unique_ptr& CublasHandle() const { + return cublas_handle_; + } + + const std::unique_ptr& CublasTensorCoreHandle() const { + return cublas_tensor_core_handle_; + } + + /*! \brief Call cublas function safely. */ + template + inline void CublasCall(Callback&& callback) const { + cublas_handle_->Call(std::forward(callback)); + } + + /*! \brief Check whether tensor core is supported */ + bool tensor_core_available() const; + + /*! \brief Call cublas function with Tensor Core safely. If + Tensor Core is not available, use DEFAULT_MATH instead. */ + template + inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const { + if (cublas_tensor_core_handle_) { + cublas_tensor_core_handle_->Call(std::forward(callback)); + } else { + cublas_handle_->Call(std::forward(callback)); + } + } + + private: + void InitEigenContext(); + + void InitCuBlasContext() { + cublas_handle_.reset( + new CublasHandleHolder(RawStream(), CUBLAS_DEFAULT_MATH)); + if (TensorCoreAvailable()) { +#if CUDA_VERSION >= 9000 + cublas_tensor_core_handle_.reset( + new CublasHandleHolder(RawStream(), CUBLAS_TENSOR_OP_MATH)); +#endif + } + } + + void InitCuDNNContext() { + if (dynload::HasCUDNN()) { + auto local_cudnn_version = dynload::cudnnGetVersion() / 100; + auto compile_cudnn_version = CUDNN_VERSION / 100; + if (local_cudnn_version < static_cast(compile_cudnn_version)) { + LOG_FIRST_N(WARNING, 1) + << "WARNING: device: " << place_.device + << ". The installed Paddle is compiled with CUDNN " + << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10 + << ", but CUDNN version in your machine is " + << local_cudnn_version / 10 << "." << local_cudnn_version % 10 + << ", which may cause serious incompatible bug. " + << "Please recompile or reinstall Paddle with compatible CUDNN " + "version."; + } + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreate(&cudnn_handle_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnSetStream(cudnn_handle_, RawStream())); + } else { + cudnn_handle_ = nullptr; + } + } + + void DestoryCuDNNContext() { + if (cudnn_handle_) { + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroy(cudnn_handle_)); + } + cudnn_handle_ = nullptr; + } + + void DestoryCuBlasContext() { + cublas_handle_.reset(); + cublas_tensor_core_handle_.reset(); + } + + CUDAPlace place_; + std::unique_ptr eigen_device_; + std::unique_ptr eigen_stream_; + std::unique_ptr stream_; + cudnnHandle_t cudnn_handle_; + std::unique_ptr cublas_handle_; + std::unique_ptr cublas_tensor_core_handle_; + DISABLE_COPY_AND_ASSIGN(CUDAContext); +}; + class CUDADeviceContext : public DeviceContext { public: explicit CUDADeviceContext(CUDAPlace place); @@ -112,7 +224,7 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Call cublas function safely. */ template inline void CublasCall(Callback&& callback) const { - cublas_handle_->Call(std::forward(callback)); + return context()->CublasCall(callback); } /*! \brief Check whether tensor core is supported */ @@ -122,11 +234,7 @@ class CUDADeviceContext : public DeviceContext { Tensor Core is not available, use DEFAULT_MATH instead. */ template inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const { - if (cublas_tensor_core_handle_) { - cublas_tensor_core_handle_->Call(std::forward(callback)); - } else { - cublas_handle_->Call(std::forward(callback)); - } + return context()->TensorCoreCublasCallIfAvailable(callback); } /*! \brief Return cudnn handle in the device context. */ @@ -153,33 +261,48 @@ class CUDADeviceContext : public DeviceContext { #endif template - void RecordEvent(cudaEvent_t ev, Callback callback) { - callback(); - PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(ev, stream_)); + void RecordEvent(cudaEvent_t ev, Callback callback) const { + return context()->Stream()->RecordEvent(ev, callback); } template void AddStreamCallback(Callback&& callback) const { - callback_manager_->AddCallback(callback); + return context()->Stream()->AddCallback(callback); + } + + void WaitStreamCallback() const { + return context()->Stream()->WaitCallback(); } - void WaitStreamCallback() const { callback_manager_->Wait(); } + void ResetDefaultContext(const stream::Priority& priority) { + default_ctx_.reset(new CUDAContext(place_, priority)); + } + + void ResetThreadContext(const stream::Priority& priority) { + std::lock_guard guard(ctx_mtx_); + thread_ctx_[this].reset(new CUDAContext(place_, priority)); + } + + std::shared_ptr context() const { + if (!thread_ctx_.count(this)) { + return default_ctx_; + } + return thread_ctx_.at(this); + } private: CUDAPlace place_; + std::shared_ptr default_ctx_; - mutable std::once_flag init_cudnn_; + // The thread_local static variable will be released before the + // global static variable, so avoid using it in dtor. + static thread_local std::unordered_map> + thread_ctx_; + static thread_local std::mutex ctx_mtx_; - std::unique_ptr eigen_device_; - std::unique_ptr eigen_stream_; - cudaStream_t stream_; - - cudnnHandle_t cudnn_handle_; mutable std::mutex cudnn_handle_mtx_; - std::unique_ptr cublas_handle_; - std::unique_ptr cublas_tensor_core_handle_; - #if defined(PADDLE_WITH_NCCL) // NCCL communicator (single process version) for NCCL collective operations. // NCCL collective operations provides fast collectives over multiple GPUs @@ -197,9 +320,6 @@ class CUDADeviceContext : public DeviceContext { int max_threads_per_block_; dim3 max_grid_dim_size_; - // StreamCallbackManager is thread-safe - std::unique_ptr callback_manager_; - DISABLE_COPY_AND_ASSIGN(CUDADeviceContext); }; diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 046fd16fb15..199c37d78ab 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -303,7 +303,8 @@ DEFINE_double(memory_fraction_of_eager_deletion, 1.0, * Allocator related FLAG * Name: FLAGS_allocator_strategy * Since Version: 1.2 - * Value Range: string, {naive_best_fit, auto_growth}, default=auto_growth + * Value Range: string, {naive_best_fit, auto_growth, thread_local}, + * default=auto_growth * Example: * Note: For selecting allocator policy of PaddlePaddle. */ diff --git a/paddle/fluid/platform/stream/CMakeLists.txt b/paddle/fluid/platform/stream/CMakeLists.txt new file mode 100644 index 00000000000..78a7313bded --- /dev/null +++ b/paddle/fluid/platform/stream/CMakeLists.txt @@ -0,0 +1,3 @@ +IF(WITH_GPU) +cc_library(cuda_stream SRCS cuda_stream.cc DEPS enforce boost) +ENDIF() diff --git a/paddle/fluid/platform/stream/cuda_stream.cc b/paddle/fluid/platform/stream/cuda_stream.cc new file mode 100644 index 00000000000..6d3b0650376 --- /dev/null +++ b/paddle/fluid/platform/stream/cuda_stream.cc @@ -0,0 +1,70 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/stream/cuda_stream.h" +#include "paddle/fluid/platform/cuda_device_guard.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace stream { + +constexpr unsigned int kDefaultFlag = cudaStreamDefault; + +bool CUDAStream::Init(const Place& place, const Priority& priority) { + PADDLE_ENFORCE_EQ(is_gpu_place(place), true, + platform::errors::InvalidArgument( + "Cuda stream must be created using cuda place.")); + place_ = place; + CUDADeviceGuard guard(boost::get(place_).device); + if (priority == Priority::kHigh) { + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamCreateWithPriority(&stream_, kDefaultFlag, -1)); + } else if (priority == Priority::kNormal) { + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamCreateWithPriority(&stream_, kDefaultFlag, 0)); + } + callback_manager_.reset(new StreamCallbackManager(stream_)); + VLOG(3) << "CUDAStream Init stream: " << stream_ + << ", priority: " << static_cast(priority); + return true; +} + +void CUDAStream::Destroy() { + CUDADeviceGuard guard(boost::get(place_).device); + Wait(); + WaitCallback(); + if (stream_) { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(stream_)); + } + stream_ = nullptr; +} + +void CUDAStream::Wait() const { + cudaError_t e_sync = cudaSuccess; +#if !defined(_WIN32) + e_sync = cudaStreamSynchronize(stream_); +#else + while (e_sync = cudaStreamQuery(stream_)) { + if (e_sync == cudaErrorNotReady) continue; + break; + } +#endif + + PADDLE_ENFORCE_CUDA_SUCCESS(e_sync); +} + +} // namespace stream +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/stream/cuda_stream.h b/paddle/fluid/platform/stream/cuda_stream.h new file mode 100644 index 00000000000..4272d5fd0b1 --- /dev/null +++ b/paddle/fluid/platform/stream/cuda_stream.h @@ -0,0 +1,85 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/platform/gpu_info.h" +#include "paddle/fluid/platform/macros.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/platform/stream_callback_manager.h" + +namespace paddle { +namespace platform { +namespace stream { + +#ifdef PADDLE_WITH_CUDA + +enum class Priority : uint8_t { + kNull = 0x0, + kHigh = 0x1, + kNormal = 0x2, +}; + +class CUDAStream final { + public: + CUDAStream() = default; + explicit CUDAStream(const Place& place, + const Priority& priority = Priority::kNormal) { + Init(place, priority); + } + virtual ~CUDAStream() { Destroy(); } + + bool Init(const Place& place, const Priority& priority = Priority::kNormal); + + template + void AddCallback(Callback&& callback) const { + callback_manager_->AddCallback(callback); + } + + template + void RecordEvent(cudaEvent_t ev, Callback callback) const { + callback(); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(ev, stream_)); + } + + void RecordEvent(cudaEvent_t ev) const { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(ev, stream_)); + } + + void WaitEvent(cudaEvent_t ev) const { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(stream_, ev, 0)); + } + + void Wait() const; + void WaitCallback() const { callback_manager_->Wait(); } + + const cudaStream_t& raw_stream() const { return stream_; } + void Destroy(); + + private: + Place place_; + cudaStream_t stream_{nullptr}; + Priority priority_{Priority::kNormal}; + std::unique_ptr callback_manager_; + + DISABLE_COPY_AND_ASSIGN(CUDAStream); +}; + +#endif + +} // namespace stream +} // namespace platform +} // namespace paddle -- GitLab