提交 7ffc9fd8 编写于 作者: Y Yu Yang

Merge branch 'rewrite_allocation' of https://github.com/sneaxiy/Paddle into rewrite_allocation

...@@ -2,6 +2,8 @@ cc_library(allocator SRCS allocator.cc DEPS place) ...@@ -2,6 +2,8 @@ cc_library(allocator SRCS allocator.cc DEPS place)
cc_library(cpu_allocator SRCS cpu_allocator.cc DEPS allocator) cc_library(cpu_allocator SRCS cpu_allocator.cc DEPS allocator)
cc_library(best_fit_allocator SRCS best_fit_allocator.cc DEPS allocator) cc_library(best_fit_allocator SRCS best_fit_allocator.cc DEPS allocator)
cc_library(locked_allocator SRCS locked_allocator.cc DEPS allocator) cc_library(locked_allocator SRCS locked_allocator.cc DEPS allocator)
cc_library(buffered_allocator SRCS buffered_allocator.cc DEPS allocator)
cc_test(buffered_allocator_test SRCS buffered_allocator_test.cc DEPS best_fit_allocator locked_allocator buffered_allocator cpu_allocator)
if (WITH_GPU) if (WITH_GPU)
nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard) nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard)
...@@ -51,7 +53,8 @@ cc_library(allocator_facade SRCS allocator_facade.cc DEPS ...@@ -51,7 +53,8 @@ cc_library(allocator_facade SRCS allocator_facade.cc DEPS
auto_increment_allocator auto_increment_allocator
zero_size_allocator zero_size_allocator
conditional_allocator conditional_allocator
retry_allocator) retry_allocator
buffered_allocator)
nv_test(allocation_and_eigen_test SRCS allocation_and_eigen_test.cu DEPS allocator_facade) nv_test(allocation_and_eigen_test SRCS allocation_and_eigen_test.cu DEPS allocator_facade)
......
...@@ -12,22 +12,6 @@ ...@@ -12,22 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <utility>
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once #pragma once
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -141,11 +125,7 @@ class Allocator { ...@@ -141,11 +125,7 @@ class Allocator {
// a manally managed allocator. // a manally managed allocator.
class UnmanagedAllocator : public Allocator { class UnmanagedAllocator : public Allocator {
public: public:
virtual void Free(Allocation* allocation) = 0; virtual void FreeUniquePtr(std::unique_ptr<Allocation> allocation) = 0;
void FreeUniquePtr(std::unique_ptr<Allocation> allocation) {
Free(allocation.get());
}
}; };
// The allocation will be managed by smart pointers. i.e., users do not need // The allocation will be managed by smart pointers. i.e., users do not need
......
...@@ -104,8 +104,8 @@ BestFitAllocator::ListIt BestFitAllocator::SplitChunk(size_t request_size, ...@@ -104,8 +104,8 @@ BestFitAllocator::ListIt BestFitAllocator::SplitChunk(size_t request_size,
return to_use_it; return to_use_it;
} }
void BestFitAllocator::Free(Allocation* allocation) { void BestFitAllocator::FreeUniquePtr(std::unique_ptr<Allocation> allocation) {
auto* bf_allocation = dynamic_cast<BestFitAllocation*>(allocation); auto* bf_allocation = dynamic_cast<BestFitAllocation*>(allocation.get());
auto chunk_it = bf_allocation->ChunkIterator(); auto chunk_it = bf_allocation->ChunkIterator();
PADDLE_ENFORCE(!chunk_it->is_free); PADDLE_ENFORCE(!chunk_it->is_free);
chunk_it->is_free = true; chunk_it->is_free = true;
......
...@@ -109,7 +109,7 @@ class BestFitAllocator : public UnmanagedAllocator { ...@@ -109,7 +109,7 @@ class BestFitAllocator : public UnmanagedAllocator {
std::unique_ptr<Allocation> Allocate(size_t size, std::unique_ptr<Allocation> Allocate(size_t size,
Attr attr = kDefault) override; Attr attr = kDefault) override;
void Free(Allocation* allocation) override; void FreeUniquePtr(std::unique_ptr<Allocation> allocation) override;
size_t NumFreeChunks() const; size_t NumFreeChunks() const;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/buffered_allocator.h"
#include <algorithm>
#include <limits>
#include <utility>
namespace paddle {
namespace memory {
namespace allocation {
BufferedAllocator::BufferedAllocator(std::unique_ptr<Allocator>&& allocator) {
std::vector<size_t> division_plan(8 * sizeof(size_t));
for (size_t i = 0; i < 8 * sizeof(size_t); ++i) {
division_plan[i] = (static_cast<size_t>(1) << i);
}
InitAndEnforceCheck(std::move(allocator), division_plan);
}
BufferedAllocator::BufferedAllocator(std::unique_ptr<Allocator>&& allocator,
const std::vector<size_t>& division_plan) {
InitAndEnforceCheck(std::move(allocator), division_plan);
}
BufferedAllocator::~BufferedAllocator() { FlushImpl(); }
void BufferedAllocator::FlushImpl() {
for (auto& v : allocations_) {
for (auto& pair : v) {
underlying_allocator_->FreeUniquePtr(std::move(pair.second));
}
v.clear();
}
}
void BufferedAllocator::Flush() {
if (mtx_) {
std::lock_guard<std::mutex> lock(*mtx_);
FlushImpl();
} else {
FlushImpl();
}
}
void BufferedAllocator::InitAndEnforceCheck(
std::unique_ptr<Allocator>&& allocator,
const std::vector<size_t>& division_plan) {
underlying_allocator_.reset(
dynamic_cast<UnmanagedAllocator*>(allocator.release()));
PADDLE_ENFORCE_NOT_NULL(
underlying_allocator_,
"Underlying allocator of BufferedAllocator must be unmanaged");
if (underlying_allocator_->IsAllocThreadSafe()) {
mtx_.reset(new std::mutex());
}
constexpr size_t kMax = std::numeric_limits<size_t>::max();
if (division_plan.empty()) {
division_plan_.assign({0, kMax});
} else {
auto from = division_plan.front() == 0 ? division_plan.begin() + 1
: division_plan.begin();
auto to = division_plan.back() == kMax ? division_plan.end() - 1
: division_plan.end();
division_plan_.reserve(to - from + 2);
division_plan_.push_back(0);
division_plan_.insert(division_plan_.end(), from, to);
division_plan_.push_back(kMax);
for (size_t i = 1; i < division_plan_.size(); ++i) {
PADDLE_ENFORCE_LT(division_plan_[i - 1], division_plan_[i],
"Division plan must be strictly sorted");
}
}
allocations_.resize(division_plan_.size() - 1);
}
void BufferedAllocator::InsertAllocationImpl(
std::unique_ptr<Allocation>&& allocation) {
auto size = allocation->size();
auto idx = GetListIndex(size);
allocations_[idx].emplace(size, std::move(allocation));
}
void BufferedAllocator::InsertAllocation(
std::unique_ptr<Allocation>&& allocation) {
if (mtx_) {
std::lock_guard<std::mutex> lock(*mtx_);
InsertAllocationImpl(std::move(allocation));
} else {
InsertAllocationImpl(std::move(allocation));
}
}
bool BufferedAllocator::Match(size_t actual_size, size_t requested_size) {
return (actual_size >> 1) < requested_size;
}
size_t BufferedAllocator::GetListIndex(size_t size) {
auto it =
std::upper_bound(division_plan_.begin(), division_plan_.end(), size);
return static_cast<size_t>(it - division_plan_.begin()) - 1;
}
std::unique_ptr<Allocation> BufferedAllocator::RemoveAllocationImpl(
size_t size) {
auto idx = GetListIndex(size);
auto& allocation_map = allocations_[idx];
auto it = allocation_map.lower_bound(size);
// Only remove allocation whose size is not more than twice of requested size
if (it != allocation_map.end()) {
if (Match(it->second->size(), size)) {
auto ret = std::move(it->second);
allocation_map.erase(it);
return ret;
} else {
return nullptr;
}
} else {
while (++idx < allocations_.size() && Match(division_plan_[idx], size)) {
auto& allocation_map = allocations_[idx];
if (!allocation_map.empty()) {
auto it = allocation_map.begin();
if (Match(it->second->size(), size)) {
auto ret = std::move(it->second);
allocation_map.erase(it);
return ret;
} else {
return nullptr;
}
}
}
return nullptr;
}
}
std::unique_ptr<Allocation> BufferedAllocator::RemoveAllocation(size_t size) {
if (mtx_) {
std::lock_guard<std::mutex> lock(*mtx_);
return RemoveAllocationImpl(size);
} else {
return RemoveAllocationImpl(size);
}
}
std::unique_ptr<Allocation> BufferedAllocator::Allocate(size_t size,
Allocator::Attr attr) {
auto ret = RemoveAllocation(size);
if (!ret) {
try {
return underlying_allocator_->Allocate(size, attr);
} catch (BadAlloc&) {
// if allocation failed, try to free some memorys from buffers
FreeAllocations(size);
return underlying_allocator_->Allocate(size, attr);
}
}
return ret;
}
void BufferedAllocator::FreeAllocationsImpl(size_t size) {
if (UNLIKELY(size == 0)) return;
size_t cur = 0;
for (auto& alloc_map : allocations_) {
// use reverse iterator to free large allocations first
while (!alloc_map.empty()) {
auto it = --(alloc_map.end());
cur += it->second->size();
underlying_allocator_->FreeUniquePtr(std::move(it->second));
alloc_map.erase(it);
if (cur >= size) return;
}
}
}
void BufferedAllocator::FreeAllocations(size_t size) {
if (mtx_) {
std::lock_guard<std::mutex> lock(*mtx_);
FreeAllocationsImpl(size);
} else {
FreeAllocationsImpl(size);
}
}
void BufferedAllocator::FreeUniquePtr(std::unique_ptr<Allocation> allocation) {
InsertAllocation(std::move(allocation));
}
bool BufferedAllocator::IsAllocThreadSafe() const { return mtx_ != nullptr; }
const std::vector<size_t>& BufferedAllocator::GetDivisionPlan() const {
return division_plan_;
}
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <map>
#include <memory>
#include <vector>
#include "paddle/fluid/memory/allocation/allocator.h"
namespace paddle {
namespace memory {
namespace allocation {
// NOTE(zjl): BufferedAllocator maintains a memory pool to accelerate
// memory allocation and reuse memory.
// BufferedAllocator provides the same thread-safety level as
// underlying_allocator_
class BufferedAllocator : public UnmanagedAllocator {
public:
explicit BufferedAllocator(std::unique_ptr<Allocator>&& allocator);
BufferedAllocator(std::unique_ptr<Allocator>&& allocator,
const std::vector<size_t>& division_plan);
~BufferedAllocator();
std::unique_ptr<Allocation> Allocate(
size_t size, Allocator::Attr attr = Allocator::Attr::kDefault) override;
void FreeUniquePtr(std::unique_ptr<Allocation> allocation) override;
bool IsAllocThreadSafe() const override;
const std::vector<size_t>& GetDivisionPlan() const;
void Flush();
private:
void InitAndEnforceCheck(std::unique_ptr<Allocator>&& allocator,
const std::vector<size_t>& division_plan);
void InsertAllocation(std::unique_ptr<Allocation>&& allocation);
void InsertAllocationImpl(std::unique_ptr<Allocation>&& allocation);
static bool Match(size_t actual_size, size_t requested_size);
std::unique_ptr<Allocation> RemoveAllocation(size_t size);
std::unique_ptr<Allocation> RemoveAllocationImpl(size_t size);
void FreeAllocations(size_t size);
void FreeAllocationsImpl(size_t size);
void FlushImpl();
size_t GetListIndex(size_t size);
std::unique_ptr<UnmanagedAllocator> underlying_allocator_;
std::vector<std::multimap<size_t, std::unique_ptr<Allocation>>> allocations_;
std::vector<size_t> division_plan_;
std::unique_ptr<std::mutex> mtx_;
};
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/buffered_allocator.h"
#include <gtest/gtest.h>
#include "paddle/fluid/memory/allocation/best_fit_allocator.h"
#include "paddle/fluid/memory/allocation/cpu_allocator.h"
#include "paddle/fluid/memory/allocation/locked_allocator.h"
namespace paddle {
namespace memory {
namespace allocation {
inline std::unique_ptr<BufferedAllocator> GetBufferedAllocator(
Allocation *allocation, bool thread_safe) {
std::unique_ptr<Allocator> allocator(new BestFitAllocator(allocation));
if (thread_safe) {
allocator.reset(new LockedAllocator(std::move(allocator)));
}
return std::unique_ptr<BufferedAllocator>(
new BufferedAllocator(std::move(allocator)));
}
TEST(buffered_allocator, thread_safety) {
std::unique_ptr<CPUAllocator> allocator(new CPUAllocator());
auto chunk = allocator->Allocate(1 << 20);
{
auto buf_allocator = GetBufferedAllocator(chunk.get(), true);
ASSERT_EQ(buf_allocator->IsAllocThreadSafe(), true);
}
{
auto buf_allocator = GetBufferedAllocator(chunk.get(), false);
ASSERT_EQ(buf_allocator->IsAllocThreadSafe(), false);
}
allocator->FreeUniquePtr(std::move(chunk));
}
class StubAllocation : public Allocation {
public:
using Allocation::Allocation;
};
class StubAllocator : public UnmanagedAllocator {
public:
std::unique_ptr<Allocation> Allocate(size_t size,
Allocator::Attr attr) override {
++construct_count_;
if (size == 0) {
return std::unique_ptr<Allocation>(
new StubAllocation(nullptr, 0, platform::CPUPlace()));
} else {
return std::unique_ptr<Allocation>(
new StubAllocation(new uint8_t[size], size, platform::CPUPlace()));
}
}
void FreeUniquePtr(std::unique_ptr<Allocation> allocation) {
StubAllocation *alloc = dynamic_cast<StubAllocation *>(allocation.get());
PADDLE_ENFORCE_NOT_NULL(alloc);
if (alloc->ptr()) delete[] static_cast<uint8_t *>(alloc->ptr());
++destruct_count_;
}
void ResetCounter() {
construct_count_ = 0;
destruct_count_ = 0;
}
size_t GetAllocCount() const { return construct_count_; }
size_t GetFreeCount() const { return destruct_count_; }
private:
size_t construct_count_ = 0;
size_t destruct_count_ = 0;
};
constexpr size_t kZero = 0;
constexpr size_t kOne = 1;
constexpr size_t kTwo = 2;
TEST(buffered_allocator, lazy_free) {
std::unique_ptr<StubAllocator> stub_allocator(new StubAllocator());
auto *underlying_allocator = stub_allocator.get();
std::unique_ptr<BufferedAllocator> allocator(
new BufferedAllocator(std::move(stub_allocator)));
{
underlying_allocator->ResetCounter();
auto x = allocator->Allocate(1025);
ASSERT_EQ(underlying_allocator->GetAllocCount(), kOne);
ASSERT_EQ(underlying_allocator->GetFreeCount(), kZero);
allocator->FreeUniquePtr(std::move(x));
ASSERT_EQ(underlying_allocator->GetFreeCount(), kZero);
}
{
underlying_allocator->ResetCounter();
auto x = allocator->Allocate(900);
ASSERT_EQ(underlying_allocator->GetAllocCount(), kZero);
ASSERT_EQ(underlying_allocator->GetFreeCount(), kZero);
auto y = allocator->Allocate(2048);
ASSERT_EQ(underlying_allocator->GetAllocCount(), kOne);
ASSERT_EQ(underlying_allocator->GetFreeCount(), kZero);
allocator->FreeUniquePtr(std::move(x));
ASSERT_EQ(underlying_allocator->GetFreeCount(), kZero);
allocator->FreeUniquePtr(std::move(y));
ASSERT_EQ(underlying_allocator->GetFreeCount(), kZero);
}
{
underlying_allocator->ResetCounter();
allocator->Flush();
ASSERT_EQ(underlying_allocator->GetAllocCount(), kZero);
ASSERT_EQ(underlying_allocator->GetFreeCount(), kTwo);
}
}
TEST(buffered_allocator, garbage_collection) {
std::unique_ptr<CPUAllocator> cpu_allocator(new CPUAllocator());
auto chunk = cpu_allocator->Allocate(2048);
auto allocator = GetBufferedAllocator(chunk.get(), false);
auto x1 = allocator->Allocate(1600);
auto x2 = allocator->Allocate(400);
allocator->FreeUniquePtr(std::move(x1));
allocator->FreeUniquePtr(std::move(x2));
auto x3 = allocator->Allocate(1600);
ASSERT_NE(x3, nullptr);
ASSERT_NE(x3->ptr(), nullptr);
}
} // namespace allocation
} // namespace memory
} // namespace paddle
...@@ -29,8 +29,8 @@ std::unique_ptr<Allocation> CPUAllocator::Allocate(size_t size, Attr attr) { ...@@ -29,8 +29,8 @@ std::unique_ptr<Allocation> CPUAllocator::Allocate(size_t size, Attr attr) {
} }
return std::unique_ptr<Allocation>(new CPUAllocation(ptr, size)); return std::unique_ptr<Allocation>(new CPUAllocation(ptr, size));
} }
void CPUAllocator::Free(Allocation* allocation) { void CPUAllocator::FreeUniquePtr(std::unique_ptr<Allocation> allocation) {
PADDLE_ENFORCE_NOT_NULL(dynamic_cast<CPUAllocation*>(allocation)); PADDLE_ENFORCE_NOT_NULL(dynamic_cast<CPUAllocation*>(allocation.get()));
free(allocation->ptr()); free(allocation->ptr());
} }
......
...@@ -36,7 +36,7 @@ class CPUAllocator : public UnmanagedAllocator { ...@@ -36,7 +36,7 @@ class CPUAllocator : public UnmanagedAllocator {
constexpr static size_t kAlignment = 64u; constexpr static size_t kAlignment = 64u;
std::unique_ptr<Allocation> Allocate(size_t size, std::unique_ptr<Allocation> Allocate(size_t size,
Attr attr = kDefault) override; Attr attr = kDefault) override;
void Free(Allocation* allocation) override; void FreeUniquePtr(std::unique_ptr<Allocation> allocation) override;
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
}; };
} // namespace allocation } // namespace allocation
......
...@@ -35,9 +35,9 @@ std::unique_ptr<Allocation> CUDAAllocator::Allocate(size_t size, Attr attr) { ...@@ -35,9 +35,9 @@ std::unique_ptr<Allocation> CUDAAllocator::Allocate(size_t size, Attr attr) {
new CUDAAllocation(ptr, size, platform::Place(place_))); new CUDAAllocation(ptr, size, platform::Place(place_)));
} }
void CUDAAllocator::Free(Allocation* allocation) { void CUDAAllocator::FreeUniquePtr(std::unique_ptr<Allocation> allocation) {
platform::CUDADeviceGuard guard(place_.device); platform::CUDADeviceGuard guard(place_.device);
auto* cuda_allocation = dynamic_cast<CUDAAllocation*>(allocation); auto* cuda_allocation = dynamic_cast<CUDAAllocation*>(allocation.get());
PADDLE_ENFORCE_NOT_NULL(cuda_allocation); PADDLE_ENFORCE_NOT_NULL(cuda_allocation);
PADDLE_ENFORCE_EQ(boost::get<platform::CUDAPlace>(cuda_allocation->place()), PADDLE_ENFORCE_EQ(boost::get<platform::CUDAPlace>(cuda_allocation->place()),
place_); place_);
......
...@@ -34,7 +34,7 @@ class CUDAAllocator : public UnmanagedAllocator { ...@@ -34,7 +34,7 @@ class CUDAAllocator : public UnmanagedAllocator {
: place_(boost::get<platform::CUDAPlace>(place)) {} : place_(boost::get<platform::CUDAPlace>(place)) {}
std::unique_ptr<Allocation> Allocate(size_t size, std::unique_ptr<Allocation> Allocate(size_t size,
Attr attr = kDefault) override; Attr attr = kDefault) override;
void Free(Allocation* allocation) override; void FreeUniquePtr(std::unique_ptr<Allocation> allocation) override;
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
private: private:
......
...@@ -27,12 +27,12 @@ std::unique_ptr<Allocation> LockedAllocator::Allocate(size_t size, Attr attr) { ...@@ -27,12 +27,12 @@ std::unique_ptr<Allocation> LockedAllocator::Allocate(size_t size, Attr attr) {
return underlying_allocator_->Allocate(size, attr); return underlying_allocator_->Allocate(size, attr);
} }
} }
void LockedAllocator::Free(Allocation *allocation) { void LockedAllocator::FreeUniquePtr(std::unique_ptr<Allocation> allocation) {
if (underlying_allocator_->IsAllocThreadSafe()) { if (underlying_allocator_->IsAllocThreadSafe()) {
return underlying_allocator_->Free(allocation); return underlying_allocator_->FreeUniquePtr(std::move(allocation));
} else { } else {
std::lock_guard<std::mutex> guard(mtx_); std::lock_guard<std::mutex> guard(mtx_);
return underlying_allocator_->Free(allocation); return underlying_allocator_->FreeUniquePtr(std::move(allocation));
} }
} }
bool LockedAllocator::IsAllocThreadSafe() const { return true; } bool LockedAllocator::IsAllocThreadSafe() const { return true; }
......
...@@ -27,7 +27,7 @@ class LockedAllocator : public UnmanagedAllocator { ...@@ -27,7 +27,7 @@ class LockedAllocator : public UnmanagedAllocator {
explicit LockedAllocator(std::unique_ptr<Allocator>&& underlying_allocator); explicit LockedAllocator(std::unique_ptr<Allocator>&& underlying_allocator);
std::unique_ptr<Allocation> Allocate(size_t size, std::unique_ptr<Allocation> Allocate(size_t size,
Attr attr = kDefault) override; Attr attr = kDefault) override;
void Free(Allocation* allocation) override; void FreeUniquePtr(std::unique_ptr<Allocation> allocation) override;
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
private: private:
......
...@@ -31,7 +31,9 @@ class StubAllocator : public UnmanagedAllocator { ...@@ -31,7 +31,9 @@ class StubAllocator : public UnmanagedAllocator {
return std::unique_ptr<Allocation>( return std::unique_ptr<Allocation>(
new Allocation(nullptr, size, platform::CPUPlace())); new Allocation(nullptr, size, platform::CPUPlace()));
} }
void Free(Allocation* allocation) override { counter_.fetch_sub(1); } void FreeUniquePtr(std::unique_ptr<Allocation> allocation) override {
counter_.fetch_sub(1);
}
bool IsAllocThreadSafe() const override { return true; } bool IsAllocThreadSafe() const override { return true; }
std::atomic<int> counter_{0}; std::atomic<int> counter_{0};
......
...@@ -32,8 +32,8 @@ std::unique_ptr<Allocation> CPUPinnedAllocator::Allocate(size_t size, ...@@ -32,8 +32,8 @@ std::unique_ptr<Allocation> CPUPinnedAllocator::Allocate(size_t size,
new CPUPinnedAllocation(ptr, size)); new CPUPinnedAllocation(ptr, size));
} }
void CPUPinnedAllocator::Free(Allocation* allocation) { void CPUPinnedAllocator::FreeUniquePtr(std::unique_ptr<Allocation> allocation) {
PADDLE_ENFORCE_NOT_NULL(dynamic_cast<CPUPinnedAllocation*>(allocation)); PADDLE_ENFORCE_NOT_NULL(dynamic_cast<CPUPinnedAllocation*>(allocation.get()));
PADDLE_ENFORCE(cudaFreeHost(allocation->ptr())); PADDLE_ENFORCE(cudaFreeHost(allocation->ptr()));
} }
......
...@@ -29,7 +29,7 @@ class CPUPinnedAllocation : public Allocation { ...@@ -29,7 +29,7 @@ class CPUPinnedAllocation : public Allocation {
class CPUPinnedAllocator : public UnmanagedAllocator { class CPUPinnedAllocator : public UnmanagedAllocator {
public: public:
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override; std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override;
void Free(Allocation* allocation) override; void FreeUniquePtr(std::unique_ptr<Allocation> allocation) override;
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
}; };
......
...@@ -75,7 +75,7 @@ Allocation* RetryAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { ...@@ -75,7 +75,7 @@ Allocation* RetryAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
} }
void RetryAllocator::FreeUnderlyingAllocation( void RetryAllocator::FreeUnderlyingAllocation(
std::unique_ptr<Allocation>&& allocation) { std::unique_ptr<Allocation>&& allocation) {
underlying_allocator_->Free(allocation.get()); underlying_allocator_->FreeUniquePtr(std::move(allocation));
{ {
// notify all waited allocators, they can try to allocate memory after free. // notify all waited allocators, they can try to allocate memory after free.
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册