提交 a5cf565c 编写于 作者: S sneaxiy

fix auto_increment_allocator thread-safety bug

上级 bb04b54e
......@@ -14,6 +14,7 @@
#pragma once
#include <atomic> // NOLINT
#include <functional>
#include <memory>
#include <thread> // NOLINT
......@@ -55,44 +56,61 @@ class AutoIncrementAllocator : public ManagedAllocator {
template <typename Callback>
inline typename std::result_of<Callback(ManagedAllocator&)>::type
InvokeOrCreateUnderlyingAllocator(Callback callback) {
size_t retry_count = underlying_allocators_.size();
auto cur = prev_success_allocator_;
std::shared_ptr<std::vector<AllocatorCreator::result_type>>
underlying_allocators = underlying_allocators_;
size_t retry_count = underlying_allocators->size();
size_t allocator_num = retry_count;
auto cur = prev_success_allocator_.load();
while (retry_count-- > 0) { // until there retry count is zero
try {
auto res = callback(*underlying_allocators_[cur]);
{
std::lock_guard<std::mutex> guard(mtx_);
prev_success_allocator_ = cur;
}
auto res = callback(*((*underlying_allocators)[cur]));
prev_success_allocator_.store(cur);
return std::move(res);
} catch (BadAlloc&) {
++cur;
if (cur >= underlying_allocators_.size()) {
if (++cur >= allocator_num) {
cur = 0;
}
} catch (...) {
// if there is another type of allocation, just rethrow it.
throw;
std::rethrow_exception(std::current_exception());
}
}
// No suitable allocator
ManagedAllocator* new_allocator;
{
std::lock_guard<std::mutex> guard(mtx_);
underlying_allocators_.emplace_back(creator_());
prev_success_allocator_ = underlying_allocators_.size() - 1;
PADDLE_ENFORCE(
underlying_allocators_[prev_success_allocator_]->IsAllocThreadSafe(),
"the underlying allocator must be thread safe. This is a program "
"bug.");
auto old_size = underlying_allocators_->size();
decltype(underlying_allocators_) new_allocators(
new std::vector<AllocatorCreator::result_type>(old_size + 1));
for (size_t i = 0; i < old_size; ++i) {
(*new_allocators)[i] = (*underlying_allocators_)[i];
}
return callback(*underlying_allocators_[prev_success_allocator_]);
(*new_allocators)[old_size] = creator_();
new_allocator = (*new_allocators)[old_size].get();
underlying_allocators_ = new_allocators;
prev_success_allocator_.store(old_size);
}
PADDLE_ENFORCE(
new_allocator->IsAllocThreadSafe(),
"the underlying allocator must be thread safe. This is a program "
"bug.");
return callback(*new_allocator);
}
AllocatorCreator creator_;
std::vector<AllocatorCreator::result_type> underlying_allocators_;
size_t prev_success_allocator_{0};
std::mutex mtx_; // NOLINT
// Use std::shared_ptr to ensure thread-safety
std::shared_ptr<std::vector<AllocatorCreator::result_type>>
underlying_allocators_;
// Use std::atomic rather than std::mutex, since std::atomic is usually
// lock-free
std::atomic<size_t> prev_success_allocator_{0};
std::mutex mtx_;
};
} // namespace allocation
} // namespace memory
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册