From 064512aa47b9ea35c0b5479b32c1653512c8b7c4 Mon Sep 17 00:00:00 2001 From: chengduo Date: Thu, 10 Jan 2019 18:41:40 -0600 Subject: [PATCH] Remove workspace_handle in conv_cudnn (#15186) * remove workspace_handle in conv2d_cudnn test=develop * remove workspace_handle test=develop * fix bug test=develop * make test_conv2d_op SERIAL test=develop * save memory in conv_cudnn test=develop * enhance thread safety test=develop * enhance temporary allocator test=develop * Add excess fraction test=develop * follow comments test=develop * fix bug and code refine test=develop * fix memory size check test=develop * rename reuse_tmp_allocation_excess_fraction test=develop --- paddle/fluid/framework/operator.h | 2 +- paddle/fluid/operators/conv_cudnn_op.cu.cc | 149 ++++++++++-------- paddle/fluid/platform/device_context.cc | 28 ++-- paddle/fluid/platform/device_context.h | 2 +- paddle/fluid/platform/temporary_allocator.cc | 63 ++++++-- paddle/fluid/platform/temporary_allocator.h | 10 +- .../platform/temporary_allocator_test.cc | 58 ++++++- python/paddle/fluid/__init__.py | 3 +- 8 files changed, 208 insertions(+), 107 deletions(-) diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 4d29564aeed..041187665af 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -391,7 +391,7 @@ class ExecutionContext { PADDLE_ENFORCE( dynamic_cast(allocation_ptr) != nullptr, "The AllocationPtr must be TemporaryAllocation."); - PADDLE_ENFORCE_EQ(allocation_ptr->size(), + PADDLE_ENFORCE_GE(allocation_ptr->size(), framework::product(dim) * sizeof(T)); paddle::framework::Tensor temp_tensor( diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 25a723fc079..f5208e7a601 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -137,7 +137,6 @@ class CUDNNConvOpKernel : public framework::OpKernel { // ------------------- cudnn conv algorithm --------------------- cudnnConvolutionFwdAlgo_t algo; auto handle = dev_ctx.cudnn_handle(); - auto workspace_handle = dev_ctx.cudnn_workspace_handle(); bool half_float = false; #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) @@ -158,6 +157,8 @@ class CUDNNConvOpKernel : public framework::OpKernel { VLOG(5) << "NOT use cudnn_tensor_op_math"; } #endif + Tensor cudnn_workspace; + void* cudnn_workspace_ptr = nullptr; auto x_dims = framework::vectorize(input->dims()); auto f_dims = framework::vectorize(filter->dims()); @@ -180,21 +181,26 @@ class CUDNNConvOpKernel : public framework::OpKernel { .Var(kCUDNNFwdAlgoCache) ->GetMutable>(); } + cudnn_workspace = + ctx.AllocateTmpTensor( + framework::make_ddim( + {static_cast(workspace_size_limit)}), + dev_ctx); + cudnn_workspace_ptr = static_cast(cudnn_workspace.data()); + algo = algo_cache->GetAlgorithm( x_dims, f_dims, strides, paddings, dilations, 0, [&]() { int returned_algo_count; std::array fwd_perf_stat; - auto cudnn_find_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE( - platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( - handle, cudnn_input_desc, input_data, cudnn_filter_desc, - filter_data, cudnn_conv_desc, cudnn_output_desc, - output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count, - fwd_perf_stat.data(), cudnn_workspace, - workspace_size_limit)); - }; - workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit); + + CUDNN_ENFORCE( + platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( + handle, cudnn_input_desc, input_data, cudnn_filter_desc, + filter_data, cudnn_conv_desc, cudnn_output_desc, + output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count, + fwd_perf_stat.data(), cudnn_workspace_ptr, + workspace_size_limit)); VLOG(3) << "Perf result: (algo: stat, time, memory)"; for (int i = 0; i < returned_algo_count; ++i) { @@ -219,17 +225,23 @@ class CUDNNConvOpKernel : public framework::OpKernel { PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, "workspace_size to be allocated exceeds the limit"); + // Allocate on GPU memory + if (!cudnn_workspace_ptr) { + cudnn_workspace = + ctx.AllocateTmpTensor( + framework::make_ddim( + {static_cast(workspace_size_in_bytes)}), + dev_ctx); + cudnn_workspace_ptr = static_cast(cudnn_workspace.data()); + } // ------------------- cudnn conv forward --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; for (int i = 0; i < groups; i++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( - handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, - cudnn_filter_desc, filter_data + i * group_offset_filter, - cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes, - &beta, cudnn_output_desc, output_data + i * group_offset_out)); - }; - workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, + cudnn_filter_desc, filter_data + i * group_offset_filter, + cudnn_conv_desc, algo, cudnn_workspace_ptr, workspace_size_in_bytes, + &beta, cudnn_output_desc, output_data + i * group_offset_out)); } } }; @@ -353,10 +365,20 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { workspace_size_limit = max_user_size * 1024 * 1024; } + Tensor cudnn_workspace; + void* cudnn_workspace_ptr = nullptr; + if ((input_data || filter_data) && exhaustive_search) { + cudnn_workspace = + ctx.AllocateTmpTensor( + framework::make_ddim( + {static_cast(workspace_size_limit)}), + dev_ctx); + cudnn_workspace_ptr = static_cast(cudnn_workspace.data()); + } + auto x_dims = framework::vectorize(input->dims()); auto f_dims = framework::vectorize(filter->dims()); auto handle = dev_ctx.cudnn_handle(); - auto workspace_handle = dev_ctx.cudnn_workspace_handle(); if (input_grad) { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); if (exhaustive_search) { @@ -374,25 +396,22 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { ->GetMutable< AlgorithmsCache>(); } + data_algo = data_algo_cache->GetAlgorithm( x_dims, f_dims, strides, paddings, dilations, 0, [&]() { int returned_algo_count; std::array data_perf_stat; - auto cudnn_find_bd_data_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE( - platform::dynload:: - cudnnFindConvolutionBackwardDataAlgorithmEx( - handle, cudnn_filter_desc, filter_data, - cudnn_output_grad_desc, output_grad_data, - cudnn_conv_desc, cudnn_input_desc, input_grad_data, - kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count, - data_perf_stat.data(), cudnn_workspace, - workspace_size_limit)); - }; - workspace_handle.RunFunc(cudnn_find_bd_data_func, - workspace_size_limit); + + CUDNN_ENFORCE(platform::dynload:: + cudnnFindConvolutionBackwardDataAlgorithmEx( + handle, cudnn_filter_desc, filter_data, + cudnn_output_grad_desc, output_grad_data, + cudnn_conv_desc, cudnn_input_desc, + input_grad_data, kNUM_CUDNN_BWD_DATA_ALGS, + &returned_algo_count, data_perf_stat.data(), + cudnn_workspace_ptr, workspace_size_limit)); VLOG(3) << "Perf result: (algo: stat, time, memory)"; for (int i = 0; i < returned_algo_count; ++i) { @@ -443,25 +462,23 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { ->GetMutable< AlgorithmsCache>(); } + filter_algo = f_algo_cache->GetAlgorithm( x_dims, f_dims, strides, paddings, dilations, 0, [&]() { int returned_algo_count; std::array filter_perf_stat; - auto cudnn_find_bd_f_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE( - platform::dynload:: - cudnnFindConvolutionBackwardFilterAlgorithmEx( - handle, cudnn_input_desc, input_data, - cudnn_output_grad_desc, output_grad_data, - cudnn_conv_desc, cudnn_filter_desc, - filter_grad_data, kNUM_CUDNN_BWD_FILTER_ALGS, - &returned_algo_count, filter_perf_stat.data(), - cudnn_workspace, workspace_size_limit)); - }; - workspace_handle.RunFunc(cudnn_find_bd_f_func, - workspace_size_limit); + + CUDNN_ENFORCE( + platform::dynload:: + cudnnFindConvolutionBackwardFilterAlgorithmEx( + handle, cudnn_input_desc, input_data, + cudnn_output_grad_desc, output_grad_data, + cudnn_conv_desc, cudnn_filter_desc, filter_grad_data, + kNUM_CUDNN_BWD_FILTER_ALGS, &returned_algo_count, + filter_perf_stat.data(), cudnn_workspace_ptr, + workspace_size_limit)); return filter_perf_stat[0].algo; }); VLOG(3) << "cuDNN backward filter algo " << filter_algo; @@ -482,6 +499,16 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); } + // ------------------- cudnn conv workspace --------------------- + if (!cudnn_workspace_ptr) { + cudnn_workspace = + ctx.AllocateTmpTensor( + framework::make_ddim( + {static_cast(workspace_size_in_bytes)}), + dev_ctx); + cudnn_workspace_ptr = static_cast(cudnn_workspace.data()); + } + // ------------------- cudnn conv backward data --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; if (input_grad) { @@ -489,15 +516,12 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { // Because beta is zero, it is unnecessary to reset input_grad. for (int i = 0; i < groups; i++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( - handle, &alpha, cudnn_filter_desc, - filter_data + i * group_offset_filter, cudnn_output_grad_desc, - output_grad_data + i * group_offset_out, cudnn_conv_desc, - data_algo, cudnn_workspace, workspace_size_in_bytes, &beta, - cudnn_input_desc, input_grad_data + i * group_offset_in)); - }; - workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, cudnn_filter_desc, + filter_data + i * group_offset_filter, cudnn_output_grad_desc, + output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo, + cudnn_workspace_ptr, workspace_size_in_bytes, &beta, + cudnn_input_desc, input_grad_data + i * group_offset_in)); } } // ------------------- cudnn conv backward filter --------------------- @@ -505,15 +529,12 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset filter_grad. for (int i = 0; i < groups; i++) { - auto cudnn_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( - handle, &alpha, cudnn_input_desc, - input_data + i * group_offset_in, cudnn_output_grad_desc, - output_grad_data + i * group_offset_out, cudnn_conv_desc, - filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta, - cudnn_filter_desc, filter_grad_data + i * group_offset_filter)); - }; - workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, + cudnn_output_grad_desc, output_grad_data + i * group_offset_out, + cudnn_conv_desc, filter_algo, cudnn_workspace_ptr, + workspace_size_in_bytes, &beta, cudnn_filter_desc, + filter_grad_data + i * group_offset_filter)); } } } diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 09f3d3de54e..8f80a2d7822 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -92,26 +92,24 @@ platform::TemporaryAllocator& DeviceTemporaryAllocator::Get( const platform::Place& place, const cudaStream_t& stream) { PADDLE_ENFORCE(platform::is_gpu_place(place)); auto place_stream = std::make_pair(place, stream); - { - std::unique_lock lock(mtx_); - if (!device_allocator_.count(place_stream)) { - device_allocator_[place_stream].reset(new TemporaryAllocator(place)); - device_allocator_[place_stream]->SetCallback([stream]() { - PADDLE_ENFORCE(cudaStreamSynchronize(stream)); - PADDLE_ENFORCE(cudaGetLastError()); - }); - } + std::unique_lock lock(mtx_); + auto it = device_allocator_.find(place_stream); + if (it == device_allocator_.end()) { + auto tmp_allocator = new TemporaryAllocator(place); + tmp_allocator->SetCallback([stream]() { + PADDLE_ENFORCE(cudaStreamSynchronize(stream)); + PADDLE_ENFORCE(cudaGetLastError()); + }); + device_allocator_[place_stream].reset(tmp_allocator); + return *tmp_allocator; + } else { + return *it->second; } - return *device_allocator_.at(place_stream); } template <> platform::TemporaryAllocator& DeviceTemporaryAllocator::Get( const platform::CUDADeviceContext& dev_ctx) { - auto place_stream = std::make_pair(dev_ctx.GetPlace(), dev_ctx.stream()); - if (device_allocator_.count(place_stream)) { - return *device_allocator_.at(place_stream); - } return Get(dev_ctx.GetPlace(), dev_ctx.stream()); } #endif @@ -325,7 +323,7 @@ Place CUDADeviceContext::GetPlace() const { return place_; } void CUDADeviceContext::Wait() const { auto& allocator = DeviceTemporaryAllocator::Instance().Get(*this); - allocator.Release([=]() { + allocator.Release([this]() { PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE(cudaGetLastError()); }); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index c81d17380cf..d376f90ad57 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -61,7 +61,7 @@ namespace platform { * the allocations of temp_allocation_queue: * - when the Stream calls cudaStreamSynchronize; * - when the allocation size of opportunities exceeds a certain threshold - * (defined by FLAGS_limit_of_temporary_allocation). + * (defined by FLAGS_limit_of_tmp_allocation). * * */ class DeviceTemporaryAllocator { diff --git a/paddle/fluid/platform/temporary_allocator.cc b/paddle/fluid/platform/temporary_allocator.cc index 0be017f75bc..9cbdfe46e78 100644 --- a/paddle/fluid/platform/temporary_allocator.cc +++ b/paddle/fluid/platform/temporary_allocator.cc @@ -15,8 +15,15 @@ #include "paddle/fluid/platform/temporary_allocator.h" #include "paddle/fluid/memory/allocation/allocator_facade.h" -DEFINE_double(limit_of_temporary_allocation, -1, - "The up limit of temporary_allocation size."); +DEFINE_int64(limit_of_tmp_allocation, -1, + "The up limit of temporary_allocation size."); +DEFINE_double(times_excess_than_required_tmp_allocation, 2, + "times_excess_than_required_tmp_allocation indicates the " + "max size the TemporaryAllocator can return. For example, " + "if the required memory size is N, and " + "times_excess_than_required_tmp_allocation is 2.0, " + "the TemporaryAllocator will return the available allocation " + "that the range of size is N ~ 2*N."); namespace paddle { namespace platform { @@ -29,24 +36,25 @@ TemporaryAllocation::TemporaryAllocation( underlying_allocation_(std::move(underlying_allocation)) {} TemporaryAllocator::TemporaryAllocator(platform::Place place) : place_(place) { - temp_mem_queue_.reset(new std::deque()); + temp_mem_map_.reset(new std::multimap()); } bool TemporaryAllocator::IsAllocThreadSafe() const { return true; } void TemporaryAllocator::Release(const std::function &callback) { - std::shared_ptr> t_allocations; + std::unique_ptr> t_allocations; { std::unique_lock lock(mtx_); callback(); - t_allocations = temp_mem_queue_; - temp_mem_queue_.reset(new std::deque()); + t_allocations.swap(temp_mem_map_); + temp_mem_map_.reset(new std::multimap()); wait_delete_mem_ = 0; } + for (auto tmp : *t_allocations) { - VLOG(10) << "Delete temporary allocation " << tmp->ptr() - << " size: " << tmp->size(); - delete tmp; + VLOG(10) << "Delete temporary allocation " << tmp.second->ptr() + << " size: " << tmp.second->size(); + delete tmp.second; } } @@ -54,28 +62,34 @@ void TemporaryAllocator::Free(alloc::Allocation *allocation) { auto *temp_allocation = dynamic_cast(allocation); PADDLE_ENFORCE_NOT_NULL(temp_allocation); if (platform::is_gpu_place(temp_allocation->place())) { + PADDLE_ENFORCE(platform::is_same_place(temp_allocation->place(), place_), + "The place should be the same."); size_t wait_delete_mem = 0; { std::unique_lock lock(mtx_); - temp_mem_queue_->emplace_back(temp_allocation); + temp_mem_map_->emplace(temp_allocation->size(), temp_allocation); wait_delete_mem_ += temp_allocation->size(); wait_delete_mem = wait_delete_mem_; VLOG(10) << "Move temporary allocation: " << temp_allocation->ptr() << " to delete queue: " << temp_allocation->size() << "; " - << "wait_delete_mem: " << wait_delete_mem_; + << "wait_delete_mem: " << wait_delete_mem; } - if (FLAGS_limit_of_temporary_allocation > 0 && - wait_delete_mem > FLAGS_limit_of_temporary_allocation) { + + if (FLAGS_limit_of_tmp_allocation > 0 && + wait_delete_mem > static_cast(FLAGS_limit_of_tmp_allocation)) { + PADDLE_ENFORCE(callback_ != nullptr, "The callback is non-initialized."); Release(callback_); } return; } + VLOG(10) << "Delete temporary allocation " << temp_allocation->ptr() + << " size: " << temp_allocation->size(); delete temp_allocation; } size_t TemporaryAllocator::TemporaryAllocationQueueSize() { std::unique_lock lock(mtx_); - return temp_mem_queue_ ? temp_mem_queue_->size() : 0; + return temp_mem_map_ ? temp_mem_map_->size() : 0; } void TemporaryAllocator::SetCallback(const std::function &callback) { @@ -84,6 +98,27 @@ void TemporaryAllocator::SetCallback(const std::function &callback) { alloc::Allocation *TemporaryAllocator::AllocateImpl( size_t size, alloc::Allocator::Attr attr) { + { + // Find available allocation in temp_mem_map. + std::unique_lock lock(mtx_); + if (temp_mem_map_->size()) { + auto it = temp_mem_map_->lower_bound(size); + // FIXME(zcd): Not sure the best value of excess fraction. + if (it != temp_mem_map_->end() && + it->first < + static_cast( + size * FLAGS_times_excess_than_required_tmp_allocation)) { + auto tmp_ptr = it->second; + temp_mem_map_->erase(it); + wait_delete_mem_ -= tmp_ptr->size(); + VLOG(10) << "Reuse temporary allocation: " << tmp_ptr->ptr() << ": " + << tmp_ptr->size(); + return tmp_ptr; + } + } + } + // If not find the the available allocation, get allocation from + // AllocatorFacadeInstance. auto raw_allocation = alloc::AllocatorFacade::Instance().Alloc(place_, size, attr); auto temp_mem = new TemporaryAllocation(std::move(raw_allocation)); diff --git a/paddle/fluid/platform/temporary_allocator.h b/paddle/fluid/platform/temporary_allocator.h index 812c4a33318..d657a142233 100644 --- a/paddle/fluid/platform/temporary_allocator.h +++ b/paddle/fluid/platform/temporary_allocator.h @@ -15,6 +15,7 @@ #pragma once #include // NOLINT #include +#include #include // NOLINT #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/platform/lock_guard_ptr.h" @@ -39,7 +40,7 @@ class TemporaryAllocation : public memory::allocation::Allocation { * * There is one opportunity to free the allocations of temp_allocation_queue: * - when the allocation size of opportunities exceeds a certain threshold - * (defined by FLAGS_limit_of_temporary_allocation). + * (defined by FLAGS_limit_of_tmp_allocation). * * */ class TemporaryAllocator : public memory::allocation::Allocator { @@ -62,11 +63,10 @@ class TemporaryAllocator : public memory::allocation::Allocator { private: platform::Place place_; - // When the allocation is not held by any variable, it should be placed - // to temp_mem_queue immediately. - std::shared_ptr> temp_mem_queue_{nullptr}; - + // to temp_mem_map immediately. + std::unique_ptr> temp_mem_map_{ + nullptr}; std::mutex mtx_; size_t wait_delete_mem_{0}; std::function callback_; diff --git a/paddle/fluid/platform/temporary_allocator_test.cc b/paddle/fluid/platform/temporary_allocator_test.cc index 35d1d929819..3879cd54001 100644 --- a/paddle/fluid/platform/temporary_allocator_test.cc +++ b/paddle/fluid/platform/temporary_allocator_test.cc @@ -18,7 +18,8 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor_util.h" -DECLARE_double(limit_of_temporary_allocation); +DECLARE_int64(limit_of_tmp_allocation); +DECLARE_double(times_excess_than_required_tmp_allocation); namespace paddle { namespace platform { @@ -35,7 +36,7 @@ class DummyOp : public framework::OperatorBase { const platform::Place& place) const override {} }; -TEST(temporary_allocator, temporary_allocator) { +TEST(temporary_allocator, test_base_function) { platform::CPUPlace cpu_place; TemporaryAllocator alloc(cpu_place); alloc.Allocate(100); @@ -59,10 +60,10 @@ TEST(temporary_allocator, temporary_allocator) { #endif } -TEST(temporary_allocator, add_callback) { +TEST(temporary_allocator, test_flags_function) { #ifdef PADDLE_WITH_CUDA - const double limit = FLAGS_limit_of_temporary_allocation; - FLAGS_limit_of_temporary_allocation = 10; + const int64_t limit = FLAGS_limit_of_tmp_allocation; + FLAGS_limit_of_tmp_allocation = 10; platform::CUDAPlace gpu_place(0); TemporaryAllocator gpu_alloc(gpu_place); @@ -78,7 +79,52 @@ TEST(temporary_allocator, add_callback) { }); { gpu_alloc.Allocate(100); } PADDLE_ENFORCE(deleted); - FLAGS_limit_of_temporary_allocation = limit; + FLAGS_limit_of_tmp_allocation = limit; +#endif +} + +TEST(temporary_allocator, test_reuse_tmp_allocation) { +#ifdef PADDLE_WITH_CUDA + platform::CUDAPlace gpu_place(0); + TemporaryAllocator gpu_alloc(gpu_place); + gpu_alloc.SetCallback([]() {}); + + void* tmp_allocation_ptr1 = nullptr; + { + PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); + auto tmp_allocation1 = gpu_alloc.Allocate(100); + tmp_allocation_ptr1 = tmp_allocation1->ptr(); + } + PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1); + auto tmp_allocation2 = gpu_alloc.Allocate(100); + void* tmp_allocation_ptr2 = tmp_allocation2->ptr(); + PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); + PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr2); + + auto tmp_allocation3 = gpu_alloc.Allocate(100); + void* tmp_allocation_ptr3 = tmp_allocation2->ptr(); + PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr3); +#endif +} + +TEST(temporary_allocator, test_times_excess_than_required_tmp_allocation) { +#ifdef PADDLE_WITH_CUDA + platform::CUDAPlace gpu_place(0); + TemporaryAllocator gpu_alloc(gpu_place); + gpu_alloc.SetCallback([]() {}); + double excess_fraction = FLAGS_times_excess_than_required_tmp_allocation; + void* tmp_allocation_ptr1 = nullptr; + { + PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); + auto tmp_allocation1 = + gpu_alloc.Allocate(static_cast(100 * excess_fraction - 1)); + tmp_allocation_ptr1 = tmp_allocation1->ptr(); + } + PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1); + auto tmp_allocation2 = gpu_alloc.Allocate(100); + void* tmp_allocation_ptr2 = tmp_allocation2->ptr(); + PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); + PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr2); #endif } diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 2c17716500a..686550a3c8d 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -155,7 +155,8 @@ def __bootstrap__(): 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', 'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus', - 'sync_nccl_allreduce' + 'sync_nccl_allreduce', 'limit_of_tmp_allocation', + 'times_excess_than_required_tmp_allocation' ] core.init_gflags([sys.argv[0]] + -- GitLab