diff --git a/paddle/fluid/memory/allocation/auto_increment_allocator.h b/paddle/fluid/memory/allocation/auto_increment_allocator.h index 116d4ca6892fdceba0dda6da8a5c6039ac24ebb5..650f1d1cc6c268185924bd77e145beda40772dd8 100644 --- a/paddle/fluid/memory/allocation/auto_increment_allocator.h +++ b/paddle/fluid/memory/allocation/auto_increment_allocator.h @@ -14,6 +14,7 @@ #pragma once +#include // NOLINT #include #include #include // NOLINT @@ -55,44 +56,61 @@ class AutoIncrementAllocator : public ManagedAllocator { template inline typename std::result_of::type InvokeOrCreateUnderlyingAllocator(Callback callback) { - size_t retry_count = underlying_allocators_.size(); - auto cur = prev_success_allocator_; + std::shared_ptr> + underlying_allocators = underlying_allocators_; + size_t retry_count = underlying_allocators->size(); + size_t allocator_num = retry_count; + auto cur = prev_success_allocator_.load(); while (retry_count-- > 0) { // until there retry count is zero try { - auto res = callback(*underlying_allocators_[cur]); - { - std::lock_guard guard(mtx_); - prev_success_allocator_ = cur; - } + auto res = callback(*((*underlying_allocators)[cur])); + prev_success_allocator_.store(cur); return std::move(res); } catch (BadAlloc&) { - ++cur; - if (cur >= underlying_allocators_.size()) { + if (++cur >= allocator_num) { cur = 0; } } catch (...) { // if there is another type of allocation, just rethrow it. - throw; + std::rethrow_exception(std::current_exception()); } } // No suitable allocator + + ManagedAllocator* new_allocator; { std::lock_guard guard(mtx_); - underlying_allocators_.emplace_back(creator_()); - prev_success_allocator_ = underlying_allocators_.size() - 1; - PADDLE_ENFORCE( - underlying_allocators_[prev_success_allocator_]->IsAllocThreadSafe(), - "the underlying allocator must be thread safe. This is a program " - "bug."); + auto old_size = underlying_allocators_->size(); + decltype(underlying_allocators_) new_allocators( + new std::vector(old_size + 1)); + for (size_t i = 0; i < old_size; ++i) { + (*new_allocators)[i] = (*underlying_allocators_)[i]; + } - return callback(*underlying_allocators_[prev_success_allocator_]); + (*new_allocators)[old_size] = creator_(); + new_allocator = (*new_allocators)[old_size].get(); + underlying_allocators_ = new_allocators; + prev_success_allocator_.store(old_size); } + + PADDLE_ENFORCE( + new_allocator->IsAllocThreadSafe(), + "the underlying allocator must be thread safe. This is a program " + "bug."); + return callback(*new_allocator); } AllocatorCreator creator_; - std::vector underlying_allocators_; - size_t prev_success_allocator_{0}; - std::mutex mtx_; // NOLINT + + // Use std::shared_ptr to ensure thread-safety + std::shared_ptr> + underlying_allocators_; + + // Use std::atomic rather than std::mutex, since std::atomic is usually + // lock-free + std::atomic prev_success_allocator_{0}; + + std::mutex mtx_; }; } // namespace allocation } // namespace memory