提交 bb04b54e 编写于 作者: S sneaxiy

add retry_allocator

add unittest of retry_allocator
上级 ea61e4ef
...@@ -4,6 +4,8 @@ cc_library(best_fit_allocator SRCS best_fit_allocator.cc DEPS allocator) ...@@ -4,6 +4,8 @@ cc_library(best_fit_allocator SRCS best_fit_allocator.cc DEPS allocator)
cc_library(locked_allocator SRCS locked_allocator.cc DEPS allocator) cc_library(locked_allocator SRCS locked_allocator.cc DEPS allocator)
nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard) nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard)
cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator)
if (WITH_GPU) if (WITH_GPU)
nv_test(best_fit_allocator_test nv_test(best_fit_allocator_test
SRCS best_fit_allocator_test.cc SRCS best_fit_allocator_test.cc
...@@ -49,3 +51,5 @@ cc_library(allocator_facade SRCS allocator_facade.cc DEPS ...@@ -49,3 +51,5 @@ cc_library(allocator_facade SRCS allocator_facade.cc DEPS
cuda_device_guard) cuda_device_guard)
nv_test(allocation_and_eigen_test SRCS allocation_and_eigen_test.cu DEPS allocator_facade) nv_test(allocation_and_eigen_test SRCS allocation_and_eigen_test.cu DEPS allocator_facade)
cc_test(retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator naive_managed_allocator best_fit_allocator locked_allocator cpu_allocator)
...@@ -29,6 +29,9 @@ namespace allocation { ...@@ -29,6 +29,9 @@ namespace allocation {
// NOTE(yy): kAlignment must be 2^N. a `static_assert` should be added. // NOTE(yy): kAlignment must be 2^N. a `static_assert` should be added.
template <size_t kAlignment> template <size_t kAlignment>
class AlignedAllocation : public Allocation { class AlignedAllocation : public Allocation {
static_assert(kAlignment > 0 && (kAlignment & (kAlignment - 1)) == 0,
"kAlignment must be 2^N");
public: public:
AlignedAllocation(std::unique_ptr<Allocation>&& underlying_allocation, AlignedAllocation(std::unique_ptr<Allocation>&& underlying_allocation,
size_t size) size_t size)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/retry_allocator.h"
namespace paddle {
namespace memory {
namespace allocation {
RetryAllocation::~RetryAllocation() {
auto allocator = retry_allocator_.lock();
{
// release allocation first
if (UNLIKELY(allocator == nullptr)) return;
allocator->underlying_allocator_->Free(underlying_allocation_.release());
}
{
// notify all waited allocators
std::lock_guard<std::mutex> lock(allocator->mutex_);
allocator->cv_.notify_all();
}
}
bool RetryAllocator::IsAllocThreadSafe() const { return true; }
std::shared_ptr<Allocation> RetryAllocator::AllocateShared(
size_t size, Allocator::Attr attr) {
return std::shared_ptr<Allocation>(Allocate(size, attr));
}
std::unique_ptr<Allocation> RetryAllocator::Allocate(size_t size,
Allocator::Attr attr) {
auto alloc_func = [&, this]() {
return new RetryAllocation(underlying_allocator_->Allocate(size, attr),
this->shared_from_this());
};
// In fact, we can unify the code of allocation success and failure
// But it would add lock even when allocation success at the first time
std::unique_ptr<Allocation> ret;
try {
ret.reset(alloc_func());
} catch (BadAlloc &) {
{
// We can just write allocation retry inside the predicate function of
// wait_until
// But it needs to acquire the lock when executing predicate function
// For better performance, we use loop here
std::exception_ptr ex;
auto end_time = std::chrono::high_resolution_clock::now() + retry_time_;
std::cv_status status;
do {
{
std::unique_lock<std::mutex> lock(mutex_);
status = cv_.wait_until(lock, end_time);
}
try {
ret.reset(alloc_func());
} catch (BadAlloc &) {
ex = std::current_exception();
} catch (...) {
std::rethrow_exception(std::current_exception());
}
} while (ret == nullptr && status != std::cv_status::timeout);
if (ret == nullptr) std::rethrow_exception(ex);
}
} catch (...) {
std::rethrow_exception(std::current_exception());
}
return ret;
}
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono> // NOLINT
#include <condition_variable> // NOLINT
#include <memory>
#include <mutex> // NOLINT
#include "paddle/fluid/memory/allocation/allocator.h"
namespace paddle {
namespace memory {
namespace allocation {
class RetryAllocator;
class RetryAllocation : public Allocation {
public:
RetryAllocation(std::unique_ptr<Allocation>&& underlying_allocation,
const std::shared_ptr<RetryAllocator>& retry_allocator)
: Allocation(underlying_allocation->ptr(), underlying_allocation->size(),
underlying_allocation->place()),
underlying_allocation_(std::move(underlying_allocation)),
retry_allocator_(retry_allocator) {}
~RetryAllocation();
private:
std::unique_ptr<Allocation> underlying_allocation_;
std::weak_ptr<RetryAllocator> retry_allocator_;
};
class RetryAllocator : public ManagedAllocator,
public std::enable_shared_from_this<RetryAllocator> {
private:
RetryAllocator(std::unique_ptr<Allocator>&& allocator, size_t retry_ms)
: underlying_allocator_(
dynamic_cast<UnmanagedAllocator*>(allocator.release())),
retry_time_(retry_ms) {
EnforceCheck();
}
public:
template <typename... Args>
static std::shared_ptr<ManagedAllocator> Create(Args... args) {
return std::shared_ptr<ManagedAllocator>(
new RetryAllocator(std::forward<Args>(args)...));
}
bool IsAllocThreadSafe() const override;
std::unique_ptr<Allocation> Allocate(
size_t size, Allocator::Attr attr = kDefault) override;
std::shared_ptr<Allocation> AllocateShared(
size_t size, Allocator::Attr attr = kDefault) override;
private:
void EnforceCheck() {
PADDLE_ENFORCE_NOT_NULL(
underlying_allocator_.get(),
"UnderlyingAllocator of RetryAllocator must be UnmanagedAllocator");
PADDLE_ENFORCE(underlying_allocator_->IsAllocThreadSafe(),
"UnderlyingAllocator of RetryAllocator must be thread-safe");
}
std::unique_ptr<UnmanagedAllocator> underlying_allocator_;
std::chrono::milliseconds retry_time_;
std::mutex mutex_;
std::condition_variable cv_;
// For debug, We can add an atomic integer to record how many memory sizes are
// waited to allocate
// std::atomic<size_t> waited_allocate_size_{0};
friend class RetryAllocation;
};
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/retry_allocator.h"
#include <algorithm>
#include <chrono> // NOLINT
#include <condition_variable> // NOLINT
#include <mutex> // NOLINT
#include <thread> // NOLINT
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/memory/allocation/best_fit_allocator.h"
#include "paddle/fluid/memory/allocation/cpu_allocator.h"
#include "paddle/fluid/memory/allocation/locked_allocator.h"
namespace paddle {
namespace memory {
namespace allocation {
TEST(RetryAllocator, RetryAllocator) {
CPUAllocator cpu_allocator;
size_t size = (1 << 20);
auto cpu_allocation = cpu_allocator.Allocate(size);
std::unique_ptr<BestFitAllocator> best_fit_allocator(
new BestFitAllocator(cpu_allocation.get()));
std::unique_ptr<LockedAllocator> locked_allocator(
new LockedAllocator(std::move(best_fit_allocator)));
size_t thread_num = 32;
size_t sleep_time = 40;
size_t extra_time = 2;
// Reserve to perform more tests in the future
std::vector<std::shared_ptr<ManagedAllocator>> allocators;
{
std::unique_ptr<BestFitAllocator> best_fit_allocator(
new BestFitAllocator(cpu_allocation.get()));
std::unique_ptr<LockedAllocator> locked_allocator(
new LockedAllocator(std::move(best_fit_allocator)));
allocators.push_back(
RetryAllocator::Create(std::move(locked_allocator),
(thread_num - 1) * (sleep_time + extra_time)));
}
for (auto &allocator : allocators) {
std::vector<std::thread> threads(thread_num);
std::vector<void *> addresses(threads.size(), nullptr);
std::mutex mutex;
std::condition_variable cv;
bool flag = false;
for (size_t i = 0; i < threads.size(); ++i) {
threads[i] = std::thread([&, i]() {
{
std::unique_lock<std::mutex> lock(mutex);
cv.wait(lock, [&] { return flag; });
}
auto ret = allocator->Allocate(size - 1);
addresses[i] = ret->ptr();
std::this_thread::sleep_for(std::chrono::milliseconds(sleep_time));
});
}
{
std::lock_guard<std::mutex> lock(mutex);
flag = true;
cv.notify_all();
}
for (auto &th : threads) {
th.join();
}
void *val = cpu_allocation->ptr();
bool is_all_equal = std::all_of(addresses.begin(), addresses.end(),
[val](void *p) { return p == val; });
ASSERT_TRUE(is_all_equal);
}
cpu_allocator.FreeUniquePtr(std::move(cpu_allocation));
}
} // namespace allocation
} // namespace memory
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册