提交 ab87a882 编写于 作者: Y Yu Yang

Polish retry allocator

上级 2002e71d
......@@ -20,67 +20,67 @@ namespace allocation {
RetryAllocation::~RetryAllocation() {
auto allocator = retry_allocator_.lock();
{
// release allocation first
// Allocator is destroyed before allocation. Should not happened usually.
if (UNLIKELY(allocator == nullptr)) return;
allocator->underlying_allocator_->Free(underlying_allocation_.release());
}
{
// notify all waited allocators
std::lock_guard<std::mutex> lock(allocator->mutex_);
allocator->cv_.notify_all();
}
allocator->FreeUnderlyingAllocation(std::move(underlying_allocation_));
}
bool RetryAllocator::IsAllocThreadSafe() const { return true; }
std::shared_ptr<Allocation> RetryAllocator::AllocateShared(
size_t size, Allocator::Attr attr) {
return std::shared_ptr<Allocation>(Allocate(size, attr));
return std::shared_ptr<Allocation>(AllocateImpl(size, attr));
}
std::unique_ptr<Allocation> RetryAllocator::Allocate(size_t size,
Allocator::Attr attr) {
return std::unique_ptr<Allocation>(AllocateImpl(size, attr));
}
Allocation* RetryAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
auto alloc_func = [&, this]() {
return new RetryAllocation(underlying_allocator_->Allocate(size, attr),
this->shared_from_this());
};
// In fact, we can unify the code of allocation success and failure
// But it would add lock even when allocation success at the first time
std::unique_ptr<Allocation> ret;
try {
ret.reset(alloc_func());
} catch (BadAlloc &) {
return alloc_func();
} catch (BadAlloc& bad_alloc) {
{
// We can just write allocation retry inside the predicate function of
// wait_until
// But it needs to acquire the lock when executing predicate function
// For better performance, we use loop here
std::exception_ptr ex;
auto end_time = std::chrono::high_resolution_clock::now() + retry_time_;
std::cv_status status;
do {
{
auto wait_until = [&, this] {
std::unique_lock<std::mutex> lock(mutex_);
status = cv_.wait_until(lock, end_time);
}
return cv_.wait_until(lock, end_time);
};
while (wait_until() != std::cv_status::timeout) {
try {
ret.reset(alloc_func());
} catch (BadAlloc &) {
ex = std::current_exception();
return alloc_func();
} catch (BadAlloc& ex) {
bad_alloc = ex;
} catch (...) {
std::rethrow_exception(std::current_exception());
throw;
}
}
} while (ret == nullptr && status != std::cv_status::timeout);
if (ret == nullptr) std::rethrow_exception(ex);
throw; // rethrow the original exception or throw the internal bad_alloc
}
} catch (...) {
std::rethrow_exception(std::current_exception());
throw;
}
}
void RetryAllocator::FreeUnderlyingAllocation(
std::unique_ptr<Allocation>&& allocation) {
underlying_allocator_->Free(allocation.get());
{
// notify all waited allocators, they can try to allocate memory after free.
std::lock_guard<std::mutex> lock(mutex_);
cv_.notify_all();
}
return ret;
}
} // namespace allocation
......
......@@ -35,7 +35,7 @@ class RetryAllocation : public Allocation {
underlying_allocation_(std::move(underlying_allocation)),
retry_allocator_(retry_allocator) {}
~RetryAllocation();
~RetryAllocation() final;
private:
std::unique_ptr<Allocation> underlying_allocation_;
......@@ -61,13 +61,17 @@ class RetryAllocator : public ManagedAllocator,
bool IsAllocThreadSafe() const override;
std::unique_ptr<Allocation> Allocate(
size_t size, Allocator::Attr attr = kDefault) override;
std::unique_ptr<Allocation> Allocate(size_t size,
Allocator::Attr attr) override;
std::shared_ptr<Allocation> AllocateShared(
size_t size, Allocator::Attr attr = kDefault) override;
std::shared_ptr<Allocation> AllocateShared(size_t size,
Allocator::Attr attr) override;
void FreeUnderlyingAllocation(std::unique_ptr<Allocation>&& allocation);
private:
Allocation* AllocateImpl(size_t size, Allocator::Attr attr);
void EnforceCheck() {
PADDLE_ENFORCE_NOT_NULL(
underlying_allocator_.get(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册