未验证 提交 62b1f38c 编写于 作者: S sneaxiy 提交者: GitHub

make cuda graph thread local allocator (#37814)

上级 c732c831
...@@ -348,13 +348,14 @@ class AllocatorFacadePrivate { ...@@ -348,13 +348,14 @@ class AllocatorFacadePrivate {
const AllocatorMap& GetAllocatorMap() { const AllocatorMap& GetAllocatorMap() {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (UNLIKELY(platform::CUDAGraph::IsCapturing())) { if (UNLIKELY(platform::CUDAGraph::IsThisThreadCapturing())) {
auto id = platform::CUDAGraph::CapturingID(); auto id = platform::CUDAGraph::CapturingID();
auto iter = cuda_graph_allocator_map_.find(id); auto iter = cuda_graph_allocator_map_.find(id);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
iter, cuda_graph_allocator_map_.end(), iter, cuda_graph_allocator_map_.end(),
platform::errors::PermissionDenied( platform::errors::PermissionDenied(
"No memory pool is prepared for CUDA Graph capturing.")); "No memory pool is prepared for CUDA Graph capturing."));
VLOG(10) << "Choose CUDA Graph memory pool to allocate memory";
return iter->second->allocators_; return iter->second->allocators_;
} else { } else {
return allocators_; return allocators_;
...@@ -405,7 +406,7 @@ class AllocatorFacadePrivate { ...@@ -405,7 +406,7 @@ class AllocatorFacadePrivate {
#if defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_HIP)
auto cuda_allocator = std::make_shared<CUDAAllocator>(p); auto cuda_allocator = std::make_shared<CUDAAllocator>(p);
cuda_allocators_[p][stream] = std::make_shared<AutoGrowthBestFitAllocator>( cuda_allocators_[p][stream] = std::make_shared<AutoGrowthBestFitAllocator>(
cuda_allocator, platform::GpuMinChunkSize(), allow_free_idle_chunk_); cuda_allocator, platform::GpuMinChunkSize(), 0, allow_free_idle_chunk_);
#endif #endif
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
......
...@@ -18,6 +18,7 @@ namespace paddle { ...@@ -18,6 +18,7 @@ namespace paddle {
namespace platform { namespace platform {
std::unique_ptr<CUDAGraph> CUDAGraph::capturing_graph_{nullptr}; std::unique_ptr<CUDAGraph> CUDAGraph::capturing_graph_{nullptr};
paddle::optional<std::thread::id> CUDAGraph::capturing_thread_id_{paddle::none};
void CUDAGraph::Reset() { void CUDAGraph::Reset() {
if (is_reset_) return; if (is_reset_) return;
...@@ -58,6 +59,13 @@ void CUDAGraph::BeginSegmentCapture() { ...@@ -58,6 +59,13 @@ void CUDAGraph::BeginSegmentCapture() {
IsCapturing(), true, IsCapturing(), true,
errors::PermissionDenied("BeginSegmentCapture should be called when CUDA " errors::PermissionDenied("BeginSegmentCapture should be called when CUDA "
"Graph is capturing.")); "Graph is capturing."));
if (IsThreadLocalCapturing()) {
PADDLE_ENFORCE_EQ(IsThisThreadCapturing(), true,
platform::errors::PermissionDenied(
"When capturing CUDA Graph in the thread local mode, "
"you cannot begin segmented capturing in the thread "
"which is not the one that starts the capturing."));
}
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamBeginCapture( PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamBeginCapture(
capturing_graph_->stream_, capturing_graph_->capture_mode_)); capturing_graph_->stream_, capturing_graph_->capture_mode_));
PADDLE_ENFORCE_EQ(IsValidCapturing(), true, PADDLE_ENFORCE_EQ(IsValidCapturing(), true,
...@@ -82,6 +90,11 @@ void CUDAGraph::BeginCapture(platform::CUDAPlace place, cudaStream_t stream, ...@@ -82,6 +90,11 @@ void CUDAGraph::BeginCapture(platform::CUDAPlace place, cudaStream_t stream,
capturing_graph_->place_ = place; capturing_graph_->place_ = place;
capturing_graph_->stream_ = stream; capturing_graph_->stream_ = stream;
capturing_graph_->capture_mode_ = mode; capturing_graph_->capture_mode_ = mode;
if (mode == cudaStreamCaptureModeThreadLocal) {
capturing_thread_id_ = std::this_thread::get_id();
VLOG(10) << "Capturing CUDA Graph in thread local mode, thread id: "
<< capturing_thread_id_;
}
BeginSegmentCapture(); BeginSegmentCapture();
#endif #endif
} }
...@@ -115,6 +128,7 @@ void CUDAGraph::EndSegmentCapture() { ...@@ -115,6 +128,7 @@ void CUDAGraph::EndSegmentCapture() {
std::unique_ptr<CUDAGraph> CUDAGraph::EndCapture() { std::unique_ptr<CUDAGraph> CUDAGraph::EndCapture() {
EndSegmentCapture(); EndSegmentCapture();
capturing_thread_id_ = paddle::none;
return std::move(capturing_graph_); return std::move(capturing_graph_);
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <thread>
#include <vector> #include <vector>
#include "cuda.h" // NOLINT #include "cuda.h" // NOLINT
#include "cuda_runtime.h" // NOLINT #include "cuda_runtime.h" // NOLINT
...@@ -26,6 +27,7 @@ ...@@ -26,6 +27,7 @@
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/utils/optional.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -99,6 +101,25 @@ class CUDAGraph { ...@@ -99,6 +101,25 @@ class CUDAGraph {
// supported during capturing CUDA Graph. // supported during capturing CUDA Graph.
static bool IsValidCapturing(); static bool IsValidCapturing();
static bool IsThreadLocalCapturing() {
#if CUDA_VERSION >= 10010
return IsCapturing() &&
capturing_graph_->capture_mode_ == cudaStreamCaptureModeThreadLocal;
#else
return false;
#endif
}
static bool IsThisThreadCapturing() {
if (UNLIKELY(IsCapturing())) {
return IsThreadLocalCapturing()
? capturing_thread_id_.get() == std::this_thread::get_id()
: true;
} else {
return false;
}
}
private: private:
static CUDAGraphID UniqueID() { static CUDAGraphID UniqueID() {
static std::atomic<CUDAGraphID> id; static std::atomic<CUDAGraphID> id;
...@@ -118,6 +139,7 @@ class CUDAGraph { ...@@ -118,6 +139,7 @@ class CUDAGraph {
bool is_reset_{false}; bool is_reset_{false};
std::mutex mtx_; std::mutex mtx_;
static paddle::optional<std::thread::id> capturing_thread_id_;
static std::unique_ptr<CUDAGraph> capturing_graph_; static std::unique_ptr<CUDAGraph> capturing_graph_;
}; };
......
...@@ -34,7 +34,8 @@ class TestCUDAGraph(unittest.TestCase): ...@@ -34,7 +34,8 @@ class TestCUDAGraph(unittest.TestCase):
paddle.set_flags({ paddle.set_flags({
'FLAGS_allocator_strategy': 'auto_growth', 'FLAGS_allocator_strategy': 'auto_growth',
'FLAGS_sync_nccl_allreduce': False, 'FLAGS_sync_nccl_allreduce': False,
'FLAGS_cudnn_deterministic': True 'FLAGS_cudnn_deterministic': True,
'FLAGS_use_stream_safe_cuda_allocator': False,
}) })
def random_tensor(self, shape): def random_tensor(self, shape):
...@@ -187,6 +188,48 @@ class TestCUDAGraph(unittest.TestCase): ...@@ -187,6 +188,48 @@ class TestCUDAGraph(unittest.TestCase):
finally: finally:
graph.reset() graph.reset()
def test_dataloader(self):
if not can_use_cuda_graph():
return
class AutoIncDataset(paddle.io.Dataset):
def __init__(self, n, dtype):
self.n = n
self.dtype = dtype
def __len__(self):
return self.n
def __getitem__(self, idx):
return np.array([idx]).astype(self.dtype)
n = 100
dtype = 'int64'
dataset = AutoIncDataset(n, dtype)
data_loader = paddle.io.DataLoader(
dataset, batch_size=1, num_workers=2, use_buffer_reader=True)
x = None
y = None
graph = None
for i, data in enumerate(data_loader):
if graph is None:
x = data
x = x.cuda()
graph = CUDAGraph()
graph.capture_begin()
y = x * x
graph.capture_end()
else:
x.copy_(data, False)
x = x.cuda()
graph.replay()
actual_x = np.array([[i]]).astype(dtype)
actual_y = np.array([[i * i]]).astype(dtype)
self.assertTrue(np.array_equal(actual_x, x.numpy()))
self.assertTrue(np.array_equal(actual_y, y.numpy()))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册