未验证 提交 ac8aad59 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add default stream safe allocator support (#56380)

上级 d7cce317
...@@ -213,6 +213,10 @@ class AllocatorFacadePrivate { ...@@ -213,6 +213,10 @@ class AllocatorFacadePrivate {
platform::CustomPlace(dev_type, dev_id)); platform::CustomPlace(dev_type, dev_id));
} }
} }
if (FLAGS_use_stream_safe_cuda_allocator) {
WrapStreamSafeCustomDeviceAllocatorForDefault();
is_stream_safe_cuda_allocator_used_ = true;
}
#endif #endif
break; break;
} }
...@@ -275,6 +279,10 @@ class AllocatorFacadePrivate { ...@@ -275,6 +279,10 @@ class AllocatorFacadePrivate {
platform::CustomPlace(dev_type, dev_id), allow_free_idle_chunk); platform::CustomPlace(dev_type, dev_id), allow_free_idle_chunk);
} }
} }
if (FLAGS_use_stream_safe_cuda_allocator) {
WrapStreamSafeCustomDeviceAllocatorForDefault();
is_stream_safe_cuda_allocator_used_ = true;
}
#endif #endif
break; break;
} }
...@@ -607,6 +615,55 @@ class AllocatorFacadePrivate { ...@@ -607,6 +615,55 @@ class AllocatorFacadePrivate {
return custom_device_allocators_[place][stream]; return custom_device_allocators_[place][stream];
} }
} }
const std::shared_ptr<StreamSafeCustomDeviceAllocator>
GetDefaultStreamSafeCustomDeviceAllocator(
const platform::CustomPlace& place) const {
const auto iter = default_stream_safe_custom_device_allocators_.find(place);
PADDLE_ENFORCE_NE(
iter,
default_stream_safe_custom_device_allocators_.end(),
platform::errors::NotFound(
"No StreamSafeCustomDeviceAllocator found for the place, %s",
place));
return iter->second;
}
void RecordStream(std::shared_ptr<phi::Allocation> allocation,
phi::stream::stream_t stream) {
std::shared_ptr<StreamSafeCustomDeviceAllocation>
stream_safe_custom_device_allocation =
std::dynamic_pointer_cast<StreamSafeCustomDeviceAllocation>(
allocation);
if (stream_safe_custom_device_allocation != nullptr) {
stream_safe_custom_device_allocation->RecordStream(stream);
} else {
VLOG(6) << "RecordStream for a non-StreamSafeCustomDeviceAllocation";
}
}
void SetDefaultStream(const platform::CustomPlace& place,
phi::stream::stream_t stream) {
const std::shared_ptr<StreamSafeCustomDeviceAllocator>& allocator =
GetDefaultStreamSafeCustomDeviceAllocator(place);
PADDLE_ENFORCE_EQ(allocator->GetDefaultStream(),
nullptr,
platform::errors::Unavailable(
"The default stream for "
"StreamSafeCustomDeviceAllocator(%p) in %s has been "
"set to %p, not allow to change it to %p.",
allocator.get(),
place,
allocator->GetDefaultStream(),
stream));
allocator->SetDefaultStream(stream);
VLOG(8) << "Set default stream to " << stream
<< " for StreamSafeCustomDeviceAllocator(" << allocator.get()
<< ") in " << place;
}
#endif #endif
private: private:
...@@ -1073,6 +1130,24 @@ class AllocatorFacadePrivate { ...@@ -1073,6 +1130,24 @@ class AllocatorFacadePrivate {
allow_free_idle_chunk); allow_free_idle_chunk);
} }
void WrapStreamSafeCustomDeviceAllocatorForDefault() {
for (auto& pair : allocators_) {
auto& place = pair.first;
if (platform::is_custom_place(place)) {
std::shared_ptr<StreamSafeCustomDeviceAllocator>&& allocator =
std::make_shared<StreamSafeCustomDeviceAllocator>(
pair.second,
place,
/* default_stream = */
nullptr);
pair.second = allocator;
default_stream_safe_custom_device_allocators_[place] = allocator;
VLOG(8) << "WrapStreamSafeCustomDeviceAllocatorForDefault for " << place
<< ", allocator address = " << pair.second.get();
}
}
}
void InitAutoGrowthCustomDeviceAllocator(platform::CustomPlace p, void InitAutoGrowthCustomDeviceAllocator(platform::CustomPlace p,
phi::stream::stream_t stream) { phi::stream::stream_t stream) {
auto chunk_size = FLAGS_auto_growth_chunk_size_in_mb << 20; auto chunk_size = FLAGS_auto_growth_chunk_size_in_mb << 20;
...@@ -1486,13 +1561,25 @@ void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(int64_t id) { ...@@ -1486,13 +1561,25 @@ void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(int64_t id) {
const std::shared_ptr<Allocator>& AllocatorFacade::GetAllocator( const std::shared_ptr<Allocator>& AllocatorFacade::GetAllocator(
const platform::Place& place, phi::stream::stream_t stream) { const platform::Place& place, phi::stream::stream_t stream) {
AllocatorFacadePrivate* m = GetPrivate(); AllocatorFacadePrivate* m = GetPrivate();
if (!FLAGS_use_stream_safe_cuda_allocator) { if (!m->IsStreamSafeCUDAAllocatorUsed()) {
return m->GetAllocator(place, return m->GetAllocator(place,
stream, stream,
/*create_if_not_found=*/true); /*create_if_not_found=*/true);
} }
return m->GetAllocator(place, /* A non-zero num to choose allocator_ */ 1); return m->GetAllocator(place, /* A non-zero num to choose allocator_ */ 1);
} }
void AllocatorFacade::RecordStream(std::shared_ptr<phi::Allocation> allocation,
phi::stream::stream_t stream) {
GetPrivate()->RecordStream(allocation, stream);
}
void AllocatorFacade::SetDefaultStream(const platform::CustomPlace& place,
phi::stream::stream_t stream) {
if (m_->IsStreamSafeCUDAAllocatorUsed()) {
m_->SetDefaultStream(place, stream);
}
}
#endif #endif
UNUSED static std::shared_ptr<NaiveBestFitAllocator> unused_obj = UNUSED static std::shared_ptr<NaiveBestFitAllocator> unused_obj =
......
...@@ -99,6 +99,10 @@ class AllocatorFacade { ...@@ -99,6 +99,10 @@ class AllocatorFacade {
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
const std::shared_ptr<Allocator>& GetAllocator(const platform::Place& place, const std::shared_ptr<Allocator>& GetAllocator(const platform::Place& place,
phi::stream::stream_t stream); phi::stream::stream_t stream);
void RecordStream(std::shared_ptr<phi::Allocation> allocation,
phi::stream::stream_t stream);
void SetDefaultStream(const platform::CustomPlace& place,
phi::stream::stream_t stream);
#endif #endif
// TODO(yy): Allocate a Copy-On-Write allocation? // TODO(yy): Allocate a Copy-On-Write allocation?
private: private:
......
...@@ -55,7 +55,6 @@ void StreamSafeCustomDeviceAllocation::RecordStream( ...@@ -55,7 +55,6 @@ void StreamSafeCustomDeviceAllocation::RecordStream(
} }
void StreamSafeCustomDeviceAllocation::MarkAsWillBeFreed() { void StreamSafeCustomDeviceAllocation::MarkAsWillBeFreed() {
std::call_once(once_flag_, [this] { phi::DeviceManager::SetDevice(place_); });
std::lock_guard<SpinLock> lock_guard(outstanding_event_map_lock_); std::lock_guard<SpinLock> lock_guard(outstanding_event_map_lock_);
if (!will_be_freed_) { if (!will_be_freed_) {
will_be_freed_ = false; will_be_freed_ = false;
...@@ -63,6 +62,8 @@ void StreamSafeCustomDeviceAllocation::MarkAsWillBeFreed() { ...@@ -63,6 +62,8 @@ void StreamSafeCustomDeviceAllocation::MarkAsWillBeFreed() {
if (phi::DeviceManager::HasDeviceType(place_.GetDeviceType()) && if (phi::DeviceManager::HasDeviceType(place_.GetDeviceType()) &&
outstanding_event_map_.find(owning_stream_) == outstanding_event_map_.find(owning_stream_) ==
outstanding_event_map_.end()) { outstanding_event_map_.end()) {
std::call_once(once_flag_,
[this] { phi::DeviceManager::SetDevice(place_); });
outstanding_event_map_[owning_stream_].Init(place_); outstanding_event_map_[owning_stream_].Init(place_);
VLOG(9) << "Create a new event " VLOG(9) << "Create a new event "
<< outstanding_event_map_[owning_stream_].raw_event(); << outstanding_event_map_[owning_stream_].raw_event();
...@@ -76,11 +77,11 @@ void StreamSafeCustomDeviceAllocation::MarkAsWillBeFreed() { ...@@ -76,11 +77,11 @@ void StreamSafeCustomDeviceAllocation::MarkAsWillBeFreed() {
} }
bool StreamSafeCustomDeviceAllocation::CanBeFreed() { bool StreamSafeCustomDeviceAllocation::CanBeFreed() {
std::call_once(once_flag_, [this] { phi::DeviceManager::SetDevice(place_); });
std::lock_guard<SpinLock> lock_guard(outstanding_event_map_lock_); std::lock_guard<SpinLock> lock_guard(outstanding_event_map_lock_);
if (!phi::DeviceManager::HasDeviceType(place_.GetDeviceType())) { if (!phi::DeviceManager::HasDeviceType(place_.GetDeviceType())) {
return true; return true;
} }
std::call_once(once_flag_, [this] { phi::DeviceManager::SetDevice(place_); });
for (auto it = outstanding_event_map_.begin(); for (auto it = outstanding_event_map_.begin();
it != outstanding_event_map_.end(); it != outstanding_event_map_.end();
++it) { ++it) {
...@@ -102,6 +103,11 @@ phi::stream::stream_t StreamSafeCustomDeviceAllocation::GetOwningStream() ...@@ -102,6 +103,11 @@ phi::stream::stream_t StreamSafeCustomDeviceAllocation::GetOwningStream()
return owning_stream_; return owning_stream_;
} }
void StreamSafeCustomDeviceAllocation::SetOwningStream(
phi::stream::stream_t s) {
owning_stream_ = s;
}
StreamSafeCustomDeviceAllocator::StreamSafeCustomDeviceAllocator( StreamSafeCustomDeviceAllocator::StreamSafeCustomDeviceAllocator(
std::shared_ptr<Allocator> underlying_allocator, std::shared_ptr<Allocator> underlying_allocator,
platform::CustomPlace place, platform::CustomPlace place,
...@@ -172,6 +178,9 @@ void StreamSafeCustomDeviceAllocator::FreeImpl(phi::Allocation* allocation) { ...@@ -172,6 +178,9 @@ void StreamSafeCustomDeviceAllocator::FreeImpl(phi::Allocation* allocation) {
static_cast<StreamSafeCustomDeviceAllocation*>(allocation); static_cast<StreamSafeCustomDeviceAllocation*>(allocation);
VLOG(8) << "Try free allocation " << stream_safe_cuda_allocation->ptr(); VLOG(8) << "Try free allocation " << stream_safe_cuda_allocation->ptr();
if (!stream_safe_cuda_allocation->GetOwningStream()) {
stream_safe_cuda_allocation->SetOwningStream(default_stream_);
}
stream_safe_cuda_allocation->MarkAsWillBeFreed(); stream_safe_cuda_allocation->MarkAsWillBeFreed();
if (stream_safe_cuda_allocation->CanBeFreed()) { if (stream_safe_cuda_allocation->CanBeFreed()) {
VLOG(9) << "Directly delete allocation"; VLOG(9) << "Directly delete allocation";
......
...@@ -39,6 +39,7 @@ class StreamSafeCustomDeviceAllocation : public Allocation { ...@@ -39,6 +39,7 @@ class StreamSafeCustomDeviceAllocation : public Allocation {
bool CanBeFreed(); bool CanBeFreed();
void MarkAsWillBeFreed(); void MarkAsWillBeFreed();
phi::stream::stream_t GetOwningStream() const; phi::stream::stream_t GetOwningStream() const;
void SetOwningStream(phi::stream::stream_t s);
private: private:
thread_local static std::once_flag once_flag_; thread_local static std::once_flag once_flag_;
......
...@@ -111,6 +111,15 @@ inline std::unique_ptr<DeviceContext> CreateDeviceContext( ...@@ -111,6 +111,15 @@ inline std::unique_ptr<DeviceContext> CreateDeviceContext(
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
} else if (p.GetType() == phi::AllocationType::CUSTOM) { } else if (p.GetType() == phi::AllocationType::CUSTOM) {
auto* custom_ctx = dynamic_cast<phi::CustomContext*>(dev_ctx); auto* custom_ctx = dynamic_cast<phi::CustomContext*>(dev_ctx);
PADDLE_ENFORCE_NOT_NULL(
custom_ctx,
phi::errors::InvalidArgument(
"Failed to dynamic_cast dev_ctx into phi::CustomContext."));
if (!disable_setting_default_stream_for_allocator) {
instance.SetDefaultStream(CustomPlace(p.GetDeviceType(), p.GetDeviceId()),
custom_ctx->stream());
}
dev_ctx->SetAllocator(instance.GetAllocator(p, custom_ctx->stream()).get()); dev_ctx->SetAllocator(instance.GetAllocator(p, custom_ctx->stream()).get());
dev_ctx->SetGenerator(phi::DefaultCustomDeviceGenerator(p).get()); dev_ctx->SetGenerator(phi::DefaultCustomDeviceGenerator(p).get());
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册