From 2a143f842a5077411bab2c081c7d92515844655e Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 9 Nov 2021 10:50:06 +0800 Subject: [PATCH] Try to fix CUDA Graph H2D copy bug (#36987) * try to fix CUDA Graph H2D copy bug * remove useless code * fix ci * fix ROCM CI * fix CUDA_VERSION * improve CI coverage --- .../fluid/operators/math/concat_and_split.cu | 45 ++++--- paddle/fluid/platform/cuda_graph.cc | 118 ++++++++++++++---- paddle/fluid/platform/cuda_graph.h | 26 +++- .../platform/cuda_graph_with_memory_pool.h | 25 ++++ paddle/fluid/pybind/pybind.cc | 3 +- python/paddle/device/cuda/graphs.py | 17 ++- .../fluid/tests/unittests/test_cuda_graph.py | 55 +++++++- 7 files changed, 235 insertions(+), 54 deletions(-) diff --git a/paddle/fluid/operators/math/concat_and_split.cu b/paddle/fluid/operators/math/concat_and_split.cu index b9481f1c8e4..614ae93d9fa 100644 --- a/paddle/fluid/operators/math/concat_and_split.cu +++ b/paddle/fluid/operators/math/concat_and_split.cu @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/math/concat_and_split.h" +#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/float16.h" @@ -286,10 +287,13 @@ class ConcatFunctor { const T** dev_ins_data = nullptr; if (!has_same_shape || in_num < 2 || in_num > 4) { tmp_dev_ins_data = memory::Alloc(context, in_num * sizeof(T*)); - memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), - tmp_dev_ins_data->ptr(), platform::CPUPlace(), - static_cast(inputs_data), in_num * sizeof(T*), - context.stream()); + { + platform::SkipCUDAGraphCaptureGuard guard; + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), + tmp_dev_ins_data->ptr(), platform::CPUPlace(), + static_cast(inputs_data), in_num * sizeof(T*), + context.stream()); + } dev_ins_data = reinterpret_cast(tmp_dev_ins_data->ptr()); } @@ -313,10 +317,13 @@ class ConcatFunctor { } else { auto tmp_dev_ins_col_data = memory::Alloc(context, inputs_col_num * sizeof(int64_t)); - memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), - tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), - static_cast(inputs_col), - inputs_col_num * sizeof(int64_t), context.stream()); + { + platform::SkipCUDAGraphCaptureGuard guard; + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), + tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), + static_cast(inputs_col), + inputs_col_num * sizeof(int64_t), context.stream()); + } int64_t* dev_ins_col_data = static_cast(tmp_dev_ins_col_data->ptr()); @@ -415,10 +422,13 @@ class SplitFunctor { T** dev_out_gpu_data = nullptr; if (!has_same_shape || o_num < 2 || o_num > 4) { tmp_dev_outs_data = memory::Alloc(context, o_num * sizeof(T*)); - memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), - tmp_dev_outs_data->ptr(), platform::CPUPlace(), - reinterpret_cast(outputs_data), o_num * sizeof(T*), - context.stream()); + { + platform::SkipCUDAGraphCaptureGuard guard; + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), + tmp_dev_outs_data->ptr(), platform::CPUPlace(), + reinterpret_cast(outputs_data), o_num * sizeof(T*), + context.stream()); + } dev_out_gpu_data = reinterpret_cast(tmp_dev_outs_data->ptr()); } @@ -442,10 +452,13 @@ class SplitFunctor { } else { auto tmp_dev_ins_col_data = memory::Alloc(context, outputs_cols_num * sizeof(int64_t)); - memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), - tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), - reinterpret_cast(outputs_cols), - outputs_cols_num * sizeof(int64_t), context.stream()); + { + platform::SkipCUDAGraphCaptureGuard guard; + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), + tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), + reinterpret_cast(outputs_cols), + outputs_cols_num * sizeof(int64_t), context.stream()); + } int64_t* dev_outs_col_data = reinterpret_cast(tmp_dev_ins_col_data->ptr()); diff --git a/paddle/fluid/platform/cuda_graph.cc b/paddle/fluid/platform/cuda_graph.cc index 693a5927990..6f3d452ef5c 100644 --- a/paddle/fluid/platform/cuda_graph.cc +++ b/paddle/fluid/platform/cuda_graph.cc @@ -22,14 +22,14 @@ std::unique_ptr CUDAGraph::capturing_graph_{nullptr}; void CUDAGraph::Reset() { if (is_reset_) return; #if CUDA_VERSION >= 10010 - if (graph_) { - PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphDestroy(graph_)); - graph_ = nullptr; + for (auto graph : graphs_) { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphDestroy(graph)); } - if (exec_graph_) { - PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphExecDestroy(exec_graph_)); - exec_graph_ = nullptr; + graphs_.clear(); + for (auto exec_graph : exec_graphs_) { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphExecDestroy(exec_graph)); } + exec_graphs_.clear(); #endif // callback should be called in reverse order because the latter added // callback may rely on the former added callback. @@ -45,16 +45,33 @@ void CUDAGraph::Replay() { PADDLE_ENFORCE_EQ(is_reset_, false, errors::PermissionDenied( "Cannot replay the CUDA Graph after reset is called.")); - PADDLE_ENFORCE_NOT_NULL(exec_graph_, - errors::PermissionDenied( - "CUDA Graph must be captured before replaying.")); - PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphLaunch(exec_graph_, stream_)); + for (auto exec_graph : exec_graphs_) { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphLaunch(exec_graph, stream_)); + } +#endif +} + +void CUDAGraph::BeginSegmentCapture() { + ThrowErrorIfNotSupportCUDAGraph(); +#if CUDA_VERSION >= 10010 + PADDLE_ENFORCE_EQ( + IsCapturing(), true, + errors::PermissionDenied("BeginSegmentCapture should be called when CUDA " + "Graph is capturing.")); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamBeginCapture( + capturing_graph_->stream_, capturing_graph_->capture_mode_)); + PADDLE_ENFORCE_EQ(IsValidCapturing(), true, + platform::errors::PermissionDenied( + "CUDA Graph should not be invalidated.")); + VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_ + << ", segment id " << capturing_graph_->graphs_.size(); #endif } void CUDAGraph::BeginCapture(platform::CUDAPlace place, cudaStream_t stream, cudaStreamCaptureMode mode) { ThrowErrorIfNotSupportCUDAGraph(); +#if CUDA_VERSION >= 10010 PADDLE_ENFORCE_EQ( IsCapturing(), false, errors::PermissionDenied("CUDA Graph can only captured one by one.")); @@ -64,40 +81,87 @@ void CUDAGraph::BeginCapture(platform::CUDAPlace place, cudaStream_t stream, capturing_graph_.reset(new CUDAGraph()); capturing_graph_->place_ = place; capturing_graph_->stream_ = stream; - - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaStreamBeginCapture(capturing_graph_->stream_, mode)); - cudaStreamCaptureStatus status; - PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamGetCaptureInfo( - capturing_graph_->stream_, &status, &(capturing_graph_->id_))); - PADDLE_ENFORCE_EQ(IsValidCapturing(), true, - platform::errors::PermissionDenied( - "CUDA Graph should not be invalidated.")); - VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_; + capturing_graph_->capture_mode_ = mode; + BeginSegmentCapture(); +#endif } -std::unique_ptr CUDAGraph::EndCapture() { +void CUDAGraph::EndSegmentCapture() { ThrowErrorIfNotSupportCUDAGraph(); #if CUDA_VERSION >= 10010 PADDLE_ENFORCE_EQ(IsCapturing(), true, errors::PermissionDenied("No CUDA Graph is capturing.")); - PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamEndCapture( - capturing_graph_->stream_, &(capturing_graph_->graph_))); + cudaGraph_t graph; PADDLE_ENFORCE_CUDA_SUCCESS( - cudaGraphInstantiate(&(capturing_graph_->exec_graph_), - capturing_graph_->graph_, nullptr, nullptr, 0)); - VLOG(10) << "End to capture CUDA Graph with ID " << capturing_graph_->id_; - return std::move(capturing_graph_); + cudaStreamEndCapture(capturing_graph_->stream_, &graph)); + auto num_nodes = static_cast(-1); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphGetNodes(graph, nullptr, &num_nodes)); + if (num_nodes == 0) { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphDestroy(graph)); + VLOG(10) << "Skip empty CUDA Graph with ID " << capturing_graph_->id_ + << ", segment id " << capturing_graph_->graphs_.size(); + return; + } + + cudaGraphExec_t exec_graph; + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaGraphInstantiate(&exec_graph, graph, nullptr, nullptr, 0)); + VLOG(10) << "End to capture CUDA Graph with ID " << capturing_graph_->id_ + << ", segment id " << capturing_graph_->graphs_.size(); + capturing_graph_->graphs_.emplace_back(graph); + capturing_graph_->exec_graphs_.emplace_back(exec_graph); #endif } +std::unique_ptr CUDAGraph::EndCapture() { + EndSegmentCapture(); + return std::move(capturing_graph_); +} + bool CUDAGraph::IsValidCapturing() { +#if CUDA_VERSION >= 10010 if (!IsCapturing()) return false; cudaStreamCaptureStatus status; CUDAGraphID id; PADDLE_ENFORCE_CUDA_SUCCESS( cudaStreamGetCaptureInfo(capturing_graph_->stream_, &status, &id)); return status == cudaStreamCaptureStatusActive; +#else + return false; +#endif +} + +static std::string ConcatPath(const std::string &dirname, + const std::string &filename) { +#ifdef _WIN32 + const char kFileSep[] = "\\"; +#else + const char kFileSep[] = "/"; +#endif + if (!dirname.empty() && dirname.back() == kFileSep[0]) { + return dirname + filename; + } else { + return dirname + kFileSep + filename; + } +} + +void CUDAGraph::PrintToDotFiles(const std::string &dirname, + unsigned int flags) { + ThrowErrorIfNotSupportCUDAGraph(); +#if CUDA_VERSION >= 11030 + for (size_t i = 0; i < graphs_.size(); ++i) { + auto filename = + ConcatPath(dirname, "segment_" + std::to_string(i) + ".dot"); + VLOG(10) << "Save the " << i << "-th segment of graph " << id_ << " to " + << filename; + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaGraphDebugDotPrint(graphs_[i], filename.c_str(), flags)); + } +#else + PADDLE_THROW(platform::errors::Unimplemented( + "The print_to_dot_files() method is only supported when CUDA version >= " + "11.3.")); +#endif } } // namespace platform diff --git a/paddle/fluid/platform/cuda_graph.h b/paddle/fluid/platform/cuda_graph.h index 55ec463556b..f70a66f7624 100644 --- a/paddle/fluid/platform/cuda_graph.h +++ b/paddle/fluid/platform/cuda_graph.h @@ -14,9 +14,11 @@ #pragma once +#include #include #include #include +#include #include "cuda.h" // NOLINT #include "cuda_runtime.h" // NOLINT #include "paddle/fluid/platform/type_defs.h" @@ -51,7 +53,10 @@ class CUDAGraph { // Since the constructor would throw error is CUDA_VERSION < 10010. // The non-static method of CUDAGraph need not check CUDA_VERSION // again. - CUDAGraph() { ThrowErrorIfNotSupportCUDAGraph(); } + CUDAGraph() { + ThrowErrorIfNotSupportCUDAGraph(); + id_ = UniqueID(); + } public: ~CUDAGraph() { Reset(); } @@ -67,9 +72,15 @@ class CUDAGraph { callbacks_.push_back(std::move(callback)); } + void PrintToDotFiles(const std::string &dirname, unsigned int flags); + static void BeginCapture(platform::CUDAPlace place, cudaStream_t stream, cudaStreamCaptureMode mode); static std::unique_ptr EndCapture(); + + static void BeginSegmentCapture(); + static void EndSegmentCapture(); + static void AddResetCallbackDuringCapturing(std::function callback) { capturing_graph_->AddResetCallback(std::move(callback)); } @@ -88,14 +99,21 @@ class CUDAGraph { // supported during capturing CUDA Graph. static bool IsValidCapturing(); + private: + static CUDAGraphID UniqueID() { + static std::atomic id; + return id.fetch_add(1); + } + private: #if CUDA_VERSION >= 10010 - cudaGraph_t graph_{nullptr}; - cudaGraphExec_t exec_graph_{nullptr}; + std::vector graphs_; + std::vector exec_graphs_; + cudaStreamCaptureMode capture_mode_; #endif cudaStream_t stream_{nullptr}; platform::CUDAPlace place_; - CUDAGraphID id_{0}; + CUDAGraphID id_; std::vector> callbacks_; bool is_reset_{false}; std::mutex mtx_; diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.h b/paddle/fluid/platform/cuda_graph_with_memory_pool.h index f9f0248e515..6586146c5ae 100644 --- a/paddle/fluid/platform/cuda_graph_with_memory_pool.h +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.h @@ -60,5 +60,30 @@ inline void AddResetCallbackIfCapturingCUDAGraph(Callback &&callback) { callback(); } +class SkipCUDAGraphCaptureGuard { + DISABLE_COPY_AND_ASSIGN(SkipCUDAGraphCaptureGuard); + + public: + SkipCUDAGraphCaptureGuard() { +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 10010 + if (UNLIKELY(CUDAGraph::IsCapturing())) { + CUDAGraph::EndSegmentCapture(); + } +#endif +#endif + } + + ~SkipCUDAGraphCaptureGuard() { +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 10010 + if (UNLIKELY(CUDAGraph::IsCapturing())) { + CUDAGraph::BeginSegmentCapture(); + } +#endif +#endif + } +}; + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c3b25671468..7a0930ddde0 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -562,7 +562,8 @@ PYBIND11_MODULE(core_noavx, m) { }) .def_static("end_capture", &platform::EndCUDAGraphCapture) .def("replay", &platform::CUDAGraph::Replay) - .def("reset", &platform::CUDAGraph::Reset); + .def("reset", &platform::CUDAGraph::Reset) + .def("print_to_dot_files", &platform::CUDAGraph::PrintToDotFiles); #endif m.def("wait_device", [](const platform::Place &place) { diff --git a/python/paddle/device/cuda/graphs.py b/python/paddle/device/cuda/graphs.py index 612f4d2c8ce..2a60aad2fd2 100644 --- a/python/paddle/device/cuda/graphs.py +++ b/python/paddle/device/cuda/graphs.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDAPlace if is_compiled_with_cuda() and not is_compiled_with_rocm(): @@ -22,7 +23,8 @@ if is_compiled_with_cuda() and not is_compiled_with_rocm(): ALL_MODES = ["global", "thread_local", "relaxed"] self._graph = None if place is None: - place = CUDAPlace(0) + device_id = int(os.environ.get('FLAGS_selected_gpus', 0)) + place = CUDAPlace(device_id) self._place = place assert mode in ALL_MODES self._mode = ALL_MODES.index(mode) @@ -38,6 +40,16 @@ if is_compiled_with_cuda() and not is_compiled_with_rocm(): def reset(self): self._graph.reset() + + def print_to_dot_files(self, dirname, flags=None): + if not isinstance(dirname, (str, bytes)): + dirname = dirname.name + os.makedirs(name=dirname, exist_ok=True) + assert os.path.isdir( + dirname), "The dirname {} should be a directory".format(dirname) + if flags is None: + flags = 2047 # only all information. It can be any integer inside [1, 2048) + self._graph.print_to_dot_files(dirname, flags) else: class CUDAGraph: @@ -55,3 +67,6 @@ else: def reset(self): raise NotImplementedError() + + def print_to_dot_files(self, dirname, flags=None): + raise NotImplementedError() diff --git a/python/paddle/fluid/tests/unittests/test_cuda_graph.py b/python/paddle/fluid/tests/unittests/test_cuda_graph.py index 7d131747353..8b4eae8ada4 100644 --- a/python/paddle/fluid/tests/unittests/test_cuda_graph.py +++ b/python/paddle/fluid/tests/unittests/test_cuda_graph.py @@ -17,15 +17,21 @@ import paddle.fluid as fluid from paddle.device.cuda.graphs import CUDAGraph import unittest import numpy as np +import os +import pathlib +import shutil from paddle.fluid.dygraph.base import switch_to_static_graph from simple_nets import simple_fc_net_with_inputs +def can_use_cuda_graph(): + return paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm() + + class TestCUDAGraph(unittest.TestCase): def setUp(self): - if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm( - ): - fluid.set_flags({ + if can_use_cuda_graph(): + paddle.set_flags({ 'FLAGS_allocator_strategy': 'auto_growth', 'FLAGS_sync_nccl_allreduce': False, 'FLAGS_cudnn_deterministic': True @@ -38,7 +44,7 @@ class TestCUDAGraph(unittest.TestCase): @switch_to_static_graph def test_cuda_graph_static_graph(self): - if not paddle.is_compiled_with_cuda() or paddle.is_compiled_with_rocm(): + if not can_use_cuda_graph(): return seed = 100 @@ -116,7 +122,7 @@ class TestCUDAGraph(unittest.TestCase): return np.array(loss_t) def test_cuda_graph_dynamic_graph(self): - if not paddle.is_compiled_with_cuda() or paddle.is_compiled_with_rocm(): + if not can_use_cuda_graph(): return shape = [2, 3] @@ -142,6 +148,45 @@ class TestCUDAGraph(unittest.TestCase): g.reset() + def test_concat_and_split(self): + if not can_use_cuda_graph(): + return + + concat_num = 100 + xs = [] + xs_np = [] + + for i in range(concat_num): + x_np = np.random.random(size=[1]).astype(np.float32) + xs.append(paddle.to_tensor(x_np)) + xs_np.append(x_np) + + graph = CUDAGraph() + graph.capture_begin() + y = paddle.concat(xs) + zs = paddle.split(y, len(xs)) + graph.capture_end() + graph.replay() + + y_np = y.numpy() + y_np_expected = np.concatenate(xs_np) + self.assertTrue(np.array_equal(y_np, y_np_expected)) + self.assertEqual(len(zs), len(xs_np)) + for i, z in enumerate(zs): + self.assertTrue(np.array_equal(z.numpy(), xs_np[i])) + + output_dir = 'cuda_graph_dot_{}'.format(os.getpid()) + try: + graph.print_to_dot_files(pathlib.Path(output_dir)) + graph.reset() + shutil.rmtree(output_dir) + except Exception as e: + msg = str(e) + sub_msg = "The print_to_dot_files() method is only supported when CUDA version >= 11.3" + self.assertTrue(sub_msg in msg) + finally: + graph.reset() + if __name__ == "__main__": unittest.main() -- GitLab