提交 e5c4cf61 编写于 作者: Y Yu Yang

Polish allocation

Clean allocation->Deleter

test=develop
上级 0d6718fc
......@@ -86,11 +86,12 @@ template <size_t kAlignment>
class AlignedAllocator : public ThinAlignedAllocator {
public:
using ThinAlignedAllocator::ThinAlignedAllocator;
AllocationPtr Allocate(size_t size, Attr attr) override {
protected:
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override {
auto raw_allocation =
underlying_allocator_->Allocate(size + kAlignment, attr);
return AllocationPtr(
new AlignedAllocation<kAlignment>(std::move(raw_allocation), size));
return new AlignedAllocation<kAlignment>(std::move(raw_allocation), size);
}
};
......
......@@ -20,9 +20,9 @@ namespace paddle {
namespace memory {
namespace allocation {
class UnderlyingManualAllocation : public Allocation {
class AllocationWithUnderlying : public Allocation {
public:
explicit UnderlyingManualAllocation(AllocationPtr allocation)
explicit AllocationWithUnderlying(AllocationPtr allocation)
: Allocation(allocation->ptr(), allocation->size(), allocation->place()),
allocation_(std::move(allocation)) {}
AllocationPtr allocation_;
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/memory/allocation/allocator.h"
#include <functional>
namespace paddle {
......@@ -24,23 +25,20 @@ Allocator::~Allocator() {}
bool Allocator::IsAllocThreadSafe() const { return false; }
AllocationPtr Allocator::Allocate(size_t size, Allocator::Attr attr) {
auto ptr = AllocateImpl(size, attr);
ptr->set_allocator(this);
return AllocationPtr(ptr);
}
void Allocator::Free(Allocation* allocation) { delete allocation; }
const char* BadAlloc::what() const noexcept { return msg_.c_str(); }
AllocationPtr MannualFreeAllocator::Allocate(size_t size,
Allocator::Attr attr) {
auto allocation = AllocateImpl(size, attr);
allocation->Deleter =
std::bind1st(std::mem_fn(&MannualFreeAllocator::Free), this);
return AllocationPtr(allocation);
}
void AllocationDeleter::operator()(Allocation* allocation) const {
if (allocation->Deleter) {
auto deleter = std::move(allocation->Deleter);
deleter(allocation);
} else {
delete allocation;
}
allocation->allocator()->Free(allocation);
}
} // namespace allocation
} // namespace memory
} // namespace paddle
......@@ -32,10 +32,12 @@ class BadAlloc : public std::exception {
};
class Allocation;
struct AllocationDeleter {
class AllocationDeleter {
public:
void operator()(Allocation* allocation) const;
};
class Allocator;
// Allocation is the object holding the actually pointer. Use
// `Allocation::ptr()` will returns the pointer that allocated.
//
......@@ -45,7 +47,7 @@ struct AllocationDeleter {
class Allocation {
public:
Allocation(void* ptr, size_t size, platform::Place place)
: ptr_(ptr), size_(size), place_(place) {}
: allocator_(nullptr), ptr_(ptr), size_(size), place_(place) {}
Allocation(const Allocation& o) = delete;
Allocation& operator=(const Allocation& o) = delete;
......@@ -70,11 +72,14 @@ class Allocation {
const platform::Place& place() const { return place_; }
virtual ~Allocation();
Allocator* allocator() { return allocator_; }
std::function<void(Allocation*)> Deleter;
void set_allocator(Allocator* allocator) { allocator_ = allocator; }
virtual ~Allocation();
private:
Allocator* allocator_;
void* ptr_;
size_t size_;
platform::Place place_;
......@@ -121,25 +126,18 @@ class Allocator {
virtual ~Allocator();
// Allocate an allocation. Note the return allocation might need to be freed
// manually if the Allocator is an `UnmanagedAllocator`.
virtual AllocationPtr Allocate(size_t size,
Allocator::Attr attr = kDefault) = 0;
// Allocate an allocation.
AllocationPtr Allocate(size_t size, Allocator::Attr attr = kDefault);
// True if the `Allocate` is thread safe.
virtual bool IsAllocThreadSafe() const;
};
// User need to invoke `Free` or `FreeUniquePtr` manually if allocated by
// a manally managed allocator.
class MannualFreeAllocator : public Allocator {
public:
AllocationPtr Allocate(size_t size, Attr attr) final;
protected:
virtual void Free(Allocation* allocation) = 0;
virtual void Free(Allocation* allocation);
virtual Allocation* AllocateImpl(size_t size, Allocator::Attr attr) = 0;
friend class MannualFreeAllocation;
private:
friend class AllocationDeleter;
};
} // namespace allocation
......
......@@ -49,12 +49,13 @@ class CPUManagedAllocator : public Allocator {
public:
CPUManagedAllocator() : normal_allocator_(new CPUAllocator()) {}
AllocationPtr Allocate(size_t size, Attr attr) override {
return normal_allocator_->Allocate(size, attr);
}
bool IsAllocThreadSafe() const override { return true; }
protected:
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override {
return normal_allocator_->Allocate(size, attr).release();
}
private:
std::shared_ptr<Allocator> normal_allocator_;
};
......@@ -103,10 +104,6 @@ class ChunkedManagedAllocator : public Allocator {
raw_allocator_.reset();
}
AllocationPtr Allocate(size_t size, Attr attr) override {
return default_allocator_->Allocate(size, attr);
}
std::shared_ptr<Allocator> BestFitAllocatorCreator() {
chunks_.emplace_back(raw_allocator_->Allocate(max_chunk_size_));
auto* allocation = chunks_.back().get();
......@@ -128,6 +125,11 @@ class ChunkedManagedAllocator : public Allocator {
bool IsAllocThreadSafe() const override { return true; }
protected:
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override {
return default_allocator_->Allocate(size, attr).release();
}
protected:
size_t max_chunk_size_;
int64_t retry_time_;
......
......@@ -17,9 +17,25 @@
namespace paddle {
namespace memory {
namespace allocation {
bool AutoIncrementAllocator::IsAllocThreadSafe() const { return true; }
AllocationPtr AutoIncrementAllocator::Allocate(size_t size,
Allocator::Attr attr) {
std::shared_ptr<Allocator> AutoIncrementAllocator::CreateNewAllocator() {
std::lock_guard<std::mutex> guard(mtx_);
auto old_size = allocator_num_.load();
PADDLE_ENFORCE_LT(old_size, underlying_allocators_.size(),
"Allocator number exceeds capacity %d",
underlying_allocators_.size());
underlying_allocators_[old_size] = creator_();
prev_success_allocator_ = old_size;
++allocator_num_;
PADDLE_ENFORCE(
underlying_allocators_[old_size]->IsAllocThreadSafe(),
"the underlying allocator must be thread safe. This is a program "
"bug.");
return underlying_allocators_[old_size];
}
Allocation *AutoIncrementAllocator::AllocateImpl(size_t size,
Allocator::Attr attr) {
auto cur = prev_success_allocator_.load();
size_t retry_count = allocator_num_.load();
size_t allocator_num = retry_count;
......@@ -27,8 +43,8 @@ AllocationPtr AutoIncrementAllocator::Allocate(size_t size,
try {
auto res = underlying_allocators_[cur]->Allocate(size, attr);
prev_success_allocator_ = cur;
return res;
} catch (BadAlloc&) {
return res.release();
} catch (BadAlloc &) {
if (++cur >= allocator_num) {
cur = 0;
}
......@@ -47,32 +63,14 @@ AllocationPtr AutoIncrementAllocator::Allocate(size_t size,
try {
auto ret = underlying_allocators_[cur]->Allocate(size, attr);
prev_success_allocator_ = cur;
return ret;
} catch (BadAlloc&) {
return ret.release();
} catch (BadAlloc &) {
} catch (...) {
throw;
}
}
// No suitable allocator
return CreateNewAllocator()->Allocate(size, attr);
}
bool AutoIncrementAllocator::IsAllocThreadSafe() const { return true; }
std::shared_ptr<Allocator> AutoIncrementAllocator::CreateNewAllocator() {
std::lock_guard<std::mutex> guard(mtx_);
auto old_size = allocator_num_.load();
PADDLE_ENFORCE_LT(old_size, underlying_allocators_.size(),
"Allocator number exceeds capacity %d",
underlying_allocators_.size());
underlying_allocators_[old_size] = creator_();
prev_success_allocator_ = old_size;
++allocator_num_;
PADDLE_ENFORCE(
underlying_allocators_[old_size]->IsAllocThreadSafe(),
"the underlying allocator must be thread safe. This is a program "
"bug.");
return underlying_allocators_[old_size];
return CreateNewAllocator()->Allocate(size, attr).release();
}
} // namespace allocation
......
......@@ -54,13 +54,15 @@ class AutoIncrementAllocator : public Allocator {
explicit AutoIncrementAllocator(AllocatorCreator&& creator, size_t capacity)
: creator_(std::move(creator)), underlying_allocators_(capacity) {}
AllocationPtr Allocate(size_t size, Attr attr) override;
bool IsAllocThreadSafe() const override;
private:
std::shared_ptr<Allocator> CreateNewAllocator();
protected:
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
private:
AllocatorCreator creator_;
std::vector<AllocatorCreator::result_type> underlying_allocators_;
......
......@@ -98,7 +98,7 @@ class BestFitAllocation : public Allocation {
//
// To free an allocation, it will set the chunk of allocation to free and merge
// the prev-chunk and the next-chunk when possible.
class BestFitAllocator : public MannualFreeAllocator {
class BestFitAllocator : public Allocator {
public:
explicit BestFitAllocator(Allocation* allocation);
......
......@@ -16,7 +16,7 @@
#include <algorithm>
#include <limits>
#include <utility>
#include "paddle/fluid/memory/allocation/underlying_manual_allocation.h"
#include "paddle/fluid/memory/allocation/allocation_with_underlying.h"
namespace paddle {
namespace memory {
......@@ -60,16 +60,16 @@ Allocation *BufferedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
if (it != allocations_.end() && it->first < size * 2) {
AllocationPtr result(std::move(it->second));
allocations_.erase(it);
return new UnderlyingManualAllocation(std::move(result));
return new AllocationWithUnderlying(std::move(result));
}
}
try {
return new UnderlyingManualAllocation(
return new AllocationWithUnderlying(
underlying_allocator_->Allocate(size, attr));
} catch (BadAlloc &) {
FreeCache(size);
return new UnderlyingManualAllocation(
return new AllocationWithUnderlying(
underlying_allocator_->Allocate(size, attr));
}
}
......
......@@ -29,7 +29,7 @@ namespace allocation {
// memory allocation and reuse memory.
// BufferedAllocator provides the same thread-safety level as
// underlying_allocator_
class BufferedAllocator : public MannualFreeAllocator {
class BufferedAllocator : public Allocator {
public:
explicit BufferedAllocator(std::unique_ptr<Allocator> &&allocator);
......
......@@ -52,7 +52,7 @@ class StubAllocation : public Allocation {
using Allocation::Allocation;
};
class StubAllocator : public MannualFreeAllocator {
class StubAllocator : public Allocator {
public:
void ResetCounter() {
construct_count_ = 0;
......
......@@ -24,15 +24,6 @@ ConditionalAllocator& ConditionalAllocator::AddAllocator(
underlying_allocators_.emplace_back(std::move(func), std::move(allocator));
return *this;
}
AllocationPtr ConditionalAllocator::Allocate(size_t size,
Allocator::Attr attr) {
for (auto& pair : underlying_allocators_) {
if (pair.first(size, attr)) {
return pair.second->Allocate(size, attr);
}
}
throw BadAlloc("No suitable allocator");
}
bool ConditionalAllocator::IsAllocThreadSafe() const {
return std::all_of(underlying_allocators_.begin(),
......@@ -42,6 +33,16 @@ bool ConditionalAllocator::IsAllocThreadSafe() const {
});
}
Allocation* ConditionalAllocator::AllocateImpl(size_t size,
Allocator::Attr attr) {
for (auto& pair : underlying_allocators_) {
if (pair.first(size, attr)) {
return pair.second->Allocate(size, attr).release();
}
}
throw BadAlloc("No suitable allocator");
}
} // namespace allocation
} // namespace memory
} // namespace paddle
......@@ -45,10 +45,13 @@ class ConditionalAllocator : public Allocator {
ConditionalAllocator& AddAllocator(std::function<bool(size_t, Attr)> func,
std::shared_ptr<Allocator> allocator);
AllocationPtr Allocate(size_t size, Attr attr) override;
// AllocationPtr Allocate(size_t size, Attr attr) override;
bool IsAllocThreadSafe() const override;
protected:
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
private:
using AllocatorWithCond =
std::pair<std::function<bool(size_t, Attr)>, std::shared_ptr<Allocator>>;
......
......@@ -31,7 +31,7 @@ class CPUAllocation : public Allocation {
CPUAllocation(void* ptr, size_t size);
};
class CPUAllocator : public MannualFreeAllocator {
class CPUAllocator : public Allocator {
public:
constexpr static size_t kAlignment = 64u;
bool IsAllocThreadSafe() const override;
......
......@@ -27,7 +27,7 @@ class CUDAAllocation : public Allocation {
using Allocation::Allocation;
};
class CUDAAllocator : public MannualFreeAllocator {
class CUDAAllocator : public Allocator {
public:
explicit CUDAAllocator(const platform::CUDAPlace& place) : place_(place) {}
explicit CUDAAllocator(const platform::Place& place)
......
......@@ -14,7 +14,7 @@
#include "paddle/fluid/memory/allocation/locked_allocator.h"
#include <mutex> // NOLINT
#include "paddle/fluid/memory/allocation/underlying_manual_allocation.h"
#include "paddle/fluid/memory/allocation/allocation_with_underlying.h"
#include "paddle/fluid/platform/lock_guard_ptr.h"
namespace paddle {
namespace memory {
......@@ -33,14 +33,14 @@ LockedAllocator::LockedAllocator(
void LockedAllocator::Free(Allocation *allocation) {
{
platform::LockGuardPtr<std::mutex> guard(mtx_);
reinterpret_cast<UnderlyingManualAllocation *>(allocation)
reinterpret_cast<AllocationWithUnderlying *>(allocation)
->allocation_.reset(); // Destroy inner allocation
}
delete allocation;
}
Allocation *LockedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
platform::LockGuardPtr<std::mutex> guard(mtx_);
return new UnderlyingManualAllocation(
return new AllocationWithUnderlying(
underlying_allocator_->Allocate(size, attr));
}
} // namespace allocation
......
......@@ -22,7 +22,7 @@ namespace memory {
namespace allocation {
// A allocator to make underlying allocator thread safe.
class LockedAllocator : public MannualFreeAllocator {
class LockedAllocator : public Allocator {
public:
explicit LockedAllocator(std::unique_ptr<Allocator> &&underlying_allocator);
bool IsAllocThreadSafe() const override;
......
......@@ -26,7 +26,7 @@ class CPUPinnedAllocation : public Allocation {
: Allocation(ptr, size, platform::CUDAPinnedPlace()) {}
};
class CPUPinnedAllocator : public MannualFreeAllocator {
class CPUPinnedAllocator : public Allocator {
public:
bool IsAllocThreadSafe() const override;
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/memory/allocation/retry_allocator.h"
#include "paddle/fluid/memory/allocation/underlying_manual_allocation.h"
#include "paddle/fluid/memory/allocation/allocation_with_underlying.h"
namespace paddle {
namespace memory {
namespace allocation {
......@@ -24,8 +24,7 @@ bool RetryAllocator::IsAllocThreadSafe() const {
void RetryAllocator::Free(Allocation* allocation) {
// Delete underlying allocation first.
reinterpret_cast<UnderlyingManualAllocation*>(allocation)
->allocation_.reset();
reinterpret_cast<AllocationWithUnderlying*>(allocation)->allocation_.reset();
{
// notify all waited allocators, they can try to allocate memory after free.
std::lock_guard<std::mutex> lock(mutex_);
......@@ -36,7 +35,7 @@ void RetryAllocator::Free(Allocation* allocation) {
Allocation* RetryAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
auto alloc_func = [&, this]() {
return new UnderlyingManualAllocation(
return new AllocationWithUnderlying(
underlying_allocator_->Allocate(size, attr));
};
// In fact, we can unify the code of allocation success and failure
......
......@@ -26,7 +26,7 @@ namespace allocation {
class RetryAllocator;
class RetryAllocator : public MannualFreeAllocator {
class RetryAllocator : public Allocator {
public:
RetryAllocator(std::unique_ptr<Allocator>&& allocator, size_t retry_ms)
: underlying_allocator_(std::move(allocator)), retry_time_(retry_ms) {
......
......@@ -18,17 +18,17 @@ namespace paddle {
namespace memory {
namespace allocation {
AllocationPtr ZeroSizeAllocator::Allocate(size_t size, Allocator::Attr attr) {
bool ZeroSizeAllocator::IsAllocThreadSafe() const {
return underlying_allocator_->IsAllocThreadSafe();
}
Allocation *ZeroSizeAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
if (size == 0) {
return AllocationPtr(new ZeroSizeAllocation(place_));
return new ZeroSizeAllocation(place_);
} else {
return underlying_allocator_->Allocate(size, attr);
return underlying_allocator_->Allocate(size, attr).release();
}
}
bool ZeroSizeAllocator::IsAllocThreadSafe() const {
return underlying_allocator_->IsAllocThreadSafe();
}
} // namespace allocation
} // namespace memory
} // namespace paddle
......@@ -34,10 +34,12 @@ class ZeroSizeAllocator : public Allocator {
ZeroSizeAllocator(std::shared_ptr<Allocator> underlying_allocator,
const platform::Place& p)
: underlying_allocator_(std::move(underlying_allocator)), place_(p) {}
AllocationPtr Allocate(size_t size, Attr attr) override;
bool IsAllocThreadSafe() const override;
protected:
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
private:
std::shared_ptr<Allocator> underlying_allocator_;
const platform::Place& place_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册