未验证 提交 579fb5fd 编写于 作者: P pangyoki 提交者: GitHub

cuda graph support multi-stream for new executor (#51389)

* cuda graph support multi-stream for new executor

* fix windows compile error

* delete create_cuda_graph_stream
上级 26007b1d
......@@ -541,6 +541,11 @@ void InterpreterCore::BuildInplace() {
void InterpreterCore::PrepareForCUDAGraphCapture() {
if (!FLAGS_new_executor_use_cuda_graph) return;
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_EQ(
platform::IsCUDAGraphCapturing(),
false,
platform::errors::PermissionDenied("CUDA Graph is not allowed to capture "
"when running the first batch."));
PADDLE_ENFORCE_EQ(platform::is_gpu_place(place_),
true,
platform::errors::InvalidArgument(
......@@ -672,6 +677,20 @@ void InterpreterCore::Convert(
auto& op_func_node = nodes[op_idx];
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_);
#ifdef PADDLE_WITH_CUDA
if (FLAGS_new_executor_use_cuda_graph) {
auto& op = op_func_node.operator_base_;
auto& op_type = op->Type();
if (op_type == interpreter::kMemcpyD2H ||
op_type == interpreter::kMemcpyH2D) {
PADDLE_THROW(paddle::platform::errors::Fatal(
"op_type can't be memcpy d2h or h2d while using cuda graph."));
}
// cuda graph needs to record all stream
phi::backends::gpu::CUDAGraphContextManager::Instance()
.RecordCapturingDeviceContext(dev_ctx_);
}
#endif
}
BuildOperatorDependences();
......
......@@ -32,18 +32,6 @@ cc_test(
SRCS os_info_test.cc
DEPS phi_os_info)
if(WITH_GPU)
nv_library(
cuda_graph_with_memory_pool
SRCS cuda_graph_with_memory_pool.cc
DEPS device_context allocator phi_backends)
else()
cc_library(
cuda_graph_with_memory_pool
SRCS cuda_graph_with_memory_pool.cc
DEPS device_context allocator)
endif()
cc_library(
place
SRCS place.cc
......@@ -239,6 +227,10 @@ if(WITH_GPU)
SRCS device_event_test.cc
DEPS device_event_gpu)
endif()
nv_library(
cuda_graph_with_memory_pool
SRCS cuda_graph_with_memory_pool.cc
DEPS ${DEVICE_EVENT_LIBS} device_context allocator phi_backends)
nv_test(
device_context_test
SRCS device_context_test.cu
......@@ -247,6 +239,11 @@ if(WITH_GPU)
device_context_test_cuda_graph
SRCS device_context_test_cuda_graph.cu
DEPS device_context gpu_info cuda_graph_with_memory_pool)
else()
cc_library(
cuda_graph_with_memory_pool
SRCS cuda_graph_with_memory_pool.cc
DEPS device_context allocator)
endif()
if(WITH_ROCM)
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device_event.h"
#include "paddle/phi/backends/context_pool.h"
DECLARE_bool(use_stream_safe_cuda_allocator);
......@@ -24,25 +25,60 @@ namespace paddle {
namespace platform {
#ifdef PADDLE_WITH_CUDA
void BeginCUDAGraphCapture(phi::GPUPlace place,
cudaStreamCaptureMode mode,
int64_t pool_id) {
auto* mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place);
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
void InitCUDNNRelatedHandle(phi::GPUContext* dev_ctx) {
dev_ctx->cudnn_workspace_handle().ResetWorkspace();
// After PR(#43206), cudnn related initializations will change to lazy mode.
// It will only be initialized when op calls them. But cuda graph not support
// capture such kind of init, need to init all these handle before cuda graph.
// It will only be initialized when op calls them. But cuda graph not
// support capture such kind of init, need to init all these handle before
// cuda graph.
dev_ctx->cublas_handle();
#if CUDA_VERSION >= 11060
dev_ctx->cublaslt_handle();
#endif
dev_ctx->cudnn_handle();
dev_ctx->cusolver_dn_handle();
}
void BeginCUDAGraphCapture(phi::GPUPlace place,
cudaStreamCaptureMode mode,
int64_t pool_id) {
auto* mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place);
auto all_capturing_dev_ctxs =
phi::backends::gpu::CUDAGraphContextManager::Instance()
.GetAllCapturingDeviceContexts();
// create_cuda_graph_stream: Whether to create a new stream to
// capture cuda graph, usually used in multi-stream scenarios.
// Can only be used for new executor in static mode, that is,
// FLAGS_new_executor_use_cuda_graph needs to be set to True.
bool create_cuda_graph_stream = false;
if (FLAGS_new_executor_use_cuda_graph &&
(all_capturing_dev_ctxs.size() > 1 ||
(all_capturing_dev_ctxs.size() == 1 &&
(*(all_capturing_dev_ctxs.begin()) != mutable_dev_ctx)))) {
create_cuda_graph_stream = true;
}
if (create_cuda_graph_stream) {
VLOG(4) << "create a new stream to capture cuda graph.";
if (pool_id <= CUDAGraph::kInvalidPoolID) {
pool_id = CUDAGraph::UniqueMemoryPoolID();
}
mutable_dev_ctx =
phi::backends::gpu::CUDAGraphContextManager::Instance().Get(
pool_id, place, 0);
for (auto iter = all_capturing_dev_ctxs.begin();
iter != all_capturing_dev_ctxs.end();
++iter) {
auto* capturing_dev_ctx = reinterpret_cast<phi::GPUContext*>(*iter);
InitCUDNNRelatedHandle(capturing_dev_ctx);
}
}
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
InitCUDNNRelatedHandle(dev_ctx);
auto stream = dev_ctx->stream();
CUDAGraph::BeginCapture(place, stream, mode);
CUDAGraph::SetIsCUDAGraphStreamCreated(create_cuda_graph_stream);
// When using cuda graph in new executor, fast GC must be used.
// FLAGS_use_stream_safe_cuda_allocator should be true.
......@@ -60,6 +96,32 @@ void BeginCUDAGraphCapture(phi::GPUPlace place,
if (old_value) {
FLAGS_use_stream_safe_cuda_allocator = true;
}
if (create_cuda_graph_stream) {
// Set cuda graph allocator for all streams.
// Establish dependencies between cuda graph stream and all other streams
// using eventWait, so that all streams will be captured.
std::shared_ptr<platform::DeviceEvent> cuda_graph_event =
std::make_shared<platform::DeviceEvent>(
dev_ctx->GetPlace(), platform::GenerateDeviceEventFlag());
cuda_graph_event->Record(dev_ctx);
for (auto iter = all_capturing_dev_ctxs.begin();
iter != all_capturing_dev_ctxs.end();
++iter) {
auto* capturing_dev_ctx = reinterpret_cast<phi::GPUContext*>(*iter);
auto capturing_stream = capturing_dev_ctx->stream();
capturing_dev_ctx->SetCUDAGraphAllocator(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place, capturing_stream)
.get());
VLOG(4) << "set CUDAGraphAllocator for dev_ctx: " << capturing_dev_ctx
<< " with stream: " << capturing_stream;
cuda_graph_event->Wait(platform::kCUDA, capturing_dev_ctx);
VLOG(4) << "CUDA Graph stream eventWait. Capturing dev_ctx: "
<< capturing_dev_ctx
<< " wait for cuda graph dev_ctx: " << dev_ctx;
}
}
AddResetCallbackIfCapturingCUDAGraph([pool_id] {
memory::allocation::AllocatorFacade::Instance().RemoveMemoryPoolOfCUDAGraph(
pool_id);
......@@ -67,8 +129,41 @@ void BeginCUDAGraphCapture(phi::GPUPlace place,
}
std::unique_ptr<CUDAGraph> EndCUDAGraphCapture() {
phi::DeviceContext* mutable_dev_ctx;
auto place = CUDAGraph::CapturingPlace();
auto* mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place);
bool create_cuda_graph_stream = CUDAGraph::IsCUDAGraphStreamCreated();
if (create_cuda_graph_stream) {
// join all other streams back to origin cuda graph stream.
int64_t pool_id = CUDAGraph::CapturingPoolID();
mutable_dev_ctx =
phi::backends::gpu::CUDAGraphContextManager::Instance().Get(
pool_id, place, 0);
auto* cuda_graph_dev_ctx =
reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
auto all_capturing_dev_ctxs =
phi::backends::gpu::CUDAGraphContextManager::Instance()
.GetAllCapturingDeviceContexts();
for (auto iter = all_capturing_dev_ctxs.begin();
iter != all_capturing_dev_ctxs.end();
++iter) {
auto* capturing_dev_ctx = reinterpret_cast<phi::GPUContext*>(*iter);
std::shared_ptr<platform::DeviceEvent> capturing_event =
std::make_shared<platform::DeviceEvent>(
capturing_dev_ctx->GetPlace(),
platform::GenerateDeviceEventFlag());
capturing_event->Record(capturing_dev_ctx);
capturing_event->Wait(platform::kCUDA, cuda_graph_dev_ctx);
VLOG(4) << "CUDA Graph stream eventWait. cuda graph dev_ctx: "
<< cuda_graph_dev_ctx
<< " wait for capturing dev_ctx: " << capturing_dev_ctx;
capturing_dev_ctx->cudnn_workspace_handle().ResetWorkspace();
capturing_dev_ctx->SetCUDAGraphAllocator(nullptr);
}
phi::backends::gpu::CUDAGraphContextManager::Instance()
.ClearDeviceContextsRecords();
} else {
mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place);
}
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
dev_ctx->cudnn_workspace_handle().ResetWorkspace();
dev_ctx->SetCUDAGraphAllocator(nullptr);
......
......@@ -16,14 +16,18 @@
#include <atomic>
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <set>
#include <thread>
#include <vector>
#include "cuda.h" // NOLINT
#include "cuda_runtime.h" // NOLINT
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
......@@ -34,6 +38,51 @@ namespace phi {
namespace backends {
namespace gpu {
class CUDAGraphContextManager {
public:
using DeviceContextMap =
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>;
static CUDAGraphContextManager &Instance() {
static CUDAGraphContextManager *cuda_graph_ctx_manager =
new CUDAGraphContextManager;
return *cuda_graph_ctx_manager;
}
DeviceContext *Get(int64_t pool_id, const Place &place, int stream_priority) {
std::lock_guard<std::mutex> lk(ctx_mtx_);
VLOG(6) << "Get cuda graph device context for " << place;
DeviceContextMap &ctxs = cuda_graph_ctx_pool_[pool_id];
if (ctxs.find(place) == ctxs.end()) {
EmplaceDeviceContexts(
&ctxs,
{place},
/*disable_setting_default_stream_for_allocator=*/true,
stream_priority);
}
return ctxs[place].get().get();
}
void RecordCapturingDeviceContext(DeviceContext *dev_ctx) {
capturing_ctxs_.insert(dev_ctx);
}
std::set<DeviceContext *> GetAllCapturingDeviceContexts() const {
return capturing_ctxs_;
}
void ClearDeviceContextsRecords() { capturing_ctxs_.clear(); }
private:
CUDAGraphContextManager() {}
DISABLE_COPY_AND_ASSIGN(CUDAGraphContextManager);
std::mutex ctx_mtx_;
std::unordered_map<int64_t, DeviceContextMap> cuda_graph_ctx_pool_;
std::set<DeviceContext *> capturing_ctxs_;
};
class CUDAKernelParams {
public:
explicit CUDAKernelParams(const cudaKernelNodeParams *params)
......@@ -147,6 +196,14 @@ class CUDAGraph {
// supported during capturing CUDA Graph.
static bool IsValidCapturing();
static void SetIsCUDAGraphStreamCreated(bool create_cuda_graph_stream) {
capturing_graph_->is_cuda_graph_stream_created_ = create_cuda_graph_stream;
}
static bool IsCUDAGraphStreamCreated() {
return capturing_graph_->is_cuda_graph_stream_created_;
}
static bool IsThreadLocalCapturing() {
#if CUDA_VERSION >= 10010
return IsCapturing() &&
......@@ -197,6 +254,8 @@ class CUDAGraph {
bool is_first_run_{true};
bool is_cuda_graph_stream_created_{false};
static paddle::optional<std::thread::id> capturing_thread_id_;
static std::unique_ptr<CUDAGraph> capturing_graph_;
};
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle
from paddle.device.cuda.graphs import CUDAGraph
sys.path.append("..")
from test_cuda_graph_static_mode import build_program
paddle.enable_static()
def can_use_cuda_graph():
return paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm()
class TestCustomStream(unittest.TestCase):
def setUp(self):
self.steps = 10
if can_use_cuda_graph():
paddle.set_flags(
{
'FLAGS_allocator_strategy': 'auto_growth',
'FLAGS_sync_nccl_allreduce': False,
'FLAGS_cudnn_deterministic': True,
'FLAGS_use_stream_safe_cuda_allocator': True,
'FLAGS_new_executor_use_cuda_graph': True,
}
)
def set_custom_stream(self, prog):
op_index_for_stream1 = [2, 4, 9]
op_index_for_stream2 = [7, 8, 10, 11]
ops = prog.global_block().ops
for op_index in op_index_for_stream1:
ops[op_index].dist_attr.execution_stream = "s1"
ops[op_index].dist_attr.stream_priority = 0
for op_index in op_index_for_stream2:
ops[op_index].dist_attr.execution_stream = "s2"
ops[op_index].dist_attr.stream_priority = -1
def run_program(self, use_cuda_graph=False, apply_custom_stream=False):
seed = 100
batch_size = 1
class_num = 10
image_shape = [batch_size, 784]
label_shape = [batch_size, 1]
paddle.seed(seed)
np.random.seed(seed)
startup = paddle.static.Program()
main = paddle.static.Program()
image, label, loss, lr = build_program(
main, startup, batch_size, class_num
)
if apply_custom_stream:
self.set_custom_stream(main)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup)
image_t = scope.var(image.name).get_tensor()
label_t = scope.var(label.name).get_tensor()
loss_t = scope.var(loss.name).get_tensor()
lr_var = main.global_block().var(lr._var_name)
self.assertTrue(lr_var.persistable)
lr_t = scope.var(lr_var.name).get_tensor()
cuda_graph = None
outs = []
for batch_id in range(20):
image_np = np.random.rand(*image_shape).astype('float32')
label_np = np.random.randint(
low=0, high=class_num, size=label_shape, dtype='int64'
)
image_t.set(image_np, place)
label_t.set(label_np, place)
if batch_id == 1 and use_cuda_graph:
cuda_graph = CUDAGraph(place, mode="global")
cuda_graph.capture_begin()
exe.run(main)
cuda_graph.capture_end()
if cuda_graph:
lr_t.set(np.array([lr()], dtype='float32'), place)
cuda_graph.replay()
else:
exe.run(main)
outs.append(np.array(loss_t))
lr.step()
if cuda_graph:
cuda_graph.reset()
return outs
def test_result(self):
if not can_use_cuda_graph():
return
outs = []
for use_cuda_graph in [False, True]:
for apply_custom_stream in [False, True]:
out = self.run_program(use_cuda_graph, apply_custom_stream)
outs.append(out)
for out in outs:
for baseline, result in zip(outs[0], out):
self.assertEqual(baseline[0], result[0])
if __name__ == "__main__":
unittest.main()
......@@ -50,7 +50,6 @@ def build_program(main, startup, batch_size, class_num):
class TestCUDAGraphInStaticMode(unittest.TestCase):
def setUp(self):
self.init_data()
if can_use_cuda_graph():
# The behavior of `FLAGS_use_stream_safe_cuda_allocator` in static
# mode is inconsistent with that in dygraph mode.
......@@ -69,9 +68,6 @@ class TestCUDAGraphInStaticMode(unittest.TestCase):
}
)
def init_data(self):
self.use_feed_data = False
@switch_to_static_graph
def test_cuda_graph_static_graph(self):
if not can_use_cuda_graph():
......@@ -121,16 +117,12 @@ class TestCUDAGraphInStaticMode(unittest.TestCase):
lr_t = scope.var(lr_var.name).get_tensor()
cuda_graph = None
for batch_id in range(20):
use_feed_data = (
True if batch_id == 0 and self.use_feed_data else False
)
image_np = np.random.rand(*image_shape).astype('float32')
label_np = np.random.randint(
low=0, high=class_num, size=label_shape, dtype='int64'
)
if not use_feed_data:
image_t.set(image_np, place)
label_t.set(label_np, place)
image_t.set(image_np, place)
label_t.set(label_np, place)
if batch_id == 1 and use_cuda_graph:
cuda_graph = CUDAGraph(place, mode="global")
......@@ -142,27 +134,12 @@ class TestCUDAGraphInStaticMode(unittest.TestCase):
lr_t.set(np.array([lr()], dtype='float32'), place)
cuda_graph.replay()
else:
if use_feed_data:
exe.run(
compiled_program,
feed={'image': image_np, 'label': label_np},
)
else:
exe.run(compiled_program)
exe.run(compiled_program)
lr.step()
if cuda_graph:
cuda_graph.reset()
return np.array(loss_t)
class TestCUDAGraphWhenFeedDataChanges(TestCUDAGraphInStaticMode):
def init_data(self):
# When feed fetch var of new executor changes, a new
# StandaloneExecutor will be newly created. And the
# behavior of capturing cuda graph will change.
# Add test for this case.
self.use_feed_data = True
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册