diff --git a/paddle/fluid/memory/allocation/retry_allocator.cc b/paddle/fluid/memory/allocation/retry_allocator.cc index ae54ac13ac62975c567646f13925c9e05990423d..9a4ff2f51d08713b425f2a21c3287b71a1857327 100644 --- a/paddle/fluid/memory/allocation/retry_allocator.cc +++ b/paddle/fluid/memory/allocation/retry_allocator.cc @@ -20,67 +20,67 @@ namespace allocation { RetryAllocation::~RetryAllocation() { auto allocator = retry_allocator_.lock(); - { - // release allocation first - if (UNLIKELY(allocator == nullptr)) return; - allocator->underlying_allocator_->Free(underlying_allocation_.release()); - } - - { - // notify all waited allocators - std::lock_guard lock(allocator->mutex_); - allocator->cv_.notify_all(); - } + // Allocator is destroyed before allocation. Should not happened usually. + if (UNLIKELY(allocator == nullptr)) return; + allocator->FreeUnderlyingAllocation(std::move(underlying_allocation_)); } bool RetryAllocator::IsAllocThreadSafe() const { return true; } std::shared_ptr RetryAllocator::AllocateShared( size_t size, Allocator::Attr attr) { - return std::shared_ptr(Allocate(size, attr)); + return std::shared_ptr(AllocateImpl(size, attr)); } std::unique_ptr RetryAllocator::Allocate(size_t size, Allocator::Attr attr) { + return std::unique_ptr(AllocateImpl(size, attr)); +} + +Allocation* RetryAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { auto alloc_func = [&, this]() { return new RetryAllocation(underlying_allocator_->Allocate(size, attr), this->shared_from_this()); }; - // In fact, we can unify the code of allocation success and failure // But it would add lock even when allocation success at the first time - std::unique_ptr ret; try { - ret.reset(alloc_func()); - } catch (BadAlloc &) { + return alloc_func(); + } catch (BadAlloc& bad_alloc) { { // We can just write allocation retry inside the predicate function of // wait_until // But it needs to acquire the lock when executing predicate function // For better performance, we use loop here - std::exception_ptr ex; auto end_time = std::chrono::high_resolution_clock::now() + retry_time_; - std::cv_status status; - do { - { - std::unique_lock lock(mutex_); - status = cv_.wait_until(lock, end_time); - } + auto wait_until = [&, this] { + std::unique_lock lock(mutex_); + return cv_.wait_until(lock, end_time); + }; + while (wait_until() != std::cv_status::timeout) { try { - ret.reset(alloc_func()); - } catch (BadAlloc &) { - ex = std::current_exception(); + return alloc_func(); + } catch (BadAlloc& ex) { + bad_alloc = ex; } catch (...) { - std::rethrow_exception(std::current_exception()); + throw; } - } while (ret == nullptr && status != std::cv_status::timeout); + } - if (ret == nullptr) std::rethrow_exception(ex); + throw; // rethrow the original exception or throw the internal bad_alloc } } catch (...) { - std::rethrow_exception(std::current_exception()); + throw; + } +} +void RetryAllocator::FreeUnderlyingAllocation( + std::unique_ptr&& allocation) { + underlying_allocator_->Free(allocation.get()); + { + // notify all waited allocators, they can try to allocate memory after free. + std::lock_guard lock(mutex_); + cv_.notify_all(); } - return ret; } } // namespace allocation diff --git a/paddle/fluid/memory/allocation/retry_allocator.h b/paddle/fluid/memory/allocation/retry_allocator.h index ef7945e750295458d1995d668220751d6b5d94f7..25461e5423a20425d9d191b5edb68063bb8a796c 100644 --- a/paddle/fluid/memory/allocation/retry_allocator.h +++ b/paddle/fluid/memory/allocation/retry_allocator.h @@ -35,7 +35,7 @@ class RetryAllocation : public Allocation { underlying_allocation_(std::move(underlying_allocation)), retry_allocator_(retry_allocator) {} - ~RetryAllocation(); + ~RetryAllocation() final; private: std::unique_ptr underlying_allocation_; @@ -61,13 +61,17 @@ class RetryAllocator : public ManagedAllocator, bool IsAllocThreadSafe() const override; - std::unique_ptr Allocate( - size_t size, Allocator::Attr attr = kDefault) override; + std::unique_ptr Allocate(size_t size, + Allocator::Attr attr) override; - std::shared_ptr AllocateShared( - size_t size, Allocator::Attr attr = kDefault) override; + std::shared_ptr AllocateShared(size_t size, + Allocator::Attr attr) override; + + void FreeUnderlyingAllocation(std::unique_ptr&& allocation); private: + Allocation* AllocateImpl(size_t size, Allocator::Attr attr); + void EnforceCheck() { PADDLE_ENFORCE_NOT_NULL( underlying_allocator_.get(),