未验证 提交 597cc058 编写于 作者: 石晓伟 提交者: GitHub

[cherry-pick] Thread-local Allocator, test=release/2.0 (#24061)

* cherry-pick of DeviceContext Split, test=develop (#23737)

* New feature: thread local allocator, test=develop (#23989)

* add the thread_local_allocator, test=develop

* refactor the thread_local_allocator, test=develop

* provides option setting strategy, test=develop

* add boost dependency to cuda_stream, test=develop

* declare the stream::Priority as enum class, test=develop

* deal with PADDLE_ENFORCE_CUDA_SUCCESS macro in pr #23816
上级 c1b4d1c1
......@@ -14,13 +14,15 @@ endif()
if (WITH_GPU)
nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard)
nv_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator)
cc_test(thread_local_allocator_test SRCS thread_local_allocator_test.cc DEPS thread_local_allocator)
endif()
cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator)
nv_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator)
if (WITH_GPU)
set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard)
set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard thread_local_allocator)
else ()
set(AllocatorFacadeDeps)
endif()
......
......@@ -32,6 +32,7 @@
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/memory/allocation/cuda_allocator.h"
#include "paddle/fluid/memory/allocation/pinned_allocator.h"
#include "paddle/fluid/memory/allocation/thread_local_allocator.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/gpu_info.h"
#endif
......@@ -80,6 +81,18 @@ class AllocatorFacadePrivate {
break;
}
case AllocatorStrategy::kThreadLocal: {
InitNaiveBestFitCPUAllocator();
#ifdef PADDLE_WITH_CUDA
for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount();
++dev_id) {
InitThreadLocalCUDAAllocator(platform::CUDAPlace(dev_id));
}
InitNaiveBestFitCUDAPinnedAllocator();
#endif
break;
}
default: {
PADDLE_THROW("Unsupported allocator strategy: %d",
static_cast<int>(strategy));
......@@ -136,6 +149,10 @@ class AllocatorFacadePrivate {
allocators_[p] = std::make_shared<NaiveBestFitAllocator>(p);
}
void InitThreadLocalCUDAAllocator(platform::CUDAPlace p) {
allocators_[p] = std::make_shared<ThreadLocalCUDAAllocator>(p);
}
void InitAutoGrowthCUDAAllocator(platform::CUDAPlace p) {
auto cuda_allocator = std::make_shared<CUDAAllocator>(p);
allocators_[p] = std::make_shared<AutoGrowthBestFitAllocator>(
......
......@@ -32,6 +32,10 @@ static AllocatorStrategy GetStrategyFromFlag() {
return AllocatorStrategy::kAutoGrowth;
}
if (FLAGS_allocator_strategy == "thread_local") {
return AllocatorStrategy::kThreadLocal;
}
PADDLE_THROW("Unsupported allocator strategy: %s", FLAGS_allocator_strategy);
}
......
......@@ -18,7 +18,7 @@ namespace paddle {
namespace memory {
namespace allocation {
enum class AllocatorStrategy { kNaiveBestFit, kAutoGrowth };
enum class AllocatorStrategy { kNaiveBestFit, kAutoGrowth, kThreadLocal };
extern AllocatorStrategy GetAllocatorStrategy();
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/thread_local_allocator.h"
namespace paddle {
namespace memory {
namespace allocation {
ThreadLocalAllocatorImpl::ThreadLocalAllocatorImpl(const platform::Place& p)
: place_(p) {
if (platform::is_gpu_place(place_)) {
buddy_allocator_.reset(new memory::detail::BuddyAllocator(
std::unique_ptr<memory::detail::SystemAllocator>(
new memory::detail::GPUAllocator(
boost::get<platform::CUDAPlace>(place_).device)),
platform::GpuMinChunkSize(), platform::GpuMaxChunkSize()));
} else {
LOG(FATAL) << "Thread local allocator only supports CUDAPlace now.";
}
}
std::shared_ptr<ThreadLocalAllocatorImpl> ThreadLocalCUDAAllocatorPool::Get(
int gpu_id) {
auto pos = std::distance(devices_.begin(),
std::find(devices_.begin(), devices_.end(), gpu_id));
PADDLE_ENFORCE_LT(
pos, devices_.size(),
platform::errors::InvalidArgument(
"The position of device should be less than the size of devices."));
std::call_once(*init_flags_[pos], [this, pos, gpu_id] {
platform::SetDeviceId(devices_[pos]);
allocators_[pos].reset(
new ThreadLocalAllocatorImpl(platform::CUDAPlace(gpu_id)));
});
return allocators_[pos];
}
ThreadLocalCUDAAllocatorPool::ThreadLocalCUDAAllocatorPool()
: devices_(platform::GetSelectedDevices()) {
auto gpu_num = devices_.size();
allocators_.resize(gpu_num);
init_flags_.reserve(gpu_num);
for (size_t i = 0; i < gpu_num; ++i) {
init_flags_.emplace_back(new std::once_flag());
}
}
ThreadLocalAllocation* ThreadLocalAllocatorImpl::AllocateImpl(size_t size) {
VLOG(10) << "ThreadLocalAllocatorImpl::AllocateImpl " << size;
void* ptr = buddy_allocator_->Alloc(size);
auto* tl_allocation = new ThreadLocalAllocation(ptr, size, place_);
tl_allocation->SetThreadLocalAllocatorImpl(shared_from_this());
return tl_allocation;
}
void ThreadLocalAllocatorImpl::FreeImpl(ThreadLocalAllocation* allocation) {
VLOG(10) << "ThreadLocalAllocatorImpl::FreeImpl " << allocation;
buddy_allocator_->Free(allocation->ptr());
delete allocation;
}
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/detail/buddy_allocator.h"
#include "paddle/fluid/memory/detail/system_allocator.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace memory {
namespace allocation {
class ThreadLocalAllocatorImpl;
class ThreadLocalAllocation : public Allocation {
public:
ThreadLocalAllocation(void* ptr, size_t size, platform::Place place)
: Allocation(ptr, size, place) {}
void SetThreadLocalAllocatorImpl(
std::shared_ptr<ThreadLocalAllocatorImpl> allocator) {
allocator_ = allocator;
}
std::shared_ptr<ThreadLocalAllocatorImpl> GetAllocator() {
return allocator_;
}
private:
std::shared_ptr<ThreadLocalAllocatorImpl> allocator_;
};
class ThreadLocalAllocatorImpl
: public std::enable_shared_from_this<ThreadLocalAllocatorImpl> {
public:
explicit ThreadLocalAllocatorImpl(const platform::Place& p);
ThreadLocalAllocation* AllocateImpl(size_t size);
void FreeImpl(ThreadLocalAllocation* allocation);
private:
std::unique_ptr<memory::detail::BuddyAllocator> buddy_allocator_;
platform::Place place_;
};
class ThreadLocalCUDAAllocatorPool {
public:
static ThreadLocalCUDAAllocatorPool& Instance() {
static thread_local ThreadLocalCUDAAllocatorPool pool;
return pool;
}
std::shared_ptr<ThreadLocalAllocatorImpl> Get(int gpu_id);
private:
ThreadLocalCUDAAllocatorPool();
std::vector<int> devices_;
std::vector<std::unique_ptr<std::once_flag>> init_flags_;
std::vector<std::shared_ptr<ThreadLocalAllocatorImpl>> allocators_;
};
class ThreadLocalCUDAAllocator : public Allocator {
public:
explicit ThreadLocalCUDAAllocator(const platform::CUDAPlace& p)
: gpu_id_(p.device) {}
bool IsAllocThreadSafe() const override { return true; }
protected:
Allocation* AllocateImpl(size_t size) override {
return ThreadLocalCUDAAllocatorPool::Instance().Get(gpu_id_)->AllocateImpl(
size);
}
void FreeImpl(Allocation* allocation) override {
auto* tl_allocation = static_cast<ThreadLocalAllocation*>(allocation);
auto allocator_impl = tl_allocation->GetAllocator();
allocator_impl->FreeImpl(tl_allocation);
}
private:
int gpu_id_;
};
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/thread_local_allocator.h"
#include <algorithm>
#include <condition_variable> // NOLINT
#include <functional>
#include <iostream>
#include <thread> // NOLINT
#include <utility>
#include "gtest/gtest.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/gpu_info.h"
DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_string(allocator_strategy);
namespace paddle {
namespace memory {
namespace allocation {
TEST(ThreadLocalAllocator, cross_scope_release) {
FLAGS_fraction_of_gpu_memory_to_use = 0.1;
FLAGS_allocator_strategy = "thread_local";
const size_t thread_num = 5;
const std::vector<int> devices = platform::GetSelectedDevices();
std::vector<std::vector<void *>> allocator_addresses(devices.size());
std::vector<std::vector<AllocationPtr>> thread_allocations(devices.size());
for (size_t i = 0; i < devices.size(); ++i) {
allocator_addresses[i].resize(thread_num);
thread_allocations[i].resize(thread_num);
}
std::vector<std::thread> threads(thread_num);
std::mutex mutex;
std::condition_variable cv;
bool flag = false;
for (size_t i = 0; i < threads.size(); ++i) {
threads[i] = std::thread([&, i]() {
{
std::unique_lock<std::mutex> lock(mutex);
cv.wait(lock, [&] { return flag; });
}
for (size_t j = 0; j < devices.size(); ++j) {
thread_allocations[j][i] =
memory::Alloc(platform::CUDAPlace(devices[j]), 10);
auto tl_allocator_impl =
ThreadLocalCUDAAllocatorPool::Instance().Get(devices[j]);
allocator_addresses[j][i] = tl_allocator_impl.get();
}
});
}
{
std::lock_guard<std::mutex> lock(mutex);
flag = true;
cv.notify_all();
}
for (auto &th : threads) {
th.join();
}
for (auto &addresses : allocator_addresses) {
std::sort(addresses.begin(), addresses.end());
ASSERT_EQ(std::adjacent_find(addresses.begin(), addresses.end(),
std::equal_to<void *>()),
addresses.end());
}
::testing::FLAGS_gtest_death_test_style = "threadsafe";
ASSERT_EXIT(([&]() { thread_allocations.clear(); }(), exit(0)),
::testing::ExitedWithCode(0), ".*");
}
} // namespace allocation
} // namespace memory
} // namespace paddle
......@@ -44,6 +44,7 @@ cc_library(place SRCS place.cc DEPS enforce boost)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
add_subdirectory(dynload)
add_subdirectory(stream)
cc_library(cpu_helper SRCS cpu_helper.cc DEPS cblas enforce)
cc_test(cpu_helper_test SRCS cpu_helper_test.cc DEPS cpu_helper)
......@@ -54,7 +55,7 @@ IF(WITH_DGC)
ENDIF()
IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader cuda_stream)
ENDIF()
IF(WITH_MKLDNN)
......
......@@ -211,6 +211,33 @@ void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}
thread_local std::unordered_map<const CUDADeviceContext*,
std::shared_ptr<CUDAContext>>
CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;
void CUDAContext::InitEigenContext() {
eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&RawStream(), place_);
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}
CUDAContext::CUDAContext(const CUDAPlace& place,
const stream::Priority& priority) {
place_ = place;
CUDADeviceGuard guard(place_.device);
stream_.reset(new stream::CUDAStream(place, priority));
InitEigenContext();
InitCuBlasContext();
InitCuDNNContext();
}
CUDAContext::~CUDAContext() {
CUDADeviceGuard guard(place_.device);
DestoryCuDNNContext();
DestoryCuBlasContext();
}
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
CUDADeviceGuard guard(place_.device);
compute_capability_ = GetCUDAComputeCapability(place_.device);
......@@ -218,18 +245,6 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
max_threads_per_block_ = GetCUDAMaxThreadsPerBlock(place_.device);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream_));
eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&stream_, place);
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
cublas_handle_.reset(new CublasHandleHolder(stream_, CUBLAS_DEFAULT_MATH));
if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_.reset(
new CublasHandleHolder(stream_, CUBLAS_TENSOR_OP_MATH));
#endif
}
driver_version_ = GetCUDADriverVersion(place_.device);
runtime_version_ = GetCUDARuntimeVersion(place_.device);
......@@ -263,44 +278,12 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
<< "Please recompile or reinstall Paddle with compatible CUDA "
"version.";
}
if (dynload::HasCUDNN()) {
auto local_cudnn_version = cudnn_dso_ver / 100;
auto compile_cudnn_version = CUDNN_VERSION / 100;
if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << place_.device
<< ". The installed Paddle is compiled with CUDNN "
<< compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
<< ", but CUDNN version in your machine is "
<< local_cudnn_version / 10 << "." << local_cudnn_version % 10
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible CUDNN "
"version.";
}
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnSetStream(cudnn_handle_, stream_));
} else {
cudnn_handle_ = nullptr;
}
}
callback_manager_.reset(new StreamCallbackManager(stream_));
default_ctx_.reset(new CUDAContext(place_));
}
CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device);
Wait();
WaitStreamCallback();
cublas_handle_.reset();
cublas_tensor_core_handle_.reset();
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(stream_));
if (cudnn_handle_) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroy(cudnn_handle_));
}
#if defined(PADDLE_WITH_NCCL)
if (nccl_comm_) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
......@@ -310,19 +293,7 @@ CUDADeviceContext::~CUDADeviceContext() {
Place CUDADeviceContext::GetPlace() const { return place_; }
void CUDADeviceContext::Wait() const {
cudaError_t e_sync = cudaSuccess;
#if !defined(_WIN32)
e_sync = cudaStreamSynchronize(stream_);
#else
while (e_sync = cudaStreamQuery(stream_)) {
if (e_sync == cudaErrorNotReady) continue;
break;
}
#endif
PADDLE_ENFORCE_CUDA_SUCCESS(e_sync);
}
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
int CUDADeviceContext::GetComputeCapability() const {
return compute_capability_;
......@@ -339,24 +310,28 @@ int CUDADeviceContext::GetMaxThreadsPerBlock() const {
}
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get();
return context()->EigenDevice().get();
}
bool CUDADeviceContext::tensor_core_available() const {
return cublas_tensor_core_handle_ != nullptr;
return context()->CublasTensorCoreHandle() != nullptr;
}
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
return max_grid_dim_size_;
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return context()->CudnnHandle();
}
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
}
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
cudaStream_t CUDADeviceContext::stream() const {
return context()->RawStream();
}
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
......
......@@ -38,7 +38,7 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/stream_callback_manager.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
#endif
#include "unsupported/Eigen/CXX11/Tensor"
......@@ -80,6 +80,118 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
class EigenCudaStreamDevice;
class CudnnWorkspaceHandle;
class CUDAContext {
public:
CUDAContext() = default;
explicit CUDAContext(
const CUDAPlace& place,
const stream::Priority& priority = stream::Priority::kNormal);
~CUDAContext();
const CUDAPlace& Place() const { return place_; }
const std::unique_ptr<Eigen::GpuDevice>& EigenDevice() const {
return eigen_device_;
}
const std::unique_ptr<EigenCudaStreamDevice>& EigenStream() const {
return eigen_stream_;
}
const std::unique_ptr<stream::CUDAStream>& Stream() const { return stream_; }
const cudaStream_t& RawStream() { return stream_->raw_stream(); }
const cudnnHandle_t& CudnnHandle() const { return cudnn_handle_; }
const std::unique_ptr<CublasHandleHolder>& CublasHandle() const {
return cublas_handle_;
}
const std::unique_ptr<CublasHandleHolder>& CublasTensorCoreHandle() const {
return cublas_tensor_core_handle_;
}
/*! \brief Call cublas function safely. */
template <typename Callback>
inline void CublasCall(Callback&& callback) const {
cublas_handle_->Call(std::forward<Callback>(callback));
}
/*! \brief Check whether tensor core is supported */
bool tensor_core_available() const;
/*! \brief Call cublas function with Tensor Core safely. If
Tensor Core is not available, use DEFAULT_MATH instead. */
template <typename Callback>
inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
if (cublas_tensor_core_handle_) {
cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
} else {
cublas_handle_->Call(std::forward<Callback>(callback));
}
}
private:
void InitEigenContext();
void InitCuBlasContext() {
cublas_handle_.reset(
new CublasHandleHolder(RawStream(), CUBLAS_DEFAULT_MATH));
if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_.reset(
new CublasHandleHolder(RawStream(), CUBLAS_TENSOR_OP_MATH));
#endif
}
}
void InitCuDNNContext() {
if (dynload::HasCUDNN()) {
auto local_cudnn_version = dynload::cudnnGetVersion() / 100;
auto compile_cudnn_version = CUDNN_VERSION / 100;
if (local_cudnn_version < static_cast<size_t>(compile_cudnn_version)) {
LOG_FIRST_N(WARNING, 1)
<< "WARNING: device: " << place_.device
<< ". The installed Paddle is compiled with CUDNN "
<< compile_cudnn_version / 10 << "." << compile_cudnn_version % 10
<< ", but CUDNN version in your machine is "
<< local_cudnn_version / 10 << "." << local_cudnn_version % 10
<< ", which may cause serious incompatible bug. "
<< "Please recompile or reinstall Paddle with compatible CUDNN "
"version.";
}
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnSetStream(cudnn_handle_, RawStream()));
} else {
cudnn_handle_ = nullptr;
}
}
void DestoryCuDNNContext() {
if (cudnn_handle_) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroy(cudnn_handle_));
}
cudnn_handle_ = nullptr;
}
void DestoryCuBlasContext() {
cublas_handle_.reset();
cublas_tensor_core_handle_.reset();
}
CUDAPlace place_;
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
std::unique_ptr<stream::CUDAStream> stream_;
cudnnHandle_t cudnn_handle_;
std::unique_ptr<CublasHandleHolder> cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
DISABLE_COPY_AND_ASSIGN(CUDAContext);
};
class CUDADeviceContext : public DeviceContext {
public:
explicit CUDADeviceContext(CUDAPlace place);
......@@ -112,7 +224,7 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Call cublas function safely. */
template <typename Callback>
inline void CublasCall(Callback&& callback) const {
cublas_handle_->Call(std::forward<Callback>(callback));
return context()->CublasCall(callback);
}
/*! \brief Check whether tensor core is supported */
......@@ -122,11 +234,7 @@ class CUDADeviceContext : public DeviceContext {
Tensor Core is not available, use DEFAULT_MATH instead. */
template <typename Callback>
inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
if (cublas_tensor_core_handle_) {
cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
} else {
cublas_handle_->Call(std::forward<Callback>(callback));
}
return context()->TensorCoreCublasCallIfAvailable(callback);
}
/*! \brief Return cudnn handle in the device context. */
......@@ -153,33 +261,48 @@ class CUDADeviceContext : public DeviceContext {
#endif
template <typename Callback>
void RecordEvent(cudaEvent_t ev, Callback callback) {
callback();
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(ev, stream_));
void RecordEvent(cudaEvent_t ev, Callback callback) const {
return context()->Stream()->RecordEvent(ev, callback);
}
template <typename Callback>
void AddStreamCallback(Callback&& callback) const {
callback_manager_->AddCallback(callback);
return context()->Stream()->AddCallback(callback);
}
void WaitStreamCallback() const {
return context()->Stream()->WaitCallback();
}
void WaitStreamCallback() const { callback_manager_->Wait(); }
void ResetDefaultContext(const stream::Priority& priority) {
default_ctx_.reset(new CUDAContext(place_, priority));
}
void ResetThreadContext(const stream::Priority& priority) {
std::lock_guard<std::mutex> guard(ctx_mtx_);
thread_ctx_[this].reset(new CUDAContext(place_, priority));
}
std::shared_ptr<CUDAContext> context() const {
if (!thread_ctx_.count(this)) {
return default_ctx_;
}
return thread_ctx_.at(this);
}
private:
CUDAPlace place_;
std::shared_ptr<CUDAContext> default_ctx_;
mutable std::once_flag init_cudnn_;
// The thread_local static variable will be released before the
// global static variable, so avoid using it in dtor.
static thread_local std::unordered_map<const CUDADeviceContext*,
std::shared_ptr<CUDAContext>>
thread_ctx_;
static thread_local std::mutex ctx_mtx_;
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
cudaStream_t stream_;
cudnnHandle_t cudnn_handle_;
mutable std::mutex cudnn_handle_mtx_;
std::unique_ptr<CublasHandleHolder> cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
#if defined(PADDLE_WITH_NCCL)
// NCCL communicator (single process version) for NCCL collective operations.
// NCCL collective operations provides fast collectives over multiple GPUs
......@@ -197,9 +320,6 @@ class CUDADeviceContext : public DeviceContext {
int max_threads_per_block_;
dim3 max_grid_dim_size_;
// StreamCallbackManager is thread-safe
std::unique_ptr<StreamCallbackManager> callback_manager_;
DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
};
......
......@@ -303,7 +303,8 @@ DEFINE_double(memory_fraction_of_eager_deletion, 1.0,
* Allocator related FLAG
* Name: FLAGS_allocator_strategy
* Since Version: 1.2
* Value Range: string, {naive_best_fit, auto_growth}, default=auto_growth
* Value Range: string, {naive_best_fit, auto_growth, thread_local},
* default=auto_growth
* Example:
* Note: For selecting allocator policy of PaddlePaddle.
*/
......
IF(WITH_GPU)
cc_library(cuda_stream SRCS cuda_stream.cc DEPS enforce boost)
ENDIF()
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
namespace stream {
constexpr unsigned int kDefaultFlag = cudaStreamDefault;
bool CUDAStream::Init(const Place& place, const Priority& priority) {
PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
platform::errors::InvalidArgument(
"Cuda stream must be created using cuda place."));
place_ = place;
CUDADeviceGuard guard(boost::get<CUDAPlace>(place_).device);
if (priority == Priority::kHigh) {
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithPriority(&stream_, kDefaultFlag, -1));
} else if (priority == Priority::kNormal) {
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithPriority(&stream_, kDefaultFlag, 0));
}
callback_manager_.reset(new StreamCallbackManager(stream_));
VLOG(3) << "CUDAStream Init stream: " << stream_
<< ", priority: " << static_cast<int>(priority);
return true;
}
void CUDAStream::Destroy() {
CUDADeviceGuard guard(boost::get<CUDAPlace>(place_).device);
Wait();
WaitCallback();
if (stream_) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(stream_));
}
stream_ = nullptr;
}
void CUDAStream::Wait() const {
cudaError_t e_sync = cudaSuccess;
#if !defined(_WIN32)
e_sync = cudaStreamSynchronize(stream_);
#else
while (e_sync = cudaStreamQuery(stream_)) {
if (e_sync == cudaErrorNotReady) continue;
break;
}
#endif
PADDLE_ENFORCE_CUDA_SUCCESS(e_sync);
}
} // namespace stream
} // namespace platform
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstdint>
#include <memory>
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream_callback_manager.h"
namespace paddle {
namespace platform {
namespace stream {
#ifdef PADDLE_WITH_CUDA
enum class Priority : uint8_t {
kNull = 0x0,
kHigh = 0x1,
kNormal = 0x2,
};
class CUDAStream final {
public:
CUDAStream() = default;
explicit CUDAStream(const Place& place,
const Priority& priority = Priority::kNormal) {
Init(place, priority);
}
virtual ~CUDAStream() { Destroy(); }
bool Init(const Place& place, const Priority& priority = Priority::kNormal);
template <typename Callback>
void AddCallback(Callback&& callback) const {
callback_manager_->AddCallback(callback);
}
template <typename Callback>
void RecordEvent(cudaEvent_t ev, Callback callback) const {
callback();
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(ev, stream_));
}
void RecordEvent(cudaEvent_t ev) const {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(ev, stream_));
}
void WaitEvent(cudaEvent_t ev) const {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(stream_, ev, 0));
}
void Wait() const;
void WaitCallback() const { callback_manager_->Wait(); }
const cudaStream_t& raw_stream() const { return stream_; }
void Destroy();
private:
Place place_;
cudaStream_t stream_{nullptr};
Priority priority_{Priority::kNormal};
std::unique_ptr<StreamCallbackManager> callback_manager_;
DISABLE_COPY_AND_ASSIGN(CUDAStream);
};
#endif
} // namespace stream
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册