未验证 提交 268f097e 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add new executor support (#46038)

* [CustomDevice] add custom_device_resource_pool & device_event_custom_device

* update

* update

* update

* update
上级 cbda49e6
...@@ -169,7 +169,12 @@ void DataTranferHelper::RunAndConstructOpFuncNode( ...@@ -169,7 +169,12 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
// NOTE(winter-wang): in npu device, D2H kernel is asynchronous. need to // NOTE(winter-wang): in npu device, D2H kernel is asynchronous. need to
// explicit synchronization. // explicit synchronization.
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
if (op_type == kMemcpyD2H) { if (op_type == kMemcpyD2H && platform::is_npu_place(dev_ctx->GetPlace())) {
dev_ctx->Wait();
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (op_type == kMemcpyD2H && platform::is_custom_place(dev_ctx->GetPlace())) {
dev_ctx->Wait(); dev_ctx->Wait();
} }
#endif #endif
...@@ -363,11 +368,12 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name, ...@@ -363,11 +368,12 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
src_place)); src_place));
if (IsSupportedHetePlace(dst_place)) { if (IsSupportedHetePlace(dst_place)) {
op_type = kMemcpyH2D; op_type = kMemcpyH2D;
int dst_place_type = platform::is_gpu_place(dst_place) ? 0 int dst_place_type = platform::is_gpu_place(dst_place) ? 0
: platform::is_npu_place(dst_place) ? 1 : platform::is_npu_place(dst_place) ? 1
: platform::is_ipu_place(dst_place) ? 3 : platform::is_ipu_place(dst_place) ? 3
: platform::is_xpu_place(dst_place) ? 2 : platform::is_xpu_place(dst_place) ? 2
: -1; : platform::is_custom_place(dst_place) ? 6
: -1;
attr_map = {{"dst_place_type", dst_place_type}}; attr_map = {{"dst_place_type", dst_place_type}};
} else if (IsSupportedHetePlace(src_place)) { } else if (IsSupportedHetePlace(src_place)) {
op_type = kMemcpyD2H; op_type = kMemcpyD2H;
......
...@@ -165,7 +165,14 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -165,7 +165,14 @@ paddle::framework::FetchList InterpreterCore::Run(
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
platform::DeviceContextPool::Instance().Get(place_)->Wait(); if (platform::is_npu_place(place_)) {
platform::DeviceContextPool::Instance().Get(place_)->Wait();
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (platform::is_custom_place(place_)) {
platform::DeviceContextPool::Instance().Get(place_)->Wait();
}
#endif #endif
} }
if (create_local_scope_) { if (create_local_scope_) {
...@@ -223,7 +230,14 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -223,7 +230,14 @@ paddle::framework::FetchList InterpreterCore::Run(
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
platform::DeviceContextPool::Instance().Get(place_)->Wait(); if (platform::is_npu_place(place_)) {
platform::DeviceContextPool::Instance().Get(place_)->Wait();
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (platform::is_custom_place(place_)) {
platform::DeviceContextPool::Instance().Get(place_)->Wait();
}
#endif #endif
} }
......
...@@ -399,7 +399,8 @@ static bool IsCpuOp(const Instruction& instr) { ...@@ -399,7 +399,8 @@ static bool IsCpuOp(const Instruction& instr) {
// is supported heterogeneous place // is supported heterogeneous place
static bool IsSupportedHetePlace(const phi::Place& place) { static bool IsSupportedHetePlace(const phi::Place& place) {
return platform::is_gpu_place(place) || platform::is_npu_place(place) || return platform::is_gpu_place(place) || platform::is_npu_place(place) ||
platform::is_xpu_place(place) || platform::is_ipu_place(place); platform::is_xpu_place(place) || platform::is_ipu_place(place) ||
platform::is_custom_place(place);
} }
} // namespace interpreter } // namespace interpreter
......
...@@ -30,7 +30,8 @@ std::mutex ctx_mtx; ...@@ -30,7 +30,8 @@ std::mutex ctx_mtx;
} // namespace } // namespace
StreamAnalyzer::StreamAnalyzer(const platform::Place& place) : place_(place) { StreamAnalyzer::StreamAnalyzer(const platform::Place& place) : place_(place) {
if (platform::is_gpu_place(place) || platform::is_npu_place(place)) { if (platform::is_gpu_place(place) || platform::is_npu_place(place) ||
platform::is_custom_place(place)) {
std::lock_guard<std::mutex> lk(ctx_mtx); std::lock_guard<std::mutex> lk(ctx_mtx);
if (d2h_ctxs == nullptr) { if (d2h_ctxs == nullptr) {
d2h_ctxs = new std::map< d2h_ctxs = new std::map<
...@@ -178,7 +179,8 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( ...@@ -178,7 +179,8 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
auto* dev_ctx = op_func_node.dev_ctx_; auto* dev_ctx = op_func_node.dev_ctx_;
// only gpu/npu need update. xpu not need, because xpu memcpy op kernel is // only gpu/npu need update. xpu not need, because xpu memcpy op kernel is
// synchronous. // synchronous.
if (platform::is_gpu_place(place_) || platform::is_npu_place(place_)) { if (platform::is_gpu_place(place_) || platform::is_npu_place(place_) ||
platform::is_custom_place(place_)) {
if (op_type == interpreter::kMemcpyD2H) { if (op_type == interpreter::kMemcpyD2H) {
VLOG(3) << "Get dev_ctx from d2h_context_pool_"; VLOG(3) << "Get dev_ctx from d2h_context_pool_";
dev_ctx = d2h_ctx_.get().get(); dev_ctx = d2h_ctx_.get().get();
...@@ -209,7 +211,7 @@ bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr, ...@@ -209,7 +211,7 @@ bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr,
return true; return true;
// npu d2h kernel is asynchronous. // npu d2h kernel is asynchronous.
if (platform::is_npu_place(place_)) { if (platform::is_npu_place(place_) || platform::is_custom_place(place_)) {
return interpreter::IsCpuOp(cur_instr) || return interpreter::IsCpuOp(cur_instr) ||
interpreter::IsMemcpyH2D(next_instr); interpreter::IsMemcpyH2D(next_instr);
} }
...@@ -227,6 +229,8 @@ platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) { ...@@ -227,6 +229,8 @@ platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) {
return platform::kXPU; return platform::kXPU;
} else if (platform::is_npu_place(place_)) { } else if (platform::is_npu_place(place_)) {
return platform::kNPU; return platform::kNPU;
} else if (platform::is_custom_place(place_)) {
return platform::kCUSTOM_DEVICE;
} }
return platform::kCUDA; return platform::kCUDA;
} }
......
...@@ -251,6 +251,10 @@ if(WITH_MLU) ...@@ -251,6 +251,10 @@ if(WITH_MLU)
target_link_libraries(device_context mlu_resource_pool) target_link_libraries(device_context mlu_resource_pool)
endif() endif()
if(WITH_CUSTOM_DEVICE)
target_link_libraries(device_context custom_device_resource_pool)
endif()
cc_test( cc_test(
init_test init_test
SRCS init_test.cc SRCS init_test.cc
...@@ -284,11 +288,17 @@ if(WITH_GPU) ...@@ -284,11 +288,17 @@ if(WITH_GPU)
set(DEVICE_EVENT_LIBS set(DEVICE_EVENT_LIBS
device_event_gpu device_event_gpu
CACHE INTERNAL "device event libs") CACHE INTERNAL "device event libs")
nv_test( if(WITH_CUSTOM_DEVICE)
device_event_test nv_test(
SRCS device_event_test.cc device_event_test
DEPS device_event_gpu) SRCS device_event_test.cc
DEPS device_event_gpu device_event_custom_device)
else()
nv_test(
device_event_test
SRCS device_event_test.cc
DEPS device_event_gpu)
endif()
nv_test( nv_test(
device_context_test device_context_test
SRCS device_context_test.cu SRCS device_context_test.cu
...@@ -311,11 +321,17 @@ if(WITH_ROCM) ...@@ -311,11 +321,17 @@ if(WITH_ROCM)
set(DEVICE_EVENT_LIBS set(DEVICE_EVENT_LIBS
device_event_gpu device_event_gpu
CACHE INTERNAL "device event libs") CACHE INTERNAL "device event libs")
hip_test( if(WITH_CUSTOM_DEVICE)
device_event_test hip_test(
SRCS device_event_test.cc device_event_test
DEPS device_event_gpu) SRCS device_event_test.cc
DEPS device_event_gpu device_event_custom_device)
else()
hip_test(
device_event_test
SRCS device_event_test.cc
DEPS device_event_gpu)
endif()
hip_test( hip_test(
device_context_test device_context_test
SRCS device_context_test.cu SRCS device_context_test.cu
...@@ -470,3 +486,13 @@ if(NOT APPLE AND NOT WIN32) ...@@ -470,3 +486,13 @@ if(NOT APPLE AND NOT WIN32)
DEPS device_code lod_tensor) DEPS device_code lod_tensor)
endif() endif()
endif() endif()
if(WITH_CUSTOM_DEVICE)
cc_library(
device_event_custom_device
SRCS device_event_custom_device.cc
DEPS device_event_base custom_device_resource_pool)
set(DEVICE_EVENT_LIBS
${DEVICE_EVENT_LIBS} device_event_custom_device
CACHE INTERNAL "device event libs")
endif()
...@@ -24,3 +24,7 @@ endif() ...@@ -24,3 +24,7 @@ endif()
if(WITH_MLU) if(WITH_MLU)
add_subdirectory(mlu) add_subdirectory(mlu)
endif() endif()
if(WITH_CUSTOM_DEVICE)
add_subdirectory(custom)
endif()
if(WITH_CUSTOM_DEVICE)
cc_library(
custom_device_resource_pool
SRCS custom_device_resource_pool.cc
DEPS gflags glog enforce monitor)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/platform/device/custom/custom_device_resource_pool.h"
namespace paddle {
namespace platform {
CustomDeviceStreamResourcePool::CustomDeviceStreamResourcePool(
const paddle::Place& place) {
PADDLE_ENFORCE_EQ(
platform::is_custom_place(place),
true,
platform::errors::PreconditionNotMet(
"Required device shall be CustomPlace, but received %d. ", place));
int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
pool_.reserve(dev_cnt);
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
auto creator = [place, dev_idx] {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_);
phi::stream::Stream* stream = new phi::stream::Stream(place_, nullptr);
phi::DeviceManager::GetDeviceWithPlace(place_)->CreateStream(stream);
return stream;
};
auto deleter = [place, dev_idx](phi::stream::Stream* stream) {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_);
phi::DeviceManager::GetDeviceWithPlace(place_)->DestroyStream(stream);
delete stream;
};
pool_.emplace_back(
ResourcePool<CustomDeviceStreamObject>::Create(creator, deleter));
}
}
CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance(
const paddle::Place& place) {
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceStreamResourcePool>>>
pool;
PADDLE_ENFORCE_EQ(
platform::is_custom_place(place),
true,
platform::errors::PreconditionNotMet(
"Required device shall be CustomPlace, but received %d. ", place));
if (pool.find(place.GetDeviceType()) == pool.end()) {
pool.insert(
{place.GetDeviceType(),
std::vector<std::shared_ptr<CustomDeviceStreamResourcePool>>()});
for (size_t i = 0;
i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
++i) {
pool[place.GetDeviceType()].emplace_back(
new CustomDeviceStreamResourcePool(
paddle::platform::CustomPlace(place.GetDeviceType(), i)));
}
}
PADDLE_ENFORCE_LT(
place.GetDeviceId(),
pool[place.GetDeviceType()].size(),
platform::errors::OutOfRange("Device id is out of range, device id shall "
"be less than %d, but received %d. ",
pool[place.GetDeviceType()].size(),
place.GetDeviceId()));
return *pool[place.GetDeviceType()][place.GetDeviceId()];
}
std::shared_ptr<CustomDeviceStreamObject> CustomDeviceStreamResourcePool::New(
int dev_idx) {
PADDLE_ENFORCE_GE(
dev_idx,
0,
platform::errors::InvalidArgument(
"The dev_idx should be not less than 0, but got %d.", dev_idx));
PADDLE_ENFORCE_LT(
dev_idx,
pool_.size(),
platform::errors::OutOfRange(
"The dev_idx should be less than device count %d, but got %d.",
pool_.size(),
dev_idx));
return pool_[dev_idx]->New();
}
CustomDeviceEventResourcePool::CustomDeviceEventResourcePool(
const paddle::Place& place) {
PADDLE_ENFORCE_EQ(
platform::is_custom_place(place),
true,
platform::errors::PreconditionNotMet(
"Required device shall be CustomPlace, but received %d. ", place));
int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
pool_.reserve(dev_cnt);
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
auto creator = [place, dev_idx] {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_);
phi::event::Event* event = new phi::event::Event(place_, nullptr);
phi::DeviceManager::GetDeviceWithPlace(place_)->CreateEvent(event);
return event;
};
auto deleter = [place, dev_idx](phi::event::Event* event) {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_);
phi::DeviceManager::GetDeviceWithPlace(place_)->DestroyEvent(event);
};
pool_.emplace_back(
ResourcePool<CustomDeviceEventObject>::Create(creator, deleter));
}
}
CustomDeviceEventResourcePool& CustomDeviceEventResourcePool::Instance(
const phi::Place& place) {
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>>
pool;
PADDLE_ENFORCE_EQ(
platform::is_custom_place(place),
true,
platform::errors::PreconditionNotMet(
"Required device shall be CustomPlace, but received %d. ", place));
if (pool.find(place.GetDeviceType()) == pool.end()) {
pool.insert(
{place.GetDeviceType(),
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>()});
for (size_t i = 0;
i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
++i) {
pool[place.GetDeviceType()].emplace_back(
new CustomDeviceEventResourcePool(
paddle::platform::CustomPlace(place.GetDeviceType(), i)));
}
}
PADDLE_ENFORCE_LT(
place.GetDeviceId(),
pool[place.GetDeviceType()].size(),
platform::errors::OutOfRange("Device id is out of range, device id shall "
"be less than %d, but received %d. ",
pool[place.GetDeviceType()].size(),
place.GetDeviceId()));
return *pool[place.GetDeviceType()][place.GetDeviceId()];
}
std::shared_ptr<CustomDeviceEventObject> CustomDeviceEventResourcePool::New(
int dev_idx) {
PADDLE_ENFORCE_GE(
dev_idx,
0,
platform::errors::InvalidArgument(
"The dev_idx should be not less than 0, but got %d.", dev_idx));
PADDLE_ENFORCE_LT(
dev_idx,
pool_.size(),
platform::errors::OutOfRange(
"The dev_idx should be less than device count %d, but got %d.",
pool_.size(),
dev_idx));
return pool_[dev_idx]->New();
}
} // namespace platform
} // namespace paddle
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include <memory>
#include <type_traits>
#include <vector>
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/resource_pool.h"
#include "paddle/phi/backends/device_manager.h"
namespace paddle {
namespace platform {
using CustomDeviceStreamObject = phi::stream::Stream;
using CustomDeviceEventObject = phi::event::Event;
class CustomDeviceStreamResourcePool {
public:
std::shared_ptr<CustomDeviceStreamObject> New(int dev_idx);
static CustomDeviceStreamResourcePool& Instance(const paddle::Place& place);
private:
explicit CustomDeviceStreamResourcePool(const paddle::Place& place);
DISABLE_COPY_AND_ASSIGN(CustomDeviceStreamResourcePool);
private:
std::vector<std::shared_ptr<ResourcePool<CustomDeviceStreamObject>>> pool_;
};
class CustomDeviceEventResourcePool {
public:
std::shared_ptr<CustomDeviceEventObject> New(int dev_idx);
static CustomDeviceEventResourcePool& Instance(const paddle::Place& place);
private:
explicit CustomDeviceEventResourcePool(const paddle::Place& place);
DISABLE_COPY_AND_ASSIGN(CustomDeviceEventResourcePool);
private:
std::vector<std::shared_ptr<ResourcePool<CustomDeviceEventObject>>> pool_;
};
} // namespace platform
} // namespace paddle
#endif
...@@ -130,7 +130,7 @@ constexpr DeviceType kXPU = DeviceType::XPU; ...@@ -130,7 +130,7 @@ constexpr DeviceType kXPU = DeviceType::XPU;
constexpr DeviceType kNPU = DeviceType::NPU; constexpr DeviceType kNPU = DeviceType::NPU;
constexpr DeviceType kIPU = DeviceType::IPU; constexpr DeviceType kIPU = DeviceType::IPU;
constexpr DeviceType kMLU = DeviceType::MLU; constexpr DeviceType kMLU = DeviceType::MLU;
constexpr DeviceType kCUSOTM_DEVICE = DeviceType::CUSTOM_DEVICE; constexpr DeviceType kCUSTOM_DEVICE = DeviceType::CUSTOM_DEVICE;
using DeviceContext = phi::DeviceContext; using DeviceContext = phi::DeviceContext;
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
using ::paddle::platform::kCPU; using ::paddle::platform::kCPU;
using ::paddle::platform::kCUDA; using ::paddle::platform::kCUDA;
using ::paddle::platform::kCUSTOM_DEVICE;
using ::paddle::platform::kNPU; using ::paddle::platform::kNPU;
using ::paddle::platform::kXPU; using ::paddle::platform::kXPU;
...@@ -42,3 +43,9 @@ USE_EVENT(kNPU); ...@@ -42,3 +43,9 @@ USE_EVENT(kNPU);
USE_EVENT_WAIT(kNPU, kNPU) USE_EVENT_WAIT(kNPU, kNPU)
USE_EVENT_WAIT(kCPU, kNPU) USE_EVENT_WAIT(kCPU, kNPU)
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
USE_EVENT(kCUSTOM_DEVICE);
USE_EVENT_WAIT(kCUSTOM_DEVICE, kCUSTOM_DEVICE)
USE_EVENT_WAIT(kCPU, kCUSTOM_DEVICE)
#endif
...@@ -64,11 +64,13 @@ class DeviceEvent { ...@@ -64,11 +64,13 @@ class DeviceEvent {
"Required type < %d, but received type = %d", "Required type < %d, but received type = %d",
MaxDeviceTypes, MaxDeviceTypes,
type_id_)); type_id_));
#ifndef PADDLE_WITH_CUSTOM_DEVICE
// TODO(Aurelius84): only support CPU/CUDA/NPU. // TODO(Aurelius84): only support CPU/CUDA/NPU.
PADDLE_ENFORCE_LT(type_id_, PADDLE_ENFORCE_LT(type_id_,
3, 3,
platform::errors::Unavailable( platform::errors::Unavailable(
"Currently DeviceEvent do not support %s", place)); "Currently DeviceEvent do not support %s", place));
#endif
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
event_creator_[type_id_], event_creator_[type_id_],
platform::errors::Unavailable( platform::errors::Unavailable(
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/platform/device/custom/custom_device_resource_pool.h"
#include "paddle/fluid/platform/device_event_base.h"
#include "paddle/fluid/platform/event.h"
namespace paddle {
namespace platform {
struct CustomDeviceEventWrapper {
explicit CustomDeviceEventWrapper(const platform::Place& place) {
PADDLE_ENFORCE_EQ(
platform::is_custom_place(place),
true,
platform::errors::PreconditionNotMet(
"Required device shall be CustomPlace, but received %d. ", place));
device_id_ = place.device;
PADDLE_ENFORCE_GT(
device_id_,
-1,
platform::errors::PreconditionNotMet(
"Required DeviceOption.device_id > -1, but received %d. ",
device_id_));
inner_event_ =
CustomDeviceEventResourcePool::Instance(place).New(device_id_);
}
std::shared_ptr<CustomDeviceEventObject> inner_event_;
int device_id_;
};
void DeviceEventCreateCustomDevice(DeviceEvent* event,
const platform::Place& place,
unsigned int) {
event->InitEvent(std::make_shared<CustomDeviceEventWrapper>(place));
}
void DeviceEventRecordCustomDevice(DeviceEvent* event,
const DeviceContext* context) {
auto* wrapper =
static_cast<CustomDeviceEventWrapper*>(event->GetEvent().get());
auto* custom_device_ctx =
dynamic_cast<const platform::CustomDeviceContext*>(context);
PADDLE_ENFORCE_NOT_NULL(
custom_device_ctx,
platform::errors::PreconditionNotMet(
"Failed to dynamic_cast context into NPUDeviceContext."));
phi::stream::Stream stream_wrapper(custom_device_ctx->GetPlace(),
custom_device_ctx->stream());
wrapper->inner_event_->Record(&stream_wrapper);
}
bool DeviceEventQueryCustomDevice(const DeviceEvent* event) {
auto* wrapper =
static_cast<CustomDeviceEventWrapper*>(event->GetEvent().get());
PADDLE_ENFORCE_NOT_NULL(
wrapper,
platform::errors::PreconditionNotMet(
"Failed to dynamic_cast event into CustomDeviceEventWrapper."));
return wrapper->inner_event_->Query();
}
void DeviceEventFinishCustomDevice(const DeviceEvent* event) {
auto* wrapper =
static_cast<CustomDeviceEventWrapper*>(event->GetEvent().get());
wrapper->inner_event_->Synchonrize();
}
void DeviceEventCustomDeviceWaitCustomDevice(const DeviceEvent* event,
const DeviceContext* context) {
auto* wrapper =
static_cast<CustomDeviceEventWrapper*>(event->GetEvent().get());
auto* custom_device_ctx =
dynamic_cast<const platform::CustomDeviceContext*>(context);
PADDLE_ENFORCE_NOT_NULL(
custom_device_ctx,
platform::errors::PreconditionNotMet(
"Failed to dynamic_cast context into NPUDeviceContext."));
phi::stream::Stream stream_wrapper(custom_device_ctx->GetPlace(),
custom_device_ctx->stream());
stream_wrapper.WaitEvent(wrapper->inner_event_.get());
}
void DeviceEventCPUWaitCustomDevice(const DeviceEvent* event,
const DeviceContext* context) {
DeviceEventFinishCustomDevice(event);
}
void DeviceEventSetFinishedCustomDevice(const DeviceEvent* event) {
// do nothing
}
void EventResetCustomDevice(const DeviceEvent* event) {
// do nothing
}
} // namespace platform
} // namespace paddle
using ::paddle::platform::kCPU;
using ::paddle::platform::kCUSTOM_DEVICE;
REGISTER_EVENT_CREATE_FUNCTION(kCUSTOM_DEVICE,
paddle::platform::DeviceEventCreateCustomDevice)
REGISTER_EVENT_RECORD_FUNCTION(kCUSTOM_DEVICE,
paddle::platform::DeviceEventRecordCustomDevice)
REGISTER_EVENT_QUERY_FUNCTION(kCUSTOM_DEVICE,
paddle::platform::DeviceEventQueryCustomDevice)
REGISTER_EVENT_FINISH_FUNCTION(kCUSTOM_DEVICE,
paddle::platform::DeviceEventFinishCustomDevice)
REGISTER_EVENT_SET_FINISHED_FUNCTION(
kCUSTOM_DEVICE, paddle::platform::DeviceEventSetFinishedCustomDevice)
REGISTER_EVENT_WAIT_FUNCTION(
kCUSTOM_DEVICE,
kCUSTOM_DEVICE,
paddle::platform::DeviceEventCustomDeviceWaitCustomDevice)
REGISTER_EVENT_WAIT_FUNCTION(kCPU,
kCUSTOM_DEVICE,
paddle::platform::DeviceEventCPUWaitCustomDevice)
REGISTER_EVENT_RESET_FUNCTION(kCUSTOM_DEVICE,
paddle::platform::EventResetCustomDevice)
#endif
...@@ -96,9 +96,10 @@ class DeviceInterface { // Driver / Runtime ...@@ -96,9 +96,10 @@ class DeviceInterface { // Driver / Runtime
// Event // Event
// ! Create an event. // ! Create an event.
virtual void CreateEvent(size_t dev_id, virtual void CreateEvent(
event::Event* event, size_t dev_id,
event::Event::Flag flags); event::Event* event,
event::Event::Flag flags = event::Event::Flag::Default);
// ! Destroy an event. // ! Destroy an event.
virtual void DestroyEvent(size_t dev_id, event::Event* event); virtual void DestroyEvent(size_t dev_id, event::Event* event);
......
...@@ -55,7 +55,8 @@ class Device final { ...@@ -55,7 +55,8 @@ class Device final {
// Event // Event
// ! Create an event. // ! Create an event.
void CreateEvent(event::Event* event, event::Event::Flag flags); void CreateEvent(event::Event* event,
event::Event::Flag flags = event::Event::Flag::Default);
// ! Destroy an event. // ! Destroy an event.
void DestroyEvent(event::Event* event); void DestroyEvent(event::Event* event);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册