未验证 提交 20cfa8ba 编写于 作者: A Aurelius84 提交者: GitHub

Abstract GenerateDeviceEventFlag to shield platforms (#35219)

* Abstract GenerateDeviceEventFlag to shield platforms

* Remove get_cuda_flags
上级 31cd1065
...@@ -77,7 +77,7 @@ void AssociateInputWithEvents( ...@@ -77,7 +77,7 @@ void AssociateInputWithEvents(
for (auto var_id : new_event_var_id) { for (auto var_id : new_event_var_id) {
if (var_id2event->count(var_id) == 0) { if (var_id2event->count(var_id) == 0) {
auto device_event = std::make_shared<platform::DeviceEvent>( auto device_event = std::make_shared<platform::DeviceEvent>(
place, platform::get_cuda_flags(false, false, false)); place, platform::GenerateDeviceEventFlag());
var_id2event->emplace(var_id, std::move(device_event)); var_id2event->emplace(var_id, std::move(device_event));
} }
// Add events for next_instr.inputs // Add events for next_instr.inputs
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/platform/device_event_base.h" #include "paddle/fluid/platform/device_event_base.h"
#include "paddle/fluid/platform/device_event_cpu.h" #include "paddle/fluid/platform/device_event_cpu.h"
#include "paddle/fluid/platform/event.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -25,6 +26,31 @@ EventFinishFunction DeviceEvent::event_finisher_[MaxDeviceTypes]; ...@@ -25,6 +26,31 @@ EventFinishFunction DeviceEvent::event_finisher_[MaxDeviceTypes];
EventFinishFunction DeviceEvent::event_finished_setter_[MaxDeviceTypes]; EventFinishFunction DeviceEvent::event_finished_setter_[MaxDeviceTypes];
EventWaitFunction DeviceEvent::event_waiter_[MaxDeviceTypes][MaxDeviceTypes]; EventWaitFunction DeviceEvent::event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
/*
* Generate flag used to create event on all sorts of equipment.
* NOTE: Support CPU/CUDA/ROCM currently.
*/
unsigned int GenerateDeviceEventFlag(bool enable_timing, bool blocking,
bool interprocess) {
#ifdef PADDLE_WITH_CUDA
unsigned int flags =
(blocking ? cudaEventBlockingSync : cudaEventDefault) |
(enable_timing ? cudaEventDefault : cudaEventDisableTiming) |
(interprocess ? cudaEventInterprocess : cudaEventDefault);
return flags;
#endif
#ifdef PADDLE_WITH_HIP
unsigned int flags =
(blocking ? hipEventBlockingSync : hipEventDefault) |
(enable_timing ? hipEventDefault : hipEventDisableTiming) |
(interprocess ? hipEventInterprocess : hipEventDefault);
return flags;
#endif
return 0;
}
void DeviceEventCreateCPU(DeviceEvent* event, const platform::Place& place, void DeviceEventCreateCPU(DeviceEvent* event, const platform::Place& place,
unsigned int flag) { unsigned int flag) {
event->InitEvent(std::make_shared<CPUDeviceEventWrapper>(place, flag)); event->InitEvent(std::make_shared<CPUDeviceEventWrapper>(place, flag));
......
...@@ -39,6 +39,10 @@ inline int DeviceTypeToId(const DeviceType& device_type) { ...@@ -39,6 +39,10 @@ inline int DeviceTypeToId(const DeviceType& device_type) {
return static_cast<int>(device_type); return static_cast<int>(device_type);
} }
unsigned int GenerateDeviceEventFlag(bool enable_timing = false,
bool blocking = false,
bool interprocess = false);
enum EventStatus { enum EventStatus {
INITIALIZED = 0, INITIALIZED = 0,
SCHEDULED = 1, SCHEDULED = 1,
......
...@@ -195,30 +195,5 @@ class CudaEvent { ...@@ -195,30 +195,5 @@ class CudaEvent {
#endif #endif
}; };
static unsigned int get_cuda_flags(bool enable_timing, bool blocking,
bool interprocess) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef PADDLE_WITH_HIP
unsigned int flags =
(blocking ? hipEventBlockingSync : hipEventDefault) |
(enable_timing ? hipEventDefault : hipEventDisableTiming) |
(interprocess ? hipEventInterprocess : hipEventDefault);
return flags;
#else
unsigned int flags =
(blocking ? cudaEventBlockingSync : cudaEventDefault) |
(enable_timing ? cudaEventDefault : cudaEventDisableTiming) |
(interprocess ? cudaEventInterprocess : cudaEventDefault);
return flags;
#endif
#else
PADDLE_THROW(platform::errors::Unavailable(
"Paddle is not compiled with CUDA. Cannot get the cuda event flags."));
return 0;
#endif
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/platform/device_event_base.h"
#include "paddle/fluid/platform/event.h" #include "paddle/fluid/platform/event.h"
#include "paddle/fluid/platform/stream/cuda_stream.h" #include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/fluid/pybind/cuda_streams_py.h" #include "paddle/fluid/pybind/cuda_streams_py.h"
...@@ -331,7 +332,7 @@ void BindCudaStream(py::module *m_ptr) { ...@@ -331,7 +332,7 @@ void BindCudaStream(py::module *m_ptr) {
[](paddle::platform::CudaEvent &self, bool enable_timing, [](paddle::platform::CudaEvent &self, bool enable_timing,
bool blocking, bool interprocess) { bool blocking, bool interprocess) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
unsigned int flags = platform::get_cuda_flags( unsigned int flags = platform::GenerateDeviceEventFlag(
enable_timing, blocking, interprocess); enable_timing, blocking, interprocess);
new (&self) paddle::platform::CudaEvent(flags); new (&self) paddle::platform::CudaEvent(flags);
#else #else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册