未验证 提交 c64d9593 编写于 作者: Z Zeng Jinle 提交者: GitHub

Merge pull request #16295 from zhhsplendid/zhenghuihuang-dev-2

Add support for init_memory and re-allocate_memory
......@@ -61,4 +61,6 @@ nv_test(allocation_and_eigen_test SRCS allocation_and_eigen_test.cu DEPS allocat
cc_test(retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator best_fit_allocator locked_allocator cpu_allocator)
cc_test(allocator_facade_test SRCS allocator_facade_test.cc DEPS allocator_facade)
cc_test(allocator_facade_abs_flags_test SRCS allocator_facade_abs_flags_test.cc DEPS allocator_facade)
cc_test(allocator_facade_frac_flags_test SRCS allocator_facade_frac_flags_test.cc DEPS allocator_facade)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#ifdef PADDLE_WITH_CUDA
DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_double(fraction_of_cuda_pinned_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_int64(gpu_allocator_retry_time);
#endif
namespace paddle {
namespace memory {
namespace allocation {
//! Run allocate test cases for different places
void AllocateTestCases() {
auto &instance = AllocatorFacade::Instance();
platform::Place place;
size_t size = 1024;
{
place = platform::CPUPlace();
size = 1024;
auto cpu_allocation = instance.Alloc(place, size);
ASSERT_NE(cpu_allocation, nullptr);
ASSERT_NE(cpu_allocation->ptr(), nullptr);
ASSERT_EQ(cpu_allocation->place(), place);
ASSERT_EQ(cpu_allocation->size(), size);
}
#ifdef PADDLE_WITH_CUDA
{
place = platform::CUDAPlace(0);
size = 1024;
auto gpu_allocation = instance.Alloc(place, size);
ASSERT_NE(gpu_allocation, nullptr);
ASSERT_NE(gpu_allocation->ptr(), nullptr);
ASSERT_EQ(gpu_allocation->place(), place);
ASSERT_GE(gpu_allocation->size(), size);
}
{
// Allocate 2GB gpu memory
place = platform::CUDAPlace(0);
size = 2 * static_cast<size_t>(1 << 30);
auto gpu_allocation = instance.Alloc(place, size);
ASSERT_NE(gpu_allocation, nullptr);
ASSERT_NE(gpu_allocation->ptr(), nullptr);
ASSERT_EQ(gpu_allocation->place(), place);
ASSERT_GE(gpu_allocation->size(), size);
}
{
place = platform::CUDAPinnedPlace();
size = (1 << 20);
auto cuda_pinned_allocation =
instance.Alloc(platform::CUDAPinnedPlace(), 1 << 20);
ASSERT_NE(cuda_pinned_allocation, nullptr);
ASSERT_NE(cuda_pinned_allocation->ptr(), nullptr);
ASSERT_EQ(cuda_pinned_allocation->place(), place);
ASSERT_GE(cuda_pinned_allocation->size(), size);
}
#endif
}
TEST(Allocator, SpecifyGpuMemory) {
#ifdef PADDLE_WITH_CUDA
// Set to 0.0 to test FLAGS_initial_gpu_memory_in_mb and
// FLAGS_reallocate_gpu_memory_in_mb
FLAGS_fraction_of_gpu_memory_to_use = 0.0;
// 512 MB
FLAGS_initial_gpu_memory_in_mb = 512;
// 4 MB
FLAGS_reallocate_gpu_memory_in_mb = 4;
FLAGS_gpu_allocator_retry_time = 500;
FLAGS_fraction_of_cuda_pinned_memory_to_use = 0.5;
#endif
AllocateTestCases();
}
} // namespace allocation
} // namespace memory
} // namespace paddle
......@@ -19,6 +19,8 @@
#ifdef PADDLE_WITH_CUDA
DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_double(fraction_of_cuda_pinned_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_int64(gpu_allocator_retry_time);
#endif
......@@ -26,13 +28,8 @@ namespace paddle {
namespace memory {
namespace allocation {
TEST(allocator, allocator) {
#ifdef PADDLE_WITH_CUDA
FLAGS_fraction_of_gpu_memory_to_use = 0.01;
FLAGS_gpu_allocator_retry_time = 500;
FLAGS_fraction_of_cuda_pinned_memory_to_use = 0.5;
#endif
//! Run allocate test cases for different places
void AllocateTestCases() {
auto &instance = AllocatorFacade::Instance();
platform::Place place;
size_t size = 1024;
......@@ -82,6 +79,16 @@ TEST(allocator, allocator) {
#endif
}
TEST(Allocator, Allocator) {
#ifdef PADDLE_WITH_CUDA
FLAGS_fraction_of_gpu_memory_to_use = 0.01;
FLAGS_gpu_allocator_retry_time = 500;
FLAGS_fraction_of_cuda_pinned_memory_to_use = 0.5;
#endif
AllocateTestCases();
}
} // namespace allocation
} // namespace memory
} // namespace paddle
......@@ -37,6 +37,8 @@ DEFINE_bool(init_allocated_mem, false,
"that initializing the allocated memory with a small value "
"during unit testing.");
DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_bool(benchmark);
namespace paddle {
......@@ -153,12 +155,18 @@ BuddyAllocator *GetGPUBuddyAllocator(int gpu_id) {
platform::GpuMinChunkSize(),
platform::GpuMaxChunkSize());
VLOG(10) << "\n\nNOTE: each GPU device use "
<< FLAGS_fraction_of_gpu_memory_to_use * 100
<< "% of GPU memory.\n"
<< "You can set GFlags environment variable '"
<< "FLAGS_fraction_of_gpu_memory_to_use"
<< "' to change the fraction of GPU usage.\n\n";
VLOG(10) << "\n\nNOTE:\n"
<< "You can set GFlags environment variable "
<< "'FLAGS_fraction_of_gpu_memory_to_use' "
<< "or 'FLAGS_initial_gpu_memory_in_mb' "
<< "or 'FLAGS_reallocate_gpu_memory_in_mb' "
<< "to change the memory size for GPU usage.\n"
<< "Current 'FLAGS_fraction_of_gpu_memory_to_use' value is "
<< FLAGS_fraction_of_gpu_memory_to_use
<< ". Current 'FLAGS_initial_gpu_memory_in_mb' value is "
<< FLAGS_initial_gpu_memory_in_mb
<< ". Current 'FLAGS_reallocate_gpu_memory_in_mb' value is "
<< FLAGS_reallocate_gpu_memory_in_mb << "\n\n";
}
});
......
......@@ -9,3 +9,5 @@ endif(${WITH_GPU})
cc_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocator)
cc_library(buddy_allocator SRCS buddy_allocator.cc DEPS memory_block system_allocator glog)
cc_test(buddy_allocator_test SRCS buddy_allocator_test.cc DEPS buddy_allocator)
......@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/detail/buddy_allocator.h"
#include <algorithm>
#include <utility>
#include "glog/logging.h"
DEFINE_bool(free_idle_memory, false,
......@@ -36,9 +40,10 @@ BuddyAllocator::~BuddyAllocator() {
"have actually been freed";
while (!pool_.empty()) {
auto block = static_cast<MemoryBlock*>(std::get<2>(*pool_.begin()));
VLOG(10) << "Free from block (" << block << ", " << max_chunk_size_ << ")";
VLOG(10) << "Free from block (" << block << ", " << block->size(cache_)
<< ")";
system_allocator_->Free(block, max_chunk_size_, block->index(cache_));
system_allocator_->Free(block, block->size(cache_), block->index(cache_));
cache_.invalidate(block);
pool_.erase(pool_.begin());
}
......@@ -71,7 +76,7 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) {
// refill the pool if failure
if (it == pool_.end()) {
it = RefillPool();
it = RefillPool(size);
// if still failure, fail fatally
if (it == pool_.end()) {
return nullptr;
......@@ -184,19 +189,28 @@ void* BuddyAllocator::SystemAlloc(size_t size) {
return static_cast<MemoryBlock*>(p)->data();
}
BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() {
BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool(
size_t request_bytes) {
size_t allocate_bytes = max_chunk_size_;
size_t index = 0;
#ifdef PADDLE_WITH_CUDA
if (system_allocator_->UseGpu()) {
if ((total_used_ + total_free_) == 0) {
// Compute the maximum allocation size for the first allocation.
max_chunk_size_ = platform::GpuMaxChunkSize();
// Compute the allocation size for gpu for the first allocation.
allocate_bytes = std::max(platform::GpuInitAllocSize(), request_bytes);
} else {
// Reallocation size
if (realloc_size_ == 0) {
realloc_size_ = platform::GpuReallocSize();
}
allocate_bytes = std::max(realloc_size_, request_bytes);
}
}
#endif
// Allocate a new maximum sized block
size_t index = 0;
void* p = system_allocator_->Alloc(&index, max_chunk_size_);
// Allocate a new block
void* p = system_allocator_->Alloc(&index, allocate_bytes);
if (p == nullptr) return pool_.end();
......@@ -204,7 +218,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() {
<< " from system allocator";
static_cast<MemoryBlock*>(p)->init(&cache_, MemoryBlock::FREE_CHUNK, index,
max_chunk_size_, nullptr, nullptr);
allocate_bytes, nullptr, nullptr);
// gpu fallback allocation
if (system_allocator_->UseGpu() &&
......@@ -212,10 +226,10 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() {
fallback_alloc_count_++;
}
total_free_ += max_chunk_size_;
total_free_ += allocate_bytes;
// dump the block into pool
return pool_.insert(IndexSizeAddress(index, max_chunk_size_, p)).first;
return pool_.insert(IndexSizeAddress(index, allocate_bytes, p)).first;
}
BuddyAllocator::PoolSet::iterator BuddyAllocator::FindExistChunk(size_t size) {
......@@ -286,12 +300,12 @@ void BuddyAllocator::CleanIdleFallBackAlloc() {
VLOG(10) << "Return block " << block << " to fallback allocator.";
system_allocator_->Free(block, max_chunk_size_, block->index(cache_));
system_allocator_->Free(block, block->size(cache_), block->index(cache_));
cache_.invalidate(block);
pool = PoolSet::reverse_iterator(pool_.erase(std::next(pool).base()));
total_free_ -= max_chunk_size_;
total_free_ -= block->size(cache_);
fallback_alloc_count_--;
// If no fall allocation exists, return directly
......@@ -322,12 +336,12 @@ void BuddyAllocator::CleanIdleNormalAlloc() {
VLOG(10) << "Return block " << block << " to base allocator.";
system_allocator_->Free(block, max_chunk_size_, block->index(cache_));
system_allocator_->Free(block, block->size(cache_), block->index(cache_));
cache_.invalidate(block);
pool = PoolSet::reverse_iterator(pool_.erase(std::next(pool).base()));
total_free_ -= max_chunk_size_;
total_free_ -= block->size(cache_);
if (!shall_free_alloc()) return;
}
......
......@@ -60,7 +60,7 @@ class BuddyAllocator {
void* SystemAlloc(size_t size);
/*! \brief If existing chunks are not suitable, refill pool */
PoolSet::iterator RefillPool();
PoolSet::iterator RefillPool(size_t request_bytes);
/**
* \brief Find the suitable chunk from existing pool and split
......@@ -89,6 +89,8 @@ class BuddyAllocator {
size_t min_chunk_size_; // the minimum size of each chunk
size_t max_chunk_size_; // the maximum size of each chunk
size_t realloc_size_ = 0; // the size of re-allocated chunk
private:
/**
* \brief A list of free allocation
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/detail/buddy_allocator.h"
#include <memory>
#include "gflags/gflags.h"
#include "gtest/gtest.h"
#include "paddle/fluid/memory/detail/system_allocator.h"
#include "paddle/fluid/platform/gpu_info.h"
#ifdef PADDLE_WITH_CUDA
DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
#endif
namespace paddle {
namespace memory {
namespace detail {
constexpr static int test_gpu_id = 0;
void TestBuddyAllocator(BuddyAllocator* allocator, size_t size_bytes) {
bool freed = false;
size_t used_bytes = allocator->Used();
if (size_bytes > 0) {
void* p = allocator->Alloc(size_bytes);
EXPECT_NE(p, nullptr);
#ifdef PADDLE_WITH_CUDA
if (size_bytes < platform::GpuMaxChunkSize()) {
#else
if (size_bytes < platform::CpuMaxChunkSize()) {
#endif
// Not allocate from SystemAllocator
EXPECT_GE(allocator->Used(), used_bytes + size_bytes);
} else {
// Allocate from SystemAllocator doesn't count in Used()
EXPECT_EQ(allocator->Used(), used_bytes);
}
int* intp = static_cast<int*>(p);
std::shared_ptr<int> ptr(intp, [&](void* p) {
allocator->Free(intp);
freed = true;
});
} else {
freed = true;
}
EXPECT_EQ(used_bytes, allocator->Used());
EXPECT_TRUE(freed);
}
#ifdef PADDLE_WITH_CUDA
TEST(BuddyAllocator, GpuFraction) {
FLAGS_fraction_of_gpu_memory_to_use = 0.01;
BuddyAllocator buddy_allocator(
std::unique_ptr<SystemAllocator>(new GPUAllocator(test_gpu_id)),
platform::GpuMinChunkSize(), platform::GpuMaxChunkSize());
TestBuddyAllocator(&buddy_allocator, 10);
TestBuddyAllocator(&buddy_allocator, 10 << 10);
TestBuddyAllocator(&buddy_allocator, 10 << 20);
TestBuddyAllocator(&buddy_allocator, 2 * static_cast<size_t>(1 << 30));
}
TEST(BuddyAllocator, InitRealloc) {
FLAGS_initial_gpu_memory_in_mb = 100;
FLAGS_reallocate_gpu_memory_in_mb = 50;
EXPECT_EQ(platform::GpuMaxChunkSize(), static_cast<size_t>(100 << 20));
BuddyAllocator buddy_allocator(
std::unique_ptr<SystemAllocator>(new GPUAllocator(test_gpu_id)),
platform::GpuMinChunkSize(), platform::GpuMaxChunkSize());
// Less then initial size and reallocate size
TestBuddyAllocator(&buddy_allocator, 10 << 20);
// Between initial size and reallocate size and not exceed pool
TestBuddyAllocator(&buddy_allocator, 80 << 20);
// Less then reallocate size and exceed pool
TestBuddyAllocator(&buddy_allocator, 40 << 20);
// Greater then reallocate size and exceed pool
TestBuddyAllocator(&buddy_allocator, 80 << 20);
// Greater then initial size and reallocate size
TestBuddyAllocator(&buddy_allocator, 2 * static_cast<size_t>(1 << 30));
}
TEST(BuddyAllocator, ReallocSizeGreaterThanInit) {
FLAGS_initial_gpu_memory_in_mb = 5;
FLAGS_reallocate_gpu_memory_in_mb = 10;
EXPECT_EQ(platform::GpuMaxChunkSize(), static_cast<size_t>(10 << 20));
BuddyAllocator buddy_allocator(
std::unique_ptr<SystemAllocator>(new GPUAllocator(test_gpu_id)),
platform::GpuMinChunkSize(), platform::GpuMaxChunkSize());
// Less then initial size and reallocate size
TestBuddyAllocator(&buddy_allocator, 1 << 20);
// Between initial size and reallocate size and not exceed pool
TestBuddyAllocator(&buddy_allocator, 3 << 20);
// Less then initial size and exceed pool
TestBuddyAllocator(&buddy_allocator, 3 << 20);
// Less then reallocate size and not exceed pool (now pool is 15 MB, used 7
// MB)
TestBuddyAllocator(&buddy_allocator, 7 << 20);
// Less then reallocate size and exceed pool
TestBuddyAllocator(&buddy_allocator, 8 << 20);
// Greater then initial size and reallocate size
TestBuddyAllocator(&buddy_allocator, 2 * static_cast<size_t>(1 << 30));
}
#endif
} // namespace detail
} // namespace memory
} // namespace paddle
......@@ -32,6 +32,9 @@ limitations under the License. */
DECLARE_bool(use_pinned_memory);
DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
namespace paddle {
namespace memory {
namespace detail {
......@@ -119,11 +122,18 @@ void* GPUAllocator::Alloc(size_t* index, size_t size) {
gpu_alloc_size_ += size;
return p;
} else {
LOG(WARNING)
<< "Cannot malloc " << size / 1024.0 / 1024.0
<< " MB GPU memory. Please shrink FLAGS_fraction_of_gpu_memory_to_use "
"environment variable to a lower value. Current value is "
<< FLAGS_fraction_of_gpu_memory_to_use;
LOG(WARNING) << "Cannot malloc " << size / 1024.0 / 1024.0
<< " MB GPU memory. Please shrink "
"FLAGS_fraction_of_gpu_memory_to_use or "
"FLAGS_initial_gpu_memory_in_mb or "
"FLAGS_reallocate_gpu_memory_in_mb"
"environment variable to a lower value. "
<< "Current FLAGS_fraction_of_gpu_memory_to_use value is "
<< FLAGS_fraction_of_gpu_memory_to_use
<< ". Current FLAGS_initial_gpu_memory_in_mb value is "
<< FLAGS_initial_gpu_memory_in_mb
<< ". Current FLAGS_reallocate_gpu_memory_in_mb value is "
<< FLAGS_reallocate_gpu_memory_in_mb;
return nullptr;
}
}
......
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/gpu_info.h"
#include <algorithm>
#include <cstdlib>
#include <string>
......@@ -31,6 +30,8 @@ constexpr static float fraction_of_gpu_memory_to_use = 0.92f;
constexpr static float fraction_of_gpu_memory_to_use = 0.5f;
#endif
constexpr static float fraction_reserve_gpu_memory = 0.05f;
DEFINE_double(fraction_of_gpu_memory_to_use, fraction_of_gpu_memory_to_use,
"Allocate a trunk of gpu memory that is this fraction of the "
"total gpu memory size. Future memory usage will be allocated "
......@@ -38,6 +39,24 @@ DEFINE_double(fraction_of_gpu_memory_to_use, fraction_of_gpu_memory_to_use,
"additional trunks of the same size will be requested from gpu "
"until the gpu has no memory left for another trunk.");
DEFINE_uint64(
initial_gpu_memory_in_mb, 0ul,
"Allocate a trunk of gpu memory whose byte size is specified by "
"the flag. Future memory usage will be allocated from the "
"truck. If the trunk doesn't have enough gpu memory, additional "
"trunks of the gpu memory will be requested from gpu with size "
"specified by FLAGS_reallocate_gpu_memory_in_mb until the gpu has "
"no memory left for the additional trunk. Note: if you set this "
"flag, the memory size set by "
"FLAGS_fraction_of_gpu_memory_to_use will be overrided by this "
"flag. If you don't set this flag, PaddlePaddle will use "
"FLAGS_fraction_of_gpu_memory_to_use to allocate gpu memory");
DEFINE_uint64(reallocate_gpu_memory_in_mb, 0ul,
"If this flag is set, Paddle will reallocate the gpu memory with "
"size specified by this flag. Else Paddle will reallocate by "
"FLAGS_fraction_of_gpu_memory_to_use");
DEFINE_bool(
enable_cublas_tensor_op_math, false,
"The enable_cublas_tensor_op_math indicate whether to use Tensor Core, "
......@@ -180,13 +199,43 @@ void GpuMemoryUsage(size_t *available, size_t *total) {
}
size_t GpuMaxAllocSize() {
return std::max(GpuInitAllocSize(), GpuReallocSize());
}
size_t GpuInitAllocSize() {
if (FLAGS_initial_gpu_memory_in_mb > 0ul) {
// Initial memory will be allocated by FLAGS_initial_gpu_memory_in_mb
return static_cast<size_t>(FLAGS_initial_gpu_memory_in_mb << 20);
}
// FLAGS_initial_gpu_memory_in_mb is 0, initial memory will be allocated by
// fraction
size_t total = 0;
size_t available = 0;
GpuMemoryUsage(&available, &total);
size_t reserving = static_cast<size_t>(fraction_reserve_gpu_memory * total);
// Reserve the rest for page tables, etc.
return static_cast<size_t>(total * FLAGS_fraction_of_gpu_memory_to_use);
return static_cast<size_t>((total - reserving) *
FLAGS_fraction_of_gpu_memory_to_use);
}
size_t GpuReallocSize() {
if (FLAGS_reallocate_gpu_memory_in_mb > 0ul) {
// Additional memory will be allocated by FLAGS_reallocate_gpu_memory_in_mb
return static_cast<size_t>(FLAGS_reallocate_gpu_memory_in_mb << 20);
}
// FLAGS_reallocate_gpu_memory_in_mb is 0, additional memory will be allocated
// by fraction
size_t total = 0;
size_t available = 0;
GpuMemoryUsage(&available, &total);
size_t reserving = static_cast<size_t>(fraction_reserve_gpu_memory * total);
return static_cast<size_t>((total - reserving) *
FLAGS_fraction_of_gpu_memory_to_use);
}
size_t GpuMinChunkSize() {
......@@ -201,16 +250,13 @@ size_t GpuMaxChunkSize() {
GpuMemoryUsage(&available, &total);
VLOG(10) << "GPU Usage " << available / 1024 / 1024 << "M/"
<< total / 1024 / 1024 << "M";
size_t reserving = static_cast<size_t>(0.05 * total);
size_t reserving = static_cast<size_t>(fraction_reserve_gpu_memory * total);
// If available less than minimum chunk size, no usable memory exists.
available =
std::min(std::max(available, GpuMinChunkSize()) - GpuMinChunkSize(),
total - reserving);
// Reserving the rest memory for page tables, etc.
size_t allocating = static_cast<size_t>(FLAGS_fraction_of_gpu_memory_to_use *
(total - reserving));
size_t allocating = GpuMaxAllocSize();
PADDLE_ENFORCE_LE(allocating, available,
"Insufficient GPU memory to allocation.");
......
......@@ -60,6 +60,12 @@ void GpuMemoryUsage(size_t *available, size_t *total);
//! Get the maximum allocation size of current GPU device.
size_t GpuMaxAllocSize();
//! Get the initial allocation size of current GPU device.
size_t GpuInitAllocSize();
//! Get the re-allocation size of current GPU device.
size_t GpuReallocSize();
//! Get the minimum chunk size for GPU buddy allocator.
size_t GpuMinChunkSize();
......
......@@ -41,6 +41,8 @@ int main(int argc, char** argv) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
envs.push_back("fraction_of_gpu_memory_to_use");
envs.push_back("initial_gpu_memory_in_mb");
envs.push_back("reallocate_gpu_memory_in_mb");
envs.push_back("allocator_strategy");
#elif __clang__
envs.push_back("use_mkldnn");
......
......@@ -163,7 +163,8 @@ def __bootstrap__():
if core.is_compiled_with_cuda():
read_env_flags += [
'fraction_of_gpu_memory_to_use', 'cudnn_deterministic',
'fraction_of_gpu_memory_to_use', 'initial_gpu_memory_in_mb',
'reallocate_gpu_memory_in_mb', 'cudnn_deterministic',
'enable_cublas_tensor_op_math', 'conv_workspace_size_limit',
'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus',
'sync_nccl_allreduce', 'limit_of_tmp_allocation',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册