未验证 提交 383a08e1 编写于 作者: L Leo Chen 提交者: GitHub

fix cross step sync problem on npu (#50517)

上级 7cc47a1d
...@@ -85,6 +85,9 @@ class DeviceEvent { ...@@ -85,6 +85,9 @@ class DeviceEvent {
event_recorder_[type_id_], event_recorder_[type_id_],
platform::errors::Unavailable( platform::errors::Unavailable(
"event_recorder_[%d] shall not be nullptr.", type_id_)); "event_recorder_[%d] shall not be nullptr.", type_id_));
if (!recorded_) {
recorded_ = true;
}
event_recorder_[type_id_](this, dev_ctx); event_recorder_[type_id_](this, dev_ctx);
} }
...@@ -93,6 +96,10 @@ class DeviceEvent { ...@@ -93,6 +96,10 @@ class DeviceEvent {
event_querier_[type_id_], event_querier_[type_id_],
platform::errors::Unavailable( platform::errors::Unavailable(
"event_querier_[%d] shall not be nullptr.", type_id_)); "event_querier_[%d] shall not be nullptr.", type_id_));
if (!recorded_) {
VLOG(4) << "Event " << this << " is not recorded yet, and skip query!";
return true;
}
return event_querier_[type_id_](this); return event_querier_[type_id_](this);
} }
...@@ -127,6 +134,10 @@ class DeviceEvent { ...@@ -127,6 +134,10 @@ class DeviceEvent {
"event_waiter_[%d][%d] shall not be nullptr.", "event_waiter_[%d][%d] shall not be nullptr.",
waiter_idx, waiter_idx,
type_id_)); type_id_));
if (!recorded_) {
VLOG(4) << "Event " << this << " is not recorded yet, and skip wait!";
return;
}
event_waiter_[waiter_idx][type_id_](this, context); event_waiter_[waiter_idx][type_id_](this, context);
} }
...@@ -140,6 +151,14 @@ class DeviceEvent { ...@@ -140,6 +151,14 @@ class DeviceEvent {
int type_id_; int type_id_;
unsigned int flag_; unsigned int flag_;
// NOTE(chenruibiao): In cross-step stream synchronization, an event may be
// recorded in the first step and waited in the second step. So, in the first
// step, the WaitEvent may be called without RecordEvent.
// On cuda device, it is ok to wait event that is not recorded yet;
// while on npu device, it results in error.
// So, we add flag recorded_ to handle this case uniformly.
bool recorded_{false};
static EventCreateFunction event_creator_[MaxDeviceTypes]; static EventCreateFunction event_creator_[MaxDeviceTypes];
static EventRecordFunction event_recorder_[MaxDeviceTypes]; static EventRecordFunction event_recorder_[MaxDeviceTypes];
static EventQueryFunction event_querier_[MaxDeviceTypes]; static EventQueryFunction event_querier_[MaxDeviceTypes];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册