diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 1d8866c43ed23c6fb1a31efbc756717f1eb947ce..7887045e9dc06c1b15a4900346b4da39d9e5a06d 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -213,6 +213,10 @@ class AllocatorFacadePrivate { platform::CustomPlace(dev_type, dev_id)); } } + if (FLAGS_use_stream_safe_cuda_allocator) { + WrapStreamSafeCustomDeviceAllocatorForDefault(); + is_stream_safe_cuda_allocator_used_ = true; + } #endif break; } @@ -275,6 +279,10 @@ class AllocatorFacadePrivate { platform::CustomPlace(dev_type, dev_id), allow_free_idle_chunk); } } + if (FLAGS_use_stream_safe_cuda_allocator) { + WrapStreamSafeCustomDeviceAllocatorForDefault(); + is_stream_safe_cuda_allocator_used_ = true; + } #endif break; } @@ -607,6 +615,55 @@ class AllocatorFacadePrivate { return custom_device_allocators_[place][stream]; } } + + const std::shared_ptr + GetDefaultStreamSafeCustomDeviceAllocator( + const platform::CustomPlace& place) const { + const auto iter = default_stream_safe_custom_device_allocators_.find(place); + PADDLE_ENFORCE_NE( + iter, + default_stream_safe_custom_device_allocators_.end(), + platform::errors::NotFound( + "No StreamSafeCustomDeviceAllocator found for the place, %s", + place)); + return iter->second; + } + + void RecordStream(std::shared_ptr allocation, + phi::stream::stream_t stream) { + std::shared_ptr + stream_safe_custom_device_allocation = + std::dynamic_pointer_cast( + allocation); + if (stream_safe_custom_device_allocation != nullptr) { + stream_safe_custom_device_allocation->RecordStream(stream); + } else { + VLOG(6) << "RecordStream for a non-StreamSafeCustomDeviceAllocation"; + } + } + + void SetDefaultStream(const platform::CustomPlace& place, + phi::stream::stream_t stream) { + const std::shared_ptr& allocator = + GetDefaultStreamSafeCustomDeviceAllocator(place); + + PADDLE_ENFORCE_EQ(allocator->GetDefaultStream(), + nullptr, + platform::errors::Unavailable( + "The default stream for " + "StreamSafeCustomDeviceAllocator(%p) in %s has been " + "set to %p, not allow to change it to %p.", + allocator.get(), + place, + allocator->GetDefaultStream(), + stream)); + + allocator->SetDefaultStream(stream); + VLOG(8) << "Set default stream to " << stream + << " for StreamSafeCustomDeviceAllocator(" << allocator.get() + << ") in " << place; + } + #endif private: @@ -1073,6 +1130,24 @@ class AllocatorFacadePrivate { allow_free_idle_chunk); } + void WrapStreamSafeCustomDeviceAllocatorForDefault() { + for (auto& pair : allocators_) { + auto& place = pair.first; + if (platform::is_custom_place(place)) { + std::shared_ptr&& allocator = + std::make_shared( + pair.second, + place, + /* default_stream = */ + nullptr); + pair.second = allocator; + default_stream_safe_custom_device_allocators_[place] = allocator; + VLOG(8) << "WrapStreamSafeCustomDeviceAllocatorForDefault for " << place + << ", allocator address = " << pair.second.get(); + } + } + } + void InitAutoGrowthCustomDeviceAllocator(platform::CustomPlace p, phi::stream::stream_t stream) { auto chunk_size = FLAGS_auto_growth_chunk_size_in_mb << 20; @@ -1486,13 +1561,25 @@ void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(int64_t id) { const std::shared_ptr& AllocatorFacade::GetAllocator( const platform::Place& place, phi::stream::stream_t stream) { AllocatorFacadePrivate* m = GetPrivate(); - if (!FLAGS_use_stream_safe_cuda_allocator) { + if (!m->IsStreamSafeCUDAAllocatorUsed()) { return m->GetAllocator(place, stream, /*create_if_not_found=*/true); } return m->GetAllocator(place, /* A non-zero num to choose allocator_ */ 1); } + +void AllocatorFacade::RecordStream(std::shared_ptr allocation, + phi::stream::stream_t stream) { + GetPrivate()->RecordStream(allocation, stream); +} + +void AllocatorFacade::SetDefaultStream(const platform::CustomPlace& place, + phi::stream::stream_t stream) { + if (m_->IsStreamSafeCUDAAllocatorUsed()) { + m_->SetDefaultStream(place, stream); + } +} #endif UNUSED static std::shared_ptr unused_obj = diff --git a/paddle/fluid/memory/allocation/allocator_facade.h b/paddle/fluid/memory/allocation/allocator_facade.h index 986222a1b03c4691964b44ee66e4f32e4df1ef4b..20be4c11bfe2d53d7fdfd34639b27886a990dddc 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.h +++ b/paddle/fluid/memory/allocation/allocator_facade.h @@ -99,6 +99,10 @@ class AllocatorFacade { #ifdef PADDLE_WITH_CUSTOM_DEVICE const std::shared_ptr& GetAllocator(const platform::Place& place, phi::stream::stream_t stream); + void RecordStream(std::shared_ptr allocation, + phi::stream::stream_t stream); + void SetDefaultStream(const platform::CustomPlace& place, + phi::stream::stream_t stream); #endif // TODO(yy): Allocate a Copy-On-Write allocation? private: diff --git a/paddle/fluid/memory/allocation/stream_safe_custom_device_allocator.cc b/paddle/fluid/memory/allocation/stream_safe_custom_device_allocator.cc index 59f1ec86a56c36286e44ed2723c6763f18e047cd..0658dacc98621a0e2599c3ac7a9607388670d5a1 100644 --- a/paddle/fluid/memory/allocation/stream_safe_custom_device_allocator.cc +++ b/paddle/fluid/memory/allocation/stream_safe_custom_device_allocator.cc @@ -55,7 +55,6 @@ void StreamSafeCustomDeviceAllocation::RecordStream( } void StreamSafeCustomDeviceAllocation::MarkAsWillBeFreed() { - std::call_once(once_flag_, [this] { phi::DeviceManager::SetDevice(place_); }); std::lock_guard lock_guard(outstanding_event_map_lock_); if (!will_be_freed_) { will_be_freed_ = false; @@ -63,6 +62,8 @@ void StreamSafeCustomDeviceAllocation::MarkAsWillBeFreed() { if (phi::DeviceManager::HasDeviceType(place_.GetDeviceType()) && outstanding_event_map_.find(owning_stream_) == outstanding_event_map_.end()) { + std::call_once(once_flag_, + [this] { phi::DeviceManager::SetDevice(place_); }); outstanding_event_map_[owning_stream_].Init(place_); VLOG(9) << "Create a new event " << outstanding_event_map_[owning_stream_].raw_event(); @@ -76,11 +77,11 @@ void StreamSafeCustomDeviceAllocation::MarkAsWillBeFreed() { } bool StreamSafeCustomDeviceAllocation::CanBeFreed() { - std::call_once(once_flag_, [this] { phi::DeviceManager::SetDevice(place_); }); std::lock_guard lock_guard(outstanding_event_map_lock_); if (!phi::DeviceManager::HasDeviceType(place_.GetDeviceType())) { return true; } + std::call_once(once_flag_, [this] { phi::DeviceManager::SetDevice(place_); }); for (auto it = outstanding_event_map_.begin(); it != outstanding_event_map_.end(); ++it) { @@ -102,6 +103,11 @@ phi::stream::stream_t StreamSafeCustomDeviceAllocation::GetOwningStream() return owning_stream_; } +void StreamSafeCustomDeviceAllocation::SetOwningStream( + phi::stream::stream_t s) { + owning_stream_ = s; +} + StreamSafeCustomDeviceAllocator::StreamSafeCustomDeviceAllocator( std::shared_ptr underlying_allocator, platform::CustomPlace place, @@ -172,6 +178,9 @@ void StreamSafeCustomDeviceAllocator::FreeImpl(phi::Allocation* allocation) { static_cast(allocation); VLOG(8) << "Try free allocation " << stream_safe_cuda_allocation->ptr(); + if (!stream_safe_cuda_allocation->GetOwningStream()) { + stream_safe_cuda_allocation->SetOwningStream(default_stream_); + } stream_safe_cuda_allocation->MarkAsWillBeFreed(); if (stream_safe_cuda_allocation->CanBeFreed()) { VLOG(9) << "Directly delete allocation"; diff --git a/paddle/fluid/memory/allocation/stream_safe_custom_device_allocator.h b/paddle/fluid/memory/allocation/stream_safe_custom_device_allocator.h index 997075e124e61fd2b431082d8fb0e278a71a4c72..a73201b76bcc91db895d965fb60a700d32c2b834 100644 --- a/paddle/fluid/memory/allocation/stream_safe_custom_device_allocator.h +++ b/paddle/fluid/memory/allocation/stream_safe_custom_device_allocator.h @@ -39,6 +39,7 @@ class StreamSafeCustomDeviceAllocation : public Allocation { bool CanBeFreed(); void MarkAsWillBeFreed(); phi::stream::stream_t GetOwningStream() const; + void SetOwningStream(phi::stream::stream_t s); private: thread_local static std::once_flag once_flag_; diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 34ab16400ae4d5fbbf515b545449ed26d5adaa59..c4f40767fd52ce379dbca164a43d4e105c5b4493 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -111,6 +111,15 @@ inline std::unique_ptr CreateDeviceContext( #ifdef PADDLE_WITH_CUSTOM_DEVICE } else if (p.GetType() == phi::AllocationType::CUSTOM) { auto* custom_ctx = dynamic_cast(dev_ctx); + PADDLE_ENFORCE_NOT_NULL( + custom_ctx, + phi::errors::InvalidArgument( + "Failed to dynamic_cast dev_ctx into phi::CustomContext.")); + + if (!disable_setting_default_stream_for_allocator) { + instance.SetDefaultStream(CustomPlace(p.GetDeviceType(), p.GetDeviceId()), + custom_ctx->stream()); + } dev_ctx->SetAllocator(instance.GetAllocator(p, custom_ctx->stream()).get()); dev_ctx->SetGenerator(phi::DefaultCustomDeviceGenerator(p).get()); #endif