未验证 提交 2a143f84 编写于 作者: Z Zeng Jinle 提交者: GitHub

Try to fix CUDA Graph H2D copy bug (#36987)

* try to fix CUDA Graph H2D copy bug

* remove useless code

* fix ci

* fix ROCM CI

* fix CUDA_VERSION

* improve CI coverage
上级 819b9589
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -286,10 +287,13 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -286,10 +287,13 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
const T** dev_ins_data = nullptr; const T** dev_ins_data = nullptr;
if (!has_same_shape || in_num < 2 || in_num > 4) { if (!has_same_shape || in_num < 2 || in_num > 4) {
tmp_dev_ins_data = memory::Alloc(context, in_num * sizeof(T*)); tmp_dev_ins_data = memory::Alloc(context, in_num * sizeof(T*));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), {
tmp_dev_ins_data->ptr(), platform::CPUPlace(), platform::SkipCUDAGraphCaptureGuard guard;
static_cast<void*>(inputs_data), in_num * sizeof(T*), memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
context.stream()); tmp_dev_ins_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_data), in_num * sizeof(T*),
context.stream());
}
dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr()); dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr());
} }
...@@ -313,10 +317,13 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -313,10 +317,13 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
} else { } else {
auto tmp_dev_ins_col_data = auto tmp_dev_ins_col_data =
memory::Alloc(context, inputs_col_num * sizeof(int64_t)); memory::Alloc(context, inputs_col_num * sizeof(int64_t));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), {
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), platform::SkipCUDAGraphCaptureGuard guard;
static_cast<void*>(inputs_col), memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
inputs_col_num * sizeof(int64_t), context.stream()); tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_col),
inputs_col_num * sizeof(int64_t), context.stream());
}
int64_t* dev_ins_col_data = int64_t* dev_ins_col_data =
static_cast<int64_t*>(tmp_dev_ins_col_data->ptr()); static_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
...@@ -415,10 +422,13 @@ class SplitFunctor<platform::CUDADeviceContext, T> { ...@@ -415,10 +422,13 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
T** dev_out_gpu_data = nullptr; T** dev_out_gpu_data = nullptr;
if (!has_same_shape || o_num < 2 || o_num > 4) { if (!has_same_shape || o_num < 2 || o_num > 4) {
tmp_dev_outs_data = memory::Alloc(context, o_num * sizeof(T*)); tmp_dev_outs_data = memory::Alloc(context, o_num * sizeof(T*));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), {
tmp_dev_outs_data->ptr(), platform::CPUPlace(), platform::SkipCUDAGraphCaptureGuard guard;
reinterpret_cast<void*>(outputs_data), o_num * sizeof(T*), memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
context.stream()); tmp_dev_outs_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_data), o_num * sizeof(T*),
context.stream());
}
dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr()); dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
} }
...@@ -442,10 +452,13 @@ class SplitFunctor<platform::CUDADeviceContext, T> { ...@@ -442,10 +452,13 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
} else { } else {
auto tmp_dev_ins_col_data = auto tmp_dev_ins_col_data =
memory::Alloc(context, outputs_cols_num * sizeof(int64_t)); memory::Alloc(context, outputs_cols_num * sizeof(int64_t));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), {
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), platform::SkipCUDAGraphCaptureGuard guard;
reinterpret_cast<void*>(outputs_cols), memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
outputs_cols_num * sizeof(int64_t), context.stream()); tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_cols),
outputs_cols_num * sizeof(int64_t), context.stream());
}
int64_t* dev_outs_col_data = int64_t* dev_outs_col_data =
reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr()); reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
......
...@@ -22,14 +22,14 @@ std::unique_ptr<CUDAGraph> CUDAGraph::capturing_graph_{nullptr}; ...@@ -22,14 +22,14 @@ std::unique_ptr<CUDAGraph> CUDAGraph::capturing_graph_{nullptr};
void CUDAGraph::Reset() { void CUDAGraph::Reset() {
if (is_reset_) return; if (is_reset_) return;
#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
if (graph_) { for (auto graph : graphs_) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphDestroy(graph_)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphDestroy(graph));
graph_ = nullptr;
} }
if (exec_graph_) { graphs_.clear();
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphExecDestroy(exec_graph_)); for (auto exec_graph : exec_graphs_) {
exec_graph_ = nullptr; PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphExecDestroy(exec_graph));
} }
exec_graphs_.clear();
#endif #endif
// callback should be called in reverse order because the latter added // callback should be called in reverse order because the latter added
// callback may rely on the former added callback. // callback may rely on the former added callback.
...@@ -45,16 +45,33 @@ void CUDAGraph::Replay() { ...@@ -45,16 +45,33 @@ void CUDAGraph::Replay() {
PADDLE_ENFORCE_EQ(is_reset_, false, PADDLE_ENFORCE_EQ(is_reset_, false,
errors::PermissionDenied( errors::PermissionDenied(
"Cannot replay the CUDA Graph after reset is called.")); "Cannot replay the CUDA Graph after reset is called."));
PADDLE_ENFORCE_NOT_NULL(exec_graph_, for (auto exec_graph : exec_graphs_) {
errors::PermissionDenied( PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphLaunch(exec_graph, stream_));
"CUDA Graph must be captured before replaying.")); }
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphLaunch(exec_graph_, stream_)); #endif
}
void CUDAGraph::BeginSegmentCapture() {
ThrowErrorIfNotSupportCUDAGraph();
#if CUDA_VERSION >= 10010
PADDLE_ENFORCE_EQ(
IsCapturing(), true,
errors::PermissionDenied("BeginSegmentCapture should be called when CUDA "
"Graph is capturing."));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamBeginCapture(
capturing_graph_->stream_, capturing_graph_->capture_mode_));
PADDLE_ENFORCE_EQ(IsValidCapturing(), true,
platform::errors::PermissionDenied(
"CUDA Graph should not be invalidated."));
VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_
<< ", segment id " << capturing_graph_->graphs_.size();
#endif #endif
} }
void CUDAGraph::BeginCapture(platform::CUDAPlace place, cudaStream_t stream, void CUDAGraph::BeginCapture(platform::CUDAPlace place, cudaStream_t stream,
cudaStreamCaptureMode mode) { cudaStreamCaptureMode mode) {
ThrowErrorIfNotSupportCUDAGraph(); ThrowErrorIfNotSupportCUDAGraph();
#if CUDA_VERSION >= 10010
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
IsCapturing(), false, IsCapturing(), false,
errors::PermissionDenied("CUDA Graph can only captured one by one.")); errors::PermissionDenied("CUDA Graph can only captured one by one."));
...@@ -64,40 +81,87 @@ void CUDAGraph::BeginCapture(platform::CUDAPlace place, cudaStream_t stream, ...@@ -64,40 +81,87 @@ void CUDAGraph::BeginCapture(platform::CUDAPlace place, cudaStream_t stream,
capturing_graph_.reset(new CUDAGraph()); capturing_graph_.reset(new CUDAGraph());
capturing_graph_->place_ = place; capturing_graph_->place_ = place;
capturing_graph_->stream_ = stream; capturing_graph_->stream_ = stream;
capturing_graph_->capture_mode_ = mode;
PADDLE_ENFORCE_CUDA_SUCCESS( BeginSegmentCapture();
cudaStreamBeginCapture(capturing_graph_->stream_, mode)); #endif
cudaStreamCaptureStatus status;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamGetCaptureInfo(
capturing_graph_->stream_, &status, &(capturing_graph_->id_)));
PADDLE_ENFORCE_EQ(IsValidCapturing(), true,
platform::errors::PermissionDenied(
"CUDA Graph should not be invalidated."));
VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_;
} }
std::unique_ptr<CUDAGraph> CUDAGraph::EndCapture() { void CUDAGraph::EndSegmentCapture() {
ThrowErrorIfNotSupportCUDAGraph(); ThrowErrorIfNotSupportCUDAGraph();
#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
PADDLE_ENFORCE_EQ(IsCapturing(), true, PADDLE_ENFORCE_EQ(IsCapturing(), true,
errors::PermissionDenied("No CUDA Graph is capturing.")); errors::PermissionDenied("No CUDA Graph is capturing."));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamEndCapture( cudaGraph_t graph;
capturing_graph_->stream_, &(capturing_graph_->graph_)));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
cudaGraphInstantiate(&(capturing_graph_->exec_graph_), cudaStreamEndCapture(capturing_graph_->stream_, &graph));
capturing_graph_->graph_, nullptr, nullptr, 0)); auto num_nodes = static_cast<size_t>(-1);
VLOG(10) << "End to capture CUDA Graph with ID " << capturing_graph_->id_; PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphGetNodes(graph, nullptr, &num_nodes));
return std::move(capturing_graph_); if (num_nodes == 0) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphDestroy(graph));
VLOG(10) << "Skip empty CUDA Graph with ID " << capturing_graph_->id_
<< ", segment id " << capturing_graph_->graphs_.size();
return;
}
cudaGraphExec_t exec_graph;
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaGraphInstantiate(&exec_graph, graph, nullptr, nullptr, 0));
VLOG(10) << "End to capture CUDA Graph with ID " << capturing_graph_->id_
<< ", segment id " << capturing_graph_->graphs_.size();
capturing_graph_->graphs_.emplace_back(graph);
capturing_graph_->exec_graphs_.emplace_back(exec_graph);
#endif #endif
} }
std::unique_ptr<CUDAGraph> CUDAGraph::EndCapture() {
EndSegmentCapture();
return std::move(capturing_graph_);
}
bool CUDAGraph::IsValidCapturing() { bool CUDAGraph::IsValidCapturing() {
#if CUDA_VERSION >= 10010
if (!IsCapturing()) return false; if (!IsCapturing()) return false;
cudaStreamCaptureStatus status; cudaStreamCaptureStatus status;
CUDAGraphID id; CUDAGraphID id;
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamGetCaptureInfo(capturing_graph_->stream_, &status, &id)); cudaStreamGetCaptureInfo(capturing_graph_->stream_, &status, &id));
return status == cudaStreamCaptureStatusActive; return status == cudaStreamCaptureStatusActive;
#else
return false;
#endif
}
static std::string ConcatPath(const std::string &dirname,
const std::string &filename) {
#ifdef _WIN32
const char kFileSep[] = "\\";
#else
const char kFileSep[] = "/";
#endif
if (!dirname.empty() && dirname.back() == kFileSep[0]) {
return dirname + filename;
} else {
return dirname + kFileSep + filename;
}
}
void CUDAGraph::PrintToDotFiles(const std::string &dirname,
unsigned int flags) {
ThrowErrorIfNotSupportCUDAGraph();
#if CUDA_VERSION >= 11030
for (size_t i = 0; i < graphs_.size(); ++i) {
auto filename =
ConcatPath(dirname, "segment_" + std::to_string(i) + ".dot");
VLOG(10) << "Save the " << i << "-th segment of graph " << id_ << " to "
<< filename;
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaGraphDebugDotPrint(graphs_[i], filename.c_str(), flags));
}
#else
PADDLE_THROW(platform::errors::Unimplemented(
"The print_to_dot_files() method is only supported when CUDA version >= "
"11.3."));
#endif
} }
} // namespace platform } // namespace platform
......
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
#pragma once #pragma once
#include <atomic>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <vector>
#include "cuda.h" // NOLINT #include "cuda.h" // NOLINT
#include "cuda_runtime.h" // NOLINT #include "cuda_runtime.h" // NOLINT
#include "paddle/fluid/platform/type_defs.h" #include "paddle/fluid/platform/type_defs.h"
...@@ -51,7 +53,10 @@ class CUDAGraph { ...@@ -51,7 +53,10 @@ class CUDAGraph {
// Since the constructor would throw error is CUDA_VERSION < 10010. // Since the constructor would throw error is CUDA_VERSION < 10010.
// The non-static method of CUDAGraph need not check CUDA_VERSION // The non-static method of CUDAGraph need not check CUDA_VERSION
// again. // again.
CUDAGraph() { ThrowErrorIfNotSupportCUDAGraph(); } CUDAGraph() {
ThrowErrorIfNotSupportCUDAGraph();
id_ = UniqueID();
}
public: public:
~CUDAGraph() { Reset(); } ~CUDAGraph() { Reset(); }
...@@ -67,9 +72,15 @@ class CUDAGraph { ...@@ -67,9 +72,15 @@ class CUDAGraph {
callbacks_.push_back(std::move(callback)); callbacks_.push_back(std::move(callback));
} }
void PrintToDotFiles(const std::string &dirname, unsigned int flags);
static void BeginCapture(platform::CUDAPlace place, cudaStream_t stream, static void BeginCapture(platform::CUDAPlace place, cudaStream_t stream,
cudaStreamCaptureMode mode); cudaStreamCaptureMode mode);
static std::unique_ptr<CUDAGraph> EndCapture(); static std::unique_ptr<CUDAGraph> EndCapture();
static void BeginSegmentCapture();
static void EndSegmentCapture();
static void AddResetCallbackDuringCapturing(std::function<void()> callback) { static void AddResetCallbackDuringCapturing(std::function<void()> callback) {
capturing_graph_->AddResetCallback(std::move(callback)); capturing_graph_->AddResetCallback(std::move(callback));
} }
...@@ -88,14 +99,21 @@ class CUDAGraph { ...@@ -88,14 +99,21 @@ class CUDAGraph {
// supported during capturing CUDA Graph. // supported during capturing CUDA Graph.
static bool IsValidCapturing(); static bool IsValidCapturing();
private:
static CUDAGraphID UniqueID() {
static std::atomic<CUDAGraphID> id;
return id.fetch_add(1);
}
private: private:
#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
cudaGraph_t graph_{nullptr}; std::vector<cudaGraph_t> graphs_;
cudaGraphExec_t exec_graph_{nullptr}; std::vector<cudaGraphExec_t> exec_graphs_;
cudaStreamCaptureMode capture_mode_;
#endif #endif
cudaStream_t stream_{nullptr}; cudaStream_t stream_{nullptr};
platform::CUDAPlace place_; platform::CUDAPlace place_;
CUDAGraphID id_{0}; CUDAGraphID id_;
std::vector<std::function<void()>> callbacks_; std::vector<std::function<void()>> callbacks_;
bool is_reset_{false}; bool is_reset_{false};
std::mutex mtx_; std::mutex mtx_;
......
...@@ -60,5 +60,30 @@ inline void AddResetCallbackIfCapturingCUDAGraph(Callback &&callback) { ...@@ -60,5 +60,30 @@ inline void AddResetCallbackIfCapturingCUDAGraph(Callback &&callback) {
callback(); callback();
} }
class SkipCUDAGraphCaptureGuard {
DISABLE_COPY_AND_ASSIGN(SkipCUDAGraphCaptureGuard);
public:
SkipCUDAGraphCaptureGuard() {
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10010
if (UNLIKELY(CUDAGraph::IsCapturing())) {
CUDAGraph::EndSegmentCapture();
}
#endif
#endif
}
~SkipCUDAGraphCaptureGuard() {
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10010
if (UNLIKELY(CUDAGraph::IsCapturing())) {
CUDAGraph::BeginSegmentCapture();
}
#endif
#endif
}
};
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -562,7 +562,8 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -562,7 +562,8 @@ PYBIND11_MODULE(core_noavx, m) {
}) })
.def_static("end_capture", &platform::EndCUDAGraphCapture) .def_static("end_capture", &platform::EndCUDAGraphCapture)
.def("replay", &platform::CUDAGraph::Replay) .def("replay", &platform::CUDAGraph::Replay)
.def("reset", &platform::CUDAGraph::Reset); .def("reset", &platform::CUDAGraph::Reset)
.def("print_to_dot_files", &platform::CUDAGraph::PrintToDotFiles);
#endif #endif
m.def("wait_device", [](const platform::Place &place) { m.def("wait_device", [](const platform::Place &place) {
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDAPlace from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDAPlace
if is_compiled_with_cuda() and not is_compiled_with_rocm(): if is_compiled_with_cuda() and not is_compiled_with_rocm():
...@@ -22,7 +23,8 @@ if is_compiled_with_cuda() and not is_compiled_with_rocm(): ...@@ -22,7 +23,8 @@ if is_compiled_with_cuda() and not is_compiled_with_rocm():
ALL_MODES = ["global", "thread_local", "relaxed"] ALL_MODES = ["global", "thread_local", "relaxed"]
self._graph = None self._graph = None
if place is None: if place is None:
place = CUDAPlace(0) device_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = CUDAPlace(device_id)
self._place = place self._place = place
assert mode in ALL_MODES assert mode in ALL_MODES
self._mode = ALL_MODES.index(mode) self._mode = ALL_MODES.index(mode)
...@@ -38,6 +40,16 @@ if is_compiled_with_cuda() and not is_compiled_with_rocm(): ...@@ -38,6 +40,16 @@ if is_compiled_with_cuda() and not is_compiled_with_rocm():
def reset(self): def reset(self):
self._graph.reset() self._graph.reset()
def print_to_dot_files(self, dirname, flags=None):
if not isinstance(dirname, (str, bytes)):
dirname = dirname.name
os.makedirs(name=dirname, exist_ok=True)
assert os.path.isdir(
dirname), "The dirname {} should be a directory".format(dirname)
if flags is None:
flags = 2047 # only all information. It can be any integer inside [1, 2048)
self._graph.print_to_dot_files(dirname, flags)
else: else:
class CUDAGraph: class CUDAGraph:
...@@ -55,3 +67,6 @@ else: ...@@ -55,3 +67,6 @@ else:
def reset(self): def reset(self):
raise NotImplementedError() raise NotImplementedError()
def print_to_dot_files(self, dirname, flags=None):
raise NotImplementedError()
...@@ -17,15 +17,21 @@ import paddle.fluid as fluid ...@@ -17,15 +17,21 @@ import paddle.fluid as fluid
from paddle.device.cuda.graphs import CUDAGraph from paddle.device.cuda.graphs import CUDAGraph
import unittest import unittest
import numpy as np import numpy as np
import os
import pathlib
import shutil
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
from simple_nets import simple_fc_net_with_inputs from simple_nets import simple_fc_net_with_inputs
def can_use_cuda_graph():
return paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm()
class TestCUDAGraph(unittest.TestCase): class TestCUDAGraph(unittest.TestCase):
def setUp(self): def setUp(self):
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm( if can_use_cuda_graph():
): paddle.set_flags({
fluid.set_flags({
'FLAGS_allocator_strategy': 'auto_growth', 'FLAGS_allocator_strategy': 'auto_growth',
'FLAGS_sync_nccl_allreduce': False, 'FLAGS_sync_nccl_allreduce': False,
'FLAGS_cudnn_deterministic': True 'FLAGS_cudnn_deterministic': True
...@@ -38,7 +44,7 @@ class TestCUDAGraph(unittest.TestCase): ...@@ -38,7 +44,7 @@ class TestCUDAGraph(unittest.TestCase):
@switch_to_static_graph @switch_to_static_graph
def test_cuda_graph_static_graph(self): def test_cuda_graph_static_graph(self):
if not paddle.is_compiled_with_cuda() or paddle.is_compiled_with_rocm(): if not can_use_cuda_graph():
return return
seed = 100 seed = 100
...@@ -116,7 +122,7 @@ class TestCUDAGraph(unittest.TestCase): ...@@ -116,7 +122,7 @@ class TestCUDAGraph(unittest.TestCase):
return np.array(loss_t) return np.array(loss_t)
def test_cuda_graph_dynamic_graph(self): def test_cuda_graph_dynamic_graph(self):
if not paddle.is_compiled_with_cuda() or paddle.is_compiled_with_rocm(): if not can_use_cuda_graph():
return return
shape = [2, 3] shape = [2, 3]
...@@ -142,6 +148,45 @@ class TestCUDAGraph(unittest.TestCase): ...@@ -142,6 +148,45 @@ class TestCUDAGraph(unittest.TestCase):
g.reset() g.reset()
def test_concat_and_split(self):
if not can_use_cuda_graph():
return
concat_num = 100
xs = []
xs_np = []
for i in range(concat_num):
x_np = np.random.random(size=[1]).astype(np.float32)
xs.append(paddle.to_tensor(x_np))
xs_np.append(x_np)
graph = CUDAGraph()
graph.capture_begin()
y = paddle.concat(xs)
zs = paddle.split(y, len(xs))
graph.capture_end()
graph.replay()
y_np = y.numpy()
y_np_expected = np.concatenate(xs_np)
self.assertTrue(np.array_equal(y_np, y_np_expected))
self.assertEqual(len(zs), len(xs_np))
for i, z in enumerate(zs):
self.assertTrue(np.array_equal(z.numpy(), xs_np[i]))
output_dir = 'cuda_graph_dot_{}'.format(os.getpid())
try:
graph.print_to_dot_files(pathlib.Path(output_dir))
graph.reset()
shutil.rmtree(output_dir)
except Exception as e:
msg = str(e)
sub_msg = "The print_to_dot_files() method is only supported when CUDA version >= 11.3"
self.assertTrue(sub_msg in msg)
finally:
graph.reset()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册