未验证 提交 ae58afc5 编写于 作者: Z Zeng Jinle 提交者: GitHub

Feature/auto_growth_allocator (#18561)

* feature/auto_growth_allocator, test=develop

* add unittest of AlignedAllocator, test=develop

* try to turn on auto_growth to test on CI, test=develop

* fix segmentation fault in mixed_vector.h, test=develop

* add unittests, test=develop
上级 bb2f5d24
......@@ -216,7 +216,7 @@ class Vector {
void *src = gpu_->ptr();
void *dst = cpu_.data();
paddle::memory::Copy(platform::CPUPlace(), dst, CUDAPlace().get(), src,
gpu_->size(), stream);
gpu_memory_size_, stream);
dev_ctx->Wait();
}
......@@ -256,13 +256,14 @@ class Vector {
void CopyCPUDataToCUDA(const platform::Place &place) const {
void *src = cpu_.data();
gpu_ = memory::Alloc(place, cpu_.size() * sizeof(T));
gpu_memory_size_ = cpu_.size() * sizeof(T);
gpu_ = memory::Alloc(place, gpu_memory_size_);
void *dst = gpu_->ptr();
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
auto stream = dev_ctx->stream();
paddle::memory::Copy(CUDAPlace().get(), dst, platform::CPUPlace(), src,
gpu_->size(), stream);
gpu_memory_size_, stream);
}
void ImmutableCPU() const {
......@@ -285,6 +286,7 @@ class Vector {
mutable std::vector<T> cpu_;
mutable paddle::memory::AllocationPtr gpu_;
mutable size_t gpu_memory_size_{0};
mutable int flag_;
mutable std::mutex mtx_;
......
......@@ -116,7 +116,7 @@ TEST(Tensor, MutableData) {
EXPECT_NE(p1, nullptr);
// set src_tensor a new dim with large size
// momery is supposed to be re-allocated
p2 = src_tensor.mutable_data<float>(framework::make_ddim({3, 4}),
p2 = src_tensor.mutable_data<float>(framework::make_ddim({3, 1024}),
platform::CUDAPlace());
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2);
......
cc_library(allocator SRCS allocator.cc DEPS place)
cc_library(cpu_allocator SRCS cpu_allocator.cc DEPS allocator)
cc_library(best_fit_allocator SRCS best_fit_allocator.cc DEPS allocator)
cc_library(locked_allocator SRCS locked_allocator.cc DEPS allocator)
cc_library(buffered_allocator SRCS buffered_allocator.cc DEPS allocator)
cc_library(legacy_allocator SRCS legacy_allocator.cc DEPS allocator buddy_allocator profiler)
cc_test(buffered_allocator_test SRCS buffered_allocator_test.cc DEPS best_fit_allocator locked_allocator buffered_allocator cpu_allocator)
cc_library(best_fit_allocator SRCS best_fit_allocator.cc DEPS allocator)
cc_library(naive_best_fit_allocator SRCS naive_best_fit_allocator.cc DEPS allocator buddy_allocator profiler)
cc_test(buffered_allocator_test SRCS buffered_allocator_test.cc DEPS locked_allocator buffered_allocator cpu_allocator best_fit_allocator)
if (WITH_GPU)
nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard)
......@@ -12,6 +12,13 @@ endif()
cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator)
nv_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator)
if (WITH_GPU)
set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard)
else ()
set(AllocatorFacadeDeps)
endif()
if (WITH_GPU)
nv_test(best_fit_allocator_test
SRCS best_fit_allocator_test.cc
......@@ -30,26 +37,14 @@ else()
cpu_allocator)
endif()
nv_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator)
if (WITH_GPU)
set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard)
else ()
set(AllocatorFacadeDeps)
endif()
list(APPEND AllocatorFacadeDeps cpu_allocator locked_allocator best_fit_allocator aligned_allocator auto_increment_allocator conditional_allocator retry_allocator buffered_allocator legacy_allocator)
list(APPEND AllocatorFacadeDeps cpu_allocator locked_allocator aligned_allocator retry_allocator buffered_allocator naive_best_fit_allocator auto_growth_best_fit_allocator best_fit_allocator)
cc_library(aligned_allocator SRCS aligned_allocator.cc DEPS allocator)
cc_library(auto_increment_allocator SRCS auto_increment_allocator.cc DEPS allocator)
cc_library(conditional_allocator SRCS conditional_allocator.cc DEPS allocator)
cc_test(test_aligned_allocator SRCS test_aligned_allocator.cc DEPS aligned_allocator)
cc_library(allocator_strategy SRCS allocator_strategy.cc DEPS gflags ${AllocatorFacadeDeps})
cc_library(allocator_facade SRCS allocator_facade.cc DEPS allocator_strategy)
nv_test(allocation_and_eigen_test SRCS allocation_and_eigen_test.cu DEPS allocator_facade)
cc_test(naive_best_fit_allocator_facade_test SRCS naive_best_fit_allocator_facade_test.cc DEPS allocator_facade)
cc_test(retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator best_fit_allocator locked_allocator cpu_allocator)
cc_test(retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator locked_allocator cpu_allocator)
if (WITH_TESTING)
set_tests_properties(retry_allocator_test PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
endif()
......@@ -57,3 +52,6 @@ endif()
cc_test(allocator_facade_abs_flags_test SRCS allocator_facade_abs_flags_test.cc DEPS allocator_facade)
cc_test(allocator_facade_frac_flags_test SRCS allocator_facade_frac_flags_test.cc DEPS allocator_facade)
cc_library(auto_growth_best_fit_allocator SRCS auto_growth_best_fit_allocator.cc DEPS allocator aligned_allocator)
cc_test(auto_growth_best_fit_allocator_facade_test SRCS auto_growth_best_fit_allocator_facade_test.cc DEPS cpu_allocator auto_growth_best_fit_allocator)
......@@ -13,19 +13,46 @@
// limitations under the License.
#include "paddle/fluid/memory/allocation/aligned_allocator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace memory {
namespace allocation {
ThinAlignedAllocator::ThinAlignedAllocator(
std::shared_ptr<Allocator> underlyning_allocator)
: underlying_allocator_(std::move(underlyning_allocator)) {}
class AlignedAllocation : public Allocation {
public:
AlignedAllocation(AllocationPtr underlying_allocation, size_t offset)
: Allocation(
reinterpret_cast<uint8_t*>(underlying_allocation->ptr()) + offset,
underlying_allocation->size() - offset,
underlying_allocation->place()),
underlying_allocation_(std::move(underlying_allocation)) {}
bool ThinAlignedAllocator::IsAllocThreadSafe() const {
private:
AllocationPtr underlying_allocation_;
};
AlignedAllocator::AlignedAllocator(
const std::shared_ptr<Allocator>& underlyning_allocator, size_t alignment)
: underlying_allocator_(underlyning_allocator), alignment_(alignment) {
PADDLE_ENFORCE(alignment_ > 0, "alignment must be positive integer");
if (alignment_ & (alignment_ - 1)) {
PADDLE_THROW("alignment must be 2^N, but got %d", alignment_);
}
}
bool AlignedAllocator::IsAllocThreadSafe() const {
return underlying_allocator_->IsAllocThreadSafe();
}
Allocation* AlignedAllocator::AllocateImpl(size_t size) {
auto raw_allocation = underlying_allocator_->Allocate(size + alignment_);
size_t offset = AlignedPtrOffset(raw_allocation->ptr(), alignment_);
return new AlignedAllocation(std::move(raw_allocation), offset);
}
void AlignedAllocator::FreeImpl(Allocation* allocation) { delete allocation; }
} // namespace allocation
} // namespace memory
} // namespace paddle
......@@ -21,80 +21,21 @@ namespace paddle {
namespace memory {
namespace allocation {
// The aligned allocation and allocator will wrap a managed allocator,
// and returns the aligned pointer.
//
// NOTE(yy): For speed reason, I just use a template parameter to get
// alignment, however, it can be an private member if necessary.
//
// NOTE(yy): kAlignment must be 2^N. a `static_assert` should be added.
template <size_t kAlignment>
class AlignedAllocation : public Allocation {
static_assert(kAlignment > 0 && (kAlignment & (kAlignment - 1)) == 0,
"kAlignment must be 2^N");
class AlignedAllocator : public Allocator {
public:
AlignedAllocation(AllocationPtr&& underlying_allocation, size_t size)
: Allocation(AlignedPtr(underlying_allocation->ptr()),
size + kAlignment - Offset(underlying_allocation->ptr()),
underlying_allocation->place()),
underlying_allocation_(std::move(underlying_allocation)) {}
private:
static void* AlignedPtr(void* ptr) {
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(ptr) +
Offset(ptr));
}
AlignedAllocator(const std::shared_ptr<Allocator>& underlying_allocator,
size_t alignment);
// Offset to aligned pointer.
// if ptr is already aligned, returns 0.
static size_t Offset(void* ptr) {
auto ptr_addr = reinterpret_cast<intptr_t>(ptr);
intptr_t aligned_addr = (ptr_addr & ~(kAlignment - 1));
intptr_t diff = aligned_addr - ptr_addr;
if (diff == 0) {
return 0;
} else {
return kAlignment + diff;
}
}
AllocationPtr underlying_allocation_;
};
// Thin aligned allocator is trivial and used to generate a small size binary.
//
// NOTE(yy): This is a trick to make a template class. This class extract the
// common code into a `thin` class. So if there are multiple specification of
// the template class, the binary size will not extended too much.
//
// NOTE(yy): This could be an over design. If it harms readability of code, it
// could be removed later.
class ThinAlignedAllocator : public Allocator {
public:
explicit ThinAlignedAllocator(
std::shared_ptr<Allocator> underlyning_allocator);
bool IsAllocThreadSafe() const;
bool IsAllocThreadSafe() const override;
protected:
std::shared_ptr<Allocator> underlying_allocator_;
};
Allocation* AllocateImpl(size_t size) override;
// An aligned allocator will allocate `size+kAlignment` allocation and adjust
// the pointer offset.
template <size_t kAlignment>
class AlignedAllocator : public ThinAlignedAllocator {
public:
using ThinAlignedAllocator::ThinAlignedAllocator;
void FreeImpl(Allocation* allocation) override;
protected:
Allocation* AllocateImpl(size_t size) override {
auto raw_allocation = underlying_allocator_->Allocate(size + kAlignment);
return new AlignedAllocation<kAlignment>(std::move(raw_allocation), size);
}
void FreeImpl(Allocation* allocation) override { delete allocation; }
private:
std::shared_ptr<Allocator> underlying_allocator_;
size_t alignment_;
};
} // namespace allocation
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
#include "unsupported/Eigen/CXX11/Tensor"
// NOTE(yy): this unittest is not important. It just used for debugging.
// It can be removed later.
struct FillZero {
public:
float* ptr_;
__device__ void operator()(size_t i) { ptr_[i] = 0.0f; }
};
namespace paddle {
TEST(Eigen, main) {
framework::Tensor tensor;
platform::CUDAPlace gpu(0);
float* ptr = tensor.mutable_data<float>({10, 10}, gpu);
auto& dev_ctx = *reinterpret_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(gpu));
PADDLE_ENFORCE(cudaMemset(ptr, 0, sizeof(float) * 100));
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx, 100);
for_range(FillZero{ptr});
dev_ctx.Wait();
auto eigen_vec = framework::EigenVector<float>::Flatten(tensor);
auto& eigen_dev = *dev_ctx.eigen_device();
eigen_vec.device(eigen_dev) = eigen_vec.constant(0.0f);
}
} // namespace paddle
......@@ -161,6 +161,9 @@ class Allocator {
using AllocationPtr = std::unique_ptr<Allocation, AllocationDeleter>;
// Allocate an allocation.
// size may be 0, but it would be too complex if we handle size == 0
// in each Allocator. So we handle size == 0 inside AllocatorFacade
// in our design.
inline AllocationPtr Allocate(size_t size) {
auto ptr = AllocateImpl(size);
ptr->RegisterDecoratedAllocator(this);
......@@ -184,6 +187,17 @@ class Allocator {
using AllocationDeleter = Allocator::AllocationDeleter;
using AllocationPtr = Allocator::AllocationPtr;
inline size_t AlignedSize(size_t size, size_t alignment) {
auto remaining = size % alignment;
return remaining == 0 ? size : size + alignment - remaining;
}
inline size_t AlignedPtrOffset(const void* ptr, size_t alignment) {
auto ptr_addr = reinterpret_cast<uintptr_t>(ptr);
auto diff = ptr_addr % alignment;
return diff == 0 ? 0 : alignment - diff;
}
} // namespace allocation
} // namespace memory
} // namespace paddle
......@@ -19,15 +19,12 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/memory/allocation/aligned_allocator.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
#include "paddle/fluid/memory/allocation/auto_increment_allocator.h"
#include "paddle/fluid/memory/allocation/best_fit_allocator.h"
#include "paddle/fluid/memory/allocation/conditional_allocator.h"
#include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h"
#include "paddle/fluid/memory/allocation/cpu_allocator.h"
#include "paddle/fluid/memory/allocation/legacy_allocator.h"
#include "paddle/fluid/memory/allocation/locked_allocator.h"
#include "paddle/fluid/memory/allocation/naive_best_fit_allocator.h"
#include "paddle/fluid/memory/allocation/retry_allocator.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -48,161 +45,35 @@ namespace paddle {
namespace memory {
namespace allocation {
static inline std::shared_ptr<Allocator> WrapRetryAllocator(
std::shared_ptr<Allocator> allocator, int64_t retry_time) {
if (retry_time > 0) {
auto* retry_allocator =
new RetryAllocator(std::move(allocator), retry_time);
allocator.reset(retry_allocator);
}
return allocator;
}
// TODO(yy): Dirty code here. This class should be configurable in runtime.
class CPUManagedAllocator : public Allocator {
public:
CPUManagedAllocator() : normal_allocator_(new CPUAllocator()) {}
bool IsAllocThreadSafe() const override { return true; }
protected:
Allocation* AllocateImpl(size_t size) override {
return normal_allocator_->Allocate(size).release();
}
private:
std::shared_ptr<Allocator> normal_allocator_;
};
// TODO(yy): Dirty code here. This class should be configurable in runtime.
class ChunkedAllocator : public Allocator {
public:
explicit ChunkedAllocator(std::unique_ptr<Allocator> system_allocator,
size_t max_chunk_size, size_t capacity = 1,
int64_t retry_time = -1)
: max_chunk_size_(max_chunk_size), retry_time_(retry_time) {
raw_allocator_ = std::move(system_allocator);
if (max_chunk_size_ == 0) {
default_allocator_ = raw_allocator_;
} else {
if (capacity == 1) {
VLOG(1) << "Create BestFitAllocator with chunk_size "
<< max_chunk_size_;
default_allocator_ = CreateAllocatorWithChunk();
} else {
VLOG(1) << "Create AutoIncrementAllocator with chunk_size "
<< max_chunk_size_ << " and capacity " << capacity;
default_allocator_ = std::make_shared<AutoIncrementAllocator>(
[this] { return CreateAllocatorWithChunk(); }, capacity);
}
}
auto* cond_allocator = new ConditionalAllocator();
cond_allocator
->AddAllocator([this](size_t size) { return size < max_chunk_size_; },
default_allocator_)
.AddAllocator(
[](size_t size) {
return true; // default case
},
raw_allocator_);
default_allocator_.reset(cond_allocator);
}
~ChunkedAllocator() override {
// Specify destruct order.
default_allocator_.reset();
chunks_.clear();
raw_allocator_.reset();
}
std::shared_ptr<Allocator> CreateAllocatorWithChunk() {
chunks_.emplace_back(raw_allocator_->Allocate(max_chunk_size_));
auto* allocation = chunks_.back().get();
std::shared_ptr<Allocator> allocator(new LockedAllocator(
std::shared_ptr<Allocator>(new BestFitAllocator(allocation))));
allocator = WrapRetryAllocator(allocator, retry_time_);
return std::make_shared<AlignedAllocator<64u>>(std::move(allocator));
}
bool IsAllocThreadSafe() const override { return true; }
protected:
Allocation* AllocateImpl(size_t size) override {
return default_allocator_->Allocate(size).release();
}
protected:
size_t max_chunk_size_;
int64_t retry_time_;
std::vector<AllocationPtr> chunks_;
std::shared_ptr<Allocator> raw_allocator_;
std::shared_ptr<Allocator> default_allocator_;
};
#ifdef PADDLE_WITH_CUDA
class CUDAChunkedAllocator : public ChunkedAllocator {
public:
explicit CUDAChunkedAllocator(int dev_id)
: ChunkedAllocator(std::unique_ptr<Allocator>(
new CUDAAllocator(platform::CUDAPlace(dev_id))),
GetMaxChunkSize(dev_id), GetCapcity(dev_id),
GetRetryTime()) {}
private:
static size_t GetMaxChunkSize(int dev_id) {
platform::CUDADeviceGuard guard(dev_id);
return platform::GpuMaxChunkSize();
}
static size_t GetCapcity(int dev_id) {
platform::CUDADeviceGuard guard(dev_id);
size_t available, total;
platform::GpuMemoryUsage(&available, &total);
size_t max_chunk_size = platform::GpuMaxChunkSize();
return max_chunk_size == 0 ? 0 : available / max_chunk_size;
}
static int64_t GetRetryTime() { return FLAGS_gpu_allocator_retry_time; }
};
class CUDAPinnedChunkedAllocator : public ChunkedAllocator {
public:
CUDAPinnedChunkedAllocator()
: ChunkedAllocator(std::unique_ptr<Allocator>(new CPUPinnedAllocator()),
platform::CUDAPinnedMaxChunkSize(), GetCapacity(),
-1) {} // never retry
private:
static size_t GetCapacity() {
size_t total = platform::CpuTotalPhysicalMemory();
size_t max_chunk_size = platform::CUDAPinnedMaxChunkSize();
return max_chunk_size == 0 ? 0 : total / max_chunk_size;
}
};
#endif
class AllocatorFacadePrivate {
public:
AllocatorFacadePrivate() {
auto strategy = GetAllocatorStrategy();
switch (strategy) {
case AllocatorStrategy::kLegacy: {
InitLegacyAllocator();
case AllocatorStrategy::kNaiveBestFit: {
InitNaiveBestFitCPUAllocator();
#ifdef PADDLE_WITH_CUDA
for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount();
++dev_id) {
InitNaiveBestFitCUDAAllocator(platform::CUDAPlace(dev_id));
}
InitNaiveBestFitCUDAPinnedAllocator();
#endif
break;
}
case AllocatorStrategy::kNaiveBestFit: {
InitCPUAllocator();
InitCUDAAllocator();
InitCUDAPinnedAllocator();
case AllocatorStrategy::kAutoGrowth: {
InitNaiveBestFitCPUAllocator();
#ifdef PADDLE_WITH_CUDA
for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount();
++dev_id) {
InitAutoGrowthCUDAAllocator(platform::CUDAPlace(dev_id));
}
InitNaiveBestFitCUDAPinnedAllocator();
#endif
break;
}
default: {
PADDLE_THROW("Unsupported allocator strategy: %d",
static_cast<int>(strategy));
......@@ -215,47 +86,33 @@ class AllocatorFacadePrivate {
const platform::Place& place, size_t size) {
const auto& allocators = (size > 0 ? allocators_ : zero_size_allocators_);
auto iter = allocators.find(place);
if (iter == allocators.end()) {
throw BadAlloc(
string::Sprintf("No such allocator for the place, %s", place));
}
PADDLE_ENFORCE(iter != allocators.end(),
"No such allocator for the place, %s", place);
return iter->second;
}
private:
void InitLegacyAllocator() {
std::vector<platform::Place> places{platform::CPUPlace()};
#ifdef PADDLE_WITH_CUDA
for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) {
places.emplace_back(platform::CUDAPlace(dev_id));
}
places.emplace_back(platform::CUDAPinnedPlace());
#endif
for (auto& p : places) {
allocators_[p] = std::make_shared<LegacyAllocator>(p);
}
void InitNaiveBestFitCPUAllocator() {
allocators_[platform::CPUPlace()] =
std::make_shared<NaiveBestFitAllocator>(platform::CPUPlace());
}
void InitCPUAllocator() {
allocators_[platform::CPUPlace()] = std::make_shared<CPUManagedAllocator>();
#ifdef PADDLE_WITH_CUDA
void InitNaiveBestFitCUDAPinnedAllocator() {
allocators_[platform::CUDAPinnedPlace()] =
std::make_shared<NaiveBestFitAllocator>(platform::CUDAPinnedPlace());
}
void InitCUDAAllocator() {
#ifdef PADDLE_WITH_CUDA
int device_count = platform::GetCUDADeviceCount();
for (int dev_id = 0; dev_id < device_count; ++dev_id) {
allocators_[platform::CUDAPlace(dev_id)] =
std::make_shared<CUDAChunkedAllocator>(dev_id);
}
#endif
void InitNaiveBestFitCUDAAllocator(platform::CUDAPlace p) {
allocators_[p] = std::make_shared<NaiveBestFitAllocator>(p);
}
void InitCUDAPinnedAllocator() {
#ifdef PADDLE_WITH_CUDA
allocators_[platform::CUDAPinnedPlace()] =
std::make_shared<CUDAPinnedChunkedAllocator>();
#endif
void InitAutoGrowthCUDAAllocator(platform::CUDAPlace p) {
auto cuda_allocator = std::make_shared<CUDAAllocator>(p);
allocators_[p] = std::make_shared<AutoGrowthBestFitAllocator>(
cuda_allocator, platform::GpuMinChunkSize());
}
#endif
class ZeroSizeAllocator : public Allocator {
public:
......
......@@ -23,6 +23,7 @@ DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_int64(gpu_allocator_retry_time);
#endif
DECLARE_string(allocator_strategy);
namespace paddle {
namespace memory {
......@@ -92,6 +93,8 @@ TEST(Allocator, SpecifyGpuMemory) {
FLAGS_fraction_of_cuda_pinned_memory_to_use = 0.5;
#endif
FLAGS_allocator_strategy = "naive_best_fit";
AllocateTestCases();
}
......
......@@ -23,6 +23,7 @@ DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_int64(gpu_allocator_retry_time);
#endif
DECLARE_string(allocator_strategy);
namespace paddle {
namespace memory {
......@@ -85,6 +86,7 @@ TEST(Allocator, Allocator) {
FLAGS_gpu_allocator_retry_time = 500;
FLAGS_fraction_of_cuda_pinned_memory_to_use = 0.5;
#endif
FLAGS_allocator_strategy = "naive_best_fit";
AllocateTestCases();
}
......
......@@ -14,27 +14,29 @@
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h"
DEFINE_string(
allocator_strategy, "legacy",
"The allocation strategy. Legacy means the original allocator of Fluid."
"naive_best_fit means the experimental best fit allocator. "
"allocator. Enum in [legacy, naive_best_fit].");
DEFINE_string(allocator_strategy, "naive_best_fit",
"The allocation strategy. naive_best_fit means the original best "
"fit allocator of Fluid. "
"auto_growth means the experimental auto-growth allocator. "
"Enum in [naive_best_fit, auto_growth].");
namespace paddle {
namespace memory {
namespace allocation {
static AllocatorStrategy GetStrategyFromFlag() {
if (FLAGS_allocator_strategy == "legacy") {
return AllocatorStrategy::kLegacy;
} else if (FLAGS_allocator_strategy == "naive_best_fit") {
if (FLAGS_allocator_strategy == "naive_best_fit") {
return AllocatorStrategy::kNaiveBestFit;
} else {
PADDLE_THROW("Unsupported allocator strategy: %s",
FLAGS_allocator_strategy);
}
if (FLAGS_allocator_strategy == "auto_growth") {
return AllocatorStrategy::kAutoGrowth;
}
PADDLE_THROW("Unsupported allocator strategy: %s", FLAGS_allocator_strategy);
}
AllocatorStrategy GetAllocatorStrategy() {
......
......@@ -18,7 +18,7 @@ namespace paddle {
namespace memory {
namespace allocation {
enum class AllocatorStrategy { kLegacy, kNaiveBestFit };
enum class AllocatorStrategy { kNaiveBestFit, kAutoGrowth };
extern AllocatorStrategy GetAllocatorStrategy();
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h"
#include <algorithm>
#include <list>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <unordered_map>
#include "paddle/fluid/memory/allocation/aligned_allocator.h"
namespace paddle {
namespace memory {
namespace allocation {
AutoGrowthBestFitAllocator::AutoGrowthBestFitAllocator(
const std::shared_ptr<Allocator> &underlying_allocator, size_t alignment,
size_t chunk_size)
: underlying_allocator_(
std::make_shared<AlignedAllocator>(underlying_allocator, alignment)),
alignment_(alignment),
chunk_size_(std::max(AlignedSize(chunk_size, alignment), alignment)) {}
Allocation *AutoGrowthBestFitAllocator::AllocateImpl(size_t size) {
size = AlignedSize(size, alignment_);
std::lock_guard<std::mutex> guard(mtx_);
auto iter = free_blocks_.lower_bound(std::make_pair(size, nullptr));
BlockIt block_it;
if (iter != free_blocks_.end()) {
block_it = iter->second;
free_blocks_.erase(iter);
auto *chunk = block_it->chunk_;
size_t remaining_size = block_it->size_ - size;
if (remaining_size == 0) {
block_it->is_free_ = false;
} else {
auto remaining_free_block = chunk->blocks_.insert(
block_it, Block(block_it->ptr_, remaining_size, true, chunk));
free_blocks_.emplace(std::make_pair(remaining_size, block_it->ptr_),
remaining_free_block);
block_it->ptr_ =
reinterpret_cast<uint8_t *>(block_it->ptr_) + remaining_size;
block_it->size_ = size;
block_it->is_free_ = false;
}
} else {
size_t realloc_size = std::max(size, chunk_size_);
try {
chunks_.emplace_back(underlying_allocator_->Allocate(realloc_size));
} catch (BadAlloc &ex) {
if (size == realloc_size) throw ex;
realloc_size = size;
chunks_.emplace_back(underlying_allocator_->Allocate(realloc_size));
}
auto *chunk = &(*chunks_.rbegin());
realloc_size = chunk->allocation_->size();
uint8_t *p = reinterpret_cast<uint8_t *>(chunk->allocation_->ptr());
auto &blocks = chunk->blocks_;
size_t remaining_size = realloc_size - size;
if (remaining_size > 0) {
blocks.emplace_back(p, remaining_size, true, chunk);
free_blocks_.emplace(std::make_pair(remaining_size, p), --(blocks.end()));
}
blocks.emplace_back(p + remaining_size, size, false, chunk);
block_it = --(blocks.end());
VLOG(2) << "Not found and reallocate " << realloc_size << ", and remaining "
<< remaining_size;
}
return new BlockAllocation(block_it);
}
void AutoGrowthBestFitAllocator::FreeImpl(Allocation *allocation) {
std::lock_guard<std::mutex> guard(mtx_);
auto block_it = static_cast<BlockAllocation *>(allocation)->block_it_;
auto &blocks = block_it->chunk_->blocks_;
block_it->is_free_ = true;
if (block_it != blocks.begin()) {
auto prev_it = block_it;
--prev_it;
if (prev_it->is_free_) {
free_blocks_.erase(std::make_pair(prev_it->size_, prev_it->ptr_));
prev_it->size_ += block_it->size_;
blocks.erase(block_it);
block_it = prev_it;
}
}
auto next_it = block_it;
++next_it;
if (next_it != blocks.end() && next_it->is_free_) {
free_blocks_.erase(std::make_pair(next_it->size_, next_it->ptr_));
block_it->size_ += next_it->size_;
blocks.erase(next_it);
}
free_blocks_.emplace(std::make_pair(block_it->size_, block_it->ptr_),
block_it);
delete allocation;
}
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -13,45 +13,72 @@
// limitations under the License.
#pragma once
#include <functional>
#include <list>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <utility>
#include <vector>
#include "paddle/fluid/memory/allocation/allocator.h"
namespace paddle {
namespace memory {
namespace allocation {
// A composite allocator who will dispatch the allocation request by registered
// condition.
//
// For example:
//
// auto* cond_allocator = new ConditionalAllocator();
// cond_allocator->AddAllocator([](size_t size){
// // if size > 10
// return size > 10;
// }, allocator_b).AddAllocator([](size_t size){
// // else
// return true;
// }, allocator_c);
class ConditionalAllocator : public Allocator {
class AutoGrowthBestFitAllocator : public Allocator {
public:
ConditionalAllocator() = default;
ConditionalAllocator& AddAllocator(std::function<bool(size_t)> func,
std::shared_ptr<Allocator> allocator);
explicit AutoGrowthBestFitAllocator(
const std::shared_ptr<Allocator> &underlying_allocator, size_t alignment,
size_t chunk_size = 0);
bool IsAllocThreadSafe() const override;
bool IsAllocThreadSafe() const override { return true; }
protected:
Allocation* AllocateImpl(size_t size) override;
Allocation *AllocateImpl(size_t size) override;
void FreeImpl(Allocation *allocation) override;
private:
using AllocatorWithCond =
std::pair<std::function<bool(size_t)>, std::shared_ptr<Allocator>>;
std::vector<AllocatorWithCond> underlying_allocators_;
template <typename T>
using List = std::list<T>;
struct Chunk;
struct Block {
Block(void *ptr, size_t size, bool is_free, Chunk *chunk)
: ptr_(ptr), size_(size), is_free_(is_free), chunk_(chunk) {}
void *ptr_;
size_t size_;
bool is_free_;
Chunk *chunk_; // which chunk it is from
};
struct Chunk {
explicit Chunk(AllocationPtr allocation)
: allocation_(std::move(allocation)) {}
AllocationPtr allocation_;
List<Block> blocks_;
};
struct BlockAllocation : public Allocation {
explicit BlockAllocation(const List<Block>::iterator &it)
: Allocation(it->ptr_, it->size_, it->chunk_->allocation_->place()),
block_it_(it) {}
List<Block>::iterator block_it_;
};
using BlockIt = List<Block>::iterator;
std::shared_ptr<Allocator> underlying_allocator_;
std::map<std::pair<size_t, void *>, BlockIt> free_blocks_;
std::list<Chunk> chunks_;
size_t alignment_;
size_t chunk_size_;
mutable std::mutex mtx_;
};
} // namespace allocation
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -14,7 +14,13 @@
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <chrono> // NOLINT
#include <condition_variable> // NOLINT
#include <mutex> // NOLINT
#include <random>
#include <thread> // NOLINT
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/gpu_info.h"
#ifdef PADDLE_WITH_CUDA
DECLARE_double(fraction_of_gpu_memory_to_use);
......@@ -28,6 +34,11 @@ namespace paddle {
namespace memory {
namespace allocation {
static inline size_t AlignTo(size_t size, size_t alignment) {
auto remaining = size % alignment;
return remaining == 0 ? size : size + alignment - remaining;
}
TEST(allocator, allocator) {
#ifdef PADDLE_WITH_CUDA
FLAGS_fraction_of_gpu_memory_to_use = 0.01;
......@@ -35,11 +46,11 @@ TEST(allocator, allocator) {
FLAGS_fraction_of_cuda_pinned_memory_to_use = 0.5;
#endif
FLAGS_allocator_strategy = "naive_best_fit";
FLAGS_allocator_strategy = "auto_growth";
auto &instance = AllocatorFacade::Instance();
platform::Place place;
size_t size = 1024;
platform::Place place;
{
place = platform::CPUPlace();
......@@ -48,7 +59,7 @@ TEST(allocator, allocator) {
ASSERT_NE(cpu_allocation, nullptr);
ASSERT_NE(cpu_allocation->ptr(), nullptr);
ASSERT_EQ(cpu_allocation->place(), place);
ASSERT_EQ(cpu_allocation->size(), size);
ASSERT_EQ(cpu_allocation->size(), AlignedSize(size, 1024));
}
#ifdef PADDLE_WITH_CUDA
......@@ -59,7 +70,8 @@ TEST(allocator, allocator) {
ASSERT_NE(gpu_allocation, nullptr);
ASSERT_NE(gpu_allocation->ptr(), nullptr);
ASSERT_EQ(gpu_allocation->place(), place);
ASSERT_GE(gpu_allocation->size(), size);
ASSERT_GE(gpu_allocation->size(),
AlignedSize(size, platform::GpuMinChunkSize()));
}
{
......@@ -70,7 +82,8 @@ TEST(allocator, allocator) {
ASSERT_NE(gpu_allocation, nullptr);
ASSERT_NE(gpu_allocation->ptr(), nullptr);
ASSERT_EQ(gpu_allocation->place(), place);
ASSERT_GE(gpu_allocation->size(), size);
ASSERT_GE(gpu_allocation->size(),
AlignedSize(size, platform::GpuMinChunkSize()));
}
{
......@@ -81,7 +94,51 @@ TEST(allocator, allocator) {
ASSERT_NE(cuda_pinned_allocation, nullptr);
ASSERT_NE(cuda_pinned_allocation->ptr(), nullptr);
ASSERT_EQ(cuda_pinned_allocation->place(), place);
ASSERT_GE(cuda_pinned_allocation->size(), size);
ASSERT_GE(cuda_pinned_allocation->size(), AlignedSize(size, 1 << 20));
}
#endif
}
TEST(multithread_allocate, test_segfault) {
FLAGS_allocator_strategy = "auto_growth";
#ifdef PADDLE_WITH_CUDA
std::mutex mtx;
std::condition_variable cv;
bool flag = false;
auto alloc_func = [&](int dev_id, unsigned int seed) {
auto &instance = AllocatorFacade::Instance();
std::mt19937 gen(seed);
std::uniform_int_distribution<size_t> dist(1 << 20, 1 << 25);
{
std::unique_lock<std::mutex> lock(mtx);
cv.wait(lock, [&] { return flag; });
}
for (int i = 0; i < 50; i++) {
size_t size = dist(gen);
for (int j = 0; j < 10; j++) {
instance.Alloc(platform::CUDAPlace(dev_id), size);
}
}
};
std::vector<std::thread> ths;
for (size_t i = 0; i < 50; ++i) {
std::random_device rd;
ths.emplace_back(alloc_func, 0, rd());
}
{
std::lock_guard<std::mutex> guard(mtx);
flag = true;
}
cv.notify_all();
for (auto &th : ths) {
th.join();
}
#endif
}
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/auto_increment_allocator.h"
namespace paddle {
namespace memory {
namespace allocation {
bool AutoIncrementAllocator::IsAllocThreadSafe() const { return true; }
std::shared_ptr<Allocator> AutoIncrementAllocator::CreateNewAllocator() {
std::lock_guard<std::mutex> guard(mtx_);
auto old_size = allocator_num_.load();
PADDLE_ENFORCE_LT(old_size, underlying_allocators_.size(),
"Allocator number exceeds capacity %d",
underlying_allocators_.size());
underlying_allocators_[old_size] = creator_();
prev_success_allocator_ = old_size;
++allocator_num_;
PADDLE_ENFORCE(
underlying_allocators_[old_size]->IsAllocThreadSafe(),
"the underlying allocator must be thread safe. This is a program "
"bug.");
return underlying_allocators_[old_size];
}
Allocation *AutoIncrementAllocator::AllocateImpl(size_t size) {
auto cur = prev_success_allocator_.load();
size_t retry_count = allocator_num_.load();
size_t allocator_num = retry_count;
while (retry_count-- > 0) { // until there retry count is zero
try {
auto res = underlying_allocators_[cur]->Allocate(size);
prev_success_allocator_ = cur;
return res.release();
} catch (BadAlloc &) {
if (++cur >= allocator_num) {
cur = 0;
}
} catch (...) {
// if there is another type of allocation, just rethrow it.
throw;
}
}
// This happens when the first allocator is exhausted and
// there are more than 1 allocation requests
// In this situation, the first allocation request would success
// and the second allocation request would fail if we do not use
// the newly created allocator by the first allocation request.
for (cur = allocator_num; cur < allocator_num_; ++cur) {
try {
auto ret = underlying_allocators_[cur]->Allocate(size);
prev_success_allocator_ = cur;
return ret.release();
} catch (BadAlloc &) {
} catch (...) {
throw;
}
}
// No suitable allocator
return CreateNewAllocator()->Allocate(size).release();
}
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic> // NOLINT
#include <functional>
#include <memory>
#include <mutex> // NOLINT
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "paddle/fluid/memory/allocation/allocator.h"
namespace paddle {
namespace memory {
namespace allocation {
// The AutoIncrementAllocator manages many underlying allocators. If none of
// them can allocate the request memory, a new allocator will be created and
// invoke its `allocate` method.
//
// NOTE(yy): The AutoIncrementAllocator will prefer to allocate memory from
// the latest successful allocator.
//
// NOTE(yy): We may need to release an underlying allocator if it allocate
// nothing. However, it is generally not useful, since it will make performance
// undetermined.
//
// NOTE(yy): This allocator is only locked when creating new underlying
// allocator. The allocation requests from many threads may be dispatched
// to the same underlying allocator. So the underlying allocator must be
// thread safe.
//
// NOTE(zjl): Add capacity parameters to constructor. A high-performance
// thread-safe std::vector with varying size is hard to implement.
// Fortunately, we can get the total GPU memory and each chunk size.
// Therefore, we can get the suitable capacity of AutoIncrementAllocator.
class AutoIncrementAllocator : public Allocator {
public:
// Creator is the method to create ManagedAllocator
using AllocatorCreator = std::function<std::shared_ptr<Allocator>()>;
explicit AutoIncrementAllocator(AllocatorCreator&& creator, size_t capacity)
: creator_(std::move(creator)), underlying_allocators_(capacity) {}
bool IsAllocThreadSafe() const override;
private:
std::shared_ptr<Allocator> CreateNewAllocator();
protected:
Allocation* AllocateImpl(size_t size) override;
private:
AllocatorCreator creator_;
std::vector<AllocatorCreator::result_type> underlying_allocators_;
std::atomic<size_t> allocator_num_{0};
// Use std::atomic rather than std::mutex, since std::atomic is usually
// lock-free
std::atomic<size_t> prev_success_allocator_{0};
std::mutex mtx_;
};
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/conditional_allocator.h"
#include <memory>
namespace paddle {
namespace memory {
namespace allocation {
ConditionalAllocator& ConditionalAllocator::AddAllocator(
std::function<bool(size_t)> func, std::shared_ptr<Allocator> allocator) {
underlying_allocators_.emplace_back(std::move(func), std::move(allocator));
return *this;
}
bool ConditionalAllocator::IsAllocThreadSafe() const {
return std::all_of(underlying_allocators_.begin(),
underlying_allocators_.end(),
[](const AllocatorWithCond& allocatorWithCond) {
return allocatorWithCond.second->IsAllocThreadSafe();
});
}
Allocation* ConditionalAllocator::AllocateImpl(size_t size) {
for (auto& pair : underlying_allocators_) {
if (pair.first(size)) {
return pair.second->Allocate(size).release();
}
}
throw BadAlloc("No suitable allocator");
}
} // namespace allocation
} // namespace memory
} // namespace paddle
......@@ -18,7 +18,7 @@
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/memory/allocation/legacy_allocator.h"
#include "paddle/fluid/memory/allocation/naive_best_fit_allocator.h"
#include "paddle/fluid/memory/detail/buddy_allocator.h"
#include "paddle/fluid/memory/detail/system_allocator.h"
#include "paddle/fluid/platform/gpu_info.h"
......@@ -77,25 +77,6 @@ BuddyAllocator *GetCPUBuddyAllocator() {
return a;
}
// We compared the NaiveAllocator with BuddyAllocator in CPU memory allocation,
// seems they are almost the same overhead.
struct NaiveAllocator {
void *Alloc(size_t size) { return malloc(size); }
void Free(void *p) {
PADDLE_ENFORCE(p);
free(p);
}
static NaiveAllocator *Instance() {
static NaiveAllocator x;
return &x;
}
private:
std::mutex lock_;
};
template <>
void *Alloc<platform::CPUPlace>(const platform::CPUPlace &place, size_t size) {
VLOG(10) << "Allocate " << size << " bytes on " << platform::Place(place);
......@@ -128,9 +109,8 @@ BuddyAllocator *GetGPUBuddyAllocator(int gpu_id) {
std::call_once(init_flag, [gpu_id]() {
devices = platform::GetSelectedDevices();
int gpu_num = devices.size();
allocation::GPUMemMonitor.Initialize(devices.size());
a_arr = new BuddyAllocator *[gpu_num];
for (size_t i = 0; i < devices.size(); ++i) {
int dev_id = devices[i];
a_arr[i] = nullptr;
......@@ -191,9 +171,6 @@ void *Alloc<platform::CUDAPlace>(const platform::CUDAPlace &place,
<< ", GPU memory used: "
<< string::HumanReadableSize(Used<platform::CUDAPlace>(place));
} else {
if (FLAGS_benchmark) {
allocation::GPUMemMonitor.Add(place.device, size);
}
if (FLAGS_init_allocated_mem) {
cudaMemset(ptr, 0xEF, size);
}
......@@ -209,9 +186,6 @@ void Free<platform::CUDAPlace>(const platform::CUDAPlace &place, void *p,
size_t size) {
#ifdef PADDLE_WITH_CUDA
GetGPUBuddyAllocator(place.device)->Free(p);
if (FLAGS_benchmark) {
allocation::GPUMemMonitor.Minus(place.device, size);
}
#else
PADDLE_THROW("'CUDAPlace' is not supported in CPU only device.");
#endif
......@@ -320,81 +294,19 @@ size_t Usage::operator()(const platform::CUDAPinnedPlace &cuda_pinned) const {
} // namespace legacy
namespace allocation {
LegacyMemMonitor GPUMemMonitor;
Allocation *LegacyAllocator::AllocateImpl(size_t size) {
Allocation *NaiveBestFitAllocator::AllocateImpl(size_t size) {
void *ptr = boost::apply_visitor(legacy::AllocVisitor(size), place_);
auto *tmp_alloc = new Allocation(ptr, size, place_);
platform::MemEvenRecorder::Instance().PushMemRecord(
static_cast<void *>(tmp_alloc), place_, size);
return tmp_alloc;
return new Allocation(ptr, size, place_);
}
void LegacyAllocator::FreeImpl(Allocation *allocation) {
void NaiveBestFitAllocator::FreeImpl(Allocation *allocation) {
boost::apply_visitor(
legacy::FreeVisitor(allocation->ptr(), allocation->size()),
allocation->place());
platform::MemEvenRecorder::Instance().PopMemRecord(
static_cast<void *>(allocation), place_);
delete allocation;
}
bool MemInfo::Add(const size_t &size) {
std::lock_guard<std::mutex> lock(mutex_);
usage_ += size;
bool peak_point = usage_ > peak_usage_;
if (peak_point) peak_usage_ = usage_;
return peak_point;
}
void MemInfo::Minus(const size_t &size) {
std::lock_guard<std::mutex> lock(mutex_);
usage_ -= size;
}
uint64_t MemInfo::GetPeakUsage() const { return peak_usage_; }
LegacyMemMonitor::~LegacyMemMonitor() {
for (auto &item : gpu_mem_info_) delete item.second;
}
void LegacyMemMonitor::Initialize(const int &device_num) {
for (auto i = 0; i < device_num; ++i) {
gpu_mem_info_[i] = new MemInfo();
}
}
void LegacyMemMonitor::Add(const int &device, const size_t &size) {
if (gpu_mem_info_[device]->Add(size)) {
VLOG(3) << "#LegacyMemMonitor# device: " << device
<< " peak memory usage : "
<< (gpu_mem_info_[device]->GetPeakUsage() >> 20) << " MiB";
}
}
void LegacyMemMonitor::Minus(const int &device, const size_t &size) {
gpu_mem_info_[device]->Minus(size);
}
uint64_t LegacyMemMonitor::GetMemUsage(const int &device) const {
return gpu_mem_info_.find(device) == gpu_mem_info_.end()
? 0
: gpu_mem_info_.at(device)->GetPeakUsage();
}
void LegacyMemMonitor::PrintMemUsage() {
std::vector<int> devices;
for (const auto &item : gpu_mem_info_) {
devices.emplace_back(item.first);
}
std::sort(devices.begin(), devices.end());
for (const auto &device : devices) {
std::cout << "Device : " << device << " Peak Memory Usage : "
<< (gpu_mem_info_[device]->GetPeakUsage() >> 20) << " MiB"
<< std::endl;
}
}
} // namespace allocation
} // namespace memory
} // namespace paddle
......@@ -24,52 +24,9 @@ namespace paddle {
namespace memory {
namespace allocation {
class MemInfo {
class NaiveBestFitAllocator : public Allocator {
public:
MemInfo() : usage_(0), peak_usage_(0) {}
// return a flag to indicate current operation will create a peak point or not
bool Add(const size_t &);
void Minus(const size_t &);
uint64_t GetPeakUsage() const;
private:
/* current memory usage*/
uint64_t usage_;
uint64_t peak_usage_;
std::mutex mutex_;
DISABLE_COPY_AND_ASSIGN(MemInfo);
};
class LegacyMemMonitor {
public:
// used to store the GPU memory usage of each devices
using MemUsage = std::unordered_map</*device id*/ int,
/*mem usage info node*/ MemInfo *>;
MemUsage GetMemUsageInfo() { return gpu_mem_info_; }
~LegacyMemMonitor();
void Initialize(const int &);
void Add(const int &, const size_t &);
void Minus(const int &, const size_t &);
uint64_t GetMemUsage(const int &) const;
void PrintMemUsage();
private:
MemUsage gpu_mem_info_;
};
extern LegacyMemMonitor GPUMemMonitor;
class LegacyAllocatorPrivate;
class LegacyAllocator : public Allocator {
public:
explicit LegacyAllocator(const platform::Place &p) : place_(p) {}
explicit NaiveBestFitAllocator(const platform::Place &p) : place_(p) {}
protected:
Allocation *AllocateImpl(size_t size) override;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/fluid/memory/allocation/aligned_allocator.h"
namespace paddle {
namespace memory {
namespace allocation {
TEST(aligned, aligned_size) {
ASSERT_EQ(AlignedSize(1024, 1024), 1024);
ASSERT_EQ(AlignedSize(1023, 1024), 1024);
ASSERT_EQ(AlignedSize(1025, 1024), 2048);
}
struct StubAllocator : public Allocator {
public:
StubAllocator() = default;
size_t AllocNum() const { return alloc_num_; }
protected:
Allocation *AllocateImpl(size_t size) override {
++alloc_num_;
return new Allocation(new uint8_t[size], size, platform::CPUPlace());
}
void FreeImpl(Allocation *allocation) override {
delete[] static_cast<uint8_t *>(allocation->ptr());
delete allocation;
--alloc_num_;
}
private:
size_t alloc_num_{0};
};
bool IsAligned(const AllocationPtr &alloc, size_t alignment) {
return reinterpret_cast<uintptr_t>(alloc->ptr()) % alignment == 0;
}
TEST(aligned_allocator, aligned_allocator) {
size_t alignment = 1024;
auto allocator = std::make_shared<StubAllocator>();
auto aligned_allocator =
std::make_shared<AlignedAllocator>(allocator, alignment);
auto alloc1 = aligned_allocator->Allocate(1345);
ASSERT_EQ(allocator->AllocNum(), 1);
ASSERT_TRUE(IsAligned(alloc1, alignment));
alloc1.reset();
ASSERT_EQ(allocator->AllocNum(), 0);
{
auto alloc2 = aligned_allocator->Allocate(200);
ASSERT_TRUE(IsAligned(alloc2, alignment));
ASSERT_EQ(allocator->AllocNum(), 1);
auto alloc3 = aligned_allocator->Allocate(3021);
ASSERT_TRUE(IsAligned(alloc3, alignment));
ASSERT_EQ(allocator->AllocNum(), 2);
}
ASSERT_EQ(allocator->AllocNum(), 0);
}
} // namespace allocation
} // namespace memory
} // namespace paddle
......@@ -92,16 +92,16 @@ TEST(temporary_allocator, test_reuse_tmp_allocation) {
void* tmp_allocation_ptr1 = nullptr;
{
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
auto tmp_allocation1 = gpu_alloc.Allocate(100);
auto tmp_allocation1 = gpu_alloc.Allocate(200);
tmp_allocation_ptr1 = tmp_allocation1->ptr();
}
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1);
auto tmp_allocation2 = gpu_alloc.Allocate(100);
auto tmp_allocation2 = gpu_alloc.Allocate(200);
void* tmp_allocation_ptr2 = tmp_allocation2->ptr();
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr2);
auto tmp_allocation3 = gpu_alloc.Allocate(100);
auto tmp_allocation3 = gpu_alloc.Allocate(200);
void* tmp_allocation_ptr3 = tmp_allocation2->ptr();
PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr3);
#endif
......@@ -117,11 +117,11 @@ TEST(temporary_allocator, test_times_excess_than_required_tmp_allocation) {
{
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
auto tmp_allocation1 =
gpu_alloc.Allocate(static_cast<size_t>(100 * excess_fraction - 1));
gpu_alloc.Allocate(static_cast<size_t>(200 * excess_fraction - 1));
tmp_allocation_ptr1 = tmp_allocation1->ptr();
}
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1);
auto tmp_allocation2 = gpu_alloc.Allocate(100);
auto tmp_allocation2 = gpu_alloc.Allocate(200 * excess_fraction - 10);
void* tmp_allocation_ptr2 = tmp_allocation2->ptr();
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr2);
......
......@@ -41,7 +41,6 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
#include "paddle/fluid/memory/allocation/legacy_allocator.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/py_func_op.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
......@@ -189,13 +188,6 @@ PYBIND11_MODULE(core_noavx, m) {
m.add_object("_cleanup",
py::capsule([]() { ScopePool::Instance().Clear(); }));
m.def("get_mem_usage", [](int device) {
return memory::allocation::GPUMemMonitor.GetMemUsage(device);
});
m.def("print_mem_usage",
[]() { return memory::allocation::GPUMemMonitor.PrintMemUsage(); });
BindImperative(&m);
py::class_<Tensor>(m, "Tensor", py::buffer_protocol())
......
......@@ -122,12 +122,14 @@ list(REMOVE_ITEM TEST_OPS test_warpctc_op)
list(REMOVE_ITEM TEST_OPS test_dist_train)
list(REMOVE_ITEM TEST_OPS test_dist_transpiler)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf_auto_growth)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_feed)
list(REMOVE_ITEM TEST_OPS test_dist_se_resnext)
list(REMOVE_ITEM TEST_OPS test_dgc_op)
list(REMOVE_ITEM TEST_OPS test_dist_se_resnext_nccl)
list(REMOVE_ITEM TEST_OPS test_dist_transformer)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_transformer)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_transformer_auto_growth)
list(REMOVE_ITEM TEST_OPS test_bilinear_interp_op)
list(REMOVE_ITEM TEST_OPS test_nearest_interp_op)
list(REMOVE_ITEM TEST_OPS test_imperative_resnet)
......@@ -227,10 +229,12 @@ if(WITH_DISTRIBUTE)
endif()
py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf)
py_test_modules(test_parallel_executor_crf_auto_growth MODULES test_parallel_executor_crf_auto_growth ENVS FLAGS_allocator_strategy=auto_growth)
py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed)
set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 450)
set_tests_properties(test_parallel_executor_seresnext PROPERTIES TIMEOUT 740)
py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executor_transformer)
py_test_modules(test_parallel_executor_transformer_auto_growth MODULES test_parallel_executor_transformer_auto_growth ENVS FLAGS_allocator_strategy=auto_growth)
py_test_modules(test_layers MODULES test_layers ENVS FLAGS_cudnn_deterministic=1)
if(NOT WIN32)
py_test_modules(test_ir_memory_optimize_transformer MODULES test_ir_memory_optimize_transformer)
......@@ -256,4 +260,5 @@ endif()
set_tests_properties(test_recordio_reader test_parallel_executor_test_while_train test_parallel_executor_mnist
test_parallel_executor_seresnext test_parallel_executor_crf test_sync_batch_norm_op
test_parallel_executor_crf_auto_growth
test_buffer_shared_inplace_pass PROPERTIES LABELS "RUN_TYPE=DIST")
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from test_parallel_executor_crf import *
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,48 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import os
os.environ['FLAGS_benchmark'] = 'True'
import numpy
import paddle.fluid.core as core
from paddle.fluid.executor import Executor
from paddle.fluid.layers import mul, data
class TestPeakMemoryMonitoring(unittest.TestCase):
def test_mul(self):
a = data(name='a', shape=[784], dtype='float32')
b = data(
name='b',
shape=[784, 100],
dtype='float32',
append_batch_size=False)
out = mul(x=a, y=b)
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
a_np = numpy.random.random((100, 784)).astype('float32')
b_np = numpy.random.random((784, 100)).astype('float32')
self.assertEqual(0, core.get_mem_usage(0))
exe = Executor(place)
outs = exe.run(feed={'a': a_np, 'b': b_np}, fetch_list=[out])
out = outs[0]
#disable this assert since ctest will ignore the os.environ setting
#self.assertGreater(core.get_mem_usage(0), 0)
raised = False
try:
core.print_mem_usage()
except:
raised = True
self.assertFalse(raised, 'Exception raised')
os.environ['RECORDIO_FILENAME'] = './auto_growth_pe_transformer.wmt16.recordio'
import unittest
from test_parallel_executor_transformer import *
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册