提交 e5c4cf61 编写于 作者: Y Yu Yang

Polish allocation

Clean allocation->Deleter

test=develop
上级 0d6718fc
...@@ -86,11 +86,12 @@ template <size_t kAlignment> ...@@ -86,11 +86,12 @@ template <size_t kAlignment>
class AlignedAllocator : public ThinAlignedAllocator { class AlignedAllocator : public ThinAlignedAllocator {
public: public:
using ThinAlignedAllocator::ThinAlignedAllocator; using ThinAlignedAllocator::ThinAlignedAllocator;
AllocationPtr Allocate(size_t size, Attr attr) override {
protected:
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override {
auto raw_allocation = auto raw_allocation =
underlying_allocator_->Allocate(size + kAlignment, attr); underlying_allocator_->Allocate(size + kAlignment, attr);
return AllocationPtr( return new AlignedAllocation<kAlignment>(std::move(raw_allocation), size);
new AlignedAllocation<kAlignment>(std::move(raw_allocation), size));
} }
}; };
......
...@@ -20,9 +20,9 @@ namespace paddle { ...@@ -20,9 +20,9 @@ namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
class UnderlyingManualAllocation : public Allocation { class AllocationWithUnderlying : public Allocation {
public: public:
explicit UnderlyingManualAllocation(AllocationPtr allocation) explicit AllocationWithUnderlying(AllocationPtr allocation)
: Allocation(allocation->ptr(), allocation->size(), allocation->place()), : Allocation(allocation->ptr(), allocation->size(), allocation->place()),
allocation_(std::move(allocation)) {} allocation_(std::move(allocation)) {}
AllocationPtr allocation_; AllocationPtr allocation_;
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator.h"
#include <functional> #include <functional>
namespace paddle { namespace paddle {
...@@ -24,23 +25,20 @@ Allocator::~Allocator() {} ...@@ -24,23 +25,20 @@ Allocator::~Allocator() {}
bool Allocator::IsAllocThreadSafe() const { return false; } bool Allocator::IsAllocThreadSafe() const { return false; }
AllocationPtr Allocator::Allocate(size_t size, Allocator::Attr attr) {
auto ptr = AllocateImpl(size, attr);
ptr->set_allocator(this);
return AllocationPtr(ptr);
}
void Allocator::Free(Allocation* allocation) { delete allocation; }
const char* BadAlloc::what() const noexcept { return msg_.c_str(); } const char* BadAlloc::what() const noexcept { return msg_.c_str(); }
AllocationPtr MannualFreeAllocator::Allocate(size_t size,
Allocator::Attr attr) {
auto allocation = AllocateImpl(size, attr);
allocation->Deleter =
std::bind1st(std::mem_fn(&MannualFreeAllocator::Free), this);
return AllocationPtr(allocation);
}
void AllocationDeleter::operator()(Allocation* allocation) const { void AllocationDeleter::operator()(Allocation* allocation) const {
if (allocation->Deleter) { allocation->allocator()->Free(allocation);
auto deleter = std::move(allocation->Deleter);
deleter(allocation);
} else {
delete allocation;
}
} }
} // namespace allocation } // namespace allocation
} // namespace memory } // namespace memory
} // namespace paddle } // namespace paddle
...@@ -32,10 +32,12 @@ class BadAlloc : public std::exception { ...@@ -32,10 +32,12 @@ class BadAlloc : public std::exception {
}; };
class Allocation; class Allocation;
struct AllocationDeleter { class AllocationDeleter {
public:
void operator()(Allocation* allocation) const; void operator()(Allocation* allocation) const;
}; };
class Allocator;
// Allocation is the object holding the actually pointer. Use // Allocation is the object holding the actually pointer. Use
// `Allocation::ptr()` will returns the pointer that allocated. // `Allocation::ptr()` will returns the pointer that allocated.
// //
...@@ -45,7 +47,7 @@ struct AllocationDeleter { ...@@ -45,7 +47,7 @@ struct AllocationDeleter {
class Allocation { class Allocation {
public: public:
Allocation(void* ptr, size_t size, platform::Place place) Allocation(void* ptr, size_t size, platform::Place place)
: ptr_(ptr), size_(size), place_(place) {} : allocator_(nullptr), ptr_(ptr), size_(size), place_(place) {}
Allocation(const Allocation& o) = delete; Allocation(const Allocation& o) = delete;
Allocation& operator=(const Allocation& o) = delete; Allocation& operator=(const Allocation& o) = delete;
...@@ -70,11 +72,14 @@ class Allocation { ...@@ -70,11 +72,14 @@ class Allocation {
const platform::Place& place() const { return place_; } const platform::Place& place() const { return place_; }
virtual ~Allocation(); Allocator* allocator() { return allocator_; }
std::function<void(Allocation*)> Deleter; void set_allocator(Allocator* allocator) { allocator_ = allocator; }
virtual ~Allocation();
private: private:
Allocator* allocator_;
void* ptr_; void* ptr_;
size_t size_; size_t size_;
platform::Place place_; platform::Place place_;
...@@ -121,25 +126,18 @@ class Allocator { ...@@ -121,25 +126,18 @@ class Allocator {
virtual ~Allocator(); virtual ~Allocator();
// Allocate an allocation. Note the return allocation might need to be freed // Allocate an allocation.
// manually if the Allocator is an `UnmanagedAllocator`. AllocationPtr Allocate(size_t size, Allocator::Attr attr = kDefault);
virtual AllocationPtr Allocate(size_t size,
Allocator::Attr attr = kDefault) = 0;
// True if the `Allocate` is thread safe. // True if the `Allocate` is thread safe.
virtual bool IsAllocThreadSafe() const; virtual bool IsAllocThreadSafe() const;
};
// User need to invoke `Free` or `FreeUniquePtr` manually if allocated by
// a manally managed allocator.
class MannualFreeAllocator : public Allocator {
public:
AllocationPtr Allocate(size_t size, Attr attr) final;
protected: protected:
virtual void Free(Allocation* allocation) = 0; virtual void Free(Allocation* allocation);
virtual Allocation* AllocateImpl(size_t size, Allocator::Attr attr) = 0; virtual Allocation* AllocateImpl(size_t size, Allocator::Attr attr) = 0;
friend class MannualFreeAllocation;
private:
friend class AllocationDeleter;
}; };
} // namespace allocation } // namespace allocation
......
...@@ -49,12 +49,13 @@ class CPUManagedAllocator : public Allocator { ...@@ -49,12 +49,13 @@ class CPUManagedAllocator : public Allocator {
public: public:
CPUManagedAllocator() : normal_allocator_(new CPUAllocator()) {} CPUManagedAllocator() : normal_allocator_(new CPUAllocator()) {}
AllocationPtr Allocate(size_t size, Attr attr) override {
return normal_allocator_->Allocate(size, attr);
}
bool IsAllocThreadSafe() const override { return true; } bool IsAllocThreadSafe() const override { return true; }
protected:
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override {
return normal_allocator_->Allocate(size, attr).release();
}
private: private:
std::shared_ptr<Allocator> normal_allocator_; std::shared_ptr<Allocator> normal_allocator_;
}; };
...@@ -103,10 +104,6 @@ class ChunkedManagedAllocator : public Allocator { ...@@ -103,10 +104,6 @@ class ChunkedManagedAllocator : public Allocator {
raw_allocator_.reset(); raw_allocator_.reset();
} }
AllocationPtr Allocate(size_t size, Attr attr) override {
return default_allocator_->Allocate(size, attr);
}
std::shared_ptr<Allocator> BestFitAllocatorCreator() { std::shared_ptr<Allocator> BestFitAllocatorCreator() {
chunks_.emplace_back(raw_allocator_->Allocate(max_chunk_size_)); chunks_.emplace_back(raw_allocator_->Allocate(max_chunk_size_));
auto* allocation = chunks_.back().get(); auto* allocation = chunks_.back().get();
...@@ -128,6 +125,11 @@ class ChunkedManagedAllocator : public Allocator { ...@@ -128,6 +125,11 @@ class ChunkedManagedAllocator : public Allocator {
bool IsAllocThreadSafe() const override { return true; } bool IsAllocThreadSafe() const override { return true; }
protected:
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override {
return default_allocator_->Allocate(size, attr).release();
}
protected: protected:
size_t max_chunk_size_; size_t max_chunk_size_;
int64_t retry_time_; int64_t retry_time_;
......
...@@ -17,9 +17,25 @@ ...@@ -17,9 +17,25 @@
namespace paddle { namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
bool AutoIncrementAllocator::IsAllocThreadSafe() const { return true; }
AllocationPtr AutoIncrementAllocator::Allocate(size_t size, std::shared_ptr<Allocator> AutoIncrementAllocator::CreateNewAllocator() {
Allocator::Attr attr) { std::lock_guard<std::mutex> guard(mtx_);
auto old_size = allocator_num_.load();
PADDLE_ENFORCE_LT(old_size, underlying_allocators_.size(),
"Allocator number exceeds capacity %d",
underlying_allocators_.size());
underlying_allocators_[old_size] = creator_();
prev_success_allocator_ = old_size;
++allocator_num_;
PADDLE_ENFORCE(
underlying_allocators_[old_size]->IsAllocThreadSafe(),
"the underlying allocator must be thread safe. This is a program "
"bug.");
return underlying_allocators_[old_size];
}
Allocation *AutoIncrementAllocator::AllocateImpl(size_t size,
Allocator::Attr attr) {
auto cur = prev_success_allocator_.load(); auto cur = prev_success_allocator_.load();
size_t retry_count = allocator_num_.load(); size_t retry_count = allocator_num_.load();
size_t allocator_num = retry_count; size_t allocator_num = retry_count;
...@@ -27,8 +43,8 @@ AllocationPtr AutoIncrementAllocator::Allocate(size_t size, ...@@ -27,8 +43,8 @@ AllocationPtr AutoIncrementAllocator::Allocate(size_t size,
try { try {
auto res = underlying_allocators_[cur]->Allocate(size, attr); auto res = underlying_allocators_[cur]->Allocate(size, attr);
prev_success_allocator_ = cur; prev_success_allocator_ = cur;
return res; return res.release();
} catch (BadAlloc&) { } catch (BadAlloc &) {
if (++cur >= allocator_num) { if (++cur >= allocator_num) {
cur = 0; cur = 0;
} }
...@@ -47,32 +63,14 @@ AllocationPtr AutoIncrementAllocator::Allocate(size_t size, ...@@ -47,32 +63,14 @@ AllocationPtr AutoIncrementAllocator::Allocate(size_t size,
try { try {
auto ret = underlying_allocators_[cur]->Allocate(size, attr); auto ret = underlying_allocators_[cur]->Allocate(size, attr);
prev_success_allocator_ = cur; prev_success_allocator_ = cur;
return ret; return ret.release();
} catch (BadAlloc&) { } catch (BadAlloc &) {
} catch (...) { } catch (...) {
throw; throw;
} }
} }
// No suitable allocator // No suitable allocator
return CreateNewAllocator()->Allocate(size, attr); return CreateNewAllocator()->Allocate(size, attr).release();
}
bool AutoIncrementAllocator::IsAllocThreadSafe() const { return true; }
std::shared_ptr<Allocator> AutoIncrementAllocator::CreateNewAllocator() {
std::lock_guard<std::mutex> guard(mtx_);
auto old_size = allocator_num_.load();
PADDLE_ENFORCE_LT(old_size, underlying_allocators_.size(),
"Allocator number exceeds capacity %d",
underlying_allocators_.size());
underlying_allocators_[old_size] = creator_();
prev_success_allocator_ = old_size;
++allocator_num_;
PADDLE_ENFORCE(
underlying_allocators_[old_size]->IsAllocThreadSafe(),
"the underlying allocator must be thread safe. This is a program "
"bug.");
return underlying_allocators_[old_size];
} }
} // namespace allocation } // namespace allocation
......
...@@ -54,13 +54,15 @@ class AutoIncrementAllocator : public Allocator { ...@@ -54,13 +54,15 @@ class AutoIncrementAllocator : public Allocator {
explicit AutoIncrementAllocator(AllocatorCreator&& creator, size_t capacity) explicit AutoIncrementAllocator(AllocatorCreator&& creator, size_t capacity)
: creator_(std::move(creator)), underlying_allocators_(capacity) {} : creator_(std::move(creator)), underlying_allocators_(capacity) {}
AllocationPtr Allocate(size_t size, Attr attr) override;
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
private: private:
std::shared_ptr<Allocator> CreateNewAllocator(); std::shared_ptr<Allocator> CreateNewAllocator();
protected:
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
private:
AllocatorCreator creator_; AllocatorCreator creator_;
std::vector<AllocatorCreator::result_type> underlying_allocators_; std::vector<AllocatorCreator::result_type> underlying_allocators_;
......
...@@ -98,7 +98,7 @@ class BestFitAllocation : public Allocation { ...@@ -98,7 +98,7 @@ class BestFitAllocation : public Allocation {
// //
// To free an allocation, it will set the chunk of allocation to free and merge // To free an allocation, it will set the chunk of allocation to free and merge
// the prev-chunk and the next-chunk when possible. // the prev-chunk and the next-chunk when possible.
class BestFitAllocator : public MannualFreeAllocator { class BestFitAllocator : public Allocator {
public: public:
explicit BestFitAllocator(Allocation* allocation); explicit BestFitAllocator(Allocation* allocation);
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <utility> #include <utility>
#include "paddle/fluid/memory/allocation/underlying_manual_allocation.h" #include "paddle/fluid/memory/allocation/allocation_with_underlying.h"
namespace paddle { namespace paddle {
namespace memory { namespace memory {
...@@ -60,16 +60,16 @@ Allocation *BufferedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { ...@@ -60,16 +60,16 @@ Allocation *BufferedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
if (it != allocations_.end() && it->first < size * 2) { if (it != allocations_.end() && it->first < size * 2) {
AllocationPtr result(std::move(it->second)); AllocationPtr result(std::move(it->second));
allocations_.erase(it); allocations_.erase(it);
return new UnderlyingManualAllocation(std::move(result)); return new AllocationWithUnderlying(std::move(result));
} }
} }
try { try {
return new UnderlyingManualAllocation( return new AllocationWithUnderlying(
underlying_allocator_->Allocate(size, attr)); underlying_allocator_->Allocate(size, attr));
} catch (BadAlloc &) { } catch (BadAlloc &) {
FreeCache(size); FreeCache(size);
return new UnderlyingManualAllocation( return new AllocationWithUnderlying(
underlying_allocator_->Allocate(size, attr)); underlying_allocator_->Allocate(size, attr));
} }
} }
......
...@@ -29,7 +29,7 @@ namespace allocation { ...@@ -29,7 +29,7 @@ namespace allocation {
// memory allocation and reuse memory. // memory allocation and reuse memory.
// BufferedAllocator provides the same thread-safety level as // BufferedAllocator provides the same thread-safety level as
// underlying_allocator_ // underlying_allocator_
class BufferedAllocator : public MannualFreeAllocator { class BufferedAllocator : public Allocator {
public: public:
explicit BufferedAllocator(std::unique_ptr<Allocator> &&allocator); explicit BufferedAllocator(std::unique_ptr<Allocator> &&allocator);
......
...@@ -52,7 +52,7 @@ class StubAllocation : public Allocation { ...@@ -52,7 +52,7 @@ class StubAllocation : public Allocation {
using Allocation::Allocation; using Allocation::Allocation;
}; };
class StubAllocator : public MannualFreeAllocator { class StubAllocator : public Allocator {
public: public:
void ResetCounter() { void ResetCounter() {
construct_count_ = 0; construct_count_ = 0;
......
...@@ -24,15 +24,6 @@ ConditionalAllocator& ConditionalAllocator::AddAllocator( ...@@ -24,15 +24,6 @@ ConditionalAllocator& ConditionalAllocator::AddAllocator(
underlying_allocators_.emplace_back(std::move(func), std::move(allocator)); underlying_allocators_.emplace_back(std::move(func), std::move(allocator));
return *this; return *this;
} }
AllocationPtr ConditionalAllocator::Allocate(size_t size,
Allocator::Attr attr) {
for (auto& pair : underlying_allocators_) {
if (pair.first(size, attr)) {
return pair.second->Allocate(size, attr);
}
}
throw BadAlloc("No suitable allocator");
}
bool ConditionalAllocator::IsAllocThreadSafe() const { bool ConditionalAllocator::IsAllocThreadSafe() const {
return std::all_of(underlying_allocators_.begin(), return std::all_of(underlying_allocators_.begin(),
...@@ -42,6 +33,16 @@ bool ConditionalAllocator::IsAllocThreadSafe() const { ...@@ -42,6 +33,16 @@ bool ConditionalAllocator::IsAllocThreadSafe() const {
}); });
} }
Allocation* ConditionalAllocator::AllocateImpl(size_t size,
Allocator::Attr attr) {
for (auto& pair : underlying_allocators_) {
if (pair.first(size, attr)) {
return pair.second->Allocate(size, attr).release();
}
}
throw BadAlloc("No suitable allocator");
}
} // namespace allocation } // namespace allocation
} // namespace memory } // namespace memory
} // namespace paddle } // namespace paddle
...@@ -45,10 +45,13 @@ class ConditionalAllocator : public Allocator { ...@@ -45,10 +45,13 @@ class ConditionalAllocator : public Allocator {
ConditionalAllocator& AddAllocator(std::function<bool(size_t, Attr)> func, ConditionalAllocator& AddAllocator(std::function<bool(size_t, Attr)> func,
std::shared_ptr<Allocator> allocator); std::shared_ptr<Allocator> allocator);
AllocationPtr Allocate(size_t size, Attr attr) override; // AllocationPtr Allocate(size_t size, Attr attr) override;
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
protected:
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
private: private:
using AllocatorWithCond = using AllocatorWithCond =
std::pair<std::function<bool(size_t, Attr)>, std::shared_ptr<Allocator>>; std::pair<std::function<bool(size_t, Attr)>, std::shared_ptr<Allocator>>;
......
...@@ -31,7 +31,7 @@ class CPUAllocation : public Allocation { ...@@ -31,7 +31,7 @@ class CPUAllocation : public Allocation {
CPUAllocation(void* ptr, size_t size); CPUAllocation(void* ptr, size_t size);
}; };
class CPUAllocator : public MannualFreeAllocator { class CPUAllocator : public Allocator {
public: public:
constexpr static size_t kAlignment = 64u; constexpr static size_t kAlignment = 64u;
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
......
...@@ -27,7 +27,7 @@ class CUDAAllocation : public Allocation { ...@@ -27,7 +27,7 @@ class CUDAAllocation : public Allocation {
using Allocation::Allocation; using Allocation::Allocation;
}; };
class CUDAAllocator : public MannualFreeAllocator { class CUDAAllocator : public Allocator {
public: public:
explicit CUDAAllocator(const platform::CUDAPlace& place) : place_(place) {} explicit CUDAAllocator(const platform::CUDAPlace& place) : place_(place) {}
explicit CUDAAllocator(const platform::Place& place) explicit CUDAAllocator(const platform::Place& place)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "paddle/fluid/memory/allocation/locked_allocator.h" #include "paddle/fluid/memory/allocation/locked_allocator.h"
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include "paddle/fluid/memory/allocation/underlying_manual_allocation.h" #include "paddle/fluid/memory/allocation/allocation_with_underlying.h"
#include "paddle/fluid/platform/lock_guard_ptr.h" #include "paddle/fluid/platform/lock_guard_ptr.h"
namespace paddle { namespace paddle {
namespace memory { namespace memory {
...@@ -33,14 +33,14 @@ LockedAllocator::LockedAllocator( ...@@ -33,14 +33,14 @@ LockedAllocator::LockedAllocator(
void LockedAllocator::Free(Allocation *allocation) { void LockedAllocator::Free(Allocation *allocation) {
{ {
platform::LockGuardPtr<std::mutex> guard(mtx_); platform::LockGuardPtr<std::mutex> guard(mtx_);
reinterpret_cast<UnderlyingManualAllocation *>(allocation) reinterpret_cast<AllocationWithUnderlying *>(allocation)
->allocation_.reset(); // Destroy inner allocation ->allocation_.reset(); // Destroy inner allocation
} }
delete allocation; delete allocation;
} }
Allocation *LockedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { Allocation *LockedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
platform::LockGuardPtr<std::mutex> guard(mtx_); platform::LockGuardPtr<std::mutex> guard(mtx_);
return new UnderlyingManualAllocation( return new AllocationWithUnderlying(
underlying_allocator_->Allocate(size, attr)); underlying_allocator_->Allocate(size, attr));
} }
} // namespace allocation } // namespace allocation
......
...@@ -22,7 +22,7 @@ namespace memory { ...@@ -22,7 +22,7 @@ namespace memory {
namespace allocation { namespace allocation {
// A allocator to make underlying allocator thread safe. // A allocator to make underlying allocator thread safe.
class LockedAllocator : public MannualFreeAllocator { class LockedAllocator : public Allocator {
public: public:
explicit LockedAllocator(std::unique_ptr<Allocator> &&underlying_allocator); explicit LockedAllocator(std::unique_ptr<Allocator> &&underlying_allocator);
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
......
...@@ -26,7 +26,7 @@ class CPUPinnedAllocation : public Allocation { ...@@ -26,7 +26,7 @@ class CPUPinnedAllocation : public Allocation {
: Allocation(ptr, size, platform::CUDAPinnedPlace()) {} : Allocation(ptr, size, platform::CUDAPinnedPlace()) {}
}; };
class CPUPinnedAllocator : public MannualFreeAllocator { class CPUPinnedAllocator : public Allocator {
public: public:
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/memory/allocation/retry_allocator.h" #include "paddle/fluid/memory/allocation/retry_allocator.h"
#include "paddle/fluid/memory/allocation/underlying_manual_allocation.h" #include "paddle/fluid/memory/allocation/allocation_with_underlying.h"
namespace paddle { namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
...@@ -24,8 +24,7 @@ bool RetryAllocator::IsAllocThreadSafe() const { ...@@ -24,8 +24,7 @@ bool RetryAllocator::IsAllocThreadSafe() const {
void RetryAllocator::Free(Allocation* allocation) { void RetryAllocator::Free(Allocation* allocation) {
// Delete underlying allocation first. // Delete underlying allocation first.
reinterpret_cast<UnderlyingManualAllocation*>(allocation) reinterpret_cast<AllocationWithUnderlying*>(allocation)->allocation_.reset();
->allocation_.reset();
{ {
// notify all waited allocators, they can try to allocate memory after free. // notify all waited allocators, they can try to allocate memory after free.
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
...@@ -36,7 +35,7 @@ void RetryAllocator::Free(Allocation* allocation) { ...@@ -36,7 +35,7 @@ void RetryAllocator::Free(Allocation* allocation) {
Allocation* RetryAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { Allocation* RetryAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
auto alloc_func = [&, this]() { auto alloc_func = [&, this]() {
return new UnderlyingManualAllocation( return new AllocationWithUnderlying(
underlying_allocator_->Allocate(size, attr)); underlying_allocator_->Allocate(size, attr));
}; };
// In fact, we can unify the code of allocation success and failure // In fact, we can unify the code of allocation success and failure
......
...@@ -26,7 +26,7 @@ namespace allocation { ...@@ -26,7 +26,7 @@ namespace allocation {
class RetryAllocator; class RetryAllocator;
class RetryAllocator : public MannualFreeAllocator { class RetryAllocator : public Allocator {
public: public:
RetryAllocator(std::unique_ptr<Allocator>&& allocator, size_t retry_ms) RetryAllocator(std::unique_ptr<Allocator>&& allocator, size_t retry_ms)
: underlying_allocator_(std::move(allocator)), retry_time_(retry_ms) { : underlying_allocator_(std::move(allocator)), retry_time_(retry_ms) {
......
...@@ -18,17 +18,17 @@ namespace paddle { ...@@ -18,17 +18,17 @@ namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
AllocationPtr ZeroSizeAllocator::Allocate(size_t size, Allocator::Attr attr) { bool ZeroSizeAllocator::IsAllocThreadSafe() const {
return underlying_allocator_->IsAllocThreadSafe();
}
Allocation *ZeroSizeAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
if (size == 0) { if (size == 0) {
return AllocationPtr(new ZeroSizeAllocation(place_)); return new ZeroSizeAllocation(place_);
} else { } else {
return underlying_allocator_->Allocate(size, attr); return underlying_allocator_->Allocate(size, attr).release();
} }
} }
bool ZeroSizeAllocator::IsAllocThreadSafe() const {
return underlying_allocator_->IsAllocThreadSafe();
}
} // namespace allocation } // namespace allocation
} // namespace memory } // namespace memory
} // namespace paddle } // namespace paddle
...@@ -34,10 +34,12 @@ class ZeroSizeAllocator : public Allocator { ...@@ -34,10 +34,12 @@ class ZeroSizeAllocator : public Allocator {
ZeroSizeAllocator(std::shared_ptr<Allocator> underlying_allocator, ZeroSizeAllocator(std::shared_ptr<Allocator> underlying_allocator,
const platform::Place& p) const platform::Place& p)
: underlying_allocator_(std::move(underlying_allocator)), place_(p) {} : underlying_allocator_(std::move(underlying_allocator)), place_(p) {}
AllocationPtr Allocate(size_t size, Attr attr) override;
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
protected:
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
private: private:
std::shared_ptr<Allocator> underlying_allocator_; std::shared_ptr<Allocator> underlying_allocator_;
const platform::Place& place_; const platform::Place& place_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册