未验证 提交 48bf7cbf 编写于 作者: A Aurelius84 提交者: GitHub

Polish DeviceEvent interface and Remove #ifdef in InterpreterCore (#35196)

* add CPUDeiveEvent

* Polish DeviceEvent code

* Add DEVICE_EVENT_LIBS
上级 7272526b
......@@ -2,7 +2,7 @@ cc_library(workqueue SRCS workqueue.cc)
cc_library(interpretercore SRCS interpretercore.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor workqueue device_event device_event_gpu)
graph_to_program_pass variable_helper timer monitor workqueue ${DEVICE_EVENT_LIBS})
cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore)
cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue)
# cc_binary(standalone_executor_test SRCS standalone_executor_test.cc DEPS interpretercore standalone_executor operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)
......@@ -12,16 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/new_executor/interpretercore_gc_helper.h"
#if defined(PADDLE_WITH_CUDA)
using ::paddle::platform::kCUDA;
USE_EVENT(kCUDA);
#endif
#include <unordered_set>
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/new_executor/interpretercore_gc_helper.h"
namespace paddle {
namespace framework {
......@@ -74,27 +70,26 @@ std::vector<size_t> ParseEventVarIds(const Instruction& cur_instr,
}
void AssociateInputWithEvents(
const std::vector<size_t>& new_event_var_id, Instruction* next_instr,
std::map<size_t, std::shared_ptr<platform::CudaEvent>>* var_id2event,
const platform::Place& place, const std::vector<size_t>& new_event_var_id,
Instruction* next_instr,
std::map<size_t, std::shared_ptr<platform::DeviceEvent>>* var_id2event,
bool is_sync) {
#ifdef PADDLE_WITH_CUDA
for (auto var_id : new_event_var_id) {
if (var_id2event->count(var_id) == 0) {
auto cuda_event = std::make_shared<platform::CudaEvent>(
platform::get_cuda_flags(false, false, false));
var_id2event->emplace(var_id, std::move(cuda_event));
auto device_event = std::make_shared<platform::DeviceEvent>(
place, platform::get_cuda_flags(false, false, false));
var_id2event->emplace(var_id, std::move(device_event));
}
// Add events for next_instr.inputs
next_instr->intput_events_.emplace_back(var_id, var_id2event->at(var_id),
is_sync);
}
#endif
}
void ParseDirectAndEventRunOps(
const std::vector<OpFuncNode>& op_func_nodes,
const platform::Place& place, const std::vector<OpFuncNode>& op_func_nodes,
const std::vector<size_t>& downstream_ops, size_t op_index,
std::map<size_t, std::shared_ptr<platform::CudaEvent>>* var_id2event,
std::map<size_t, std::shared_ptr<platform::DeviceEvent>>* var_id2event,
std::vector<Instruction>* instructions) {
auto& op_func_type = op_func_nodes[op_index].type_;
auto& cur_instr = instructions->at(op_index);
......@@ -119,8 +114,8 @@ void ParseDirectAndEventRunOps(
bool is_sync =
(op_func_nodes[next_op_id].type_ == OpFuncType::kQueueSync);
AssociateInputWithEvents(new_event_var_ids, &next_instr, var_id2event,
is_sync);
AssociateInputWithEvents(place, new_event_var_ids, &next_instr,
var_id2event, is_sync);
if (is_sync) { // GPU -> CPU
next_instruction.synchronize_run_.emplace_back(next_op_id);
......@@ -128,7 +123,6 @@ void ParseDirectAndEventRunOps(
next_instruction.event_wait_run_.emplace_back(next_op_id);
}
}
#ifdef PADDLE_WITH_CUDA
// Create events for these cross-stream vars
VLOG(3) << cur_instr.kernel_func_.operator_base_->Type()
<< " event_var_ids.size: " << event_var_ids.size();
......@@ -136,7 +130,6 @@ void ParseDirectAndEventRunOps(
cur_instr.output_events_.emplace_back(var_id, var_id2event->at(var_id),
false /*not used*/);
}
#endif
}
}
} // namespace
......@@ -263,12 +256,10 @@ void InterpreterCore::Convert() {
}
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
#if defined(PADDLE_WITH_CUDA)
int device_type = static_cast<int>(paddle::platform::DeviceType::CUDA);
paddle::platform::DeviceOption dev_opt(
device_type, BOOST_GET_CONST(platform::CUDAPlace, place_).device);
gc_event_.emplace_back(dev_opt);
#endif
// int device_type = static_cast<int>(paddle::platform::DeviceType::CUDA);
// paddle::platform::DeviceOption dev_opt(
// device_type, BOOST_GET_CONST(platform::CUDAPlace, place_).device);
gc_event_.emplace_back(place_);
std::vector<size_t> vec_temp;
for (auto& item : vec_instruction_[i].output_index_) {
......@@ -287,8 +278,8 @@ void InterpreterCore::Convert() {
}
}
ParseDirectAndEventRunOps(vec_func_list_, filter_next, i, &var_id2event_,
&vec_instruction_);
ParseDirectAndEventRunOps(place_, vec_func_list_, filter_next, i,
&var_id2event_, &vec_instruction_);
// checkout ouput
for (auto& item : vec_instruction_[i].output_index_) {
......@@ -466,7 +457,7 @@ void InterpreterCore::CheckGC(size_t instr_id,
#if defined(PADDLE_WITH_CUDA)
auto* dev_ctx = reinterpret_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
gc_event_[instr_id].Record(place, dev_ctx);
gc_event_[instr_id].Record(dev_ctx);
gc_queue_->AddTask(
[ container = garbages_.release(), event = &gc_event_[instr_id] ]() {
while (!event->Query()) {
......@@ -483,7 +474,7 @@ void InterpreterCore::CheckGC(size_t instr_id,
#if defined(PADDLE_WITH_CUDA)
auto* dev_ctx = reinterpret_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
gc_event_[instr_id].Record(place, dev_ctx);
gc_event_[instr_id].Record(dev_ctx);
gc_queue_->AddTask(
[ container = garbages_.release(), event = &gc_event_[instr_id] ]() {
while (!event->Query()) {
......@@ -857,34 +848,23 @@ void InterpreterCore::RecordEventInstruction(const Instruction& instruction,
// If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(place_)) return;
#ifdef PADDLE_WITH_CUDA
const platform::CUDADeviceContext* dev_ctx =
reinterpret_cast<const platform::CUDADeviceContext*>(
instruction.dev_ctx_);
for (auto& event : instruction.output_events_) {
VLOG(3) << "Record event in out_var_id: " << event.var_id_;
event.event_->Record(*(dev_ctx->context()->Stream()));
event.event_->Record(instruction.dev_ctx_);
}
#endif
}
void InterpreterCore::WaitOrSync(const std::vector<EventInter>& events,
const platform::DeviceContext* dev_ctx) {
#ifdef PADDLE_WITH_CUDA
auto* cuda_dev_ctx =
reinterpret_cast<const platform::CUDADeviceContext*>(dev_ctx);
for (auto& event : events) {
if (event.is_sync_) {
VLOG(3) << "host sync wait in_var_id " << event.var_id_;
event.event_->Synchronize();
for (auto& event_iter : events) {
if (event_iter.is_sync_) {
VLOG(3) << "host sync wait in_var_id " << event_iter.var_id_;
event_iter.event_->Wait(platform::kCPU, dev_ctx);
} else {
VLOG(3) << "stream async wait in_var_id " << event.var_id_;
cuda_dev_ctx->context()->Stream()->WaitEvent(
event.event_->GetRawCudaEvent());
VLOG(3) << "stream async wait in_var_id " << event_iter.var_id_;
event_iter.event_->Wait(platform::kCUDA, dev_ctx);
}
}
#endif
}
void InterpreterCore::StreamWaitEventOrSync(const Instruction& instruction) {
......
......@@ -26,7 +26,6 @@
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_event.h"
#include "paddle/fluid/platform/event.h"
namespace paddle {
namespace framework {
......@@ -101,7 +100,7 @@ class InterpreterCore {
bool is_build_;
std::vector<std::string> feed_names_;
std::map<size_t, std::shared_ptr<platform::CudaEvent>> var_id2event_;
std::map<size_t, std::shared_ptr<platform::DeviceEvent>> var_id2event_;
std::vector<paddle::platform::DeviceEvent> gc_event_;
std::unique_ptr<GarbageQueue> garbages_;
......
......@@ -19,6 +19,7 @@
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device_event_base.h"
#include "paddle/fluid/platform/event.h"
namespace paddle {
......@@ -56,11 +57,12 @@ struct NextInstruction {
};
struct EventInter {
explicit EventInter(size_t var_id, std::shared_ptr<platform::CudaEvent> event,
explicit EventInter(size_t var_id,
std::shared_ptr<platform::DeviceEvent> event,
bool is_sync)
: var_id_(var_id), event_(event), is_sync_(is_sync) {}
size_t var_id_;
std::shared_ptr<platform::CudaEvent> event_;
std::shared_ptr<platform::DeviceEvent> event_;
bool is_sync_;
};
......
......@@ -151,19 +151,28 @@ endif()
cc_test(init_test SRCS init_test.cc DEPS device_context)
cc_library(device_event SRCS device_event.cc DEPS place enforce device_context op_registry)
cc_library(device_event_gpu SRCS device_event_gpu.cc DEPS device_event)
# Manage all device event library
set(DEVICE_EVENT_LIBS)
cc_library(device_event_base SRCS device_event_base.cc DEPS place enforce device_context op_registry)
set(DEVICE_EVENT_LIBS device_event_base CACHE INTERNAL "device event libs")
if(WITH_GPU)
nv_library(device_event_gpu SRCS device_event_gpu.cc DEPS device_event_base)
set(DEVICE_EVENT_LIBS device_event_gpu CACHE INTERNAL "device event libs")
nv_test(device_event_test SRCS device_event_test.cc DEPS device_event_gpu)
nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)
nv_test(cudnn_desc_test SRCS cudnn_desc_test.cc DEPS dynload_cuda)
nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context)
nv_test(device_event_test SRCS device_event_test.cc DEPS device_event_gpu)
endif()
if(WITH_ROCM)
hip_library(device_event_gpu SRCS device_event_gpu.cc DEPS device_event_base)
set(DEVICE_EVENT_LIBS device_event_gpu CACHE INTERNAL "device event libs")
hip_test(device_event_test SRCS device_event_test.cc DEPS device_event_gpu)
hip_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
hip_test(miopen_helper_test SRCS miopen_helper_test.cc DEPS dynload_cuda)
hip_test(cudnn_desc_test SRCS cudnn_desc_test.cc DEPS dynload_cuda tensor)
......
......@@ -11,267 +11,26 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace platform {
class DeviceOption;
class DeviceEvent;
constexpr int MaxDeviceTypes =
static_cast<int>(platform::DeviceType::MAX_DEVICE_TYPES);
typedef void (*EventCreateFunction)(DeviceEvent*, const DeviceOption&);
typedef void (*EventRecordFunction)(DeviceEvent*, const platform::Place&,
const DeviceContext*);
typedef bool (*EventQueryFunction)(const DeviceEvent*);
typedef void (*EventFinishFunction)(const DeviceEvent*);
typedef void (*EventWaitFunction)(const DeviceEvent*, DeviceContext*);
inline int DeviceTypeToId(const DeviceType& device_type) {
return static_cast<int>(device_type);
}
class DeviceOption {
public:
explicit DeviceOption(int device_type) : device_type_(device_type) {}
DeviceOption(int device_type, int device_id)
: device_type_(device_type), device_id_(device_id) {}
int device_type() const { return device_type_; }
int device_id() const { return device_id_; }
private:
int device_type_;
int device_id_;
};
class DeviceEvent {
public:
explicit DeviceEvent(const DeviceOption& device_option)
: event_(),
type_(device_option.device_type()),
device_option_(device_option) {
PADDLE_ENFORCE_LT(type_, MaxDeviceTypes,
platform::errors::PreconditionNotMet(
"Required type < %d, but received type = %d",
MaxDeviceTypes, type_));
PADDLE_ENFORCE_NOT_NULL(
event_creator_[type_],
platform::errors::Unavailable(
"event_creator_[%d] shall not be nullptr.", type_));
event_creator_[type_](this, device_option_);
}
~DeviceEvent() {}
void Record(const platform::Place& place, const DeviceContext* dev_ctx) {
PADDLE_ENFORCE_NOT_NULL(
event_recorder_[type_],
platform::errors::Unavailable(
"event_recorder_[%d] shall not be nullptr.", type_));
event_recorder_[type_](this, place, dev_ctx);
}
bool Query() {
PADDLE_ENFORCE_NOT_NULL(
event_querier_[type_],
platform::errors::Unavailable(
"event_querier_[%d] shall not be nullptr.", type_));
return event_querier_[type_](this);
}
void Finish() const {
PADDLE_ENFORCE_NOT_NULL(
event_finisher_[type_],
platform::errors::Unavailable(
"event_finisher_[%d] shall not be nullptr.", type_));
event_finisher_[type_](this);
}
void Wait(const DeviceType& waiter_type, DeviceContext* context) const {
auto waiter_idx = DeviceTypeToId(waiter_type);
PADDLE_ENFORCE_NOT_NULL(
event_waiter_[waiter_idx][type_],
platform::errors::Unavailable(
"event_waiter_[%d][%d] shall not be nullptr.", waiter_idx, type_));
event_waiter_[waiter_idx][type_](this, context);
}
void InitEvent(std::shared_ptr<void> event) { event_ = event; }
std::shared_ptr<void> GetEvent() const { return event_; }
private:
std::shared_ptr<void> event_;
int type_;
DeviceOption device_option_;
static EventCreateFunction event_creator_[MaxDeviceTypes];
static EventRecordFunction event_recorder_[MaxDeviceTypes];
static EventQueryFunction event_querier_[MaxDeviceTypes];
static EventFinishFunction event_finisher_[MaxDeviceTypes];
static EventWaitFunction event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
template <DeviceType device_typ>
friend struct EventCreateFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventRecordFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventQueryFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventFinishFunctionRegisterer;
template <DeviceType waiter_typ, DeviceType event_type>
friend struct EventWaitFunctionRegisterer;
};
/**
* check if MACRO is used in GLOBAL NAMESPACE.
#pragma once
#include "paddle/fluid/platform/device_event_base.h"
/*
* NOTE: Now we generate this file manually and will consider
* automatically generate it later. Just as 'paddle/fluid/pybind/pybind.h'
* for USE_OP from op_library macros, and
* `paddle/fluid/inference/paddle_inference_pass.h`
* for USE_PASS from pass_library.
*/
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
// =============== Register for Create ===============
template <DeviceType device_type>
struct EventCreateFunctionRegisterer : public framework::Registrar {
explicit EventCreateFunctionRegisterer(EventCreateFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_creator with type_id :" << type_idx;
DeviceEvent::event_creator_[type_idx] = func;
}
};
#define REGISTER_EVENT_CREATE_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_creator__##device_type, \
"REGISTER_EVENT_CREATE_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventCreateFunctionRegisterer<device_type> \
__reg_event_create_##device_type##__(func); \
int TouchDeviceEventCreate##device_type() { \
__reg_event_create_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Record ===============
template <DeviceType device_type>
struct EventRecordFunctionRegisterer : public framework::Registrar {
explicit EventRecordFunctionRegisterer(EventRecordFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_recorder with type_id :" << type_idx;
DeviceEvent::event_recorder_[type_idx] = func;
}
};
#define REGISTER_EVENT_RECORD_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_recorder__##device_type, \
"REGISTER_EVENT_RECORD_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventRecordFunctionRegisterer<device_type> \
__reg_event_record_##device_type##__(func); \
int TouchDeviceEventRecord##device_type() { \
__reg_event_record_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Query ===============
template <DeviceType device_type>
struct EventQueryFunctionRegisterer : public framework::Registrar {
explicit EventQueryFunctionRegisterer(EventQueryFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_querier with type_id :" << type_idx;
DeviceEvent::event_querier_[type_idx] = func;
}
};
#define REGISTER_EVENT_QUERY_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_querier__##device_type, \
"REGISTER_EVENT_QUERY_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventQueryFunctionRegisterer<device_type> \
__reg_event_query_##device_type##__(func); \
int TouchDeviceEventQuery##device_type() { \
__reg_event_query_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Finish ===============
template <DeviceType device_type>
struct EventFinishFunctionRegisterer : public framework::Registrar {
explicit EventFinishFunctionRegisterer(EventFinishFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_finisher with type_id :" << type_idx;
DeviceEvent::event_finisher_[type_idx] = func;
}
};
#define REGISTER_EVENT_FINISH_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_finishier__##device_type, \
"REGISTER_EVENT_FINISH_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventFinishFunctionRegisterer<device_type> \
__reg_event_finish_##device_type##__(func); \
int TouchDeviceEventFinish##device_type() { \
__reg_event_finish_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Wait ===============
template <DeviceType waiter_type, DeviceType event_type>
struct EventWaitFunctionRegisterer : public framework::Registrar {
explicit EventWaitFunctionRegisterer(EventWaitFunction func) {
auto waiter_idx = DeviceTypeToId(waiter_type);
auto event_idx = DeviceTypeToId(event_type);
VLOG(3) << "register event_finisher with waiter_idx : " << waiter_idx
<< ", event_idx : " << event_idx;
DeviceEvent::event_waiter_[waiter_idx][event_idx] = func;
}
};
#define REGISTER_EVENT_WAIT_FUNCTION(waiter_type, event_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_waiter__##waiter_type##event_type, \
"REGISTER_EVENT_WAIT_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventWaitFunctionRegisterer<waiter_type, \
event_type> \
__reg_event_wait_##waiter_type##event_type##__(func); \
int TouchDeviceEventWait##waiter_type##event_type() { \
__reg_event_wait_##waiter_type##event_type##__.Touch(); \
return 0; \
}
#define USE_EVENT(device_type) \
extern int TouchDeviceEventCreate##device_type(); \
extern int TouchDeviceEventRecord##device_type(); \
extern int TouchDeviceEventQuery##device_type(); \
extern int TouchDeviceEventFinish##device_type(); \
UNUSED static int use_event_creator_##device_type = \
TouchDeviceEventCreate##device_type(); \
UNUSED static int use_event_recorder_##device_type = \
TouchDeviceEventRecord##device_type(); \
UNUSED static int use_event_querier_##device_type = \
TouchDeviceEventQuery##device_type(); \
UNUSED static int use_event_finisher_##device_type = \
TouchDeviceEventFinish##device_type();
using ::paddle::platform::kCUDA;
using ::paddle::platform::kCPU;
#define USE_EVENT_WAIT(waiter_type, event_type) \
extern int TouchDeviceEventWait##waiter_type##event_type(); \
UNUSED static int use_event_waiter_##waiter_type##event_type = \
TouchDeviceEventWait##waiter_type##event_type();
USE_EVENT(kCPU)
USE_EVENT_WAIT(kCPU, kCPU)
} // namespace platform
} // namespace paddle
#ifdef PADDLE_WITH_CUDA
USE_EVENT(kCUDA);
USE_EVENT_WAIT(kCUDA, kCUDA)
USE_EVENT_WAIT(kCPU, kCUDA)
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device_event_base.h"
#include "paddle/fluid/platform/device_event_cpu.h"
namespace paddle {
namespace platform {
EventCreateFunction DeviceEvent::event_creator_[MaxDeviceTypes];
EventRecordFunction DeviceEvent::event_recorder_[MaxDeviceTypes];
EventQueryFunction DeviceEvent::event_querier_[MaxDeviceTypes];
EventFinishFunction DeviceEvent::event_finisher_[MaxDeviceTypes];
EventFinishFunction DeviceEvent::event_finished_setter_[MaxDeviceTypes];
EventWaitFunction DeviceEvent::event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
void DeviceEventCreateCPU(DeviceEvent* event, const platform::Place& place,
unsigned int flag) {
event->InitEvent(std::make_shared<CPUDeviceEventWrapper>(place, flag));
}
void DeviceEventRecordCPU(DeviceEvent* event, const DeviceContext* context) {
auto* wrapper = static_cast<CPUDeviceEventWrapper*>(event->GetEvent().get());
std::unique_lock<std::mutex> lock(wrapper->mutex_);
PADDLE_ENFORCE_NE(wrapper->status_.load(), EventStatus::SCHEDULED,
platform::errors::PreconditionNotMet(
"EventStatus shall be not SCHEDULED before Record()"));
if (wrapper->status_ == EventStatus::INITIALIZED) {
wrapper->status_ = EventStatus::SCHEDULED;
}
}
bool DeviceEventQueryCPU(const DeviceEvent* event) {
auto* wrapper = static_cast<CPUDeviceEventWrapper*>(event->GetEvent().get());
PADDLE_ENFORCE_NOT_NULL(
wrapper, platform::errors::PreconditionNotMet(
"Failed to dynamic_cast event into CPUDeviceEventWrapper."));
return wrapper->status_ == EventStatus::SUCCESS;
}
void DeviceEventFinishCPU(const DeviceEvent* event) {
auto* wrapper = static_cast<CPUDeviceEventWrapper*>(event->GetEvent().get());
std::unique_lock<std::mutex> lock(wrapper->mutex_);
while (wrapper->status_ != EventStatus::SUCCESS &&
wrapper->status_ != EventStatus::FAILED) {
wrapper->cv_completed_.wait(lock);
}
}
void DeviceEventCPUWaitCPU(const DeviceEvent* event,
const DeviceContext* context) {
DeviceEventFinishCPU(event);
}
void EventSetFinishedCPU(const DeviceEvent* event) {
auto* wrapper = static_cast<CPUDeviceEventWrapper*>(event->GetEvent().get());
std::unique_lock<std::mutex> lock(wrapper->mutex_);
PADDLE_ENFORCE_LE(wrapper->status_.load(), EventStatus::SCHEDULED,
platform::errors::PreconditionNotMet(
"EventStatus shall be INITIALIZED | SCHEDULED before "
"EventSetFinishedCPU()"));
wrapper->status_ = EventStatus::SUCCESS;
wrapper->cv_completed_.notify_all();
}
} // namespace platform
} // namespace paddle
using ::paddle::platform::kCPU;
REGISTER_EVENT_CREATE_FUNCTION(kCPU, paddle::platform::DeviceEventCreateCPU)
REGISTER_EVENT_RECORD_FUNCTION(kCPU, paddle::platform::DeviceEventRecordCPU)
REGISTER_EVENT_QUERY_FUNCTION(kCPU, paddle::platform::DeviceEventQueryCPU)
REGISTER_EVENT_FINISH_FUNCTION(kCPU, paddle::platform::DeviceEventFinishCPU)
REGISTER_EVENT_SET_FINISHED_FUNCTION(kCPU,
paddle::platform::EventSetFinishedCPU);
REGISTER_EVENT_WAIT_FUNCTION(kCPU, kCPU,
paddle::platform::DeviceEventCPUWaitCPU)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace platform {
class DeviceOption;
class DeviceEvent;
constexpr int MaxDeviceTypes =
static_cast<int>(platform::DeviceType::MAX_DEVICE_TYPES);
typedef void (*EventCreateFunction)(DeviceEvent*, const platform::Place&,
unsigned int flag);
typedef void (*EventRecordFunction)(DeviceEvent*, const DeviceContext*);
typedef bool (*EventQueryFunction)(const DeviceEvent*);
typedef void (*EventFinishFunction)(const DeviceEvent*);
typedef void (*EventSetFinishedFunction)(const DeviceEvent*);
typedef void (*EventWaitFunction)(const DeviceEvent*, const DeviceContext*);
inline int DeviceTypeToId(const DeviceType& device_type) {
return static_cast<int>(device_type);
}
enum EventStatus {
INITIALIZED = 0,
SCHEDULED = 1,
SUCCESS = 2,
FAILED = 3,
};
class DeviceEvent {
public:
explicit DeviceEvent(const platform::Place& place, unsigned int flag = 0)
: event_(), place_(place), flag_(flag) {
type_id_ = DeviceTypeToId(platform::Place2DeviceType(place));
PADDLE_ENFORCE_LT(type_id_, MaxDeviceTypes,
platform::errors::PreconditionNotMet(
"Required type < %d, but received type = %d",
MaxDeviceTypes, type_id_));
// TODO(Aurelius84): only support CPU/CUDA, need consider XPU/NPU later
PADDLE_ENFORCE_LT(type_id_, 3,
platform::errors::Unavailable(
"Currently DeviceEvent do not support %s", place));
PADDLE_ENFORCE_NOT_NULL(
event_creator_[type_id_],
platform::errors::Unavailable(
"event_creator_[%d] shall not be nullptr.", type_id_));
event_creator_[type_id_](this, place, flag);
}
~DeviceEvent() {}
void Record(const DeviceContext* dev_ctx) {
PADDLE_ENFORCE_NOT_NULL(
event_recorder_[type_id_],
platform::errors::Unavailable(
"event_recorder_[%d] shall not be nullptr.", type_id_));
event_recorder_[type_id_](this, dev_ctx);
}
bool Query() {
PADDLE_ENFORCE_NOT_NULL(
event_querier_[type_id_],
platform::errors::Unavailable(
"event_querier_[%d] shall not be nullptr.", type_id_));
return event_querier_[type_id_](this);
}
void Finish() const {
PADDLE_ENFORCE_NOT_NULL(
event_finisher_[type_id_],
platform::errors::Unavailable(
"event_finisher_[%d] shall not be nullptr.", type_id_));
event_finisher_[type_id_](this);
}
void SetFininshed() {
PADDLE_ENFORCE_NOT_NULL(
event_finished_setter_[type_id_],
platform::errors::Unavailable(
"event_finished_setter_[%d] shall not be nullptr.", type_id_));
event_finished_setter_[type_id_](this);
}
void Wait(const DeviceType& waiter_type, const DeviceContext* context) const {
auto waiter_idx = DeviceTypeToId(waiter_type);
PADDLE_ENFORCE_NOT_NULL(event_waiter_[waiter_idx][type_id_],
platform::errors::Unavailable(
"event_waiter_[%d][%d] shall not be nullptr.",
waiter_idx, type_id_));
event_waiter_[waiter_idx][type_id_](this, context);
}
void InitEvent(std::shared_ptr<void> event) { event_ = event; }
std::shared_ptr<void> GetEvent() const { return event_; }
private:
std::shared_ptr<void> event_;
platform::Place place_;
int type_id_;
unsigned int flag_;
static EventCreateFunction event_creator_[MaxDeviceTypes];
static EventRecordFunction event_recorder_[MaxDeviceTypes];
static EventQueryFunction event_querier_[MaxDeviceTypes];
static EventFinishFunction event_finisher_[MaxDeviceTypes];
static EventFinishFunction event_finished_setter_[MaxDeviceTypes];
static EventWaitFunction event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
template <DeviceType device_typ>
friend struct EventCreateFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventRecordFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventQueryFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventFinishFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventSetFinishedFunctionRegisterer;
template <DeviceType waiter_typ, DeviceType event_type>
friend struct EventWaitFunctionRegisterer;
};
/**
* check if MACRO is used in GLOBAL NAMESPACE.
*/
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
// =============== Register for Create ===============
template <DeviceType device_type>
struct EventCreateFunctionRegisterer : public framework::Registrar {
explicit EventCreateFunctionRegisterer(EventCreateFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_creator with type_id :" << type_idx;
DeviceEvent::event_creator_[type_idx] = func;
}
};
#define REGISTER_EVENT_CREATE_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_creator__##device_type, \
"REGISTER_EVENT_CREATE_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventCreateFunctionRegisterer<device_type> \
__reg_event_create_##device_type##__(func); \
int TouchDeviceEventCreate##device_type() { \
__reg_event_create_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Record ===============
template <DeviceType device_type>
struct EventRecordFunctionRegisterer : public framework::Registrar {
explicit EventRecordFunctionRegisterer(EventRecordFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_recorder with type_id :" << type_idx;
DeviceEvent::event_recorder_[type_idx] = func;
}
};
#define REGISTER_EVENT_RECORD_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_recorder__##device_type, \
"REGISTER_EVENT_RECORD_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventRecordFunctionRegisterer<device_type> \
__reg_event_record_##device_type##__(func); \
int TouchDeviceEventRecord##device_type() { \
__reg_event_record_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Query ===============
template <DeviceType device_type>
struct EventQueryFunctionRegisterer : public framework::Registrar {
explicit EventQueryFunctionRegisterer(EventQueryFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_querier with type_id :" << type_idx;
DeviceEvent::event_querier_[type_idx] = func;
}
};
#define REGISTER_EVENT_QUERY_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_querier__##device_type, \
"REGISTER_EVENT_QUERY_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventQueryFunctionRegisterer<device_type> \
__reg_event_query_##device_type##__(func); \
int TouchDeviceEventQuery##device_type() { \
__reg_event_query_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Finish ===============
template <DeviceType device_type>
struct EventFinishFunctionRegisterer : public framework::Registrar {
explicit EventFinishFunctionRegisterer(EventFinishFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_finisher with type_id :" << type_idx;
DeviceEvent::event_finisher_[type_idx] = func;
}
};
#define REGISTER_EVENT_FINISH_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_finishier__##device_type, \
"REGISTER_EVENT_FINISH_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventFinishFunctionRegisterer<device_type> \
__reg_event_finish_##device_type##__(func); \
int TouchDeviceEventFinish##device_type() { \
__reg_event_finish_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for SetFinished ===============
template <DeviceType device_type>
struct EventSetFinishedFunctionRegisterer : public framework::Registrar {
explicit EventSetFinishedFunctionRegisterer(EventSetFinishedFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_finished_setter with type_id :" << type_idx;
DeviceEvent::event_finished_setter_[type_idx] = func;
}
};
#define REGISTER_EVENT_SET_FINISHED_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_finished_setter__##device_type, \
"REGISTER_EVENT_FINISH_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventSetFinishedFunctionRegisterer<device_type> \
__reg_event_finished_setter_##device_type##__(func); \
int TouchDeviceEventSetFinished##device_type() { \
__reg_event_finished_setter_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Wait ===============
template <DeviceType waiter_type, DeviceType event_type>
struct EventWaitFunctionRegisterer : public framework::Registrar {
explicit EventWaitFunctionRegisterer(EventWaitFunction func) {
auto waiter_idx = DeviceTypeToId(waiter_type);
auto event_idx = DeviceTypeToId(event_type);
VLOG(3) << "register event_finisher with waiter_idx : " << waiter_idx
<< ", event_idx : " << event_idx;
DeviceEvent::event_waiter_[waiter_idx][event_idx] = func;
}
};
#define REGISTER_EVENT_WAIT_FUNCTION(waiter_type, event_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_waiter__##waiter_type##event_type, \
"REGISTER_EVENT_WAIT_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventWaitFunctionRegisterer<waiter_type, \
event_type> \
__reg_event_wait_##waiter_type##event_type##__(func); \
int TouchDeviceEventWait##waiter_type##event_type() { \
__reg_event_wait_##waiter_type##event_type##__.Touch(); \
return 0; \
}
#define USE_EVENT(device_type) \
extern int TouchDeviceEventCreate##device_type(); \
extern int TouchDeviceEventRecord##device_type(); \
extern int TouchDeviceEventQuery##device_type(); \
extern int TouchDeviceEventFinish##device_type(); \
extern int TouchDeviceEventSetFinished##device_type(); \
UNUSED static int use_event_creator_##device_type = \
TouchDeviceEventCreate##device_type(); \
UNUSED static int use_event_recorder_##device_type = \
TouchDeviceEventRecord##device_type(); \
UNUSED static int use_event_querier_##device_type = \
TouchDeviceEventQuery##device_type(); \
UNUSED static int use_event_finisher_##device_type = \
TouchDeviceEventFinish##device_type(); \
UNUSED static int use_event_finished_setter_##device_type = \
TouchDeviceEventSetFinished##device_type();
#define USE_EVENT_WAIT(waiter_type, event_type) \
extern int TouchDeviceEventWait##waiter_type##event_type(); \
UNUSED static int use_event_waiter_##waiter_type##event_type = \
TouchDeviceEventWait##waiter_type##event_type();
} // namespace platform
} // namespace paddle
......@@ -12,16 +12,40 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device_event.h"
#pragma once
#include <atomic>
#include <condition_variable>
#include <mutex>
#include "paddle/fluid/platform/device_event_base.h"
namespace paddle {
namespace platform {
EventCreateFunction DeviceEvent::event_creator_[MaxDeviceTypes];
EventRecordFunction DeviceEvent::event_recorder_[MaxDeviceTypes];
EventQueryFunction DeviceEvent::event_querier_[MaxDeviceTypes];
EventFinishFunction DeviceEvent::event_finisher_[MaxDeviceTypes];
EventWaitFunction DeviceEvent::event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
struct CPUDeviceEventWrapper {
explicit CPUDeviceEventWrapper(const platform::Place& place,
unsigned int flag = 0) {
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(place), true,
platform::errors::PreconditionNotMet(
"Required device shall be CPUAPlace, but received %d. ", place));
}
std::mutex mutex_;
std::condition_variable cv_completed_;
std::atomic<int> status_;
};
void DeviceEventCreateCPU(DeviceEvent* event, const platform::Place& place);
void DeviceEventRecordCPU(DeviceEvent* event, const platform::Place& place,
const DeviceContext* context);
bool DeviceEventQueryCPU(const DeviceEvent* event);
void DeviceEventFinishCPU(const DeviceEvent* event);
void EventSetFinishedCPU(const DeviceEvent* event);
void DeviceEventCPUWaitCPU(const DeviceEvent* event, DeviceContext* context);
} // namespace platform
} // namespace paddle
......@@ -12,38 +12,38 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device_event.h"
#include "paddle/fluid/platform/device_event_base.h"
#include "paddle/fluid/platform/event.h"
#ifdef PADDLE_WITH_CUDA
namespace paddle {
namespace platform {
struct CUDADeviceEventWrapper {
explicit CUDADeviceEventWrapper(const DeviceOption& dev_opt)
: inner_event_() {
CUDADeviceEventWrapper(const platform::Place& place, unsigned int flag)
: inner_event_(flag) {
PADDLE_ENFORCE_EQ(
dev_opt.device_type(), static_cast<int>(DeviceType::CUDA),
platform::is_gpu_place(place), true,
platform::errors::PreconditionNotMet(
"Required device type shall be CUDA, but received %d. ",
dev_opt.device_type()));
"Required device shall be CUDAPlace, but received %d. ", place));
device_id_ = BOOST_GET_CONST(platform::CUDAPlace, place).device;
PADDLE_ENFORCE_GT(
dev_opt.device_id(), -1,
device_id_, -1,
platform::errors::PreconditionNotMet(
"Required DeviceOption.device_id > -1, but received %d. ",
dev_opt.device_id()));
device_id_ = dev_opt.device_id();
device_id_));
}
CudaEvent inner_event_;
int device_id_;
};
void DeviceEventCreateCUDA(DeviceEvent* event, const DeviceOption& dev_opt) {
event->InitEvent(std::make_shared<CUDADeviceEventWrapper>(dev_opt));
void DeviceEventCreateCUDA(DeviceEvent* event, const platform::Place& place,
unsigned int flag) {
event->InitEvent(std::make_shared<CUDADeviceEventWrapper>(place, flag));
}
void DeviceEventRecordCUDA(DeviceEvent* event, const platform::Place& place,
const DeviceContext* context) {
void DeviceEventRecordCUDA(DeviceEvent* event, const DeviceContext* context) {
auto* wrapper = static_cast<CUDADeviceEventWrapper*>(event->GetEvent().get());
auto* cuda_dev_ctx =
......@@ -72,7 +72,8 @@ void DeviceEventFinishCUDA(const DeviceEvent* event) {
wrapper->inner_event_.Synchronize();
}
void DeviceEventCUDAWaitCUDA(const DeviceEvent* event, DeviceContext* context) {
void DeviceEventCUDAWaitCUDA(const DeviceEvent* event,
const DeviceContext* context) {
auto* wrapper = static_cast<CUDADeviceEventWrapper*>(event->GetEvent().get());
auto* cuda_dev_ctx =
dynamic_cast<const platform::CUDADeviceContext*>(context);
......@@ -85,10 +86,15 @@ void DeviceEventCUDAWaitCUDA(const DeviceEvent* event, DeviceContext* context) {
wrapper->inner_event_.GetRawCudaEvent());
}
void DeviceEventCPUWaitCUDA(const DeviceEvent* event, DeviceContext* context) {
void DeviceEventCPUWaitCUDA(const DeviceEvent* event,
const DeviceContext* context) {
DeviceEventFinishCUDA(event);
}
void DeviceEventSetFinishedCUDA(const DeviceEvent* event) {
// do nothing
}
} // namespace platform
} // namespace paddle
......@@ -98,6 +104,8 @@ REGISTER_EVENT_CREATE_FUNCTION(kCUDA, paddle::platform::DeviceEventCreateCUDA)
REGISTER_EVENT_RECORD_FUNCTION(kCUDA, paddle::platform::DeviceEventRecordCUDA)
REGISTER_EVENT_QUERY_FUNCTION(kCUDA, paddle::platform::DeviceEventQueryCUDA)
REGISTER_EVENT_FINISH_FUNCTION(kCUDA, paddle::platform::DeviceEventFinishCUDA)
REGISTER_EVENT_SET_FINISHED_FUNCTION(
kCUDA, paddle::platform::DeviceEventSetFinishedCUDA)
REGISTER_EVENT_WAIT_FUNCTION(kCUDA, kCUDA,
paddle::platform::DeviceEventCUDAWaitCUDA)
REGISTER_EVENT_WAIT_FUNCTION(kCPU, kCUDA,
......
......@@ -16,35 +16,30 @@
#include "glog/logging.h"
#include "gtest/gtest.h"
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
using ::paddle::platform::kCUDA;
using ::paddle::platform::kCPU;
USE_EVENT(kCUDA);
USE_EVENT_WAIT(kCUDA, kCUDA)
USE_EVENT_WAIT(kCPU, kCUDA)
using paddle::platform::DeviceEvent;
using paddle::platform::DeviceContextPool;
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
TEST(DeviceEvent, CUDA) {
VLOG(1) << "In Test";
using paddle::platform::CUDAPlace;
using paddle::platform::DeviceOption;
using paddle::platform::DeviceEvent;
using paddle::platform::DeviceContextPool;
using paddle::platform::DeviceType;
auto& pool = DeviceContextPool::Instance();
auto place = CUDAPlace(0);
auto* context =
static_cast<paddle::platform::CUDADeviceContext*>(pool.Get(place));
int device_type = static_cast<int>(DeviceType::CUDA);
DeviceOption dev_opt(device_type, place.device);
ASSERT_NE(context, nullptr);
// case 1. test for event_creator
DeviceEvent event(dev_opt);
DeviceEvent event(place);
ASSERT_NE(event.GetEvent().get(), nullptr);
// case 2. test for event_recorder
event.Record(place, context);
event.Record(context);
bool status = event.Query();
ASSERT_EQ(status, false);
// case 3. test for event_finisher
......@@ -59,7 +54,7 @@ TEST(DeviceEvent, CUDA) {
cudaMalloc(reinterpret_cast<void**>(&dst_fp32), size);
cudaMemcpyAsync(dst_fp32, src_fp32, size, cudaMemcpyHostToDevice,
context->stream());
event.Record(place, context); // step 1. record it
event.Record(context); // step 1. record it
status = event.Query();
ASSERT_EQ(status, false);
......@@ -76,3 +71,17 @@ TEST(DeviceEvent, CUDA) {
cudaFreeHost(src_fp32);
}
#endif
TEST(DeviceEvent, CPU) {
using paddle::platform::CPUPlace;
auto place = CPUPlace();
DeviceEvent event(place);
auto& pool = DeviceContextPool::Instance();
auto* context = pool.Get(place);
// TODO(Aurelius84): All DeviceContext should has Record/Wait
event.Record(context);
event.SetFininshed();
bool status = event.Query();
ASSERT_EQ(status, true);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册