未验证 提交 0e26361c 编写于 作者: 王明冬 提交者: GitHub

add xpu garbage collector for standalone executor. (#44572)

上级 cd55385a
......@@ -39,15 +39,10 @@ set(STANDALONE_EXECUTOR_DEPS
scope
glog
workqueue
interpretercore_event_garbage_collector
interpretercore_garbage_collector
${DEVICE_EVENT_LIBS}
glog)
if(WITH_GPU OR WITH_ROCM)
set(STANDALONE_EXECUTOR_DEPS ${STANDALONE_EXECUTOR_DEPS}
interpretercore_fast_garbage_collector)
endif()
cc_library(
standalone_executor
SRCS ${STANDALONE_EXECUTOR_SRCS}
......
cc_library(
interpretercore_garbage_collector
SRCS garbage_collector.cc
SRCS garbage_collector.cc event_garbage_collector.cc fast_garbage_collector.cc
no_event_garbage_collector.cc
DEPS garbage_collector)
cc_library(
interpretercore_event_garbage_collector
SRCS event_garbage_collector.cc
DEPS interpretercore_garbage_collector)
if(WITH_GPU OR WITH_ROCM)
if(WITH_GPU)
nv_library(
interpretercore_fast_garbage_collector
SRCS fast_garbage_collector.cc
DEPS interpretercore_garbage_collector)
elseif(WITH_ROCM)
hip_library(
interpretercore_fast_garbage_collector
SRCS fast_garbage_collector.cc
DEPS interpretercore_garbage_collector)
endif()
endif()
......@@ -24,48 +24,33 @@
namespace paddle {
namespace framework {
InterpreterCoreEventGarbageCollector::InterpreterCoreEventGarbageCollector() {
InterpreterCoreEventGarbageCollector::InterpreterCoreEventGarbageCollector(
const std::vector<Instruction>& vec_instruction) {
WorkQueueOptions options(/*name*/ "GarbageCollector",
/*num_threads*/ 1,
/*allow_spinning*/ true,
/*track_task*/ false);
queue_ = CreateSingleThreadedWorkQueue(options);
for (auto& instruc : vec_instruction) {
gc_event_.emplace_back(instruc.DeviceContext().GetPlace(),
platform::GenerateDeviceEventFlag());
}
}
InterpreterCoreEventGarbageCollector::~InterpreterCoreEventGarbageCollector() {
queue_.reset(nullptr);
}
void InterpreterCoreEventGarbageCollector::Add(
Garbage garbage,
platform::DeviceEvent* event,
const platform::DeviceContext* ctx) {
if (!garbage) {
return;
}
if (max_memory_size_ <= 1) {
Free(garbage, event, ctx);
} else {
std::unique_ptr<GarbageQueue> pending_delete_garbages;
{ // lock guard
std::lock_guard<memory::SpinLock> guard(spinlock_);
cur_memory_size_ += garbage->size();
garbages_->push_back(std::move(garbage));
if (cur_memory_size_ >= max_memory_size_) {
cur_memory_size_ = 0;
pending_delete_garbages = std::move(garbages_);
garbages_ = std::make_unique<GarbageQueue>();
}
}
}
}
void InterpreterCoreEventGarbageCollector::Add(Variable* var) {
PADDLE_THROW(platform::errors::Unimplemented(
"Add(Variable* var) is not implemented for "
"InterpreterCoreEventGarbageCollector."));
void InterpreterCoreEventGarbageCollector::Add(Variable* var,
const Instruction& instr) {
PADDLE_ENFORCE_LT(instr.Id(),
gc_event_.size(),
platform::errors::OutOfRange(
"The index should be less than the size of gc event "
", but got index is %d and size is %d",
instr.Id(),
gc_event_.size()));
Add(var, &gc_event_.at(instr.Id()), &instr.DeviceContext());
}
void InterpreterCoreEventGarbageCollector::Add(
......@@ -109,23 +94,28 @@ void InterpreterCoreEventGarbageCollector::Add(
}
}
void InterpreterCoreEventGarbageCollector::Free(
GarbageQueue* garbages,
void InterpreterCoreEventGarbageCollector::Add(
Garbage garbage,
platform::DeviceEvent* event,
const platform::DeviceContext* ctx) {
event->Record(ctx);
event->SetFininshed(); // Only for CPU Event
queue_->AddTask([container = garbages, event = event]() {
while (!event->Query()) {
#if defined(_WIN32)
SleepEx(50, FALSE);
#else
sched_yield();
#endif
continue;
if (!garbage) {
return;
}
if (max_memory_size_ <= 1) {
Free(garbage, event, ctx);
} else {
{ // lock guard
std::lock_guard<memory::SpinLock> guard(spinlock_);
cur_memory_size_ += garbage->size();
garbages_->push_back(std::move(garbage));
events_[ctx] = event;
if (cur_memory_size_ >= max_memory_size_) {
FreeGarbages();
}
}
delete container;
});
}
}
void InterpreterCoreEventGarbageCollector::Free(
......@@ -146,5 +136,28 @@ void InterpreterCoreEventGarbageCollector::Free(
});
}
void InterpreterCoreEventGarbageCollector::FreeGarbages() {
for (auto& vals : events_) {
vals.second->Record(vals.first);
vals.second->SetFininshed(); // Only for CPU Event
}
queue_->AddTask(
[container = std::move(*garbages_), events = std::move(events_)]() {
for (auto& vals : events) {
while (!vals.second->Query()) {
#if defined(_WIN32)
SleepEx(50, FALSE);
#else
sched_yield();
#endif
continue;
}
}
});
cur_memory_size_ = 0;
garbages_->clear();
events_.clear();
}
} // namespace framework
} // namespace paddle
......@@ -24,28 +24,31 @@ namespace framework {
class InterpreterCoreEventGarbageCollector
: public InterpreterCoreGarbageCollector {
public:
InterpreterCoreEventGarbageCollector();
InterpreterCoreEventGarbageCollector(
const std::vector<Instruction>& vec_instruction);
~InterpreterCoreEventGarbageCollector();
void Add(Variable* var) override;
virtual void Add(Variable* var,
platform::DeviceEvent* event,
const platform::DeviceContext* ctx);
void Add(Variable* var, const Instruction& instruction) override;
private:
void Add(Variable* var,
platform::DeviceEvent* event,
const platform::DeviceContext* ctx);
void Add(Garbage garbage,
platform::DeviceEvent* event,
const platform::DeviceContext* ctx);
void Free(GarbageQueue* garbages,
platform::DeviceEvent* event,
const platform::DeviceContext* ctx);
void Free(const Garbage& garbage,
platform::DeviceEvent* event,
const platform::DeviceContext* ctx);
void FreeGarbages();
std::unique_ptr<WorkQueue> queue_;
paddle::memory::SpinLock spinlock_;
std::vector<paddle::platform::DeviceEvent> gc_event_;
std::unordered_map<const platform::DeviceContext*,
paddle::platform::DeviceEvent*>
events_;
};
} // namespace framework
} // namespace paddle
......@@ -17,12 +17,9 @@
namespace paddle {
namespace framework {
void InterpreterCoreFastGarbageCollector::Add(
Variable* var,
platform::DeviceEvent* event,
const platform::DeviceContext* ctx) {
PADDLE_THROW(platform::errors::Unimplemented(
"Not implemented for InterpreterCoreFastGarbageCollector."));
void InterpreterCoreFastGarbageCollector::Add(Variable* var,
const Instruction&) {
Add(var);
}
void InterpreterCoreFastGarbageCollector::Add(Variable* var) {
......
......@@ -13,8 +13,6 @@
// limitations under the License.
#pragma once
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h"
namespace paddle {
......@@ -23,15 +21,11 @@ namespace framework {
class InterpreterCoreFastGarbageCollector
: public InterpreterCoreGarbageCollector {
public:
void Add(Variable* var) override;
void Add(Variable* var,
platform::DeviceEvent* event,
const platform::DeviceContext* ctx) override;
void Add(Variable* var, const Instruction& instr) override;
private:
void Add(Variable* var);
void Add(Garbage garbage);
};
} // namespace framework
} // namespace paddle
#endif
......@@ -13,17 +13,48 @@
// limitations under the License.
#include "paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/new_executor/garbage_collector/event_garbage_collector.h"
#include "paddle/fluid/framework/new_executor/garbage_collector/fast_garbage_collector.h"
#include "paddle/fluid/framework/new_executor/garbage_collector/no_event_garbage_collector.h"
DECLARE_bool(fast_eager_deletion_mode);
namespace paddle {
namespace framework {
bool IsInterpretercoreFastGCEnabled() {
return memory::allocation::AllocatorFacade::Instance()
.IsStreamSafeCUDAAllocatorUsed() &&
FLAGS_fast_eager_deletion_mode;
}
InterpreterCoreGarbageCollector::InterpreterCoreGarbageCollector() {
garbages_ = std::make_unique<GarbageQueue>();
max_memory_size_ = static_cast<int64_t>(GetEagerDeletionThreshold());
cur_memory_size_ = 0;
}
std::unique_ptr<InterpreterCoreGarbageCollector>
CreateInterpreterCoreGarbageCollector(
const platform::Place& place,
const std::vector<Instruction>& vec_instruction) {
if (platform::is_gpu_place(place)) {
if (IsInterpretercoreFastGCEnabled()) {
return std::unique_ptr<InterpreterCoreGarbageCollector>(
new InterpreterCoreFastGarbageCollector());
} else {
return std::unique_ptr<InterpreterCoreGarbageCollector>(
new InterpreterCoreEventGarbageCollector(vec_instruction));
}
} else if (platform::is_xpu_place(place) || platform::is_ipu_place(place)) {
return std::unique_ptr<InterpreterCoreGarbageCollector>(
new InterpreterCoreNoEventGarbageCollector());
} else {
return std::unique_ptr<InterpreterCoreGarbageCollector>(
new InterpreterCoreEventGarbageCollector(vec_instruction));
}
}
} // namespace framework
} // namespace paddle
......@@ -15,6 +15,7 @@
#include <queue>
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/memory/allocation/spin_lock.h"
#include "paddle/fluid/platform/device_event.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -30,10 +31,9 @@ class InterpreterCoreGarbageCollector {
public:
InterpreterCoreGarbageCollector();
virtual ~InterpreterCoreGarbageCollector() {}
virtual void Add(Variable* var) = 0;
virtual void Add(Variable* var,
platform::DeviceEvent* event,
const platform::DeviceContext* ctx) = 0;
virtual void Add(Variable* var, const Instruction& instruction) = 0;
DISABLE_COPY_AND_ASSIGN(InterpreterCoreGarbageCollector);
protected:
......@@ -43,5 +43,12 @@ class InterpreterCoreGarbageCollector {
memory::SpinLock spinlock_;
};
bool IsInterpretercoreFastGCEnabled();
std::unique_ptr<InterpreterCoreGarbageCollector>
CreateInterpreterCoreGarbageCollector(
const platform::Place& place,
const std::vector<Instruction>& vec_instruction);
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/garbage_collector/no_event_garbage_collector.h"
namespace paddle {
namespace framework {
InterpreterCoreNoEventGarbageCollector::
InterpreterCoreNoEventGarbageCollector() {
WorkQueueOptions options(/*name*/ "NoEventGarbageCollector",
/*num_threads*/ 1,
/*allow_spinning*/ true,
/*track_task*/ false);
queue_ = CreateSingleThreadedWorkQueue(options);
}
InterpreterCoreNoEventGarbageCollector::
~InterpreterCoreNoEventGarbageCollector() {
queue_.reset(nullptr);
}
void InterpreterCoreNoEventGarbageCollector::Add(Variable* var,
const Instruction& instr) {
Add(var, &instr.DeviceContext());
}
void InterpreterCoreNoEventGarbageCollector::Add(
Variable* var, const platform::DeviceContext* ctx) {
if (UNLIKELY(max_memory_size_ < 0) || var == nullptr) {
return;
}
if (var->IsType<LoDTensor>()) {
Add(var->GetMutable<LoDTensor>()->MoveMemoryHolder(), ctx);
} else if (var->IsType<
operators::reader::
OrderedMultiDeviceLoDTensorBlockingQueueHolder>()) {
// TODO(xiongkun03) in old executor, this type of variable is not support
// eager deletion. so we just leave it here ?
} else if (var->IsType<LoDRankTable>()) {
// TODO(xiongkun03) in old executor, this type of variable is not support
// eager deletion. so we just leave it here ?
} else if (var->IsType<phi::SelectedRows>()) {
Add(var->GetMutable<phi::SelectedRows>()
->mutable_value()
->MoveMemoryHolder(),
ctx);
var->GetMutable<phi::SelectedRows>()->mutable_rows()->clear();
} else if (var->IsType<LoDTensorArray>()) {
auto* tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto& t : *tensor_arr) {
Add(t.MoveMemoryHolder(), ctx);
}
} else if (var->IsType<std::vector<Scope*>>()) {
// NOTE(@xiongkun03) conditional_op / while_op will create a STEP_SCOPE
// refer to executor.cc to see what old garbage collector does.
// do nothing, because the sub scope will be deleted by sub-executor.
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"The variable(%s) is not supported in eager deletion.",
framework::ToTypeName(var->Type())));
}
}
void InterpreterCoreNoEventGarbageCollector::Add(
Garbage garbage, const platform::DeviceContext* ctx) {
if (!garbage) {
return;
}
if (max_memory_size_ <= 1) {
queue_->AddTask([container = garbage, ctx = ctx]() { ctx->Wait(); });
} else {
// lock guard
std::lock_guard<memory::SpinLock> guard(spinlock_);
cur_memory_size_ += garbage->size();
garbages_->emplace_back(std::move(garbage));
ctxs_.insert(ctx);
if (cur_memory_size_ >= max_memory_size_) {
cur_memory_size_ = 0;
queue_->AddTask(
[container = std::move(*garbages_), dev_ctxs = std::move(ctxs_)]() {
for (auto& ctx : dev_ctxs) {
ctx->Wait();
}
});
ctxs_.clear();
garbages_->clear();
}
}
}
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <queue>
#include "paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h"
#include "paddle/fluid/framework/new_executor/workqueue/workqueue.h"
namespace paddle {
namespace framework {
class InterpreterCoreNoEventGarbageCollector
: public InterpreterCoreGarbageCollector {
public:
InterpreterCoreNoEventGarbageCollector();
~InterpreterCoreNoEventGarbageCollector();
void Add(Variable* var, const Instruction& instr) override;
private:
void Add(Variable* var, const platform::DeviceContext* ctx);
void Add(Garbage garbage, const platform::DeviceContext* ctx);
std::unique_ptr<WorkQueue> queue_;
std::unordered_set<const platform::DeviceContext*> ctxs_;
};
} // namespace framework
} // namespace paddle
......@@ -18,8 +18,6 @@
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_functor.h"
#include "paddle/fluid/framework/new_executor/garbage_collector/event_garbage_collector.h"
#include "paddle/fluid/framework/new_executor/garbage_collector/fast_garbage_collector.h"
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/os_info.h"
......@@ -41,7 +39,6 @@ PADDLE_DEFINE_EXPORTED_bool(new_executor_use_local_scope,
DECLARE_bool(check_nan_inf);
DECLARE_bool(benchmark);
DECLARE_bool(fast_eager_deletion_mode);
constexpr const char* kExceptionCaught = "ExceptionCaught";
constexpr const char* kTaskCompletion = "TaskCompletion";
......@@ -52,12 +49,6 @@ namespace framework {
static constexpr size_t kHostNumThreads = 4;
static constexpr size_t kDeviceNumThreads = 1;
bool IsInterpretercoreFastGCEnabled() {
return memory::allocation::AllocatorFacade::Instance()
.IsStreamSafeCUDAAllocatorUsed() &&
FLAGS_fast_eager_deletion_mode;
}
InterpreterCore::InterpreterCore(const platform::Place& place,
const BlockDesc& block,
const std::set<std::string>& skip_gc_vars,
......@@ -71,16 +62,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
is_build_ = false;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (IsInterpretercoreFastGCEnabled()) {
gc_ = std::make_unique<InterpreterCoreFastGarbageCollector>();
} else {
gc_ = std::make_unique<InterpreterCoreEventGarbageCollector>();
}
#else
gc_ = std::make_unique<InterpreterCoreEventGarbageCollector>();
#endif
exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught);
completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion);
......@@ -498,16 +479,7 @@ void InterpreterCore::Convert(
}
BuildSkipShareLoDInfo();
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
#ifdef PADDLE_WITH_IPU
gc_event_.emplace_back(phi::CPUPlace(), 0);
#else
gc_event_.emplace_back(vec_instruction_[i].DeviceContext().GetPlace(),
platform::GenerateDeviceEventFlag());
#endif
}
gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_);
bool inplaced = false;
for (auto inst : vec_instruction_) {
if (inst.OpBase()->Type() == "share_buffer" ||
......@@ -828,9 +800,6 @@ void InterpreterCore::RunInstructionAsync(
RunInstruction(instr_node);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
RecordStreamForGC(instr_node);
#endif
CheckGC(instr_node, atomic_var_ref);
interpreter::RecordEvent(instr_node, place_);
......@@ -969,7 +938,9 @@ void InterpreterCore::CheckGC(
std::vector<std::atomic<size_t>>* atomic_var_ref) {
platform::RecordEvent record(
"CheckGC", platform::TracerEventType::UserDefined, 10);
size_t instr_id = instr.Id();
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
RecordStreamForGC(instr);
#endif
auto& var_scope = var_scope_;
for (auto var_id : instr.GCCheckVars()) {
......@@ -986,23 +957,7 @@ void InterpreterCore::CheckGC(
if (is_ready) {
VLOG(6) << "Async delete variable with name : "
<< var_scope.GetNameById(var_id);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (IsInterpretercoreFastGCEnabled()) {
static_cast<InterpreterCoreFastGarbageCollector*>(gc_.get())->Add(
var_scope_.VarRef(var_id));
} else {
static_cast<InterpreterCoreEventGarbageCollector*>(gc_.get())->Add(
var_scope_.VarRef(var_id),
&gc_event_.at(instr_id),
&instr.DeviceContext());
}
#else
static_cast<InterpreterCoreEventGarbageCollector*>(gc_.get())->Add(
var_scope_.VarRef(var_id),
&gc_event_.at(instr_id),
&instr.DeviceContext());
#endif
gc_->Add(var_scope_.VarRef(var_id), instr);
}
}
}
......
......@@ -141,7 +141,6 @@ class InterpreterCore {
std::shared_ptr<EventsWaiter::EventNotifier> completion_notifier_{nullptr};
std::unique_ptr<InterpreterCoreGarbageCollector> gc_;
std::vector<paddle::platform::DeviceEvent> gc_event_;
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_deps_;
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_var_ref_;
......
......@@ -265,15 +265,6 @@ cc_library(
set(DEVICE_EVENT_LIBS
device_event_base
CACHE INTERNAL "device event libs")
if(WITH_XPU)
cc_library(
device_event_xpu
SRCS device_event_xpu.cc
DEPS device_event_base xpu_info)
set(DEVICE_EVENT_LIBS
device_event_xpu
CACHE INTERNAL "device event libs")
endif()
if(WITH_ASCEND_CL)
cc_library(
......
......@@ -113,8 +113,8 @@ bool AllowTF32Cudnn();
enum DeviceType {
CPU = 0,
CUDA = 1,
XPU = 2,
NPU = 3,
NPU = 2,
XPU = 3,
IPU = 4,
MLU = 5,
......
......@@ -37,12 +37,6 @@ USE_EVENT_WAIT(kCUDA, kCUDA)
USE_EVENT_WAIT(kCPU, kCUDA)
#endif
#ifdef PADDLE_WITH_XPU
USE_EVENT(kXPU);
USE_EVENT_WAIT(kXPU, kXPU)
USE_EVENT_WAIT(kCPU, kXPU)
#endif
#ifdef PADDLE_WITH_ASCEND_CL
USE_EVENT(kNPU);
USE_EVENT_WAIT(kNPU, kNPU)
......
......@@ -64,9 +64,9 @@ class DeviceEvent {
"Required type < %d, but received type = %d",
MaxDeviceTypes,
type_id_));
// TODO(Aurelius84): only support CPU/CUDA/XPU/NPU.
// TODO(Aurelius84): only support CPU/CUDA/NPU.
PADDLE_ENFORCE_LT(type_id_,
4,
3,
platform::errors::Unavailable(
"Currently DeviceEvent do not support %s", place));
PADDLE_ENFORCE_NOT_NULL(
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#include "paddle/fluid/platform/device_event_base.h"
#ifdef PADDLE_WITH_XPU
namespace paddle {
namespace platform {
struct XPUDeviceEventWrapper {
explicit XPUDeviceEventWrapper(const platform::Place& place) {
PADDLE_ENFORCE_EQ(
platform::is_xpu_place(place),
true,
platform::errors::PreconditionNotMet(
"Required device shall be XPUPlace, but received %d. ", place));
device_id_ = place.device;
PADDLE_ENFORCE_GT(
device_id_,
-1,
platform::errors::PreconditionNotMet(
"Required DeviceOption.device_id > -1, but received %d. ",
device_id_));
xpu_event_create(&handle_);
}
xpuEventHandle handle_;
int device_id_;
};
void DeviceEventCreateXPU(DeviceEvent* event,
const platform::Place& place,
unsigned int) {
event->InitEvent(std::make_shared<XPUDeviceEventWrapper>(place));
}
void DeviceEventRecordXPU(DeviceEvent* event, const DeviceContext* context) {
auto* wrapper = static_cast<XPUDeviceEventWrapper*>(event->GetEvent().get());
PADDLE_ENFORCE_NOT_NULL(
wrapper,
platform::errors::PreconditionNotMet(
"Failed to dynamic_cast event into XPUDeviceEventWrapper."));
auto* xpu_dev_ctx = dynamic_cast<const platform::XPUDeviceContext*>(context);
PADDLE_ENFORCE_NOT_NULL(
xpu_dev_ctx,
platform::errors::PreconditionNotMet(
"Failed to dynamic_cast context into XPUDeviceContext."));
xpu_event_record(wrapper->handle_, xpu_dev_ctx->stream());
}
void DeviceEventFinishXPU(const DeviceEvent* event) {
auto* wrapper = static_cast<XPUDeviceEventWrapper*>(event->GetEvent().get());
PADDLE_ENFORCE_NOT_NULL(
wrapper,
platform::errors::PreconditionNotMet(
"Failed to dynamic_cast event into XPUDeviceEventWrapper."));
xpu_event_wait(wrapper->handle_);
}
// current xpu not support query, used wait to instead.
bool DeviceEventQueryXPU(const DeviceEvent* event) {
DeviceEventFinishXPU(event);
return true;
}
void DeviceEventXPUWaitXPU(const DeviceEvent* event,
const DeviceContext* context) {
auto* wrapper = static_cast<XPUDeviceEventWrapper*>(event->GetEvent().get());
PADDLE_ENFORCE_NOT_NULL(
wrapper,
platform::errors::PreconditionNotMet(
"Failed to dynamic_cast event into XPUDeviceEventWrapper."));
auto* xpu_dev_ctx = dynamic_cast<const platform::XPUDeviceContext*>(context);
PADDLE_ENFORCE_NOT_NULL(
xpu_dev_ctx,
platform::errors::PreconditionNotMet(
"Failed to dynamic_cast context into XOUDeviceContext."));
xpu_stream_wait_event(xpu_dev_ctx->stream(), wrapper->handle_);
}
void DeviceEventCPUWaitXPU(const DeviceEvent* event,
const DeviceContext* context) {
DeviceEventFinishXPU(event);
}
void DeviceEventSetFinishedXPU(const DeviceEvent* event) {
// do nothing
}
void EventResetXPU(const DeviceEvent* event) {
// do nothing
}
} // namespace platform
} // namespace paddle
using ::paddle::platform::kCPU;
using ::paddle::platform::kXPU;
REGISTER_EVENT_CREATE_FUNCTION(kXPU, paddle::platform::DeviceEventCreateXPU)
REGISTER_EVENT_RECORD_FUNCTION(kXPU, paddle::platform::DeviceEventRecordXPU)
REGISTER_EVENT_QUERY_FUNCTION(kXPU, paddle::platform::DeviceEventQueryXPU)
REGISTER_EVENT_FINISH_FUNCTION(kXPU, paddle::platform::DeviceEventFinishXPU)
REGISTER_EVENT_SET_FINISHED_FUNCTION(
kXPU, paddle::platform::DeviceEventSetFinishedXPU)
REGISTER_EVENT_WAIT_FUNCTION(kXPU,
kXPU,
paddle::platform::DeviceEventXPUWaitXPU)
REGISTER_EVENT_WAIT_FUNCTION(kCPU,
kXPU,
paddle::platform::DeviceEventCPUWaitXPU)
REGISTER_EVENT_RESET_FUNCTION(kXPU, paddle::platform::EventResetXPU)
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册