diff --git a/paddle/fluid/platform/device_event_base.cc b/paddle/fluid/platform/device_event_base.cc index 0cd1cff556b3a484982980ef60d9dd2006bca107..67fad3857f2c142870ad08a14a82210a76a48cb9 100644 --- a/paddle/fluid/platform/device_event_base.cc +++ b/paddle/fluid/platform/device_event_base.cc @@ -23,8 +23,9 @@ EventCreateFunction DeviceEvent::event_creator_[MaxDeviceTypes]; EventRecordFunction DeviceEvent::event_recorder_[MaxDeviceTypes]; EventQueryFunction DeviceEvent::event_querier_[MaxDeviceTypes]; EventFinishFunction DeviceEvent::event_finisher_[MaxDeviceTypes]; -EventFinishFunction DeviceEvent::event_finished_setter_[MaxDeviceTypes]; +EventSetFinishedFunction DeviceEvent::event_finished_setter_[MaxDeviceTypes]; EventWaitFunction DeviceEvent::event_waiter_[MaxDeviceTypes][MaxDeviceTypes]; +EventResetFunction DeviceEvent::event_resetter_[MaxDeviceTypes]; /* * Generate flag used to create event on all sorts of equipment. @@ -60,9 +61,20 @@ void DeviceEventRecordCPU(DeviceEvent* event, const DeviceContext* context) { auto* wrapper = static_cast(event->GetEvent().get()); std::unique_lock lock(wrapper->mutex_); - PADDLE_ENFORCE_NE(wrapper->status_.load(), EventStatus::SCHEDULED, - platform::errors::PreconditionNotMet( - "EventStatus shall be not SCHEDULED before Record()")); + // NOTE: As for CudaEvent_t, it can be used to Record() repeatly. CudaEvent_t + // internally reset its status from finished into initialized. + // So we simulate the process here. + if (wrapper->status_.load() == EventStatus::SUCCESS) { + VLOG(3) << "Found EventStatus is SUCCESS before RecordCPU. Reset it into " + "INITIALIZED."; + wrapper->status_ = EventStatus::INITIALIZED; + } + + PADDLE_ENFORCE_LT( + wrapper->status_.load(), EventStatus::SCHEDULED, + platform::errors::PreconditionNotMet( + "EventStatus shall be not SCHEDULED before Record(), but received %s", + wrapper->status_.load())); if (wrapper->status_ == EventStatus::INITIALIZED) { wrapper->status_ = EventStatus::SCHEDULED; } @@ -104,6 +116,12 @@ void EventSetFinishedCPU(const DeviceEvent* event) { wrapper->cv_completed_.notify_all(); } +void EventResetCPU(const DeviceEvent* event) { + auto* wrapper = static_cast(event->GetEvent().get()); + std::unique_lock lock(wrapper->mutex_); + wrapper->status_ = EventStatus::INITIALIZED; +} + } // namespace platform } // namespace paddle @@ -113,6 +131,7 @@ REGISTER_EVENT_RECORD_FUNCTION(kCPU, paddle::platform::DeviceEventRecordCPU) REGISTER_EVENT_QUERY_FUNCTION(kCPU, paddle::platform::DeviceEventQueryCPU) REGISTER_EVENT_FINISH_FUNCTION(kCPU, paddle::platform::DeviceEventFinishCPU) REGISTER_EVENT_SET_FINISHED_FUNCTION(kCPU, - paddle::platform::EventSetFinishedCPU); + paddle::platform::EventSetFinishedCPU) REGISTER_EVENT_WAIT_FUNCTION(kCPU, kCPU, paddle::platform::DeviceEventCPUWaitCPU) +REGISTER_EVENT_RESET_FUNCTION(kCPU, paddle::platform::EventResetCPU) diff --git a/paddle/fluid/platform/device_event_base.h b/paddle/fluid/platform/device_event_base.h index e2e1fdf29d51aafdcfa25de72e6a3c532befe5d9..e018de9577beb4799cd2ad0e7e811b166d2d12bf 100644 --- a/paddle/fluid/platform/device_event_base.h +++ b/paddle/fluid/platform/device_event_base.h @@ -34,6 +34,7 @@ typedef bool (*EventQueryFunction)(const DeviceEvent*); typedef void (*EventFinishFunction)(const DeviceEvent*); typedef void (*EventSetFinishedFunction)(const DeviceEvent*); typedef void (*EventWaitFunction)(const DeviceEvent*, const DeviceContext*); +typedef void (*EventResetFunction)(const DeviceEvent*); inline int DeviceTypeToId(const DeviceType& device_type) { return static_cast(device_type); @@ -104,6 +105,14 @@ class DeviceEvent { event_finished_setter_[type_id_](this); } + void Reset() { + PADDLE_ENFORCE_NOT_NULL( + event_resetter_[type_id_], + platform::errors::Unavailable( + "event_resetter_[%d] shall not be nullptr.", type_id_)); + event_resetter_[type_id_](this); + } + void Wait(const DeviceType& waiter_type, const DeviceContext* context) const { auto waiter_idx = DeviceTypeToId(waiter_type); PADDLE_ENFORCE_NOT_NULL(event_waiter_[waiter_idx][type_id_], @@ -127,8 +136,9 @@ class DeviceEvent { static EventRecordFunction event_recorder_[MaxDeviceTypes]; static EventQueryFunction event_querier_[MaxDeviceTypes]; static EventFinishFunction event_finisher_[MaxDeviceTypes]; - static EventFinishFunction event_finished_setter_[MaxDeviceTypes]; + static EventSetFinishedFunction event_finished_setter_[MaxDeviceTypes]; static EventWaitFunction event_waiter_[MaxDeviceTypes][MaxDeviceTypes]; + static EventResetFunction event_resetter_[MaxDeviceTypes]; template friend struct EventCreateFunctionRegisterer; @@ -147,6 +157,9 @@ class DeviceEvent { template friend struct EventWaitFunctionRegisterer; + + template + friend struct EventResetFunctionRegisterer; }; /** @@ -287,12 +300,34 @@ struct EventWaitFunctionRegisterer : public framework::Registrar { return 0; \ } +// =============== Register for Reset =============== +template +struct EventResetFunctionRegisterer : public framework::Registrar { + explicit EventResetFunctionRegisterer(EventResetFunction func) { + auto type_idx = DeviceTypeToId(device_type); + VLOG(3) << "register event_resetter with type_id :" << type_idx; + DeviceEvent::event_resetter_[type_idx] = func; + } +}; + +#define REGISTER_EVENT_RESET_FUNCTION(device_type, func) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_event_resetter__##device_type, \ + "REGISTER_EVENT_RESET_FUNCTION must be called in global namespace"); \ + static ::paddle::platform::EventResetFunctionRegisterer \ + __reg_event_resetter_##device_type##__(func); \ + int TouchDeviceEventReset##device_type() { \ + __reg_event_resetter_##device_type##__.Touch(); \ + return 0; \ + } + #define USE_EVENT(device_type) \ extern int TouchDeviceEventCreate##device_type(); \ extern int TouchDeviceEventRecord##device_type(); \ extern int TouchDeviceEventQuery##device_type(); \ extern int TouchDeviceEventFinish##device_type(); \ extern int TouchDeviceEventSetFinished##device_type(); \ + extern int TouchDeviceEventReset##device_type(); \ UNUSED static int use_event_creator_##device_type = \ TouchDeviceEventCreate##device_type(); \ UNUSED static int use_event_recorder_##device_type = \ @@ -302,7 +337,9 @@ struct EventWaitFunctionRegisterer : public framework::Registrar { UNUSED static int use_event_finisher_##device_type = \ TouchDeviceEventFinish##device_type(); \ UNUSED static int use_event_finished_setter_##device_type = \ - TouchDeviceEventSetFinished##device_type(); + TouchDeviceEventSetFinished##device_type(); \ + UNUSED static int use_event_resetter_##device_type = \ + TouchDeviceEventReset##device_type(); #define USE_EVENT_WAIT(waiter_type, event_type) \ extern int TouchDeviceEventWait##waiter_type##event_type(); \ diff --git a/paddle/fluid/platform/device_event_gpu.cc b/paddle/fluid/platform/device_event_gpu.cc index 252ee893bb2820d051a8f34a09563acf2f9b7d89..bc842ef9c74de8c69d79cd7128785913b6540549 100644 --- a/paddle/fluid/platform/device_event_gpu.cc +++ b/paddle/fluid/platform/device_event_gpu.cc @@ -95,6 +95,10 @@ void DeviceEventSetFinishedCUDA(const DeviceEvent* event) { // do nothing } +void EventResetCUDA(const DeviceEvent* event) { + // do nothing +} + } // namespace platform } // namespace paddle @@ -110,4 +114,5 @@ REGISTER_EVENT_WAIT_FUNCTION(kCUDA, kCUDA, paddle::platform::DeviceEventCUDAWaitCUDA) REGISTER_EVENT_WAIT_FUNCTION(kCPU, kCUDA, paddle::platform::DeviceEventCPUWaitCUDA) +REGISTER_EVENT_RESET_FUNCTION(kCUDA, paddle::platform::EventResetCUDA) #endif diff --git a/paddle/fluid/platform/device_event_test.cc b/paddle/fluid/platform/device_event_test.cc index b25f9772a6c2e93df9d38ce8c40464c4bded913e..a56d94b892e98cecb0b7b2c0d0bb67b9fee2bd29 100644 --- a/paddle/fluid/platform/device_event_test.cc +++ b/paddle/fluid/platform/device_event_test.cc @@ -84,4 +84,12 @@ TEST(DeviceEvent, CPU) { event.SetFininshed(); bool status = event.Query(); ASSERT_EQ(status, true); + + // test for Record again + event.Record(context); + status = event.Query(); + ASSERT_EQ(status, false); // SCHEDULED + + event.Reset(); + ASSERT_EQ(status, false); // INITIALIZED }