未验证 提交 53c73c77 编写于 作者: P pangyoki 提交者: GitHub

fix cuda graph (#51648)

上级 4283e19e
...@@ -545,7 +545,7 @@ void InterpreterCore::PrepareForCUDAGraphCapture() { ...@@ -545,7 +545,7 @@ void InterpreterCore::PrepareForCUDAGraphCapture() {
platform::IsCUDAGraphCapturing(), platform::IsCUDAGraphCapturing(),
false, false,
platform::errors::PermissionDenied("CUDA Graph is not allowed to capture " platform::errors::PermissionDenied("CUDA Graph is not allowed to capture "
"when running the first batch.")); "before prepare."));
PADDLE_ENFORCE_EQ(platform::is_gpu_place(place_), PADDLE_ENFORCE_EQ(platform::is_gpu_place(place_),
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -684,8 +684,16 @@ void InterpreterCore::Convert( ...@@ -684,8 +684,16 @@ void InterpreterCore::Convert(
if (op_type == interpreter::kMemcpyD2H || if (op_type == interpreter::kMemcpyD2H ||
op_type == interpreter::kMemcpyH2D) { op_type == interpreter::kMemcpyH2D) {
PADDLE_THROW(paddle::platform::errors::Fatal( PADDLE_THROW(paddle::platform::errors::Fatal(
"op_type can't be memcpy d2h or h2d while using cuda graph.")); "Cuda memory copy d2h/h2d is not allowed while using cuda graph."));
} }
PADDLE_ENFORCE_EQ(typeid(*dev_ctx_) == typeid(phi::GPUContext),
true,
platform::errors::InvalidArgument(
"Device context of op %s must be [%s] while using "
"cuda graph, but got [%s].",
op_type,
typeid(phi::GPUContext).name(),
typeid(*dev_ctx_).name()));
// cuda graph needs to record all stream // cuda graph needs to record all stream
phi::backends::gpu::CUDAGraphContextManager::Instance() phi::backends::gpu::CUDAGraphContextManager::Instance()
.RecordCapturingDeviceContext(dev_ctx_); .RecordCapturingDeviceContext(dev_ctx_);
......
...@@ -40,32 +40,58 @@ void InitCUDNNRelatedHandle(phi::GPUContext* dev_ctx) { ...@@ -40,32 +40,58 @@ void InitCUDNNRelatedHandle(phi::GPUContext* dev_ctx) {
dev_ctx->cusolver_dn_handle(); dev_ctx->cusolver_dn_handle();
} }
phi::DeviceContext* SelectCUDAGraphDeviceContext(phi::GPUPlace place,
int64_t* pool_id) {
phi::DeviceContext* mutable_dev_ctx;
auto all_capturing_dev_ctxs =
phi::backends::gpu::CUDAGraphContextManager::Instance()
.GetAllCapturingDeviceContexts();
auto num_stream = all_capturing_dev_ctxs.size();
if (num_stream > 0) {
// Capturing device contexts will only be recorded in new
// executor in temporary, that is,
// FLAGS_new_executor_use_cuda_graph needs to be set to True.
// This restriction can be removed if device context is
// recorded in other modes.
// Record method: RecordCapturingDeviceContext.
PADDLE_ENFORCE_EQ(FLAGS_new_executor_use_cuda_graph,
true,
platform::errors::InvalidArgument(
"FLAGS_new_executor_use_cuda_graph must be True when "
"capturing stream is recorded."));
if (num_stream > 1) {
VLOG(4) << "Use a new stream to capture cuda graph. Used in multi-stream "
"scenarios with new executor.";
if (*pool_id <= CUDAGraph::kInvalidPoolID) {
*pool_id = CUDAGraph::UniqueMemoryPoolID();
}
mutable_dev_ctx =
phi::backends::gpu::CUDAGraphContextManager::Instance().Get(
*pool_id, place, 0);
} else if (num_stream == 1) {
VLOG(4) << "Use recorded stream to capture cuda graph. Used in "
"single-stream scenarios with new executor.";
mutable_dev_ctx = *(all_capturing_dev_ctxs.begin());
}
} else {
VLOG(4) << "Use default stream to capture cuda graph.";
mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place);
}
return mutable_dev_ctx;
}
void BeginCUDAGraphCapture(phi::GPUPlace place, void BeginCUDAGraphCapture(phi::GPUPlace place,
cudaStreamCaptureMode mode, cudaStreamCaptureMode mode,
int64_t pool_id) { int64_t pool_id) {
auto* mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place); auto* mutable_dev_ctx = SelectCUDAGraphDeviceContext(place, &pool_id);
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
InitCUDNNRelatedHandle(dev_ctx);
auto all_capturing_dev_ctxs = auto all_capturing_dev_ctxs =
phi::backends::gpu::CUDAGraphContextManager::Instance() phi::backends::gpu::CUDAGraphContextManager::Instance()
.GetAllCapturingDeviceContexts(); .GetAllCapturingDeviceContexts();
// create_cuda_graph_stream: Whether to create a new stream to auto num_stream = all_capturing_dev_ctxs.size();
// capture cuda graph, usually used in multi-stream scenarios. if (num_stream > 1) {
// Can only be used for new executor in static mode, that is,
// FLAGS_new_executor_use_cuda_graph needs to be set to True.
bool create_cuda_graph_stream = false;
if (FLAGS_new_executor_use_cuda_graph &&
(all_capturing_dev_ctxs.size() > 1 ||
(all_capturing_dev_ctxs.size() == 1 &&
(*(all_capturing_dev_ctxs.begin()) != mutable_dev_ctx)))) {
create_cuda_graph_stream = true;
}
if (create_cuda_graph_stream) {
VLOG(4) << "create a new stream to capture cuda graph.";
if (pool_id <= CUDAGraph::kInvalidPoolID) {
pool_id = CUDAGraph::UniqueMemoryPoolID();
}
mutable_dev_ctx =
phi::backends::gpu::CUDAGraphContextManager::Instance().Get(
pool_id, place, 0);
for (auto iter = all_capturing_dev_ctxs.begin(); for (auto iter = all_capturing_dev_ctxs.begin();
iter != all_capturing_dev_ctxs.end(); iter != all_capturing_dev_ctxs.end();
++iter) { ++iter) {
...@@ -73,12 +99,9 @@ void BeginCUDAGraphCapture(phi::GPUPlace place, ...@@ -73,12 +99,9 @@ void BeginCUDAGraphCapture(phi::GPUPlace place,
InitCUDNNRelatedHandle(capturing_dev_ctx); InitCUDNNRelatedHandle(capturing_dev_ctx);
} }
} }
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
InitCUDNNRelatedHandle(dev_ctx);
auto stream = dev_ctx->stream(); auto stream = dev_ctx->stream();
CUDAGraph::BeginCapture(place, stream, mode); CUDAGraph::BeginCapture(place, stream, mode);
CUDAGraph::SetIsCUDAGraphStreamCreated(create_cuda_graph_stream);
// When using cuda graph in new executor, fast GC must be used. // When using cuda graph in new executor, fast GC must be used.
// FLAGS_use_stream_safe_cuda_allocator should be true. // FLAGS_use_stream_safe_cuda_allocator should be true.
...@@ -96,7 +119,7 @@ void BeginCUDAGraphCapture(phi::GPUPlace place, ...@@ -96,7 +119,7 @@ void BeginCUDAGraphCapture(phi::GPUPlace place,
if (old_value) { if (old_value) {
FLAGS_use_stream_safe_cuda_allocator = true; FLAGS_use_stream_safe_cuda_allocator = true;
} }
if (create_cuda_graph_stream) { if (num_stream > 1) {
// Set cuda graph allocator for all streams. // Set cuda graph allocator for all streams.
// Establish dependencies between cuda graph stream and all other streams // Establish dependencies between cuda graph stream and all other streams
// using eventWait, so that all streams will be captured. // using eventWait, so that all streams will be captured.
...@@ -129,20 +152,17 @@ void BeginCUDAGraphCapture(phi::GPUPlace place, ...@@ -129,20 +152,17 @@ void BeginCUDAGraphCapture(phi::GPUPlace place,
} }
std::unique_ptr<CUDAGraph> EndCUDAGraphCapture() { std::unique_ptr<CUDAGraph> EndCUDAGraphCapture() {
phi::DeviceContext* mutable_dev_ctx;
auto place = CUDAGraph::CapturingPlace(); auto place = CUDAGraph::CapturingPlace();
bool create_cuda_graph_stream = CUDAGraph::IsCUDAGraphStreamCreated(); auto pool_id = CUDAGraph::CapturingPoolID();
if (create_cuda_graph_stream) { auto* mutable_dev_ctx = SelectCUDAGraphDeviceContext(place, &pool_id);
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
auto all_capturing_dev_ctxs =
phi::backends::gpu::CUDAGraphContextManager::Instance()
.GetAllCapturingDeviceContexts();
auto num_stream = all_capturing_dev_ctxs.size();
if (num_stream > 1) {
// join all other streams back to origin cuda graph stream. // join all other streams back to origin cuda graph stream.
int64_t pool_id = CUDAGraph::CapturingPoolID();
mutable_dev_ctx =
phi::backends::gpu::CUDAGraphContextManager::Instance().Get(
pool_id, place, 0);
auto* cuda_graph_dev_ctx =
reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
auto all_capturing_dev_ctxs =
phi::backends::gpu::CUDAGraphContextManager::Instance()
.GetAllCapturingDeviceContexts();
for (auto iter = all_capturing_dev_ctxs.begin(); for (auto iter = all_capturing_dev_ctxs.begin();
iter != all_capturing_dev_ctxs.end(); iter != all_capturing_dev_ctxs.end();
++iter) { ++iter) {
...@@ -152,19 +172,16 @@ std::unique_ptr<CUDAGraph> EndCUDAGraphCapture() { ...@@ -152,19 +172,16 @@ std::unique_ptr<CUDAGraph> EndCUDAGraphCapture() {
capturing_dev_ctx->GetPlace(), capturing_dev_ctx->GetPlace(),
platform::GenerateDeviceEventFlag()); platform::GenerateDeviceEventFlag());
capturing_event->Record(capturing_dev_ctx); capturing_event->Record(capturing_dev_ctx);
capturing_event->Wait(platform::kCUDA, cuda_graph_dev_ctx); capturing_event->Wait(platform::kCUDA, dev_ctx);
VLOG(4) << "CUDA Graph stream eventWait. cuda graph dev_ctx: " VLOG(4) << "CUDA Graph stream eventWait. cuda graph dev_ctx: " << dev_ctx
<< cuda_graph_dev_ctx
<< " wait for capturing dev_ctx: " << capturing_dev_ctx; << " wait for capturing dev_ctx: " << capturing_dev_ctx;
capturing_dev_ctx->cudnn_workspace_handle().ResetWorkspace(); capturing_dev_ctx->cudnn_workspace_handle().ResetWorkspace();
capturing_dev_ctx->SetCUDAGraphAllocator(nullptr); capturing_dev_ctx->SetCUDAGraphAllocator(nullptr);
} }
phi::backends::gpu::CUDAGraphContextManager::Instance()
.ClearDeviceContextsRecords();
} else {
mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place);
} }
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
phi::backends::gpu::CUDAGraphContextManager::Instance()
.ClearDeviceContextsRecords();
dev_ctx->cudnn_workspace_handle().ResetWorkspace(); dev_ctx->cudnn_workspace_handle().ResetWorkspace();
dev_ctx->SetCUDAGraphAllocator(nullptr); dev_ctx->SetCUDAGraphAllocator(nullptr);
return CUDAGraph::EndCapture(); return CUDAGraph::EndCapture();
......
...@@ -196,14 +196,6 @@ class CUDAGraph { ...@@ -196,14 +196,6 @@ class CUDAGraph {
// supported during capturing CUDA Graph. // supported during capturing CUDA Graph.
static bool IsValidCapturing(); static bool IsValidCapturing();
static void SetIsCUDAGraphStreamCreated(bool create_cuda_graph_stream) {
capturing_graph_->is_cuda_graph_stream_created_ = create_cuda_graph_stream;
}
static bool IsCUDAGraphStreamCreated() {
return capturing_graph_->is_cuda_graph_stream_created_;
}
static bool IsThreadLocalCapturing() { static bool IsThreadLocalCapturing() {
#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
return IsCapturing() && return IsCapturing() &&
...@@ -254,8 +246,6 @@ class CUDAGraph { ...@@ -254,8 +246,6 @@ class CUDAGraph {
bool is_first_run_{true}; bool is_first_run_{true};
bool is_cuda_graph_stream_created_{false};
static paddle::optional<std::thread::id> capturing_thread_id_; static paddle::optional<std::thread::id> capturing_thread_id_;
static std::unique_ptr<CUDAGraph> capturing_graph_; static std::unique_ptr<CUDAGraph> capturing_graph_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册