From 579fb5fde1b3545615731689bceddd8d7c5003e4 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Tue, 14 Mar 2023 15:34:04 +0800 Subject: [PATCH] cuda graph support multi-stream for new executor (#51389) * cuda graph support multi-stream for new executor * fix windows compile error * delete create_cuda_graph_stream --- .../framework/new_executor/interpretercore.cc | 19 +++ paddle/fluid/platform/CMakeLists.txt | 21 ++- .../platform/cuda_graph_with_memory_pool.cc | 111 +++++++++++++-- paddle/phi/backends/gpu/cuda/cuda_graph.h | 59 ++++++++ ...test_standalone_cuda_graph_multi_stream.py | 131 ++++++++++++++++++ .../unittests/test_cuda_graph_static_mode.py | 29 +--- 6 files changed, 324 insertions(+), 46 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_cuda_graph_multi_stream.py diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index bdad36b3e9d..3302dbe79b9 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -541,6 +541,11 @@ void InterpreterCore::BuildInplace() { void InterpreterCore::PrepareForCUDAGraphCapture() { if (!FLAGS_new_executor_use_cuda_graph) return; #ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_EQ( + platform::IsCUDAGraphCapturing(), + false, + platform::errors::PermissionDenied("CUDA Graph is not allowed to capture " + "when running the first batch.")); PADDLE_ENFORCE_EQ(platform::is_gpu_place(place_), true, platform::errors::InvalidArgument( @@ -672,6 +677,20 @@ void InterpreterCore::Convert( auto& op_func_node = nodes[op_idx]; auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_); +#ifdef PADDLE_WITH_CUDA + if (FLAGS_new_executor_use_cuda_graph) { + auto& op = op_func_node.operator_base_; + auto& op_type = op->Type(); + if (op_type == interpreter::kMemcpyD2H || + op_type == interpreter::kMemcpyH2D) { + PADDLE_THROW(paddle::platform::errors::Fatal( + "op_type can't be memcpy d2h or h2d while using cuda graph.")); + } + // cuda graph needs to record all stream + phi::backends::gpu::CUDAGraphContextManager::Instance() + .RecordCapturingDeviceContext(dev_ctx_); + } +#endif } BuildOperatorDependences(); diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index a5b924b40ac..312d9a84e03 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -32,18 +32,6 @@ cc_test( SRCS os_info_test.cc DEPS phi_os_info) -if(WITH_GPU) - nv_library( - cuda_graph_with_memory_pool - SRCS cuda_graph_with_memory_pool.cc - DEPS device_context allocator phi_backends) -else() - cc_library( - cuda_graph_with_memory_pool - SRCS cuda_graph_with_memory_pool.cc - DEPS device_context allocator) -endif() - cc_library( place SRCS place.cc @@ -239,6 +227,10 @@ if(WITH_GPU) SRCS device_event_test.cc DEPS device_event_gpu) endif() + nv_library( + cuda_graph_with_memory_pool + SRCS cuda_graph_with_memory_pool.cc + DEPS ${DEVICE_EVENT_LIBS} device_context allocator phi_backends) nv_test( device_context_test SRCS device_context_test.cu @@ -247,6 +239,11 @@ if(WITH_GPU) device_context_test_cuda_graph SRCS device_context_test_cuda_graph.cu DEPS device_context gpu_info cuda_graph_with_memory_pool) +else() + cc_library( + cuda_graph_with_memory_pool + SRCS cuda_graph_with_memory_pool.cc + DEPS device_context allocator) endif() if(WITH_ROCM) diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc index c965045623a..a3df4301ad4 100644 --- a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/fluid/platform/device_event.h" #include "paddle/phi/backends/context_pool.h" DECLARE_bool(use_stream_safe_cuda_allocator); @@ -24,25 +25,60 @@ namespace paddle { namespace platform { #ifdef PADDLE_WITH_CUDA -void BeginCUDAGraphCapture(phi::GPUPlace place, - cudaStreamCaptureMode mode, - int64_t pool_id) { - auto* mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place); - auto* dev_ctx = reinterpret_cast(mutable_dev_ctx); +void InitCUDNNRelatedHandle(phi::GPUContext* dev_ctx) { dev_ctx->cudnn_workspace_handle().ResetWorkspace(); // After PR(#43206), cudnn related initializations will change to lazy mode. - // It will only be initialized when op calls them. But cuda graph not support - // capture such kind of init, need to init all these handle before cuda graph. + // It will only be initialized when op calls them. But cuda graph not + // support capture such kind of init, need to init all these handle before + // cuda graph. dev_ctx->cublas_handle(); #if CUDA_VERSION >= 11060 dev_ctx->cublaslt_handle(); #endif dev_ctx->cudnn_handle(); dev_ctx->cusolver_dn_handle(); +} + +void BeginCUDAGraphCapture(phi::GPUPlace place, + cudaStreamCaptureMode mode, + int64_t pool_id) { + auto* mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place); + auto all_capturing_dev_ctxs = + phi::backends::gpu::CUDAGraphContextManager::Instance() + .GetAllCapturingDeviceContexts(); + // create_cuda_graph_stream: Whether to create a new stream to + // capture cuda graph, usually used in multi-stream scenarios. + // Can only be used for new executor in static mode, that is, + // FLAGS_new_executor_use_cuda_graph needs to be set to True. + bool create_cuda_graph_stream = false; + if (FLAGS_new_executor_use_cuda_graph && + (all_capturing_dev_ctxs.size() > 1 || + (all_capturing_dev_ctxs.size() == 1 && + (*(all_capturing_dev_ctxs.begin()) != mutable_dev_ctx)))) { + create_cuda_graph_stream = true; + } + if (create_cuda_graph_stream) { + VLOG(4) << "create a new stream to capture cuda graph."; + if (pool_id <= CUDAGraph::kInvalidPoolID) { + pool_id = CUDAGraph::UniqueMemoryPoolID(); + } + mutable_dev_ctx = + phi::backends::gpu::CUDAGraphContextManager::Instance().Get( + pool_id, place, 0); + for (auto iter = all_capturing_dev_ctxs.begin(); + iter != all_capturing_dev_ctxs.end(); + ++iter) { + auto* capturing_dev_ctx = reinterpret_cast(*iter); + InitCUDNNRelatedHandle(capturing_dev_ctx); + } + } + auto* dev_ctx = reinterpret_cast(mutable_dev_ctx); + InitCUDNNRelatedHandle(dev_ctx); auto stream = dev_ctx->stream(); CUDAGraph::BeginCapture(place, stream, mode); + CUDAGraph::SetIsCUDAGraphStreamCreated(create_cuda_graph_stream); // When using cuda graph in new executor, fast GC must be used. // FLAGS_use_stream_safe_cuda_allocator should be true. @@ -60,6 +96,32 @@ void BeginCUDAGraphCapture(phi::GPUPlace place, if (old_value) { FLAGS_use_stream_safe_cuda_allocator = true; } + if (create_cuda_graph_stream) { + // Set cuda graph allocator for all streams. + // Establish dependencies between cuda graph stream and all other streams + // using eventWait, so that all streams will be captured. + std::shared_ptr cuda_graph_event = + std::make_shared( + dev_ctx->GetPlace(), platform::GenerateDeviceEventFlag()); + cuda_graph_event->Record(dev_ctx); + + for (auto iter = all_capturing_dev_ctxs.begin(); + iter != all_capturing_dev_ctxs.end(); + ++iter) { + auto* capturing_dev_ctx = reinterpret_cast(*iter); + auto capturing_stream = capturing_dev_ctx->stream(); + capturing_dev_ctx->SetCUDAGraphAllocator( + memory::allocation::AllocatorFacade::Instance() + .GetAllocator(place, capturing_stream) + .get()); + VLOG(4) << "set CUDAGraphAllocator for dev_ctx: " << capturing_dev_ctx + << " with stream: " << capturing_stream; + cuda_graph_event->Wait(platform::kCUDA, capturing_dev_ctx); + VLOG(4) << "CUDA Graph stream eventWait. Capturing dev_ctx: " + << capturing_dev_ctx + << " wait for cuda graph dev_ctx: " << dev_ctx; + } + } AddResetCallbackIfCapturingCUDAGraph([pool_id] { memory::allocation::AllocatorFacade::Instance().RemoveMemoryPoolOfCUDAGraph( pool_id); @@ -67,8 +129,41 @@ void BeginCUDAGraphCapture(phi::GPUPlace place, } std::unique_ptr EndCUDAGraphCapture() { + phi::DeviceContext* mutable_dev_ctx; auto place = CUDAGraph::CapturingPlace(); - auto* mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place); + bool create_cuda_graph_stream = CUDAGraph::IsCUDAGraphStreamCreated(); + if (create_cuda_graph_stream) { + // join all other streams back to origin cuda graph stream. + int64_t pool_id = CUDAGraph::CapturingPoolID(); + mutable_dev_ctx = + phi::backends::gpu::CUDAGraphContextManager::Instance().Get( + pool_id, place, 0); + auto* cuda_graph_dev_ctx = + reinterpret_cast(mutable_dev_ctx); + auto all_capturing_dev_ctxs = + phi::backends::gpu::CUDAGraphContextManager::Instance() + .GetAllCapturingDeviceContexts(); + for (auto iter = all_capturing_dev_ctxs.begin(); + iter != all_capturing_dev_ctxs.end(); + ++iter) { + auto* capturing_dev_ctx = reinterpret_cast(*iter); + std::shared_ptr capturing_event = + std::make_shared( + capturing_dev_ctx->GetPlace(), + platform::GenerateDeviceEventFlag()); + capturing_event->Record(capturing_dev_ctx); + capturing_event->Wait(platform::kCUDA, cuda_graph_dev_ctx); + VLOG(4) << "CUDA Graph stream eventWait. cuda graph dev_ctx: " + << cuda_graph_dev_ctx + << " wait for capturing dev_ctx: " << capturing_dev_ctx; + capturing_dev_ctx->cudnn_workspace_handle().ResetWorkspace(); + capturing_dev_ctx->SetCUDAGraphAllocator(nullptr); + } + phi::backends::gpu::CUDAGraphContextManager::Instance() + .ClearDeviceContextsRecords(); + } else { + mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place); + } auto* dev_ctx = reinterpret_cast(mutable_dev_ctx); dev_ctx->cudnn_workspace_handle().ResetWorkspace(); dev_ctx->SetCUDAGraphAllocator(nullptr); diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph.h b/paddle/phi/backends/gpu/cuda/cuda_graph.h index 13054c347ef..a9cbe2537ad 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph.h +++ b/paddle/phi/backends/gpu/cuda/cuda_graph.h @@ -16,14 +16,18 @@ #include #include +#include #include #include +#include #include #include #include "cuda.h" // NOLINT #include "cuda_runtime.h" // NOLINT +#include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" @@ -34,6 +38,51 @@ namespace phi { namespace backends { namespace gpu { +class CUDAGraphContextManager { + public: + using DeviceContextMap = + std::map>>; + + static CUDAGraphContextManager &Instance() { + static CUDAGraphContextManager *cuda_graph_ctx_manager = + new CUDAGraphContextManager; + return *cuda_graph_ctx_manager; + } + + DeviceContext *Get(int64_t pool_id, const Place &place, int stream_priority) { + std::lock_guard lk(ctx_mtx_); + VLOG(6) << "Get cuda graph device context for " << place; + + DeviceContextMap &ctxs = cuda_graph_ctx_pool_[pool_id]; + if (ctxs.find(place) == ctxs.end()) { + EmplaceDeviceContexts( + &ctxs, + {place}, + /*disable_setting_default_stream_for_allocator=*/true, + stream_priority); + } + return ctxs[place].get().get(); + } + + void RecordCapturingDeviceContext(DeviceContext *dev_ctx) { + capturing_ctxs_.insert(dev_ctx); + } + + std::set GetAllCapturingDeviceContexts() const { + return capturing_ctxs_; + } + + void ClearDeviceContextsRecords() { capturing_ctxs_.clear(); } + + private: + CUDAGraphContextManager() {} + DISABLE_COPY_AND_ASSIGN(CUDAGraphContextManager); + + std::mutex ctx_mtx_; + std::unordered_map cuda_graph_ctx_pool_; + std::set capturing_ctxs_; +}; + class CUDAKernelParams { public: explicit CUDAKernelParams(const cudaKernelNodeParams *params) @@ -147,6 +196,14 @@ class CUDAGraph { // supported during capturing CUDA Graph. static bool IsValidCapturing(); + static void SetIsCUDAGraphStreamCreated(bool create_cuda_graph_stream) { + capturing_graph_->is_cuda_graph_stream_created_ = create_cuda_graph_stream; + } + + static bool IsCUDAGraphStreamCreated() { + return capturing_graph_->is_cuda_graph_stream_created_; + } + static bool IsThreadLocalCapturing() { #if CUDA_VERSION >= 10010 return IsCapturing() && @@ -197,6 +254,8 @@ class CUDAGraph { bool is_first_run_{true}; + bool is_cuda_graph_stream_created_{false}; + static paddle::optional capturing_thread_id_; static std::unique_ptr capturing_graph_; }; diff --git a/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_cuda_graph_multi_stream.py b/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_cuda_graph_multi_stream.py new file mode 100644 index 00000000000..51881410756 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_cuda_graph_multi_stream.py @@ -0,0 +1,131 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np + +import paddle +from paddle.device.cuda.graphs import CUDAGraph + +sys.path.append("..") +from test_cuda_graph_static_mode import build_program + +paddle.enable_static() + + +def can_use_cuda_graph(): + return paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm() + + +class TestCustomStream(unittest.TestCase): + def setUp(self): + self.steps = 10 + if can_use_cuda_graph(): + paddle.set_flags( + { + 'FLAGS_allocator_strategy': 'auto_growth', + 'FLAGS_sync_nccl_allreduce': False, + 'FLAGS_cudnn_deterministic': True, + 'FLAGS_use_stream_safe_cuda_allocator': True, + 'FLAGS_new_executor_use_cuda_graph': True, + } + ) + + def set_custom_stream(self, prog): + op_index_for_stream1 = [2, 4, 9] + op_index_for_stream2 = [7, 8, 10, 11] + ops = prog.global_block().ops + for op_index in op_index_for_stream1: + ops[op_index].dist_attr.execution_stream = "s1" + ops[op_index].dist_attr.stream_priority = 0 + for op_index in op_index_for_stream2: + ops[op_index].dist_attr.execution_stream = "s2" + ops[op_index].dist_attr.stream_priority = -1 + + def run_program(self, use_cuda_graph=False, apply_custom_stream=False): + seed = 100 + + batch_size = 1 + class_num = 10 + image_shape = [batch_size, 784] + label_shape = [batch_size, 1] + + paddle.seed(seed) + np.random.seed(seed) + startup = paddle.static.Program() + main = paddle.static.Program() + image, label, loss, lr = build_program( + main, startup, batch_size, class_num + ) + + if apply_custom_stream: + self.set_custom_stream(main) + + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + scope = paddle.static.Scope() + with paddle.static.scope_guard(scope): + exe.run(startup) + image_t = scope.var(image.name).get_tensor() + label_t = scope.var(label.name).get_tensor() + loss_t = scope.var(loss.name).get_tensor() + lr_var = main.global_block().var(lr._var_name) + self.assertTrue(lr_var.persistable) + lr_t = scope.var(lr_var.name).get_tensor() + cuda_graph = None + outs = [] + for batch_id in range(20): + image_np = np.random.rand(*image_shape).astype('float32') + label_np = np.random.randint( + low=0, high=class_num, size=label_shape, dtype='int64' + ) + image_t.set(image_np, place) + label_t.set(label_np, place) + + if batch_id == 1 and use_cuda_graph: + cuda_graph = CUDAGraph(place, mode="global") + cuda_graph.capture_begin() + exe.run(main) + cuda_graph.capture_end() + + if cuda_graph: + lr_t.set(np.array([lr()], dtype='float32'), place) + cuda_graph.replay() + else: + exe.run(main) + outs.append(np.array(loss_t)) + lr.step() + if cuda_graph: + cuda_graph.reset() + return outs + + def test_result(self): + if not can_use_cuda_graph(): + return + + outs = [] + for use_cuda_graph in [False, True]: + for apply_custom_stream in [False, True]: + out = self.run_program(use_cuda_graph, apply_custom_stream) + outs.append(out) + + for out in outs: + for baseline, result in zip(outs[0], out): + self.assertEqual(baseline[0], result[0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_cuda_graph_static_mode.py b/python/paddle/fluid/tests/unittests/test_cuda_graph_static_mode.py index cf3b60e490c..3f433e46911 100644 --- a/python/paddle/fluid/tests/unittests/test_cuda_graph_static_mode.py +++ b/python/paddle/fluid/tests/unittests/test_cuda_graph_static_mode.py @@ -50,7 +50,6 @@ def build_program(main, startup, batch_size, class_num): class TestCUDAGraphInStaticMode(unittest.TestCase): def setUp(self): - self.init_data() if can_use_cuda_graph(): # The behavior of `FLAGS_use_stream_safe_cuda_allocator` in static # mode is inconsistent with that in dygraph mode. @@ -69,9 +68,6 @@ class TestCUDAGraphInStaticMode(unittest.TestCase): } ) - def init_data(self): - self.use_feed_data = False - @switch_to_static_graph def test_cuda_graph_static_graph(self): if not can_use_cuda_graph(): @@ -121,16 +117,12 @@ class TestCUDAGraphInStaticMode(unittest.TestCase): lr_t = scope.var(lr_var.name).get_tensor() cuda_graph = None for batch_id in range(20): - use_feed_data = ( - True if batch_id == 0 and self.use_feed_data else False - ) image_np = np.random.rand(*image_shape).astype('float32') label_np = np.random.randint( low=0, high=class_num, size=label_shape, dtype='int64' ) - if not use_feed_data: - image_t.set(image_np, place) - label_t.set(label_np, place) + image_t.set(image_np, place) + label_t.set(label_np, place) if batch_id == 1 and use_cuda_graph: cuda_graph = CUDAGraph(place, mode="global") @@ -142,27 +134,12 @@ class TestCUDAGraphInStaticMode(unittest.TestCase): lr_t.set(np.array([lr()], dtype='float32'), place) cuda_graph.replay() else: - if use_feed_data: - exe.run( - compiled_program, - feed={'image': image_np, 'label': label_np}, - ) - else: - exe.run(compiled_program) + exe.run(compiled_program) lr.step() if cuda_graph: cuda_graph.reset() return np.array(loss_t) -class TestCUDAGraphWhenFeedDataChanges(TestCUDAGraphInStaticMode): - def init_data(self): - # When feed fetch var of new executor changes, a new - # StandaloneExecutor will be newly created. And the - # behavior of capturing cuda graph will change. - # Add test for this case. - self.use_feed_data = True - - if __name__ == "__main__": unittest.main() -- GitLab