diff --git a/paddle/fluid/framework/new_executor/data_transfer.cc b/paddle/fluid/framework/new_executor/data_transfer.cc index fccb2ee5a755085e4964841af7055789a1c9c17e..06962f7b5e77313663af8bda640f164e7959fec3 100644 --- a/paddle/fluid/framework/new_executor/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/data_transfer.cc @@ -169,7 +169,12 @@ void DataTranferHelper::RunAndConstructOpFuncNode( // NOTE(winter-wang): in npu device, D2H kernel is asynchronous. need to // explicit synchronization. #ifdef PADDLE_WITH_ASCEND_CL - if (op_type == kMemcpyD2H) { + if (op_type == kMemcpyD2H && platform::is_npu_place(dev_ctx->GetPlace())) { + dev_ctx->Wait(); + } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (op_type == kMemcpyD2H && platform::is_custom_place(dev_ctx->GetPlace())) { dev_ctx->Wait(); } #endif @@ -363,11 +368,12 @@ std::shared_ptr TransferDevice(const std::string& var_name, src_place)); if (IsSupportedHetePlace(dst_place)) { op_type = kMemcpyH2D; - int dst_place_type = platform::is_gpu_place(dst_place) ? 0 - : platform::is_npu_place(dst_place) ? 1 - : platform::is_ipu_place(dst_place) ? 3 - : platform::is_xpu_place(dst_place) ? 2 - : -1; + int dst_place_type = platform::is_gpu_place(dst_place) ? 0 + : platform::is_npu_place(dst_place) ? 1 + : platform::is_ipu_place(dst_place) ? 3 + : platform::is_xpu_place(dst_place) ? 2 + : platform::is_custom_place(dst_place) ? 6 + : -1; attr_map = {{"dst_place_type", dst_place_type}}; } else if (IsSupportedHetePlace(src_place)) { op_type = kMemcpyD2H; diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 0a0170110de2aed4bcefa90cab6ceb464f4c1ee6..63d6fcbf823e4fce0b1b260543aec38475a607d8 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -165,7 +165,14 @@ paddle::framework::FetchList InterpreterCore::Run( ExecuteInstructionList(vec_instruction_); #ifdef PADDLE_WITH_ASCEND_CL - platform::DeviceContextPool::Instance().Get(place_)->Wait(); + if (platform::is_npu_place(place_)) { + platform::DeviceContextPool::Instance().Get(place_)->Wait(); + } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (platform::is_custom_place(place_)) { + platform::DeviceContextPool::Instance().Get(place_)->Wait(); + } #endif } if (create_local_scope_) { @@ -223,7 +230,14 @@ paddle::framework::FetchList InterpreterCore::Run( ExecuteInstructionList(vec_instruction_); #ifdef PADDLE_WITH_ASCEND_CL - platform::DeviceContextPool::Instance().Get(place_)->Wait(); + if (platform::is_npu_place(place_)) { + platform::DeviceContextPool::Instance().Get(place_)->Wait(); + } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (platform::is_custom_place(place_)) { + platform::DeviceContextPool::Instance().Get(place_)->Wait(); + } #endif } diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 2df8892f5bd8aac45ad0af4cce3aead84da683c5..4f065f2452e287b8a6b0815801095878864e0e7d 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -399,7 +399,8 @@ static bool IsCpuOp(const Instruction& instr) { // is supported heterogeneous place static bool IsSupportedHetePlace(const phi::Place& place) { return platform::is_gpu_place(place) || platform::is_npu_place(place) || - platform::is_xpu_place(place) || platform::is_ipu_place(place); + platform::is_xpu_place(place) || platform::is_ipu_place(place) || + platform::is_custom_place(place); } } // namespace interpreter diff --git a/paddle/fluid/framework/new_executor/stream_analyzer.cc b/paddle/fluid/framework/new_executor/stream_analyzer.cc index 760a852baee68f3c3f53386ce28611e923d80342..3025f017471c766495cfa940dc6ffe1ebae7a196 100644 --- a/paddle/fluid/framework/new_executor/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/stream_analyzer.cc @@ -30,7 +30,8 @@ std::mutex ctx_mtx; } // namespace StreamAnalyzer::StreamAnalyzer(const platform::Place& place) : place_(place) { - if (platform::is_gpu_place(place) || platform::is_npu_place(place)) { + if (platform::is_gpu_place(place) || platform::is_npu_place(place) || + platform::is_custom_place(place)) { std::lock_guard lk(ctx_mtx); if (d2h_ctxs == nullptr) { d2h_ctxs = new std::map< @@ -178,7 +179,8 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( auto* dev_ctx = op_func_node.dev_ctx_; // only gpu/npu need update. xpu not need, because xpu memcpy op kernel is // synchronous. - if (platform::is_gpu_place(place_) || platform::is_npu_place(place_)) { + if (platform::is_gpu_place(place_) || platform::is_npu_place(place_) || + platform::is_custom_place(place_)) { if (op_type == interpreter::kMemcpyD2H) { VLOG(3) << "Get dev_ctx from d2h_context_pool_"; dev_ctx = d2h_ctx_.get().get(); @@ -209,7 +211,7 @@ bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr, return true; // npu d2h kernel is asynchronous. - if (platform::is_npu_place(place_)) { + if (platform::is_npu_place(place_) || platform::is_custom_place(place_)) { return interpreter::IsCpuOp(cur_instr) || interpreter::IsMemcpyH2D(next_instr); } @@ -227,6 +229,8 @@ platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) { return platform::kXPU; } else if (platform::is_npu_place(place_)) { return platform::kNPU; + } else if (platform::is_custom_place(place_)) { + return platform::kCUSTOM_DEVICE; } return platform::kCUDA; } diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 6ed27fd9b326d53f145ca71cb41751eec8273d77..c910e4b4ea0fbf0e43caee3be8bf8ab23a6ce236 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -251,6 +251,10 @@ if(WITH_MLU) target_link_libraries(device_context mlu_resource_pool) endif() +if(WITH_CUSTOM_DEVICE) + target_link_libraries(device_context custom_device_resource_pool) +endif() + cc_test( init_test SRCS init_test.cc @@ -284,11 +288,17 @@ if(WITH_GPU) set(DEVICE_EVENT_LIBS device_event_gpu CACHE INTERNAL "device event libs") - nv_test( - device_event_test - SRCS device_event_test.cc - DEPS device_event_gpu) - + if(WITH_CUSTOM_DEVICE) + nv_test( + device_event_test + SRCS device_event_test.cc + DEPS device_event_gpu device_event_custom_device) + else() + nv_test( + device_event_test + SRCS device_event_test.cc + DEPS device_event_gpu) + endif() nv_test( device_context_test SRCS device_context_test.cu @@ -311,11 +321,17 @@ if(WITH_ROCM) set(DEVICE_EVENT_LIBS device_event_gpu CACHE INTERNAL "device event libs") - hip_test( - device_event_test - SRCS device_event_test.cc - DEPS device_event_gpu) - + if(WITH_CUSTOM_DEVICE) + hip_test( + device_event_test + SRCS device_event_test.cc + DEPS device_event_gpu device_event_custom_device) + else() + hip_test( + device_event_test + SRCS device_event_test.cc + DEPS device_event_gpu) + endif() hip_test( device_context_test SRCS device_context_test.cu @@ -470,3 +486,13 @@ if(NOT APPLE AND NOT WIN32) DEPS device_code lod_tensor) endif() endif() + +if(WITH_CUSTOM_DEVICE) + cc_library( + device_event_custom_device + SRCS device_event_custom_device.cc + DEPS device_event_base custom_device_resource_pool) + set(DEVICE_EVENT_LIBS + ${DEVICE_EVENT_LIBS} device_event_custom_device + CACHE INTERNAL "device event libs") +endif() diff --git a/paddle/fluid/platform/device/CMakeLists.txt b/paddle/fluid/platform/device/CMakeLists.txt index 62745883023cb0fb37f6cfcbd6dca70d43f44faa..d01cb2288adaa435e92f083384c846bef66731f2 100644 --- a/paddle/fluid/platform/device/CMakeLists.txt +++ b/paddle/fluid/platform/device/CMakeLists.txt @@ -24,3 +24,7 @@ endif() if(WITH_MLU) add_subdirectory(mlu) endif() + +if(WITH_CUSTOM_DEVICE) + add_subdirectory(custom) +endif() diff --git a/paddle/fluid/platform/device/custom/CMakeLists.txt b/paddle/fluid/platform/device/custom/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..142c6c70e797d1746c56ee4b58f4cd18ef9bc22d --- /dev/null +++ b/paddle/fluid/platform/device/custom/CMakeLists.txt @@ -0,0 +1,6 @@ +if(WITH_CUSTOM_DEVICE) + cc_library( + custom_device_resource_pool + SRCS custom_device_resource_pool.cc + DEPS gflags glog enforce monitor) +endif() diff --git a/paddle/fluid/platform/device/custom/custom_device_resource_pool.cc b/paddle/fluid/platform/device/custom/custom_device_resource_pool.cc new file mode 100644 index 0000000000000000000000000000000000000000..1cd6c3bb3f7458cdb512d53c52fcfecef43dce91 --- /dev/null +++ b/paddle/fluid/platform/device/custom/custom_device_resource_pool.cc @@ -0,0 +1,190 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/fluid/platform/device/custom/custom_device_resource_pool.h" + +namespace paddle { +namespace platform { + +CustomDeviceStreamResourcePool::CustomDeviceStreamResourcePool( + const paddle::Place& place) { + PADDLE_ENFORCE_EQ( + platform::is_custom_place(place), + true, + platform::errors::PreconditionNotMet( + "Required device shall be CustomPlace, but received %d. ", place)); + + int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType()); + pool_.reserve(dev_cnt); + for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) { + auto creator = [place, dev_idx] { + auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx); + phi::DeviceManager::SetDevice(place_); + + phi::stream::Stream* stream = new phi::stream::Stream(place_, nullptr); + phi::DeviceManager::GetDeviceWithPlace(place_)->CreateStream(stream); + return stream; + }; + + auto deleter = [place, dev_idx](phi::stream::Stream* stream) { + auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx); + phi::DeviceManager::SetDevice(place_); + + phi::DeviceManager::GetDeviceWithPlace(place_)->DestroyStream(stream); + delete stream; + }; + + pool_.emplace_back( + ResourcePool::Create(creator, deleter)); + } +} + +CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance( + const paddle::Place& place) { + static std::unordered_map< + std::string, + std::vector>> + pool; + PADDLE_ENFORCE_EQ( + platform::is_custom_place(place), + true, + platform::errors::PreconditionNotMet( + "Required device shall be CustomPlace, but received %d. ", place)); + if (pool.find(place.GetDeviceType()) == pool.end()) { + pool.insert( + {place.GetDeviceType(), + std::vector>()}); + for (size_t i = 0; + i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType()); + ++i) { + pool[place.GetDeviceType()].emplace_back( + new CustomDeviceStreamResourcePool( + paddle::platform::CustomPlace(place.GetDeviceType(), i))); + } + } + PADDLE_ENFORCE_LT( + place.GetDeviceId(), + pool[place.GetDeviceType()].size(), + platform::errors::OutOfRange("Device id is out of range, device id shall " + "be less than %d, but received %d. ", + pool[place.GetDeviceType()].size(), + place.GetDeviceId())); + return *pool[place.GetDeviceType()][place.GetDeviceId()]; +} + +std::shared_ptr CustomDeviceStreamResourcePool::New( + int dev_idx) { + PADDLE_ENFORCE_GE( + dev_idx, + 0, + platform::errors::InvalidArgument( + "The dev_idx should be not less than 0, but got %d.", dev_idx)); + PADDLE_ENFORCE_LT( + dev_idx, + pool_.size(), + platform::errors::OutOfRange( + "The dev_idx should be less than device count %d, but got %d.", + pool_.size(), + dev_idx)); + return pool_[dev_idx]->New(); +} + +CustomDeviceEventResourcePool::CustomDeviceEventResourcePool( + const paddle::Place& place) { + PADDLE_ENFORCE_EQ( + platform::is_custom_place(place), + true, + platform::errors::PreconditionNotMet( + "Required device shall be CustomPlace, but received %d. ", place)); + + int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType()); + pool_.reserve(dev_cnt); + for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) { + auto creator = [place, dev_idx] { + auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx); + phi::DeviceManager::SetDevice(place_); + + phi::event::Event* event = new phi::event::Event(place_, nullptr); + phi::DeviceManager::GetDeviceWithPlace(place_)->CreateEvent(event); + return event; + }; + + auto deleter = [place, dev_idx](phi::event::Event* event) { + auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx); + phi::DeviceManager::SetDevice(place_); + + phi::DeviceManager::GetDeviceWithPlace(place_)->DestroyEvent(event); + }; + + pool_.emplace_back( + ResourcePool::Create(creator, deleter)); + } +} + +CustomDeviceEventResourcePool& CustomDeviceEventResourcePool::Instance( + const phi::Place& place) { + static std::unordered_map< + std::string, + std::vector>> + pool; + PADDLE_ENFORCE_EQ( + platform::is_custom_place(place), + true, + platform::errors::PreconditionNotMet( + "Required device shall be CustomPlace, but received %d. ", place)); + if (pool.find(place.GetDeviceType()) == pool.end()) { + pool.insert( + {place.GetDeviceType(), + std::vector>()}); + for (size_t i = 0; + i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType()); + ++i) { + pool[place.GetDeviceType()].emplace_back( + new CustomDeviceEventResourcePool( + paddle::platform::CustomPlace(place.GetDeviceType(), i))); + } + } + PADDLE_ENFORCE_LT( + place.GetDeviceId(), + pool[place.GetDeviceType()].size(), + platform::errors::OutOfRange("Device id is out of range, device id shall " + "be less than %d, but received %d. ", + pool[place.GetDeviceType()].size(), + place.GetDeviceId())); + return *pool[place.GetDeviceType()][place.GetDeviceId()]; +} + +std::shared_ptr CustomDeviceEventResourcePool::New( + int dev_idx) { + PADDLE_ENFORCE_GE( + dev_idx, + 0, + platform::errors::InvalidArgument( + "The dev_idx should be not less than 0, but got %d.", dev_idx)); + PADDLE_ENFORCE_LT( + dev_idx, + pool_.size(), + platform::errors::OutOfRange( + "The dev_idx should be less than device count %d, but got %d.", + pool_.size(), + dev_idx)); + return pool_[dev_idx]->New(); +} + +} // namespace platform +} // namespace paddle + +#endif diff --git a/paddle/fluid/platform/device/custom/custom_device_resource_pool.h b/paddle/fluid/platform/device/custom/custom_device_resource_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..c643cff7b54512257e11df9ab4cd2cbfaf836c9c --- /dev/null +++ b/paddle/fluid/platform/device/custom/custom_device_resource_pool.h @@ -0,0 +1,65 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include +#include +#include + +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/platform/resource_pool.h" +#include "paddle/phi/backends/device_manager.h" + +namespace paddle { +namespace platform { + +using CustomDeviceStreamObject = phi::stream::Stream; +using CustomDeviceEventObject = phi::event::Event; + +class CustomDeviceStreamResourcePool { + public: + std::shared_ptr New(int dev_idx); + + static CustomDeviceStreamResourcePool& Instance(const paddle::Place& place); + + private: + explicit CustomDeviceStreamResourcePool(const paddle::Place& place); + + DISABLE_COPY_AND_ASSIGN(CustomDeviceStreamResourcePool); + + private: + std::vector>> pool_; +}; + +class CustomDeviceEventResourcePool { + public: + std::shared_ptr New(int dev_idx); + + static CustomDeviceEventResourcePool& Instance(const paddle::Place& place); + + private: + explicit CustomDeviceEventResourcePool(const paddle::Place& place); + + DISABLE_COPY_AND_ASSIGN(CustomDeviceEventResourcePool); + + private: + std::vector>> pool_; +}; + +} // namespace platform +} // namespace paddle + +#endif diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 7939f8ff7c066062036d31f709d9ac0b5d0e768a..d8ebb019fc69166b98e6edf46ad64dec0e3a2fea 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -130,7 +130,7 @@ constexpr DeviceType kXPU = DeviceType::XPU; constexpr DeviceType kNPU = DeviceType::NPU; constexpr DeviceType kIPU = DeviceType::IPU; constexpr DeviceType kMLU = DeviceType::MLU; -constexpr DeviceType kCUSOTM_DEVICE = DeviceType::CUSTOM_DEVICE; +constexpr DeviceType kCUSTOM_DEVICE = DeviceType::CUSTOM_DEVICE; using DeviceContext = phi::DeviceContext; diff --git a/paddle/fluid/platform/device_event.h b/paddle/fluid/platform/device_event.h index cf80266050af2ecfe8f6439c4bd898ff3a6a5f23..8659d8be902b61d299e9eb1bc5a904f55b945653 100644 --- a/paddle/fluid/platform/device_event.h +++ b/paddle/fluid/platform/device_event.h @@ -25,6 +25,7 @@ using ::paddle::platform::kCPU; using ::paddle::platform::kCUDA; +using ::paddle::platform::kCUSTOM_DEVICE; using ::paddle::platform::kNPU; using ::paddle::platform::kXPU; @@ -42,3 +43,9 @@ USE_EVENT(kNPU); USE_EVENT_WAIT(kNPU, kNPU) USE_EVENT_WAIT(kCPU, kNPU) #endif + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +USE_EVENT(kCUSTOM_DEVICE); +USE_EVENT_WAIT(kCUSTOM_DEVICE, kCUSTOM_DEVICE) +USE_EVENT_WAIT(kCPU, kCUSTOM_DEVICE) +#endif diff --git a/paddle/fluid/platform/device_event_base.h b/paddle/fluid/platform/device_event_base.h index 6a2948480b549727edc9af12f973e0f61a87160d..d0458dcb9e4e46ba4836e862264db0d36b048b31 100644 --- a/paddle/fluid/platform/device_event_base.h +++ b/paddle/fluid/platform/device_event_base.h @@ -64,11 +64,13 @@ class DeviceEvent { "Required type < %d, but received type = %d", MaxDeviceTypes, type_id_)); +#ifndef PADDLE_WITH_CUSTOM_DEVICE // TODO(Aurelius84): only support CPU/CUDA/NPU. PADDLE_ENFORCE_LT(type_id_, 3, platform::errors::Unavailable( "Currently DeviceEvent do not support %s", place)); +#endif PADDLE_ENFORCE_NOT_NULL( event_creator_[type_id_], platform::errors::Unavailable( diff --git a/paddle/fluid/platform/device_event_custom_device.cc b/paddle/fluid/platform/device_event_custom_device.cc new file mode 100644 index 0000000000000000000000000000000000000000..a45cb43baf2ec3980d046b3ebadabc4a1614ac63 --- /dev/null +++ b/paddle/fluid/platform/device_event_custom_device.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_CUSTOM_DEVICE + +#include "paddle/fluid/platform/device/custom/custom_device_resource_pool.h" +#include "paddle/fluid/platform/device_event_base.h" +#include "paddle/fluid/platform/event.h" +namespace paddle { +namespace platform { +struct CustomDeviceEventWrapper { + explicit CustomDeviceEventWrapper(const platform::Place& place) { + PADDLE_ENFORCE_EQ( + platform::is_custom_place(place), + true, + platform::errors::PreconditionNotMet( + "Required device shall be CustomPlace, but received %d. ", place)); + + device_id_ = place.device; + PADDLE_ENFORCE_GT( + device_id_, + -1, + platform::errors::PreconditionNotMet( + "Required DeviceOption.device_id > -1, but received %d. ", + device_id_)); + inner_event_ = + CustomDeviceEventResourcePool::Instance(place).New(device_id_); + } + std::shared_ptr inner_event_; + int device_id_; +}; + +void DeviceEventCreateCustomDevice(DeviceEvent* event, + const platform::Place& place, + unsigned int) { + event->InitEvent(std::make_shared(place)); +} + +void DeviceEventRecordCustomDevice(DeviceEvent* event, + const DeviceContext* context) { + auto* wrapper = + static_cast(event->GetEvent().get()); + auto* custom_device_ctx = + dynamic_cast(context); + PADDLE_ENFORCE_NOT_NULL( + custom_device_ctx, + platform::errors::PreconditionNotMet( + "Failed to dynamic_cast context into NPUDeviceContext.")); + + phi::stream::Stream stream_wrapper(custom_device_ctx->GetPlace(), + custom_device_ctx->stream()); + wrapper->inner_event_->Record(&stream_wrapper); +} + +bool DeviceEventQueryCustomDevice(const DeviceEvent* event) { + auto* wrapper = + static_cast(event->GetEvent().get()); + PADDLE_ENFORCE_NOT_NULL( + wrapper, + platform::errors::PreconditionNotMet( + "Failed to dynamic_cast event into CustomDeviceEventWrapper.")); + return wrapper->inner_event_->Query(); +} + +void DeviceEventFinishCustomDevice(const DeviceEvent* event) { + auto* wrapper = + static_cast(event->GetEvent().get()); + wrapper->inner_event_->Synchonrize(); +} + +void DeviceEventCustomDeviceWaitCustomDevice(const DeviceEvent* event, + const DeviceContext* context) { + auto* wrapper = + static_cast(event->GetEvent().get()); + auto* custom_device_ctx = + dynamic_cast(context); + PADDLE_ENFORCE_NOT_NULL( + custom_device_ctx, + platform::errors::PreconditionNotMet( + "Failed to dynamic_cast context into NPUDeviceContext.")); + phi::stream::Stream stream_wrapper(custom_device_ctx->GetPlace(), + custom_device_ctx->stream()); + stream_wrapper.WaitEvent(wrapper->inner_event_.get()); +} + +void DeviceEventCPUWaitCustomDevice(const DeviceEvent* event, + const DeviceContext* context) { + DeviceEventFinishCustomDevice(event); +} + +void DeviceEventSetFinishedCustomDevice(const DeviceEvent* event) { + // do nothing +} + +void EventResetCustomDevice(const DeviceEvent* event) { + // do nothing +} + +} // namespace platform +} // namespace paddle + +using ::paddle::platform::kCPU; +using ::paddle::platform::kCUSTOM_DEVICE; +REGISTER_EVENT_CREATE_FUNCTION(kCUSTOM_DEVICE, + paddle::platform::DeviceEventCreateCustomDevice) +REGISTER_EVENT_RECORD_FUNCTION(kCUSTOM_DEVICE, + paddle::platform::DeviceEventRecordCustomDevice) +REGISTER_EVENT_QUERY_FUNCTION(kCUSTOM_DEVICE, + paddle::platform::DeviceEventQueryCustomDevice) +REGISTER_EVENT_FINISH_FUNCTION(kCUSTOM_DEVICE, + paddle::platform::DeviceEventFinishCustomDevice) +REGISTER_EVENT_SET_FINISHED_FUNCTION( + kCUSTOM_DEVICE, paddle::platform::DeviceEventSetFinishedCustomDevice) +REGISTER_EVENT_WAIT_FUNCTION( + kCUSTOM_DEVICE, + kCUSTOM_DEVICE, + paddle::platform::DeviceEventCustomDeviceWaitCustomDevice) +REGISTER_EVENT_WAIT_FUNCTION(kCPU, + kCUSTOM_DEVICE, + paddle::platform::DeviceEventCPUWaitCustomDevice) +REGISTER_EVENT_RESET_FUNCTION(kCUSTOM_DEVICE, + paddle::platform::EventResetCustomDevice) +#endif diff --git a/paddle/phi/backends/device_base.h b/paddle/phi/backends/device_base.h index a0d95349196e09a9dacb0cfebfac4616fcd1ea05..7030777474d5aa3dc659d31857453c60d803c467 100644 --- a/paddle/phi/backends/device_base.h +++ b/paddle/phi/backends/device_base.h @@ -96,9 +96,10 @@ class DeviceInterface { // Driver / Runtime // Event // ! Create an event. - virtual void CreateEvent(size_t dev_id, - event::Event* event, - event::Event::Flag flags); + virtual void CreateEvent( + size_t dev_id, + event::Event* event, + event::Event::Flag flags = event::Event::Flag::Default); // ! Destroy an event. virtual void DestroyEvent(size_t dev_id, event::Event* event); diff --git a/paddle/phi/backends/device_manager.h b/paddle/phi/backends/device_manager.h index 54bafd796df46322e08cd6f8bbc444cc3a86058f..130f8fab449ac7838681818979bb770967eda0fa 100644 --- a/paddle/phi/backends/device_manager.h +++ b/paddle/phi/backends/device_manager.h @@ -55,7 +55,8 @@ class Device final { // Event // ! Create an event. - void CreateEvent(event::Event* event, event::Event::Flag flags); + void CreateEvent(event::Event* event, + event::Event::Flag flags = event::Event::Flag::Default); // ! Destroy an event. void DestroyEvent(event::Event* event);