diff --git a/paddle/fluid/memory/allocation/cuda_allocator.cc b/paddle/fluid/memory/allocation/cuda_allocator.cc index 154ff1c87aafebde3fe28a8f380b652376651065..56a2ebb35554f24527731a62da65c019dfeb04d6 100644 --- a/paddle/fluid/memory/allocation/cuda_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_allocator.cc @@ -25,29 +25,38 @@ namespace memory { namespace allocation { bool CUDAAllocator::IsAllocThreadSafe() const { return true; } void CUDAAllocator::FreeImpl(Allocation* allocation) { - platform::CUDADeviceGuard guard(place_.device); - PADDLE_ENFORCE_EQ(boost::get(allocation->place()), - place_); - PADDLE_ENFORCE(cudaFree(allocation->ptr())); + PADDLE_ENFORCE_EQ( + boost::get(allocation->place()), place_, + platform::errors::PermissionDenied( + "GPU memory is freed in incorrect device. This may be a bug")); + platform::RecordedCudaFree(allocation->ptr(), allocation->size(), + place_.device); delete allocation; } Allocation* CUDAAllocator::AllocateImpl(size_t size) { std::call_once(once_flag_, [this] { platform::SetDeviceId(place_.device); }); - platform::CUDADeviceGuard guard(place_.device); void* ptr; - auto result = cudaMalloc(&ptr, size); + auto result = platform::RecordedCudaMalloc(&ptr, size, place_.device); if (LIKELY(result == cudaSuccess)) { return new Allocation(ptr, size, platform::Place(place_)); } - platform::RaiseNonOutOfMemoryError(&result); + size_t avail, total, actual_avail, actual_total; + bool is_limited = platform::RecordedCudaMemGetInfo( + &avail, &total, &actual_avail, &actual_total, place_.device); - size_t avail = 0, total = 0; - result = cudaMemGetInfo(&avail, &total); - if (result != cudaSuccess) avail = 0; - platform::RaiseNonOutOfMemoryError(&result); + std::string err_msg; + if (is_limited) { + auto limit_size = (total >> 20); + err_msg = string::Sprintf( + "Or set environment variable `FLAGS_gpu_memory_limit_mb` to a larger " + "value. Currently `FLAGS_gpu_memory_limit_mb` is %d, so the maximum " + "GPU memory usage is limited to %d MB.\n" + " The command is `export FLAGS_gpu_memory_limit_mb=xxx`.", + limit_size, limit_size); + } PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted( "\n\nOut of memory error on GPU %d. " @@ -55,9 +64,9 @@ Allocation* CUDAAllocator::AllocateImpl(size_t size) { "available memory is only %s.\n\n" "Please check whether there is any other process using GPU %d.\n" "1. If yes, please stop them, or start PaddlePaddle on another GPU.\n" - "2. If no, please decrease the batch size of your model.\n", + "2. If no, please decrease the batch size of your model. %s\n\n", place_.device, string::HumanReadableSize(size), place_.device, - string::HumanReadableSize(avail), place_.device)); + string::HumanReadableSize(avail), place_.device, err_msg)); } } // namespace allocation diff --git a/paddle/fluid/memory/detail/system_allocator.cc b/paddle/fluid/memory/detail/system_allocator.cc index 058485b8ddd1e934dcfc5133c9052f76346fb587..a3f96ea58729dc908cad77152680860bd773c3f0 100644 --- a/paddle/fluid/memory/detail/system_allocator.cc +++ b/paddle/fluid/memory/detail/system_allocator.cc @@ -110,29 +110,28 @@ void* GPUAllocator::Alloc(size_t* index, size_t size) { // if size is 0. We just make sure it does. if (size <= 0) return nullptr; - paddle::platform::CUDADeviceGuard guard(gpu_id_); - void* p; - cudaError_t result = cudaMalloc(&p, size); + auto result = platform::RecordedCudaMalloc(&p, size, gpu_id_); if (result == cudaSuccess) { *index = 0; gpu_alloc_size_ += size; return p; } else { - platform::RaiseNonOutOfMemoryError(&result); - - /** - * NOTE(zjl): Sometimes cudaMemGetInfo would raise OOM error - * if there is very little GPU memory left. In this case, we - * should consider the available GPU memory to be 0, and throw - * exception inside this function instead of throwing exception - * inside cudaMemGetInfo. - */ - size_t avail = 0, total = 0; - result = cudaMemGetInfo(&avail, &total); - if (result != cudaSuccess) avail = 0; - platform::RaiseNonOutOfMemoryError(&result); + size_t avail, total, actual_avail, actual_total; + bool is_limited = platform::RecordedCudaMemGetInfo( + &avail, &total, &actual_avail, &actual_total, gpu_id_); + + std::string err_msg; + if (is_limited) { + auto limit_size = (total >> 20); + err_msg = string::Sprintf( + "\n 3) Set environment variable `FLAGS_gpu_memory_limit_mb` to a " + "larger value. Currently `FLAGS_gpu_memory_limit_mb` is %d, so the " + "maximum GPU memory usage is limited to %d MB.\n" + " The command is `export FLAGS_gpu_memory_limit_mb=xxx`.", + limit_size, limit_size); + } PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted( "\n\nOut of memory error on GPU %d. " @@ -145,28 +144,19 @@ void* GPUAllocator::Alloc(size_t* index, size_t size) { " 2) FLAGS_fraction_of_gpu_memory_to_use is %.2lf now, " "please set it to a higher value but less than 1.0.\n" " The command is " - "`export FLAGS_fraction_of_gpu_memory_to_use=xxx`.\n\n", + "`export FLAGS_fraction_of_gpu_memory_to_use=xxx`.%s\n\n", gpu_id_, string::HumanReadableSize(size), gpu_id_, string::HumanReadableSize(avail), gpu_id_, - FLAGS_fraction_of_gpu_memory_to_use)); + FLAGS_fraction_of_gpu_memory_to_use, err_msg)); } } void GPUAllocator::Free(void* p, size_t size, size_t index) { - cudaError_t err; PADDLE_ENFORCE_EQ(index, 0); PADDLE_ENFORCE_GE(gpu_alloc_size_, size); gpu_alloc_size_ -= size; - err = cudaFree(p); - // Purposefully allow cudaErrorCudartUnloading, because - // that is returned if you ever call cudaFree after the - // driver has already shutdown. This happens only if the - // process is terminating, in which case we don't care if - // cudaFree succeeds. - if (err != cudaErrorCudartUnloading) { - PADDLE_ENFORCE(err, "cudaFree{Host} failed in GPUAllocator::Free."); - } + platform::RecordedCudaFree(p, size, gpu_id_); } bool GPUAllocator::UseGpu() const { return true; } diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 9e119ba66df18fa6ba6363966ee3470c7a16d8bb..08ed66542a72fe1e853dd24c7b9bf4e16423253e 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -117,6 +117,8 @@ cc_test(profiler_test SRCS profiler_test.cc DEPS profiler) nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor) cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor) +nv_test(test_limit_gpu_memory SRCS test_limit_gpu_memory.cu DEPS gpu_info flags) + nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info) if(NOT APPLE AND NOT WIN32) diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 1f16ce335778849bb3164e989592cf294bae5ad5..046fd16fb15932a80eb7536d715522ba5e54156f 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -449,6 +449,14 @@ DEFINE_uint64(reallocate_gpu_memory_in_mb, 0ul, "size specified by this flag. Else Paddle will reallocate by " "FLAGS_fraction_of_gpu_memory_to_use"); +DEFINE_uint64(gpu_memory_limit_mb, 0UL, + "The maximum gpu memory limit that the process can allocate. " + "If it is equal to 0, there would be no limit and all gpu memory " + "would be available to the process. If it is larger than 0, " + "the process would raise out of memory error if the allocated " + "memory exceeds the limit even though there is available " + "memory on the gpu card. The unit is MB and default value is 0."); + #endif /** diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index ad664b88e606f41fd32f6c431c2c054c630ac46d..40d6bc54ccf928845095cbc6561a0447657b0a25 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -15,10 +15,14 @@ limitations under the License. */ #include "paddle/fluid/platform/gpu_info.h" #include #include +#include #include #include "gflags/gflags.h" +#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/lock_guard_ptr.h" +#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/string/split.h" DECLARE_double(fraction_of_gpu_memory_to_use); @@ -26,6 +30,7 @@ DECLARE_uint64(initial_gpu_memory_in_mb); DECLARE_uint64(reallocate_gpu_memory_in_mb); DECLARE_bool(enable_cublas_tensor_op_math); DECLARE_string(selected_gpus); +DECLARE_uint64(gpu_memory_limit_mb); constexpr static float fraction_reserve_gpu_memory = 0.05f; @@ -241,11 +246,9 @@ void SetDeviceId(int id) { } void GpuMemoryUsage(size_t *available, size_t *total) { - auto error_code = cudaMemGetInfo(available, total); - PADDLE_ENFORCE(error_code, - "cudaMemGetInfo failed in " - "paddle::platform::GetMemoryUsage, error code : %d, %s", - error_code, CudaErrorWebsite()); + size_t actual_available, actual_total; + RecordedCudaMemGetInfo(available, total, &actual_available, &actual_total, + platform::GetCurrentDeviceId()); } size_t GpuAvailableMemToAlloc() { @@ -359,7 +362,7 @@ void GpuStreamSync(cudaStream_t stream) { error_code, CudaErrorWebsite())); } -void RaiseNonOutOfMemoryError(cudaError_t *status) { +static void RaiseNonOutOfMemoryError(cudaError_t *status) { if (*status == cudaErrorMemoryAllocation) { *status = cudaSuccess; } @@ -374,5 +377,158 @@ void RaiseNonOutOfMemoryError(cudaError_t *status) { PADDLE_ENFORCE_CUDA_SUCCESS(*status); } +class RecordedCudaMallocHelper { + private: + explicit RecordedCudaMallocHelper(int dev_id, uint64_t limit_size = 0) + : dev_id_(dev_id), limit_size_(limit_size) { + if (NeedRecord()) { + mtx_.reset(new std::mutex()); + } + } + + DISABLE_COPY_AND_ASSIGN(RecordedCudaMallocHelper); + + public: + static RecordedCudaMallocHelper *Instance(int dev_id) { + std::call_once(once_flag_, [] { + int dev_cnt = GetCUDADeviceCount(); + instances_.reserve(dev_cnt); + for (int i = 0; i < dev_cnt; ++i) { + instances_.emplace_back( + new RecordedCudaMallocHelper(i, FLAGS_gpu_memory_limit_mb << 20)); + } + }); + + PADDLE_ENFORCE_GE( + dev_id, 0, + platform::errors::OutOfRange( + "Device id must be not less than 0, but got %d", dev_id)); + PADDLE_ENFORCE_LT( + dev_id, instances_.size(), + platform::errors::OutOfRange("Device id %d exceeds gpu card number %d", + dev_id, instances_.size())); + return instances_[dev_id].get(); + } + + /** + * Try to allocate `size` gpu memory. Only cudaErrorMemoryAllocation + * or cudaSuccess would be returned, and the cudaGetLastError() flag + * would be clear. + */ + cudaError_t Malloc(void **ptr, size_t size) { + LockGuardPtr lock(mtx_); + if (UNLIKELY(NeedRecord() && cur_size_ + size > limit_size_)) { + return cudaErrorMemoryAllocation; + } + + CUDADeviceGuard guard(dev_id_); + auto result = cudaMalloc(ptr, size); + if (result == cudaSuccess) { + if (NeedRecord()) { + cur_size_ += size; + } + return cudaSuccess; + } else { + RaiseNonOutOfMemoryError(&result); + // Non out of memory error would be raised inside + // RaiseNonOutOfMemoryError. Therefore, we can + // return cudaErrorMemoryAllocation directly here. + return cudaErrorMemoryAllocation; + } + } + + /** + * Free gpu memory. Usually, free is not allowed to raise error. + * If it does raise error, the process should be crashed. + */ + void Free(void *ptr, size_t size) { + // Purposefully allow cudaErrorCudartUnloading, because + // that is returned if you ever call cudaFree after the + // driver has already shutdown. This happens only if the + // process is terminating, in which case we don't care if + // cudaFree succeeds. + CUDADeviceGuard guard(dev_id_); + auto err = cudaFree(ptr); + if (err != cudaErrorCudartUnloading) { + PADDLE_ENFORCE_CUDA_SUCCESS( + err, platform::errors::External("cudaFree raises unexpected error")); + if (NeedRecord()) { + std::lock_guard guard(*mtx_); + cur_size_ -= size; + } + } else { + cudaGetLastError(); // clear the error flag when cudaErrorCudartUnloading + } + } + + bool GetMemInfo(size_t *avail, size_t *total, size_t *actual_avail, + size_t *actual_total) { + { + CUDADeviceGuard guard(dev_id_); + auto result = cudaMemGetInfo(actual_avail, actual_total); + if (result != cudaSuccess) { + *actual_avail = 0; + } + RaiseNonOutOfMemoryError(&result); + } + + if (NeedRecord()) { + std::lock_guard guard(*mtx_); + *avail = std::min(*actual_avail, limit_size_ - cur_size_); + *total = std::min(*actual_total, limit_size_); + return *total < *actual_total; + } else { + *avail = *actual_avail; + *total = *actual_total; + return false; + } + } + + inline bool NeedRecord() const { return limit_size_ != 0; } + + uint64_t RecordedSize() const { + LockGuardPtr lock(mtx_); + return NeedRecord() ? cur_size_ : 0; + } + + uint64_t LimitSize() const { return limit_size_; } + + private: + const int dev_id_; + const uint64_t limit_size_; + uint64_t cur_size_{0}; + + mutable std::unique_ptr mtx_; + + static std::once_flag once_flag_; + static std::vector> instances_; +}; + +std::once_flag RecordedCudaMallocHelper::once_flag_; +std::vector> + RecordedCudaMallocHelper::instances_; + +cudaError_t RecordedCudaMalloc(void **ptr, size_t size, int dev_id) { + return RecordedCudaMallocHelper::Instance(dev_id)->Malloc(ptr, size); +} + +void RecordedCudaFree(void *p, size_t size, int dev_id) { + return RecordedCudaMallocHelper::Instance(dev_id)->Free(p, size); +} + +bool RecordedCudaMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail, + size_t *actual_total, int dev_id) { + return RecordedCudaMallocHelper::Instance(dev_id)->GetMemInfo( + avail, total, actual_avail, actual_total); +} + +uint64_t RecordedCudaMallocSize(int dev_id) { + return RecordedCudaMallocHelper::Instance(dev_id)->RecordedSize(); +} + +bool IsCudaMallocRecorded(int dev_id) { + return RecordedCudaMallocHelper::Instance(dev_id)->NeedRecord(); +} + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h index 46e5326c8b76cef6f51969273c5bf7f0921bbaa1..6a9893647172e2c63f4749fdb0ae1cb0fdfaaf04 100644 --- a/paddle/fluid/platform/gpu_info.h +++ b/paddle/fluid/platform/gpu_info.h @@ -104,8 +104,20 @@ void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream); //! Blocks until stream has completed all operations. void GpuStreamSync(cudaStream_t stream); -//! Raise error if status is not cudaSuccess or OOM, otherwise reset status. -void RaiseNonOutOfMemoryError(cudaError_t *status); +//! CudaMalloc with recorded info +cudaError_t RecordedCudaMalloc(void **ptr, size_t size, int dev_id); + +//! CudaFree with recorded info +void RecordedCudaFree(void *p, size_t size, int dev_id); + +//! Get available and total gpu memory with considering limitation +bool RecordedCudaMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail, + size_t *actual_total, int dev_id); + +//! Get recorded cudaMalloc size. If record is disabled, return 0. +uint64_t RecordedCudaMallocSize(int dev_id); + +bool IsCudaMallocRecorded(int dev_id); } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/test_limit_gpu_memory.cu b/paddle/fluid/platform/test_limit_gpu_memory.cu new file mode 100644 index 0000000000000000000000000000000000000000..ab42feba74629b009d3e999b5753ebc5fc1980d7 --- /dev/null +++ b/paddle/fluid/platform/test_limit_gpu_memory.cu @@ -0,0 +1,97 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gflags/gflags.h" +#include "gtest/gtest.h" +#include "paddle/fluid/platform/cuda_device_guard.h" +#include "paddle/fluid/platform/gpu_info.h" + +DECLARE_uint64(gpu_memory_limit_mb); + +namespace paddle { +namespace platform { + +static constexpr uint64_t GPU_MEMORY_LIMIT_MB = 500; +static constexpr int DEVICE_ID = 0; + +TEST(test_record_malloc, test_limit_gpu_memory) { + FLAGS_gpu_memory_limit_mb = GPU_MEMORY_LIMIT_MB; + size_t limit = FLAGS_gpu_memory_limit_mb << 20; + + { + ASSERT_TRUE(IsCudaMallocRecorded(DEVICE_ID)); + ASSERT_EQ(RecordedCudaMallocSize(DEVICE_ID), 0UL); + } + + size_t avail, total; + { + size_t actual_avail, actual_total; + RecordedCudaMemGetInfo(&avail, &total, &actual_avail, &actual_total, + DEVICE_ID); + ASSERT_EQ(total, limit); + ASSERT_EQ(cudaGetLastError(), cudaSuccess); + } + + { + CUDADeviceGuard guard(DEVICE_ID); + GpuMemoryUsage(&avail, &total); + ASSERT_EQ(total, limit); + ASSERT_EQ(cudaGetLastError(), cudaSuccess); + } + + cudaError_t err = cudaSuccess; + + void *p1 = nullptr; + size_t size1 = limit / 4 * 3; + { + err = platform::RecordedCudaMalloc(&p1, size1, DEVICE_ID); + ASSERT_EQ(err, cudaSuccess); + ASSERT_EQ(cudaGetLastError(), cudaSuccess); + ASSERT_NE(p1, nullptr); + + ASSERT_EQ(RecordedCudaMallocSize(DEVICE_ID), size1); + } + + void *p2 = nullptr; + size_t size2 = limit / 2; + { + err = platform::RecordedCudaMalloc(&p2, size2, DEVICE_ID); + ASSERT_EQ(err, cudaErrorMemoryAllocation); + ASSERT_EQ(cudaGetLastError(), cudaSuccess); + ASSERT_EQ(p2, nullptr); + + ASSERT_EQ(RecordedCudaMallocSize(DEVICE_ID), size1); + } + + { + platform::RecordedCudaFree(p1, size1, DEVICE_ID); + ASSERT_EQ(RecordedCudaMallocSize(DEVICE_ID), 0UL); + } + + { + err = platform::RecordedCudaMalloc(&p2, size2, DEVICE_ID); + ASSERT_EQ(err, cudaSuccess); + ASSERT_EQ(cudaGetLastError(), cudaSuccess); + ASSERT_NE(p2, nullptr); + ASSERT_EQ(RecordedCudaMallocSize(DEVICE_ID), size2); + } + + { + platform::RecordedCudaFree(p2, size2, DEVICE_ID); + ASSERT_EQ(RecordedCudaMallocSize(DEVICE_ID), 0UL); + } +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/pybind/global_value_getter_setter.cc b/paddle/fluid/pybind/global_value_getter_setter.cc index 4a0e09bb2ae7239ea429b51e464c1451ec0ae27f..3a803e03090639919cdb365fa693f18425e0651d 100644 --- a/paddle/fluid/pybind/global_value_getter_setter.cc +++ b/paddle/fluid/pybind/global_value_getter_setter.cc @@ -32,6 +32,10 @@ DECLARE_bool(use_ngraph); DECLARE_bool(use_system_allocator); DECLARE_bool(free_idle_chunk); DECLARE_bool(free_when_no_cache_hit); +#ifdef PADDLE_WITH_CUDA +DECLARE_uint64(gpu_memory_limit_mb); +#endif +DECLARE_string(allocator_strategy); namespace paddle { namespace pybind { @@ -169,8 +173,12 @@ static void RegisterGlobalVarGetterSetter() { REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_use_ngraph); REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_eager_delete_tensor_gb); REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_use_system_allocator); + REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_allocator_strategy); REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_free_idle_chunk); REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_free_when_no_cache_hit); +#ifdef PADDLE_WITH_CUDA + REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_gpu_memory_limit_mb); +#endif } } // namespace pybind diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 654627af445e33cc040935106b7f2316479cb4e1..d8d8410cf7e19c863194f55f79e33d0bbdc5001e 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -211,7 +211,7 @@ def __bootstrap__(): 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', 'cudnn_exhaustive_search', 'selected_gpus', 'sync_nccl_allreduce', 'cudnn_batchnorm_spatial_persistent', 'gpu_allocator_retry_time', - 'local_exe_sub_scope_limit' + 'local_exe_sub_scope_limit', 'gpu_memory_limit_mb' ] core.init_gflags([sys.argv[0]] + ["--tryfromenv=" + ",".join(read_env_flags)]) diff --git a/python/paddle/fluid/tests/unittests/test_auto_growth_gpu_memory_limit.py b/python/paddle/fluid/tests/unittests/test_auto_growth_gpu_memory_limit.py new file mode 100644 index 0000000000000000000000000000000000000000..3ff67a923a209e6f4c20232d2a5207c9f9e69909 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_growth_gpu_memory_limit.py @@ -0,0 +1,54 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import unittest +import numpy as np + +fluid.core.globals()['FLAGS_allocator_strategy'] = 'auto_growth' + +if fluid.is_compiled_with_cuda(): + fluid.core.globals()['FLAGS_gpu_memory_limit_mb'] = 10 + + +class TestBase(unittest.TestCase): + def setUp(self): + if fluid.is_compiled_with_cuda(): + self._limit = fluid.core.globals()['FLAGS_gpu_memory_limit_mb'] + + def test_allocate(self): + if not fluid.is_compiled_with_cuda(): + return + + other_dim = int(1024 * 1024 / 4) + + place = fluid.CUDAPlace(0) + t = fluid.LoDTensor() + t.set(np.ndarray( + [int(self._limit / 2), other_dim], dtype='float32'), + place) + del t + + t = fluid.LoDTensor() + large_np = np.ndarray([2 * self._limit, other_dim], dtype='float32') + + try: + t.set(large_np, place) + self.assertTrue(False) + except: + self.assertTrue(True) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_naive_best_fit_gpu_memory_limit.py b/python/paddle/fluid/tests/unittests/test_naive_best_fit_gpu_memory_limit.py new file mode 100644 index 0000000000000000000000000000000000000000..d8d10816bf97aba29fd8708d0a67b768d64bd417 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_naive_best_fit_gpu_memory_limit.py @@ -0,0 +1,54 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import unittest +import numpy as np + +fluid.core.globals()['FLAGS_allocator_strategy'] = 'naive_best_fit' + +if fluid.is_compiled_with_cuda(): + fluid.core.globals()['FLAGS_gpu_memory_limit_mb'] = 10 + + +class TestBase(unittest.TestCase): + def setUp(self): + if fluid.is_compiled_with_cuda(): + self._limit = fluid.core.globals()['FLAGS_gpu_memory_limit_mb'] + + def test_allocate(self): + if not fluid.is_compiled_with_cuda(): + return + + other_dim = int(1024 * 1024 / 4) + + place = fluid.CUDAPlace(0) + t = fluid.LoDTensor() + t.set(np.ndarray( + [int(self._limit / 2), other_dim], dtype='float32'), + place) + del t + + t = fluid.LoDTensor() + large_np = np.ndarray([2 * self._limit, other_dim], dtype='float32') + + try: + t.set(large_np, place) + self.assertTrue(False) + except: + self.assertTrue(True) + + +if __name__ == '__main__': + unittest.main()