From 40cfa5122da3a4dc1cf3fce6d46204ed9cdba46a Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Fri, 17 Sep 2021 13:08:54 +0800 Subject: [PATCH] expose cuda stream to users (#35813) * expose cuda stream to users * add ut --- paddle/fluid/pybind/cuda_streams_py.cc | 22 +++++++++++++++++++ .../tests/unittests/test_cuda_stream_event.py | 10 +++++++++ 2 files changed, 32 insertions(+) diff --git a/paddle/fluid/pybind/cuda_streams_py.cc b/paddle/fluid/pybind/cuda_streams_py.cc index 706012f4a44..311fb872ac1 100644 --- a/paddle/fluid/pybind/cuda_streams_py.cc +++ b/paddle/fluid/pybind/cuda_streams_py.cc @@ -202,6 +202,28 @@ void BindCudaStream(py::module *m_ptr) { )DOC", py::arg("event") = nullptr) + .def_property_readonly( + "cuda_stream", + [](paddle::platform::stream::CUDAStream &self) { + VLOG(10) << self.raw_stream(); + return reinterpret_cast(self.raw_stream()); + }, + R"DOC( + retrun the raw cuda stream of type cudaStream_t as type int. + + Examples: + .. code-block:: python + + # required: gpu + import paddle + import ctypes + cuda_stream = paddle.device.cuda.current_stream().cuda_stream + print(cuda_stream) + + ptr = ctypes.c_void_p(cuda_stream) # convert back to void* + print(ptr) + + )DOC") #endif .def("__init__", [](paddle::platform::stream::CUDAStream &self, diff --git a/python/paddle/fluid/tests/unittests/test_cuda_stream_event.py b/python/paddle/fluid/tests/unittests/test_cuda_stream_event.py index ec024105f89..30bc00c9d94 100644 --- a/python/paddle/fluid/tests/unittests/test_cuda_stream_event.py +++ b/python/paddle/fluid/tests/unittests/test_cuda_stream_event.py @@ -14,6 +14,7 @@ from paddle.device import cuda import paddle +import ctypes import unittest import numpy as np @@ -156,5 +157,14 @@ class TestStreamGuard(unittest.TestCase): None) +class TestRawStream(unittest.TestCase): + def test_cuda_stream(self): + if paddle.is_compiled_with_cuda(): + cuda_stream = paddle.device.cuda.current_stream().cuda_stream + print(cuda_stream) + self.assertTrue(type(cuda_stream) is int) + ptr = ctypes.c_void_p(cuda_stream) + + if __name__ == "__main__": unittest.main() -- GitLab