提交 d93b2d03 编写于 作者: Y Yu Yang

Refine code

上级 ea81f8ee
......@@ -33,8 +33,7 @@ class AlignedAllocation : public Allocation {
"kAlignment must be 2^N");
public:
AlignedAllocation(std::unique_ptr<Allocation>&& underlying_allocation,
size_t size)
AlignedAllocation(AllocationPtr&& underlying_allocation, size_t size)
: Allocation(AlignedPtr(underlying_allocation->ptr()),
size + kAlignment - Offset(underlying_allocation->ptr()),
underlying_allocation->place()),
......@@ -59,7 +58,7 @@ class AlignedAllocation : public Allocation {
}
}
std::unique_ptr<Allocation> underlying_allocation_;
AllocationPtr underlying_allocation_;
};
// Thin aligned allocator is trivial and used to generate a small size binary.
......@@ -87,10 +86,10 @@ template <size_t kAlignment>
class AlignedAllocator : public ThinAlignedAllocator {
public:
using ThinAlignedAllocator::ThinAlignedAllocator;
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override {
AllocationPtr Allocate(size_t size, Attr attr) override {
auto raw_allocation =
underlying_allocator_->Allocate(size + kAlignment, attr);
return std::unique_ptr<Allocation>(
return AllocationPtr(
new AlignedAllocation<kAlignment>(std::move(raw_allocation), size));
}
};
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "paddle/fluid/memory/allocation/allocator.h"
#include <functional>
namespace paddle {
namespace memory {
namespace allocation {
......@@ -24,10 +26,20 @@ bool Allocator::IsAllocThreadSafe() const { return false; }
const char* BadAlloc::what() const noexcept { return msg_.c_str(); }
MannualFreeAllocation::~MannualFreeAllocation() { allocator_->Free(this); }
std::unique_ptr<Allocation> MannualFreeAllocator::Allocate(
size_t size, Allocator::Attr attr) {
return std::unique_ptr<Allocation>(AllocateImpl(size, attr));
AllocationPtr MannualFreeAllocator::Allocate(size_t size,
Allocator::Attr attr) {
auto allocation = AllocateImpl(size, attr);
allocation->Deleter =
std::bind1st(std::mem_fn(&MannualFreeAllocator::Free), this);
return AllocationPtr(allocation);
}
void AllocationDeleter::operator()(Allocation* allocation) const {
if (allocation->Deleter) {
auto deleter = std::move(allocation->Deleter);
deleter(allocation);
} else {
delete allocation;
}
}
} // namespace allocation
} // namespace memory
......
......@@ -31,6 +31,11 @@ class BadAlloc : public std::exception {
std::string msg_;
};
class Allocation;
struct AllocationDeleter {
void operator()(Allocation* allocation) const;
};
// Allocation is the object holding the actually pointer. Use
// `Allocation::ptr()` will returns the pointer that allocated.
//
......@@ -67,12 +72,16 @@ class Allocation {
virtual ~Allocation();
std::function<void(Allocation*)> Deleter;
private:
void* ptr_;
size_t size_;
platform::Place place_;
};
using AllocationPtr = std::unique_ptr<Allocation, AllocationDeleter>;
// Base interface class of memory Allocator.
// To allocate a memory, allocator needs two parameters:
// 1. size of bytes.
......@@ -114,36 +123,22 @@ class Allocator {
// Allocate an allocation. Note the return allocation might need to be freed
// manually if the Allocator is an `UnmanagedAllocator`.
virtual std::unique_ptr<Allocation> Allocate(
size_t size, Allocator::Attr attr = kDefault) = 0;
virtual AllocationPtr Allocate(size_t size,
Allocator::Attr attr = kDefault) = 0;
// True if the `Allocate` is thread safe.
virtual bool IsAllocThreadSafe() const;
};
class MannualFreeAllocator;
class MannualFreeAllocation : public Allocation {
public:
MannualFreeAllocation(MannualFreeAllocator* allocator, void* ptr, size_t size,
platform::Place place)
: Allocation(ptr, size, place), allocator_(allocator) {}
~MannualFreeAllocation();
private:
MannualFreeAllocator* allocator_;
};
// User need to invoke `Free` or `FreeUniquePtr` manually if allocated by
// a manally managed allocator.
class MannualFreeAllocator : public Allocator {
public:
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) final;
AllocationPtr Allocate(size_t size, Attr attr) final;
protected:
virtual void Free(MannualFreeAllocation* allocation) = 0;
virtual MannualFreeAllocation* AllocateImpl(size_t size,
Allocator::Attr attr) = 0;
virtual void Free(Allocation* allocation) = 0;
virtual Allocation* AllocateImpl(size_t size, Allocator::Attr attr) = 0;
friend class MannualFreeAllocation;
};
......
......@@ -49,7 +49,7 @@ class CPUManagedAllocator : public Allocator {
public:
CPUManagedAllocator() : normal_allocator_(new CPUAllocator()) {}
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override {
AllocationPtr Allocate(size_t size, Attr attr) override {
return normal_allocator_->Allocate(size, attr);
}
......@@ -103,7 +103,7 @@ class ChunkedManagedAllocator : public Allocator {
raw_allocator_.reset();
}
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override {
AllocationPtr Allocate(size_t size, Attr attr) override {
return default_allocator_->Allocate(size, attr);
}
......@@ -131,7 +131,7 @@ class ChunkedManagedAllocator : public Allocator {
protected:
size_t max_chunk_size_;
int64_t retry_time_;
std::vector<std::unique_ptr<Allocation>> chunks_;
std::vector<AllocationPtr> chunks_;
std::shared_ptr<Allocator> raw_allocator_;
std::shared_ptr<Allocator> default_allocator_;
};
......@@ -236,12 +236,12 @@ AllocatorFacade& AllocatorFacade::Instance() {
std::shared_ptr<Allocation> AllocatorFacade::AllocShared(
const platform::Place& place, size_t size, Allocator::Attr attr) {
return std::shared_ptr<Allocation>(
m_->allocators_.at(place)->Allocate(size, attr).release());
m_->allocators_.at(place)->Allocate(size, attr).release(),
AllocationDeleter());
}
std::unique_ptr<Allocation> AllocatorFacade::Alloc(const platform::Place& place,
size_t size,
Allocator::Attr attr) {
AllocationPtr AllocatorFacade::Alloc(const platform::Place& place, size_t size,
Allocator::Attr attr) {
return m_->allocators_.at(place)->Allocate(size, attr);
}
......
......@@ -43,8 +43,8 @@ class AllocatorFacade {
Allocator::Attr attr = Allocator::kDefault);
// Allocate a unique allocation.
std::unique_ptr<Allocation> Alloc(const platform::Place& place, size_t size,
Allocator::Attr attr = Allocator::kDefault);
AllocationPtr Alloc(const platform::Place& place, size_t size,
Allocator::Attr attr = Allocator::kDefault);
// TODO(yy): Allocate a Copy-On-Write allocation?
private:
......
......@@ -18,8 +18,8 @@ namespace paddle {
namespace memory {
namespace allocation {
std::unique_ptr<Allocation> AutoIncrementAllocator::Allocate(
size_t size, Allocator::Attr attr) {
AllocationPtr AutoIncrementAllocator::Allocate(size_t size,
Allocator::Attr attr) {
auto cur = prev_success_allocator_.load();
size_t retry_count = allocator_num_.load();
size_t allocator_num = retry_count;
......
......@@ -54,7 +54,7 @@ class AutoIncrementAllocator : public Allocator {
explicit AutoIncrementAllocator(AllocatorCreator&& creator, size_t capacity)
: creator_(std::move(creator)), underlying_allocators_(capacity) {}
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override;
AllocationPtr Allocate(size_t size, Attr attr) override;
bool IsAllocThreadSafe() const override;
......
......@@ -109,7 +109,7 @@ size_t BestFitAllocator::NumFreeChunks() const {
}
return num;
}
void BestFitAllocator::Free(MannualFreeAllocation* allocation) {
void BestFitAllocator::Free(Allocation* allocation) {
auto* bf_allocation = dynamic_cast<BestFitAllocation*>(allocation);
auto chunk_it = bf_allocation->ChunkIterator();
PADDLE_ENFORCE(!chunk_it->is_free);
......@@ -136,9 +136,9 @@ void BestFitAllocator::Free(MannualFreeAllocation* allocation) {
}
InsertFreeNode(chunk_it);
delete allocation;
}
MannualFreeAllocation* BestFitAllocator::AllocateImpl(size_t size,
Allocator::Attr attr) {
Allocation* BestFitAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
auto highest_set_bit = static_cast<size_t>(HighestBitPos(size));
MapIt map_it;
for (; highest_set_bit < free_chunks_.size(); ++highest_set_bit) {
......@@ -158,11 +158,10 @@ MannualFreeAllocation* BestFitAllocator::AllocateImpl(size_t size,
BestFitAllocation::BestFitAllocation(
paddle::memory::allocation::BestFitAllocator* allocator,
typename details::ChunkList::iterator chunk_it)
: MannualFreeAllocation(
allocator, reinterpret_cast<void*>(
reinterpret_cast<uintptr_t>(allocator->BasePtr()) +
chunk_it->offset_),
chunk_it->size_, allocator->Place()),
: Allocation(reinterpret_cast<void*>(
reinterpret_cast<uintptr_t>(allocator->BasePtr()) +
chunk_it->offset_),
chunk_it->size_, allocator->Place()),
chunk_it_(chunk_it) {}
} // namespace allocation
} // namespace memory
......
......@@ -71,7 +71,7 @@ using FreeChunkBin =
class BestFitAllocator;
// The BestFitAllocation maintain the List Node iterator.
class BestFitAllocation : public MannualFreeAllocation {
class BestFitAllocation : public Allocation {
private:
using ListIt = typename details::ChunkList::iterator;
......@@ -123,9 +123,8 @@ class BestFitAllocator : public MannualFreeAllocator {
void InsertFreeNode(const ListIt& it);
protected:
void Free(MannualFreeAllocation* allocation) override;
MannualFreeAllocation* AllocateImpl(size_t size,
Allocator::Attr attr) override;
void Free(Allocation* allocation) override;
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
private:
Allocation* allocation_; // not owned
......
......@@ -49,33 +49,28 @@ void BufferedAllocator::FreeCache(size_t size) {
bool BufferedAllocator::IsAllocThreadSafe() const {
return this->underlying_allocator_->IsAllocThreadSafe();
}
void BufferedAllocator::Free(MannualFreeAllocation *allocation) {
void BufferedAllocator::Free(Allocation *allocation) {
platform::LockGuardPtr<std::mutex> guard(mtx_);
std::unique_ptr<Allocation> new_allocation(new UnderlyingManualAllocation(
this, std::move(reinterpret_cast<UnderlyingManualAllocation *>(allocation)
->allocation_)));
allocations_.emplace(allocation->size(), std::move(new_allocation));
allocations_.emplace(allocation->size(), AllocationPtr(allocation));
}
MannualFreeAllocation *BufferedAllocator::AllocateImpl(size_t size,
Allocator::Attr attr) {
Allocation *BufferedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
{
platform::LockGuardPtr<std::mutex> guard(mtx_);
auto it = allocations_.lower_bound(size);
if (it != allocations_.end() && it->first < size * 2) {
std::unique_ptr<Allocation> result(std::move(it->second));
AllocationPtr result(std::move(it->second));
allocations_.erase(it);
return new UnderlyingManualAllocation(this, std::move(result));
return new UnderlyingManualAllocation(std::move(result));
}
}
try {
return new UnderlyingManualAllocation(
this, underlying_allocator_->Allocate(size, attr));
underlying_allocator_->Allocate(size, attr));
} catch (BadAlloc &) {
FreeCache(size);
return new UnderlyingManualAllocation(
this, underlying_allocator_->Allocate(size, attr));
underlying_allocator_->Allocate(size, attr));
}
}
......
......@@ -50,13 +50,12 @@ class BufferedAllocator : public MannualFreeAllocator {
void FreeCache(size_t size);
protected:
void Free(MannualFreeAllocation *allocation) override;
MannualFreeAllocation *AllocateImpl(size_t size,
Allocator::Attr attr) override;
void Free(Allocation *allocation) override;
Allocation *AllocateImpl(size_t size, Allocator::Attr attr) override;
private:
std::unique_ptr<Allocator> underlying_allocator_;
std::multimap<size_t, std::unique_ptr<Allocation>> allocations_;
std::multimap<size_t, AllocationPtr> allocations_;
std::unique_ptr<std::mutex> mtx_;
};
......
......@@ -24,8 +24,8 @@ ConditionalAllocator& ConditionalAllocator::AddAllocator(
underlying_allocators_.emplace_back(std::move(func), std::move(allocator));
return *this;
}
std::unique_ptr<Allocation> ConditionalAllocator::Allocate(
size_t size, Allocator::Attr attr) {
AllocationPtr ConditionalAllocator::Allocate(size_t size,
Allocator::Attr attr) {
for (auto& pair : underlying_allocators_) {
if (pair.first(size, attr)) {
return pair.second->Allocate(size, attr);
......
......@@ -45,7 +45,7 @@ class ConditionalAllocator : public Allocator {
ConditionalAllocator& AddAllocator(std::function<bool(size_t, Attr)> func,
std::shared_ptr<Allocator> allocator);
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override;
AllocationPtr Allocate(size_t size, Attr attr) override;
bool IsAllocThreadSafe() const override;
......
......@@ -20,26 +20,25 @@ namespace paddle {
namespace memory {
namespace allocation {
CPUAllocation::CPUAllocation(
paddle::memory::allocation::CPUAllocator *allocator, void *ptr, size_t size)
: MannualFreeAllocation(allocator, ptr, size, platform::CPUPlace()) {}
CPUAllocation::CPUAllocation(void *ptr, size_t size)
: Allocation(ptr, size, platform::CPUPlace()) {}
bool CPUAllocator::IsAllocThreadSafe() const { return true; }
void CPUAllocator::Free(MannualFreeAllocation *allocation) {
void CPUAllocator::Free(Allocation *allocation) {
PADDLE_ENFORCE_NOT_NULL(dynamic_cast<CPUAllocation *>(allocation));
free(allocation->ptr());
delete allocation;
}
MannualFreeAllocation *CPUAllocator::AllocateImpl(size_t size,
Allocator::Attr attr) {
Allocation *CPUAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
void *ptr;
auto status = posix_memalign(&ptr, kAlignment, size);
if (UNLIKELY(status) != 0) {
throw BadAlloc(string::Sprintf("Cannot allocate cpu memory %d. Errno is %d",
size, status));
}
return new CPUAllocation(this, ptr, size);
return new CPUAllocation(ptr, size);
}
} // namespace allocation
} // namespace memory
......
......@@ -26,9 +26,9 @@ namespace allocation {
// NOTE(yy): It is no need to use `BestFitAllocator` in CPU. We can import
// an open-sourced allocator into Paddle.
class CPUAllocator;
class CPUAllocation : public MannualFreeAllocation {
class CPUAllocation : public Allocation {
public:
CPUAllocation(CPUAllocator* allocator, void* ptr, size_t size);
CPUAllocation(void* ptr, size_t size);
};
class CPUAllocator : public MannualFreeAllocator {
......@@ -37,9 +37,8 @@ class CPUAllocator : public MannualFreeAllocator {
bool IsAllocThreadSafe() const override;
protected:
void Free(MannualFreeAllocation* allocation) override;
MannualFreeAllocation* AllocateImpl(size_t size,
Allocator::Attr attr) override;
void Free(Allocation* allocation) override;
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
};
} // namespace allocation
} // namespace memory
......
......@@ -22,7 +22,17 @@
namespace paddle {
namespace memory {
namespace allocation {
std::unique_ptr<Allocation> CUDAAllocator::Allocate(size_t size, Attr attr) {
bool CUDAAllocator::IsAllocThreadSafe() const { return true; }
void CUDAAllocator::Free(Allocation* allocation) {
platform::CUDADeviceGuard guard(place_.device);
auto* cuda_allocation = dynamic_cast<CUDAAllocation*>(allocation);
PADDLE_ENFORCE_NOT_NULL(cuda_allocation);
PADDLE_ENFORCE_EQ(boost::get<platform::CUDAPlace>(cuda_allocation->place()),
place_);
PADDLE_ENFORCE(cudaFree(allocation->ptr()));
delete allocation;
}
Allocation* CUDAAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
platform::CUDADeviceGuard guard(place_.device);
void* ptr;
auto status = cudaMalloc(&ptr, size);
......@@ -31,19 +41,8 @@ std::unique_ptr<Allocation> CUDAAllocator::Allocate(size_t size, Attr attr) {
"Cannot allocate %d on GPU %d, cuda status %d, %s", size, place_.device,
status, cudaGetErrorString(status)));
}
return std::unique_ptr<Allocation>(
new CUDAAllocation(ptr, size, platform::Place(place_)));
return new CUDAAllocation(ptr, size, platform::Place(place_));
}
void CUDAAllocator::FreeUniquePtr(std::unique_ptr<Allocation> allocation) {
platform::CUDADeviceGuard guard(place_.device);
auto* cuda_allocation = dynamic_cast<CUDAAllocation*>(allocation.get());
PADDLE_ENFORCE_NOT_NULL(cuda_allocation);
PADDLE_ENFORCE_EQ(boost::get<platform::CUDAPlace>(cuda_allocation->place()),
place_);
PADDLE_ENFORCE(cudaFree(allocation->ptr()));
}
bool CUDAAllocator::IsAllocThreadSafe() const { return true; }
} // namespace allocation
} // namespace memory
} // namespace paddle
......@@ -27,16 +27,17 @@ class CUDAAllocation : public Allocation {
using Allocation::Allocation;
};
class CUDAAllocator : public UnmanagedAllocator {
class CUDAAllocator : public MannualFreeAllocator {
public:
explicit CUDAAllocator(const platform::CUDAPlace& place) : place_(place) {}
explicit CUDAAllocator(const platform::Place& place)
: place_(boost::get<platform::CUDAPlace>(place)) {}
std::unique_ptr<Allocation> Allocate(size_t size,
Attr attr = kDefault) override;
void FreeUniquePtr(std::unique_ptr<Allocation> allocation) override;
bool IsAllocThreadSafe() const override;
protected:
void Free(Allocation* allocation) override;
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
private:
platform::CUDAPlace place_;
};
......
......@@ -30,16 +30,18 @@ LockedAllocator::LockedAllocator(
mtx_.reset(new std::mutex());
}
}
void LockedAllocator::Free(MannualFreeAllocation *allocation) {
platform::LockGuardPtr<std::mutex> guard(mtx_);
reinterpret_cast<UnderlyingManualAllocation *>(allocation)
->allocation_.reset();
void LockedAllocator::Free(Allocation *allocation) {
{
platform::LockGuardPtr<std::mutex> guard(mtx_);
reinterpret_cast<UnderlyingManualAllocation *>(allocation)
->allocation_.reset(); // Destroy inner allocation
}
delete allocation;
}
MannualFreeAllocation *LockedAllocator::AllocateImpl(size_t size,
Allocator::Attr attr) {
Allocation *LockedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
platform::LockGuardPtr<std::mutex> guard(mtx_);
return new UnderlyingManualAllocation(
this, underlying_allocator_->Allocate(size, attr));
underlying_allocator_->Allocate(size, attr));
}
} // namespace allocation
} // namespace memory
......
......@@ -28,9 +28,8 @@ class LockedAllocator : public MannualFreeAllocator {
bool IsAllocThreadSafe() const override;
protected:
void Free(MannualFreeAllocation *allocation) override;
MannualFreeAllocation *AllocateImpl(size_t size,
Allocator::Attr attr) override;
void Free(Allocation *allocation) override;
Allocation *AllocateImpl(size_t size, Allocator::Attr attr) override;
private:
std::unique_ptr<Allocator> underlying_allocator_;
......
......@@ -19,25 +19,22 @@
namespace paddle {
namespace memory {
namespace allocation {
std::unique_ptr<Allocation> CPUPinnedAllocator::Allocate(size_t size,
Allocator::Attr attr) {
bool CPUPinnedAllocator::IsAllocThreadSafe() const { return true; }
void CPUPinnedAllocator::Free(Allocation *allocation) {
PADDLE_ENFORCE_NOT_NULL(dynamic_cast<CPUPinnedAllocation *>(allocation));
PADDLE_ENFORCE(cudaFreeHost(allocation->ptr()));
delete allocation;
}
Allocation *CPUPinnedAllocator::AllocateImpl(size_t size,
Allocator::Attr attr) {
// PADDLE_ENFORCE_EQ(
// attr, kCrossDevice,
// "CPUPinnedAllocator should be used for Cross-Device Communication");
void* ptr;
void *ptr;
PADDLE_ENFORCE(cudaMallocHost(&ptr, size));
return std::unique_ptr<CPUPinnedAllocation>(
new CPUPinnedAllocation(ptr, size));
return new CPUPinnedAllocation(ptr, size);
}
void CPUPinnedAllocator::FreeUniquePtr(std::unique_ptr<Allocation> allocation) {
PADDLE_ENFORCE_NOT_NULL(dynamic_cast<CPUPinnedAllocation*>(allocation.get()));
PADDLE_ENFORCE(cudaFreeHost(allocation->ptr()));
}
bool CPUPinnedAllocator::IsAllocThreadSafe() const { return true; }
} // namespace allocation
} // namespace memory
} // namespace paddle
......@@ -22,15 +22,17 @@ namespace allocation {
// Allocator uses `cudaMallocHost`
class CPUPinnedAllocation : public Allocation {
public:
CPUPinnedAllocation(void* ptr, size_t size)
CPUPinnedAllocation(void *ptr, size_t size)
: Allocation(ptr, size, platform::CUDAPinnedPlace()) {}
};
class CPUPinnedAllocator : public UnmanagedAllocator {
class CPUPinnedAllocator : public MannualFreeAllocator {
public:
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override;
void FreeUniquePtr(std::unique_ptr<Allocation> allocation) override;
bool IsAllocThreadSafe() const override;
protected:
void Free(Allocation *allocation) override;
Allocation *AllocateImpl(size_t size, Allocator::Attr attr) override;
};
} // namespace allocation
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/memory/allocation/retry_allocator.h"
#include "paddle/fluid/memory/allocation/underlying_manual_allocation.h"
namespace paddle {
namespace memory {
namespace allocation {
......@@ -22,21 +22,22 @@ bool RetryAllocator::IsAllocThreadSafe() const {
return underlying_allocator_->IsAllocThreadSafe();
}
void RetryAllocator::Free(MannualFreeAllocation* allocation) {
reinterpret_cast<RetryAllocation*>(allocation)
->underlying_allocation_.reset();
void RetryAllocator::Free(Allocation* allocation) {
// Delete underlying allocation first.
reinterpret_cast<UnderlyingManualAllocation*>(allocation)
->allocation_.reset();
{
// notify all waited allocators, they can try to allocate memory after free.
std::lock_guard<std::mutex> lock(mutex_);
cv_.notify_all();
}
delete allocation;
}
MannualFreeAllocation* RetryAllocator::AllocateImpl(size_t size,
Allocator::Attr attr) {
Allocation* RetryAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
auto alloc_func = [&, this]() {
return new RetryAllocation(underlying_allocator_->Allocate(size, attr),
this);
return new UnderlyingManualAllocation(
underlying_allocator_->Allocate(size, attr));
};
// In fact, we can unify the code of allocation success and failure
// But it would add lock even when allocation success at the first time
......
......@@ -26,17 +26,6 @@ namespace allocation {
class RetryAllocator;
class RetryAllocation : public MannualFreeAllocation {
public:
RetryAllocation(std::unique_ptr<Allocation>&& underlying_allocation,
MannualFreeAllocator* allocator)
: MannualFreeAllocation(allocator, underlying_allocation->ptr(),
underlying_allocation->size(),
underlying_allocation->place()),
underlying_allocation_(std::move(underlying_allocation)) {}
std::unique_ptr<Allocation> underlying_allocation_;
};
class RetryAllocator : public MannualFreeAllocator {
public:
RetryAllocator(std::unique_ptr<Allocator>&& allocator, size_t retry_ms)
......@@ -56,9 +45,8 @@ class RetryAllocator : public MannualFreeAllocator {
}
protected:
void Free(MannualFreeAllocation* allocation) override;
MannualFreeAllocation* AllocateImpl(size_t size,
Allocator::Attr attr) override;
void Free(Allocation* allocation) override;
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
private:
std::unique_ptr<Allocator> underlying_allocator_;
......
......@@ -20,14 +20,12 @@ namespace paddle {
namespace memory {
namespace allocation {
class UnderlyingManualAllocation : public MannualFreeAllocation {
class UnderlyingManualAllocation : public Allocation {
public:
UnderlyingManualAllocation(MannualFreeAllocator* allocator,
std::unique_ptr<Allocation> allocation)
: MannualFreeAllocation(allocator, allocation->ptr(), allocation->size(),
allocation->place()),
explicit UnderlyingManualAllocation(AllocationPtr allocation)
: Allocation(allocation->ptr(), allocation->size(), allocation->place()),
allocation_(std::move(allocation)) {}
std::unique_ptr<Allocation> allocation_;
AllocationPtr allocation_;
};
} // namespace allocation
......
......@@ -18,10 +18,9 @@ namespace paddle {
namespace memory {
namespace allocation {
std::unique_ptr<Allocation> ZeroSizeAllocator::Allocate(size_t size,
Allocator::Attr attr) {
AllocationPtr ZeroSizeAllocator::Allocate(size_t size, Allocator::Attr attr) {
if (size == 0) {
return std::unique_ptr<Allocation>(new ZeroSizeAllocation(place_));
return AllocationPtr(new ZeroSizeAllocation(place_));
} else {
return underlying_allocator_->Allocate(size, attr);
}
......
......@@ -34,7 +34,7 @@ class ZeroSizeAllocator : public Allocator {
ZeroSizeAllocator(std::shared_ptr<Allocator> underlying_allocator,
const platform::Place& p)
: underlying_allocator_(std::move(underlying_allocator)), place_(p) {}
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override;
AllocationPtr Allocate(size_t size, Attr attr) override;
bool IsAllocThreadSafe() const override;
......
......@@ -294,13 +294,12 @@ std::shared_ptr<Allocation> AllocShared(const platform::Place& place,
}
}
std::unique_ptr<Allocation> Alloc(const platform::Place& place, size_t size,
Allocator::Attr attr) {
AllocationPtr Alloc(const platform::Place& place, size_t size,
Allocator::Attr attr) {
if (allocation::GetAllocatorStrategy() ==
allocation::AllocatorStrategy::kLegacy) {
void* p = boost::apply_visitor(legacy::AllocVisitor(size), place);
return std::unique_ptr<Allocation>(
new legacy::LegacyAllocation(p, size, place));
return AllocationPtr(new legacy::LegacyAllocation(p, size, place));
} else {
return allocation::AllocatorFacade::Instance().Alloc(place, size, attr);
}
......
......@@ -21,14 +21,14 @@ namespace paddle {
namespace memory {
using allocation::Allocation;
using allocation::Allocator;
using allocation::AllocationPtr;
extern std::shared_ptr<Allocation> AllocShared(
const platform::Place& place, size_t size,
Allocator::Attr attr = Allocator::kDefault);
extern std::unique_ptr<Allocation> Alloc(
const platform::Place& place, size_t size,
Allocator::Attr attr = Allocator::kDefault);
extern AllocationPtr Alloc(const platform::Place& place, size_t size,
Allocator::Attr attr = Allocator::kDefault);
namespace legacy {
......
......@@ -155,8 +155,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
const cudaDeviceProp* device_prop_; // not owned;
mutable void* scratch_;
mutable unsigned int* semaphore_;
mutable std::unordered_map<void*, std::unique_ptr<memory::Allocation>>
allocations_;
mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
};
CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册