未验证 提交 e99b3cb2 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] Fix device id out of range in custom device resource pool (#56580)

上级 700bc8cd
......@@ -45,13 +45,15 @@ void StreamSafeCustomDeviceAllocation::RecordStream(
auto it = outstanding_event_map_.find(stream);
if (it == outstanding_event_map_.end()) {
outstanding_event_map_[stream].Init(place());
outstanding_event_map_.insert(
{stream, std::make_shared<phi::event::Event>()});
outstanding_event_map_[stream]->Init(place());
VLOG(9) << "Create a new event "
<< outstanding_event_map_[stream].raw_event();
<< outstanding_event_map_[stream]->raw_event();
auto stream_wrapper = phi::stream::Stream(place(), stream);
VLOG(8) << "Record event " << it->second.raw_event() << " to stream "
<< stream;
outstanding_event_map_[stream].Record(&stream_wrapper);
VLOG(8) << "Record event " << outstanding_event_map_[stream]->raw_event()
<< " to stream " << stream;
outstanding_event_map_[stream]->Record(&stream_wrapper);
}
}
......@@ -65,14 +67,16 @@ void StreamSafeCustomDeviceAllocation::MarkAsWillBeFreed() {
outstanding_event_map_.end()) {
std::call_once(once_flag_,
[this] { phi::DeviceManager::SetDevice(place_); });
outstanding_event_map_[owning_stream_].Init(place_);
outstanding_event_map_.insert(
{owning_stream_, std::make_shared<phi::event::Event>()});
outstanding_event_map_[owning_stream_]->Init(place_);
VLOG(9) << "Create a new event "
<< outstanding_event_map_[owning_stream_].raw_event();
<< outstanding_event_map_[owning_stream_]->raw_event();
auto stream_wrapper = phi::stream::Stream(place_, owning_stream_);
VLOG(8) << "Record event "
<< outstanding_event_map_[owning_stream_].raw_event()
<< outstanding_event_map_[owning_stream_]->raw_event()
<< " to stream " << owning_stream_;
outstanding_event_map_[owning_stream_].Record(&stream_wrapper);
outstanding_event_map_[owning_stream_]->Record(&stream_wrapper);
}
}
}
......@@ -87,15 +91,15 @@ bool StreamSafeCustomDeviceAllocation::CanBeFreed() {
it != outstanding_event_map_.end();
++it) {
auto& event = it->second;
if (!event.Query()) {
VLOG(9) << "Event " << event.raw_event() << " for " << ptr()
if (!event->Query()) {
VLOG(9) << "Event " << event->raw_event() << " for " << ptr()
<< " is not completed";
return false;
}
VLOG(8) << "Destroy event " << event.raw_event();
outstanding_event_map_.erase(outstanding_event_map_.begin(), it);
event.Destroy();
VLOG(8) << "Destroy event " << event->raw_event();
event->Destroy();
}
outstanding_event_map_.clear();
return true;
}
......
......@@ -44,7 +44,8 @@ class StreamSafeCustomDeviceAllocation : public Allocation {
private:
thread_local static std::once_flag once_flag_;
DecoratedAllocationPtr underlying_allocation_;
std::map<phi::stream::stream_t, phi::event::Event> outstanding_event_map_;
std::map<phi::stream::stream_t, std::shared_ptr<phi::event::Event>>
outstanding_event_map_;
phi::stream::stream_t owning_stream_;
SpinLock outstanding_event_map_lock_;
std::shared_ptr<Allocator> allocator_;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -27,10 +26,9 @@ CustomDeviceStreamResourcePool::CustomDeviceStreamResourcePool(
platform::errors::PreconditionNotMet(
"Required device shall be CustomPlace, but received %d. ", place));
auto selected_devices =
phi::DeviceManager::GetSelectedDeviceList(place.GetDeviceType());
pool_.reserve(selected_devices.size());
for (auto& dev_idx : selected_devices) {
int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
pool_.reserve(dev_cnt);
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
auto creator = [place, dev_idx, this] {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_);
......@@ -83,11 +81,12 @@ CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance(
if (pool.find(place.GetDeviceType()) == pool.end()) {
pool.insert({place.GetDeviceType(),
std::vector<CustomDeviceStreamResourcePool*>()});
for (auto& dev_id :
phi::DeviceManager::GetSelectedDeviceList(place.GetDeviceType())) {
for (size_t i = 0;
i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
++i) {
pool[place.GetDeviceType()].emplace_back(
new CustomDeviceStreamResourcePool(
paddle::platform::CustomPlace(place.GetDeviceType(), dev_id)));
paddle::platform::CustomPlace(place.GetDeviceType(), i)));
}
}
PADDLE_ENFORCE_LT(
......@@ -125,10 +124,9 @@ CustomDeviceEventResourcePool::CustomDeviceEventResourcePool(
platform::errors::PreconditionNotMet(
"Required device shall be CustomPlace, but received %d. ", place));
auto selected_devices =
phi::DeviceManager::GetSelectedDeviceList(place.GetDeviceType());
pool_.reserve(selected_devices.size());
for (auto& dev_idx : selected_devices) {
int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
pool_.reserve(dev_cnt);
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
auto creator = [place, dev_idx, this] {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_);
......@@ -181,11 +179,12 @@ CustomDeviceEventResourcePool& CustomDeviceEventResourcePool::Instance(
if (pool.find(place.GetDeviceType()) == pool.end()) {
pool.insert(
{place.GetDeviceType(), std::vector<CustomDeviceEventResourcePool*>()});
for (auto& dev_id :
phi::DeviceManager::GetSelectedDeviceList(place.GetDeviceType())) {
for (size_t i = 0;
i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
++i) {
pool[place.GetDeviceType()].emplace_back(
new CustomDeviceEventResourcePool(
paddle::platform::CustomPlace(place.GetDeviceType(), dev_id)));
paddle::platform::CustomPlace(place.GetDeviceType(), i)));
}
}
PADDLE_ENFORCE_LT(
......
......@@ -90,16 +90,10 @@ class CustomDevice : public DeviceInterface {
C_Device_st device;
device.id = dev_id;
devices_pool[dev_id] = device;
InitDevice(dev_id);
}
}
void Finalize() override {
auto devices = GetDeviceList();
for (auto dev_id : devices) {
DeInitDevice(dev_id);
}
bool ok = true;
if (pimpl_->finalize && pimpl_->finalize() != C_SUCCESS) {
LOG(ERROR) << "Finalize " << Type() << " Failed\n";
......
......@@ -29,52 +29,69 @@
namespace phi {
void Device::CheckInitialized() {
std::call_once(initialized_, [&]() { this->impl_->InitDevice(dev_id_); });
}
Device::~Device() { impl_->DeInitDevice(dev_id_); }
void Device::CreateStream(stream::Stream* stream,
const stream::Stream::Priority& priority,
const stream::Stream::Flag& flag) {
CheckInitialized();
impl_->CreateStream(dev_id_, stream, priority, flag);
}
void Device::DestroyStream(stream::Stream* stream) {
CheckInitialized();
impl_->DestroyStream(dev_id_, stream);
}
void Device::SynchronizeStream(const stream::Stream* stream) {
CheckInitialized();
impl_->SynchronizeStream(dev_id_, stream);
}
bool Device::QueryStream(const stream::Stream* stream) {
CheckInitialized();
return impl_->QueryStream(dev_id_, stream);
}
void Device::AddCallback(stream::Stream* stream,
stream::Stream::Callback* callback) {
CheckInitialized();
impl_->AddCallback(dev_id_, stream, callback);
}
void Device::CreateEvent(event::Event* event, event::Event::Flag flags) {
CheckInitialized();
impl_->CreateEvent(dev_id_, event, flags);
}
void Device::DestroyEvent(event::Event* event) {
CheckInitialized();
impl_->DestroyEvent(dev_id_, event);
}
void Device::RecordEvent(const event::Event* event,
const stream::Stream* stream) {
CheckInitialized();
impl_->RecordEvent(dev_id_, event, stream);
}
void Device::SynchronizeEvent(const event::Event* event) {
CheckInitialized();
impl_->SynchronizeEvent(dev_id_, event);
}
bool Device::QueryEvent(const event::Event* event) {
CheckInitialized();
return impl_->QueryEvent(dev_id_, event);
}
void Device::StreamWaitEvent(const stream::Stream* stream,
const event::Event* event) {
CheckInitialized();
impl_->StreamWaitEvent(dev_id_, stream, event);
}
......@@ -82,6 +99,7 @@ void Device::MemoryCopyH2D(void* dst,
const void* src,
size_t size,
const stream::Stream* stream) {
CheckInitialized();
impl_->MemoryCopyH2D(dev_id_, dst, src, size, stream);
}
......@@ -89,6 +107,7 @@ void Device::MemoryCopyD2H(void* dst,
const void* src,
size_t size,
const stream::Stream* stream) {
CheckInitialized();
impl_->MemoryCopyD2H(dev_id_, dst, src, size, stream);
}
......@@ -96,6 +115,7 @@ void Device::MemoryCopyD2D(void* dst,
const void* src,
size_t size,
const stream::Stream* stream) {
CheckInitialized();
impl_->MemoryCopyD2D(dev_id_, dst, src, size, stream);
}
......@@ -104,34 +124,42 @@ void Device::MemoryCopyP2P(const Place& dst_place,
const void* src,
size_t size,
const stream::Stream* stream) {
CheckInitialized();
impl_->MemoryCopyP2P(dst_place, dst, dev_id_, src, size, stream);
}
void* Device::MemoryAllocate(size_t size) {
CheckInitialized();
return impl_->MemoryAllocate(dev_id_, size);
}
void Device::MemoryDeallocate(void* ptr, size_t size) {
CheckInitialized();
impl_->MemoryDeallocate(dev_id_, ptr, size);
}
void* Device::MemoryAllocateHost(size_t size) {
CheckInitialized();
return impl_->MemoryAllocateHost(dev_id_, size);
}
void Device::MemoryDeallocateHost(void* ptr, size_t size) {
CheckInitialized();
impl_->MemoryDeallocateHost(dev_id_, ptr, size);
}
void* Device::MemoryAllocateUnified(size_t size) {
CheckInitialized();
return impl_->MemoryAllocateUnified(dev_id_, size);
}
void Device::MemoryDeallocateUnified(void* ptr, size_t size) {
CheckInitialized();
impl_->MemoryDeallocateUnified(dev_id_, ptr, size);
}
void Device::MemorySet(void* ptr, uint8_t value, size_t size) {
CheckInitialized();
impl_->MemorySet(dev_id_, ptr, value, size);
}
......@@ -142,6 +170,7 @@ void Device::BlasAXPBY(const stream::Stream& stream,
const T* x,
float beta,
T* y) {
CheckInitialized();
impl_->BlasAXPBY(dev_id_,
stream,
phi::CppTypeToDataType<T>::Type(),
......@@ -370,20 +399,6 @@ void DeviceManager::SynchronizeDevice(const Place& place) {
dev_impl->SynchronizeDevice(device_id);
}
void DeviceManager::InitDevice(const Place& place) {
auto device_type = place.GetDeviceType();
auto device_id = place.GetDeviceId();
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->InitDevice(device_id);
}
void DeviceManager::DeInitDevice(const Place& place) {
auto device_type = place.GetDeviceType();
auto device_id = place.GetDeviceId();
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->DeInitDevice(device_id);
}
void DeviceManager::SetDevice(const std::string& device_type,
size_t device_id) {
auto dev_impl = GetDeviceInterfaceWithType(device_type);
......
......@@ -32,6 +32,10 @@ class Device final {
public:
Device(size_t dev_id, DeviceInterface* impl) : dev_id_(dev_id), impl_(impl) {}
~Device();
void CheckInitialized();
// Stream
// ! Create an asynchronous stream
void CreateStream(
......@@ -123,6 +127,7 @@ class Device final {
private:
size_t dev_id_;
DeviceInterface* impl_;
std::once_flag initialized_;
};
class DeviceManager {
......@@ -144,10 +149,6 @@ class DeviceManager {
static void SynchronizeDevice(const Place& place);
static void InitDevice(const Place& place);
static void DeInitDevice(const Place& place);
static void SetDevice(const std::string& device_type, size_t device_id);
static void SetDevice(const Place& place);
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <list>
#include <mutex>
#include "paddle/phi/backends/event.h"
......@@ -25,8 +26,10 @@ namespace phi {
namespace event {
std::list<Event*> g_events;
std::mutex g_events_mutex;
void Event::ReleaseAll() {
std::unique_lock lock(g_events_mutex);
for (auto* event : g_events) {
event->Destroy();
}
......@@ -43,8 +46,9 @@ Event::Event(const Place& place, event_t event)
own_data_(false) {}
Event::~Event() {
g_events.remove(this);
Destroy();
std::unique_lock lock(g_events_mutex);
g_events.remove(this);
}
bool Event::Init(const Place& place, Flag flags) {
......@@ -58,13 +62,15 @@ bool Event::Init(const Place& place, Flag flags) {
VLOG(3) << "Init Event: " << event_ << ", place: " << place_
<< ", flag:" << static_cast<int>(flags);
own_data_ = true;
std::unique_lock lock(g_events_mutex);
g_events.push_back(this);
return true;
}
void Event::Destroy() {
if (device_) {
if (own_data_) {
if (own_data_ &&
phi::DeviceManager::HasDeviceType(place_.GetDeviceType())) {
phi::DeviceManager::SetDevice(place_);
device_->DestroyEvent(this);
}
......
......@@ -25,16 +25,19 @@ namespace phi {
namespace stream {
std::list<Stream*> g_streams;
std::mutex g_streams_mutex;
void Stream::ReleaseAll() {
std::unique_lock lock(g_streams_mutex);
for (auto* stream : g_streams) {
stream->Destroy();
}
}
Stream::~Stream() {
g_streams.remove(this);
Destroy();
std::unique_lock lock(g_streams_mutex);
g_streams.remove(this);
}
const stream_t& Stream::raw_stream() const { return stream_; }
......@@ -65,6 +68,7 @@ bool Stream::Init(const Place& place,
<< ", priority: " << static_cast<int>(priority)
<< ", flag:" << static_cast<int>(flag);
own_data_ = true;
std::unique_lock lock(g_streams_mutex);
g_streams.push_back(this);
return true;
}
......@@ -98,7 +102,8 @@ void Stream::WaitCallback() const { callback_manager_->Wait(); }
void Stream::Destroy() {
if (device_) {
if (own_data_) {
if (own_data_ &&
phi::DeviceManager::HasDeviceType(place_.GetDeviceType())) {
phi::DeviceManager::SetDevice(place_);
device_->DestroyStream(this);
}
......
......@@ -29,6 +29,6 @@ if(WITH_CUSTOM_DEVICE AND NOT WITH_GPU)
set_tests_properties(test_custom_cpu_plugin PROPERTIES TIMEOUT 120)
set_tests_properties(test_custom_cpu_profiler_plugin PROPERTIES TIMEOUT 120)
set_tests_properties(test_fleet_launch_custom_device PROPERTIES TIMEOUT 120)
set_tests_properties(test_custom_cpu_to_static PROPERTIES TIMEOUT 120)
set_tests_properties(test_custom_cpu_to_static PROPERTIES TIMEOUT 180)
set_tests_properties(test_custom_op_setup PROPERTIES TIMEOUT 120)
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册