未验证 提交 8c73c1b5 编写于 作者: A Aurelius84 提交者: GitHub

Support Reset for DeviceEvent (#35443)

* Support Reset for DeviceEvent

* fix code

* add more unittest
上级 c2f76b0a
...@@ -23,8 +23,9 @@ EventCreateFunction DeviceEvent::event_creator_[MaxDeviceTypes]; ...@@ -23,8 +23,9 @@ EventCreateFunction DeviceEvent::event_creator_[MaxDeviceTypes];
EventRecordFunction DeviceEvent::event_recorder_[MaxDeviceTypes]; EventRecordFunction DeviceEvent::event_recorder_[MaxDeviceTypes];
EventQueryFunction DeviceEvent::event_querier_[MaxDeviceTypes]; EventQueryFunction DeviceEvent::event_querier_[MaxDeviceTypes];
EventFinishFunction DeviceEvent::event_finisher_[MaxDeviceTypes]; EventFinishFunction DeviceEvent::event_finisher_[MaxDeviceTypes];
EventFinishFunction DeviceEvent::event_finished_setter_[MaxDeviceTypes]; EventSetFinishedFunction DeviceEvent::event_finished_setter_[MaxDeviceTypes];
EventWaitFunction DeviceEvent::event_waiter_[MaxDeviceTypes][MaxDeviceTypes]; EventWaitFunction DeviceEvent::event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
EventResetFunction DeviceEvent::event_resetter_[MaxDeviceTypes];
/* /*
* Generate flag used to create event on all sorts of equipment. * Generate flag used to create event on all sorts of equipment.
...@@ -60,9 +61,20 @@ void DeviceEventRecordCPU(DeviceEvent* event, const DeviceContext* context) { ...@@ -60,9 +61,20 @@ void DeviceEventRecordCPU(DeviceEvent* event, const DeviceContext* context) {
auto* wrapper = static_cast<CPUDeviceEventWrapper*>(event->GetEvent().get()); auto* wrapper = static_cast<CPUDeviceEventWrapper*>(event->GetEvent().get());
std::unique_lock<std::mutex> lock(wrapper->mutex_); std::unique_lock<std::mutex> lock(wrapper->mutex_);
PADDLE_ENFORCE_NE(wrapper->status_.load(), EventStatus::SCHEDULED, // NOTE: As for CudaEvent_t, it can be used to Record() repeatly. CudaEvent_t
platform::errors::PreconditionNotMet( // internally reset its status from finished into initialized.
"EventStatus shall be not SCHEDULED before Record()")); // So we simulate the process here.
if (wrapper->status_.load() == EventStatus::SUCCESS) {
VLOG(3) << "Found EventStatus is SUCCESS before RecordCPU. Reset it into "
"INITIALIZED.";
wrapper->status_ = EventStatus::INITIALIZED;
}
PADDLE_ENFORCE_LT(
wrapper->status_.load(), EventStatus::SCHEDULED,
platform::errors::PreconditionNotMet(
"EventStatus shall be not SCHEDULED before Record(), but received %s",
wrapper->status_.load()));
if (wrapper->status_ == EventStatus::INITIALIZED) { if (wrapper->status_ == EventStatus::INITIALIZED) {
wrapper->status_ = EventStatus::SCHEDULED; wrapper->status_ = EventStatus::SCHEDULED;
} }
...@@ -104,6 +116,12 @@ void EventSetFinishedCPU(const DeviceEvent* event) { ...@@ -104,6 +116,12 @@ void EventSetFinishedCPU(const DeviceEvent* event) {
wrapper->cv_completed_.notify_all(); wrapper->cv_completed_.notify_all();
} }
void EventResetCPU(const DeviceEvent* event) {
auto* wrapper = static_cast<CPUDeviceEventWrapper*>(event->GetEvent().get());
std::unique_lock<std::mutex> lock(wrapper->mutex_);
wrapper->status_ = EventStatus::INITIALIZED;
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -113,6 +131,7 @@ REGISTER_EVENT_RECORD_FUNCTION(kCPU, paddle::platform::DeviceEventRecordCPU) ...@@ -113,6 +131,7 @@ REGISTER_EVENT_RECORD_FUNCTION(kCPU, paddle::platform::DeviceEventRecordCPU)
REGISTER_EVENT_QUERY_FUNCTION(kCPU, paddle::platform::DeviceEventQueryCPU) REGISTER_EVENT_QUERY_FUNCTION(kCPU, paddle::platform::DeviceEventQueryCPU)
REGISTER_EVENT_FINISH_FUNCTION(kCPU, paddle::platform::DeviceEventFinishCPU) REGISTER_EVENT_FINISH_FUNCTION(kCPU, paddle::platform::DeviceEventFinishCPU)
REGISTER_EVENT_SET_FINISHED_FUNCTION(kCPU, REGISTER_EVENT_SET_FINISHED_FUNCTION(kCPU,
paddle::platform::EventSetFinishedCPU); paddle::platform::EventSetFinishedCPU)
REGISTER_EVENT_WAIT_FUNCTION(kCPU, kCPU, REGISTER_EVENT_WAIT_FUNCTION(kCPU, kCPU,
paddle::platform::DeviceEventCPUWaitCPU) paddle::platform::DeviceEventCPUWaitCPU)
REGISTER_EVENT_RESET_FUNCTION(kCPU, paddle::platform::EventResetCPU)
...@@ -34,6 +34,7 @@ typedef bool (*EventQueryFunction)(const DeviceEvent*); ...@@ -34,6 +34,7 @@ typedef bool (*EventQueryFunction)(const DeviceEvent*);
typedef void (*EventFinishFunction)(const DeviceEvent*); typedef void (*EventFinishFunction)(const DeviceEvent*);
typedef void (*EventSetFinishedFunction)(const DeviceEvent*); typedef void (*EventSetFinishedFunction)(const DeviceEvent*);
typedef void (*EventWaitFunction)(const DeviceEvent*, const DeviceContext*); typedef void (*EventWaitFunction)(const DeviceEvent*, const DeviceContext*);
typedef void (*EventResetFunction)(const DeviceEvent*);
inline int DeviceTypeToId(const DeviceType& device_type) { inline int DeviceTypeToId(const DeviceType& device_type) {
return static_cast<int>(device_type); return static_cast<int>(device_type);
...@@ -104,6 +105,14 @@ class DeviceEvent { ...@@ -104,6 +105,14 @@ class DeviceEvent {
event_finished_setter_[type_id_](this); event_finished_setter_[type_id_](this);
} }
void Reset() {
PADDLE_ENFORCE_NOT_NULL(
event_resetter_[type_id_],
platform::errors::Unavailable(
"event_resetter_[%d] shall not be nullptr.", type_id_));
event_resetter_[type_id_](this);
}
void Wait(const DeviceType& waiter_type, const DeviceContext* context) const { void Wait(const DeviceType& waiter_type, const DeviceContext* context) const {
auto waiter_idx = DeviceTypeToId(waiter_type); auto waiter_idx = DeviceTypeToId(waiter_type);
PADDLE_ENFORCE_NOT_NULL(event_waiter_[waiter_idx][type_id_], PADDLE_ENFORCE_NOT_NULL(event_waiter_[waiter_idx][type_id_],
...@@ -127,8 +136,9 @@ class DeviceEvent { ...@@ -127,8 +136,9 @@ class DeviceEvent {
static EventRecordFunction event_recorder_[MaxDeviceTypes]; static EventRecordFunction event_recorder_[MaxDeviceTypes];
static EventQueryFunction event_querier_[MaxDeviceTypes]; static EventQueryFunction event_querier_[MaxDeviceTypes];
static EventFinishFunction event_finisher_[MaxDeviceTypes]; static EventFinishFunction event_finisher_[MaxDeviceTypes];
static EventFinishFunction event_finished_setter_[MaxDeviceTypes]; static EventSetFinishedFunction event_finished_setter_[MaxDeviceTypes];
static EventWaitFunction event_waiter_[MaxDeviceTypes][MaxDeviceTypes]; static EventWaitFunction event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
static EventResetFunction event_resetter_[MaxDeviceTypes];
template <DeviceType device_typ> template <DeviceType device_typ>
friend struct EventCreateFunctionRegisterer; friend struct EventCreateFunctionRegisterer;
...@@ -147,6 +157,9 @@ class DeviceEvent { ...@@ -147,6 +157,9 @@ class DeviceEvent {
template <DeviceType waiter_typ, DeviceType event_type> template <DeviceType waiter_typ, DeviceType event_type>
friend struct EventWaitFunctionRegisterer; friend struct EventWaitFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventResetFunctionRegisterer;
}; };
/** /**
...@@ -287,12 +300,34 @@ struct EventWaitFunctionRegisterer : public framework::Registrar { ...@@ -287,12 +300,34 @@ struct EventWaitFunctionRegisterer : public framework::Registrar {
return 0; \ return 0; \
} }
// =============== Register for Reset ===============
template <DeviceType device_type>
struct EventResetFunctionRegisterer : public framework::Registrar {
explicit EventResetFunctionRegisterer(EventResetFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_resetter with type_id :" << type_idx;
DeviceEvent::event_resetter_[type_idx] = func;
}
};
#define REGISTER_EVENT_RESET_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_resetter__##device_type, \
"REGISTER_EVENT_RESET_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventResetFunctionRegisterer<device_type> \
__reg_event_resetter_##device_type##__(func); \
int TouchDeviceEventReset##device_type() { \
__reg_event_resetter_##device_type##__.Touch(); \
return 0; \
}
#define USE_EVENT(device_type) \ #define USE_EVENT(device_type) \
extern int TouchDeviceEventCreate##device_type(); \ extern int TouchDeviceEventCreate##device_type(); \
extern int TouchDeviceEventRecord##device_type(); \ extern int TouchDeviceEventRecord##device_type(); \
extern int TouchDeviceEventQuery##device_type(); \ extern int TouchDeviceEventQuery##device_type(); \
extern int TouchDeviceEventFinish##device_type(); \ extern int TouchDeviceEventFinish##device_type(); \
extern int TouchDeviceEventSetFinished##device_type(); \ extern int TouchDeviceEventSetFinished##device_type(); \
extern int TouchDeviceEventReset##device_type(); \
UNUSED static int use_event_creator_##device_type = \ UNUSED static int use_event_creator_##device_type = \
TouchDeviceEventCreate##device_type(); \ TouchDeviceEventCreate##device_type(); \
UNUSED static int use_event_recorder_##device_type = \ UNUSED static int use_event_recorder_##device_type = \
...@@ -302,7 +337,9 @@ struct EventWaitFunctionRegisterer : public framework::Registrar { ...@@ -302,7 +337,9 @@ struct EventWaitFunctionRegisterer : public framework::Registrar {
UNUSED static int use_event_finisher_##device_type = \ UNUSED static int use_event_finisher_##device_type = \
TouchDeviceEventFinish##device_type(); \ TouchDeviceEventFinish##device_type(); \
UNUSED static int use_event_finished_setter_##device_type = \ UNUSED static int use_event_finished_setter_##device_type = \
TouchDeviceEventSetFinished##device_type(); TouchDeviceEventSetFinished##device_type(); \
UNUSED static int use_event_resetter_##device_type = \
TouchDeviceEventReset##device_type();
#define USE_EVENT_WAIT(waiter_type, event_type) \ #define USE_EVENT_WAIT(waiter_type, event_type) \
extern int TouchDeviceEventWait##waiter_type##event_type(); \ extern int TouchDeviceEventWait##waiter_type##event_type(); \
......
...@@ -95,6 +95,10 @@ void DeviceEventSetFinishedCUDA(const DeviceEvent* event) { ...@@ -95,6 +95,10 @@ void DeviceEventSetFinishedCUDA(const DeviceEvent* event) {
// do nothing // do nothing
} }
void EventResetCUDA(const DeviceEvent* event) {
// do nothing
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -110,4 +114,5 @@ REGISTER_EVENT_WAIT_FUNCTION(kCUDA, kCUDA, ...@@ -110,4 +114,5 @@ REGISTER_EVENT_WAIT_FUNCTION(kCUDA, kCUDA,
paddle::platform::DeviceEventCUDAWaitCUDA) paddle::platform::DeviceEventCUDAWaitCUDA)
REGISTER_EVENT_WAIT_FUNCTION(kCPU, kCUDA, REGISTER_EVENT_WAIT_FUNCTION(kCPU, kCUDA,
paddle::platform::DeviceEventCPUWaitCUDA) paddle::platform::DeviceEventCPUWaitCUDA)
REGISTER_EVENT_RESET_FUNCTION(kCUDA, paddle::platform::EventResetCUDA)
#endif #endif
...@@ -84,4 +84,12 @@ TEST(DeviceEvent, CPU) { ...@@ -84,4 +84,12 @@ TEST(DeviceEvent, CPU) {
event.SetFininshed(); event.SetFininshed();
bool status = event.Query(); bool status = event.Query();
ASSERT_EQ(status, true); ASSERT_EQ(status, true);
// test for Record again
event.Record(context);
status = event.Query();
ASSERT_EQ(status, false); // SCHEDULED
event.Reset();
ASSERT_EQ(status, false); // INITIALIZED
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册