From e5c4cf614046565d5ca27494385c9332a55a03c4 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 15 Nov 2018 16:11:09 +0800 Subject: [PATCH] Polish allocation Clean allocation->Deleter test=develop --- .../memory/allocation/aligned_allocator.h | 7 +-- ...ocation.h => allocation_with_underlying.h} | 4 +- paddle/fluid/memory/allocation/allocator.cc | 24 +++++----- paddle/fluid/memory/allocation/allocator.h | 32 ++++++------- .../memory/allocation/allocator_facade.cc | 18 +++---- .../allocation/auto_increment_allocator.cc | 48 +++++++++---------- .../allocation/auto_increment_allocator.h | 6 ++- .../memory/allocation/best_fit_allocator.h | 2 +- .../memory/allocation/buffered_allocator.cc | 8 ++-- .../memory/allocation/buffered_allocator.h | 2 +- .../allocation/buffered_allocator_test.cc | 2 +- .../allocation/conditional_allocator.cc | 19 ++++---- .../memory/allocation/conditional_allocator.h | 5 +- .../fluid/memory/allocation/cpu_allocator.h | 2 +- .../fluid/memory/allocation/cuda_allocator.h | 2 +- .../memory/allocation/locked_allocator.cc | 6 +-- .../memory/allocation/locked_allocator.h | 2 +- .../memory/allocation/pinned_allocator.h | 2 +- .../memory/allocation/retry_allocator.cc | 7 ++- .../fluid/memory/allocation/retry_allocator.h | 2 +- .../memory/allocation/zero_size_allocator.cc | 14 +++--- .../memory/allocation/zero_size_allocator.h | 4 +- 22 files changed, 111 insertions(+), 107 deletions(-) rename paddle/fluid/memory/allocation/{underlying_manual_allocation.h => allocation_with_underlying.h} (89%) diff --git a/paddle/fluid/memory/allocation/aligned_allocator.h b/paddle/fluid/memory/allocation/aligned_allocator.h index 0818bdc68a2..fc1a8e9247b 100644 --- a/paddle/fluid/memory/allocation/aligned_allocator.h +++ b/paddle/fluid/memory/allocation/aligned_allocator.h @@ -86,11 +86,12 @@ template class AlignedAllocator : public ThinAlignedAllocator { public: using ThinAlignedAllocator::ThinAlignedAllocator; - AllocationPtr Allocate(size_t size, Attr attr) override { + + protected: + Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override { auto raw_allocation = underlying_allocator_->Allocate(size + kAlignment, attr); - return AllocationPtr( - new AlignedAllocation(std::move(raw_allocation), size)); + return new AlignedAllocation(std::move(raw_allocation), size); } }; diff --git a/paddle/fluid/memory/allocation/underlying_manual_allocation.h b/paddle/fluid/memory/allocation/allocation_with_underlying.h similarity index 89% rename from paddle/fluid/memory/allocation/underlying_manual_allocation.h rename to paddle/fluid/memory/allocation/allocation_with_underlying.h index c02dff74475..69f78667d7d 100644 --- a/paddle/fluid/memory/allocation/underlying_manual_allocation.h +++ b/paddle/fluid/memory/allocation/allocation_with_underlying.h @@ -20,9 +20,9 @@ namespace paddle { namespace memory { namespace allocation { -class UnderlyingManualAllocation : public Allocation { +class AllocationWithUnderlying : public Allocation { public: - explicit UnderlyingManualAllocation(AllocationPtr allocation) + explicit AllocationWithUnderlying(AllocationPtr allocation) : Allocation(allocation->ptr(), allocation->size(), allocation->place()), allocation_(std::move(allocation)) {} AllocationPtr allocation_; diff --git a/paddle/fluid/memory/allocation/allocator.cc b/paddle/fluid/memory/allocation/allocator.cc index 7593b6776cc..41b4234de54 100644 --- a/paddle/fluid/memory/allocation/allocator.cc +++ b/paddle/fluid/memory/allocation/allocator.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/memory/allocation/allocator.h" + #include namespace paddle { @@ -24,23 +25,20 @@ Allocator::~Allocator() {} bool Allocator::IsAllocThreadSafe() const { return false; } +AllocationPtr Allocator::Allocate(size_t size, Allocator::Attr attr) { + auto ptr = AllocateImpl(size, attr); + ptr->set_allocator(this); + return AllocationPtr(ptr); +} + +void Allocator::Free(Allocation* allocation) { delete allocation; } + const char* BadAlloc::what() const noexcept { return msg_.c_str(); } -AllocationPtr MannualFreeAllocator::Allocate(size_t size, - Allocator::Attr attr) { - auto allocation = AllocateImpl(size, attr); - allocation->Deleter = - std::bind1st(std::mem_fn(&MannualFreeAllocator::Free), this); - return AllocationPtr(allocation); -} void AllocationDeleter::operator()(Allocation* allocation) const { - if (allocation->Deleter) { - auto deleter = std::move(allocation->Deleter); - deleter(allocation); - } else { - delete allocation; - } + allocation->allocator()->Free(allocation); } + } // namespace allocation } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/allocation/allocator.h b/paddle/fluid/memory/allocation/allocator.h index 90b55f19e83..f2b6f438c38 100644 --- a/paddle/fluid/memory/allocation/allocator.h +++ b/paddle/fluid/memory/allocation/allocator.h @@ -32,10 +32,12 @@ class BadAlloc : public std::exception { }; class Allocation; -struct AllocationDeleter { +class AllocationDeleter { + public: void operator()(Allocation* allocation) const; }; +class Allocator; // Allocation is the object holding the actually pointer. Use // `Allocation::ptr()` will returns the pointer that allocated. // @@ -45,7 +47,7 @@ struct AllocationDeleter { class Allocation { public: Allocation(void* ptr, size_t size, platform::Place place) - : ptr_(ptr), size_(size), place_(place) {} + : allocator_(nullptr), ptr_(ptr), size_(size), place_(place) {} Allocation(const Allocation& o) = delete; Allocation& operator=(const Allocation& o) = delete; @@ -70,11 +72,14 @@ class Allocation { const platform::Place& place() const { return place_; } - virtual ~Allocation(); + Allocator* allocator() { return allocator_; } - std::function Deleter; + void set_allocator(Allocator* allocator) { allocator_ = allocator; } + + virtual ~Allocation(); private: + Allocator* allocator_; void* ptr_; size_t size_; platform::Place place_; @@ -121,25 +126,18 @@ class Allocator { virtual ~Allocator(); - // Allocate an allocation. Note the return allocation might need to be freed - // manually if the Allocator is an `UnmanagedAllocator`. - virtual AllocationPtr Allocate(size_t size, - Allocator::Attr attr = kDefault) = 0; + // Allocate an allocation. + AllocationPtr Allocate(size_t size, Allocator::Attr attr = kDefault); // True if the `Allocate` is thread safe. virtual bool IsAllocThreadSafe() const; -}; - -// User need to invoke `Free` or `FreeUniquePtr` manually if allocated by -// a manally managed allocator. -class MannualFreeAllocator : public Allocator { - public: - AllocationPtr Allocate(size_t size, Attr attr) final; protected: - virtual void Free(Allocation* allocation) = 0; + virtual void Free(Allocation* allocation); virtual Allocation* AllocateImpl(size_t size, Allocator::Attr attr) = 0; - friend class MannualFreeAllocation; + + private: + friend class AllocationDeleter; }; } // namespace allocation diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 597742690cd..ec8a64a1d1f 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -49,12 +49,13 @@ class CPUManagedAllocator : public Allocator { public: CPUManagedAllocator() : normal_allocator_(new CPUAllocator()) {} - AllocationPtr Allocate(size_t size, Attr attr) override { - return normal_allocator_->Allocate(size, attr); - } - bool IsAllocThreadSafe() const override { return true; } + protected: + Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override { + return normal_allocator_->Allocate(size, attr).release(); + } + private: std::shared_ptr normal_allocator_; }; @@ -103,10 +104,6 @@ class ChunkedManagedAllocator : public Allocator { raw_allocator_.reset(); } - AllocationPtr Allocate(size_t size, Attr attr) override { - return default_allocator_->Allocate(size, attr); - } - std::shared_ptr BestFitAllocatorCreator() { chunks_.emplace_back(raw_allocator_->Allocate(max_chunk_size_)); auto* allocation = chunks_.back().get(); @@ -128,6 +125,11 @@ class ChunkedManagedAllocator : public Allocator { bool IsAllocThreadSafe() const override { return true; } + protected: + Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override { + return default_allocator_->Allocate(size, attr).release(); + } + protected: size_t max_chunk_size_; int64_t retry_time_; diff --git a/paddle/fluid/memory/allocation/auto_increment_allocator.cc b/paddle/fluid/memory/allocation/auto_increment_allocator.cc index 399b3c02867..c4785d20786 100644 --- a/paddle/fluid/memory/allocation/auto_increment_allocator.cc +++ b/paddle/fluid/memory/allocation/auto_increment_allocator.cc @@ -17,9 +17,25 @@ namespace paddle { namespace memory { namespace allocation { +bool AutoIncrementAllocator::IsAllocThreadSafe() const { return true; } -AllocationPtr AutoIncrementAllocator::Allocate(size_t size, - Allocator::Attr attr) { +std::shared_ptr AutoIncrementAllocator::CreateNewAllocator() { + std::lock_guard guard(mtx_); + auto old_size = allocator_num_.load(); + PADDLE_ENFORCE_LT(old_size, underlying_allocators_.size(), + "Allocator number exceeds capacity %d", + underlying_allocators_.size()); + underlying_allocators_[old_size] = creator_(); + prev_success_allocator_ = old_size; + ++allocator_num_; + PADDLE_ENFORCE( + underlying_allocators_[old_size]->IsAllocThreadSafe(), + "the underlying allocator must be thread safe. This is a program " + "bug."); + return underlying_allocators_[old_size]; +} +Allocation *AutoIncrementAllocator::AllocateImpl(size_t size, + Allocator::Attr attr) { auto cur = prev_success_allocator_.load(); size_t retry_count = allocator_num_.load(); size_t allocator_num = retry_count; @@ -27,8 +43,8 @@ AllocationPtr AutoIncrementAllocator::Allocate(size_t size, try { auto res = underlying_allocators_[cur]->Allocate(size, attr); prev_success_allocator_ = cur; - return res; - } catch (BadAlloc&) { + return res.release(); + } catch (BadAlloc &) { if (++cur >= allocator_num) { cur = 0; } @@ -47,32 +63,14 @@ AllocationPtr AutoIncrementAllocator::Allocate(size_t size, try { auto ret = underlying_allocators_[cur]->Allocate(size, attr); prev_success_allocator_ = cur; - return ret; - } catch (BadAlloc&) { + return ret.release(); + } catch (BadAlloc &) { } catch (...) { throw; } } // No suitable allocator - return CreateNewAllocator()->Allocate(size, attr); -} - -bool AutoIncrementAllocator::IsAllocThreadSafe() const { return true; } - -std::shared_ptr AutoIncrementAllocator::CreateNewAllocator() { - std::lock_guard guard(mtx_); - auto old_size = allocator_num_.load(); - PADDLE_ENFORCE_LT(old_size, underlying_allocators_.size(), - "Allocator number exceeds capacity %d", - underlying_allocators_.size()); - underlying_allocators_[old_size] = creator_(); - prev_success_allocator_ = old_size; - ++allocator_num_; - PADDLE_ENFORCE( - underlying_allocators_[old_size]->IsAllocThreadSafe(), - "the underlying allocator must be thread safe. This is a program " - "bug."); - return underlying_allocators_[old_size]; + return CreateNewAllocator()->Allocate(size, attr).release(); } } // namespace allocation diff --git a/paddle/fluid/memory/allocation/auto_increment_allocator.h b/paddle/fluid/memory/allocation/auto_increment_allocator.h index f0a46af9264..382588f17a9 100644 --- a/paddle/fluid/memory/allocation/auto_increment_allocator.h +++ b/paddle/fluid/memory/allocation/auto_increment_allocator.h @@ -54,13 +54,15 @@ class AutoIncrementAllocator : public Allocator { explicit AutoIncrementAllocator(AllocatorCreator&& creator, size_t capacity) : creator_(std::move(creator)), underlying_allocators_(capacity) {} - AllocationPtr Allocate(size_t size, Attr attr) override; - bool IsAllocThreadSafe() const override; private: std::shared_ptr CreateNewAllocator(); + protected: + Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override; + + private: AllocatorCreator creator_; std::vector underlying_allocators_; diff --git a/paddle/fluid/memory/allocation/best_fit_allocator.h b/paddle/fluid/memory/allocation/best_fit_allocator.h index 69a8260c861..141fb55d6c9 100644 --- a/paddle/fluid/memory/allocation/best_fit_allocator.h +++ b/paddle/fluid/memory/allocation/best_fit_allocator.h @@ -98,7 +98,7 @@ class BestFitAllocation : public Allocation { // // To free an allocation, it will set the chunk of allocation to free and merge // the prev-chunk and the next-chunk when possible. -class BestFitAllocator : public MannualFreeAllocator { +class BestFitAllocator : public Allocator { public: explicit BestFitAllocator(Allocation* allocation); diff --git a/paddle/fluid/memory/allocation/buffered_allocator.cc b/paddle/fluid/memory/allocation/buffered_allocator.cc index 5b6855b1254..4b57ea86694 100644 --- a/paddle/fluid/memory/allocation/buffered_allocator.cc +++ b/paddle/fluid/memory/allocation/buffered_allocator.cc @@ -16,7 +16,7 @@ #include #include #include -#include "paddle/fluid/memory/allocation/underlying_manual_allocation.h" +#include "paddle/fluid/memory/allocation/allocation_with_underlying.h" namespace paddle { namespace memory { @@ -60,16 +60,16 @@ Allocation *BufferedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { if (it != allocations_.end() && it->first < size * 2) { AllocationPtr result(std::move(it->second)); allocations_.erase(it); - return new UnderlyingManualAllocation(std::move(result)); + return new AllocationWithUnderlying(std::move(result)); } } try { - return new UnderlyingManualAllocation( + return new AllocationWithUnderlying( underlying_allocator_->Allocate(size, attr)); } catch (BadAlloc &) { FreeCache(size); - return new UnderlyingManualAllocation( + return new AllocationWithUnderlying( underlying_allocator_->Allocate(size, attr)); } } diff --git a/paddle/fluid/memory/allocation/buffered_allocator.h b/paddle/fluid/memory/allocation/buffered_allocator.h index c1db1b76be3..54b0dd244a6 100644 --- a/paddle/fluid/memory/allocation/buffered_allocator.h +++ b/paddle/fluid/memory/allocation/buffered_allocator.h @@ -29,7 +29,7 @@ namespace allocation { // memory allocation and reuse memory. // BufferedAllocator provides the same thread-safety level as // underlying_allocator_ -class BufferedAllocator : public MannualFreeAllocator { +class BufferedAllocator : public Allocator { public: explicit BufferedAllocator(std::unique_ptr &&allocator); diff --git a/paddle/fluid/memory/allocation/buffered_allocator_test.cc b/paddle/fluid/memory/allocation/buffered_allocator_test.cc index f1a57ea2e98..41ebb9dbeaf 100644 --- a/paddle/fluid/memory/allocation/buffered_allocator_test.cc +++ b/paddle/fluid/memory/allocation/buffered_allocator_test.cc @@ -52,7 +52,7 @@ class StubAllocation : public Allocation { using Allocation::Allocation; }; -class StubAllocator : public MannualFreeAllocator { +class StubAllocator : public Allocator { public: void ResetCounter() { construct_count_ = 0; diff --git a/paddle/fluid/memory/allocation/conditional_allocator.cc b/paddle/fluid/memory/allocation/conditional_allocator.cc index 2a7fd691972..96a818e03e5 100644 --- a/paddle/fluid/memory/allocation/conditional_allocator.cc +++ b/paddle/fluid/memory/allocation/conditional_allocator.cc @@ -24,15 +24,6 @@ ConditionalAllocator& ConditionalAllocator::AddAllocator( underlying_allocators_.emplace_back(std::move(func), std::move(allocator)); return *this; } -AllocationPtr ConditionalAllocator::Allocate(size_t size, - Allocator::Attr attr) { - for (auto& pair : underlying_allocators_) { - if (pair.first(size, attr)) { - return pair.second->Allocate(size, attr); - } - } - throw BadAlloc("No suitable allocator"); -} bool ConditionalAllocator::IsAllocThreadSafe() const { return std::all_of(underlying_allocators_.begin(), @@ -42,6 +33,16 @@ bool ConditionalAllocator::IsAllocThreadSafe() const { }); } +Allocation* ConditionalAllocator::AllocateImpl(size_t size, + Allocator::Attr attr) { + for (auto& pair : underlying_allocators_) { + if (pair.first(size, attr)) { + return pair.second->Allocate(size, attr).release(); + } + } + throw BadAlloc("No suitable allocator"); +} + } // namespace allocation } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/allocation/conditional_allocator.h b/paddle/fluid/memory/allocation/conditional_allocator.h index 7716fc98650..7140e1b3082 100644 --- a/paddle/fluid/memory/allocation/conditional_allocator.h +++ b/paddle/fluid/memory/allocation/conditional_allocator.h @@ -45,10 +45,13 @@ class ConditionalAllocator : public Allocator { ConditionalAllocator& AddAllocator(std::function func, std::shared_ptr allocator); - AllocationPtr Allocate(size_t size, Attr attr) override; + // AllocationPtr Allocate(size_t size, Attr attr) override; bool IsAllocThreadSafe() const override; + protected: + Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override; + private: using AllocatorWithCond = std::pair, std::shared_ptr>; diff --git a/paddle/fluid/memory/allocation/cpu_allocator.h b/paddle/fluid/memory/allocation/cpu_allocator.h index 1b16b22a31e..9e0044c47ae 100644 --- a/paddle/fluid/memory/allocation/cpu_allocator.h +++ b/paddle/fluid/memory/allocation/cpu_allocator.h @@ -31,7 +31,7 @@ class CPUAllocation : public Allocation { CPUAllocation(void* ptr, size_t size); }; -class CPUAllocator : public MannualFreeAllocator { +class CPUAllocator : public Allocator { public: constexpr static size_t kAlignment = 64u; bool IsAllocThreadSafe() const override; diff --git a/paddle/fluid/memory/allocation/cuda_allocator.h b/paddle/fluid/memory/allocation/cuda_allocator.h index 7e1360d13c4..63726f5820b 100644 --- a/paddle/fluid/memory/allocation/cuda_allocator.h +++ b/paddle/fluid/memory/allocation/cuda_allocator.h @@ -27,7 +27,7 @@ class CUDAAllocation : public Allocation { using Allocation::Allocation; }; -class CUDAAllocator : public MannualFreeAllocator { +class CUDAAllocator : public Allocator { public: explicit CUDAAllocator(const platform::CUDAPlace& place) : place_(place) {} explicit CUDAAllocator(const platform::Place& place) diff --git a/paddle/fluid/memory/allocation/locked_allocator.cc b/paddle/fluid/memory/allocation/locked_allocator.cc index ab4d6f4d121..835f6527c8a 100644 --- a/paddle/fluid/memory/allocation/locked_allocator.cc +++ b/paddle/fluid/memory/allocation/locked_allocator.cc @@ -14,7 +14,7 @@ #include "paddle/fluid/memory/allocation/locked_allocator.h" #include // NOLINT -#include "paddle/fluid/memory/allocation/underlying_manual_allocation.h" +#include "paddle/fluid/memory/allocation/allocation_with_underlying.h" #include "paddle/fluid/platform/lock_guard_ptr.h" namespace paddle { namespace memory { @@ -33,14 +33,14 @@ LockedAllocator::LockedAllocator( void LockedAllocator::Free(Allocation *allocation) { { platform::LockGuardPtr guard(mtx_); - reinterpret_cast(allocation) + reinterpret_cast(allocation) ->allocation_.reset(); // Destroy inner allocation } delete allocation; } Allocation *LockedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { platform::LockGuardPtr guard(mtx_); - return new UnderlyingManualAllocation( + return new AllocationWithUnderlying( underlying_allocator_->Allocate(size, attr)); } } // namespace allocation diff --git a/paddle/fluid/memory/allocation/locked_allocator.h b/paddle/fluid/memory/allocation/locked_allocator.h index 1675aa57402..4967b9bb8d3 100644 --- a/paddle/fluid/memory/allocation/locked_allocator.h +++ b/paddle/fluid/memory/allocation/locked_allocator.h @@ -22,7 +22,7 @@ namespace memory { namespace allocation { // A allocator to make underlying allocator thread safe. -class LockedAllocator : public MannualFreeAllocator { +class LockedAllocator : public Allocator { public: explicit LockedAllocator(std::unique_ptr &&underlying_allocator); bool IsAllocThreadSafe() const override; diff --git a/paddle/fluid/memory/allocation/pinned_allocator.h b/paddle/fluid/memory/allocation/pinned_allocator.h index 9a6677b5a82..26d12dd91c7 100644 --- a/paddle/fluid/memory/allocation/pinned_allocator.h +++ b/paddle/fluid/memory/allocation/pinned_allocator.h @@ -26,7 +26,7 @@ class CPUPinnedAllocation : public Allocation { : Allocation(ptr, size, platform::CUDAPinnedPlace()) {} }; -class CPUPinnedAllocator : public MannualFreeAllocator { +class CPUPinnedAllocator : public Allocator { public: bool IsAllocThreadSafe() const override; diff --git a/paddle/fluid/memory/allocation/retry_allocator.cc b/paddle/fluid/memory/allocation/retry_allocator.cc index 829434e5302..981705051b4 100644 --- a/paddle/fluid/memory/allocation/retry_allocator.cc +++ b/paddle/fluid/memory/allocation/retry_allocator.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/memory/allocation/retry_allocator.h" -#include "paddle/fluid/memory/allocation/underlying_manual_allocation.h" +#include "paddle/fluid/memory/allocation/allocation_with_underlying.h" namespace paddle { namespace memory { namespace allocation { @@ -24,8 +24,7 @@ bool RetryAllocator::IsAllocThreadSafe() const { void RetryAllocator::Free(Allocation* allocation) { // Delete underlying allocation first. - reinterpret_cast(allocation) - ->allocation_.reset(); + reinterpret_cast(allocation)->allocation_.reset(); { // notify all waited allocators, they can try to allocate memory after free. std::lock_guard lock(mutex_); @@ -36,7 +35,7 @@ void RetryAllocator::Free(Allocation* allocation) { Allocation* RetryAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { auto alloc_func = [&, this]() { - return new UnderlyingManualAllocation( + return new AllocationWithUnderlying( underlying_allocator_->Allocate(size, attr)); }; // In fact, we can unify the code of allocation success and failure diff --git a/paddle/fluid/memory/allocation/retry_allocator.h b/paddle/fluid/memory/allocation/retry_allocator.h index 537c2bd1a70..5efcac8b108 100644 --- a/paddle/fluid/memory/allocation/retry_allocator.h +++ b/paddle/fluid/memory/allocation/retry_allocator.h @@ -26,7 +26,7 @@ namespace allocation { class RetryAllocator; -class RetryAllocator : public MannualFreeAllocator { +class RetryAllocator : public Allocator { public: RetryAllocator(std::unique_ptr&& allocator, size_t retry_ms) : underlying_allocator_(std::move(allocator)), retry_time_(retry_ms) { diff --git a/paddle/fluid/memory/allocation/zero_size_allocator.cc b/paddle/fluid/memory/allocation/zero_size_allocator.cc index 52ef0de20fb..cb2df1a0298 100644 --- a/paddle/fluid/memory/allocation/zero_size_allocator.cc +++ b/paddle/fluid/memory/allocation/zero_size_allocator.cc @@ -18,17 +18,17 @@ namespace paddle { namespace memory { namespace allocation { -AllocationPtr ZeroSizeAllocator::Allocate(size_t size, Allocator::Attr attr) { +bool ZeroSizeAllocator::IsAllocThreadSafe() const { + return underlying_allocator_->IsAllocThreadSafe(); +} + +Allocation *ZeroSizeAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { if (size == 0) { - return AllocationPtr(new ZeroSizeAllocation(place_)); + return new ZeroSizeAllocation(place_); } else { - return underlying_allocator_->Allocate(size, attr); + return underlying_allocator_->Allocate(size, attr).release(); } } - -bool ZeroSizeAllocator::IsAllocThreadSafe() const { - return underlying_allocator_->IsAllocThreadSafe(); -} } // namespace allocation } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/allocation/zero_size_allocator.h b/paddle/fluid/memory/allocation/zero_size_allocator.h index d6e2d30d996..6b80245a34e 100644 --- a/paddle/fluid/memory/allocation/zero_size_allocator.h +++ b/paddle/fluid/memory/allocation/zero_size_allocator.h @@ -34,10 +34,12 @@ class ZeroSizeAllocator : public Allocator { ZeroSizeAllocator(std::shared_ptr underlying_allocator, const platform::Place& p) : underlying_allocator_(std::move(underlying_allocator)), place_(p) {} - AllocationPtr Allocate(size_t size, Attr attr) override; bool IsAllocThreadSafe() const override; + protected: + Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override; + private: std::shared_ptr underlying_allocator_; const platform::Place& place_; -- GitLab