diff --git a/paddle/fluid/platform/device_event_base.h b/paddle/fluid/platform/device_event_base.h index d0458dcb9e4e46ba4836e862264db0d36b048b31..f56688de09a326e6846d8d423fbaf66e9c0f6e9f 100644 --- a/paddle/fluid/platform/device_event_base.h +++ b/paddle/fluid/platform/device_event_base.h @@ -85,6 +85,9 @@ class DeviceEvent { event_recorder_[type_id_], platform::errors::Unavailable( "event_recorder_[%d] shall not be nullptr.", type_id_)); + if (!recorded_) { + recorded_ = true; + } event_recorder_[type_id_](this, dev_ctx); } @@ -93,6 +96,10 @@ class DeviceEvent { event_querier_[type_id_], platform::errors::Unavailable( "event_querier_[%d] shall not be nullptr.", type_id_)); + if (!recorded_) { + VLOG(4) << "Event " << this << " is not recorded yet, and skip query!"; + return true; + } return event_querier_[type_id_](this); } @@ -127,6 +134,10 @@ class DeviceEvent { "event_waiter_[%d][%d] shall not be nullptr.", waiter_idx, type_id_)); + if (!recorded_) { + VLOG(4) << "Event " << this << " is not recorded yet, and skip wait!"; + return; + } event_waiter_[waiter_idx][type_id_](this, context); } @@ -140,6 +151,14 @@ class DeviceEvent { int type_id_; unsigned int flag_; + // NOTE(chenruibiao): In cross-step stream synchronization, an event may be + // recorded in the first step and waited in the second step. So, in the first + // step, the WaitEvent may be called without RecordEvent. + // On cuda device, it is ok to wait event that is not recorded yet; + // while on npu device, it results in error. + // So, we add flag recorded_ to handle this case uniformly. + bool recorded_{false}; + static EventCreateFunction event_creator_[MaxDeviceTypes]; static EventRecordFunction event_recorder_[MaxDeviceTypes]; static EventQueryFunction event_querier_[MaxDeviceTypes];