未验证 提交 9c7f6af5 编写于 作者: C chentianyu03 提交者: GitHub

Add Cuda event and stream API (#32460)

* add cuda event and stream api

* add cuda event and stream api

* add get_current_stream api

* add get_current_stream api

* init streams

* modify get_current_stream

* modify get_cuttent_stream

* add synchronize func

* add current_stream doc and test file

* move get_current_stream into CUDA macro

* move CudaEvent into CUDA macro

* move _get_current_stream and _device_synchronize into cuda macro

* modify the macro of cuda stream and event

* add test case for synchronize

* add paddle.devices.cuda module

* event and stream support hip

* add doc for stream and event class

* move cuda stream and event into single pybind

* add cuda_streams_py.cc to cmakelist

* add _device_synchronize and _get_current_stream to core module

* add test case for cudastream and cudaevent

* move __all__ in streams.py

* fix test fail

* add cuda to devices __all__

* fix current_stream doc writing error

* move devices to device direction, and merge device.py into __init__.py

* add required:gpu to sample codes

* remove cuda direction from device/__init__.py
上级 8992e63a
......@@ -22,6 +22,7 @@ limitations under the License. */
#include <hip/hip_runtime.h>
#endif
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
namespace paddle {
namespace platform {
......@@ -117,5 +118,98 @@ class MemEvent {
std::string annotation_;
};
class CudaEvent {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
public:
CudaEvent() {
#ifdef PADDLE_WITH_HIP
hipEventCreateWithFlags(&event_, flags_);
#else
cudaEventCreateWithFlags(&event_, flags_);
#endif
}
CudaEvent(unsigned int flags) : flags_(flags) {
#ifdef PADDLE_WITH_HIP
hipEventCreateWithFlags(&event_, flags_);
#else
cudaEventCreateWithFlags(&event_, flags_);
#endif
}
void Record(paddle::platform::stream::CUDAStream& stream) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event_, stream.raw_stream()));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event_, stream.raw_stream()));
#endif
}
bool Query() {
#ifdef PADDLE_WITH_HIP
gpuError_t err = hipEventQuery(event_);
if (err == hipSuccess) {
return true;
}
if (err == hipErrorNotReady) {
return false;
}
#else
gpuError_t err = cudaEventQuery(event_);
if (err == cudaSuccess) {
return true;
}
if (err == cudaErrorNotReady) {
return false;
}
#endif
PADDLE_ENFORCE_CUDA_SUCCESS(err);
return false;
}
void Synchronize() {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventSynchronize(event_));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventSynchronize(event_));
#endif
}
gpuEvent_t GetRawCudaEvent() { return event_; }
private:
#ifdef PADDLE_WITH_HIP
unsigned int flags_ = hipEventDefault;
#else
unsigned int flags_ = cudaEventDefault;
#endif
gpuEvent_t event_;
#endif
};
static unsigned int get_cuda_flags(bool enable_timing, bool blocking,
bool interprocess) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef PADDLE_WITH_HIP
unsigned int flags =
(blocking ? hipEventBlockingSync : hipEventDefault) |
(enable_timing ? hipEventDefault : hipEventDisableTiming) |
(interprocess ? hipEventInterprocess : hipEventDefault);
return flags;
#else
unsigned int flags =
(blocking ? cudaEventBlockingSync : cudaEventDefault) |
(enable_timing ? cudaEventDefault : cudaEventDisableTiming) |
(interprocess ? cudaEventInterprocess : cudaEventDefault);
return flags;
#endif
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CUDA. Cannot get the cuda event flags."));
return 0;
#endif
}
} // namespace platform
} // namespace paddle
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
......@@ -95,6 +96,28 @@ void CUDAStream::Wait() const {
PADDLE_ENFORCE_CUDA_SUCCESS(e_sync);
}
CUDAStream* get_current_stream(int deviceId) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (deviceId == -1) {
deviceId = platform::GetCurrentDeviceId();
}
auto& pool = platform::DeviceContextPool::Instance();
platform::Place device = CUDAPlace(deviceId);
auto stream = static_cast<platform::CUDADeviceContext*>(pool.Get(device))
->context()
->Stream()
.get();
return stream;
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CUDA. Cannot visit cuda current stream."));
return nullptr;
#endif
}
} // namespace stream
} // namespace platform
} // namespace paddle
......@@ -33,8 +33,9 @@ enum class Priority : uint8_t {
kHigh = 0x1,
kNormal = 0x2,
};
#endif
class CUDAStream final {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
public:
CUDAStream() = default;
explicit CUDAStream(const Place& place,
......@@ -93,6 +94,37 @@ class CUDAStream final {
#endif
void Destroy();
bool Query() const {
#ifdef PADDLE_WITH_HIP
hipError_t err = hipStreamQuery(stream_);
if (err == hipSuccess) {
return true;
}
if (err == hipErrorNotReady) {
return false;
}
#else
cudaError_t err = cudaStreamQuery(stream_);
if (err == cudaSuccess) {
return true;
}
if (err == cudaErrorNotReady) {
return false;
}
#endif
PADDLE_ENFORCE_CUDA_SUCCESS(err);
return false;
}
void Synchronize() const {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream_));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_));
#endif
}
private:
Place place_;
#ifdef PADDLE_WITH_HIP
......@@ -102,11 +134,11 @@ class CUDAStream final {
#endif
Priority priority_{Priority::kNormal};
std::unique_ptr<StreamCallbackManager<gpuStream_t>> callback_manager_;
#endif
DISABLE_COPY_AND_ASSIGN(CUDAStream);
};
#endif
CUDAStream* get_current_stream(int deviceId);
} // namespace stream
} // namespace platform
......
......@@ -57,7 +57,8 @@ set(PYBIND_SRCS
inference_api.cc
compatible.cc
io.cc
generator_py.cc)
generator_py.cc
cuda_streams_py.cc)
if(WITH_ASCEND)
set(PYBIND_DEPS ${PYBIND_DEPS} ascend_wrapper)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/platform/event.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/fluid/pybind/cuda_streams_py.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindCudaStream(py::module *m_ptr) {
auto &m = *m_ptr;
// Bind Methods
m.def("_get_current_stream",
[](int deviceId) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return paddle::platform::stream::get_current_stream(deviceId);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CUDA. Cannot visit cuda current "
"stream."));
#endif
},
py::return_value_policy::reference);
m.def("_device_synchronize", [](int device_id) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (device_id == -1) {
device_id = paddle::platform::GetCurrentDeviceId();
}
int curr_device_id = paddle::platform::GetCurrentDeviceId();
paddle::platform::SetDeviceId(device_id);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipDeviceSynchronize());
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaDeviceSynchronize());
#endif
paddle::platform::SetDeviceId(curr_device_id);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CUDA. Cannot visit device synchronize."));
#endif
});
py::class_<paddle::platform::stream::CUDAStream>(m, "CUDAStream", R"DOC(
The handle of the CUDA stream.
Parameters:
device(paddle.CUDAPlace()|int|None, optional): The device which wanted to allocate the stream.
If device is None or negative integer, device will be the current device.
If device is positive integer, it must less than the device count. Default: None.
priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal).
If prioriyt is None, the priority is 2(normal). Default: None.
Examples:
.. code-block:: python
# required: gpu
import paddle
s1 = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
s2 = paddle.device.cuda.Stream(0, 1)
s3 = paddle.device.cuda.Stream()
)DOC")
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("wait_event",
[](paddle::platform::stream::CUDAStream &self,
paddle::platform::CudaEvent &event) {
self.WaitEvent(event.GetRawCudaEvent());
},
R"DOC(
Makes all future work submitted to stream wait for all work captured in event.
Parameters:
event(CUDAEvent): The event to wait on.
Examples:
.. code-block:: python
# required: gpu
import paddle
s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
event = paddle.device.cuda.Event()
s.wait_event(event)
)DOC")
.def("wait_stream",
[](paddle::platform::stream::CUDAStream &self,
paddle::platform::stream::CUDAStream &stream) {
auto event = paddle::platform::CudaEvent();
event.Record(stream);
self.WaitEvent(event.GetRawCudaEvent());
},
R"DOC(
Synchronizes with the given stream.
Parameters:
stream(CUDAStream): The stream to synchronize with.
Examples:
.. code-block:: python
# required: gpu
import paddle
s1 = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
s2 = paddle.device.cuda.Stream(0, 1)
s1.wait_stream(s2)
)DOC")
.def("query",
[](paddle::platform::stream::CUDAStream &self) {
return self.Query();
},
R"DOC(
Return the status whether if all operations in stream have completed.
Returns: A boolean value.
Examples:
.. code-block:: python
# required: gpu
import paddle
s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
is_done = s.query()
)DOC")
.def("synchronize",
[](paddle::platform::stream::CUDAStream &self) {
self.Synchronize();
},
R"DOC(
Waits for stream tasks to complete.
Examples:
.. code-block:: python
# required: gpu
import paddle
s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
s.synchronize()
)DOC")
.def("record_event",
[](paddle::platform::stream::CUDAStream &self,
paddle::platform::CudaEvent *event) {
if (event == nullptr) {
auto event_tmp = paddle::platform::CudaEvent();
event = &event_tmp;
}
event->Record(self);
return event;
},
R"DOC(
Record a CUDA event in the stream.
Parameters:
event(CUDAEvent, optional): The event to be record. If event is None, a new event is created.
Default: None.
Returns:
The recored event.
Examples:
.. code-block:: python
# required: gpu
import paddle
s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
event = s.record_event()
)DOC",
py::arg("event") = nullptr)
#endif
.def("__init__",
[](paddle::platform::stream::CUDAStream &self,
platform::CUDAPlace *device, int priority) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (priority != 1 && priority != 2) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Priority should be 1(high) or 2(normal) "));
}
auto prio = paddle::platform::stream::Priority(priority);
if (device == nullptr) {
int curr_device_id = platform::GetCurrentDeviceId();
auto device_tmp = platform::CUDAPlace(curr_device_id);
device = &device_tmp;
}
new (&self) paddle::platform::stream::CUDAStream(*device, prio);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Class CUDAStream can only be initialized on the GPU platform."));
#endif
},
py::arg("device") = nullptr, py::arg("priority") = 2)
.def(
"__init__",
[](paddle::platform::stream::CUDAStream &self, int device,
int priority) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (priority != 1 && priority != 2) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Priority should be 1(high) or 2(normal) "));
}
auto prio = paddle::platform::stream::Priority(priority);
int device_count = platform::GetCUDADeviceCount();
if (device < 0) {
device = platform::GetCurrentDeviceId();
}
if (device >= device_count) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The device id must be inside [0, %d), but input device=%d.",
device_count, device));
}
new (&self) paddle::platform::stream::CUDAStream(
platform::CUDAPlace(device), prio);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Class CUDAStream can only be initialized on the GPU platform."));
#endif
},
py::arg("device") = -1, py::arg("priority") = 2)
.def("__init__", [](paddle::platform::stream::CUDAStream &self) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto prio = paddle::platform::stream::Priority::kNormal;
int device_id = platform::GetCurrentDeviceId();
new (&self) paddle::platform::stream::CUDAStream(
platform::CUDAPlace(device_id), prio);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Class CUDAStream can only be initialized on the GPU platform."));
#endif
});
py::class_<paddle::platform::CudaEvent>(m, "CUDAEvent", R"DOC(
The handle of the CUDA event.
Parameters:
enable_timing(bool, optional): Whether the event will measure time. Default: False.
blocking(bool, optional): Whether the wait() func will be blocking. Default: False;
interprocess(bool, optional): Whether the event can be shared between processes. Defalut: False.
Examples:
.. code-block:: python
# required: gpu
import paddle
event = paddle.device.cuda.Event()
)DOC")
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("record",
[](paddle::platform::CudaEvent &self,
paddle::platform::stream::CUDAStream *stream) {
if (stream == nullptr) {
stream = paddle::platform::stream::get_current_stream(-1);
}
self.Record(*stream);
},
R"DOC(
Records the event in the given stream.
Parameters:
stream(CUDAStream, optional): The handle of CUDA stream. If None, the stream is the current stream. Default: None.
Examples:
.. code-block:: python
# required: gpu
import paddle
event = paddle.device.cuda.Event()
event.record()
)DOC",
py::arg("stream") = nullptr)
.def("query",
[](paddle::platform::CudaEvent &self) { return self.Query(); },
R"DOC(
Queries the event's status.
Returns: A boolean which indicates all work currently captured by the event has been completed.
Examples:
.. code-block:: python
# required: gpu
import paddle
event = paddle.device.cuda.Event()
is_done = event.query()
)DOC")
.def("synchronize",
[](paddle::platform::CudaEvent &self) { self.Synchronize(); }, R"DOC(
Waits for an event to complete.
Examples:
.. code-block:: python
# required: gpu
import paddle
event = paddle.device.cuda.Event()
event.synchronize()
)DOC")
#endif
.def("__init__",
[](paddle::platform::CudaEvent &self, bool enable_timing,
bool blocking, bool interprocess) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
unsigned int flags = platform::get_cuda_flags(
enable_timing, blocking, interprocess);
new (&self) paddle::platform::CudaEvent(flags);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Class CUDAEvent can only be initialized on the GPU "
"platform."));
#endif
},
py::arg("enable_timing") = false, py::arg("blocking") = false,
py::arg("interprocess") = false);
}
} // namespace pybind
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindCudaStream(py::module* m);
} // namespace pybind
} // namespace paddle
......@@ -70,6 +70,7 @@ limitations under the License. */
#include "paddle/fluid/platform/monitor.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/cuda_streams_py.h"
#include "paddle/fluid/pybind/io.h"
#ifdef PADDLE_WITH_ASCEND
#include "paddle/fluid/pybind/ascend_wrapper_py.h"
......@@ -77,6 +78,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/box_helper_py.h"
#include "paddle/fluid/pybind/compatible.h"
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/cuda_streams_py.h"
#include "paddle/fluid/pybind/data_set_py.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/fleet_wrapper_py.h"
......@@ -471,6 +473,8 @@ PYBIND11_MODULE(core_avx, m) {
PYBIND11_MODULE(core_noavx, m) {
#endif
BindCudaStream(&m);
// Not used, just make sure cpu_info.cc is linked.
paddle::platform::CpuTotalPhysicalMemory();
......
......@@ -52,6 +52,7 @@ import paddle.metric # noqa: F401
import paddle.regularizer # noqa: F401
import paddle.incubate # noqa: F401
import paddle.autograd # noqa: F401
import paddle.device # noqa: F401
import paddle.jit # noqa: F401
import paddle.amp # noqa: F401
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -20,7 +20,7 @@ from paddle.fluid import framework
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.framework import is_compiled_with_cuda # noqa: F401
from paddle.fluid.framework import is_compiled_with_rocm # noqa: F401
from . import cuda
__all__ = [ # noqa
'get_cudnn_version',
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import core
from .streams import Stream # noqa: F401
from .streams import Event # noqa: F401
__all__ = [
'Stream',
'Event',
'current_stream',
'synchronize',
]
def current_stream(device=None):
'''
Return the current CUDA stream by the device.
Parameters:
device(paddle.CUDAPlace()|int, optional): The device or the ID of the device which want to get stream from.
If device is None, the device is the current device. Default: None.
Returns:
CUDAStream: the stream to the device.
Examples:
.. code-block:: python
# required: gpu
import paddle
s1 = paddle.device.cuda.current_stream()
s2 = paddle.device.cuda.current_stream(0)
s3 = paddle.device.cuda.current_stream(paddle.CUDAPlace(0))
'''
device_id = -1
if device is not None:
if isinstance(device, int):
device_id = device
elif isinstance(device, core.CUDAPlace):
device_id = device.get_device_id()
else:
raise ValueError("device type must be int or paddle.CUDAPlace")
return core._get_current_stream(device_id)
def synchronize(device=None):
'''
Wait for the compute on the given CUDA device to finish.
Parameters:
device(paddle.CUDAPlace()|int, optional): The device or the ID of the device.
If device is None, the device is the current device. Default: None.
Examples:
.. code-block:: python
# required: gpu
import paddle
paddle.device.cuda.synchronize()
paddle.device.cuda.synchronize(0)
paddle.device.cuda.synchronize(paddle.CUDAPlace(0))
'''
device_id = -1
if device is not None:
if isinstance(device, int):
device_id = device
elif isinstance(device, core.CUDAPlace):
device_id = device.get_device_id()
else:
raise ValueError("device type must be int or paddle.CUDAPlace")
return core._device_synchronize(device_id)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.core import CUDAStream as Stream
from paddle.fluid.core import CUDAEvent as Event
......@@ -274,6 +274,8 @@ if avx_supported():
from .core_avx import _cuda_synchronize
from .core_avx import _is_compiled_with_heterps
from .core_avx import _promote_types_if_complex_exists
from .core_avx import _device_synchronize
from .core_avx import _get_current_stream
if sys.platform != 'win32':
from .core_avx import _set_process_pids
from .core_avx import _erase_process_pids
......@@ -323,6 +325,8 @@ if load_noavx:
from .core_noavx import _cuda_synchronize
from .core_noavx import _is_compiled_with_heterps
from .core_noavx import _promote_types_if_complex_exists
from .core_noavx import _device_synchronize
from .core_noavx import _get_current_stream
if sys.platform != 'win32':
from .core_noavx import _set_process_pids
from .core_noavx import _erase_process_pids
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.device import cuda
import paddle
import unittest
class TestCurrentStream(unittest.TestCase):
def test_current_stream(self):
if paddle.is_compiled_with_cuda():
s = cuda.current_stream()
self.assertTrue(isinstance(s, cuda.Stream))
s1 = cuda.current_stream(0)
self.assertTrue(isinstance(s1, cuda.Stream))
s2 = cuda.current_stream(paddle.CUDAPlace(0))
self.assertTrue(isinstance(s2, cuda.Stream))
self.assertEqual(s1, s2)
self.assertRaises(ValueError, cuda.current_stream, "gpu:0")
class TestSynchronize(unittest.TestCase):
def test_synchronize(self):
if paddle.is_compiled_with_cuda():
self.assertIsNone(cuda.synchronize())
self.assertIsNone(cuda.synchronize(0))
self.assertIsNone(cuda.synchronize(paddle.CUDAPlace(0)))
self.assertRaises(ValueError, cuda.synchronize, "gpu:0")
class TestCUDAStream(unittest.TestCase):
def test_cuda_stream(self):
if paddle.is_compiled_with_cuda():
s = paddle.device.cuda.Stream()
self.assertIsNotNone(s)
def test_cuda_stream_synchronize(self):
if paddle.is_compiled_with_cuda():
s = paddle.device.cuda.Stream()
e1 = paddle.device.cuda.Event(True, False, False)
e2 = paddle.device.cuda.Event(True, False, False)
e1.record(s)
e1.query()
tensor1 = paddle.to_tensor(paddle.rand([1000, 1000]))
tensor2 = paddle.matmul(tensor1, tensor1)
s.synchronize()
e2.record(s)
e2.synchronize()
self.assertTrue(s.query())
def test_cuda_stream_wait_event_and_record_event(self):
if paddle.is_compiled_with_cuda():
s1 = cuda.Stream(0)
tensor1 = paddle.to_tensor(paddle.rand([1000, 1000]))
tensor2 = paddle.matmul(tensor1, tensor1)
e1 = cuda.Event(False, False, False)
s1.record_event(e1)
s2 = cuda.Stream(0)
s2.wait_event(e1)
s2.synchronize()
self.assertTrue(e1.query() and s1.query() and s2.query())
class TestCUDAEvent(unittest.TestCase):
def test_cuda_event(self):
if paddle.is_compiled_with_cuda():
e = paddle.device.cuda.Event(True, False, False)
self.assertIsNotNone(e)
s = paddle.device.cuda.current_stream()
def test_cuda_event_methods(self):
if paddle.is_compiled_with_cuda():
e = paddle.device.cuda.Event(True, False, False)
s = paddle.device.cuda.current_stream()
event_query_1 = e.query()
tensor1 = paddle.to_tensor(paddle.rand([1000, 1000]))
tensor2 = paddle.matmul(tensor1, tensor1)
s.record_event(e)
e.synchronize()
event_query_2 = e.query()
self.assertTrue(event_query_1)
self.assertTrue(event_query_2)
if __name__ == "__main__":
unittest.main()
......@@ -224,6 +224,8 @@ packages=['paddle',
'paddle.tensor',
'paddle.onnx',
'paddle.autograd',
'paddle.device',
'paddle.device.cuda',
]
with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册