提交 a5cf565c 编写于 作者: S sneaxiy

fix auto_increment_allocator thread-safety bug

上级 bb04b54e
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <atomic> // NOLINT
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <thread> // NOLINT #include <thread> // NOLINT
...@@ -55,44 +56,61 @@ class AutoIncrementAllocator : public ManagedAllocator { ...@@ -55,44 +56,61 @@ class AutoIncrementAllocator : public ManagedAllocator {
template <typename Callback> template <typename Callback>
inline typename std::result_of<Callback(ManagedAllocator&)>::type inline typename std::result_of<Callback(ManagedAllocator&)>::type
InvokeOrCreateUnderlyingAllocator(Callback callback) { InvokeOrCreateUnderlyingAllocator(Callback callback) {
size_t retry_count = underlying_allocators_.size(); std::shared_ptr<std::vector<AllocatorCreator::result_type>>
auto cur = prev_success_allocator_; underlying_allocators = underlying_allocators_;
size_t retry_count = underlying_allocators->size();
size_t allocator_num = retry_count;
auto cur = prev_success_allocator_.load();
while (retry_count-- > 0) { // until there retry count is zero while (retry_count-- > 0) { // until there retry count is zero
try { try {
auto res = callback(*underlying_allocators_[cur]); auto res = callback(*((*underlying_allocators)[cur]));
{ prev_success_allocator_.store(cur);
std::lock_guard<std::mutex> guard(mtx_);
prev_success_allocator_ = cur;
}
return std::move(res); return std::move(res);
} catch (BadAlloc&) { } catch (BadAlloc&) {
++cur; if (++cur >= allocator_num) {
if (cur >= underlying_allocators_.size()) {
cur = 0; cur = 0;
} }
} catch (...) { } catch (...) {
// if there is another type of allocation, just rethrow it. // if there is another type of allocation, just rethrow it.
throw; std::rethrow_exception(std::current_exception());
} }
} }
// No suitable allocator // No suitable allocator
ManagedAllocator* new_allocator;
{ {
std::lock_guard<std::mutex> guard(mtx_); std::lock_guard<std::mutex> guard(mtx_);
underlying_allocators_.emplace_back(creator_()); auto old_size = underlying_allocators_->size();
prev_success_allocator_ = underlying_allocators_.size() - 1; decltype(underlying_allocators_) new_allocators(
new std::vector<AllocatorCreator::result_type>(old_size + 1));
for (size_t i = 0; i < old_size; ++i) {
(*new_allocators)[i] = (*underlying_allocators_)[i];
}
(*new_allocators)[old_size] = creator_();
new_allocator = (*new_allocators)[old_size].get();
underlying_allocators_ = new_allocators;
prev_success_allocator_.store(old_size);
}
PADDLE_ENFORCE( PADDLE_ENFORCE(
underlying_allocators_[prev_success_allocator_]->IsAllocThreadSafe(), new_allocator->IsAllocThreadSafe(),
"the underlying allocator must be thread safe. This is a program " "the underlying allocator must be thread safe. This is a program "
"bug."); "bug.");
return callback(*new_allocator);
return callback(*underlying_allocators_[prev_success_allocator_]);
}
} }
AllocatorCreator creator_; AllocatorCreator creator_;
std::vector<AllocatorCreator::result_type> underlying_allocators_;
size_t prev_success_allocator_{0}; // Use std::shared_ptr to ensure thread-safety
std::mutex mtx_; // NOLINT std::shared_ptr<std::vector<AllocatorCreator::result_type>>
underlying_allocators_;
// Use std::atomic rather than std::mutex, since std::atomic is usually
// lock-free
std::atomic<size_t> prev_success_allocator_{0};
std::mutex mtx_;
}; };
} // namespace allocation } // namespace allocation
} // namespace memory } // namespace memory
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册