未验证 提交 569951c4 编写于 作者: L liuwei1031 提交者: GitHub

improve the efficiency of BuddyAllocator (#19888)

* improve save and load behaviour, test=develop

* code cleaning, test=develop

* disable check_guards and update_guards in release version, test=develop

* fix compilation issue, test=develop

* add buddy_allocator speed test data, test=develop

* fix compilation issue, test=develop

* fix comment, test=develop

* update function names according to the google C++ style guide, test=develop

* tweak the test data format, test=develop

* move buddy_allocator_test_data to paddle/fluid/testdata, test=develop

* add accessor and mutator for Desc, test=develop
上级 eafc7023
...@@ -11,3 +11,9 @@ cc_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocato ...@@ -11,3 +11,9 @@ cc_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocato
cc_library(buddy_allocator SRCS buddy_allocator.cc DEPS memory_block system_allocator glog) cc_library(buddy_allocator SRCS buddy_allocator.cc DEPS memory_block system_allocator glog)
cc_test(buddy_allocator_test SRCS buddy_allocator_test.cc DEPS buddy_allocator) cc_test(buddy_allocator_test SRCS buddy_allocator_test.cc DEPS buddy_allocator)
if(WITH_TESTING)
add_custom_command(TARGET buddy_allocator_test POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/../../testdata/buddy_allocator_test_data ${CMAKE_CURRENT_BINARY_DIR}/buddy_allocator_test_data
)
endif()
...@@ -40,11 +40,11 @@ BuddyAllocator::~BuddyAllocator() { ...@@ -40,11 +40,11 @@ BuddyAllocator::~BuddyAllocator() {
"have actually been freed"; "have actually been freed";
while (!pool_.empty()) { while (!pool_.empty()) {
auto block = static_cast<MemoryBlock*>(std::get<2>(*pool_.begin())); auto block = static_cast<MemoryBlock*>(std::get<2>(*pool_.begin()));
VLOG(10) << "Free from block (" << block << ", " << block->size(cache_) auto desc = cache_.LoadDesc(block);
<< ")"; VLOG(10) << "Free from block (" << block << ", " << desc->get_size() << ")";
system_allocator_->Free(block, block->size(cache_), block->index(cache_)); system_allocator_->Free(block, desc->get_size(), desc->get_index());
cache_.invalidate(block); cache_.Invalidate(block);
pool_.erase(pool_.begin()); pool_.erase(pool_.begin());
} }
} }
...@@ -84,82 +84,83 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) { ...@@ -84,82 +84,83 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) {
} else { } else {
VLOG(10) << "Allocation from existing memory block " << std::get<2>(*it) VLOG(10) << "Allocation from existing memory block " << std::get<2>(*it)
<< " at address " << " at address "
<< reinterpret_cast<MemoryBlock*>(std::get<2>(*it))->data(); << reinterpret_cast<MemoryBlock*>(std::get<2>(*it))->Data();
} }
total_used_ += size; total_used_ += size;
total_free_ -= size; total_free_ -= size;
// split the allocation and return data for use // split the allocation and return data for use
return reinterpret_cast<MemoryBlock*>(SplitToAlloc(it, size))->data(); return reinterpret_cast<MemoryBlock*>(SplitToAlloc(it, size))->Data();
} }
void BuddyAllocator::Free(void* p) { void BuddyAllocator::Free(void* p) {
// Point back to metadata // Point back to metadata
auto block = static_cast<MemoryBlock*>(p)->metadata(); auto block = static_cast<MemoryBlock*>(p)->Metadata();
// Acquire the allocator lock // Acquire the allocator lock
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
VLOG(10) << "Free from address " << block; VLOG(10) << "Free from address " << block;
if (block->type(cache_) == MemoryBlock::HUGE_CHUNK) { auto* desc = cache_.LoadDesc(block);
if (desc->get_type() == MemoryBlock::HUGE_CHUNK) {
VLOG(10) << "Free directly from system allocator"; VLOG(10) << "Free directly from system allocator";
system_allocator_->Free(block, block->total_size(cache_), system_allocator_->Free(block, desc->get_total_size(), desc->get_index());
block->index(cache_));
// Invalidate GPU allocation from cache // Invalidate GPU allocation from cache
cache_.invalidate(block); cache_.Invalidate(block);
return; return;
} }
block->mark_as_free(&cache_); block->MarkAsFree(&cache_);
total_used_ -= block->total_size(cache_); total_used_ -= desc->get_total_size();
total_free_ += block->total_size(cache_); total_free_ += desc->get_total_size();
// Trying to merge the right buddy // Trying to merge the right buddy
if (block->has_right_buddy(cache_)) { MemoryBlock* right_buddy = block->GetRightBuddy(&cache_);
if (right_buddy) {
VLOG(10) << "Merging this block " << block << " with its right buddy " VLOG(10) << "Merging this block " << block << " with its right buddy "
<< block->right_buddy(cache_); << right_buddy;
auto right_buddy = block->right_buddy(cache_); auto rb_desc = cache_.LoadDesc(right_buddy);
if (rb_desc->get_type() == MemoryBlock::FREE_CHUNK) {
if (right_buddy->type(cache_) == MemoryBlock::FREE_CHUNK) {
// Take away right buddy from pool // Take away right buddy from pool
pool_.erase(IndexSizeAddress(right_buddy->index(cache_), pool_.erase(IndexSizeAddress(rb_desc->get_index(),
right_buddy->total_size(cache_), rb_desc->get_total_size(), right_buddy));
right_buddy));
// merge its right buddy to the block // merge its right buddy to the block
block->merge(&cache_, right_buddy); block->Merge(&cache_, right_buddy);
} }
} }
// Trying to merge the left buddy // Trying to merge the left buddy
if (block->has_left_buddy(cache_)) { MemoryBlock* left_buddy = block->GetLeftBuddy(&cache_);
if (left_buddy) {
VLOG(10) << "Merging this block " << block << " with its left buddy " VLOG(10) << "Merging this block " << block << " with its left buddy "
<< block->left_buddy(cache_); << left_buddy;
auto left_buddy = block->left_buddy(cache_);
if (left_buddy->type(cache_) == MemoryBlock::FREE_CHUNK) { // auto left_buddy = block->left_buddy(cache_);
auto* lb_desc = cache_.LoadDesc(left_buddy);
if (lb_desc->get_type() == MemoryBlock::FREE_CHUNK) {
// Take away right buddy from pool // Take away right buddy from pool
pool_.erase(IndexSizeAddress(left_buddy->index(cache_), pool_.erase(IndexSizeAddress(lb_desc->get_index(),
left_buddy->total_size(cache_), left_buddy)); lb_desc->get_total_size(), left_buddy));
// merge the block to its left buddy // merge the block to its left buddy
left_buddy->merge(&cache_, block); left_buddy->Merge(&cache_, block);
block = left_buddy; block = left_buddy;
desc = lb_desc;
} }
} }
// Dumping this block into pool // Dumping this block into pool
VLOG(10) << "Inserting free block (" << block << ", " VLOG(10) << "Inserting free block (" << block << ", "
<< block->total_size(cache_) << ")"; << desc->get_total_size() << ")";
pool_.insert( pool_.insert(
IndexSizeAddress(block->index(cache_), block->total_size(cache_), block)); IndexSizeAddress(desc->get_index(), desc->get_total_size(), block));
} }
size_t BuddyAllocator::Used() { return total_used_; } size_t BuddyAllocator::Used() { return total_used_; }
...@@ -174,10 +175,10 @@ void* BuddyAllocator::SystemAlloc(size_t size) { ...@@ -174,10 +175,10 @@ void* BuddyAllocator::SystemAlloc(size_t size) {
if (p == nullptr) return nullptr; if (p == nullptr) return nullptr;
static_cast<MemoryBlock*>(p)->init(&cache_, MemoryBlock::HUGE_CHUNK, index, static_cast<MemoryBlock*>(p)->Init(&cache_, MemoryBlock::HUGE_CHUNK, index,
size, nullptr, nullptr); size, nullptr, nullptr);
return static_cast<MemoryBlock*>(p)->data(); return static_cast<MemoryBlock*>(p)->Data();
} }
BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool( BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool(
...@@ -209,7 +210,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool( ...@@ -209,7 +210,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool(
VLOG(10) << "Creating and inserting new block " << p VLOG(10) << "Creating and inserting new block " << p
<< " from system allocator"; << " from system allocator";
static_cast<MemoryBlock*>(p)->init(&cache_, MemoryBlock::FREE_CHUNK, index, static_cast<MemoryBlock*>(p)->Init(&cache_, MemoryBlock::FREE_CHUNK, index,
allocate_bytes, nullptr, nullptr); allocate_bytes, nullptr, nullptr);
total_free_ += allocate_bytes; total_free_ += allocate_bytes;
...@@ -243,26 +244,26 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::FindExistChunk(size_t size) { ...@@ -243,26 +244,26 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::FindExistChunk(size_t size) {
void* BuddyAllocator::SplitToAlloc(BuddyAllocator::PoolSet::iterator it, void* BuddyAllocator::SplitToAlloc(BuddyAllocator::PoolSet::iterator it,
size_t size) { size_t size) {
auto block = static_cast<MemoryBlock*>(std::get<2>(*it)); auto block = static_cast<MemoryBlock*>(std::get<2>(*it));
auto desc = cache_.LoadDesc(block);
pool_.erase(it); pool_.erase(it);
VLOG(10) << "Split block (" << block << ", " << block->total_size(cache_) VLOG(10) << "Split block (" << block << ", " << desc->get_total_size()
<< ") into"; << ") into";
block->split(&cache_, size); block->Split(&cache_, size);
VLOG(10) << "Left block (" << block << ", " << block->total_size(cache_) VLOG(10) << "Left block (" << block << ", " << desc->get_total_size() << ")";
<< ")"; desc->set_type(MemoryBlock::ARENA_CHUNK);
block->set_type(&cache_, MemoryBlock::ARENA_CHUNK);
// the rest of memory if exist // the rest of memory if exist
if (block->has_right_buddy(cache_)) { MemoryBlock* right_buddy = block->GetRightBuddy(&cache_);
if (block->right_buddy(cache_)->type(cache_) == MemoryBlock::FREE_CHUNK) { if (right_buddy) {
VLOG(10) << "Insert right block (" << block->right_buddy(cache_) << ", " auto* rb_desc = cache_.LoadDesc(right_buddy);
<< block->right_buddy(cache_)->total_size(cache_) << ")"; if (rb_desc->get_type() == MemoryBlock::FREE_CHUNK) {
VLOG(10) << "Insert right block (" << right_buddy << ", "
pool_.insert( << rb_desc->get_total_size() << ")";
IndexSizeAddress(block->right_buddy(cache_)->index(cache_),
block->right_buddy(cache_)->total_size(cache_), pool_.insert(IndexSizeAddress(rb_desc->get_index(),
block->right_buddy(cache_))); rb_desc->get_total_size(), right_buddy));
} }
} }
......
...@@ -16,6 +16,9 @@ limitations under the License. */ ...@@ -16,6 +16,9 @@ limitations under the License. */
#include <memory> #include <memory>
#ifdef WITH_GPERFTOOLS
#include "gperftools/profiler.h"
#endif
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/memory/detail/system_allocator.h" #include "paddle/fluid/memory/detail/system_allocator.h"
...@@ -24,6 +27,9 @@ limitations under the License. */ ...@@ -24,6 +27,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <fstream>
#include <string>
DECLARE_double(fraction_of_gpu_memory_to_use); DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb); DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb); DECLARE_uint64(reallocate_gpu_memory_in_mb);
...@@ -235,6 +241,77 @@ TEST(BuddyAllocator, AllocFromAvailableWhenFractionIsOne) { ...@@ -235,6 +241,77 @@ TEST(BuddyAllocator, AllocFromAvailableWhenFractionIsOne) {
} }
} }
TEST(BuddyAllocator, SpeedAna) {
// In a 16 GB machine, the pool size will be about 160 MB
FLAGS_fraction_of_gpu_memory_to_use = 0.5;
FLAGS_initial_gpu_memory_in_mb = 0;
FLAGS_reallocate_gpu_memory_in_mb = 0;
BuddyAllocator buddy_allocator(
std::unique_ptr<SystemAllocator>(new GPUAllocator(TEST_GPU_ID)),
platform::GpuMinChunkSize(), platform::GpuMaxChunkSize());
// Less than pool size
TestBuddyAllocator(&buddy_allocator, 10);
TestBuddyAllocator(&buddy_allocator, 10 << 10);
TestBuddyAllocator(&buddy_allocator, 10 << 20);
std::fstream in_file;
in_file.open("buddy_allocator_test_data", std::ios::in);
std::vector<void*> vec_ptr;
std::vector<int> vec_size;
std::vector<int> vec_pos;
std::vector<bool> vec_free_flag;
std::string line;
int size, id;
while (in_file >> size >> id) {
vec_size.push_back(size);
vec_pos.push_back(id);
}
vec_ptr.reserve(vec_size.size());
auto start = std::chrono::steady_clock::now();
#ifdef WITH_GPERFTOOLS
ProfilerStart("test.prof");
#endif
for (size_t loop = 0; loop < 5000; ++loop) {
vec_ptr.clear();
for (size_t i = 0; i < vec_size.size(); ++i) {
if (vec_pos[i] == -1) {
auto res = buddy_allocator.Alloc(vec_size[i]);
vec_ptr.push_back(res);
} else {
vec_ptr.push_back(nullptr);
auto free_ptr = vec_ptr[vec_pos[i]];
EXPECT_NE(free_ptr, nullptr);
vec_ptr[vec_pos[i]] = nullptr;
buddy_allocator.Free(free_ptr);
}
}
for (size_t i = 0; i < vec_size.size(); ++i) {
if (vec_ptr[i] != nullptr) {
buddy_allocator.Free(vec_ptr[i]);
}
}
}
#ifdef WITH_GPERFTOOLS
ProfilerStop();
#endif
auto end = std::chrono::steady_clock::now();
std::chrono::duration<double> diff = end - start;
std::cerr << "time cost " << diff.count() << std::endl;
}
#endif #endif
} // namespace detail } // namespace detail
......
...@@ -19,134 +19,104 @@ namespace paddle { ...@@ -19,134 +19,104 @@ namespace paddle {
namespace memory { namespace memory {
namespace detail { namespace detail {
void MemoryBlock::init(MetadataCache* cache, Type t, size_t index, size_t size, void MemoryBlock::Init(MetadataCache* cache, Type t, size_t index, size_t size,
void* left_buddy, void* right_buddy) { void* left_buddy, void* right_buddy) {
cache->save( cache->Save(
this, MemoryBlock::Desc(t, index, size - sizeof(MemoryBlock::Desc), size, this, MemoryBlock::Desc(t, index, size - sizeof(MemoryBlock::Desc), size,
static_cast<MemoryBlock*>(left_buddy), static_cast<MemoryBlock*>(left_buddy),
static_cast<MemoryBlock*>(right_buddy))); static_cast<MemoryBlock*>(right_buddy)));
} }
MemoryBlock::Type MemoryBlock::type(const MetadataCache& cache) const { MemoryBlock* MemoryBlock::GetLeftBuddy(MetadataCache* cache) {
return cache.load(this).type; return cache->LoadDesc(this)->left_buddy;
} }
size_t MemoryBlock::size(const MetadataCache& cache) const { MemoryBlock* MemoryBlock::GetRightBuddy(MetadataCache* cache) {
return cache.load(this).size; return cache->LoadDesc(this)->right_buddy;
} }
size_t MemoryBlock::index(const MetadataCache& cache) const { void MemoryBlock::Split(MetadataCache* cache, size_t size) {
return cache.load(this).index; auto desc = cache->LoadDesc(this);
}
size_t MemoryBlock::total_size(const MetadataCache& cache) const {
return cache.load(this).total_size;
}
bool MemoryBlock::has_left_buddy(const MetadataCache& cache) const {
return left_buddy(cache) != nullptr;
}
bool MemoryBlock::has_right_buddy(const MetadataCache& cache) const {
return right_buddy(cache) != nullptr;
}
MemoryBlock* MemoryBlock::left_buddy(const MetadataCache& cache) const {
return cache.load(this).left_buddy;
}
MemoryBlock* MemoryBlock::right_buddy(const MetadataCache& cache) const {
return cache.load(this).right_buddy;
}
void MemoryBlock::split(MetadataCache* cache, size_t size) {
// make sure the split fits // make sure the split fits
PADDLE_ENFORCE_GE(total_size(*cache), size); PADDLE_ENFORCE_GE(desc->total_size, size);
// bail out if there is no room for another partition // bail out if there is no room for another partition
if (total_size(*cache) - size <= sizeof(MemoryBlock::Desc)) { if (desc->total_size - size <= sizeof(MemoryBlock::Desc)) {
return; return;
} }
// find the position of the split // find the position of the split
void* right_partition = reinterpret_cast<uint8_t*>(this) + size; void* right_partition = reinterpret_cast<uint8_t*>(this) + size;
size_t remaining_size = total_size(*cache) - size; size_t remaining_size = desc->total_size - size;
// Add the new block as a buddy // Add the new block as a buddy
auto metadata = cache->load(this);
// Write the metadata for the new block // Write the metadata for the new block
auto new_block_right_buddy = metadata.right_buddy; auto new_block_right_buddy = desc->right_buddy;
cache->save(static_cast<MemoryBlock*>(right_partition), cache->Save(static_cast<MemoryBlock*>(right_partition),
MemoryBlock::Desc(FREE_CHUNK, index(*cache), MemoryBlock::Desc(FREE_CHUNK, desc->index,
remaining_size - sizeof(MemoryBlock::Desc), remaining_size - sizeof(MemoryBlock::Desc),
remaining_size, this, new_block_right_buddy)); remaining_size, this, new_block_right_buddy));
metadata.right_buddy = static_cast<MemoryBlock*>(right_partition); desc->right_buddy = static_cast<MemoryBlock*>(right_partition);
metadata.size = size - sizeof(MemoryBlock::Desc); desc->size = size - sizeof(MemoryBlock::Desc);
metadata.total_size = size; desc->total_size = size;
cache->save(this, metadata); desc->UpdateGuards();
// Write metadata for the new block's right buddy // Write metadata for the new block's right buddy
if (new_block_right_buddy != nullptr) { if (new_block_right_buddy != nullptr) {
auto buddy_metadata = cache->load(new_block_right_buddy); auto buddy_desc = cache->LoadDesc(new_block_right_buddy);
buddy_metadata.left_buddy = static_cast<MemoryBlock*>(right_partition); buddy_desc->left_buddy = static_cast<MemoryBlock*>(right_partition);
buddy_desc->UpdateGuards();
cache->save(new_block_right_buddy, buddy_metadata);
} }
} }
void MemoryBlock::merge(MetadataCache* cache, MemoryBlock* right_buddy) { void MemoryBlock::Merge(MetadataCache* cache, MemoryBlock* right_buddy) {
// only free blocks can be merged // only free blocks can be merged
PADDLE_ENFORCE_EQ(type(*cache), FREE_CHUNK); auto desc = cache->LoadDesc(this);
PADDLE_ENFORCE_EQ(right_buddy->type(*cache), FREE_CHUNK); auto rb_desc = cache->LoadDesc(right_buddy);
PADDLE_ENFORCE_EQ(desc->type, FREE_CHUNK);
auto metadata = cache->load(this); PADDLE_ENFORCE_EQ(rb_desc->type, FREE_CHUNK);
// link this->buddy's buddy // link this->buddy's buddy
metadata.right_buddy = right_buddy->right_buddy(*cache); desc->right_buddy = rb_desc->right_buddy;
// link buddy's buddy -> this // link buddy's buddy -> this
if (metadata.right_buddy != nullptr) { if (desc->right_buddy != nullptr) {
auto buddy_metadata = cache->load(metadata.right_buddy); auto buddy_metadata = cache->LoadDesc(desc->right_buddy);
buddy_metadata.left_buddy = this; buddy_metadata->left_buddy = this;
buddy_metadata->UpdateGuards();
cache->save(metadata.right_buddy, buddy_metadata);
} }
metadata.size += right_buddy->total_size(*cache); desc->size += rb_desc->total_size;
metadata.total_size += right_buddy->total_size(*cache); desc->total_size += rb_desc->total_size;
desc->UpdateGuards();
cache->save(this, metadata); cache->Save(right_buddy,
cache->save(right_buddy,
MemoryBlock::Desc(INVALID_CHUNK, 0, 0, 0, nullptr, nullptr)); MemoryBlock::Desc(INVALID_CHUNK, 0, 0, 0, nullptr, nullptr));
} }
void MemoryBlock::mark_as_free(MetadataCache* cache) { void MemoryBlock::MarkAsFree(MetadataCache* cache) {
// check for double free or corruption // check for double free or corruption
PADDLE_ENFORCE_NE(type(*cache), FREE_CHUNK); auto desc = cache->LoadDesc(this);
PADDLE_ENFORCE_NE(type(*cache), INVALID_CHUNK); PADDLE_ENFORCE_NE(desc->type, FREE_CHUNK);
set_type(cache, FREE_CHUNK); PADDLE_ENFORCE_NE(desc->type, INVALID_CHUNK);
} desc->type = FREE_CHUNK;
desc->UpdateGuards();
void MemoryBlock::set_type(MetadataCache* cache, Type t) {
auto metadata = cache->load(this);
metadata.type = t;
cache->save(this, metadata);
} }
void* MemoryBlock::data() const { void* MemoryBlock::Data() const {
return const_cast<MemoryBlock::Desc*>( return const_cast<MemoryBlock::Desc*>(
reinterpret_cast<const MemoryBlock::Desc*>(this)) + reinterpret_cast<const MemoryBlock::Desc*>(this)) +
1; 1;
} }
MemoryBlock* MemoryBlock::metadata() const { MemoryBlock* MemoryBlock::Metadata() const {
return const_cast<MemoryBlock*>(reinterpret_cast<const MemoryBlock*>( return const_cast<MemoryBlock*>(reinterpret_cast<const MemoryBlock*>(
reinterpret_cast<const MemoryBlock::Desc*>(this) - 1)); reinterpret_cast<const MemoryBlock::Desc*>(this) - 1));
} }
......
...@@ -38,35 +38,23 @@ struct MemoryBlock { ...@@ -38,35 +38,23 @@ struct MemoryBlock {
// MemoryBlock::Desc to the beginning of the block; or, if it is a GPU memory // MemoryBlock::Desc to the beginning of the block; or, if it is a GPU memory
// block, the MetadataCache writes the Meatadata to a std::map in // block, the MetadataCache writes the Meatadata to a std::map in
// the CPU. // the CPU.
void init(MetadataCache* cache, Type t, size_t index, size_t size, void Init(MetadataCache* cache, Type t, size_t index, size_t size,
void* left_buddy, void* right_buddy); void* left_buddy, void* right_buddy);
// All these accessors returns fields in the MemoryBlock::Desc of the memory MemoryBlock* GetLeftBuddy(MetadataCache* cache);
// block. They all need a MetadataCache instance as their first MemoryBlock* GetRightBuddy(MetadataCache* cache);
// parameter because they read the MemoryBlock::Desc from the cache.
Type type(const MetadataCache& cache) const;
size_t size(const MetadataCache& cache) const;
size_t index(const MetadataCache& cache) const;
size_t total_size(const MetadataCache& cache) const;
bool has_left_buddy(const MetadataCache& cache) const;
bool has_right_buddy(const MetadataCache& cache) const;
MemoryBlock* left_buddy(const MetadataCache& cache) const;
MemoryBlock* right_buddy(const MetadataCache& cache) const;
// Split the allocation into left/right blocks. // Split the allocation into left/right blocks.
void split(MetadataCache* cache, size_t size); void Split(MetadataCache* cache, size_t size);
// Merge left and right blocks together. // Merge left and right blocks together.
void merge(MetadataCache* cache, MemoryBlock* right_buddy); void Merge(MetadataCache* cache, MemoryBlock* right_buddy);
// Mark the allocation as free. // Mark the allocation as free.
void mark_as_free(MetadataCache* cache); void MarkAsFree(MetadataCache* cache);
// Change the type of the allocation. void* Data() const;
void set_type(MetadataCache* cache, Type t); MemoryBlock* Metadata() const;
void* data() const;
MemoryBlock* metadata() const;
// MemoryBlock::Desc describes a MemoryBlock. // MemoryBlock::Desc describes a MemoryBlock.
struct Desc { struct Desc {
...@@ -74,11 +62,29 @@ struct MemoryBlock { ...@@ -74,11 +62,29 @@ struct MemoryBlock {
MemoryBlock* r); MemoryBlock* r);
Desc(); Desc();
// mutator for type
inline void set_type(const MemoryBlock::Type& type) {
this->type = type;
this->UpdateGuards();
}
// accessor for type
inline const MemoryBlock::Type& get_type() const { return this->type; }
// accessor for index
inline const size_t& get_index() const { return this->index; }
// accessor for size
inline const size_t& get_size() const { return this->size; }
// accessor for total_size
inline const size_t& get_total_size() const { return this->total_size; }
// Updates guard_begin and guard_end by hashes of the Metadata object. // Updates guard_begin and guard_end by hashes of the Metadata object.
void update_guards(); void UpdateGuards();
// Checks that guard_begin and guard_end are hashes of the Metadata object. // Checks that guard_begin and guard_end are hashes of the Metadata object.
bool check_guards() const; bool CheckGuards() const;
// TODO(gangliao): compress this // TODO(gangliao): compress this
size_t guard_begin = 0; size_t guard_begin = 0;
...@@ -109,15 +115,15 @@ class MetadataCache { ...@@ -109,15 +115,15 @@ class MetadataCache {
// used to manage CPU memory, the MemoryBlock::Desc resides at the beginning // used to manage CPU memory, the MemoryBlock::Desc resides at the beginning
// of the memory block; when used to manage GPU memory, the // of the memory block; when used to manage GPU memory, the
// Meatadata resides in CPU memory indexed by cache_. // Meatadata resides in CPU memory indexed by cache_.
MemoryBlock::Desc load(const MemoryBlock* memory_block) const; MemoryBlock::Desc* LoadDesc(MemoryBlock* memory_block);
// Saves the MemoryBlock::Desc of a memory block into the cache. For CPU // Saves the MemoryBlock::Desc of a memory block into the cache. For CPU
// memory block, writes the MemoryBlock::Desc to the beginning of the memory // memory block, writes the MemoryBlock::Desc to the beginning of the memory
// block; whereas for GPU memory, writes it to cache_. // block; whereas for GPU memory, writes it to cache_.
void save(MemoryBlock* memory_block, const MemoryBlock::Desc& meta_data); void Save(MemoryBlock* memory_block, const MemoryBlock::Desc& meta_data);
// For GPU memory block, erases its MemoryBlock::Desc from cache_. // For GPU memory block, erases its MemoryBlock::Desc from cache_.
void invalidate(MemoryBlock* memory_block); void Invalidate(MemoryBlock* memory_block);
private: private:
typedef std::unordered_map<const MemoryBlock*, MemoryBlock::Desc> MetadataMap; typedef std::unordered_map<const MemoryBlock*, MemoryBlock::Desc> MetadataMap;
......
...@@ -60,13 +60,19 @@ inline size_t hash(const MemoryBlock::Desc& metadata, size_t initial_seed) { ...@@ -60,13 +60,19 @@ inline size_t hash(const MemoryBlock::Desc& metadata, size_t initial_seed) {
} // namespace } // namespace
void MemoryBlock::Desc::update_guards() { void MemoryBlock::Desc::UpdateGuards() {
#ifdef PADDLE_WITH_TESTING
guard_begin = hash(*this, 1); guard_begin = hash(*this, 1);
guard_end = hash(*this, 2); guard_end = hash(*this, 2);
#endif
} }
bool MemoryBlock::Desc::check_guards() const { bool MemoryBlock::Desc::CheckGuards() const {
#ifdef PADDLE_WITH_TESTING
return guard_begin == hash(*this, 1) && guard_end == hash(*this, 2); return guard_begin == hash(*this, 1) && guard_end == hash(*this, 2);
#else
return true;
#endif
} }
} // namespace detail } // namespace detail
......
...@@ -22,23 +22,23 @@ namespace detail { ...@@ -22,23 +22,23 @@ namespace detail {
MetadataCache::MetadataCache(bool uses_gpu) : uses_gpu_(uses_gpu) {} MetadataCache::MetadataCache(bool uses_gpu) : uses_gpu_(uses_gpu) {}
MemoryBlock::Desc MetadataCache::load(const MemoryBlock* block) const { MemoryBlock::Desc* MetadataCache::LoadDesc(MemoryBlock* block) {
if (uses_gpu_) { if (uses_gpu_) {
auto existing_desc = cache_.find(block); auto existing_desc = cache_.find(block);
PADDLE_ENFORCE_EQ(existing_desc->second.check_guards(), true); PADDLE_ENFORCE_EQ(existing_desc->second.CheckGuards(), true);
return existing_desc->second; return &(existing_desc->second);
} else { } else {
auto* desc = reinterpret_cast<const MemoryBlock::Desc*>(block); auto* desc = reinterpret_cast<MemoryBlock::Desc*>(block);
VLOG(10) << "Load MemoryBlock::Desc type=" << desc->type; VLOG(10) << "Load MemoryBlock::Desc type=" << desc->type;
PADDLE_ENFORCE_EQ(desc->check_guards(), true); PADDLE_ENFORCE_EQ(desc->CheckGuards(), true);
return *reinterpret_cast<const MemoryBlock::Desc*>(block); return reinterpret_cast<MemoryBlock::Desc*>(block);
} }
} }
void MetadataCache::save(MemoryBlock* block, void MetadataCache::Save(MemoryBlock* block,
const MemoryBlock::Desc& original_desc) { const MemoryBlock::Desc& original_desc) {
auto desc = original_desc; auto desc = original_desc;
desc.update_guards(); desc.UpdateGuards();
if (uses_gpu_) { if (uses_gpu_) {
cache_[block] = desc; cache_[block] = desc;
...@@ -47,7 +47,7 @@ void MetadataCache::save(MemoryBlock* block, ...@@ -47,7 +47,7 @@ void MetadataCache::save(MemoryBlock* block,
} }
} }
void MetadataCache::invalidate(MemoryBlock* block) { void MetadataCache::Invalidate(MemoryBlock* block) {
if (uses_gpu_) { if (uses_gpu_) {
cache_.erase(block); cache_.erase(block);
} }
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册