未验证 提交 8c73c1b5 编写于 作者: A Aurelius84 提交者: GitHub

Support Reset for DeviceEvent (#35443)

* Support Reset for DeviceEvent

* fix code

* add more unittest
上级 c2f76b0a
......@@ -23,8 +23,9 @@ EventCreateFunction DeviceEvent::event_creator_[MaxDeviceTypes];
EventRecordFunction DeviceEvent::event_recorder_[MaxDeviceTypes];
EventQueryFunction DeviceEvent::event_querier_[MaxDeviceTypes];
EventFinishFunction DeviceEvent::event_finisher_[MaxDeviceTypes];
EventFinishFunction DeviceEvent::event_finished_setter_[MaxDeviceTypes];
EventSetFinishedFunction DeviceEvent::event_finished_setter_[MaxDeviceTypes];
EventWaitFunction DeviceEvent::event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
EventResetFunction DeviceEvent::event_resetter_[MaxDeviceTypes];
/*
* Generate flag used to create event on all sorts of equipment.
......@@ -60,9 +61,20 @@ void DeviceEventRecordCPU(DeviceEvent* event, const DeviceContext* context) {
auto* wrapper = static_cast<CPUDeviceEventWrapper*>(event->GetEvent().get());
std::unique_lock<std::mutex> lock(wrapper->mutex_);
PADDLE_ENFORCE_NE(wrapper->status_.load(), EventStatus::SCHEDULED,
// NOTE: As for CudaEvent_t, it can be used to Record() repeatly. CudaEvent_t
// internally reset its status from finished into initialized.
// So we simulate the process here.
if (wrapper->status_.load() == EventStatus::SUCCESS) {
VLOG(3) << "Found EventStatus is SUCCESS before RecordCPU. Reset it into "
"INITIALIZED.";
wrapper->status_ = EventStatus::INITIALIZED;
}
PADDLE_ENFORCE_LT(
wrapper->status_.load(), EventStatus::SCHEDULED,
platform::errors::PreconditionNotMet(
"EventStatus shall be not SCHEDULED before Record()"));
"EventStatus shall be not SCHEDULED before Record(), but received %s",
wrapper->status_.load()));
if (wrapper->status_ == EventStatus::INITIALIZED) {
wrapper->status_ = EventStatus::SCHEDULED;
}
......@@ -104,6 +116,12 @@ void EventSetFinishedCPU(const DeviceEvent* event) {
wrapper->cv_completed_.notify_all();
}
void EventResetCPU(const DeviceEvent* event) {
auto* wrapper = static_cast<CPUDeviceEventWrapper*>(event->GetEvent().get());
std::unique_lock<std::mutex> lock(wrapper->mutex_);
wrapper->status_ = EventStatus::INITIALIZED;
}
} // namespace platform
} // namespace paddle
......@@ -113,6 +131,7 @@ REGISTER_EVENT_RECORD_FUNCTION(kCPU, paddle::platform::DeviceEventRecordCPU)
REGISTER_EVENT_QUERY_FUNCTION(kCPU, paddle::platform::DeviceEventQueryCPU)
REGISTER_EVENT_FINISH_FUNCTION(kCPU, paddle::platform::DeviceEventFinishCPU)
REGISTER_EVENT_SET_FINISHED_FUNCTION(kCPU,
paddle::platform::EventSetFinishedCPU);
paddle::platform::EventSetFinishedCPU)
REGISTER_EVENT_WAIT_FUNCTION(kCPU, kCPU,
paddle::platform::DeviceEventCPUWaitCPU)
REGISTER_EVENT_RESET_FUNCTION(kCPU, paddle::platform::EventResetCPU)
......@@ -34,6 +34,7 @@ typedef bool (*EventQueryFunction)(const DeviceEvent*);
typedef void (*EventFinishFunction)(const DeviceEvent*);
typedef void (*EventSetFinishedFunction)(const DeviceEvent*);
typedef void (*EventWaitFunction)(const DeviceEvent*, const DeviceContext*);
typedef void (*EventResetFunction)(const DeviceEvent*);
inline int DeviceTypeToId(const DeviceType& device_type) {
return static_cast<int>(device_type);
......@@ -104,6 +105,14 @@ class DeviceEvent {
event_finished_setter_[type_id_](this);
}
void Reset() {
PADDLE_ENFORCE_NOT_NULL(
event_resetter_[type_id_],
platform::errors::Unavailable(
"event_resetter_[%d] shall not be nullptr.", type_id_));
event_resetter_[type_id_](this);
}
void Wait(const DeviceType& waiter_type, const DeviceContext* context) const {
auto waiter_idx = DeviceTypeToId(waiter_type);
PADDLE_ENFORCE_NOT_NULL(event_waiter_[waiter_idx][type_id_],
......@@ -127,8 +136,9 @@ class DeviceEvent {
static EventRecordFunction event_recorder_[MaxDeviceTypes];
static EventQueryFunction event_querier_[MaxDeviceTypes];
static EventFinishFunction event_finisher_[MaxDeviceTypes];
static EventFinishFunction event_finished_setter_[MaxDeviceTypes];
static EventSetFinishedFunction event_finished_setter_[MaxDeviceTypes];
static EventWaitFunction event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
static EventResetFunction event_resetter_[MaxDeviceTypes];
template <DeviceType device_typ>
friend struct EventCreateFunctionRegisterer;
......@@ -147,6 +157,9 @@ class DeviceEvent {
template <DeviceType waiter_typ, DeviceType event_type>
friend struct EventWaitFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventResetFunctionRegisterer;
};
/**
......@@ -287,12 +300,34 @@ struct EventWaitFunctionRegisterer : public framework::Registrar {
return 0; \
}
// =============== Register for Reset ===============
template <DeviceType device_type>
struct EventResetFunctionRegisterer : public framework::Registrar {
explicit EventResetFunctionRegisterer(EventResetFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_resetter with type_id :" << type_idx;
DeviceEvent::event_resetter_[type_idx] = func;
}
};
#define REGISTER_EVENT_RESET_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_resetter__##device_type, \
"REGISTER_EVENT_RESET_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventResetFunctionRegisterer<device_type> \
__reg_event_resetter_##device_type##__(func); \
int TouchDeviceEventReset##device_type() { \
__reg_event_resetter_##device_type##__.Touch(); \
return 0; \
}
#define USE_EVENT(device_type) \
extern int TouchDeviceEventCreate##device_type(); \
extern int TouchDeviceEventRecord##device_type(); \
extern int TouchDeviceEventQuery##device_type(); \
extern int TouchDeviceEventFinish##device_type(); \
extern int TouchDeviceEventSetFinished##device_type(); \
extern int TouchDeviceEventReset##device_type(); \
UNUSED static int use_event_creator_##device_type = \
TouchDeviceEventCreate##device_type(); \
UNUSED static int use_event_recorder_##device_type = \
......@@ -302,7 +337,9 @@ struct EventWaitFunctionRegisterer : public framework::Registrar {
UNUSED static int use_event_finisher_##device_type = \
TouchDeviceEventFinish##device_type(); \
UNUSED static int use_event_finished_setter_##device_type = \
TouchDeviceEventSetFinished##device_type();
TouchDeviceEventSetFinished##device_type(); \
UNUSED static int use_event_resetter_##device_type = \
TouchDeviceEventReset##device_type();
#define USE_EVENT_WAIT(waiter_type, event_type) \
extern int TouchDeviceEventWait##waiter_type##event_type(); \
......
......@@ -95,6 +95,10 @@ void DeviceEventSetFinishedCUDA(const DeviceEvent* event) {
// do nothing
}
void EventResetCUDA(const DeviceEvent* event) {
// do nothing
}
} // namespace platform
} // namespace paddle
......@@ -110,4 +114,5 @@ REGISTER_EVENT_WAIT_FUNCTION(kCUDA, kCUDA,
paddle::platform::DeviceEventCUDAWaitCUDA)
REGISTER_EVENT_WAIT_FUNCTION(kCPU, kCUDA,
paddle::platform::DeviceEventCPUWaitCUDA)
REGISTER_EVENT_RESET_FUNCTION(kCUDA, paddle::platform::EventResetCUDA)
#endif
......@@ -84,4 +84,12 @@ TEST(DeviceEvent, CPU) {
event.SetFininshed();
bool status = event.Query();
ASSERT_EQ(status, true);
// test for Record again
event.Record(context);
status = event.Query();
ASSERT_EQ(status, false); // SCHEDULED
event.Reset();
ASSERT_EQ(status, false); // INITIALIZED
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册