未验证 提交 2a143f84 编写于 作者: Z Zeng Jinle 提交者: GitHub

Try to fix CUDA Graph H2D copy bug (#36987)

* try to fix CUDA Graph H2D copy bug

* remove useless code

* fix ci

* fix ROCM CI

* fix CUDA_VERSION

* improve CI coverage
上级 819b9589
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
......@@ -286,10 +287,13 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
const T** dev_ins_data = nullptr;
if (!has_same_shape || in_num < 2 || in_num > 4) {
tmp_dev_ins_data = memory::Alloc(context, in_num * sizeof(T*));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_data), in_num * sizeof(T*),
context.stream());
{
platform::SkipCUDAGraphCaptureGuard guard;
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_data), in_num * sizeof(T*),
context.stream());
}
dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr());
}
......@@ -313,10 +317,13 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
} else {
auto tmp_dev_ins_col_data =
memory::Alloc(context, inputs_col_num * sizeof(int64_t));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_col),
inputs_col_num * sizeof(int64_t), context.stream());
{
platform::SkipCUDAGraphCaptureGuard guard;
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_col),
inputs_col_num * sizeof(int64_t), context.stream());
}
int64_t* dev_ins_col_data =
static_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
......@@ -415,10 +422,13 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
T** dev_out_gpu_data = nullptr;
if (!has_same_shape || o_num < 2 || o_num > 4) {
tmp_dev_outs_data = memory::Alloc(context, o_num * sizeof(T*));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_outs_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_data), o_num * sizeof(T*),
context.stream());
{
platform::SkipCUDAGraphCaptureGuard guard;
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_outs_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_data), o_num * sizeof(T*),
context.stream());
}
dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
}
......@@ -442,10 +452,13 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
} else {
auto tmp_dev_ins_col_data =
memory::Alloc(context, outputs_cols_num * sizeof(int64_t));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_cols),
outputs_cols_num * sizeof(int64_t), context.stream());
{
platform::SkipCUDAGraphCaptureGuard guard;
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_cols),
outputs_cols_num * sizeof(int64_t), context.stream());
}
int64_t* dev_outs_col_data =
reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
......
......@@ -22,14 +22,14 @@ std::unique_ptr<CUDAGraph> CUDAGraph::capturing_graph_{nullptr};
void CUDAGraph::Reset() {
if (is_reset_) return;
#if CUDA_VERSION >= 10010
if (graph_) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphDestroy(graph_));
graph_ = nullptr;
for (auto graph : graphs_) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphDestroy(graph));
}
if (exec_graph_) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphExecDestroy(exec_graph_));
exec_graph_ = nullptr;
graphs_.clear();
for (auto exec_graph : exec_graphs_) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphExecDestroy(exec_graph));
}
exec_graphs_.clear();
#endif
// callback should be called in reverse order because the latter added
// callback may rely on the former added callback.
......@@ -45,16 +45,33 @@ void CUDAGraph::Replay() {
PADDLE_ENFORCE_EQ(is_reset_, false,
errors::PermissionDenied(
"Cannot replay the CUDA Graph after reset is called."));
PADDLE_ENFORCE_NOT_NULL(exec_graph_,
errors::PermissionDenied(
"CUDA Graph must be captured before replaying."));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphLaunch(exec_graph_, stream_));
for (auto exec_graph : exec_graphs_) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphLaunch(exec_graph, stream_));
}
#endif
}
void CUDAGraph::BeginSegmentCapture() {
ThrowErrorIfNotSupportCUDAGraph();
#if CUDA_VERSION >= 10010
PADDLE_ENFORCE_EQ(
IsCapturing(), true,
errors::PermissionDenied("BeginSegmentCapture should be called when CUDA "
"Graph is capturing."));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamBeginCapture(
capturing_graph_->stream_, capturing_graph_->capture_mode_));
PADDLE_ENFORCE_EQ(IsValidCapturing(), true,
platform::errors::PermissionDenied(
"CUDA Graph should not be invalidated."));
VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_
<< ", segment id " << capturing_graph_->graphs_.size();
#endif
}
void CUDAGraph::BeginCapture(platform::CUDAPlace place, cudaStream_t stream,
cudaStreamCaptureMode mode) {
ThrowErrorIfNotSupportCUDAGraph();
#if CUDA_VERSION >= 10010
PADDLE_ENFORCE_EQ(
IsCapturing(), false,
errors::PermissionDenied("CUDA Graph can only captured one by one."));
......@@ -64,40 +81,87 @@ void CUDAGraph::BeginCapture(platform::CUDAPlace place, cudaStream_t stream,
capturing_graph_.reset(new CUDAGraph());
capturing_graph_->place_ = place;
capturing_graph_->stream_ = stream;
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamBeginCapture(capturing_graph_->stream_, mode));
cudaStreamCaptureStatus status;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamGetCaptureInfo(
capturing_graph_->stream_, &status, &(capturing_graph_->id_)));
PADDLE_ENFORCE_EQ(IsValidCapturing(), true,
platform::errors::PermissionDenied(
"CUDA Graph should not be invalidated."));
VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_;
capturing_graph_->capture_mode_ = mode;
BeginSegmentCapture();
#endif
}
std::unique_ptr<CUDAGraph> CUDAGraph::EndCapture() {
void CUDAGraph::EndSegmentCapture() {
ThrowErrorIfNotSupportCUDAGraph();
#if CUDA_VERSION >= 10010
PADDLE_ENFORCE_EQ(IsCapturing(), true,
errors::PermissionDenied("No CUDA Graph is capturing."));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamEndCapture(
capturing_graph_->stream_, &(capturing_graph_->graph_)));
cudaGraph_t graph;
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaGraphInstantiate(&(capturing_graph_->exec_graph_),
capturing_graph_->graph_, nullptr, nullptr, 0));
VLOG(10) << "End to capture CUDA Graph with ID " << capturing_graph_->id_;
return std::move(capturing_graph_);
cudaStreamEndCapture(capturing_graph_->stream_, &graph));
auto num_nodes = static_cast<size_t>(-1);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphGetNodes(graph, nullptr, &num_nodes));
if (num_nodes == 0) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphDestroy(graph));
VLOG(10) << "Skip empty CUDA Graph with ID " << capturing_graph_->id_
<< ", segment id " << capturing_graph_->graphs_.size();
return;
}
cudaGraphExec_t exec_graph;
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaGraphInstantiate(&exec_graph, graph, nullptr, nullptr, 0));
VLOG(10) << "End to capture CUDA Graph with ID " << capturing_graph_->id_
<< ", segment id " << capturing_graph_->graphs_.size();
capturing_graph_->graphs_.emplace_back(graph);
capturing_graph_->exec_graphs_.emplace_back(exec_graph);
#endif
}
std::unique_ptr<CUDAGraph> CUDAGraph::EndCapture() {
EndSegmentCapture();
return std::move(capturing_graph_);
}
bool CUDAGraph::IsValidCapturing() {
#if CUDA_VERSION >= 10010
if (!IsCapturing()) return false;
cudaStreamCaptureStatus status;
CUDAGraphID id;
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamGetCaptureInfo(capturing_graph_->stream_, &status, &id));
return status == cudaStreamCaptureStatusActive;
#else
return false;
#endif
}
static std::string ConcatPath(const std::string &dirname,
const std::string &filename) {
#ifdef _WIN32
const char kFileSep[] = "\\";
#else
const char kFileSep[] = "/";
#endif
if (!dirname.empty() && dirname.back() == kFileSep[0]) {
return dirname + filename;
} else {
return dirname + kFileSep + filename;
}
}
void CUDAGraph::PrintToDotFiles(const std::string &dirname,
unsigned int flags) {
ThrowErrorIfNotSupportCUDAGraph();
#if CUDA_VERSION >= 11030
for (size_t i = 0; i < graphs_.size(); ++i) {
auto filename =
ConcatPath(dirname, "segment_" + std::to_string(i) + ".dot");
VLOG(10) << "Save the " << i << "-th segment of graph " << id_ << " to "
<< filename;
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaGraphDebugDotPrint(graphs_[i], filename.c_str(), flags));
}
#else
PADDLE_THROW(platform::errors::Unimplemented(
"The print_to_dot_files() method is only supported when CUDA version >= "
"11.3."));
#endif
}
} // namespace platform
......
......@@ -14,9 +14,11 @@
#pragma once
#include <atomic>
#include <functional>
#include <memory>
#include <mutex>
#include <vector>
#include "cuda.h" // NOLINT
#include "cuda_runtime.h" // NOLINT
#include "paddle/fluid/platform/type_defs.h"
......@@ -51,7 +53,10 @@ class CUDAGraph {
// Since the constructor would throw error is CUDA_VERSION < 10010.
// The non-static method of CUDAGraph need not check CUDA_VERSION
// again.
CUDAGraph() { ThrowErrorIfNotSupportCUDAGraph(); }
CUDAGraph() {
ThrowErrorIfNotSupportCUDAGraph();
id_ = UniqueID();
}
public:
~CUDAGraph() { Reset(); }
......@@ -67,9 +72,15 @@ class CUDAGraph {
callbacks_.push_back(std::move(callback));
}
void PrintToDotFiles(const std::string &dirname, unsigned int flags);
static void BeginCapture(platform::CUDAPlace place, cudaStream_t stream,
cudaStreamCaptureMode mode);
static std::unique_ptr<CUDAGraph> EndCapture();
static void BeginSegmentCapture();
static void EndSegmentCapture();
static void AddResetCallbackDuringCapturing(std::function<void()> callback) {
capturing_graph_->AddResetCallback(std::move(callback));
}
......@@ -88,14 +99,21 @@ class CUDAGraph {
// supported during capturing CUDA Graph.
static bool IsValidCapturing();
private:
static CUDAGraphID UniqueID() {
static std::atomic<CUDAGraphID> id;
return id.fetch_add(1);
}
private:
#if CUDA_VERSION >= 10010
cudaGraph_t graph_{nullptr};
cudaGraphExec_t exec_graph_{nullptr};
std::vector<cudaGraph_t> graphs_;
std::vector<cudaGraphExec_t> exec_graphs_;
cudaStreamCaptureMode capture_mode_;
#endif
cudaStream_t stream_{nullptr};
platform::CUDAPlace place_;
CUDAGraphID id_{0};
CUDAGraphID id_;
std::vector<std::function<void()>> callbacks_;
bool is_reset_{false};
std::mutex mtx_;
......
......@@ -60,5 +60,30 @@ inline void AddResetCallbackIfCapturingCUDAGraph(Callback &&callback) {
callback();
}
class SkipCUDAGraphCaptureGuard {
DISABLE_COPY_AND_ASSIGN(SkipCUDAGraphCaptureGuard);
public:
SkipCUDAGraphCaptureGuard() {
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10010
if (UNLIKELY(CUDAGraph::IsCapturing())) {
CUDAGraph::EndSegmentCapture();
}
#endif
#endif
}
~SkipCUDAGraphCaptureGuard() {
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10010
if (UNLIKELY(CUDAGraph::IsCapturing())) {
CUDAGraph::BeginSegmentCapture();
}
#endif
#endif
}
};
} // namespace platform
} // namespace paddle
......@@ -562,7 +562,8 @@ PYBIND11_MODULE(core_noavx, m) {
})
.def_static("end_capture", &platform::EndCUDAGraphCapture)
.def("replay", &platform::CUDAGraph::Replay)
.def("reset", &platform::CUDAGraph::Reset);
.def("reset", &platform::CUDAGraph::Reset)
.def("print_to_dot_files", &platform::CUDAGraph::PrintToDotFiles);
#endif
m.def("wait_device", [](const platform::Place &place) {
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDAPlace
if is_compiled_with_cuda() and not is_compiled_with_rocm():
......@@ -22,7 +23,8 @@ if is_compiled_with_cuda() and not is_compiled_with_rocm():
ALL_MODES = ["global", "thread_local", "relaxed"]
self._graph = None
if place is None:
place = CUDAPlace(0)
device_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = CUDAPlace(device_id)
self._place = place
assert mode in ALL_MODES
self._mode = ALL_MODES.index(mode)
......@@ -38,6 +40,16 @@ if is_compiled_with_cuda() and not is_compiled_with_rocm():
def reset(self):
self._graph.reset()
def print_to_dot_files(self, dirname, flags=None):
if not isinstance(dirname, (str, bytes)):
dirname = dirname.name
os.makedirs(name=dirname, exist_ok=True)
assert os.path.isdir(
dirname), "The dirname {} should be a directory".format(dirname)
if flags is None:
flags = 2047 # only all information. It can be any integer inside [1, 2048)
self._graph.print_to_dot_files(dirname, flags)
else:
class CUDAGraph:
......@@ -55,3 +67,6 @@ else:
def reset(self):
raise NotImplementedError()
def print_to_dot_files(self, dirname, flags=None):
raise NotImplementedError()
......@@ -17,15 +17,21 @@ import paddle.fluid as fluid
from paddle.device.cuda.graphs import CUDAGraph
import unittest
import numpy as np
import os
import pathlib
import shutil
from paddle.fluid.dygraph.base import switch_to_static_graph
from simple_nets import simple_fc_net_with_inputs
def can_use_cuda_graph():
return paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm()
class TestCUDAGraph(unittest.TestCase):
def setUp(self):
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm(
):
fluid.set_flags({
if can_use_cuda_graph():
paddle.set_flags({
'FLAGS_allocator_strategy': 'auto_growth',
'FLAGS_sync_nccl_allreduce': False,
'FLAGS_cudnn_deterministic': True
......@@ -38,7 +44,7 @@ class TestCUDAGraph(unittest.TestCase):
@switch_to_static_graph
def test_cuda_graph_static_graph(self):
if not paddle.is_compiled_with_cuda() or paddle.is_compiled_with_rocm():
if not can_use_cuda_graph():
return
seed = 100
......@@ -116,7 +122,7 @@ class TestCUDAGraph(unittest.TestCase):
return np.array(loss_t)
def test_cuda_graph_dynamic_graph(self):
if not paddle.is_compiled_with_cuda() or paddle.is_compiled_with_rocm():
if not can_use_cuda_graph():
return
shape = [2, 3]
......@@ -142,6 +148,45 @@ class TestCUDAGraph(unittest.TestCase):
g.reset()
def test_concat_and_split(self):
if not can_use_cuda_graph():
return
concat_num = 100
xs = []
xs_np = []
for i in range(concat_num):
x_np = np.random.random(size=[1]).astype(np.float32)
xs.append(paddle.to_tensor(x_np))
xs_np.append(x_np)
graph = CUDAGraph()
graph.capture_begin()
y = paddle.concat(xs)
zs = paddle.split(y, len(xs))
graph.capture_end()
graph.replay()
y_np = y.numpy()
y_np_expected = np.concatenate(xs_np)
self.assertTrue(np.array_equal(y_np, y_np_expected))
self.assertEqual(len(zs), len(xs_np))
for i, z in enumerate(zs):
self.assertTrue(np.array_equal(z.numpy(), xs_np[i]))
output_dir = 'cuda_graph_dot_{}'.format(os.getpid())
try:
graph.print_to_dot_files(pathlib.Path(output_dir))
graph.reset()
shutil.rmtree(output_dir)
except Exception as e:
msg = str(e)
sub_msg = "The print_to_dot_files() method is only supported when CUDA version >= 11.3"
self.assertTrue(sub_msg in msg)
finally:
graph.reset()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册