提交 d93b2d03 编写于 作者: Y Yu Yang

Refine code

上级 ea81f8ee
...@@ -33,8 +33,7 @@ class AlignedAllocation : public Allocation { ...@@ -33,8 +33,7 @@ class AlignedAllocation : public Allocation {
"kAlignment must be 2^N"); "kAlignment must be 2^N");
public: public:
AlignedAllocation(std::unique_ptr<Allocation>&& underlying_allocation, AlignedAllocation(AllocationPtr&& underlying_allocation, size_t size)
size_t size)
: Allocation(AlignedPtr(underlying_allocation->ptr()), : Allocation(AlignedPtr(underlying_allocation->ptr()),
size + kAlignment - Offset(underlying_allocation->ptr()), size + kAlignment - Offset(underlying_allocation->ptr()),
underlying_allocation->place()), underlying_allocation->place()),
...@@ -59,7 +58,7 @@ class AlignedAllocation : public Allocation { ...@@ -59,7 +58,7 @@ class AlignedAllocation : public Allocation {
} }
} }
std::unique_ptr<Allocation> underlying_allocation_; AllocationPtr underlying_allocation_;
}; };
// Thin aligned allocator is trivial and used to generate a small size binary. // Thin aligned allocator is trivial and used to generate a small size binary.
...@@ -87,10 +86,10 @@ template <size_t kAlignment> ...@@ -87,10 +86,10 @@ template <size_t kAlignment>
class AlignedAllocator : public ThinAlignedAllocator { class AlignedAllocator : public ThinAlignedAllocator {
public: public:
using ThinAlignedAllocator::ThinAlignedAllocator; using ThinAlignedAllocator::ThinAlignedAllocator;
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override { AllocationPtr Allocate(size_t size, Attr attr) override {
auto raw_allocation = auto raw_allocation =
underlying_allocator_->Allocate(size + kAlignment, attr); underlying_allocator_->Allocate(size + kAlignment, attr);
return std::unique_ptr<Allocation>( return AllocationPtr(
new AlignedAllocation<kAlignment>(std::move(raw_allocation), size)); new AlignedAllocation<kAlignment>(std::move(raw_allocation), size));
} }
}; };
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator.h"
#include <functional>
namespace paddle { namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
...@@ -24,10 +26,20 @@ bool Allocator::IsAllocThreadSafe() const { return false; } ...@@ -24,10 +26,20 @@ bool Allocator::IsAllocThreadSafe() const { return false; }
const char* BadAlloc::what() const noexcept { return msg_.c_str(); } const char* BadAlloc::what() const noexcept { return msg_.c_str(); }
MannualFreeAllocation::~MannualFreeAllocation() { allocator_->Free(this); } AllocationPtr MannualFreeAllocator::Allocate(size_t size,
std::unique_ptr<Allocation> MannualFreeAllocator::Allocate( Allocator::Attr attr) {
size_t size, Allocator::Attr attr) { auto allocation = AllocateImpl(size, attr);
return std::unique_ptr<Allocation>(AllocateImpl(size, attr)); allocation->Deleter =
std::bind1st(std::mem_fn(&MannualFreeAllocator::Free), this);
return AllocationPtr(allocation);
}
void AllocationDeleter::operator()(Allocation* allocation) const {
if (allocation->Deleter) {
auto deleter = std::move(allocation->Deleter);
deleter(allocation);
} else {
delete allocation;
}
} }
} // namespace allocation } // namespace allocation
} // namespace memory } // namespace memory
......
...@@ -31,6 +31,11 @@ class BadAlloc : public std::exception { ...@@ -31,6 +31,11 @@ class BadAlloc : public std::exception {
std::string msg_; std::string msg_;
}; };
class Allocation;
struct AllocationDeleter {
void operator()(Allocation* allocation) const;
};
// Allocation is the object holding the actually pointer. Use // Allocation is the object holding the actually pointer. Use
// `Allocation::ptr()` will returns the pointer that allocated. // `Allocation::ptr()` will returns the pointer that allocated.
// //
...@@ -67,12 +72,16 @@ class Allocation { ...@@ -67,12 +72,16 @@ class Allocation {
virtual ~Allocation(); virtual ~Allocation();
std::function<void(Allocation*)> Deleter;
private: private:
void* ptr_; void* ptr_;
size_t size_; size_t size_;
platform::Place place_; platform::Place place_;
}; };
using AllocationPtr = std::unique_ptr<Allocation, AllocationDeleter>;
// Base interface class of memory Allocator. // Base interface class of memory Allocator.
// To allocate a memory, allocator needs two parameters: // To allocate a memory, allocator needs two parameters:
// 1. size of bytes. // 1. size of bytes.
...@@ -114,36 +123,22 @@ class Allocator { ...@@ -114,36 +123,22 @@ class Allocator {
// Allocate an allocation. Note the return allocation might need to be freed // Allocate an allocation. Note the return allocation might need to be freed
// manually if the Allocator is an `UnmanagedAllocator`. // manually if the Allocator is an `UnmanagedAllocator`.
virtual std::unique_ptr<Allocation> Allocate( virtual AllocationPtr Allocate(size_t size,
size_t size, Allocator::Attr attr = kDefault) = 0; Allocator::Attr attr = kDefault) = 0;
// True if the `Allocate` is thread safe. // True if the `Allocate` is thread safe.
virtual bool IsAllocThreadSafe() const; virtual bool IsAllocThreadSafe() const;
}; };
class MannualFreeAllocator;
class MannualFreeAllocation : public Allocation {
public:
MannualFreeAllocation(MannualFreeAllocator* allocator, void* ptr, size_t size,
platform::Place place)
: Allocation(ptr, size, place), allocator_(allocator) {}
~MannualFreeAllocation();
private:
MannualFreeAllocator* allocator_;
};
// User need to invoke `Free` or `FreeUniquePtr` manually if allocated by // User need to invoke `Free` or `FreeUniquePtr` manually if allocated by
// a manally managed allocator. // a manally managed allocator.
class MannualFreeAllocator : public Allocator { class MannualFreeAllocator : public Allocator {
public: public:
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) final; AllocationPtr Allocate(size_t size, Attr attr) final;
protected: protected:
virtual void Free(MannualFreeAllocation* allocation) = 0; virtual void Free(Allocation* allocation) = 0;
virtual MannualFreeAllocation* AllocateImpl(size_t size, virtual Allocation* AllocateImpl(size_t size, Allocator::Attr attr) = 0;
Allocator::Attr attr) = 0;
friend class MannualFreeAllocation; friend class MannualFreeAllocation;
}; };
......
...@@ -49,7 +49,7 @@ class CPUManagedAllocator : public Allocator { ...@@ -49,7 +49,7 @@ class CPUManagedAllocator : public Allocator {
public: public:
CPUManagedAllocator() : normal_allocator_(new CPUAllocator()) {} CPUManagedAllocator() : normal_allocator_(new CPUAllocator()) {}
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override { AllocationPtr Allocate(size_t size, Attr attr) override {
return normal_allocator_->Allocate(size, attr); return normal_allocator_->Allocate(size, attr);
} }
...@@ -103,7 +103,7 @@ class ChunkedManagedAllocator : public Allocator { ...@@ -103,7 +103,7 @@ class ChunkedManagedAllocator : public Allocator {
raw_allocator_.reset(); raw_allocator_.reset();
} }
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override { AllocationPtr Allocate(size_t size, Attr attr) override {
return default_allocator_->Allocate(size, attr); return default_allocator_->Allocate(size, attr);
} }
...@@ -131,7 +131,7 @@ class ChunkedManagedAllocator : public Allocator { ...@@ -131,7 +131,7 @@ class ChunkedManagedAllocator : public Allocator {
protected: protected:
size_t max_chunk_size_; size_t max_chunk_size_;
int64_t retry_time_; int64_t retry_time_;
std::vector<std::unique_ptr<Allocation>> chunks_; std::vector<AllocationPtr> chunks_;
std::shared_ptr<Allocator> raw_allocator_; std::shared_ptr<Allocator> raw_allocator_;
std::shared_ptr<Allocator> default_allocator_; std::shared_ptr<Allocator> default_allocator_;
}; };
...@@ -236,11 +236,11 @@ AllocatorFacade& AllocatorFacade::Instance() { ...@@ -236,11 +236,11 @@ AllocatorFacade& AllocatorFacade::Instance() {
std::shared_ptr<Allocation> AllocatorFacade::AllocShared( std::shared_ptr<Allocation> AllocatorFacade::AllocShared(
const platform::Place& place, size_t size, Allocator::Attr attr) { const platform::Place& place, size_t size, Allocator::Attr attr) {
return std::shared_ptr<Allocation>( return std::shared_ptr<Allocation>(
m_->allocators_.at(place)->Allocate(size, attr).release()); m_->allocators_.at(place)->Allocate(size, attr).release(),
AllocationDeleter());
} }
std::unique_ptr<Allocation> AllocatorFacade::Alloc(const platform::Place& place, AllocationPtr AllocatorFacade::Alloc(const platform::Place& place, size_t size,
size_t size,
Allocator::Attr attr) { Allocator::Attr attr) {
return m_->allocators_.at(place)->Allocate(size, attr); return m_->allocators_.at(place)->Allocate(size, attr);
} }
......
...@@ -43,7 +43,7 @@ class AllocatorFacade { ...@@ -43,7 +43,7 @@ class AllocatorFacade {
Allocator::Attr attr = Allocator::kDefault); Allocator::Attr attr = Allocator::kDefault);
// Allocate a unique allocation. // Allocate a unique allocation.
std::unique_ptr<Allocation> Alloc(const platform::Place& place, size_t size, AllocationPtr Alloc(const platform::Place& place, size_t size,
Allocator::Attr attr = Allocator::kDefault); Allocator::Attr attr = Allocator::kDefault);
// TODO(yy): Allocate a Copy-On-Write allocation? // TODO(yy): Allocate a Copy-On-Write allocation?
......
...@@ -18,8 +18,8 @@ namespace paddle { ...@@ -18,8 +18,8 @@ namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
std::unique_ptr<Allocation> AutoIncrementAllocator::Allocate( AllocationPtr AutoIncrementAllocator::Allocate(size_t size,
size_t size, Allocator::Attr attr) { Allocator::Attr attr) {
auto cur = prev_success_allocator_.load(); auto cur = prev_success_allocator_.load();
size_t retry_count = allocator_num_.load(); size_t retry_count = allocator_num_.load();
size_t allocator_num = retry_count; size_t allocator_num = retry_count;
......
...@@ -54,7 +54,7 @@ class AutoIncrementAllocator : public Allocator { ...@@ -54,7 +54,7 @@ class AutoIncrementAllocator : public Allocator {
explicit AutoIncrementAllocator(AllocatorCreator&& creator, size_t capacity) explicit AutoIncrementAllocator(AllocatorCreator&& creator, size_t capacity)
: creator_(std::move(creator)), underlying_allocators_(capacity) {} : creator_(std::move(creator)), underlying_allocators_(capacity) {}
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override; AllocationPtr Allocate(size_t size, Attr attr) override;
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
......
...@@ -109,7 +109,7 @@ size_t BestFitAllocator::NumFreeChunks() const { ...@@ -109,7 +109,7 @@ size_t BestFitAllocator::NumFreeChunks() const {
} }
return num; return num;
} }
void BestFitAllocator::Free(MannualFreeAllocation* allocation) { void BestFitAllocator::Free(Allocation* allocation) {
auto* bf_allocation = dynamic_cast<BestFitAllocation*>(allocation); auto* bf_allocation = dynamic_cast<BestFitAllocation*>(allocation);
auto chunk_it = bf_allocation->ChunkIterator(); auto chunk_it = bf_allocation->ChunkIterator();
PADDLE_ENFORCE(!chunk_it->is_free); PADDLE_ENFORCE(!chunk_it->is_free);
...@@ -136,9 +136,9 @@ void BestFitAllocator::Free(MannualFreeAllocation* allocation) { ...@@ -136,9 +136,9 @@ void BestFitAllocator::Free(MannualFreeAllocation* allocation) {
} }
InsertFreeNode(chunk_it); InsertFreeNode(chunk_it);
delete allocation;
} }
MannualFreeAllocation* BestFitAllocator::AllocateImpl(size_t size, Allocation* BestFitAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
Allocator::Attr attr) {
auto highest_set_bit = static_cast<size_t>(HighestBitPos(size)); auto highest_set_bit = static_cast<size_t>(HighestBitPos(size));
MapIt map_it; MapIt map_it;
for (; highest_set_bit < free_chunks_.size(); ++highest_set_bit) { for (; highest_set_bit < free_chunks_.size(); ++highest_set_bit) {
...@@ -158,8 +158,7 @@ MannualFreeAllocation* BestFitAllocator::AllocateImpl(size_t size, ...@@ -158,8 +158,7 @@ MannualFreeAllocation* BestFitAllocator::AllocateImpl(size_t size,
BestFitAllocation::BestFitAllocation( BestFitAllocation::BestFitAllocation(
paddle::memory::allocation::BestFitAllocator* allocator, paddle::memory::allocation::BestFitAllocator* allocator,
typename details::ChunkList::iterator chunk_it) typename details::ChunkList::iterator chunk_it)
: MannualFreeAllocation( : Allocation(reinterpret_cast<void*>(
allocator, reinterpret_cast<void*>(
reinterpret_cast<uintptr_t>(allocator->BasePtr()) + reinterpret_cast<uintptr_t>(allocator->BasePtr()) +
chunk_it->offset_), chunk_it->offset_),
chunk_it->size_, allocator->Place()), chunk_it->size_, allocator->Place()),
......
...@@ -71,7 +71,7 @@ using FreeChunkBin = ...@@ -71,7 +71,7 @@ using FreeChunkBin =
class BestFitAllocator; class BestFitAllocator;
// The BestFitAllocation maintain the List Node iterator. // The BestFitAllocation maintain the List Node iterator.
class BestFitAllocation : public MannualFreeAllocation { class BestFitAllocation : public Allocation {
private: private:
using ListIt = typename details::ChunkList::iterator; using ListIt = typename details::ChunkList::iterator;
...@@ -123,9 +123,8 @@ class BestFitAllocator : public MannualFreeAllocator { ...@@ -123,9 +123,8 @@ class BestFitAllocator : public MannualFreeAllocator {
void InsertFreeNode(const ListIt& it); void InsertFreeNode(const ListIt& it);
protected: protected:
void Free(MannualFreeAllocation* allocation) override; void Free(Allocation* allocation) override;
MannualFreeAllocation* AllocateImpl(size_t size, Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
Allocator::Attr attr) override;
private: private:
Allocation* allocation_; // not owned Allocation* allocation_; // not owned
......
...@@ -49,33 +49,28 @@ void BufferedAllocator::FreeCache(size_t size) { ...@@ -49,33 +49,28 @@ void BufferedAllocator::FreeCache(size_t size) {
bool BufferedAllocator::IsAllocThreadSafe() const { bool BufferedAllocator::IsAllocThreadSafe() const {
return this->underlying_allocator_->IsAllocThreadSafe(); return this->underlying_allocator_->IsAllocThreadSafe();
} }
void BufferedAllocator::Free(MannualFreeAllocation *allocation) { void BufferedAllocator::Free(Allocation *allocation) {
platform::LockGuardPtr<std::mutex> guard(mtx_); platform::LockGuardPtr<std::mutex> guard(mtx_);
allocations_.emplace(allocation->size(), AllocationPtr(allocation));
std::unique_ptr<Allocation> new_allocation(new UnderlyingManualAllocation(
this, std::move(reinterpret_cast<UnderlyingManualAllocation *>(allocation)
->allocation_)));
allocations_.emplace(allocation->size(), std::move(new_allocation));
} }
MannualFreeAllocation *BufferedAllocator::AllocateImpl(size_t size, Allocation *BufferedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
Allocator::Attr attr) {
{ {
platform::LockGuardPtr<std::mutex> guard(mtx_); platform::LockGuardPtr<std::mutex> guard(mtx_);
auto it = allocations_.lower_bound(size); auto it = allocations_.lower_bound(size);
if (it != allocations_.end() && it->first < size * 2) { if (it != allocations_.end() && it->first < size * 2) {
std::unique_ptr<Allocation> result(std::move(it->second)); AllocationPtr result(std::move(it->second));
allocations_.erase(it); allocations_.erase(it);
return new UnderlyingManualAllocation(this, std::move(result)); return new UnderlyingManualAllocation(std::move(result));
} }
} }
try { try {
return new UnderlyingManualAllocation( return new UnderlyingManualAllocation(
this, underlying_allocator_->Allocate(size, attr)); underlying_allocator_->Allocate(size, attr));
} catch (BadAlloc &) { } catch (BadAlloc &) {
FreeCache(size); FreeCache(size);
return new UnderlyingManualAllocation( return new UnderlyingManualAllocation(
this, underlying_allocator_->Allocate(size, attr)); underlying_allocator_->Allocate(size, attr));
} }
} }
......
...@@ -50,13 +50,12 @@ class BufferedAllocator : public MannualFreeAllocator { ...@@ -50,13 +50,12 @@ class BufferedAllocator : public MannualFreeAllocator {
void FreeCache(size_t size); void FreeCache(size_t size);
protected: protected:
void Free(MannualFreeAllocation *allocation) override; void Free(Allocation *allocation) override;
MannualFreeAllocation *AllocateImpl(size_t size, Allocation *AllocateImpl(size_t size, Allocator::Attr attr) override;
Allocator::Attr attr) override;
private: private:
std::unique_ptr<Allocator> underlying_allocator_; std::unique_ptr<Allocator> underlying_allocator_;
std::multimap<size_t, std::unique_ptr<Allocation>> allocations_; std::multimap<size_t, AllocationPtr> allocations_;
std::unique_ptr<std::mutex> mtx_; std::unique_ptr<std::mutex> mtx_;
}; };
......
...@@ -24,8 +24,8 @@ ConditionalAllocator& ConditionalAllocator::AddAllocator( ...@@ -24,8 +24,8 @@ ConditionalAllocator& ConditionalAllocator::AddAllocator(
underlying_allocators_.emplace_back(std::move(func), std::move(allocator)); underlying_allocators_.emplace_back(std::move(func), std::move(allocator));
return *this; return *this;
} }
std::unique_ptr<Allocation> ConditionalAllocator::Allocate( AllocationPtr ConditionalAllocator::Allocate(size_t size,
size_t size, Allocator::Attr attr) { Allocator::Attr attr) {
for (auto& pair : underlying_allocators_) { for (auto& pair : underlying_allocators_) {
if (pair.first(size, attr)) { if (pair.first(size, attr)) {
return pair.second->Allocate(size, attr); return pair.second->Allocate(size, attr);
......
...@@ -45,7 +45,7 @@ class ConditionalAllocator : public Allocator { ...@@ -45,7 +45,7 @@ class ConditionalAllocator : public Allocator {
ConditionalAllocator& AddAllocator(std::function<bool(size_t, Attr)> func, ConditionalAllocator& AddAllocator(std::function<bool(size_t, Attr)> func,
std::shared_ptr<Allocator> allocator); std::shared_ptr<Allocator> allocator);
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override; AllocationPtr Allocate(size_t size, Attr attr) override;
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
......
...@@ -20,26 +20,25 @@ namespace paddle { ...@@ -20,26 +20,25 @@ namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
CPUAllocation::CPUAllocation( CPUAllocation::CPUAllocation(void *ptr, size_t size)
paddle::memory::allocation::CPUAllocator *allocator, void *ptr, size_t size) : Allocation(ptr, size, platform::CPUPlace()) {}
: MannualFreeAllocation(allocator, ptr, size, platform::CPUPlace()) {}
bool CPUAllocator::IsAllocThreadSafe() const { return true; } bool CPUAllocator::IsAllocThreadSafe() const { return true; }
void CPUAllocator::Free(MannualFreeAllocation *allocation) { void CPUAllocator::Free(Allocation *allocation) {
PADDLE_ENFORCE_NOT_NULL(dynamic_cast<CPUAllocation *>(allocation)); PADDLE_ENFORCE_NOT_NULL(dynamic_cast<CPUAllocation *>(allocation));
free(allocation->ptr()); free(allocation->ptr());
delete allocation;
} }
MannualFreeAllocation *CPUAllocator::AllocateImpl(size_t size, Allocation *CPUAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
Allocator::Attr attr) {
void *ptr; void *ptr;
auto status = posix_memalign(&ptr, kAlignment, size); auto status = posix_memalign(&ptr, kAlignment, size);
if (UNLIKELY(status) != 0) { if (UNLIKELY(status) != 0) {
throw BadAlloc(string::Sprintf("Cannot allocate cpu memory %d. Errno is %d", throw BadAlloc(string::Sprintf("Cannot allocate cpu memory %d. Errno is %d",
size, status)); size, status));
} }
return new CPUAllocation(this, ptr, size); return new CPUAllocation(ptr, size);
} }
} // namespace allocation } // namespace allocation
} // namespace memory } // namespace memory
......
...@@ -26,9 +26,9 @@ namespace allocation { ...@@ -26,9 +26,9 @@ namespace allocation {
// NOTE(yy): It is no need to use `BestFitAllocator` in CPU. We can import // NOTE(yy): It is no need to use `BestFitAllocator` in CPU. We can import
// an open-sourced allocator into Paddle. // an open-sourced allocator into Paddle.
class CPUAllocator; class CPUAllocator;
class CPUAllocation : public MannualFreeAllocation { class CPUAllocation : public Allocation {
public: public:
CPUAllocation(CPUAllocator* allocator, void* ptr, size_t size); CPUAllocation(void* ptr, size_t size);
}; };
class CPUAllocator : public MannualFreeAllocator { class CPUAllocator : public MannualFreeAllocator {
...@@ -37,9 +37,8 @@ class CPUAllocator : public MannualFreeAllocator { ...@@ -37,9 +37,8 @@ class CPUAllocator : public MannualFreeAllocator {
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
protected: protected:
void Free(MannualFreeAllocation* allocation) override; void Free(Allocation* allocation) override;
MannualFreeAllocation* AllocateImpl(size_t size, Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
Allocator::Attr attr) override;
}; };
} // namespace allocation } // namespace allocation
} // namespace memory } // namespace memory
......
...@@ -22,7 +22,17 @@ ...@@ -22,7 +22,17 @@
namespace paddle { namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
std::unique_ptr<Allocation> CUDAAllocator::Allocate(size_t size, Attr attr) { bool CUDAAllocator::IsAllocThreadSafe() const { return true; }
void CUDAAllocator::Free(Allocation* allocation) {
platform::CUDADeviceGuard guard(place_.device);
auto* cuda_allocation = dynamic_cast<CUDAAllocation*>(allocation);
PADDLE_ENFORCE_NOT_NULL(cuda_allocation);
PADDLE_ENFORCE_EQ(boost::get<platform::CUDAPlace>(cuda_allocation->place()),
place_);
PADDLE_ENFORCE(cudaFree(allocation->ptr()));
delete allocation;
}
Allocation* CUDAAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
platform::CUDADeviceGuard guard(place_.device); platform::CUDADeviceGuard guard(place_.device);
void* ptr; void* ptr;
auto status = cudaMalloc(&ptr, size); auto status = cudaMalloc(&ptr, size);
...@@ -31,19 +41,8 @@ std::unique_ptr<Allocation> CUDAAllocator::Allocate(size_t size, Attr attr) { ...@@ -31,19 +41,8 @@ std::unique_ptr<Allocation> CUDAAllocator::Allocate(size_t size, Attr attr) {
"Cannot allocate %d on GPU %d, cuda status %d, %s", size, place_.device, "Cannot allocate %d on GPU %d, cuda status %d, %s", size, place_.device,
status, cudaGetErrorString(status))); status, cudaGetErrorString(status)));
} }
return std::unique_ptr<Allocation>( return new CUDAAllocation(ptr, size, platform::Place(place_));
new CUDAAllocation(ptr, size, platform::Place(place_)));
} }
void CUDAAllocator::FreeUniquePtr(std::unique_ptr<Allocation> allocation) {
platform::CUDADeviceGuard guard(place_.device);
auto* cuda_allocation = dynamic_cast<CUDAAllocation*>(allocation.get());
PADDLE_ENFORCE_NOT_NULL(cuda_allocation);
PADDLE_ENFORCE_EQ(boost::get<platform::CUDAPlace>(cuda_allocation->place()),
place_);
PADDLE_ENFORCE(cudaFree(allocation->ptr()));
}
bool CUDAAllocator::IsAllocThreadSafe() const { return true; }
} // namespace allocation } // namespace allocation
} // namespace memory } // namespace memory
} // namespace paddle } // namespace paddle
...@@ -27,16 +27,17 @@ class CUDAAllocation : public Allocation { ...@@ -27,16 +27,17 @@ class CUDAAllocation : public Allocation {
using Allocation::Allocation; using Allocation::Allocation;
}; };
class CUDAAllocator : public UnmanagedAllocator { class CUDAAllocator : public MannualFreeAllocator {
public: public:
explicit CUDAAllocator(const platform::CUDAPlace& place) : place_(place) {} explicit CUDAAllocator(const platform::CUDAPlace& place) : place_(place) {}
explicit CUDAAllocator(const platform::Place& place) explicit CUDAAllocator(const platform::Place& place)
: place_(boost::get<platform::CUDAPlace>(place)) {} : place_(boost::get<platform::CUDAPlace>(place)) {}
std::unique_ptr<Allocation> Allocate(size_t size,
Attr attr = kDefault) override;
void FreeUniquePtr(std::unique_ptr<Allocation> allocation) override;
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
protected:
void Free(Allocation* allocation) override;
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
private: private:
platform::CUDAPlace place_; platform::CUDAPlace place_;
}; };
......
...@@ -30,16 +30,18 @@ LockedAllocator::LockedAllocator( ...@@ -30,16 +30,18 @@ LockedAllocator::LockedAllocator(
mtx_.reset(new std::mutex()); mtx_.reset(new std::mutex());
} }
} }
void LockedAllocator::Free(MannualFreeAllocation *allocation) { void LockedAllocator::Free(Allocation *allocation) {
{
platform::LockGuardPtr<std::mutex> guard(mtx_); platform::LockGuardPtr<std::mutex> guard(mtx_);
reinterpret_cast<UnderlyingManualAllocation *>(allocation) reinterpret_cast<UnderlyingManualAllocation *>(allocation)
->allocation_.reset(); ->allocation_.reset(); // Destroy inner allocation
}
delete allocation;
} }
MannualFreeAllocation *LockedAllocator::AllocateImpl(size_t size, Allocation *LockedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
Allocator::Attr attr) {
platform::LockGuardPtr<std::mutex> guard(mtx_); platform::LockGuardPtr<std::mutex> guard(mtx_);
return new UnderlyingManualAllocation( return new UnderlyingManualAllocation(
this, underlying_allocator_->Allocate(size, attr)); underlying_allocator_->Allocate(size, attr));
} }
} // namespace allocation } // namespace allocation
} // namespace memory } // namespace memory
......
...@@ -28,9 +28,8 @@ class LockedAllocator : public MannualFreeAllocator { ...@@ -28,9 +28,8 @@ class LockedAllocator : public MannualFreeAllocator {
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
protected: protected:
void Free(MannualFreeAllocation *allocation) override; void Free(Allocation *allocation) override;
MannualFreeAllocation *AllocateImpl(size_t size, Allocation *AllocateImpl(size_t size, Allocator::Attr attr) override;
Allocator::Attr attr) override;
private: private:
std::unique_ptr<Allocator> underlying_allocator_; std::unique_ptr<Allocator> underlying_allocator_;
......
...@@ -19,25 +19,22 @@ ...@@ -19,25 +19,22 @@
namespace paddle { namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
bool CPUPinnedAllocator::IsAllocThreadSafe() const { return true; }
std::unique_ptr<Allocation> CPUPinnedAllocator::Allocate(size_t size, void CPUPinnedAllocator::Free(Allocation *allocation) {
PADDLE_ENFORCE_NOT_NULL(dynamic_cast<CPUPinnedAllocation *>(allocation));
PADDLE_ENFORCE(cudaFreeHost(allocation->ptr()));
delete allocation;
}
Allocation *CPUPinnedAllocator::AllocateImpl(size_t size,
Allocator::Attr attr) { Allocator::Attr attr) {
// PADDLE_ENFORCE_EQ( // PADDLE_ENFORCE_EQ(
// attr, kCrossDevice, // attr, kCrossDevice,
// "CPUPinnedAllocator should be used for Cross-Device Communication"); // "CPUPinnedAllocator should be used for Cross-Device Communication");
void* ptr; void *ptr;
PADDLE_ENFORCE(cudaMallocHost(&ptr, size)); PADDLE_ENFORCE(cudaMallocHost(&ptr, size));
return std::unique_ptr<CPUPinnedAllocation>( return new CPUPinnedAllocation(ptr, size);
new CPUPinnedAllocation(ptr, size));
} }
void CPUPinnedAllocator::FreeUniquePtr(std::unique_ptr<Allocation> allocation) {
PADDLE_ENFORCE_NOT_NULL(dynamic_cast<CPUPinnedAllocation*>(allocation.get()));
PADDLE_ENFORCE(cudaFreeHost(allocation->ptr()));
}
bool CPUPinnedAllocator::IsAllocThreadSafe() const { return true; }
} // namespace allocation } // namespace allocation
} // namespace memory } // namespace memory
} // namespace paddle } // namespace paddle
...@@ -22,15 +22,17 @@ namespace allocation { ...@@ -22,15 +22,17 @@ namespace allocation {
// Allocator uses `cudaMallocHost` // Allocator uses `cudaMallocHost`
class CPUPinnedAllocation : public Allocation { class CPUPinnedAllocation : public Allocation {
public: public:
CPUPinnedAllocation(void* ptr, size_t size) CPUPinnedAllocation(void *ptr, size_t size)
: Allocation(ptr, size, platform::CUDAPinnedPlace()) {} : Allocation(ptr, size, platform::CUDAPinnedPlace()) {}
}; };
class CPUPinnedAllocator : public UnmanagedAllocator { class CPUPinnedAllocator : public MannualFreeAllocator {
public: public:
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override;
void FreeUniquePtr(std::unique_ptr<Allocation> allocation) override;
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
protected:
void Free(Allocation *allocation) override;
Allocation *AllocateImpl(size_t size, Allocator::Attr attr) override;
}; };
} // namespace allocation } // namespace allocation
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/memory/allocation/retry_allocator.h" #include "paddle/fluid/memory/allocation/retry_allocator.h"
#include "paddle/fluid/memory/allocation/underlying_manual_allocation.h"
namespace paddle { namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
...@@ -22,21 +22,22 @@ bool RetryAllocator::IsAllocThreadSafe() const { ...@@ -22,21 +22,22 @@ bool RetryAllocator::IsAllocThreadSafe() const {
return underlying_allocator_->IsAllocThreadSafe(); return underlying_allocator_->IsAllocThreadSafe();
} }
void RetryAllocator::Free(MannualFreeAllocation* allocation) { void RetryAllocator::Free(Allocation* allocation) {
reinterpret_cast<RetryAllocation*>(allocation) // Delete underlying allocation first.
->underlying_allocation_.reset(); reinterpret_cast<UnderlyingManualAllocation*>(allocation)
->allocation_.reset();
{ {
// notify all waited allocators, they can try to allocate memory after free. // notify all waited allocators, they can try to allocate memory after free.
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
cv_.notify_all(); cv_.notify_all();
} }
delete allocation;
} }
MannualFreeAllocation* RetryAllocator::AllocateImpl(size_t size, Allocation* RetryAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
Allocator::Attr attr) {
auto alloc_func = [&, this]() { auto alloc_func = [&, this]() {
return new RetryAllocation(underlying_allocator_->Allocate(size, attr), return new UnderlyingManualAllocation(
this); underlying_allocator_->Allocate(size, attr));
}; };
// In fact, we can unify the code of allocation success and failure // In fact, we can unify the code of allocation success and failure
// But it would add lock even when allocation success at the first time // But it would add lock even when allocation success at the first time
......
...@@ -26,17 +26,6 @@ namespace allocation { ...@@ -26,17 +26,6 @@ namespace allocation {
class RetryAllocator; class RetryAllocator;
class RetryAllocation : public MannualFreeAllocation {
public:
RetryAllocation(std::unique_ptr<Allocation>&& underlying_allocation,
MannualFreeAllocator* allocator)
: MannualFreeAllocation(allocator, underlying_allocation->ptr(),
underlying_allocation->size(),
underlying_allocation->place()),
underlying_allocation_(std::move(underlying_allocation)) {}
std::unique_ptr<Allocation> underlying_allocation_;
};
class RetryAllocator : public MannualFreeAllocator { class RetryAllocator : public MannualFreeAllocator {
public: public:
RetryAllocator(std::unique_ptr<Allocator>&& allocator, size_t retry_ms) RetryAllocator(std::unique_ptr<Allocator>&& allocator, size_t retry_ms)
...@@ -56,9 +45,8 @@ class RetryAllocator : public MannualFreeAllocator { ...@@ -56,9 +45,8 @@ class RetryAllocator : public MannualFreeAllocator {
} }
protected: protected:
void Free(MannualFreeAllocation* allocation) override; void Free(Allocation* allocation) override;
MannualFreeAllocation* AllocateImpl(size_t size, Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
Allocator::Attr attr) override;
private: private:
std::unique_ptr<Allocator> underlying_allocator_; std::unique_ptr<Allocator> underlying_allocator_;
......
...@@ -20,14 +20,12 @@ namespace paddle { ...@@ -20,14 +20,12 @@ namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
class UnderlyingManualAllocation : public MannualFreeAllocation { class UnderlyingManualAllocation : public Allocation {
public: public:
UnderlyingManualAllocation(MannualFreeAllocator* allocator, explicit UnderlyingManualAllocation(AllocationPtr allocation)
std::unique_ptr<Allocation> allocation) : Allocation(allocation->ptr(), allocation->size(), allocation->place()),
: MannualFreeAllocation(allocator, allocation->ptr(), allocation->size(),
allocation->place()),
allocation_(std::move(allocation)) {} allocation_(std::move(allocation)) {}
std::unique_ptr<Allocation> allocation_; AllocationPtr allocation_;
}; };
} // namespace allocation } // namespace allocation
......
...@@ -18,10 +18,9 @@ namespace paddle { ...@@ -18,10 +18,9 @@ namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
std::unique_ptr<Allocation> ZeroSizeAllocator::Allocate(size_t size, AllocationPtr ZeroSizeAllocator::Allocate(size_t size, Allocator::Attr attr) {
Allocator::Attr attr) {
if (size == 0) { if (size == 0) {
return std::unique_ptr<Allocation>(new ZeroSizeAllocation(place_)); return AllocationPtr(new ZeroSizeAllocation(place_));
} else { } else {
return underlying_allocator_->Allocate(size, attr); return underlying_allocator_->Allocate(size, attr);
} }
......
...@@ -34,7 +34,7 @@ class ZeroSizeAllocator : public Allocator { ...@@ -34,7 +34,7 @@ class ZeroSizeAllocator : public Allocator {
ZeroSizeAllocator(std::shared_ptr<Allocator> underlying_allocator, ZeroSizeAllocator(std::shared_ptr<Allocator> underlying_allocator,
const platform::Place& p) const platform::Place& p)
: underlying_allocator_(std::move(underlying_allocator)), place_(p) {} : underlying_allocator_(std::move(underlying_allocator)), place_(p) {}
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override; AllocationPtr Allocate(size_t size, Attr attr) override;
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
......
...@@ -294,13 +294,12 @@ std::shared_ptr<Allocation> AllocShared(const platform::Place& place, ...@@ -294,13 +294,12 @@ std::shared_ptr<Allocation> AllocShared(const platform::Place& place,
} }
} }
std::unique_ptr<Allocation> Alloc(const platform::Place& place, size_t size, AllocationPtr Alloc(const platform::Place& place, size_t size,
Allocator::Attr attr) { Allocator::Attr attr) {
if (allocation::GetAllocatorStrategy() == if (allocation::GetAllocatorStrategy() ==
allocation::AllocatorStrategy::kLegacy) { allocation::AllocatorStrategy::kLegacy) {
void* p = boost::apply_visitor(legacy::AllocVisitor(size), place); void* p = boost::apply_visitor(legacy::AllocVisitor(size), place);
return std::unique_ptr<Allocation>( return AllocationPtr(new legacy::LegacyAllocation(p, size, place));
new legacy::LegacyAllocation(p, size, place));
} else { } else {
return allocation::AllocatorFacade::Instance().Alloc(place, size, attr); return allocation::AllocatorFacade::Instance().Alloc(place, size, attr);
} }
......
...@@ -21,13 +21,13 @@ namespace paddle { ...@@ -21,13 +21,13 @@ namespace paddle {
namespace memory { namespace memory {
using allocation::Allocation; using allocation::Allocation;
using allocation::Allocator; using allocation::Allocator;
using allocation::AllocationPtr;
extern std::shared_ptr<Allocation> AllocShared( extern std::shared_ptr<Allocation> AllocShared(
const platform::Place& place, size_t size, const platform::Place& place, size_t size,
Allocator::Attr attr = Allocator::kDefault); Allocator::Attr attr = Allocator::kDefault);
extern std::unique_ptr<Allocation> Alloc( extern AllocationPtr Alloc(const platform::Place& place, size_t size,
const platform::Place& place, size_t size,
Allocator::Attr attr = Allocator::kDefault); Allocator::Attr attr = Allocator::kDefault);
namespace legacy { namespace legacy {
......
...@@ -155,8 +155,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { ...@@ -155,8 +155,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
const cudaDeviceProp* device_prop_; // not owned; const cudaDeviceProp* device_prop_; // not owned;
mutable void* scratch_; mutable void* scratch_;
mutable unsigned int* semaphore_; mutable unsigned int* semaphore_;
mutable std::unordered_map<void*, std::unique_ptr<memory::Allocation>> mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
allocations_;
}; };
CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place) CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册