From 48bf7cbf13c41a67576237830f76e3806e8d6c12 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 27 Aug 2021 16:26:53 +0800 Subject: [PATCH] Polish DeviceEvent interface and Remove #ifdef in InterpreterCore (#35196) * add CPUDeiveEvent * Polish DeviceEvent code * Add DEVICE_EVENT_LIBS --- .../framework/new_executor/CMakeLists.txt | 2 +- .../framework/new_executor/interpretercore.cc | 76 ++--- .../framework/new_executor/interpretercore.h | 3 +- .../new_executor/new_executor_defs.h | 6 +- paddle/fluid/platform/CMakeLists.txt | 15 +- paddle/fluid/platform/device_event.cc | 27 -- paddle/fluid/platform/device_event.h | 277 +--------------- paddle/fluid/platform/device_event_base.cc | 92 ++++++ paddle/fluid/platform/device_event_base.h | 309 ++++++++++++++++++ paddle/fluid/platform/device_event_cpu.h | 51 +++ paddle/fluid/platform/device_event_gpu.cc | 38 ++- paddle/fluid/platform/device_event_test.cc | 37 ++- 12 files changed, 562 insertions(+), 371 deletions(-) delete mode 100644 paddle/fluid/platform/device_event.cc create mode 100644 paddle/fluid/platform/device_event_base.cc create mode 100644 paddle/fluid/platform/device_event_base.h create mode 100644 paddle/fluid/platform/device_event_cpu.h diff --git a/paddle/fluid/framework/new_executor/CMakeLists.txt b/paddle/fluid/framework/new_executor/CMakeLists.txt index 13962f9852..23f0e02bb3 100644 --- a/paddle/fluid/framework/new_executor/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/CMakeLists.txt @@ -2,7 +2,7 @@ cc_library(workqueue SRCS workqueue.cc) cc_library(interpretercore SRCS interpretercore.cc DEPS op_registry device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method - graph_to_program_pass variable_helper timer monitor workqueue device_event device_event_gpu) + graph_to_program_pass variable_helper timer monitor workqueue ${DEVICE_EVENT_LIBS}) cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore) cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue) # cc_binary(standalone_executor_test SRCS standalone_executor_test.cc DEPS interpretercore standalone_executor operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler) diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 277560a5ec..0f2ad0ff33 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -12,16 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/new_executor/interpretercore.h" -#include "paddle/fluid/framework/executor_gc_helper.h" -#include "paddle/fluid/framework/new_executor/interpretercore_gc_helper.h" - -#if defined(PADDLE_WITH_CUDA) -using ::paddle::platform::kCUDA; -USE_EVENT(kCUDA); -#endif #include +#include "paddle/fluid/framework/executor_gc_helper.h" +#include "paddle/fluid/framework/new_executor/interpretercore_gc_helper.h" + namespace paddle { namespace framework { @@ -74,27 +70,26 @@ std::vector ParseEventVarIds(const Instruction& cur_instr, } void AssociateInputWithEvents( - const std::vector& new_event_var_id, Instruction* next_instr, - std::map>* var_id2event, + const platform::Place& place, const std::vector& new_event_var_id, + Instruction* next_instr, + std::map>* var_id2event, bool is_sync) { -#ifdef PADDLE_WITH_CUDA for (auto var_id : new_event_var_id) { if (var_id2event->count(var_id) == 0) { - auto cuda_event = std::make_shared( - platform::get_cuda_flags(false, false, false)); - var_id2event->emplace(var_id, std::move(cuda_event)); + auto device_event = std::make_shared( + place, platform::get_cuda_flags(false, false, false)); + var_id2event->emplace(var_id, std::move(device_event)); } // Add events for next_instr.inputs next_instr->intput_events_.emplace_back(var_id, var_id2event->at(var_id), is_sync); } -#endif } void ParseDirectAndEventRunOps( - const std::vector& op_func_nodes, + const platform::Place& place, const std::vector& op_func_nodes, const std::vector& downstream_ops, size_t op_index, - std::map>* var_id2event, + std::map>* var_id2event, std::vector* instructions) { auto& op_func_type = op_func_nodes[op_index].type_; auto& cur_instr = instructions->at(op_index); @@ -119,8 +114,8 @@ void ParseDirectAndEventRunOps( bool is_sync = (op_func_nodes[next_op_id].type_ == OpFuncType::kQueueSync); - AssociateInputWithEvents(new_event_var_ids, &next_instr, var_id2event, - is_sync); + AssociateInputWithEvents(place, new_event_var_ids, &next_instr, + var_id2event, is_sync); if (is_sync) { // GPU -> CPU next_instruction.synchronize_run_.emplace_back(next_op_id); @@ -128,7 +123,6 @@ void ParseDirectAndEventRunOps( next_instruction.event_wait_run_.emplace_back(next_op_id); } } -#ifdef PADDLE_WITH_CUDA // Create events for these cross-stream vars VLOG(3) << cur_instr.kernel_func_.operator_base_->Type() << " event_var_ids.size: " << event_var_ids.size(); @@ -136,7 +130,6 @@ void ParseDirectAndEventRunOps( cur_instr.output_events_.emplace_back(var_id, var_id2event->at(var_id), false /*not used*/); } -#endif } } } // namespace @@ -263,12 +256,10 @@ void InterpreterCore::Convert() { } for (size_t i = 0; i < vec_instruction_.size(); ++i) { -#if defined(PADDLE_WITH_CUDA) - int device_type = static_cast(paddle::platform::DeviceType::CUDA); - paddle::platform::DeviceOption dev_opt( - device_type, BOOST_GET_CONST(platform::CUDAPlace, place_).device); - gc_event_.emplace_back(dev_opt); -#endif + // int device_type = static_cast(paddle::platform::DeviceType::CUDA); + // paddle::platform::DeviceOption dev_opt( + // device_type, BOOST_GET_CONST(platform::CUDAPlace, place_).device); + gc_event_.emplace_back(place_); std::vector vec_temp; for (auto& item : vec_instruction_[i].output_index_) { @@ -287,8 +278,8 @@ void InterpreterCore::Convert() { } } - ParseDirectAndEventRunOps(vec_func_list_, filter_next, i, &var_id2event_, - &vec_instruction_); + ParseDirectAndEventRunOps(place_, vec_func_list_, filter_next, i, + &var_id2event_, &vec_instruction_); // checkout ouput for (auto& item : vec_instruction_[i].output_index_) { @@ -466,7 +457,7 @@ void InterpreterCore::CheckGC(size_t instr_id, #if defined(PADDLE_WITH_CUDA) auto* dev_ctx = reinterpret_cast( platform::DeviceContextPool::Instance().Get(place)); - gc_event_[instr_id].Record(place, dev_ctx); + gc_event_[instr_id].Record(dev_ctx); gc_queue_->AddTask( [ container = garbages_.release(), event = &gc_event_[instr_id] ]() { while (!event->Query()) { @@ -483,7 +474,7 @@ void InterpreterCore::CheckGC(size_t instr_id, #if defined(PADDLE_WITH_CUDA) auto* dev_ctx = reinterpret_cast( platform::DeviceContextPool::Instance().Get(place)); - gc_event_[instr_id].Record(place, dev_ctx); + gc_event_[instr_id].Record(dev_ctx); gc_queue_->AddTask( [ container = garbages_.release(), event = &gc_event_[instr_id] ]() { while (!event->Query()) { @@ -857,34 +848,23 @@ void InterpreterCore::RecordEventInstruction(const Instruction& instruction, // If InterpreterCore in on CPUPlace, do nothing. if (platform::is_cpu_place(place_)) return; -#ifdef PADDLE_WITH_CUDA - const platform::CUDADeviceContext* dev_ctx = - reinterpret_cast( - instruction.dev_ctx_); for (auto& event : instruction.output_events_) { VLOG(3) << "Record event in out_var_id: " << event.var_id_; - event.event_->Record(*(dev_ctx->context()->Stream())); + event.event_->Record(instruction.dev_ctx_); } -#endif } void InterpreterCore::WaitOrSync(const std::vector& events, const platform::DeviceContext* dev_ctx) { -#ifdef PADDLE_WITH_CUDA - auto* cuda_dev_ctx = - reinterpret_cast(dev_ctx); - - for (auto& event : events) { - if (event.is_sync_) { - VLOG(3) << "host sync wait in_var_id " << event.var_id_; - event.event_->Synchronize(); + for (auto& event_iter : events) { + if (event_iter.is_sync_) { + VLOG(3) << "host sync wait in_var_id " << event_iter.var_id_; + event_iter.event_->Wait(platform::kCPU, dev_ctx); } else { - VLOG(3) << "stream async wait in_var_id " << event.var_id_; - cuda_dev_ctx->context()->Stream()->WaitEvent( - event.event_->GetRawCudaEvent()); + VLOG(3) << "stream async wait in_var_id " << event_iter.var_id_; + event_iter.event_->Wait(platform::kCUDA, dev_ctx); } } -#endif } void InterpreterCore::StreamWaitEventOrSync(const Instruction& instruction) { diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index de47f3cba8..652e4fff7c 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -26,7 +26,6 @@ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/device_event.h" -#include "paddle/fluid/platform/event.h" namespace paddle { namespace framework { @@ -101,7 +100,7 @@ class InterpreterCore { bool is_build_; std::vector feed_names_; - std::map> var_id2event_; + std::map> var_id2event_; std::vector gc_event_; std::unique_ptr garbages_; diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index e1fee63b70..e9697b3c82 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -19,6 +19,7 @@ #include #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/device_event_base.h" #include "paddle/fluid/platform/event.h" namespace paddle { @@ -56,11 +57,12 @@ struct NextInstruction { }; struct EventInter { - explicit EventInter(size_t var_id, std::shared_ptr event, + explicit EventInter(size_t var_id, + std::shared_ptr event, bool is_sync) : var_id_(var_id), event_(event), is_sync_(is_sync) {} size_t var_id_; - std::shared_ptr event_; + std::shared_ptr event_; bool is_sync_; }; diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index fab0909c01..d99f991911 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -151,19 +151,28 @@ endif() cc_test(init_test SRCS init_test.cc DEPS device_context) -cc_library(device_event SRCS device_event.cc DEPS place enforce device_context op_registry) -cc_library(device_event_gpu SRCS device_event_gpu.cc DEPS device_event) +# Manage all device event library +set(DEVICE_EVENT_LIBS) +cc_library(device_event_base SRCS device_event_base.cc DEPS place enforce device_context op_registry) +set(DEVICE_EVENT_LIBS device_event_base CACHE INTERNAL "device event libs") if(WITH_GPU) + nv_library(device_event_gpu SRCS device_event_gpu.cc DEPS device_event_base) + set(DEVICE_EVENT_LIBS device_event_gpu CACHE INTERNAL "device event libs") + nv_test(device_event_test SRCS device_event_test.cc DEPS device_event_gpu) + nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) nv_test(cudnn_desc_test SRCS cudnn_desc_test.cc DEPS dynload_cuda) nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context) - nv_test(device_event_test SRCS device_event_test.cc DEPS device_event_gpu) endif() if(WITH_ROCM) + hip_library(device_event_gpu SRCS device_event_gpu.cc DEPS device_event_base) + set(DEVICE_EVENT_LIBS device_event_gpu CACHE INTERNAL "device event libs") + hip_test(device_event_test SRCS device_event_test.cc DEPS device_event_gpu) + hip_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) hip_test(miopen_helper_test SRCS miopen_helper_test.cc DEPS dynload_cuda) hip_test(cudnn_desc_test SRCS cudnn_desc_test.cc DEPS dynload_cuda tensor) diff --git a/paddle/fluid/platform/device_event.cc b/paddle/fluid/platform/device_event.cc deleted file mode 100644 index 2c96de1637..0000000000 --- a/paddle/fluid/platform/device_event.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/platform/device_event.h" - -namespace paddle { -namespace platform { - -EventCreateFunction DeviceEvent::event_creator_[MaxDeviceTypes]; -EventRecordFunction DeviceEvent::event_recorder_[MaxDeviceTypes]; -EventQueryFunction DeviceEvent::event_querier_[MaxDeviceTypes]; -EventFinishFunction DeviceEvent::event_finisher_[MaxDeviceTypes]; -EventWaitFunction DeviceEvent::event_waiter_[MaxDeviceTypes][MaxDeviceTypes]; - -} // namespace platform -} // namespace paddle diff --git a/paddle/fluid/platform/device_event.h b/paddle/fluid/platform/device_event.h index c1f0acc00e..57f45a4016 100644 --- a/paddle/fluid/platform/device_event.h +++ b/paddle/fluid/platform/device_event.h @@ -11,267 +11,26 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#pragma once -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/place.h" - -namespace paddle { -namespace platform { - -class DeviceOption; -class DeviceEvent; - -constexpr int MaxDeviceTypes = - static_cast(platform::DeviceType::MAX_DEVICE_TYPES); - -typedef void (*EventCreateFunction)(DeviceEvent*, const DeviceOption&); -typedef void (*EventRecordFunction)(DeviceEvent*, const platform::Place&, - const DeviceContext*); -typedef bool (*EventQueryFunction)(const DeviceEvent*); -typedef void (*EventFinishFunction)(const DeviceEvent*); -typedef void (*EventWaitFunction)(const DeviceEvent*, DeviceContext*); - -inline int DeviceTypeToId(const DeviceType& device_type) { - return static_cast(device_type); -} - -class DeviceOption { - public: - explicit DeviceOption(int device_type) : device_type_(device_type) {} - - DeviceOption(int device_type, int device_id) - : device_type_(device_type), device_id_(device_id) {} - - int device_type() const { return device_type_; } - - int device_id() const { return device_id_; } - - private: - int device_type_; - int device_id_; -}; - -class DeviceEvent { - public: - explicit DeviceEvent(const DeviceOption& device_option) - : event_(), - type_(device_option.device_type()), - device_option_(device_option) { - PADDLE_ENFORCE_LT(type_, MaxDeviceTypes, - platform::errors::PreconditionNotMet( - "Required type < %d, but received type = %d", - MaxDeviceTypes, type_)); - PADDLE_ENFORCE_NOT_NULL( - event_creator_[type_], - platform::errors::Unavailable( - "event_creator_[%d] shall not be nullptr.", type_)); - event_creator_[type_](this, device_option_); - } - - ~DeviceEvent() {} - - void Record(const platform::Place& place, const DeviceContext* dev_ctx) { - PADDLE_ENFORCE_NOT_NULL( - event_recorder_[type_], - platform::errors::Unavailable( - "event_recorder_[%d] shall not be nullptr.", type_)); - event_recorder_[type_](this, place, dev_ctx); - } - - bool Query() { - PADDLE_ENFORCE_NOT_NULL( - event_querier_[type_], - platform::errors::Unavailable( - "event_querier_[%d] shall not be nullptr.", type_)); - return event_querier_[type_](this); - } - - void Finish() const { - PADDLE_ENFORCE_NOT_NULL( - event_finisher_[type_], - platform::errors::Unavailable( - "event_finisher_[%d] shall not be nullptr.", type_)); - event_finisher_[type_](this); - } - - void Wait(const DeviceType& waiter_type, DeviceContext* context) const { - auto waiter_idx = DeviceTypeToId(waiter_type); - PADDLE_ENFORCE_NOT_NULL( - event_waiter_[waiter_idx][type_], - platform::errors::Unavailable( - "event_waiter_[%d][%d] shall not be nullptr.", waiter_idx, type_)); - event_waiter_[waiter_idx][type_](this, context); - } - - void InitEvent(std::shared_ptr event) { event_ = event; } - - std::shared_ptr GetEvent() const { return event_; } - private: - std::shared_ptr event_; - int type_; - DeviceOption device_option_; - - static EventCreateFunction event_creator_[MaxDeviceTypes]; - static EventRecordFunction event_recorder_[MaxDeviceTypes]; - static EventQueryFunction event_querier_[MaxDeviceTypes]; - static EventFinishFunction event_finisher_[MaxDeviceTypes]; - static EventWaitFunction event_waiter_[MaxDeviceTypes][MaxDeviceTypes]; - - template - friend struct EventCreateFunctionRegisterer; - - template - friend struct EventRecordFunctionRegisterer; - - template - friend struct EventQueryFunctionRegisterer; - - template - friend struct EventFinishFunctionRegisterer; - - template - friend struct EventWaitFunctionRegisterer; -}; - -/** - * check if MACRO is used in GLOBAL NAMESPACE. +#pragma once +#include "paddle/fluid/platform/device_event_base.h" + +/* + * NOTE: Now we generate this file manually and will consider + * automatically generate it later. Just as 'paddle/fluid/pybind/pybind.h' + * for USE_OP from op_library macros, and + * `paddle/fluid/inference/paddle_inference_pass.h` + * for USE_PASS from pass_library. */ -#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \ - struct __test_global_namespace_##uniq_name##__ {}; \ - static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ - __test_global_namespace_##uniq_name##__>::value, \ - msg) - -// =============== Register for Create =============== -template -struct EventCreateFunctionRegisterer : public framework::Registrar { - explicit EventCreateFunctionRegisterer(EventCreateFunction func) { - auto type_idx = DeviceTypeToId(device_type); - VLOG(3) << "register event_creator with type_id :" << type_idx; - DeviceEvent::event_creator_[type_idx] = func; - } -}; - -#define REGISTER_EVENT_CREATE_FUNCTION(device_type, func) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_event_creator__##device_type, \ - "REGISTER_EVENT_CREATE_FUNCTION must be called in global namespace"); \ - static ::paddle::platform::EventCreateFunctionRegisterer \ - __reg_event_create_##device_type##__(func); \ - int TouchDeviceEventCreate##device_type() { \ - __reg_event_create_##device_type##__.Touch(); \ - return 0; \ - } - -// =============== Register for Record =============== -template -struct EventRecordFunctionRegisterer : public framework::Registrar { - explicit EventRecordFunctionRegisterer(EventRecordFunction func) { - auto type_idx = DeviceTypeToId(device_type); - VLOG(3) << "register event_recorder with type_id :" << type_idx; - DeviceEvent::event_recorder_[type_idx] = func; - } -}; - -#define REGISTER_EVENT_RECORD_FUNCTION(device_type, func) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_event_recorder__##device_type, \ - "REGISTER_EVENT_RECORD_FUNCTION must be called in global namespace"); \ - static ::paddle::platform::EventRecordFunctionRegisterer \ - __reg_event_record_##device_type##__(func); \ - int TouchDeviceEventRecord##device_type() { \ - __reg_event_record_##device_type##__.Touch(); \ - return 0; \ - } - -// =============== Register for Query =============== -template -struct EventQueryFunctionRegisterer : public framework::Registrar { - explicit EventQueryFunctionRegisterer(EventQueryFunction func) { - auto type_idx = DeviceTypeToId(device_type); - VLOG(3) << "register event_querier with type_id :" << type_idx; - DeviceEvent::event_querier_[type_idx] = func; - } -}; - -#define REGISTER_EVENT_QUERY_FUNCTION(device_type, func) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_event_querier__##device_type, \ - "REGISTER_EVENT_QUERY_FUNCTION must be called in global namespace"); \ - static ::paddle::platform::EventQueryFunctionRegisterer \ - __reg_event_query_##device_type##__(func); \ - int TouchDeviceEventQuery##device_type() { \ - __reg_event_query_##device_type##__.Touch(); \ - return 0; \ - } - -// =============== Register for Finish =============== -template -struct EventFinishFunctionRegisterer : public framework::Registrar { - explicit EventFinishFunctionRegisterer(EventFinishFunction func) { - auto type_idx = DeviceTypeToId(device_type); - VLOG(3) << "register event_finisher with type_id :" << type_idx; - DeviceEvent::event_finisher_[type_idx] = func; - } -}; - -#define REGISTER_EVENT_FINISH_FUNCTION(device_type, func) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_event_finishier__##device_type, \ - "REGISTER_EVENT_FINISH_FUNCTION must be called in global namespace"); \ - static ::paddle::platform::EventFinishFunctionRegisterer \ - __reg_event_finish_##device_type##__(func); \ - int TouchDeviceEventFinish##device_type() { \ - __reg_event_finish_##device_type##__.Touch(); \ - return 0; \ - } - -// =============== Register for Wait =============== -template -struct EventWaitFunctionRegisterer : public framework::Registrar { - explicit EventWaitFunctionRegisterer(EventWaitFunction func) { - auto waiter_idx = DeviceTypeToId(waiter_type); - auto event_idx = DeviceTypeToId(event_type); - VLOG(3) << "register event_finisher with waiter_idx : " << waiter_idx - << ", event_idx : " << event_idx; - DeviceEvent::event_waiter_[waiter_idx][event_idx] = func; - } -}; - -#define REGISTER_EVENT_WAIT_FUNCTION(waiter_type, event_type, func) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_event_waiter__##waiter_type##event_type, \ - "REGISTER_EVENT_WAIT_FUNCTION must be called in global namespace"); \ - static ::paddle::platform::EventWaitFunctionRegisterer \ - __reg_event_wait_##waiter_type##event_type##__(func); \ - int TouchDeviceEventWait##waiter_type##event_type() { \ - __reg_event_wait_##waiter_type##event_type##__.Touch(); \ - return 0; \ - } -#define USE_EVENT(device_type) \ - extern int TouchDeviceEventCreate##device_type(); \ - extern int TouchDeviceEventRecord##device_type(); \ - extern int TouchDeviceEventQuery##device_type(); \ - extern int TouchDeviceEventFinish##device_type(); \ - UNUSED static int use_event_creator_##device_type = \ - TouchDeviceEventCreate##device_type(); \ - UNUSED static int use_event_recorder_##device_type = \ - TouchDeviceEventRecord##device_type(); \ - UNUSED static int use_event_querier_##device_type = \ - TouchDeviceEventQuery##device_type(); \ - UNUSED static int use_event_finisher_##device_type = \ - TouchDeviceEventFinish##device_type(); +using ::paddle::platform::kCUDA; +using ::paddle::platform::kCPU; -#define USE_EVENT_WAIT(waiter_type, event_type) \ - extern int TouchDeviceEventWait##waiter_type##event_type(); \ - UNUSED static int use_event_waiter_##waiter_type##event_type = \ - TouchDeviceEventWait##waiter_type##event_type(); +USE_EVENT(kCPU) +USE_EVENT_WAIT(kCPU, kCPU) -} // namespace platform -} // namespace paddle +#ifdef PADDLE_WITH_CUDA +USE_EVENT(kCUDA); +USE_EVENT_WAIT(kCUDA, kCUDA) +USE_EVENT_WAIT(kCPU, kCUDA) +#endif diff --git a/paddle/fluid/platform/device_event_base.cc b/paddle/fluid/platform/device_event_base.cc new file mode 100644 index 0000000000..288052edcc --- /dev/null +++ b/paddle/fluid/platform/device_event_base.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device_event_base.h" +#include "paddle/fluid/platform/device_event_cpu.h" + +namespace paddle { +namespace platform { + +EventCreateFunction DeviceEvent::event_creator_[MaxDeviceTypes]; +EventRecordFunction DeviceEvent::event_recorder_[MaxDeviceTypes]; +EventQueryFunction DeviceEvent::event_querier_[MaxDeviceTypes]; +EventFinishFunction DeviceEvent::event_finisher_[MaxDeviceTypes]; +EventFinishFunction DeviceEvent::event_finished_setter_[MaxDeviceTypes]; +EventWaitFunction DeviceEvent::event_waiter_[MaxDeviceTypes][MaxDeviceTypes]; + +void DeviceEventCreateCPU(DeviceEvent* event, const platform::Place& place, + unsigned int flag) { + event->InitEvent(std::make_shared(place, flag)); +} + +void DeviceEventRecordCPU(DeviceEvent* event, const DeviceContext* context) { + auto* wrapper = static_cast(event->GetEvent().get()); + + std::unique_lock lock(wrapper->mutex_); + PADDLE_ENFORCE_NE(wrapper->status_.load(), EventStatus::SCHEDULED, + platform::errors::PreconditionNotMet( + "EventStatus shall be not SCHEDULED before Record()")); + if (wrapper->status_ == EventStatus::INITIALIZED) { + wrapper->status_ = EventStatus::SCHEDULED; + } +} + +bool DeviceEventQueryCPU(const DeviceEvent* event) { + auto* wrapper = static_cast(event->GetEvent().get()); + PADDLE_ENFORCE_NOT_NULL( + wrapper, platform::errors::PreconditionNotMet( + "Failed to dynamic_cast event into CPUDeviceEventWrapper.")); + + return wrapper->status_ == EventStatus::SUCCESS; +} + +void DeviceEventFinishCPU(const DeviceEvent* event) { + auto* wrapper = static_cast(event->GetEvent().get()); + + std::unique_lock lock(wrapper->mutex_); + while (wrapper->status_ != EventStatus::SUCCESS && + wrapper->status_ != EventStatus::FAILED) { + wrapper->cv_completed_.wait(lock); + } +} + +void DeviceEventCPUWaitCPU(const DeviceEvent* event, + const DeviceContext* context) { + DeviceEventFinishCPU(event); +} + +void EventSetFinishedCPU(const DeviceEvent* event) { + auto* wrapper = static_cast(event->GetEvent().get()); + std::unique_lock lock(wrapper->mutex_); + + PADDLE_ENFORCE_LE(wrapper->status_.load(), EventStatus::SCHEDULED, + platform::errors::PreconditionNotMet( + "EventStatus shall be INITIALIZED | SCHEDULED before " + "EventSetFinishedCPU()")); + wrapper->status_ = EventStatus::SUCCESS; + wrapper->cv_completed_.notify_all(); +} + +} // namespace platform +} // namespace paddle + +using ::paddle::platform::kCPU; +REGISTER_EVENT_CREATE_FUNCTION(kCPU, paddle::platform::DeviceEventCreateCPU) +REGISTER_EVENT_RECORD_FUNCTION(kCPU, paddle::platform::DeviceEventRecordCPU) +REGISTER_EVENT_QUERY_FUNCTION(kCPU, paddle::platform::DeviceEventQueryCPU) +REGISTER_EVENT_FINISH_FUNCTION(kCPU, paddle::platform::DeviceEventFinishCPU) +REGISTER_EVENT_SET_FINISHED_FUNCTION(kCPU, + paddle::platform::EventSetFinishedCPU); +REGISTER_EVENT_WAIT_FUNCTION(kCPU, kCPU, + paddle::platform::DeviceEventCPUWaitCPU) diff --git a/paddle/fluid/platform/device_event_base.h b/paddle/fluid/platform/device_event_base.h new file mode 100644 index 0000000000..d713a638af --- /dev/null +++ b/paddle/fluid/platform/device_event_base.h @@ -0,0 +1,309 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace platform { + +class DeviceOption; +class DeviceEvent; + +constexpr int MaxDeviceTypes = + static_cast(platform::DeviceType::MAX_DEVICE_TYPES); + +typedef void (*EventCreateFunction)(DeviceEvent*, const platform::Place&, + unsigned int flag); +typedef void (*EventRecordFunction)(DeviceEvent*, const DeviceContext*); +typedef bool (*EventQueryFunction)(const DeviceEvent*); +typedef void (*EventFinishFunction)(const DeviceEvent*); +typedef void (*EventSetFinishedFunction)(const DeviceEvent*); +typedef void (*EventWaitFunction)(const DeviceEvent*, const DeviceContext*); + +inline int DeviceTypeToId(const DeviceType& device_type) { + return static_cast(device_type); +} + +enum EventStatus { + INITIALIZED = 0, + SCHEDULED = 1, + SUCCESS = 2, + FAILED = 3, +}; + +class DeviceEvent { + public: + explicit DeviceEvent(const platform::Place& place, unsigned int flag = 0) + : event_(), place_(place), flag_(flag) { + type_id_ = DeviceTypeToId(platform::Place2DeviceType(place)); + PADDLE_ENFORCE_LT(type_id_, MaxDeviceTypes, + platform::errors::PreconditionNotMet( + "Required type < %d, but received type = %d", + MaxDeviceTypes, type_id_)); + // TODO(Aurelius84): only support CPU/CUDA, need consider XPU/NPU later + PADDLE_ENFORCE_LT(type_id_, 3, + platform::errors::Unavailable( + "Currently DeviceEvent do not support %s", place)); + PADDLE_ENFORCE_NOT_NULL( + event_creator_[type_id_], + platform::errors::Unavailable( + "event_creator_[%d] shall not be nullptr.", type_id_)); + event_creator_[type_id_](this, place, flag); + } + + ~DeviceEvent() {} + + void Record(const DeviceContext* dev_ctx) { + PADDLE_ENFORCE_NOT_NULL( + event_recorder_[type_id_], + platform::errors::Unavailable( + "event_recorder_[%d] shall not be nullptr.", type_id_)); + event_recorder_[type_id_](this, dev_ctx); + } + + bool Query() { + PADDLE_ENFORCE_NOT_NULL( + event_querier_[type_id_], + platform::errors::Unavailable( + "event_querier_[%d] shall not be nullptr.", type_id_)); + return event_querier_[type_id_](this); + } + + void Finish() const { + PADDLE_ENFORCE_NOT_NULL( + event_finisher_[type_id_], + platform::errors::Unavailable( + "event_finisher_[%d] shall not be nullptr.", type_id_)); + event_finisher_[type_id_](this); + } + + void SetFininshed() { + PADDLE_ENFORCE_NOT_NULL( + event_finished_setter_[type_id_], + platform::errors::Unavailable( + "event_finished_setter_[%d] shall not be nullptr.", type_id_)); + event_finished_setter_[type_id_](this); + } + + void Wait(const DeviceType& waiter_type, const DeviceContext* context) const { + auto waiter_idx = DeviceTypeToId(waiter_type); + PADDLE_ENFORCE_NOT_NULL(event_waiter_[waiter_idx][type_id_], + platform::errors::Unavailable( + "event_waiter_[%d][%d] shall not be nullptr.", + waiter_idx, type_id_)); + event_waiter_[waiter_idx][type_id_](this, context); + } + + void InitEvent(std::shared_ptr event) { event_ = event; } + + std::shared_ptr GetEvent() const { return event_; } + + private: + std::shared_ptr event_; + platform::Place place_; + int type_id_; + unsigned int flag_; + + static EventCreateFunction event_creator_[MaxDeviceTypes]; + static EventRecordFunction event_recorder_[MaxDeviceTypes]; + static EventQueryFunction event_querier_[MaxDeviceTypes]; + static EventFinishFunction event_finisher_[MaxDeviceTypes]; + static EventFinishFunction event_finished_setter_[MaxDeviceTypes]; + static EventWaitFunction event_waiter_[MaxDeviceTypes][MaxDeviceTypes]; + + template + friend struct EventCreateFunctionRegisterer; + + template + friend struct EventRecordFunctionRegisterer; + + template + friend struct EventQueryFunctionRegisterer; + + template + friend struct EventFinishFunctionRegisterer; + + template + friend struct EventSetFinishedFunctionRegisterer; + + template + friend struct EventWaitFunctionRegisterer; +}; + +/** + * check if MACRO is used in GLOBAL NAMESPACE. + */ +#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \ + struct __test_global_namespace_##uniq_name##__ {}; \ + static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ + __test_global_namespace_##uniq_name##__>::value, \ + msg) + +// =============== Register for Create =============== +template +struct EventCreateFunctionRegisterer : public framework::Registrar { + explicit EventCreateFunctionRegisterer(EventCreateFunction func) { + auto type_idx = DeviceTypeToId(device_type); + VLOG(3) << "register event_creator with type_id :" << type_idx; + DeviceEvent::event_creator_[type_idx] = func; + } +}; + +#define REGISTER_EVENT_CREATE_FUNCTION(device_type, func) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_event_creator__##device_type, \ + "REGISTER_EVENT_CREATE_FUNCTION must be called in global namespace"); \ + static ::paddle::platform::EventCreateFunctionRegisterer \ + __reg_event_create_##device_type##__(func); \ + int TouchDeviceEventCreate##device_type() { \ + __reg_event_create_##device_type##__.Touch(); \ + return 0; \ + } + +// =============== Register for Record =============== +template +struct EventRecordFunctionRegisterer : public framework::Registrar { + explicit EventRecordFunctionRegisterer(EventRecordFunction func) { + auto type_idx = DeviceTypeToId(device_type); + VLOG(3) << "register event_recorder with type_id :" << type_idx; + DeviceEvent::event_recorder_[type_idx] = func; + } +}; + +#define REGISTER_EVENT_RECORD_FUNCTION(device_type, func) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_event_recorder__##device_type, \ + "REGISTER_EVENT_RECORD_FUNCTION must be called in global namespace"); \ + static ::paddle::platform::EventRecordFunctionRegisterer \ + __reg_event_record_##device_type##__(func); \ + int TouchDeviceEventRecord##device_type() { \ + __reg_event_record_##device_type##__.Touch(); \ + return 0; \ + } + +// =============== Register for Query =============== +template +struct EventQueryFunctionRegisterer : public framework::Registrar { + explicit EventQueryFunctionRegisterer(EventQueryFunction func) { + auto type_idx = DeviceTypeToId(device_type); + VLOG(3) << "register event_querier with type_id :" << type_idx; + DeviceEvent::event_querier_[type_idx] = func; + } +}; + +#define REGISTER_EVENT_QUERY_FUNCTION(device_type, func) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_event_querier__##device_type, \ + "REGISTER_EVENT_QUERY_FUNCTION must be called in global namespace"); \ + static ::paddle::platform::EventQueryFunctionRegisterer \ + __reg_event_query_##device_type##__(func); \ + int TouchDeviceEventQuery##device_type() { \ + __reg_event_query_##device_type##__.Touch(); \ + return 0; \ + } + +// =============== Register for Finish =============== +template +struct EventFinishFunctionRegisterer : public framework::Registrar { + explicit EventFinishFunctionRegisterer(EventFinishFunction func) { + auto type_idx = DeviceTypeToId(device_type); + VLOG(3) << "register event_finisher with type_id :" << type_idx; + DeviceEvent::event_finisher_[type_idx] = func; + } +}; + +#define REGISTER_EVENT_FINISH_FUNCTION(device_type, func) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_event_finishier__##device_type, \ + "REGISTER_EVENT_FINISH_FUNCTION must be called in global namespace"); \ + static ::paddle::platform::EventFinishFunctionRegisterer \ + __reg_event_finish_##device_type##__(func); \ + int TouchDeviceEventFinish##device_type() { \ + __reg_event_finish_##device_type##__.Touch(); \ + return 0; \ + } + +// =============== Register for SetFinished =============== +template +struct EventSetFinishedFunctionRegisterer : public framework::Registrar { + explicit EventSetFinishedFunctionRegisterer(EventSetFinishedFunction func) { + auto type_idx = DeviceTypeToId(device_type); + VLOG(3) << "register event_finished_setter with type_id :" << type_idx; + DeviceEvent::event_finished_setter_[type_idx] = func; + } +}; + +#define REGISTER_EVENT_SET_FINISHED_FUNCTION(device_type, func) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_event_finished_setter__##device_type, \ + "REGISTER_EVENT_FINISH_FUNCTION must be called in global namespace"); \ + static ::paddle::platform::EventSetFinishedFunctionRegisterer \ + __reg_event_finished_setter_##device_type##__(func); \ + int TouchDeviceEventSetFinished##device_type() { \ + __reg_event_finished_setter_##device_type##__.Touch(); \ + return 0; \ + } + +// =============== Register for Wait =============== +template +struct EventWaitFunctionRegisterer : public framework::Registrar { + explicit EventWaitFunctionRegisterer(EventWaitFunction func) { + auto waiter_idx = DeviceTypeToId(waiter_type); + auto event_idx = DeviceTypeToId(event_type); + VLOG(3) << "register event_finisher with waiter_idx : " << waiter_idx + << ", event_idx : " << event_idx; + DeviceEvent::event_waiter_[waiter_idx][event_idx] = func; + } +}; + +#define REGISTER_EVENT_WAIT_FUNCTION(waiter_type, event_type, func) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_event_waiter__##waiter_type##event_type, \ + "REGISTER_EVENT_WAIT_FUNCTION must be called in global namespace"); \ + static ::paddle::platform::EventWaitFunctionRegisterer \ + __reg_event_wait_##waiter_type##event_type##__(func); \ + int TouchDeviceEventWait##waiter_type##event_type() { \ + __reg_event_wait_##waiter_type##event_type##__.Touch(); \ + return 0; \ + } + +#define USE_EVENT(device_type) \ + extern int TouchDeviceEventCreate##device_type(); \ + extern int TouchDeviceEventRecord##device_type(); \ + extern int TouchDeviceEventQuery##device_type(); \ + extern int TouchDeviceEventFinish##device_type(); \ + extern int TouchDeviceEventSetFinished##device_type(); \ + UNUSED static int use_event_creator_##device_type = \ + TouchDeviceEventCreate##device_type(); \ + UNUSED static int use_event_recorder_##device_type = \ + TouchDeviceEventRecord##device_type(); \ + UNUSED static int use_event_querier_##device_type = \ + TouchDeviceEventQuery##device_type(); \ + UNUSED static int use_event_finisher_##device_type = \ + TouchDeviceEventFinish##device_type(); \ + UNUSED static int use_event_finished_setter_##device_type = \ + TouchDeviceEventSetFinished##device_type(); + +#define USE_EVENT_WAIT(waiter_type, event_type) \ + extern int TouchDeviceEventWait##waiter_type##event_type(); \ + UNUSED static int use_event_waiter_##waiter_type##event_type = \ + TouchDeviceEventWait##waiter_type##event_type(); + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device_event_cpu.h b/paddle/fluid/platform/device_event_cpu.h new file mode 100644 index 0000000000..b08323d7f1 --- /dev/null +++ b/paddle/fluid/platform/device_event_cpu.h @@ -0,0 +1,51 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "paddle/fluid/platform/device_event_base.h" + +namespace paddle { +namespace platform { + +struct CPUDeviceEventWrapper { + explicit CPUDeviceEventWrapper(const platform::Place& place, + unsigned int flag = 0) { + PADDLE_ENFORCE_EQ( + platform::is_cpu_place(place), true, + platform::errors::PreconditionNotMet( + "Required device shall be CPUAPlace, but received %d. ", place)); + } + std::mutex mutex_; + std::condition_variable cv_completed_; + std::atomic status_; +}; + +void DeviceEventCreateCPU(DeviceEvent* event, const platform::Place& place); + +void DeviceEventRecordCPU(DeviceEvent* event, const platform::Place& place, + const DeviceContext* context); + +bool DeviceEventQueryCPU(const DeviceEvent* event); + +void DeviceEventFinishCPU(const DeviceEvent* event); + +void EventSetFinishedCPU(const DeviceEvent* event); + +void DeviceEventCPUWaitCPU(const DeviceEvent* event, DeviceContext* context); + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device_event_gpu.cc b/paddle/fluid/platform/device_event_gpu.cc index 86bcfdad5b..252ee893bb 100644 --- a/paddle/fluid/platform/device_event_gpu.cc +++ b/paddle/fluid/platform/device_event_gpu.cc @@ -12,38 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/platform/device_event.h" +#include "paddle/fluid/platform/device_event_base.h" #include "paddle/fluid/platform/event.h" #ifdef PADDLE_WITH_CUDA namespace paddle { namespace platform { struct CUDADeviceEventWrapper { - explicit CUDADeviceEventWrapper(const DeviceOption& dev_opt) - : inner_event_() { + CUDADeviceEventWrapper(const platform::Place& place, unsigned int flag) + : inner_event_(flag) { PADDLE_ENFORCE_EQ( - dev_opt.device_type(), static_cast(DeviceType::CUDA), + platform::is_gpu_place(place), true, platform::errors::PreconditionNotMet( - "Required device type shall be CUDA, but received %d. ", - dev_opt.device_type())); + "Required device shall be CUDAPlace, but received %d. ", place)); + + device_id_ = BOOST_GET_CONST(platform::CUDAPlace, place).device; PADDLE_ENFORCE_GT( - dev_opt.device_id(), -1, + device_id_, -1, platform::errors::PreconditionNotMet( "Required DeviceOption.device_id > -1, but received %d. ", - dev_opt.device_id())); - device_id_ = dev_opt.device_id(); + device_id_)); } CudaEvent inner_event_; int device_id_; }; -void DeviceEventCreateCUDA(DeviceEvent* event, const DeviceOption& dev_opt) { - event->InitEvent(std::make_shared(dev_opt)); +void DeviceEventCreateCUDA(DeviceEvent* event, const platform::Place& place, + unsigned int flag) { + event->InitEvent(std::make_shared(place, flag)); } -void DeviceEventRecordCUDA(DeviceEvent* event, const platform::Place& place, - const DeviceContext* context) { +void DeviceEventRecordCUDA(DeviceEvent* event, const DeviceContext* context) { auto* wrapper = static_cast(event->GetEvent().get()); auto* cuda_dev_ctx = @@ -72,7 +72,8 @@ void DeviceEventFinishCUDA(const DeviceEvent* event) { wrapper->inner_event_.Synchronize(); } -void DeviceEventCUDAWaitCUDA(const DeviceEvent* event, DeviceContext* context) { +void DeviceEventCUDAWaitCUDA(const DeviceEvent* event, + const DeviceContext* context) { auto* wrapper = static_cast(event->GetEvent().get()); auto* cuda_dev_ctx = dynamic_cast(context); @@ -85,10 +86,15 @@ void DeviceEventCUDAWaitCUDA(const DeviceEvent* event, DeviceContext* context) { wrapper->inner_event_.GetRawCudaEvent()); } -void DeviceEventCPUWaitCUDA(const DeviceEvent* event, DeviceContext* context) { +void DeviceEventCPUWaitCUDA(const DeviceEvent* event, + const DeviceContext* context) { DeviceEventFinishCUDA(event); } +void DeviceEventSetFinishedCUDA(const DeviceEvent* event) { + // do nothing +} + } // namespace platform } // namespace paddle @@ -98,6 +104,8 @@ REGISTER_EVENT_CREATE_FUNCTION(kCUDA, paddle::platform::DeviceEventCreateCUDA) REGISTER_EVENT_RECORD_FUNCTION(kCUDA, paddle::platform::DeviceEventRecordCUDA) REGISTER_EVENT_QUERY_FUNCTION(kCUDA, paddle::platform::DeviceEventQueryCUDA) REGISTER_EVENT_FINISH_FUNCTION(kCUDA, paddle::platform::DeviceEventFinishCUDA) +REGISTER_EVENT_SET_FINISHED_FUNCTION( + kCUDA, paddle::platform::DeviceEventSetFinishedCUDA) REGISTER_EVENT_WAIT_FUNCTION(kCUDA, kCUDA, paddle::platform::DeviceEventCUDAWaitCUDA) REGISTER_EVENT_WAIT_FUNCTION(kCPU, kCUDA, diff --git a/paddle/fluid/platform/device_event_test.cc b/paddle/fluid/platform/device_event_test.cc index 04288599c4..b25f9772a6 100644 --- a/paddle/fluid/platform/device_event_test.cc +++ b/paddle/fluid/platform/device_event_test.cc @@ -16,35 +16,30 @@ #include "glog/logging.h" #include "gtest/gtest.h" -#ifdef PADDLE_WITH_CUDA -#include using ::paddle::platform::kCUDA; using ::paddle::platform::kCPU; -USE_EVENT(kCUDA); -USE_EVENT_WAIT(kCUDA, kCUDA) -USE_EVENT_WAIT(kCPU, kCUDA) + +using paddle::platform::DeviceEvent; +using paddle::platform::DeviceContextPool; + +#ifdef PADDLE_WITH_CUDA +#include TEST(DeviceEvent, CUDA) { VLOG(1) << "In Test"; using paddle::platform::CUDAPlace; - using paddle::platform::DeviceOption; - using paddle::platform::DeviceEvent; - using paddle::platform::DeviceContextPool; - using paddle::platform::DeviceType; auto& pool = DeviceContextPool::Instance(); auto place = CUDAPlace(0); auto* context = static_cast(pool.Get(place)); - int device_type = static_cast(DeviceType::CUDA); - DeviceOption dev_opt(device_type, place.device); ASSERT_NE(context, nullptr); // case 1. test for event_creator - DeviceEvent event(dev_opt); + DeviceEvent event(place); ASSERT_NE(event.GetEvent().get(), nullptr); // case 2. test for event_recorder - event.Record(place, context); + event.Record(context); bool status = event.Query(); ASSERT_EQ(status, false); // case 3. test for event_finisher @@ -59,7 +54,7 @@ TEST(DeviceEvent, CUDA) { cudaMalloc(reinterpret_cast(&dst_fp32), size); cudaMemcpyAsync(dst_fp32, src_fp32, size, cudaMemcpyHostToDevice, context->stream()); - event.Record(place, context); // step 1. record it + event.Record(context); // step 1. record it status = event.Query(); ASSERT_EQ(status, false); @@ -76,3 +71,17 @@ TEST(DeviceEvent, CUDA) { cudaFreeHost(src_fp32); } #endif + +TEST(DeviceEvent, CPU) { + using paddle::platform::CPUPlace; + auto place = CPUPlace(); + DeviceEvent event(place); + auto& pool = DeviceContextPool::Instance(); + auto* context = pool.Get(place); + + // TODO(Aurelius84): All DeviceContext should has Record/Wait + event.Record(context); + event.SetFininshed(); + bool status = event.Query(); + ASSERT_EQ(status, true); +} -- GitLab