未验证 提交 22da1907 编写于 作者: A Aurelius84 提交者: GitHub

Abstract DeviceEvent to manage cross-platform Event implementation (#34922)

* add device_context

* add gtest for device_event_gpu

* Remvoe duplicate DeviceType

* push for test

* add unittest

* fix macros

* fix MSVC using usage
上级 9cbba97b
......@@ -151,11 +151,16 @@ endif()
cc_test(init_test SRCS init_test.cc DEPS device_context)
cc_library(device_event SRCS device_event.cc DEPS place enforce device_context op_registry)
cc_library(device_event_gpu SRCS device_event_gpu.cc DEPS device_event)
if(WITH_GPU)
nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)
nv_test(cudnn_desc_test SRCS cudnn_desc_test.cc DEPS dynload_cuda)
nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context)
nv_test(device_event_test SRCS device_event_test.cc DEPS device_event_gpu)
endif()
if(WITH_ROCM)
......
......@@ -97,6 +97,8 @@ enum DeviceType {
CUDA = 1,
XPU = 2,
NPU = 3,
MAX_DEVICE_TYPES = 4,
};
DeviceType Place2DeviceType(const platform::Place& place);
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device_event.h"
namespace paddle {
namespace platform {
EventCreateFunction DeviceEvent::event_creator_[MaxDeviceTypes];
EventRecordFunction DeviceEvent::event_recorder_[MaxDeviceTypes];
EventQueryFunction DeviceEvent::event_querier_[MaxDeviceTypes];
EventFinishFunction DeviceEvent::event_finisher_[MaxDeviceTypes];
EventWaitFunction DeviceEvent::event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
} // namespace platform
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace platform {
class DeviceOption;
class DeviceEvent;
constexpr int MaxDeviceTypes =
static_cast<int>(platform::DeviceType::MAX_DEVICE_TYPES);
typedef void (*EventCreateFunction)(DeviceEvent*, const DeviceOption&);
typedef void (*EventRecordFunction)(DeviceEvent*, const platform::Place&,
const DeviceContext*);
typedef bool (*EventQueryFunction)(const DeviceEvent*);
typedef void (*EventFinishFunction)(const DeviceEvent*);
typedef void (*EventWaitFunction)(const DeviceEvent*, DeviceContext*);
inline int DeviceTypeToId(const DeviceType& device_type) {
return static_cast<int>(device_type);
}
class DeviceOption {
public:
explicit DeviceOption(int device_type) : device_type_(device_type) {}
DeviceOption(int device_type, int device_id)
: device_type_(device_type), device_id_(device_id) {}
int device_type() const { return device_type_; }
int device_id() const { return device_id_; }
private:
int device_type_;
int device_id_;
};
class DeviceEvent {
public:
explicit DeviceEvent(const DeviceOption& device_option)
: event_(),
type_(device_option.device_type()),
device_option_(device_option) {
PADDLE_ENFORCE_LT(type_, MaxDeviceTypes,
platform::errors::PreconditionNotMet(
"Required type < %d, but received type = %d",
MaxDeviceTypes, type_));
PADDLE_ENFORCE_NOT_NULL(
event_creator_[type_],
platform::errors::Unavailable(
"event_creator_[%d] shall not be nullptr.", type_));
event_creator_[type_](this, device_option_);
}
~DeviceEvent() {}
void Record(const platform::Place& place, const DeviceContext* dev_ctx) {
PADDLE_ENFORCE_NOT_NULL(
event_recorder_[type_],
platform::errors::Unavailable(
"event_recorder_[%d] shall not be nullptr.", type_));
event_recorder_[type_](this, place, dev_ctx);
}
bool Query() {
PADDLE_ENFORCE_NOT_NULL(
event_querier_[type_],
platform::errors::Unavailable(
"event_querier_[%d] shall not be nullptr.", type_));
return event_querier_[type_](this);
}
void Finish() const {
PADDLE_ENFORCE_NOT_NULL(
event_finisher_[type_],
platform::errors::Unavailable(
"event_finisher_[%d] shall not be nullptr.", type_));
event_finisher_[type_](this);
}
void Wait(const DeviceType& waiter_type, DeviceContext* context) const {
auto waiter_idx = DeviceTypeToId(waiter_type);
PADDLE_ENFORCE_NOT_NULL(
event_waiter_[waiter_idx][type_],
platform::errors::Unavailable(
"event_waiter_[%d][%d] shall not be nullptr.", waiter_idx, type_));
event_waiter_[waiter_idx][type_](this, context);
}
void InitEvent(std::shared_ptr<void> event) { event_ = event; }
std::shared_ptr<void> GetEvent() const { return event_; }
private:
std::shared_ptr<void> event_;
int type_;
DeviceOption device_option_;
static EventCreateFunction event_creator_[MaxDeviceTypes];
static EventRecordFunction event_recorder_[MaxDeviceTypes];
static EventQueryFunction event_querier_[MaxDeviceTypes];
static EventFinishFunction event_finisher_[MaxDeviceTypes];
static EventWaitFunction event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
template <DeviceType device_typ>
friend struct EventCreateFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventRecordFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventQueryFunctionRegisterer;
template <DeviceType device_typ>
friend struct EventFinishFunctionRegisterer;
template <DeviceType waiter_typ, DeviceType event_type>
friend struct EventWaitFunctionRegisterer;
};
/**
* check if MACRO is used in GLOBAL NAMESPACE.
*/
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
// =============== Register for Create ===============
template <DeviceType device_type>
struct EventCreateFunctionRegisterer : public framework::Registrar {
explicit EventCreateFunctionRegisterer(EventCreateFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_creator with type_id :" << type_idx;
DeviceEvent::event_creator_[type_idx] = func;
}
};
#define REGISTER_EVENT_CREATE_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_creator__##device_type, \
"REGISTER_EVENT_CREATE_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventCreateFunctionRegisterer<device_type> \
__reg_event_create_##device_type##__(func); \
int TouchDeviceEventCreate##device_type() { \
__reg_event_create_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Record ===============
template <DeviceType device_type>
struct EventRecordFunctionRegisterer : public framework::Registrar {
explicit EventRecordFunctionRegisterer(EventRecordFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_recorder with type_id :" << type_idx;
DeviceEvent::event_recorder_[type_idx] = func;
}
};
#define REGISTER_EVENT_RECORD_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_recorder__##device_type, \
"REGISTER_EVENT_RECORD_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventRecordFunctionRegisterer<device_type> \
__reg_event_record_##device_type##__(func); \
int TouchDeviceEventRecord##device_type() { \
__reg_event_record_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Query ===============
template <DeviceType device_type>
struct EventQueryFunctionRegisterer : public framework::Registrar {
explicit EventQueryFunctionRegisterer(EventQueryFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_querier with type_id :" << type_idx;
DeviceEvent::event_querier_[type_idx] = func;
}
};
#define REGISTER_EVENT_QUERY_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_querier__##device_type, \
"REGISTER_EVENT_QUERY_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventQueryFunctionRegisterer<device_type> \
__reg_event_query_##device_type##__(func); \
int TouchDeviceEventQuery##device_type() { \
__reg_event_query_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Finish ===============
template <DeviceType device_type>
struct EventFinishFunctionRegisterer : public framework::Registrar {
explicit EventFinishFunctionRegisterer(EventFinishFunction func) {
auto type_idx = DeviceTypeToId(device_type);
VLOG(3) << "register event_finisher with type_id :" << type_idx;
DeviceEvent::event_finisher_[type_idx] = func;
}
};
#define REGISTER_EVENT_FINISH_FUNCTION(device_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_finishier__##device_type, \
"REGISTER_EVENT_FINISH_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventFinishFunctionRegisterer<device_type> \
__reg_event_finish_##device_type##__(func); \
int TouchDeviceEventFinish##device_type() { \
__reg_event_finish_##device_type##__.Touch(); \
return 0; \
}
// =============== Register for Wait ===============
template <DeviceType waiter_type, DeviceType event_type>
struct EventWaitFunctionRegisterer : public framework::Registrar {
explicit EventWaitFunctionRegisterer(EventWaitFunction func) {
auto waiter_idx = DeviceTypeToId(waiter_type);
auto event_idx = DeviceTypeToId(event_type);
VLOG(3) << "register event_finisher with waiter_idx : " << waiter_idx
<< ", event_idx : " << event_idx;
DeviceEvent::event_waiter_[waiter_idx][event_idx] = func;
}
};
#define REGISTER_EVENT_WAIT_FUNCTION(waiter_type, event_type, func) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_event_waiter__##waiter_type##event_type, \
"REGISTER_EVENT_WAIT_FUNCTION must be called in global namespace"); \
static ::paddle::platform::EventWaitFunctionRegisterer<waiter_type, \
event_type> \
__reg_event_wait_##waiter_type##event_type##__(func); \
int TouchDeviceEventWait##waiter_type##event_type() { \
__reg_event_wait_##waiter_type##event_type##__.Touch(); \
return 0; \
}
#define USE_EVENT(device_type) \
extern int TouchDeviceEventCreate##device_type(); \
extern int TouchDeviceEventRecord##device_type(); \
extern int TouchDeviceEventQuery##device_type(); \
extern int TouchDeviceEventFinish##device_type(); \
UNUSED static int use_event_creator_##device_type = \
TouchDeviceEventCreate##device_type(); \
UNUSED static int use_event_recorder_##device_type = \
TouchDeviceEventRecord##device_type(); \
UNUSED static int use_event_querier_##device_type = \
TouchDeviceEventQuery##device_type(); \
UNUSED static int use_event_finisher_##device_type = \
TouchDeviceEventFinish##device_type();
#define USE_EVENT_WAIT(waiter_type, event_type) \
extern int TouchDeviceEventWait##waiter_type##event_type(); \
UNUSED static int use_event_waiter_##waiter_type##event_type = \
TouchDeviceEventWait##waiter_type##event_type();
} // namespace platform
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device_event.h"
#include "paddle/fluid/platform/event.h"
#ifdef PADDLE_WITH_CUDA
namespace paddle {
namespace platform {
struct CUDADeviceEventWrapper {
explicit CUDADeviceEventWrapper(const DeviceOption& dev_opt)
: inner_event_() {
PADDLE_ENFORCE_EQ(
dev_opt.device_type(), static_cast<int>(DeviceType::CUDA),
platform::errors::PreconditionNotMet(
"Required device type shall be CUDA, but received %d. ",
dev_opt.device_type()));
PADDLE_ENFORCE_GT(
dev_opt.device_id(), -1,
platform::errors::PreconditionNotMet(
"Required DeviceOption.device_id > -1, but received %d. ",
dev_opt.device_id()));
device_id_ = dev_opt.device_id();
}
CudaEvent inner_event_;
int device_id_;
};
void DeviceEventCreateCUDA(DeviceEvent* event, const DeviceOption& dev_opt) {
event->InitEvent(std::make_shared<CUDADeviceEventWrapper>(dev_opt));
}
void DeviceEventRecordCUDA(DeviceEvent* event, const platform::Place& place,
const DeviceContext* context) {
auto* wrapper = static_cast<CUDADeviceEventWrapper*>(event->GetEvent().get());
auto* cuda_dev_ctx =
dynamic_cast<const platform::CUDADeviceContext*>(context);
PADDLE_ENFORCE_NOT_NULL(
cuda_dev_ctx,
platform::errors::PreconditionNotMet(
"Failed to dynamic_cast context into CUDADeviceContext."));
wrapper->inner_event_.Record(*cuda_dev_ctx->context()->Stream());
}
bool DeviceEventQueryCUDA(const DeviceEvent* event) {
auto* wrapper = static_cast<CUDADeviceEventWrapper*>(event->GetEvent().get());
PADDLE_ENFORCE_NOT_NULL(
wrapper,
platform::errors::PreconditionNotMet(
"Failed to dynamic_cast event into CUDADeviceEventWrapper."));
return wrapper->inner_event_.Query();
}
void DeviceEventFinishCUDA(const DeviceEvent* event) {
auto* wrapper = static_cast<CUDADeviceEventWrapper*>(event->GetEvent().get());
// calling cudaEventSynchronize
wrapper->inner_event_.Synchronize();
}
void DeviceEventCUDAWaitCUDA(const DeviceEvent* event, DeviceContext* context) {
auto* wrapper = static_cast<CUDADeviceEventWrapper*>(event->GetEvent().get());
auto* cuda_dev_ctx =
dynamic_cast<const platform::CUDADeviceContext*>(context);
PADDLE_ENFORCE_NOT_NULL(
cuda_dev_ctx,
platform::errors::PreconditionNotMet(
"Failed to dynamic_cast context into CUDADeviceContext."));
// calling cudaStreamWaitEvent(stream, event, 0)
cuda_dev_ctx->context()->Stream()->WaitEvent(
wrapper->inner_event_.GetRawCudaEvent());
}
void DeviceEventCPUWaitCUDA(const DeviceEvent* event, DeviceContext* context) {
DeviceEventFinishCUDA(event);
}
} // namespace platform
} // namespace paddle
using ::paddle::platform::kCUDA;
using ::paddle::platform::kCPU;
REGISTER_EVENT_CREATE_FUNCTION(kCUDA, paddle::platform::DeviceEventCreateCUDA)
REGISTER_EVENT_RECORD_FUNCTION(kCUDA, paddle::platform::DeviceEventRecordCUDA)
REGISTER_EVENT_QUERY_FUNCTION(kCUDA, paddle::platform::DeviceEventQueryCUDA)
REGISTER_EVENT_FINISH_FUNCTION(kCUDA, paddle::platform::DeviceEventFinishCUDA)
REGISTER_EVENT_WAIT_FUNCTION(kCUDA, kCUDA,
paddle::platform::DeviceEventCUDAWaitCUDA)
REGISTER_EVENT_WAIT_FUNCTION(kCPU, kCUDA,
paddle::platform::DeviceEventCPUWaitCUDA)
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device_event.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
using ::paddle::platform::kCUDA;
using ::paddle::platform::kCPU;
USE_EVENT(kCUDA);
USE_EVENT_WAIT(kCUDA, kCUDA)
USE_EVENT_WAIT(kCPU, kCUDA)
TEST(DeviceEvent, CUDA) {
VLOG(1) << "In Test";
using paddle::platform::CUDAPlace;
using paddle::platform::DeviceOption;
using paddle::platform::DeviceEvent;
using paddle::platform::DeviceContextPool;
using paddle::platform::DeviceType;
auto& pool = DeviceContextPool::Instance();
auto place = CUDAPlace(0);
auto* context =
static_cast<paddle::platform::CUDADeviceContext*>(pool.Get(place));
int device_type = static_cast<int>(DeviceType::CUDA);
DeviceOption dev_opt(device_type, place.device);
ASSERT_NE(context, nullptr);
// case 1. test for event_creator
DeviceEvent event(dev_opt);
ASSERT_NE(event.GetEvent().get(), nullptr);
// case 2. test for event_recorder
event.Record(place, context);
bool status = event.Query();
ASSERT_EQ(status, false);
// case 3. test for event_finisher
event.Finish();
status = event.Query();
ASSERT_EQ(status, true);
// case 4. test for event_waiter
float *src_fp32, *dst_fp32;
int size = 1000000 * sizeof(float);
cudaMallocHost(reinterpret_cast<void**>(&src_fp32), size);
cudaMalloc(reinterpret_cast<void**>(&dst_fp32), size);
cudaMemcpyAsync(dst_fp32, src_fp32, size, cudaMemcpyHostToDevice,
context->stream());
event.Record(place, context); // step 1. record it
status = event.Query();
ASSERT_EQ(status, false);
event.Wait(kCUDA, context); // step 2. add streamWaitEvent
status = event.Query();
ASSERT_EQ(status, false); // async
event.Wait(kCPU, context); // step 3. EventSynchornize
status = event.Query();
ASSERT_EQ(status, true); // sync
// release resource
cudaFree(dst_fp32);
cudaFreeHost(src_fp32);
}
#endif
......@@ -105,7 +105,7 @@ void BindCudaStream(py::module *m_ptr) {
.def("wait_stream",
[](paddle::platform::stream::CUDAStream &self,
paddle::platform::stream::CUDAStream &stream) {
auto event = paddle::platform::CudaEvent();
paddle::platform::CudaEvent event;
event.Record(stream);
self.WaitEvent(event.GetRawCudaEvent());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册