未验证 提交 40cfa512 编写于 作者: L Leo Chen 提交者: GitHub

expose cuda stream to users (#35813)

* expose cuda stream to users

* add ut
上级 05275010
...@@ -202,6 +202,28 @@ void BindCudaStream(py::module *m_ptr) { ...@@ -202,6 +202,28 @@ void BindCudaStream(py::module *m_ptr) {
)DOC", )DOC",
py::arg("event") = nullptr) py::arg("event") = nullptr)
.def_property_readonly(
"cuda_stream",
[](paddle::platform::stream::CUDAStream &self) {
VLOG(10) << self.raw_stream();
return reinterpret_cast<std::uintptr_t>(self.raw_stream());
},
R"DOC(
retrun the raw cuda stream of type cudaStream_t as type int.
Examples:
.. code-block:: python
# required: gpu
import paddle
import ctypes
cuda_stream = paddle.device.cuda.current_stream().cuda_stream
print(cuda_stream)
ptr = ctypes.c_void_p(cuda_stream) # convert back to void*
print(ptr)
)DOC")
#endif #endif
.def("__init__", .def("__init__",
[](paddle::platform::stream::CUDAStream &self, [](paddle::platform::stream::CUDAStream &self,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from paddle.device import cuda from paddle.device import cuda
import paddle import paddle
import ctypes
import unittest import unittest
import numpy as np import numpy as np
...@@ -156,5 +157,14 @@ class TestStreamGuard(unittest.TestCase): ...@@ -156,5 +157,14 @@ class TestStreamGuard(unittest.TestCase):
None) None)
class TestRawStream(unittest.TestCase):
def test_cuda_stream(self):
if paddle.is_compiled_with_cuda():
cuda_stream = paddle.device.cuda.current_stream().cuda_stream
print(cuda_stream)
self.assertTrue(type(cuda_stream) is int)
ptr = ctypes.c_void_p(cuda_stream)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册