未验证 提交 e99b3cb2 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] Fix device id out of range in custom device resource pool (#56580)

上级 700bc8cd
...@@ -45,13 +45,15 @@ void StreamSafeCustomDeviceAllocation::RecordStream( ...@@ -45,13 +45,15 @@ void StreamSafeCustomDeviceAllocation::RecordStream(
auto it = outstanding_event_map_.find(stream); auto it = outstanding_event_map_.find(stream);
if (it == outstanding_event_map_.end()) { if (it == outstanding_event_map_.end()) {
outstanding_event_map_[stream].Init(place()); outstanding_event_map_.insert(
{stream, std::make_shared<phi::event::Event>()});
outstanding_event_map_[stream]->Init(place());
VLOG(9) << "Create a new event " VLOG(9) << "Create a new event "
<< outstanding_event_map_[stream].raw_event(); << outstanding_event_map_[stream]->raw_event();
auto stream_wrapper = phi::stream::Stream(place(), stream); auto stream_wrapper = phi::stream::Stream(place(), stream);
VLOG(8) << "Record event " << it->second.raw_event() << " to stream " VLOG(8) << "Record event " << outstanding_event_map_[stream]->raw_event()
<< stream; << " to stream " << stream;
outstanding_event_map_[stream].Record(&stream_wrapper); outstanding_event_map_[stream]->Record(&stream_wrapper);
} }
} }
...@@ -65,14 +67,16 @@ void StreamSafeCustomDeviceAllocation::MarkAsWillBeFreed() { ...@@ -65,14 +67,16 @@ void StreamSafeCustomDeviceAllocation::MarkAsWillBeFreed() {
outstanding_event_map_.end()) { outstanding_event_map_.end()) {
std::call_once(once_flag_, std::call_once(once_flag_,
[this] { phi::DeviceManager::SetDevice(place_); }); [this] { phi::DeviceManager::SetDevice(place_); });
outstanding_event_map_[owning_stream_].Init(place_); outstanding_event_map_.insert(
{owning_stream_, std::make_shared<phi::event::Event>()});
outstanding_event_map_[owning_stream_]->Init(place_);
VLOG(9) << "Create a new event " VLOG(9) << "Create a new event "
<< outstanding_event_map_[owning_stream_].raw_event(); << outstanding_event_map_[owning_stream_]->raw_event();
auto stream_wrapper = phi::stream::Stream(place_, owning_stream_); auto stream_wrapper = phi::stream::Stream(place_, owning_stream_);
VLOG(8) << "Record event " VLOG(8) << "Record event "
<< outstanding_event_map_[owning_stream_].raw_event() << outstanding_event_map_[owning_stream_]->raw_event()
<< " to stream " << owning_stream_; << " to stream " << owning_stream_;
outstanding_event_map_[owning_stream_].Record(&stream_wrapper); outstanding_event_map_[owning_stream_]->Record(&stream_wrapper);
} }
} }
} }
...@@ -87,15 +91,15 @@ bool StreamSafeCustomDeviceAllocation::CanBeFreed() { ...@@ -87,15 +91,15 @@ bool StreamSafeCustomDeviceAllocation::CanBeFreed() {
it != outstanding_event_map_.end(); it != outstanding_event_map_.end();
++it) { ++it) {
auto& event = it->second; auto& event = it->second;
if (!event.Query()) { if (!event->Query()) {
VLOG(9) << "Event " << event.raw_event() << " for " << ptr() VLOG(9) << "Event " << event->raw_event() << " for " << ptr()
<< " is not completed"; << " is not completed";
return false; return false;
} }
VLOG(8) << "Destroy event " << event.raw_event(); VLOG(8) << "Destroy event " << event->raw_event();
outstanding_event_map_.erase(outstanding_event_map_.begin(), it); event->Destroy();
event.Destroy();
} }
outstanding_event_map_.clear();
return true; return true;
} }
......
...@@ -44,7 +44,8 @@ class StreamSafeCustomDeviceAllocation : public Allocation { ...@@ -44,7 +44,8 @@ class StreamSafeCustomDeviceAllocation : public Allocation {
private: private:
thread_local static std::once_flag once_flag_; thread_local static std::once_flag once_flag_;
DecoratedAllocationPtr underlying_allocation_; DecoratedAllocationPtr underlying_allocation_;
std::map<phi::stream::stream_t, phi::event::Event> outstanding_event_map_; std::map<phi::stream::stream_t, std::shared_ptr<phi::event::Event>>
outstanding_event_map_;
phi::stream::stream_t owning_stream_; phi::stream::stream_t owning_stream_;
SpinLock outstanding_event_map_lock_; SpinLock outstanding_event_map_lock_;
std::shared_ptr<Allocator> allocator_; std::shared_ptr<Allocator> allocator_;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
...@@ -27,10 +26,9 @@ CustomDeviceStreamResourcePool::CustomDeviceStreamResourcePool( ...@@ -27,10 +26,9 @@ CustomDeviceStreamResourcePool::CustomDeviceStreamResourcePool(
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Required device shall be CustomPlace, but received %d. ", place)); "Required device shall be CustomPlace, but received %d. ", place));
auto selected_devices = int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
phi::DeviceManager::GetSelectedDeviceList(place.GetDeviceType()); pool_.reserve(dev_cnt);
pool_.reserve(selected_devices.size()); for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
for (auto& dev_idx : selected_devices) {
auto creator = [place, dev_idx, this] { auto creator = [place, dev_idx, this] {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx); auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_); phi::DeviceManager::SetDevice(place_);
...@@ -83,11 +81,12 @@ CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance( ...@@ -83,11 +81,12 @@ CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance(
if (pool.find(place.GetDeviceType()) == pool.end()) { if (pool.find(place.GetDeviceType()) == pool.end()) {
pool.insert({place.GetDeviceType(), pool.insert({place.GetDeviceType(),
std::vector<CustomDeviceStreamResourcePool*>()}); std::vector<CustomDeviceStreamResourcePool*>()});
for (auto& dev_id : for (size_t i = 0;
phi::DeviceManager::GetSelectedDeviceList(place.GetDeviceType())) { i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
++i) {
pool[place.GetDeviceType()].emplace_back( pool[place.GetDeviceType()].emplace_back(
new CustomDeviceStreamResourcePool( new CustomDeviceStreamResourcePool(
paddle::platform::CustomPlace(place.GetDeviceType(), dev_id))); paddle::platform::CustomPlace(place.GetDeviceType(), i)));
} }
} }
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
...@@ -125,10 +124,9 @@ CustomDeviceEventResourcePool::CustomDeviceEventResourcePool( ...@@ -125,10 +124,9 @@ CustomDeviceEventResourcePool::CustomDeviceEventResourcePool(
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Required device shall be CustomPlace, but received %d. ", place)); "Required device shall be CustomPlace, but received %d. ", place));
auto selected_devices = int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
phi::DeviceManager::GetSelectedDeviceList(place.GetDeviceType()); pool_.reserve(dev_cnt);
pool_.reserve(selected_devices.size()); for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
for (auto& dev_idx : selected_devices) {
auto creator = [place, dev_idx, this] { auto creator = [place, dev_idx, this] {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx); auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_); phi::DeviceManager::SetDevice(place_);
...@@ -181,11 +179,12 @@ CustomDeviceEventResourcePool& CustomDeviceEventResourcePool::Instance( ...@@ -181,11 +179,12 @@ CustomDeviceEventResourcePool& CustomDeviceEventResourcePool::Instance(
if (pool.find(place.GetDeviceType()) == pool.end()) { if (pool.find(place.GetDeviceType()) == pool.end()) {
pool.insert( pool.insert(
{place.GetDeviceType(), std::vector<CustomDeviceEventResourcePool*>()}); {place.GetDeviceType(), std::vector<CustomDeviceEventResourcePool*>()});
for (auto& dev_id : for (size_t i = 0;
phi::DeviceManager::GetSelectedDeviceList(place.GetDeviceType())) { i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
++i) {
pool[place.GetDeviceType()].emplace_back( pool[place.GetDeviceType()].emplace_back(
new CustomDeviceEventResourcePool( new CustomDeviceEventResourcePool(
paddle::platform::CustomPlace(place.GetDeviceType(), dev_id))); paddle::platform::CustomPlace(place.GetDeviceType(), i)));
} }
} }
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
......
...@@ -90,16 +90,10 @@ class CustomDevice : public DeviceInterface { ...@@ -90,16 +90,10 @@ class CustomDevice : public DeviceInterface {
C_Device_st device; C_Device_st device;
device.id = dev_id; device.id = dev_id;
devices_pool[dev_id] = device; devices_pool[dev_id] = device;
InitDevice(dev_id);
} }
} }
void Finalize() override { void Finalize() override {
auto devices = GetDeviceList();
for (auto dev_id : devices) {
DeInitDevice(dev_id);
}
bool ok = true; bool ok = true;
if (pimpl_->finalize && pimpl_->finalize() != C_SUCCESS) { if (pimpl_->finalize && pimpl_->finalize() != C_SUCCESS) {
LOG(ERROR) << "Finalize " << Type() << " Failed\n"; LOG(ERROR) << "Finalize " << Type() << " Failed\n";
......
...@@ -29,52 +29,69 @@ ...@@ -29,52 +29,69 @@
namespace phi { namespace phi {
void Device::CheckInitialized() {
std::call_once(initialized_, [&]() { this->impl_->InitDevice(dev_id_); });
}
Device::~Device() { impl_->DeInitDevice(dev_id_); }
void Device::CreateStream(stream::Stream* stream, void Device::CreateStream(stream::Stream* stream,
const stream::Stream::Priority& priority, const stream::Stream::Priority& priority,
const stream::Stream::Flag& flag) { const stream::Stream::Flag& flag) {
CheckInitialized();
impl_->CreateStream(dev_id_, stream, priority, flag); impl_->CreateStream(dev_id_, stream, priority, flag);
} }
void Device::DestroyStream(stream::Stream* stream) { void Device::DestroyStream(stream::Stream* stream) {
CheckInitialized();
impl_->DestroyStream(dev_id_, stream); impl_->DestroyStream(dev_id_, stream);
} }
void Device::SynchronizeStream(const stream::Stream* stream) { void Device::SynchronizeStream(const stream::Stream* stream) {
CheckInitialized();
impl_->SynchronizeStream(dev_id_, stream); impl_->SynchronizeStream(dev_id_, stream);
} }
bool Device::QueryStream(const stream::Stream* stream) { bool Device::QueryStream(const stream::Stream* stream) {
CheckInitialized();
return impl_->QueryStream(dev_id_, stream); return impl_->QueryStream(dev_id_, stream);
} }
void Device::AddCallback(stream::Stream* stream, void Device::AddCallback(stream::Stream* stream,
stream::Stream::Callback* callback) { stream::Stream::Callback* callback) {
CheckInitialized();
impl_->AddCallback(dev_id_, stream, callback); impl_->AddCallback(dev_id_, stream, callback);
} }
void Device::CreateEvent(event::Event* event, event::Event::Flag flags) { void Device::CreateEvent(event::Event* event, event::Event::Flag flags) {
CheckInitialized();
impl_->CreateEvent(dev_id_, event, flags); impl_->CreateEvent(dev_id_, event, flags);
} }
void Device::DestroyEvent(event::Event* event) { void Device::DestroyEvent(event::Event* event) {
CheckInitialized();
impl_->DestroyEvent(dev_id_, event); impl_->DestroyEvent(dev_id_, event);
} }
void Device::RecordEvent(const event::Event* event, void Device::RecordEvent(const event::Event* event,
const stream::Stream* stream) { const stream::Stream* stream) {
CheckInitialized();
impl_->RecordEvent(dev_id_, event, stream); impl_->RecordEvent(dev_id_, event, stream);
} }
void Device::SynchronizeEvent(const event::Event* event) { void Device::SynchronizeEvent(const event::Event* event) {
CheckInitialized();
impl_->SynchronizeEvent(dev_id_, event); impl_->SynchronizeEvent(dev_id_, event);
} }
bool Device::QueryEvent(const event::Event* event) { bool Device::QueryEvent(const event::Event* event) {
CheckInitialized();
return impl_->QueryEvent(dev_id_, event); return impl_->QueryEvent(dev_id_, event);
} }
void Device::StreamWaitEvent(const stream::Stream* stream, void Device::StreamWaitEvent(const stream::Stream* stream,
const event::Event* event) { const event::Event* event) {
CheckInitialized();
impl_->StreamWaitEvent(dev_id_, stream, event); impl_->StreamWaitEvent(dev_id_, stream, event);
} }
...@@ -82,6 +99,7 @@ void Device::MemoryCopyH2D(void* dst, ...@@ -82,6 +99,7 @@ void Device::MemoryCopyH2D(void* dst,
const void* src, const void* src,
size_t size, size_t size,
const stream::Stream* stream) { const stream::Stream* stream) {
CheckInitialized();
impl_->MemoryCopyH2D(dev_id_, dst, src, size, stream); impl_->MemoryCopyH2D(dev_id_, dst, src, size, stream);
} }
...@@ -89,6 +107,7 @@ void Device::MemoryCopyD2H(void* dst, ...@@ -89,6 +107,7 @@ void Device::MemoryCopyD2H(void* dst,
const void* src, const void* src,
size_t size, size_t size,
const stream::Stream* stream) { const stream::Stream* stream) {
CheckInitialized();
impl_->MemoryCopyD2H(dev_id_, dst, src, size, stream); impl_->MemoryCopyD2H(dev_id_, dst, src, size, stream);
} }
...@@ -96,6 +115,7 @@ void Device::MemoryCopyD2D(void* dst, ...@@ -96,6 +115,7 @@ void Device::MemoryCopyD2D(void* dst,
const void* src, const void* src,
size_t size, size_t size,
const stream::Stream* stream) { const stream::Stream* stream) {
CheckInitialized();
impl_->MemoryCopyD2D(dev_id_, dst, src, size, stream); impl_->MemoryCopyD2D(dev_id_, dst, src, size, stream);
} }
...@@ -104,34 +124,42 @@ void Device::MemoryCopyP2P(const Place& dst_place, ...@@ -104,34 +124,42 @@ void Device::MemoryCopyP2P(const Place& dst_place,
const void* src, const void* src,
size_t size, size_t size,
const stream::Stream* stream) { const stream::Stream* stream) {
CheckInitialized();
impl_->MemoryCopyP2P(dst_place, dst, dev_id_, src, size, stream); impl_->MemoryCopyP2P(dst_place, dst, dev_id_, src, size, stream);
} }
void* Device::MemoryAllocate(size_t size) { void* Device::MemoryAllocate(size_t size) {
CheckInitialized();
return impl_->MemoryAllocate(dev_id_, size); return impl_->MemoryAllocate(dev_id_, size);
} }
void Device::MemoryDeallocate(void* ptr, size_t size) { void Device::MemoryDeallocate(void* ptr, size_t size) {
CheckInitialized();
impl_->MemoryDeallocate(dev_id_, ptr, size); impl_->MemoryDeallocate(dev_id_, ptr, size);
} }
void* Device::MemoryAllocateHost(size_t size) { void* Device::MemoryAllocateHost(size_t size) {
CheckInitialized();
return impl_->MemoryAllocateHost(dev_id_, size); return impl_->MemoryAllocateHost(dev_id_, size);
} }
void Device::MemoryDeallocateHost(void* ptr, size_t size) { void Device::MemoryDeallocateHost(void* ptr, size_t size) {
CheckInitialized();
impl_->MemoryDeallocateHost(dev_id_, ptr, size); impl_->MemoryDeallocateHost(dev_id_, ptr, size);
} }
void* Device::MemoryAllocateUnified(size_t size) { void* Device::MemoryAllocateUnified(size_t size) {
CheckInitialized();
return impl_->MemoryAllocateUnified(dev_id_, size); return impl_->MemoryAllocateUnified(dev_id_, size);
} }
void Device::MemoryDeallocateUnified(void* ptr, size_t size) { void Device::MemoryDeallocateUnified(void* ptr, size_t size) {
CheckInitialized();
impl_->MemoryDeallocateUnified(dev_id_, ptr, size); impl_->MemoryDeallocateUnified(dev_id_, ptr, size);
} }
void Device::MemorySet(void* ptr, uint8_t value, size_t size) { void Device::MemorySet(void* ptr, uint8_t value, size_t size) {
CheckInitialized();
impl_->MemorySet(dev_id_, ptr, value, size); impl_->MemorySet(dev_id_, ptr, value, size);
} }
...@@ -142,6 +170,7 @@ void Device::BlasAXPBY(const stream::Stream& stream, ...@@ -142,6 +170,7 @@ void Device::BlasAXPBY(const stream::Stream& stream,
const T* x, const T* x,
float beta, float beta,
T* y) { T* y) {
CheckInitialized();
impl_->BlasAXPBY(dev_id_, impl_->BlasAXPBY(dev_id_,
stream, stream,
phi::CppTypeToDataType<T>::Type(), phi::CppTypeToDataType<T>::Type(),
...@@ -370,20 +399,6 @@ void DeviceManager::SynchronizeDevice(const Place& place) { ...@@ -370,20 +399,6 @@ void DeviceManager::SynchronizeDevice(const Place& place) {
dev_impl->SynchronizeDevice(device_id); dev_impl->SynchronizeDevice(device_id);
} }
void DeviceManager::InitDevice(const Place& place) {
auto device_type = place.GetDeviceType();
auto device_id = place.GetDeviceId();
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->InitDevice(device_id);
}
void DeviceManager::DeInitDevice(const Place& place) {
auto device_type = place.GetDeviceType();
auto device_id = place.GetDeviceId();
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->DeInitDevice(device_id);
}
void DeviceManager::SetDevice(const std::string& device_type, void DeviceManager::SetDevice(const std::string& device_type,
size_t device_id) { size_t device_id) {
auto dev_impl = GetDeviceInterfaceWithType(device_type); auto dev_impl = GetDeviceInterfaceWithType(device_type);
......
...@@ -32,6 +32,10 @@ class Device final { ...@@ -32,6 +32,10 @@ class Device final {
public: public:
Device(size_t dev_id, DeviceInterface* impl) : dev_id_(dev_id), impl_(impl) {} Device(size_t dev_id, DeviceInterface* impl) : dev_id_(dev_id), impl_(impl) {}
~Device();
void CheckInitialized();
// Stream // Stream
// ! Create an asynchronous stream // ! Create an asynchronous stream
void CreateStream( void CreateStream(
...@@ -123,6 +127,7 @@ class Device final { ...@@ -123,6 +127,7 @@ class Device final {
private: private:
size_t dev_id_; size_t dev_id_;
DeviceInterface* impl_; DeviceInterface* impl_;
std::once_flag initialized_;
}; };
class DeviceManager { class DeviceManager {
...@@ -144,10 +149,6 @@ class DeviceManager { ...@@ -144,10 +149,6 @@ class DeviceManager {
static void SynchronizeDevice(const Place& place); static void SynchronizeDevice(const Place& place);
static void InitDevice(const Place& place);
static void DeInitDevice(const Place& place);
static void SetDevice(const std::string& device_type, size_t device_id); static void SetDevice(const std::string& device_type, size_t device_id);
static void SetDevice(const Place& place); static void SetDevice(const Place& place);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <list> #include <list>
#include <mutex>
#include "paddle/phi/backends/event.h" #include "paddle/phi/backends/event.h"
...@@ -25,8 +26,10 @@ namespace phi { ...@@ -25,8 +26,10 @@ namespace phi {
namespace event { namespace event {
std::list<Event*> g_events; std::list<Event*> g_events;
std::mutex g_events_mutex;
void Event::ReleaseAll() { void Event::ReleaseAll() {
std::unique_lock lock(g_events_mutex);
for (auto* event : g_events) { for (auto* event : g_events) {
event->Destroy(); event->Destroy();
} }
...@@ -43,8 +46,9 @@ Event::Event(const Place& place, event_t event) ...@@ -43,8 +46,9 @@ Event::Event(const Place& place, event_t event)
own_data_(false) {} own_data_(false) {}
Event::~Event() { Event::~Event() {
g_events.remove(this);
Destroy(); Destroy();
std::unique_lock lock(g_events_mutex);
g_events.remove(this);
} }
bool Event::Init(const Place& place, Flag flags) { bool Event::Init(const Place& place, Flag flags) {
...@@ -58,13 +62,15 @@ bool Event::Init(const Place& place, Flag flags) { ...@@ -58,13 +62,15 @@ bool Event::Init(const Place& place, Flag flags) {
VLOG(3) << "Init Event: " << event_ << ", place: " << place_ VLOG(3) << "Init Event: " << event_ << ", place: " << place_
<< ", flag:" << static_cast<int>(flags); << ", flag:" << static_cast<int>(flags);
own_data_ = true; own_data_ = true;
std::unique_lock lock(g_events_mutex);
g_events.push_back(this); g_events.push_back(this);
return true; return true;
} }
void Event::Destroy() { void Event::Destroy() {
if (device_) { if (device_) {
if (own_data_) { if (own_data_ &&
phi::DeviceManager::HasDeviceType(place_.GetDeviceType())) {
phi::DeviceManager::SetDevice(place_); phi::DeviceManager::SetDevice(place_);
device_->DestroyEvent(this); device_->DestroyEvent(this);
} }
......
...@@ -25,16 +25,19 @@ namespace phi { ...@@ -25,16 +25,19 @@ namespace phi {
namespace stream { namespace stream {
std::list<Stream*> g_streams; std::list<Stream*> g_streams;
std::mutex g_streams_mutex;
void Stream::ReleaseAll() { void Stream::ReleaseAll() {
std::unique_lock lock(g_streams_mutex);
for (auto* stream : g_streams) { for (auto* stream : g_streams) {
stream->Destroy(); stream->Destroy();
} }
} }
Stream::~Stream() { Stream::~Stream() {
g_streams.remove(this);
Destroy(); Destroy();
std::unique_lock lock(g_streams_mutex);
g_streams.remove(this);
} }
const stream_t& Stream::raw_stream() const { return stream_; } const stream_t& Stream::raw_stream() const { return stream_; }
...@@ -65,6 +68,7 @@ bool Stream::Init(const Place& place, ...@@ -65,6 +68,7 @@ bool Stream::Init(const Place& place,
<< ", priority: " << static_cast<int>(priority) << ", priority: " << static_cast<int>(priority)
<< ", flag:" << static_cast<int>(flag); << ", flag:" << static_cast<int>(flag);
own_data_ = true; own_data_ = true;
std::unique_lock lock(g_streams_mutex);
g_streams.push_back(this); g_streams.push_back(this);
return true; return true;
} }
...@@ -98,7 +102,8 @@ void Stream::WaitCallback() const { callback_manager_->Wait(); } ...@@ -98,7 +102,8 @@ void Stream::WaitCallback() const { callback_manager_->Wait(); }
void Stream::Destroy() { void Stream::Destroy() {
if (device_) { if (device_) {
if (own_data_) { if (own_data_ &&
phi::DeviceManager::HasDeviceType(place_.GetDeviceType())) {
phi::DeviceManager::SetDevice(place_); phi::DeviceManager::SetDevice(place_);
device_->DestroyStream(this); device_->DestroyStream(this);
} }
......
...@@ -29,6 +29,6 @@ if(WITH_CUSTOM_DEVICE AND NOT WITH_GPU) ...@@ -29,6 +29,6 @@ if(WITH_CUSTOM_DEVICE AND NOT WITH_GPU)
set_tests_properties(test_custom_cpu_plugin PROPERTIES TIMEOUT 120) set_tests_properties(test_custom_cpu_plugin PROPERTIES TIMEOUT 120)
set_tests_properties(test_custom_cpu_profiler_plugin PROPERTIES TIMEOUT 120) set_tests_properties(test_custom_cpu_profiler_plugin PROPERTIES TIMEOUT 120)
set_tests_properties(test_fleet_launch_custom_device PROPERTIES TIMEOUT 120) set_tests_properties(test_fleet_launch_custom_device PROPERTIES TIMEOUT 120)
set_tests_properties(test_custom_cpu_to_static PROPERTIES TIMEOUT 120) set_tests_properties(test_custom_cpu_to_static PROPERTIES TIMEOUT 180)
set_tests_properties(test_custom_op_setup PROPERTIES TIMEOUT 120) set_tests_properties(test_custom_op_setup PROPERTIES TIMEOUT 120)
endif() endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册