未验证 提交 ac8aad59 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add default stream safe allocator support (#56380)

上级 d7cce317
......@@ -213,6 +213,10 @@ class AllocatorFacadePrivate {
platform::CustomPlace(dev_type, dev_id));
}
}
if (FLAGS_use_stream_safe_cuda_allocator) {
WrapStreamSafeCustomDeviceAllocatorForDefault();
is_stream_safe_cuda_allocator_used_ = true;
}
#endif
break;
}
......@@ -275,6 +279,10 @@ class AllocatorFacadePrivate {
platform::CustomPlace(dev_type, dev_id), allow_free_idle_chunk);
}
}
if (FLAGS_use_stream_safe_cuda_allocator) {
WrapStreamSafeCustomDeviceAllocatorForDefault();
is_stream_safe_cuda_allocator_used_ = true;
}
#endif
break;
}
......@@ -607,6 +615,55 @@ class AllocatorFacadePrivate {
return custom_device_allocators_[place][stream];
}
}
const std::shared_ptr<StreamSafeCustomDeviceAllocator>
GetDefaultStreamSafeCustomDeviceAllocator(
const platform::CustomPlace& place) const {
const auto iter = default_stream_safe_custom_device_allocators_.find(place);
PADDLE_ENFORCE_NE(
iter,
default_stream_safe_custom_device_allocators_.end(),
platform::errors::NotFound(
"No StreamSafeCustomDeviceAllocator found for the place, %s",
place));
return iter->second;
}
void RecordStream(std::shared_ptr<phi::Allocation> allocation,
phi::stream::stream_t stream) {
std::shared_ptr<StreamSafeCustomDeviceAllocation>
stream_safe_custom_device_allocation =
std::dynamic_pointer_cast<StreamSafeCustomDeviceAllocation>(
allocation);
if (stream_safe_custom_device_allocation != nullptr) {
stream_safe_custom_device_allocation->RecordStream(stream);
} else {
VLOG(6) << "RecordStream for a non-StreamSafeCustomDeviceAllocation";
}
}
void SetDefaultStream(const platform::CustomPlace& place,
phi::stream::stream_t stream) {
const std::shared_ptr<StreamSafeCustomDeviceAllocator>& allocator =
GetDefaultStreamSafeCustomDeviceAllocator(place);
PADDLE_ENFORCE_EQ(allocator->GetDefaultStream(),
nullptr,
platform::errors::Unavailable(
"The default stream for "
"StreamSafeCustomDeviceAllocator(%p) in %s has been "
"set to %p, not allow to change it to %p.",
allocator.get(),
place,
allocator->GetDefaultStream(),
stream));
allocator->SetDefaultStream(stream);
VLOG(8) << "Set default stream to " << stream
<< " for StreamSafeCustomDeviceAllocator(" << allocator.get()
<< ") in " << place;
}
#endif
private:
......@@ -1073,6 +1130,24 @@ class AllocatorFacadePrivate {
allow_free_idle_chunk);
}
void WrapStreamSafeCustomDeviceAllocatorForDefault() {
for (auto& pair : allocators_) {
auto& place = pair.first;
if (platform::is_custom_place(place)) {
std::shared_ptr<StreamSafeCustomDeviceAllocator>&& allocator =
std::make_shared<StreamSafeCustomDeviceAllocator>(
pair.second,
place,
/* default_stream = */
nullptr);
pair.second = allocator;
default_stream_safe_custom_device_allocators_[place] = allocator;
VLOG(8) << "WrapStreamSafeCustomDeviceAllocatorForDefault for " << place
<< ", allocator address = " << pair.second.get();
}
}
}
void InitAutoGrowthCustomDeviceAllocator(platform::CustomPlace p,
phi::stream::stream_t stream) {
auto chunk_size = FLAGS_auto_growth_chunk_size_in_mb << 20;
......@@ -1486,13 +1561,25 @@ void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(int64_t id) {
const std::shared_ptr<Allocator>& AllocatorFacade::GetAllocator(
const platform::Place& place, phi::stream::stream_t stream) {
AllocatorFacadePrivate* m = GetPrivate();
if (!FLAGS_use_stream_safe_cuda_allocator) {
if (!m->IsStreamSafeCUDAAllocatorUsed()) {
return m->GetAllocator(place,
stream,
/*create_if_not_found=*/true);
}
return m->GetAllocator(place, /* A non-zero num to choose allocator_ */ 1);
}
void AllocatorFacade::RecordStream(std::shared_ptr<phi::Allocation> allocation,
phi::stream::stream_t stream) {
GetPrivate()->RecordStream(allocation, stream);
}
void AllocatorFacade::SetDefaultStream(const platform::CustomPlace& place,
phi::stream::stream_t stream) {
if (m_->IsStreamSafeCUDAAllocatorUsed()) {
m_->SetDefaultStream(place, stream);
}
}
#endif
UNUSED static std::shared_ptr<NaiveBestFitAllocator> unused_obj =
......
......@@ -99,6 +99,10 @@ class AllocatorFacade {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
const std::shared_ptr<Allocator>& GetAllocator(const platform::Place& place,
phi::stream::stream_t stream);
void RecordStream(std::shared_ptr<phi::Allocation> allocation,
phi::stream::stream_t stream);
void SetDefaultStream(const platform::CustomPlace& place,
phi::stream::stream_t stream);
#endif
// TODO(yy): Allocate a Copy-On-Write allocation?
private:
......
......@@ -55,7 +55,6 @@ void StreamSafeCustomDeviceAllocation::RecordStream(
}
void StreamSafeCustomDeviceAllocation::MarkAsWillBeFreed() {
std::call_once(once_flag_, [this] { phi::DeviceManager::SetDevice(place_); });
std::lock_guard<SpinLock> lock_guard(outstanding_event_map_lock_);
if (!will_be_freed_) {
will_be_freed_ = false;
......@@ -63,6 +62,8 @@ void StreamSafeCustomDeviceAllocation::MarkAsWillBeFreed() {
if (phi::DeviceManager::HasDeviceType(place_.GetDeviceType()) &&
outstanding_event_map_.find(owning_stream_) ==
outstanding_event_map_.end()) {
std::call_once(once_flag_,
[this] { phi::DeviceManager::SetDevice(place_); });
outstanding_event_map_[owning_stream_].Init(place_);
VLOG(9) << "Create a new event "
<< outstanding_event_map_[owning_stream_].raw_event();
......@@ -76,11 +77,11 @@ void StreamSafeCustomDeviceAllocation::MarkAsWillBeFreed() {
}
bool StreamSafeCustomDeviceAllocation::CanBeFreed() {
std::call_once(once_flag_, [this] { phi::DeviceManager::SetDevice(place_); });
std::lock_guard<SpinLock> lock_guard(outstanding_event_map_lock_);
if (!phi::DeviceManager::HasDeviceType(place_.GetDeviceType())) {
return true;
}
std::call_once(once_flag_, [this] { phi::DeviceManager::SetDevice(place_); });
for (auto it = outstanding_event_map_.begin();
it != outstanding_event_map_.end();
++it) {
......@@ -102,6 +103,11 @@ phi::stream::stream_t StreamSafeCustomDeviceAllocation::GetOwningStream()
return owning_stream_;
}
void StreamSafeCustomDeviceAllocation::SetOwningStream(
phi::stream::stream_t s) {
owning_stream_ = s;
}
StreamSafeCustomDeviceAllocator::StreamSafeCustomDeviceAllocator(
std::shared_ptr<Allocator> underlying_allocator,
platform::CustomPlace place,
......@@ -172,6 +178,9 @@ void StreamSafeCustomDeviceAllocator::FreeImpl(phi::Allocation* allocation) {
static_cast<StreamSafeCustomDeviceAllocation*>(allocation);
VLOG(8) << "Try free allocation " << stream_safe_cuda_allocation->ptr();
if (!stream_safe_cuda_allocation->GetOwningStream()) {
stream_safe_cuda_allocation->SetOwningStream(default_stream_);
}
stream_safe_cuda_allocation->MarkAsWillBeFreed();
if (stream_safe_cuda_allocation->CanBeFreed()) {
VLOG(9) << "Directly delete allocation";
......
......@@ -39,6 +39,7 @@ class StreamSafeCustomDeviceAllocation : public Allocation {
bool CanBeFreed();
void MarkAsWillBeFreed();
phi::stream::stream_t GetOwningStream() const;
void SetOwningStream(phi::stream::stream_t s);
private:
thread_local static std::once_flag once_flag_;
......
......@@ -111,6 +111,15 @@ inline std::unique_ptr<DeviceContext> CreateDeviceContext(
#ifdef PADDLE_WITH_CUSTOM_DEVICE
} else if (p.GetType() == phi::AllocationType::CUSTOM) {
auto* custom_ctx = dynamic_cast<phi::CustomContext*>(dev_ctx);
PADDLE_ENFORCE_NOT_NULL(
custom_ctx,
phi::errors::InvalidArgument(
"Failed to dynamic_cast dev_ctx into phi::CustomContext."));
if (!disable_setting_default_stream_for_allocator) {
instance.SetDefaultStream(CustomPlace(p.GetDeviceType(), p.GetDeviceId()),
custom_ctx->stream());
}
dev_ctx->SetAllocator(instance.GetAllocator(p, custom_ctx->stream()).get());
dev_ctx->SetGenerator(phi::DefaultCustomDeviceGenerator(p).get());
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册