未验证 提交 3218075d 编写于 作者: S Siming Dai 提交者: GitHub

Add paddle.cuda.device.stream_guard API (#35623)

Add paddle.cuda.device.stream_guard API 
上级 a9577347
...@@ -411,10 +411,11 @@ void CUDAContext::InitEigenContext() { ...@@ -411,10 +411,11 @@ void CUDAContext::InitEigenContext() {
} }
CUDAContext::CUDAContext(const CUDAPlace& place, CUDAContext::CUDAContext(const CUDAPlace& place,
const stream::Priority& priority) { const stream::Priority& priority,
const stream::StreamFlag& flag) {
place_ = place; place_ = place;
CUDADeviceGuard guard(place_.device); CUDADeviceGuard guard(place_.device);
stream_.reset(new stream::CUDAStream(place, priority)); stream_.reset(new stream::CUDAStream(place, priority, flag));
InitEigenContext(); InitEigenContext();
InitCuBlasContext(); InitCuBlasContext();
InitCuDNNContext(); InitCuDNNContext();
......
...@@ -272,7 +272,8 @@ class CUDAContext { ...@@ -272,7 +272,8 @@ class CUDAContext {
CUDAContext() = default; CUDAContext() = default;
explicit CUDAContext( explicit CUDAContext(
const CUDAPlace& place, const CUDAPlace& place,
const stream::Priority& priority = stream::Priority::kNormal); const stream::Priority& priority = stream::Priority::kNormal,
const stream::StreamFlag& flag = stream::StreamFlag::kDefaultFlag);
~CUDAContext(); ~CUDAContext();
...@@ -288,6 +289,12 @@ class CUDAContext { ...@@ -288,6 +289,12 @@ class CUDAContext {
const std::unique_ptr<stream::CUDAStream>& Stream() const { return stream_; } const std::unique_ptr<stream::CUDAStream>& Stream() const { return stream_; }
stream::CUDAStream* SetStream(stream::CUDAStream* new_stream_ptr) {
auto* old_stream_ptr = stream_.release();
stream_.reset(new_stream_ptr);
return old_stream_ptr;
}
const gpuStream_t& RawStream() { return stream_->raw_stream(); } const gpuStream_t& RawStream() { return stream_->raw_stream(); }
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
......
...@@ -21,13 +21,8 @@ namespace paddle { ...@@ -21,13 +21,8 @@ namespace paddle {
namespace platform { namespace platform {
namespace stream { namespace stream {
#ifdef PADDLE_WITH_HIP bool CUDAStream::Init(const Place& place, const Priority& priority,
constexpr unsigned int kDefaultFlag = hipStreamDefault; const StreamFlag& flag) {
#else
constexpr unsigned int kDefaultFlag = cudaStreamDefault;
#endif
bool CUDAStream::Init(const Place& place, const Priority& priority) {
PADDLE_ENFORCE_EQ(is_gpu_place(place), true, PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Cuda stream must be created using cuda place.")); "Cuda stream must be created using cuda place."));
...@@ -35,24 +30,25 @@ bool CUDAStream::Init(const Place& place, const Priority& priority) { ...@@ -35,24 +30,25 @@ bool CUDAStream::Init(const Place& place, const Priority& priority) {
CUDADeviceGuard guard(BOOST_GET_CONST(CUDAPlace, place_).device); CUDADeviceGuard guard(BOOST_GET_CONST(CUDAPlace, place_).device);
if (priority == Priority::kHigh) { if (priority == Priority::kHigh) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreateWithPriority(
hipStreamCreateWithPriority(&stream_, kDefaultFlag, -1)); &stream_, static_cast<unsigned int>(flag), -1));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreateWithPriority(
cudaStreamCreateWithPriority(&stream_, kDefaultFlag, -1)); &stream_, static_cast<unsigned int>(flag), -1));
#endif #endif
} else if (priority == Priority::kNormal) { } else if (priority == Priority::kNormal) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreateWithPriority(
hipStreamCreateWithPriority(&stream_, kDefaultFlag, 0)); &stream_, static_cast<unsigned int>(flag), 0));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreateWithPriority(
cudaStreamCreateWithPriority(&stream_, kDefaultFlag, 0)); &stream_, static_cast<unsigned int>(flag), 0));
#endif #endif
} }
callback_manager_.reset(new StreamCallbackManager<gpuStream_t>(stream_)); callback_manager_.reset(new StreamCallbackManager<gpuStream_t>(stream_));
VLOG(3) << "GPUStream Init stream: " << stream_ VLOG(3) << "GPUStream Init stream: " << stream_
<< ", priority: " << static_cast<int>(priority); << ", priority: " << static_cast<int>(priority)
<< ", flag:" << static_cast<int>(flag);
return true; return true;
} }
...@@ -118,6 +114,19 @@ CUDAStream* get_current_stream(int deviceId) { ...@@ -118,6 +114,19 @@ CUDAStream* get_current_stream(int deviceId) {
#endif #endif
} }
CUDAStream* set_current_stream(CUDAStream* stream) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto& device = stream->GetPlace();
auto& pool = platform::DeviceContextPool::Instance();
return static_cast<platform::CUDADeviceContext*>(pool.Get(device))
->context()
->SetStream(stream);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CUDA. Cannot visit cuda current stream."));
return nullptr;
#endif
}
} // namespace stream } // namespace stream
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -33,18 +33,27 @@ enum class Priority : uint8_t { ...@@ -33,18 +33,27 @@ enum class Priority : uint8_t {
kHigh = 0x1, kHigh = 0x1,
kNormal = 0x2, kNormal = 0x2,
}; };
enum class StreamFlag : uint8_t {
kDefaultFlag = 0x0,
kStreamNonBlocking = 0x1,
};
#endif #endif
class CUDAStream final { class CUDAStream final {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
public: public:
CUDAStream() = default; CUDAStream() = default;
explicit CUDAStream(const Place& place, explicit CUDAStream(const Place& place,
const Priority& priority = Priority::kNormal) { const Priority& priority = Priority::kNormal,
Init(place, priority); const StreamFlag& flag = StreamFlag::kDefaultFlag) {
Init(place, priority, flag);
} }
virtual ~CUDAStream() { Destroy(); } virtual ~CUDAStream() { Destroy(); }
bool Init(const Place& place, const Priority& priority = Priority::kNormal); bool Init(const Place& place, const Priority& priority = Priority::kNormal,
const StreamFlag& flag = StreamFlag::kDefaultFlag);
template <typename Callback> template <typename Callback>
void AddCallback(Callback&& callback) const { void AddCallback(Callback&& callback) const {
...@@ -125,6 +134,8 @@ class CUDAStream final { ...@@ -125,6 +134,8 @@ class CUDAStream final {
#endif #endif
} }
const Place& GetPlace() const { return place_; }
private: private:
Place place_; Place place_;
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
...@@ -139,6 +150,7 @@ class CUDAStream final { ...@@ -139,6 +150,7 @@ class CUDAStream final {
}; };
CUDAStream* get_current_stream(int deviceId); CUDAStream* get_current_stream(int deviceId);
CUDAStream* set_current_stream(CUDAStream* stream);
} // namespace stream } // namespace stream
} // namespace platform } // namespace platform
......
...@@ -40,6 +40,18 @@ void BindCudaStream(py::module *m_ptr) { ...@@ -40,6 +40,18 @@ void BindCudaStream(py::module *m_ptr) {
}, },
py::return_value_policy::reference); py::return_value_policy::reference);
m.def("_set_current_stream",
[](paddle::platform::stream::CUDAStream &stream) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return paddle::platform::stream::set_current_stream(&stream);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CUDA. Cannot set cuda current "
"stream."));
#endif
},
py::return_value_policy::reference);
m.def("_device_synchronize", [](int device_id) { m.def("_device_synchronize", [](int device_id) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (device_id == -1) { if (device_id == -1) {
...@@ -69,7 +81,7 @@ void BindCudaStream(py::module *m_ptr) { ...@@ -69,7 +81,7 @@ void BindCudaStream(py::module *m_ptr) {
If device is positive integer, it must less than the device count. Default: None. If device is positive integer, it must less than the device count. Default: None.
priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal). priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal).
If prioriyt is None, the priority is 2(normal). Default: None. If priority is None, the priority is 2(normal). Default: None.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -200,6 +212,8 @@ void BindCudaStream(py::module *m_ptr) { ...@@ -200,6 +212,8 @@ void BindCudaStream(py::module *m_ptr) {
"Priority should be 1(high) or 2(normal) ")); "Priority should be 1(high) or 2(normal) "));
} }
auto prio = paddle::platform::stream::Priority(priority); auto prio = paddle::platform::stream::Priority(priority);
auto stream_flag =
paddle::platform::stream::StreamFlag::kStreamNonBlocking;
if (device == nullptr) { if (device == nullptr) {
int curr_device_id = platform::GetCurrentDeviceId(); int curr_device_id = platform::GetCurrentDeviceId();
...@@ -207,7 +221,8 @@ void BindCudaStream(py::module *m_ptr) { ...@@ -207,7 +221,8 @@ void BindCudaStream(py::module *m_ptr) {
device = &device_tmp; device = &device_tmp;
} }
new (&self) paddle::platform::stream::CUDAStream(*device, prio); new (&self) paddle::platform::stream::CUDAStream(*device, prio,
stream_flag);
#else #else
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"Class CUDAStream can only be initialized on the GPU platform.")); "Class CUDAStream can only be initialized on the GPU platform."));
...@@ -224,6 +239,8 @@ void BindCudaStream(py::module *m_ptr) { ...@@ -224,6 +239,8 @@ void BindCudaStream(py::module *m_ptr) {
"Priority should be 1(high) or 2(normal) ")); "Priority should be 1(high) or 2(normal) "));
} }
auto prio = paddle::platform::stream::Priority(priority); auto prio = paddle::platform::stream::Priority(priority);
auto stream_flag =
paddle::platform::stream::StreamFlag::kStreamNonBlocking;
int device_count = platform::GetCUDADeviceCount(); int device_count = platform::GetCUDADeviceCount();
if (device < 0) { if (device < 0) {
...@@ -236,7 +253,7 @@ void BindCudaStream(py::module *m_ptr) { ...@@ -236,7 +253,7 @@ void BindCudaStream(py::module *m_ptr) {
} }
new (&self) paddle::platform::stream::CUDAStream( new (&self) paddle::platform::stream::CUDAStream(
platform::CUDAPlace(device), prio); platform::CUDAPlace(device), prio, stream_flag);
#else #else
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"Class CUDAStream can only be initialized on the GPU platform.")); "Class CUDAStream can only be initialized on the GPU platform."));
...@@ -246,11 +263,13 @@ void BindCudaStream(py::module *m_ptr) { ...@@ -246,11 +263,13 @@ void BindCudaStream(py::module *m_ptr) {
.def("__init__", [](paddle::platform::stream::CUDAStream &self) { .def("__init__", [](paddle::platform::stream::CUDAStream &self) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto prio = paddle::platform::stream::Priority::kNormal; auto prio = paddle::platform::stream::Priority::kNormal;
auto stream_flag =
paddle::platform::stream::StreamFlag::kStreamNonBlocking;
int device_id = platform::GetCurrentDeviceId(); int device_id = platform::GetCurrentDeviceId();
new (&self) paddle::platform::stream::CUDAStream( new (&self) paddle::platform::stream::CUDAStream(
platform::CUDAPlace(device_id), prio); platform::CUDAPlace(device_id), prio, stream_flag);
#else #else
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"Class CUDAStream can only be initialized on the GPU platform.")); "Class CUDAStream can only be initialized on the GPU platform."));
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from .streams import Stream # noqa: F401 from .streams import Stream # noqa: F401
from .streams import Event # noqa: F401 from .streams import Event # noqa: F401
...@@ -24,6 +26,7 @@ __all__ = [ ...@@ -24,6 +26,7 @@ __all__ = [
'synchronize', 'synchronize',
'device_count', 'device_count',
'empty_cache', 'empty_cache',
'stream_guard',
] ]
...@@ -121,7 +124,7 @@ def device_count(): ...@@ -121,7 +124,7 @@ def device_count():
def empty_cache(): def empty_cache():
""" '''
Releases idle cached memory held by the allocator so that those can be used in other GPU Releases idle cached memory held by the allocator so that those can be used in other GPU
application and visible in `nvidia-smi`. In most cases you don't need to use this function, application and visible in `nvidia-smi`. In most cases you don't need to use this function,
Paddle does not release the memory back to the OS when you remove Tensors on the GPU, Paddle does not release the memory back to the OS when you remove Tensors on the GPU,
...@@ -137,7 +140,67 @@ def empty_cache(): ...@@ -137,7 +140,67 @@ def empty_cache():
tensor = paddle.randn([512, 512, 512], "float") tensor = paddle.randn([512, 512, 512], "float")
del tensor del tensor
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
""" '''
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
core.cuda_empty_cache() core.cuda_empty_cache()
def _set_current_stream(stream):
'''
Set the current stream.
Parameters:
stream(paddle.device.cuda.Stream): The selected stream.
Returns:
CUDAStream: The previous stream.
'''
if not isinstance(stream, paddle.device.cuda.Stream):
raise TypeError("stream type should be paddle.device.cuda.Stream")
cur_stream = current_stream()
if id(stream) == id(cur_stream):
return stream
return core._set_current_stream(stream)
@signature_safe_contextmanager
def stream_guard(stream):
'''
**Notes**:
**This API only supports dygraph mode currently.**
A context manager that specifies the current stream context by the given stream.
Parameters:
stream(paddle.device.cuda.Stream): the selected stream. If stream is None, just yield. The default value is None.
Examples:
.. code-block:: python
# required: gpu
import paddle
s = paddle.device.cuda.Stream()
data1 = paddle.ones(shape=[20])
data2 = paddle.ones(shape=[20])
with paddle.device.cuda.stream_guard(s):
data3 = data1 + data2
'''
if stream is not None and not isinstance(stream, paddle.device.cuda.Stream):
raise TypeError("stream type should be paddle.device.cuda.Stream")
cur_stream = current_stream()
if stream is None or id(stream) == id(cur_stream):
yield
else:
pre_stream = _set_current_stream(stream)
try:
yield
finally:
stream = _set_current_stream(pre_stream)
...@@ -276,6 +276,7 @@ if avx_supported(): ...@@ -276,6 +276,7 @@ if avx_supported():
from .core_avx import _set_cached_executor_build_strategy from .core_avx import _set_cached_executor_build_strategy
from .core_avx import _device_synchronize from .core_avx import _device_synchronize
from .core_avx import _get_current_stream from .core_avx import _get_current_stream
from .core_avx import _set_current_stream
if sys.platform != 'win32': if sys.platform != 'win32':
from .core_avx import _set_process_pids from .core_avx import _set_process_pids
from .core_avx import _erase_process_pids from .core_avx import _erase_process_pids
...@@ -328,6 +329,7 @@ if load_noavx: ...@@ -328,6 +329,7 @@ if load_noavx:
from .core_noavx import _set_cached_executor_build_strategy from .core_noavx import _set_cached_executor_build_strategy
from .core_noavx import _device_synchronize from .core_noavx import _device_synchronize
from .core_noavx import _get_current_stream from .core_noavx import _get_current_stream
from .core_noavx import _set_current_stream
if sys.platform != 'win32': if sys.platform != 'win32':
from .core_noavx import _set_process_pids from .core_noavx import _set_process_pids
from .core_noavx import _erase_process_pids from .core_noavx import _erase_process_pids
......
...@@ -16,6 +16,7 @@ from paddle.device import cuda ...@@ -16,6 +16,7 @@ from paddle.device import cuda
import paddle import paddle
import unittest import unittest
import numpy as np
class TestCurrentStream(unittest.TestCase): class TestCurrentStream(unittest.TestCase):
...@@ -104,5 +105,56 @@ class TestCUDAEvent(unittest.TestCase): ...@@ -104,5 +105,56 @@ class TestCUDAEvent(unittest.TestCase):
self.assertTrue(event_query_2) self.assertTrue(event_query_2)
class TestStreamGuard(unittest.TestCase):
'''
Note:
The asynchronous execution property of CUDA Stream can only be tested offline.
'''
def test_stream_guard_normal(self):
if paddle.is_compiled_with_cuda():
s = paddle.device.cuda.Stream()
a = paddle.to_tensor(np.array([0, 2, 4], dtype="int32"))
b = paddle.to_tensor(np.array([1, 3, 5], dtype="int32"))
c = a + b
with paddle.device.cuda.stream_guard(s):
d = a + b
self.assertTrue(np.array_equal(np.array(c), np.array(d)))
def test_stream_guard_default_stream(self):
if paddle.is_compiled_with_cuda():
s1 = paddle.device.cuda.current_stream()
with paddle.device.cuda.stream_guard(s1):
pass
s2 = paddle.device.cuda.current_stream()
self.assertTrue(id(s1) == id(s2))
def test_set_current_stream_default_stream(self):
if paddle.is_compiled_with_cuda():
cur_stream = paddle.device.cuda.current_stream()
new_stream = paddle.device.cuda._set_current_stream(cur_stream)
self.assertTrue(id(cur_stream) == id(new_stream))
def test_stream_guard_raise_error(self):
if paddle.is_compiled_with_cuda():
def test_not_correct_stream_guard_input():
tmp = np.zeros(5)
with paddle.device.cuda.stream_guard(tmp):
pass
self.assertRaises(TypeError, test_not_correct_stream_guard_input)
def test_set_current_stream_raise_error(self):
if paddle.is_compiled_with_cuda():
self.assertRaises(TypeError, paddle.device.cuda._set_current_stream,
np.zeros(5))
self.assertRaises(TypeError, paddle.device.cuda._set_current_stream,
None)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册