未验证 提交 48bf7cbf 编写于 作者: A Aurelius84 提交者: GitHub

Polish DeviceEvent interface and Remove #ifdef in InterpreterCore (#35196)

* add CPUDeiveEvent

* Polish DeviceEvent code

* Add DEVICE_EVENT_LIBS
上级 7272526b
...@@ -2,7 +2,7 @@ cc_library(workqueue SRCS workqueue.cc) ...@@ -2,7 +2,7 @@ cc_library(workqueue SRCS workqueue.cc)
cc_library(interpretercore SRCS interpretercore.cc DEPS op_registry cc_library(interpretercore SRCS interpretercore.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor workqueue device_event device_event_gpu) graph_to_program_pass variable_helper timer monitor workqueue ${DEVICE_EVENT_LIBS})
cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore) cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore)
cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue) cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue)
# cc_binary(standalone_executor_test SRCS standalone_executor_test.cc DEPS interpretercore standalone_executor operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler) # cc_binary(standalone_executor_test SRCS standalone_executor_test.cc DEPS interpretercore standalone_executor operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)
...@@ -12,16 +12,12 @@ ...@@ -12,16 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/new_executor/interpretercore_gc_helper.h"
#if defined(PADDLE_WITH_CUDA)
using ::paddle::platform::kCUDA;
USE_EVENT(kCUDA);
#endif
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/new_executor/interpretercore_gc_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -74,27 +70,26 @@ std::vector<size_t> ParseEventVarIds(const Instruction& cur_instr, ...@@ -74,27 +70,26 @@ std::vector<size_t> ParseEventVarIds(const Instruction& cur_instr,
} }
void AssociateInputWithEvents( void AssociateInputWithEvents(
const std::vector<size_t>& new_event_var_id, Instruction* next_instr, const platform::Place& place, const std::vector<size_t>& new_event_var_id,
std::map<size_t, std::shared_ptr<platform::CudaEvent>>* var_id2event, Instruction* next_instr,
std::map<size_t, std::shared_ptr<platform::DeviceEvent>>* var_id2event,
bool is_sync) { bool is_sync) {
#ifdef PADDLE_WITH_CUDA
for (auto var_id : new_event_var_id) { for (auto var_id : new_event_var_id) {
if (var_id2event->count(var_id) == 0) { if (var_id2event->count(var_id) == 0) {
auto cuda_event = std::make_shared<platform::CudaEvent>( auto device_event = std::make_shared<platform::DeviceEvent>(
platform::get_cuda_flags(false, false, false)); place, platform::get_cuda_flags(false, false, false));
var_id2event->emplace(var_id, std::move(cuda_event)); var_id2event->emplace(var_id, std::move(device_event));
} }
// Add events for next_instr.inputs // Add events for next_instr.inputs
next_instr->intput_events_.emplace_back(var_id, var_id2event->at(var_id), next_instr->intput_events_.emplace_back(var_id, var_id2event->at(var_id),
is_sync); is_sync);
} }
#endif
} }
void ParseDirectAndEventRunOps( void ParseDirectAndEventRunOps(
const std::vector<OpFuncNode>& op_func_nodes, const platform::Place& place, const std::vector<OpFuncNode>& op_func_nodes,
const std::vector<size_t>& downstream_ops, size_t op_index, const std::vector<size_t>& downstream_ops, size_t op_index,
std::map<size_t, std::shared_ptr<platform::CudaEvent>>* var_id2event, std::map<size_t, std::shared_ptr<platform::DeviceEvent>>* var_id2event,
std::vector<Instruction>* instructions) { std::vector<Instruction>* instructions) {
auto& op_func_type = op_func_nodes[op_index].type_; auto& op_func_type = op_func_nodes[op_index].type_;
auto& cur_instr = instructions->at(op_index); auto& cur_instr = instructions->at(op_index);
...@@ -119,8 +114,8 @@ void ParseDirectAndEventRunOps( ...@@ -119,8 +114,8 @@ void ParseDirectAndEventRunOps(
bool is_sync = bool is_sync =
(op_func_nodes[next_op_id].type_ == OpFuncType::kQueueSync); (op_func_nodes[next_op_id].type_ == OpFuncType::kQueueSync);
AssociateInputWithEvents(new_event_var_ids, &next_instr, var_id2event, AssociateInputWithEvents(place, new_event_var_ids, &next_instr,
is_sync); var_id2event, is_sync);
if (is_sync) { // GPU -> CPU if (is_sync) { // GPU -> CPU
next_instruction.synchronize_run_.emplace_back(next_op_id); next_instruction.synchronize_run_.emplace_back(next_op_id);
...@@ -128,7 +123,6 @@ void ParseDirectAndEventRunOps( ...@@ -128,7 +123,6 @@ void ParseDirectAndEventRunOps(
next_instruction.event_wait_run_.emplace_back(next_op_id); next_instruction.event_wait_run_.emplace_back(next_op_id);
} }
} }
#ifdef PADDLE_WITH_CUDA
// Create events for these cross-stream vars // Create events for these cross-stream vars
VLOG(3) << cur_instr.kernel_func_.operator_base_->Type() VLOG(3) << cur_instr.kernel_func_.operator_base_->Type()
<< " event_var_ids.size: " << event_var_ids.size(); << " event_var_ids.size: " << event_var_ids.size();
...@@ -136,7 +130,6 @@ void ParseDirectAndEventRunOps( ...@@ -136,7 +130,6 @@ void ParseDirectAndEventRunOps(
cur_instr.output_events_.emplace_back(var_id, var_id2event->at(var_id), cur_instr.output_events_.emplace_back(var_id, var_id2event->at(var_id),
false /*not used*/); false /*not used*/);
} }
#endif
} }
} }
} // namespace } // namespace
...@@ -263,12 +256,10 @@ void InterpreterCore::Convert() { ...@@ -263,12 +256,10 @@ void InterpreterCore::Convert() {
} }
for (size_t i = 0; i < vec_instruction_.size(); ++i) { for (size_t i = 0; i < vec_instruction_.size(); ++i) {
#if defined(PADDLE_WITH_CUDA) // int device_type = static_cast<int>(paddle::platform::DeviceType::CUDA);
int device_type = static_cast<int>(paddle::platform::DeviceType::CUDA); // paddle::platform::DeviceOption dev_opt(
paddle::platform::DeviceOption dev_opt( // device_type, BOOST_GET_CONST(platform::CUDAPlace, place_).device);
device_type, BOOST_GET_CONST(platform::CUDAPlace, place_).device); gc_event_.emplace_back(place_);
gc_event_.emplace_back(dev_opt);
#endif
std::vector<size_t> vec_temp; std::vector<size_t> vec_temp;
for (auto& item : vec_instruction_[i].output_index_) { for (auto& item : vec_instruction_[i].output_index_) {
...@@ -287,8 +278,8 @@ void InterpreterCore::Convert() { ...@@ -287,8 +278,8 @@ void InterpreterCore::Convert() {
} }
} }
ParseDirectAndEventRunOps(vec_func_list_, filter_next, i, &var_id2event_, ParseDirectAndEventRunOps(place_, vec_func_list_, filter_next, i,
&vec_instruction_); &var_id2event_, &vec_instruction_);
// checkout ouput // checkout ouput
for (auto& item : vec_instruction_[i].output_index_) { for (auto& item : vec_instruction_[i].output_index_) {
...@@ -466,7 +457,7 @@ void InterpreterCore::CheckGC(size_t instr_id, ...@@ -466,7 +457,7 @@ void InterpreterCore::CheckGC(size_t instr_id,
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
auto* dev_ctx = reinterpret_cast<platform::CUDADeviceContext*>( auto* dev_ctx = reinterpret_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
gc_event_[instr_id].Record(place, dev_ctx); gc_event_[instr_id].Record(dev_ctx);
gc_queue_->AddTask( gc_queue_->AddTask(
[ container = garbages_.release(), event = &gc_event_[instr_id] ]() { [ container = garbages_.release(), event = &gc_event_[instr_id] ]() {
while (!event->Query()) { while (!event->Query()) {
...@@ -483,7 +474,7 @@ void InterpreterCore::CheckGC(size_t instr_id, ...@@ -483,7 +474,7 @@ void InterpreterCore::CheckGC(size_t instr_id,
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
auto* dev_ctx = reinterpret_cast<platform::CUDADeviceContext*>( auto* dev_ctx = reinterpret_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
gc_event_[instr_id].Record(place, dev_ctx); gc_event_[instr_id].Record(dev_ctx);
gc_queue_->AddTask( gc_queue_->AddTask(
[ container = garbages_.release(), event = &gc_event_[instr_id] ]() { [ container = garbages_.release(), event = &gc_event_[instr_id] ]() {
while (!event->Query()) { while (!event->Query()) {
...@@ -857,34 +848,23 @@ void InterpreterCore::RecordEventInstruction(const Instruction& instruction, ...@@ -857,34 +848,23 @@ void InterpreterCore::RecordEventInstruction(const Instruction& instruction,
// If InterpreterCore in on CPUPlace, do nothing. // If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(place_)) return; if (platform::is_cpu_place(place_)) return;
#ifdef PADDLE_WITH_CUDA
const platform::CUDADeviceContext* dev_ctx =
reinterpret_cast<const platform::CUDADeviceContext*>(
instruction.dev_ctx_);
for (auto& event : instruction.output_events_) { for (auto& event : instruction.output_events_) {
VLOG(3) << "Record event in out_var_id: " << event.var_id_; VLOG(3) << "Record event in out_var_id: " << event.var_id_;
event.event_->Record(*(dev_ctx->context()->Stream())); event.event_->Record(instruction.dev_ctx_);
} }
#endif
} }
void InterpreterCore::WaitOrSync(const std::vector<EventInter>& events, void InterpreterCore::WaitOrSync(const std::vector<EventInter>& events,
const platform::DeviceContext* dev_ctx) { const platform::DeviceContext* dev_ctx) {
#ifdef PADDLE_WITH_CUDA for (auto& event_iter : events) {
auto* cuda_dev_ctx = if (event_iter.is_sync_) {
reinterpret_cast<const platform::CUDADeviceContext*>(dev_ctx); VLOG(3) << "host sync wait in_var_id " << event_iter.var_id_;
event_iter.event_->Wait(platform::kCPU, dev_ctx);
for (auto& event : events) {
if (event.is_sync_) {
VLOG(3) << "host sync wait in_var_id " << event.var_id_;
event.event_->Synchronize();
} else { } else {
VLOG(3) << "stream async wait in_var_id " << event.var_id_; VLOG(3) << "stream async wait in_var_id " << event_iter.var_id_;
cuda_dev_ctx->context()->Stream()->WaitEvent( event_iter.event_->Wait(platform::kCUDA, dev_ctx);
event.event_->GetRawCudaEvent());
} }
} }
#endif
} }
void InterpreterCore::StreamWaitEventOrSync(const Instruction& instruction) { void InterpreterCore::StreamWaitEventOrSync(const Instruction& instruction) {
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_event.h" #include "paddle/fluid/platform/device_event.h"
#include "paddle/fluid/platform/event.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -101,7 +100,7 @@ class InterpreterCore { ...@@ -101,7 +100,7 @@ class InterpreterCore {
bool is_build_; bool is_build_;
std::vector<std::string> feed_names_; std::vector<std::string> feed_names_;
std::map<size_t, std::shared_ptr<platform::CudaEvent>> var_id2event_; std::map<size_t, std::shared_ptr<platform::DeviceEvent>> var_id2event_;
std::vector<paddle::platform::DeviceEvent> gc_event_; std::vector<paddle::platform::DeviceEvent> gc_event_;
std::unique_ptr<GarbageQueue> garbages_; std::unique_ptr<GarbageQueue> garbages_;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device_event_base.h"
#include "paddle/fluid/platform/event.h" #include "paddle/fluid/platform/event.h"
namespace paddle { namespace paddle {
...@@ -56,11 +57,12 @@ struct NextInstruction { ...@@ -56,11 +57,12 @@ struct NextInstruction {
}; };
struct EventInter { struct EventInter {
explicit EventInter(size_t var_id, std::shared_ptr<platform::CudaEvent> event, explicit EventInter(size_t var_id,
std::shared_ptr<platform::DeviceEvent> event,
bool is_sync) bool is_sync)
: var_id_(var_id), event_(event), is_sync_(is_sync) {} : var_id_(var_id), event_(event), is_sync_(is_sync) {}
size_t var_id_; size_t var_id_;
std::shared_ptr<platform::CudaEvent> event_; std::shared_ptr<platform::DeviceEvent> event_;
bool is_sync_; bool is_sync_;
}; };
......
...@@ -151,19 +151,28 @@ endif() ...@@ -151,19 +151,28 @@ endif()
cc_test(init_test SRCS init_test.cc DEPS device_context) cc_test(init_test SRCS init_test.cc DEPS device_context)
cc_library(device_event SRCS device_event.cc DEPS place enforce device_context op_registry) # Manage all device event library
cc_library(device_event_gpu SRCS device_event_gpu.cc DEPS device_event) set(DEVICE_EVENT_LIBS)
cc_library(device_event_base SRCS device_event_base.cc DEPS place enforce device_context op_registry)
set(DEVICE_EVENT_LIBS device_event_base CACHE INTERNAL "device event libs")
if(WITH_GPU) if(WITH_GPU)
nv_library(device_event_gpu SRCS device_event_gpu.cc DEPS device_event_base)
set(DEVICE_EVENT_LIBS device_event_gpu CACHE INTERNAL "device event libs")
nv_test(device_event_test SRCS device_event_test.cc DEPS device_event_gpu)
nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)
nv_test(cudnn_desc_test SRCS cudnn_desc_test.cc DEPS dynload_cuda) nv_test(cudnn_desc_test SRCS cudnn_desc_test.cc DEPS dynload_cuda)
nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context) nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context)
nv_test(device_event_test SRCS device_event_test.cc DEPS device_event_gpu)
endif() endif()
if(WITH_ROCM) if(WITH_ROCM)
hip_library(device_event_gpu SRCS device_event_gpu.cc DEPS device_event_base)
set(DEVICE_EVENT_LIBS device_event_gpu CACHE INTERNAL "device event libs")
hip_test(device_event_test SRCS device_event_test.cc DEPS device_event_gpu)
hip_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) hip_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
hip_test(miopen_helper_test SRCS miopen_helper_test.cc DEPS dynload_cuda) hip_test(miopen_helper_test SRCS miopen_helper_test.cc DEPS dynload_cuda)
hip_test(cudnn_desc_test SRCS cudnn_desc_test.cc DEPS dynload_cuda tensor) hip_test(cudnn_desc_test SRCS cudnn_desc_test.cc DEPS dynload_cuda tensor)
......
...@@ -11,267 +11,26 @@ ...@@ -11,267 +11,26 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace platform {
class DeviceOption;
class DeviceEvent;
constexpr int MaxDeviceTypes =
static_cast<int>(platform::DeviceType::MAX_DEVICE_TYPES);
typedef void (*EventCreateFunction)(DeviceEvent*, const DeviceOption&);
typedef void (*EventRecordFunction)(DeviceEvent*, const platform::Place&,
const DeviceContext*);
typedef bool (*EventQueryFunction)(const DeviceEvent*);
typedef void (*EventFinishFunction)(const DeviceEvent*);
typedef void (*EventWaitFunction)(const DeviceEvent*, DeviceContext*);
inline int DeviceTypeToId(const DeviceType& device_type) {
return static_cast<int>(device_type);
}
class DeviceOption {
public:
explicit DeviceOption(int device_type) : device_type_(device_type) {}
DeviceOption(int device_type, int device_id)
: device_type_(device_type), device_id_(device_id) {}
int device_type() const { return device_type_; }
int device_id() const { return device_id_; }
private:
int device_type_;
int device_id_;
};
class DeviceEvent {
public:
explicit DeviceEvent(const DeviceOption& device_option)
: event_(),
type_(device_option.device_type()),
device_option_(device_option) {
PADDLE_ENFORCE_LT(type_, MaxDeviceTypes,
platform::errors::PreconditionNotMet(
"Required type < %d, but received type = %d",
MaxDeviceTypes, type_));
PADDLE_ENFORCE_NOT_NULL(
event_creator_[type_],
platform::errors::Unavailable(
"event_creator_[%d] shall not be nullptr.", type_));
event_creator_[type_](this, device_option_);
}
~DeviceEvent() {}
void Record(const platform::Place& place, const DeviceContext* dev_ctx) {
PADDLE_ENFORCE_NOT_NULL(
event_recorder_[type_],
platform::errors::Unavailable(
"event_recorder_[%d] shall not be nullptr.", type_));
event_recorder_[type_](this, place, dev_ctx);
}
bool Query() {
PADDLE_ENFORCE_NOT_NULL(
event_querier_[type_],
platform::errors::Unavailable(
"event_querier_[%d] shall not be nullptr.", type_));
return event_querier_[type_](this);
}
void Finish() const {
PADDLE_ENFORCE_NOT_NULL(
event_finisher_[type_],
platform::errors::Unavailable(
"event_finisher_[%d] shall not be nullptr.", type_));
event_finisher_[type_](this);
}
void Wait(const DeviceType& waiter_type, DeviceContext* context) const {
auto waiter_idx = DeviceTypeToId(waiter_type);
PADDLE_ENFORCE_NOT_NULL(
event_waiter_[waiter_idx][type_],
platform::errors::Unavailable(
"event_waiter_[%d][%d] shall not be nullptr.", waiter_idx, type_));
event_waiter_[waiter_idx][type_](this, context);
}
void InitEvent(std::shared_ptr<void> event) { event_ = event; }
std::shared_ptr<void> GetEvent() const { return event_; }
private: #pragma once
std::shared_ptr<void> event_; #include "paddle/fluid/platform/device_event_base.h"
int type_;
DeviceOption device_option_; /*
* NOTE: Now we generate this file manually and will consider
static EventCreateFunction event_creator_[MaxDeviceTypes]; * automatically generate it later. Just as 'paddle/fluid/pybind/pybind.h'
static EventRecordFunction event_recorder_[MaxDeviceTypes]; * for USE_OP from op_library macros, and
static EventQueryFunction event_querier_[MaxDeviceTypes]; * `paddle/fluid/inference/paddle_inference_pass.h`
static EventFinishFunction event_finisher_[MaxDeviceTypes]; * for USE_PASS from pass_library.
static EventWaitFunction event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
template <DeviceType device_typ>
friend struct EventCreateFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventRecordFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventQueryFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventFinishFunctionRegisterer;
template <DeviceType waiter_typ, DeviceType event_type>
friend struct EventWaitFunctionRegisterer;
};
/**
* check if MACRO is used in GLOBAL NAMESPACE.
*/ */
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
// =============== Register for Create ===============
template <DeviceType device_type>
struct EventCreateFunctionRegisterer : public framework::Registrar {
explicit EventCreateFunctionRegisterer(EventCreateFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_creator with type_id :" << type_idx;
DeviceEvent::event_creator_[type_idx] = func;
}
};
#define REGISTER_EVENT_CREATE_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_creator__##device_type, \
"REGISTER_EVENT_CREATE_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventCreateFunctionRegisterer<device_type> \
__reg_event_create_##device_type##__(func); \
int TouchDeviceEventCreate##device_type() { \
__reg_event_create_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Record ===============
template <DeviceType device_type>
struct EventRecordFunctionRegisterer : public framework::Registrar {
explicit EventRecordFunctionRegisterer(EventRecordFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_recorder with type_id :" << type_idx;
DeviceEvent::event_recorder_[type_idx] = func;
}
};
#define REGISTER_EVENT_RECORD_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_recorder__##device_type, \
"REGISTER_EVENT_RECORD_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventRecordFunctionRegisterer<device_type> \
__reg_event_record_##device_type##__(func); \
int TouchDeviceEventRecord##device_type() { \
__reg_event_record_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Query ===============
template <DeviceType device_type>
struct EventQueryFunctionRegisterer : public framework::Registrar {
explicit EventQueryFunctionRegisterer(EventQueryFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_querier with type_id :" << type_idx;
DeviceEvent::event_querier_[type_idx] = func;
}
};
#define REGISTER_EVENT_QUERY_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_querier__##device_type, \
"REGISTER_EVENT_QUERY_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventQueryFunctionRegisterer<device_type> \
__reg_event_query_##device_type##__(func); \
int TouchDeviceEventQuery##device_type() { \
__reg_event_query_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Finish ===============
template <DeviceType device_type>
struct EventFinishFunctionRegisterer : public framework::Registrar {
explicit EventFinishFunctionRegisterer(EventFinishFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_finisher with type_id :" << type_idx;
DeviceEvent::event_finisher_[type_idx] = func;
}
};
#define REGISTER_EVENT_FINISH_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_finishier__##device_type, \
"REGISTER_EVENT_FINISH_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventFinishFunctionRegisterer<device_type> \
__reg_event_finish_##device_type##__(func); \
int TouchDeviceEventFinish##device_type() { \
__reg_event_finish_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Wait ===============
template <DeviceType waiter_type, DeviceType event_type>
struct EventWaitFunctionRegisterer : public framework::Registrar {
explicit EventWaitFunctionRegisterer(EventWaitFunction func) {
auto waiter_idx = DeviceTypeToId(waiter_type);
auto event_idx = DeviceTypeToId(event_type);
VLOG(3) << "register event_finisher with waiter_idx : " << waiter_idx
<< ", event_idx : " << event_idx;
DeviceEvent::event_waiter_[waiter_idx][event_idx] = func;
}
};
#define REGISTER_EVENT_WAIT_FUNCTION(waiter_type, event_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_waiter__##waiter_type##event_type, \
"REGISTER_EVENT_WAIT_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventWaitFunctionRegisterer<waiter_type, \
event_type> \
__reg_event_wait_##waiter_type##event_type##__(func); \
int TouchDeviceEventWait##waiter_type##event_type() { \
__reg_event_wait_##waiter_type##event_type##__.Touch(); \
return 0; \
}
#define USE_EVENT(device_type) \ using ::paddle::platform::kCUDA;
extern int TouchDeviceEventCreate##device_type(); \ using ::paddle::platform::kCPU;
extern int TouchDeviceEventRecord##device_type(); \
extern int TouchDeviceEventQuery##device_type(); \
extern int TouchDeviceEventFinish##device_type(); \
UNUSED static int use_event_creator_##device_type = \
TouchDeviceEventCreate##device_type(); \
UNUSED static int use_event_recorder_##device_type = \
TouchDeviceEventRecord##device_type(); \
UNUSED static int use_event_querier_##device_type = \
TouchDeviceEventQuery##device_type(); \
UNUSED static int use_event_finisher_##device_type = \
TouchDeviceEventFinish##device_type();
#define USE_EVENT_WAIT(waiter_type, event_type) \ USE_EVENT(kCPU)
extern int TouchDeviceEventWait##waiter_type##event_type(); \ USE_EVENT_WAIT(kCPU, kCPU)
UNUSED static int use_event_waiter_##waiter_type##event_type = \
TouchDeviceEventWait##waiter_type##event_type();
} // namespace platform #ifdef PADDLE_WITH_CUDA
} // namespace paddle USE_EVENT(kCUDA);
USE_EVENT_WAIT(kCUDA, kCUDA)
USE_EVENT_WAIT(kCPU, kCUDA)
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device_event_base.h"
#include "paddle/fluid/platform/device_event_cpu.h"
namespace paddle {
namespace platform {
EventCreateFunction DeviceEvent::event_creator_[MaxDeviceTypes];
EventRecordFunction DeviceEvent::event_recorder_[MaxDeviceTypes];
EventQueryFunction DeviceEvent::event_querier_[MaxDeviceTypes];
EventFinishFunction DeviceEvent::event_finisher_[MaxDeviceTypes];
EventFinishFunction DeviceEvent::event_finished_setter_[MaxDeviceTypes];
EventWaitFunction DeviceEvent::event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
void DeviceEventCreateCPU(DeviceEvent* event, const platform::Place& place,
unsigned int flag) {
event->InitEvent(std::make_shared<CPUDeviceEventWrapper>(place, flag));
}
void DeviceEventRecordCPU(DeviceEvent* event, const DeviceContext* context) {
auto* wrapper = static_cast<CPUDeviceEventWrapper*>(event->GetEvent().get());
std::unique_lock<std::mutex> lock(wrapper->mutex_);
PADDLE_ENFORCE_NE(wrapper->status_.load(), EventStatus::SCHEDULED,
platform::errors::PreconditionNotMet(
"EventStatus shall be not SCHEDULED before Record()"));
if (wrapper->status_ == EventStatus::INITIALIZED) {
wrapper->status_ = EventStatus::SCHEDULED;
}
}
bool DeviceEventQueryCPU(const DeviceEvent* event) {
auto* wrapper = static_cast<CPUDeviceEventWrapper*>(event->GetEvent().get());
PADDLE_ENFORCE_NOT_NULL(
wrapper, platform::errors::PreconditionNotMet(
"Failed to dynamic_cast event into CPUDeviceEventWrapper."));
return wrapper->status_ == EventStatus::SUCCESS;
}
void DeviceEventFinishCPU(const DeviceEvent* event) {
auto* wrapper = static_cast<CPUDeviceEventWrapper*>(event->GetEvent().get());
std::unique_lock<std::mutex> lock(wrapper->mutex_);
while (wrapper->status_ != EventStatus::SUCCESS &&
wrapper->status_ != EventStatus::FAILED) {
wrapper->cv_completed_.wait(lock);
}
}
void DeviceEventCPUWaitCPU(const DeviceEvent* event,
const DeviceContext* context) {
DeviceEventFinishCPU(event);
}
void EventSetFinishedCPU(const DeviceEvent* event) {
auto* wrapper = static_cast<CPUDeviceEventWrapper*>(event->GetEvent().get());
std::unique_lock<std::mutex> lock(wrapper->mutex_);
PADDLE_ENFORCE_LE(wrapper->status_.load(), EventStatus::SCHEDULED,
platform::errors::PreconditionNotMet(
"EventStatus shall be INITIALIZED | SCHEDULED before "
"EventSetFinishedCPU()"));
wrapper->status_ = EventStatus::SUCCESS;
wrapper->cv_completed_.notify_all();
}
} // namespace platform
} // namespace paddle
using ::paddle::platform::kCPU;
REGISTER_EVENT_CREATE_FUNCTION(kCPU, paddle::platform::DeviceEventCreateCPU)
REGISTER_EVENT_RECORD_FUNCTION(kCPU, paddle::platform::DeviceEventRecordCPU)
REGISTER_EVENT_QUERY_FUNCTION(kCPU, paddle::platform::DeviceEventQueryCPU)
REGISTER_EVENT_FINISH_FUNCTION(kCPU, paddle::platform::DeviceEventFinishCPU)
REGISTER_EVENT_SET_FINISHED_FUNCTION(kCPU,
paddle::platform::EventSetFinishedCPU);
REGISTER_EVENT_WAIT_FUNCTION(kCPU, kCPU,
paddle::platform::DeviceEventCPUWaitCPU)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace platform {
class DeviceOption;
class DeviceEvent;
constexpr int MaxDeviceTypes =
static_cast<int>(platform::DeviceType::MAX_DEVICE_TYPES);
typedef void (*EventCreateFunction)(DeviceEvent*, const platform::Place&,
unsigned int flag);
typedef void (*EventRecordFunction)(DeviceEvent*, const DeviceContext*);
typedef bool (*EventQueryFunction)(const DeviceEvent*);
typedef void (*EventFinishFunction)(const DeviceEvent*);
typedef void (*EventSetFinishedFunction)(const DeviceEvent*);
typedef void (*EventWaitFunction)(const DeviceEvent*, const DeviceContext*);
inline int DeviceTypeToId(const DeviceType& device_type) {
return static_cast<int>(device_type);
}
enum EventStatus {
INITIALIZED = 0,
SCHEDULED = 1,
SUCCESS = 2,
FAILED = 3,
};
class DeviceEvent {
public:
explicit DeviceEvent(const platform::Place& place, unsigned int flag = 0)
: event_(), place_(place), flag_(flag) {
type_id_ = DeviceTypeToId(platform::Place2DeviceType(place));
PADDLE_ENFORCE_LT(type_id_, MaxDeviceTypes,
platform::errors::PreconditionNotMet(
"Required type < %d, but received type = %d",
MaxDeviceTypes, type_id_));
// TODO(Aurelius84): only support CPU/CUDA, need consider XPU/NPU later
PADDLE_ENFORCE_LT(type_id_, 3,
platform::errors::Unavailable(
"Currently DeviceEvent do not support %s", place));
PADDLE_ENFORCE_NOT_NULL(
event_creator_[type_id_],
platform::errors::Unavailable(
"event_creator_[%d] shall not be nullptr.", type_id_));
event_creator_[type_id_](this, place, flag);
}
~DeviceEvent() {}
void Record(const DeviceContext* dev_ctx) {
PADDLE_ENFORCE_NOT_NULL(
event_recorder_[type_id_],
platform::errors::Unavailable(
"event_recorder_[%d] shall not be nullptr.", type_id_));
event_recorder_[type_id_](this, dev_ctx);
}
bool Query() {
PADDLE_ENFORCE_NOT_NULL(
event_querier_[type_id_],
platform::errors::Unavailable(
"event_querier_[%d] shall not be nullptr.", type_id_));
return event_querier_[type_id_](this);
}
void Finish() const {
PADDLE_ENFORCE_NOT_NULL(
event_finisher_[type_id_],
platform::errors::Unavailable(
"event_finisher_[%d] shall not be nullptr.", type_id_));
event_finisher_[type_id_](this);
}
void SetFininshed() {
PADDLE_ENFORCE_NOT_NULL(
event_finished_setter_[type_id_],
platform::errors::Unavailable(
"event_finished_setter_[%d] shall not be nullptr.", type_id_));
event_finished_setter_[type_id_](this);
}
void Wait(const DeviceType& waiter_type, const DeviceContext* context) const {
auto waiter_idx = DeviceTypeToId(waiter_type);
PADDLE_ENFORCE_NOT_NULL(event_waiter_[waiter_idx][type_id_],
platform::errors::Unavailable(
"event_waiter_[%d][%d] shall not be nullptr.",
waiter_idx, type_id_));
event_waiter_[waiter_idx][type_id_](this, context);
}
void InitEvent(std::shared_ptr<void> event) { event_ = event; }
std::shared_ptr<void> GetEvent() const { return event_; }
private:
std::shared_ptr<void> event_;
platform::Place place_;
int type_id_;
unsigned int flag_;
static EventCreateFunction event_creator_[MaxDeviceTypes];
static EventRecordFunction event_recorder_[MaxDeviceTypes];
static EventQueryFunction event_querier_[MaxDeviceTypes];
static EventFinishFunction event_finisher_[MaxDeviceTypes];
static EventFinishFunction event_finished_setter_[MaxDeviceTypes];
static EventWaitFunction event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
template <DeviceType device_typ>
friend struct EventCreateFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventRecordFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventQueryFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventFinishFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventSetFinishedFunctionRegisterer;
template <DeviceType waiter_typ, DeviceType event_type>
friend struct EventWaitFunctionRegisterer;
};
/**
* check if MACRO is used in GLOBAL NAMESPACE.
*/
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
// =============== Register for Create ===============
template <DeviceType device_type>
struct EventCreateFunctionRegisterer : public framework::Registrar {
explicit EventCreateFunctionRegisterer(EventCreateFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_creator with type_id :" << type_idx;
DeviceEvent::event_creator_[type_idx] = func;
}
};
#define REGISTER_EVENT_CREATE_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_creator__##device_type, \
"REGISTER_EVENT_CREATE_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventCreateFunctionRegisterer<device_type> \
__reg_event_create_##device_type##__(func); \
int TouchDeviceEventCreate##device_type() { \
__reg_event_create_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Record ===============
template <DeviceType device_type>
struct EventRecordFunctionRegisterer : public framework::Registrar {
explicit EventRecordFunctionRegisterer(EventRecordFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_recorder with type_id :" << type_idx;
DeviceEvent::event_recorder_[type_idx] = func;
}
};
#define REGISTER_EVENT_RECORD_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_recorder__##device_type, \
"REGISTER_EVENT_RECORD_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventRecordFunctionRegisterer<device_type> \
__reg_event_record_##device_type##__(func); \
int TouchDeviceEventRecord##device_type() { \
__reg_event_record_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Query ===============
template <DeviceType device_type>
struct EventQueryFunctionRegisterer : public framework::Registrar {
explicit EventQueryFunctionRegisterer(EventQueryFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_querier with type_id :" << type_idx;
DeviceEvent::event_querier_[type_idx] = func;
}
};
#define REGISTER_EVENT_QUERY_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_querier__##device_type, \
"REGISTER_EVENT_QUERY_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventQueryFunctionRegisterer<device_type> \
__reg_event_query_##device_type##__(func); \
int TouchDeviceEventQuery##device_type() { \
__reg_event_query_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Finish ===============
template <DeviceType device_type>
struct EventFinishFunctionRegisterer : public framework::Registrar {
explicit EventFinishFunctionRegisterer(EventFinishFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_finisher with type_id :" << type_idx;
DeviceEvent::event_finisher_[type_idx] = func;
}
};
#define REGISTER_EVENT_FINISH_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_finishier__##device_type, \
"REGISTER_EVENT_FINISH_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventFinishFunctionRegisterer<device_type> \
__reg_event_finish_##device_type##__(func); \
int TouchDeviceEventFinish##device_type() { \
__reg_event_finish_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for SetFinished ===============
template <DeviceType device_type>
struct EventSetFinishedFunctionRegisterer : public framework::Registrar {
explicit EventSetFinishedFunctionRegisterer(EventSetFinishedFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_finished_setter with type_id :" << type_idx;
DeviceEvent::event_finished_setter_[type_idx] = func;
}
};
#define REGISTER_EVENT_SET_FINISHED_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_finished_setter__##device_type, \
"REGISTER_EVENT_FINISH_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventSetFinishedFunctionRegisterer<device_type> \
__reg_event_finished_setter_##device_type##__(func); \
int TouchDeviceEventSetFinished##device_type() { \
__reg_event_finished_setter_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Wait ===============
template <DeviceType waiter_type, DeviceType event_type>
struct EventWaitFunctionRegisterer : public framework::Registrar {
explicit EventWaitFunctionRegisterer(EventWaitFunction func) {
auto waiter_idx = DeviceTypeToId(waiter_type);
auto event_idx = DeviceTypeToId(event_type);
VLOG(3) << "register event_finisher with waiter_idx : " << waiter_idx
<< ", event_idx : " << event_idx;
DeviceEvent::event_waiter_[waiter_idx][event_idx] = func;
}
};
#define REGISTER_EVENT_WAIT_FUNCTION(waiter_type, event_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_waiter__##waiter_type##event_type, \
"REGISTER_EVENT_WAIT_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventWaitFunctionRegisterer<waiter_type, \
event_type> \
__reg_event_wait_##waiter_type##event_type##__(func); \
int TouchDeviceEventWait##waiter_type##event_type() { \
__reg_event_wait_##waiter_type##event_type##__.Touch(); \
return 0; \
}
#define USE_EVENT(device_type) \
extern int TouchDeviceEventCreate##device_type(); \
extern int TouchDeviceEventRecord##device_type(); \
extern int TouchDeviceEventQuery##device_type(); \
extern int TouchDeviceEventFinish##device_type(); \
extern int TouchDeviceEventSetFinished##device_type(); \
UNUSED static int use_event_creator_##device_type = \
TouchDeviceEventCreate##device_type(); \
UNUSED static int use_event_recorder_##device_type = \
TouchDeviceEventRecord##device_type(); \
UNUSED static int use_event_querier_##device_type = \
TouchDeviceEventQuery##device_type(); \
UNUSED static int use_event_finisher_##device_type = \
TouchDeviceEventFinish##device_type(); \
UNUSED static int use_event_finished_setter_##device_type = \
TouchDeviceEventSetFinished##device_type();
#define USE_EVENT_WAIT(waiter_type, event_type) \
extern int TouchDeviceEventWait##waiter_type##event_type(); \
UNUSED static int use_event_waiter_##waiter_type##event_type = \
TouchDeviceEventWait##waiter_type##event_type();
} // namespace platform
} // namespace paddle
...@@ -12,16 +12,40 @@ ...@@ -12,16 +12,40 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/platform/device_event.h" #pragma once
#include <atomic>
#include <condition_variable>
#include <mutex>
#include "paddle/fluid/platform/device_event_base.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
EventCreateFunction DeviceEvent::event_creator_[MaxDeviceTypes]; struct CPUDeviceEventWrapper {
EventRecordFunction DeviceEvent::event_recorder_[MaxDeviceTypes]; explicit CPUDeviceEventWrapper(const platform::Place& place,
EventQueryFunction DeviceEvent::event_querier_[MaxDeviceTypes]; unsigned int flag = 0) {
EventFinishFunction DeviceEvent::event_finisher_[MaxDeviceTypes]; PADDLE_ENFORCE_EQ(
EventWaitFunction DeviceEvent::event_waiter_[MaxDeviceTypes][MaxDeviceTypes]; platform::is_cpu_place(place), true,
platform::errors::PreconditionNotMet(
"Required device shall be CPUAPlace, but received %d. ", place));
}
std::mutex mutex_;
std::condition_variable cv_completed_;
std::atomic<int> status_;
};
void DeviceEventCreateCPU(DeviceEvent* event, const platform::Place& place);
void DeviceEventRecordCPU(DeviceEvent* event, const platform::Place& place,
const DeviceContext* context);
bool DeviceEventQueryCPU(const DeviceEvent* event);
void DeviceEventFinishCPU(const DeviceEvent* event);
void EventSetFinishedCPU(const DeviceEvent* event);
void DeviceEventCPUWaitCPU(const DeviceEvent* event, DeviceContext* context);
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -12,38 +12,38 @@ ...@@ -12,38 +12,38 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/platform/device_event.h" #include "paddle/fluid/platform/device_event_base.h"
#include "paddle/fluid/platform/event.h" #include "paddle/fluid/platform/event.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
namespace paddle { namespace paddle {
namespace platform { namespace platform {
struct CUDADeviceEventWrapper { struct CUDADeviceEventWrapper {
explicit CUDADeviceEventWrapper(const DeviceOption& dev_opt) CUDADeviceEventWrapper(const platform::Place& place, unsigned int flag)
: inner_event_() { : inner_event_(flag) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dev_opt.device_type(), static_cast<int>(DeviceType::CUDA), platform::is_gpu_place(place), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Required device type shall be CUDA, but received %d. ", "Required device shall be CUDAPlace, but received %d. ", place));
dev_opt.device_type()));
device_id_ = BOOST_GET_CONST(platform::CUDAPlace, place).device;
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
dev_opt.device_id(), -1, device_id_, -1,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Required DeviceOption.device_id > -1, but received %d. ", "Required DeviceOption.device_id > -1, but received %d. ",
dev_opt.device_id())); device_id_));
device_id_ = dev_opt.device_id();
} }
CudaEvent inner_event_; CudaEvent inner_event_;
int device_id_; int device_id_;
}; };
void DeviceEventCreateCUDA(DeviceEvent* event, const DeviceOption& dev_opt) { void DeviceEventCreateCUDA(DeviceEvent* event, const platform::Place& place,
event->InitEvent(std::make_shared<CUDADeviceEventWrapper>(dev_opt)); unsigned int flag) {
event->InitEvent(std::make_shared<CUDADeviceEventWrapper>(place, flag));
} }
void DeviceEventRecordCUDA(DeviceEvent* event, const platform::Place& place, void DeviceEventRecordCUDA(DeviceEvent* event, const DeviceContext* context) {
const DeviceContext* context) {
auto* wrapper = static_cast<CUDADeviceEventWrapper*>(event->GetEvent().get()); auto* wrapper = static_cast<CUDADeviceEventWrapper*>(event->GetEvent().get());
auto* cuda_dev_ctx = auto* cuda_dev_ctx =
...@@ -72,7 +72,8 @@ void DeviceEventFinishCUDA(const DeviceEvent* event) { ...@@ -72,7 +72,8 @@ void DeviceEventFinishCUDA(const DeviceEvent* event) {
wrapper->inner_event_.Synchronize(); wrapper->inner_event_.Synchronize();
} }
void DeviceEventCUDAWaitCUDA(const DeviceEvent* event, DeviceContext* context) { void DeviceEventCUDAWaitCUDA(const DeviceEvent* event,
const DeviceContext* context) {
auto* wrapper = static_cast<CUDADeviceEventWrapper*>(event->GetEvent().get()); auto* wrapper = static_cast<CUDADeviceEventWrapper*>(event->GetEvent().get());
auto* cuda_dev_ctx = auto* cuda_dev_ctx =
dynamic_cast<const platform::CUDADeviceContext*>(context); dynamic_cast<const platform::CUDADeviceContext*>(context);
...@@ -85,10 +86,15 @@ void DeviceEventCUDAWaitCUDA(const DeviceEvent* event, DeviceContext* context) { ...@@ -85,10 +86,15 @@ void DeviceEventCUDAWaitCUDA(const DeviceEvent* event, DeviceContext* context) {
wrapper->inner_event_.GetRawCudaEvent()); wrapper->inner_event_.GetRawCudaEvent());
} }
void DeviceEventCPUWaitCUDA(const DeviceEvent* event, DeviceContext* context) { void DeviceEventCPUWaitCUDA(const DeviceEvent* event,
const DeviceContext* context) {
DeviceEventFinishCUDA(event); DeviceEventFinishCUDA(event);
} }
void DeviceEventSetFinishedCUDA(const DeviceEvent* event) {
// do nothing
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -98,6 +104,8 @@ REGISTER_EVENT_CREATE_FUNCTION(kCUDA, paddle::platform::DeviceEventCreateCUDA) ...@@ -98,6 +104,8 @@ REGISTER_EVENT_CREATE_FUNCTION(kCUDA, paddle::platform::DeviceEventCreateCUDA)
REGISTER_EVENT_RECORD_FUNCTION(kCUDA, paddle::platform::DeviceEventRecordCUDA) REGISTER_EVENT_RECORD_FUNCTION(kCUDA, paddle::platform::DeviceEventRecordCUDA)
REGISTER_EVENT_QUERY_FUNCTION(kCUDA, paddle::platform::DeviceEventQueryCUDA) REGISTER_EVENT_QUERY_FUNCTION(kCUDA, paddle::platform::DeviceEventQueryCUDA)
REGISTER_EVENT_FINISH_FUNCTION(kCUDA, paddle::platform::DeviceEventFinishCUDA) REGISTER_EVENT_FINISH_FUNCTION(kCUDA, paddle::platform::DeviceEventFinishCUDA)
REGISTER_EVENT_SET_FINISHED_FUNCTION(
kCUDA, paddle::platform::DeviceEventSetFinishedCUDA)
REGISTER_EVENT_WAIT_FUNCTION(kCUDA, kCUDA, REGISTER_EVENT_WAIT_FUNCTION(kCUDA, kCUDA,
paddle::platform::DeviceEventCUDAWaitCUDA) paddle::platform::DeviceEventCUDAWaitCUDA)
REGISTER_EVENT_WAIT_FUNCTION(kCPU, kCUDA, REGISTER_EVENT_WAIT_FUNCTION(kCPU, kCUDA,
......
...@@ -16,35 +16,30 @@ ...@@ -16,35 +16,30 @@
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
using ::paddle::platform::kCUDA; using ::paddle::platform::kCUDA;
using ::paddle::platform::kCPU; using ::paddle::platform::kCPU;
USE_EVENT(kCUDA);
USE_EVENT_WAIT(kCUDA, kCUDA) using paddle::platform::DeviceEvent;
USE_EVENT_WAIT(kCPU, kCUDA) using paddle::platform::DeviceContextPool;
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
TEST(DeviceEvent, CUDA) { TEST(DeviceEvent, CUDA) {
VLOG(1) << "In Test"; VLOG(1) << "In Test";
using paddle::platform::CUDAPlace; using paddle::platform::CUDAPlace;
using paddle::platform::DeviceOption;
using paddle::platform::DeviceEvent;
using paddle::platform::DeviceContextPool;
using paddle::platform::DeviceType;
auto& pool = DeviceContextPool::Instance(); auto& pool = DeviceContextPool::Instance();
auto place = CUDAPlace(0); auto place = CUDAPlace(0);
auto* context = auto* context =
static_cast<paddle::platform::CUDADeviceContext*>(pool.Get(place)); static_cast<paddle::platform::CUDADeviceContext*>(pool.Get(place));
int device_type = static_cast<int>(DeviceType::CUDA);
DeviceOption dev_opt(device_type, place.device);
ASSERT_NE(context, nullptr); ASSERT_NE(context, nullptr);
// case 1. test for event_creator // case 1. test for event_creator
DeviceEvent event(dev_opt); DeviceEvent event(place);
ASSERT_NE(event.GetEvent().get(), nullptr); ASSERT_NE(event.GetEvent().get(), nullptr);
// case 2. test for event_recorder // case 2. test for event_recorder
event.Record(place, context); event.Record(context);
bool status = event.Query(); bool status = event.Query();
ASSERT_EQ(status, false); ASSERT_EQ(status, false);
// case 3. test for event_finisher // case 3. test for event_finisher
...@@ -59,7 +54,7 @@ TEST(DeviceEvent, CUDA) { ...@@ -59,7 +54,7 @@ TEST(DeviceEvent, CUDA) {
cudaMalloc(reinterpret_cast<void**>(&dst_fp32), size); cudaMalloc(reinterpret_cast<void**>(&dst_fp32), size);
cudaMemcpyAsync(dst_fp32, src_fp32, size, cudaMemcpyHostToDevice, cudaMemcpyAsync(dst_fp32, src_fp32, size, cudaMemcpyHostToDevice,
context->stream()); context->stream());
event.Record(place, context); // step 1. record it event.Record(context); // step 1. record it
status = event.Query(); status = event.Query();
ASSERT_EQ(status, false); ASSERT_EQ(status, false);
...@@ -76,3 +71,17 @@ TEST(DeviceEvent, CUDA) { ...@@ -76,3 +71,17 @@ TEST(DeviceEvent, CUDA) {
cudaFreeHost(src_fp32); cudaFreeHost(src_fp32);
} }
#endif #endif
TEST(DeviceEvent, CPU) {
using paddle::platform::CPUPlace;
auto place = CPUPlace();
DeviceEvent event(place);
auto& pool = DeviceContextPool::Instance();
auto* context = pool.Get(place);
// TODO(Aurelius84): All DeviceContext should has Record/Wait
event.Record(context);
event.SetFininshed();
bool status = event.Query();
ASSERT_EQ(status, true);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册