未验证 提交 26d45137 编写于 作者: Z Zeng Jinle 提交者: GitHub

Cherry-pick gpu memory limit (#22838)

* add recorded cuda memory apis, fix typo, test=develop

* add more ut, test=develop

* follow comments, test=release/1.7

* fix py35 incompatible issues, test=release/1.7
上级 a1c0b241
......@@ -25,29 +25,38 @@ namespace memory {
namespace allocation {
bool CUDAAllocator::IsAllocThreadSafe() const { return true; }
void CUDAAllocator::FreeImpl(Allocation* allocation) {
platform::CUDADeviceGuard guard(place_.device);
PADDLE_ENFORCE_EQ(boost::get<platform::CUDAPlace>(allocation->place()),
place_);
PADDLE_ENFORCE(cudaFree(allocation->ptr()));
PADDLE_ENFORCE_EQ(
boost::get<platform::CUDAPlace>(allocation->place()), place_,
platform::errors::PermissionDenied(
"GPU memory is freed in incorrect device. This may be a bug"));
platform::RecordedCudaFree(allocation->ptr(), allocation->size(),
place_.device);
delete allocation;
}
Allocation* CUDAAllocator::AllocateImpl(size_t size) {
std::call_once(once_flag_, [this] { platform::SetDeviceId(place_.device); });
platform::CUDADeviceGuard guard(place_.device);
void* ptr;
auto result = cudaMalloc(&ptr, size);
auto result = platform::RecordedCudaMalloc(&ptr, size, place_.device);
if (LIKELY(result == cudaSuccess)) {
return new Allocation(ptr, size, platform::Place(place_));
}
platform::RaiseNonOutOfMemoryError(&result);
size_t avail, total, actual_avail, actual_total;
bool is_limited = platform::RecordedCudaMemGetInfo(
&avail, &total, &actual_avail, &actual_total, place_.device);
size_t avail = 0, total = 0;
result = cudaMemGetInfo(&avail, &total);
if (result != cudaSuccess) avail = 0;
platform::RaiseNonOutOfMemoryError(&result);
std::string err_msg;
if (is_limited) {
auto limit_size = (total >> 20);
err_msg = string::Sprintf(
"Or set environment variable `FLAGS_gpu_memory_limit_mb` to a larger "
"value. Currently `FLAGS_gpu_memory_limit_mb` is %d, so the maximum "
"GPU memory usage is limited to %d MB.\n"
" The command is `export FLAGS_gpu_memory_limit_mb=xxx`.",
limit_size, limit_size);
}
PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted(
"\n\nOut of memory error on GPU %d. "
......@@ -55,9 +64,9 @@ Allocation* CUDAAllocator::AllocateImpl(size_t size) {
"available memory is only %s.\n\n"
"Please check whether there is any other process using GPU %d.\n"
"1. If yes, please stop them, or start PaddlePaddle on another GPU.\n"
"2. If no, please decrease the batch size of your model.\n",
"2. If no, please decrease the batch size of your model. %s\n\n",
place_.device, string::HumanReadableSize(size), place_.device,
string::HumanReadableSize(avail), place_.device));
string::HumanReadableSize(avail), place_.device, err_msg));
}
} // namespace allocation
......
......@@ -110,29 +110,28 @@ void* GPUAllocator::Alloc(size_t* index, size_t size) {
// if size is 0. We just make sure it does.
if (size <= 0) return nullptr;
paddle::platform::CUDADeviceGuard guard(gpu_id_);
void* p;
cudaError_t result = cudaMalloc(&p, size);
auto result = platform::RecordedCudaMalloc(&p, size, gpu_id_);
if (result == cudaSuccess) {
*index = 0;
gpu_alloc_size_ += size;
return p;
} else {
platform::RaiseNonOutOfMemoryError(&result);
/**
* NOTE(zjl): Sometimes cudaMemGetInfo would raise OOM error
* if there is very little GPU memory left. In this case, we
* should consider the available GPU memory to be 0, and throw
* exception inside this function instead of throwing exception
* inside cudaMemGetInfo.
*/
size_t avail = 0, total = 0;
result = cudaMemGetInfo(&avail, &total);
if (result != cudaSuccess) avail = 0;
platform::RaiseNonOutOfMemoryError(&result);
size_t avail, total, actual_avail, actual_total;
bool is_limited = platform::RecordedCudaMemGetInfo(
&avail, &total, &actual_avail, &actual_total, gpu_id_);
std::string err_msg;
if (is_limited) {
auto limit_size = (total >> 20);
err_msg = string::Sprintf(
"\n 3) Set environment variable `FLAGS_gpu_memory_limit_mb` to a "
"larger value. Currently `FLAGS_gpu_memory_limit_mb` is %d, so the "
"maximum GPU memory usage is limited to %d MB.\n"
" The command is `export FLAGS_gpu_memory_limit_mb=xxx`.",
limit_size, limit_size);
}
PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted(
"\n\nOut of memory error on GPU %d. "
......@@ -145,28 +144,19 @@ void* GPUAllocator::Alloc(size_t* index, size_t size) {
" 2) FLAGS_fraction_of_gpu_memory_to_use is %.2lf now, "
"please set it to a higher value but less than 1.0.\n"
" The command is "
"`export FLAGS_fraction_of_gpu_memory_to_use=xxx`.\n\n",
"`export FLAGS_fraction_of_gpu_memory_to_use=xxx`.%s\n\n",
gpu_id_, string::HumanReadableSize(size), gpu_id_,
string::HumanReadableSize(avail), gpu_id_,
FLAGS_fraction_of_gpu_memory_to_use));
FLAGS_fraction_of_gpu_memory_to_use, err_msg));
}
}
void GPUAllocator::Free(void* p, size_t size, size_t index) {
cudaError_t err;
PADDLE_ENFORCE_EQ(index, 0);
PADDLE_ENFORCE_GE(gpu_alloc_size_, size);
gpu_alloc_size_ -= size;
err = cudaFree(p);
// Purposefully allow cudaErrorCudartUnloading, because
// that is returned if you ever call cudaFree after the
// driver has already shutdown. This happens only if the
// process is terminating, in which case we don't care if
// cudaFree succeeds.
if (err != cudaErrorCudartUnloading) {
PADDLE_ENFORCE(err, "cudaFree{Host} failed in GPUAllocator::Free.");
}
platform::RecordedCudaFree(p, size, gpu_id_);
}
bool GPUAllocator::UseGpu() const { return true; }
......
......@@ -117,6 +117,8 @@ cc_test(profiler_test SRCS profiler_test.cc DEPS profiler)
nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor)
cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor)
nv_test(test_limit_gpu_memory SRCS test_limit_gpu_memory.cu DEPS gpu_info flags)
nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info)
if(NOT APPLE AND NOT WIN32)
......
......@@ -449,6 +449,14 @@ DEFINE_uint64(reallocate_gpu_memory_in_mb, 0ul,
"size specified by this flag. Else Paddle will reallocate by "
"FLAGS_fraction_of_gpu_memory_to_use");
DEFINE_uint64(gpu_memory_limit_mb, 0UL,
"The maximum gpu memory limit that the process can allocate. "
"If it is equal to 0, there would be no limit and all gpu memory "
"would be available to the process. If it is larger than 0, "
"the process would raise out of memory error if the allocated "
"memory exceeds the limit even though there is available "
"memory on the gpu card. The unit is MB and default value is 0.");
#endif
/**
......
......@@ -15,10 +15,14 @@ limitations under the License. */
#include "paddle/fluid/platform/gpu_info.h"
#include <algorithm>
#include <cstdlib>
#include <memory>
#include <string>
#include "gflags/gflags.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/lock_guard_ptr.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/string/split.h"
DECLARE_double(fraction_of_gpu_memory_to_use);
......@@ -26,6 +30,7 @@ DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_bool(enable_cublas_tensor_op_math);
DECLARE_string(selected_gpus);
DECLARE_uint64(gpu_memory_limit_mb);
constexpr static float fraction_reserve_gpu_memory = 0.05f;
......@@ -241,11 +246,9 @@ void SetDeviceId(int id) {
}
void GpuMemoryUsage(size_t *available, size_t *total) {
auto error_code = cudaMemGetInfo(available, total);
PADDLE_ENFORCE(error_code,
"cudaMemGetInfo failed in "
"paddle::platform::GetMemoryUsage, error code : %d, %s",
error_code, CudaErrorWebsite());
size_t actual_available, actual_total;
RecordedCudaMemGetInfo(available, total, &actual_available, &actual_total,
platform::GetCurrentDeviceId());
}
size_t GpuAvailableMemToAlloc() {
......@@ -359,7 +362,7 @@ void GpuStreamSync(cudaStream_t stream) {
error_code, CudaErrorWebsite()));
}
void RaiseNonOutOfMemoryError(cudaError_t *status) {
static void RaiseNonOutOfMemoryError(cudaError_t *status) {
if (*status == cudaErrorMemoryAllocation) {
*status = cudaSuccess;
}
......@@ -374,5 +377,158 @@ void RaiseNonOutOfMemoryError(cudaError_t *status) {
PADDLE_ENFORCE_CUDA_SUCCESS(*status);
}
class RecordedCudaMallocHelper {
private:
explicit RecordedCudaMallocHelper(int dev_id, uint64_t limit_size = 0)
: dev_id_(dev_id), limit_size_(limit_size) {
if (NeedRecord()) {
mtx_.reset(new std::mutex());
}
}
DISABLE_COPY_AND_ASSIGN(RecordedCudaMallocHelper);
public:
static RecordedCudaMallocHelper *Instance(int dev_id) {
std::call_once(once_flag_, [] {
int dev_cnt = GetCUDADeviceCount();
instances_.reserve(dev_cnt);
for (int i = 0; i < dev_cnt; ++i) {
instances_.emplace_back(
new RecordedCudaMallocHelper(i, FLAGS_gpu_memory_limit_mb << 20));
}
});
PADDLE_ENFORCE_GE(
dev_id, 0,
platform::errors::OutOfRange(
"Device id must be not less than 0, but got %d", dev_id));
PADDLE_ENFORCE_LT(
dev_id, instances_.size(),
platform::errors::OutOfRange("Device id %d exceeds gpu card number %d",
dev_id, instances_.size()));
return instances_[dev_id].get();
}
/**
* Try to allocate `size` gpu memory. Only cudaErrorMemoryAllocation
* or cudaSuccess would be returned, and the cudaGetLastError() flag
* would be clear.
*/
cudaError_t Malloc(void **ptr, size_t size) {
LockGuardPtr<std::mutex> lock(mtx_);
if (UNLIKELY(NeedRecord() && cur_size_ + size > limit_size_)) {
return cudaErrorMemoryAllocation;
}
CUDADeviceGuard guard(dev_id_);
auto result = cudaMalloc(ptr, size);
if (result == cudaSuccess) {
if (NeedRecord()) {
cur_size_ += size;
}
return cudaSuccess;
} else {
RaiseNonOutOfMemoryError(&result);
// Non out of memory error would be raised inside
// RaiseNonOutOfMemoryError. Therefore, we can
// return cudaErrorMemoryAllocation directly here.
return cudaErrorMemoryAllocation;
}
}
/**
* Free gpu memory. Usually, free is not allowed to raise error.
* If it does raise error, the process should be crashed.
*/
void Free(void *ptr, size_t size) {
// Purposefully allow cudaErrorCudartUnloading, because
// that is returned if you ever call cudaFree after the
// driver has already shutdown. This happens only if the
// process is terminating, in which case we don't care if
// cudaFree succeeds.
CUDADeviceGuard guard(dev_id_);
auto err = cudaFree(ptr);
if (err != cudaErrorCudartUnloading) {
PADDLE_ENFORCE_CUDA_SUCCESS(
err, platform::errors::External("cudaFree raises unexpected error"));
if (NeedRecord()) {
std::lock_guard<std::mutex> guard(*mtx_);
cur_size_ -= size;
}
} else {
cudaGetLastError(); // clear the error flag when cudaErrorCudartUnloading
}
}
bool GetMemInfo(size_t *avail, size_t *total, size_t *actual_avail,
size_t *actual_total) {
{
CUDADeviceGuard guard(dev_id_);
auto result = cudaMemGetInfo(actual_avail, actual_total);
if (result != cudaSuccess) {
*actual_avail = 0;
}
RaiseNonOutOfMemoryError(&result);
}
if (NeedRecord()) {
std::lock_guard<std::mutex> guard(*mtx_);
*avail = std::min(*actual_avail, limit_size_ - cur_size_);
*total = std::min(*actual_total, limit_size_);
return *total < *actual_total;
} else {
*avail = *actual_avail;
*total = *actual_total;
return false;
}
}
inline bool NeedRecord() const { return limit_size_ != 0; }
uint64_t RecordedSize() const {
LockGuardPtr<std::mutex> lock(mtx_);
return NeedRecord() ? cur_size_ : 0;
}
uint64_t LimitSize() const { return limit_size_; }
private:
const int dev_id_;
const uint64_t limit_size_;
uint64_t cur_size_{0};
mutable std::unique_ptr<std::mutex> mtx_;
static std::once_flag once_flag_;
static std::vector<std::unique_ptr<RecordedCudaMallocHelper>> instances_;
};
std::once_flag RecordedCudaMallocHelper::once_flag_;
std::vector<std::unique_ptr<RecordedCudaMallocHelper>>
RecordedCudaMallocHelper::instances_;
cudaError_t RecordedCudaMalloc(void **ptr, size_t size, int dev_id) {
return RecordedCudaMallocHelper::Instance(dev_id)->Malloc(ptr, size);
}
void RecordedCudaFree(void *p, size_t size, int dev_id) {
return RecordedCudaMallocHelper::Instance(dev_id)->Free(p, size);
}
bool RecordedCudaMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail,
size_t *actual_total, int dev_id) {
return RecordedCudaMallocHelper::Instance(dev_id)->GetMemInfo(
avail, total, actual_avail, actual_total);
}
uint64_t RecordedCudaMallocSize(int dev_id) {
return RecordedCudaMallocHelper::Instance(dev_id)->RecordedSize();
}
bool IsCudaMallocRecorded(int dev_id) {
return RecordedCudaMallocHelper::Instance(dev_id)->NeedRecord();
}
} // namespace platform
} // namespace paddle
......@@ -104,8 +104,20 @@ void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream);
//! Blocks until stream has completed all operations.
void GpuStreamSync(cudaStream_t stream);
//! Raise error if status is not cudaSuccess or OOM, otherwise reset status.
void RaiseNonOutOfMemoryError(cudaError_t *status);
//! CudaMalloc with recorded info
cudaError_t RecordedCudaMalloc(void **ptr, size_t size, int dev_id);
//! CudaFree with recorded info
void RecordedCudaFree(void *p, size_t size, int dev_id);
//! Get available and total gpu memory with considering limitation
bool RecordedCudaMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail,
size_t *actual_total, int dev_id);
//! Get recorded cudaMalloc size. If record is disabled, return 0.
uint64_t RecordedCudaMallocSize(int dev_id);
bool IsCudaMallocRecorded(int dev_id);
} // namespace platform
} // namespace paddle
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gflags/gflags.h"
#include "gtest/gtest.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/gpu_info.h"
DECLARE_uint64(gpu_memory_limit_mb);
namespace paddle {
namespace platform {
static constexpr uint64_t GPU_MEMORY_LIMIT_MB = 500;
static constexpr int DEVICE_ID = 0;
TEST(test_record_malloc, test_limit_gpu_memory) {
FLAGS_gpu_memory_limit_mb = GPU_MEMORY_LIMIT_MB;
size_t limit = FLAGS_gpu_memory_limit_mb << 20;
{
ASSERT_TRUE(IsCudaMallocRecorded(DEVICE_ID));
ASSERT_EQ(RecordedCudaMallocSize(DEVICE_ID), 0UL);
}
size_t avail, total;
{
size_t actual_avail, actual_total;
RecordedCudaMemGetInfo(&avail, &total, &actual_avail, &actual_total,
DEVICE_ID);
ASSERT_EQ(total, limit);
ASSERT_EQ(cudaGetLastError(), cudaSuccess);
}
{
CUDADeviceGuard guard(DEVICE_ID);
GpuMemoryUsage(&avail, &total);
ASSERT_EQ(total, limit);
ASSERT_EQ(cudaGetLastError(), cudaSuccess);
}
cudaError_t err = cudaSuccess;
void *p1 = nullptr;
size_t size1 = limit / 4 * 3;
{
err = platform::RecordedCudaMalloc(&p1, size1, DEVICE_ID);
ASSERT_EQ(err, cudaSuccess);
ASSERT_EQ(cudaGetLastError(), cudaSuccess);
ASSERT_NE(p1, nullptr);
ASSERT_EQ(RecordedCudaMallocSize(DEVICE_ID), size1);
}
void *p2 = nullptr;
size_t size2 = limit / 2;
{
err = platform::RecordedCudaMalloc(&p2, size2, DEVICE_ID);
ASSERT_EQ(err, cudaErrorMemoryAllocation);
ASSERT_EQ(cudaGetLastError(), cudaSuccess);
ASSERT_EQ(p2, nullptr);
ASSERT_EQ(RecordedCudaMallocSize(DEVICE_ID), size1);
}
{
platform::RecordedCudaFree(p1, size1, DEVICE_ID);
ASSERT_EQ(RecordedCudaMallocSize(DEVICE_ID), 0UL);
}
{
err = platform::RecordedCudaMalloc(&p2, size2, DEVICE_ID);
ASSERT_EQ(err, cudaSuccess);
ASSERT_EQ(cudaGetLastError(), cudaSuccess);
ASSERT_NE(p2, nullptr);
ASSERT_EQ(RecordedCudaMallocSize(DEVICE_ID), size2);
}
{
platform::RecordedCudaFree(p2, size2, DEVICE_ID);
ASSERT_EQ(RecordedCudaMallocSize(DEVICE_ID), 0UL);
}
}
} // namespace platform
} // namespace paddle
......@@ -32,6 +32,10 @@ DECLARE_bool(use_ngraph);
DECLARE_bool(use_system_allocator);
DECLARE_bool(free_idle_chunk);
DECLARE_bool(free_when_no_cache_hit);
#ifdef PADDLE_WITH_CUDA
DECLARE_uint64(gpu_memory_limit_mb);
#endif
DECLARE_string(allocator_strategy);
namespace paddle {
namespace pybind {
......@@ -169,8 +173,12 @@ static void RegisterGlobalVarGetterSetter() {
REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_use_ngraph);
REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_eager_delete_tensor_gb);
REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_use_system_allocator);
REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_allocator_strategy);
REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_free_idle_chunk);
REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_free_when_no_cache_hit);
#ifdef PADDLE_WITH_CUDA
REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_gpu_memory_limit_mb);
#endif
}
} // namespace pybind
......
......@@ -211,7 +211,7 @@ def __bootstrap__():
'enable_cublas_tensor_op_math', 'conv_workspace_size_limit',
'cudnn_exhaustive_search', 'selected_gpus', 'sync_nccl_allreduce',
'cudnn_batchnorm_spatial_persistent', 'gpu_allocator_retry_time',
'local_exe_sub_scope_limit'
'local_exe_sub_scope_limit', 'gpu_memory_limit_mb'
]
core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)])
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import numpy as np
fluid.core.globals()['FLAGS_allocator_strategy'] = 'auto_growth'
if fluid.is_compiled_with_cuda():
fluid.core.globals()['FLAGS_gpu_memory_limit_mb'] = 10
class TestBase(unittest.TestCase):
def setUp(self):
if fluid.is_compiled_with_cuda():
self._limit = fluid.core.globals()['FLAGS_gpu_memory_limit_mb']
def test_allocate(self):
if not fluid.is_compiled_with_cuda():
return
other_dim = int(1024 * 1024 / 4)
place = fluid.CUDAPlace(0)
t = fluid.LoDTensor()
t.set(np.ndarray(
[int(self._limit / 2), other_dim], dtype='float32'),
place)
del t
t = fluid.LoDTensor()
large_np = np.ndarray([2 * self._limit, other_dim], dtype='float32')
try:
t.set(large_np, place)
self.assertTrue(False)
except:
self.assertTrue(True)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import numpy as np
fluid.core.globals()['FLAGS_allocator_strategy'] = 'naive_best_fit'
if fluid.is_compiled_with_cuda():
fluid.core.globals()['FLAGS_gpu_memory_limit_mb'] = 10
class TestBase(unittest.TestCase):
def setUp(self):
if fluid.is_compiled_with_cuda():
self._limit = fluid.core.globals()['FLAGS_gpu_memory_limit_mb']
def test_allocate(self):
if not fluid.is_compiled_with_cuda():
return
other_dim = int(1024 * 1024 / 4)
place = fluid.CUDAPlace(0)
t = fluid.LoDTensor()
t.set(np.ndarray(
[int(self._limit / 2), other_dim], dtype='float32'),
place)
del t
t = fluid.LoDTensor()
large_np = np.ndarray([2 * self._limit, other_dim], dtype='float32')
try:
t.set(large_np, place)
self.assertTrue(False)
except:
self.assertTrue(True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册