未验证 提交 6a18c0f9 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #15278 from chengduoZH/revert_remove_workspace_handle_in_conv2d_cudnn

Revert "Remove workspace_handle in conv_cudnn (#15186)"
...@@ -391,7 +391,7 @@ class ExecutionContext { ...@@ -391,7 +391,7 @@ class ExecutionContext {
PADDLE_ENFORCE( PADDLE_ENFORCE(
dynamic_cast<platform::TemporaryAllocation*>(allocation_ptr) != nullptr, dynamic_cast<platform::TemporaryAllocation*>(allocation_ptr) != nullptr,
"The AllocationPtr must be TemporaryAllocation."); "The AllocationPtr must be TemporaryAllocation.");
PADDLE_ENFORCE_GE(allocation_ptr->size(), PADDLE_ENFORCE_EQ(allocation_ptr->size(),
framework::product(dim) * sizeof(T)); framework::product(dim) * sizeof(T));
paddle::framework::Tensor temp_tensor( paddle::framework::Tensor temp_tensor(
......
...@@ -137,6 +137,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -137,6 +137,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv algorithm --------------------- // ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionFwdAlgo_t algo; cudnnConvolutionFwdAlgo_t algo;
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
bool half_float = false; bool half_float = false;
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
...@@ -157,8 +158,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -157,8 +158,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
VLOG(5) << "NOT use cudnn_tensor_op_math"; VLOG(5) << "NOT use cudnn_tensor_op_math";
} }
#endif #endif
Tensor cudnn_workspace;
void* cudnn_workspace_ptr = nullptr;
auto x_dims = framework::vectorize(input->dims()); auto x_dims = framework::vectorize(input->dims());
auto f_dims = framework::vectorize(filter->dims()); auto f_dims = framework::vectorize(filter->dims());
...@@ -181,26 +180,21 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -181,26 +180,21 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
.Var(kCUDNNFwdAlgoCache) .Var(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>(); ->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
} }
cudnn_workspace =
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
framework::make_ddim(
{static_cast<int64_t>(workspace_size_limit)}),
dev_ctx);
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
algo = algo_cache->GetAlgorithm( algo = algo_cache->GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() { x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count; int returned_algo_count;
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS> std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
fwd_perf_stat; fwd_perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
handle, cudnn_input_desc, input_data, cudnn_filter_desc, handle, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, cudnn_output_desc, filter_data, cudnn_conv_desc, cudnn_output_desc,
output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count, output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
fwd_perf_stat.data(), cudnn_workspace_ptr, fwd_perf_stat.data(), cudnn_workspace,
workspace_size_limit)); workspace_size_limit));
};
workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit);
VLOG(3) << "Perf result: (algo: stat, time, memory)"; VLOG(3) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) { for (int i = 0; i < returned_algo_count; ++i) {
...@@ -225,23 +219,17 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -225,23 +219,17 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
"workspace_size to be allocated exceeds the limit"); "workspace_size to be allocated exceeds the limit");
// Allocate on GPU memory
if (!cudnn_workspace_ptr) {
cudnn_workspace =
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
framework::make_ddim(
{static_cast<int64_t>(workspace_size_in_bytes)}),
dev_ctx);
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
}
// ------------------- cudnn conv forward --------------------- // ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f; ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( auto cudnn_func = [&](void* cudnn_workspace) {
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
cudnn_filter_desc, filter_data + i * group_offset_filter, handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
cudnn_conv_desc, algo, cudnn_workspace_ptr, workspace_size_in_bytes, cudnn_filter_desc, filter_data + i * group_offset_filter,
&beta, cudnn_output_desc, output_data + i * group_offset_out)); cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
&beta, cudnn_output_desc, output_data + i * group_offset_out));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
}; };
...@@ -365,20 +353,10 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -365,20 +353,10 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
workspace_size_limit = max_user_size * 1024 * 1024; workspace_size_limit = max_user_size * 1024 * 1024;
} }
Tensor cudnn_workspace;
void* cudnn_workspace_ptr = nullptr;
if ((input_data || filter_data) && exhaustive_search) {
cudnn_workspace =
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
framework::make_ddim(
{static_cast<int64_t>(workspace_size_limit)}),
dev_ctx);
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
}
auto x_dims = framework::vectorize(input->dims()); auto x_dims = framework::vectorize(input->dims());
auto f_dims = framework::vectorize(filter->dims()); auto f_dims = framework::vectorize(filter->dims());
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
if (input_grad) { if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
if (exhaustive_search) { if (exhaustive_search) {
...@@ -396,22 +374,25 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -396,22 +374,25 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
->GetMutable< ->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>(); AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>();
} }
data_algo = data_algo_cache->GetAlgorithm( data_algo = data_algo_cache->GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() { x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count; int returned_algo_count;
std::array<cudnnConvolutionBwdDataAlgoPerf_t, std::array<cudnnConvolutionBwdDataAlgoPerf_t,
kNUM_CUDNN_BWD_DATA_ALGS> kNUM_CUDNN_BWD_DATA_ALGS>
data_perf_stat; data_perf_stat;
auto cudnn_find_bd_data_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload:: CUDNN_ENFORCE(
cudnnFindConvolutionBackwardDataAlgorithmEx( platform::dynload::
handle, cudnn_filter_desc, filter_data, cudnnFindConvolutionBackwardDataAlgorithmEx(
cudnn_output_grad_desc, output_grad_data, handle, cudnn_filter_desc, filter_data,
cudnn_conv_desc, cudnn_input_desc, cudnn_output_grad_desc, output_grad_data,
input_grad_data, kNUM_CUDNN_BWD_DATA_ALGS, cudnn_conv_desc, cudnn_input_desc, input_grad_data,
&returned_algo_count, data_perf_stat.data(), kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count,
cudnn_workspace_ptr, workspace_size_limit)); data_perf_stat.data(), cudnn_workspace,
workspace_size_limit));
};
workspace_handle.RunFunc(cudnn_find_bd_data_func,
workspace_size_limit);
VLOG(3) << "Perf result: (algo: stat, time, memory)"; VLOG(3) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) { for (int i = 0; i < returned_algo_count; ++i) {
...@@ -462,23 +443,25 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -462,23 +443,25 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
->GetMutable< ->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>(); AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>();
} }
filter_algo = f_algo_cache->GetAlgorithm( filter_algo = f_algo_cache->GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() { x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count; int returned_algo_count;
std::array<cudnnConvolutionBwdFilterAlgoPerf_t, std::array<cudnnConvolutionBwdFilterAlgoPerf_t,
kNUM_CUDNN_BWD_FILTER_ALGS> kNUM_CUDNN_BWD_FILTER_ALGS>
filter_perf_stat; filter_perf_stat;
auto cudnn_find_bd_f_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE( CUDNN_ENFORCE(
platform::dynload:: platform::dynload::
cudnnFindConvolutionBackwardFilterAlgorithmEx( cudnnFindConvolutionBackwardFilterAlgorithmEx(
handle, cudnn_input_desc, input_data, handle, cudnn_input_desc, input_data,
cudnn_output_grad_desc, output_grad_data, cudnn_output_grad_desc, output_grad_data,
cudnn_conv_desc, cudnn_filter_desc, filter_grad_data, cudnn_conv_desc, cudnn_filter_desc,
kNUM_CUDNN_BWD_FILTER_ALGS, &returned_algo_count, filter_grad_data, kNUM_CUDNN_BWD_FILTER_ALGS,
filter_perf_stat.data(), cudnn_workspace_ptr, &returned_algo_count, filter_perf_stat.data(),
workspace_size_limit)); cudnn_workspace, workspace_size_limit));
};
workspace_handle.RunFunc(cudnn_find_bd_f_func,
workspace_size_limit);
return filter_perf_stat[0].algo; return filter_perf_stat[0].algo;
}); });
VLOG(3) << "cuDNN backward filter algo " << filter_algo; VLOG(3) << "cuDNN backward filter algo " << filter_algo;
...@@ -499,16 +482,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -499,16 +482,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
} }
// ------------------- cudnn conv workspace ---------------------
if (!cudnn_workspace_ptr) {
cudnn_workspace =
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
framework::make_ddim(
{static_cast<int64_t>(workspace_size_in_bytes)}),
dev_ctx);
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
}
// ------------------- cudnn conv backward data --------------------- // ------------------- cudnn conv backward data ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f; ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
if (input_grad) { if (input_grad) {
...@@ -516,12 +489,15 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -516,12 +489,15 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// Because beta is zero, it is unnecessary to reset input_grad. // Because beta is zero, it is unnecessary to reset input_grad.
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( auto cudnn_func = [&](void* cudnn_workspace) {
handle, &alpha, cudnn_filter_desc, CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
filter_data + i * group_offset_filter, cudnn_output_grad_desc, handle, &alpha, cudnn_filter_desc,
output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo, filter_data + i * group_offset_filter, cudnn_output_grad_desc,
cudnn_workspace_ptr, workspace_size_in_bytes, &beta, output_grad_data + i * group_offset_out, cudnn_conv_desc,
cudnn_input_desc, input_grad_data + i * group_offset_in)); data_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_input_desc, input_grad_data + i * group_offset_in));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
// ------------------- cudnn conv backward filter --------------------- // ------------------- cudnn conv backward filter ---------------------
...@@ -529,12 +505,15 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -529,12 +505,15 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace()); T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset filter_grad. // Because beta is zero, it is unnecessary to reset filter_grad.
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( auto cudnn_func = [&](void* cudnn_workspace) {
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
cudnn_output_grad_desc, output_grad_data + i * group_offset_out, handle, &alpha, cudnn_input_desc,
cudnn_conv_desc, filter_algo, cudnn_workspace_ptr, input_data + i * group_offset_in, cudnn_output_grad_desc,
workspace_size_in_bytes, &beta, cudnn_filter_desc, output_grad_data + i * group_offset_out, cudnn_conv_desc,
filter_grad_data + i * group_offset_filter)); filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_filter_desc, filter_grad_data + i * group_offset_filter));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
} }
......
...@@ -61,7 +61,7 @@ namespace platform { ...@@ -61,7 +61,7 @@ namespace platform {
* the allocations of temp_allocation_queue: * the allocations of temp_allocation_queue:
* - when the Stream calls cudaStreamSynchronize; * - when the Stream calls cudaStreamSynchronize;
* - when the allocation size of opportunities exceeds a certain threshold * - when the allocation size of opportunities exceeds a certain threshold
* (defined by FLAGS_limit_of_tmp_allocation). * (defined by FLAGS_limit_of_temporary_allocation).
* *
* */ * */
class DeviceTemporaryAllocator { class DeviceTemporaryAllocator {
......
...@@ -15,15 +15,8 @@ ...@@ -15,15 +15,8 @@
#include "paddle/fluid/platform/temporary_allocator.h" #include "paddle/fluid/platform/temporary_allocator.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/memory/allocation/allocator_facade.h"
DEFINE_int64(limit_of_tmp_allocation, -1, DEFINE_double(limit_of_temporary_allocation, -1,
"The up limit of temporary_allocation size."); "The up limit of temporary_allocation size.");
DEFINE_double(times_excess_than_required_tmp_allocation, 2,
"times_excess_than_required_tmp_allocation indicates the "
"max size the TemporaryAllocator can return. For example, "
"if the required memory size is N, and "
"times_excess_than_required_tmp_allocation is 2.0, "
"the TemporaryAllocator will return the available allocation "
"that the range of size is N ~ 2*N.");
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -36,25 +29,24 @@ TemporaryAllocation::TemporaryAllocation( ...@@ -36,25 +29,24 @@ TemporaryAllocation::TemporaryAllocation(
underlying_allocation_(std::move(underlying_allocation)) {} underlying_allocation_(std::move(underlying_allocation)) {}
TemporaryAllocator::TemporaryAllocator(platform::Place place) : place_(place) { TemporaryAllocator::TemporaryAllocator(platform::Place place) : place_(place) {
temp_mem_map_.reset(new std::multimap<size_t, TemporaryAllocation *>()); temp_mem_queue_.reset(new std::deque<TemporaryAllocation *>());
} }
bool TemporaryAllocator::IsAllocThreadSafe() const { return true; } bool TemporaryAllocator::IsAllocThreadSafe() const { return true; }
void TemporaryAllocator::Release(const std::function<void()> &callback) { void TemporaryAllocator::Release(const std::function<void()> &callback) {
std::unique_ptr<std::multimap<size_t, TemporaryAllocation *>> t_allocations; std::shared_ptr<std::deque<TemporaryAllocation *>> t_allocations;
{ {
std::unique_lock<std::mutex> lock(mtx_); std::unique_lock<std::mutex> lock(mtx_);
callback(); callback();
t_allocations.swap(temp_mem_map_); t_allocations = temp_mem_queue_;
temp_mem_map_.reset(new std::multimap<size_t, TemporaryAllocation *>()); temp_mem_queue_.reset(new std::deque<TemporaryAllocation *>());
wait_delete_mem_ = 0; wait_delete_mem_ = 0;
} }
for (auto tmp : *t_allocations) { for (auto tmp : *t_allocations) {
VLOG(10) << "Delete temporary allocation " << tmp.second->ptr() VLOG(10) << "Delete temporary allocation " << tmp->ptr()
<< " size: " << tmp.second->size(); << " size: " << tmp->size();
delete tmp.second; delete tmp;
} }
} }
...@@ -62,34 +54,28 @@ void TemporaryAllocator::Free(alloc::Allocation *allocation) { ...@@ -62,34 +54,28 @@ void TemporaryAllocator::Free(alloc::Allocation *allocation) {
auto *temp_allocation = dynamic_cast<TemporaryAllocation *>(allocation); auto *temp_allocation = dynamic_cast<TemporaryAllocation *>(allocation);
PADDLE_ENFORCE_NOT_NULL(temp_allocation); PADDLE_ENFORCE_NOT_NULL(temp_allocation);
if (platform::is_gpu_place(temp_allocation->place())) { if (platform::is_gpu_place(temp_allocation->place())) {
PADDLE_ENFORCE(platform::is_same_place(temp_allocation->place(), place_),
"The place should be the same.");
size_t wait_delete_mem = 0; size_t wait_delete_mem = 0;
{ {
std::unique_lock<std::mutex> lock(mtx_); std::unique_lock<std::mutex> lock(mtx_);
temp_mem_map_->emplace(temp_allocation->size(), temp_allocation); temp_mem_queue_->emplace_back(temp_allocation);
wait_delete_mem_ += temp_allocation->size(); wait_delete_mem_ += temp_allocation->size();
wait_delete_mem = wait_delete_mem_; wait_delete_mem = wait_delete_mem_;
VLOG(10) << "Move temporary allocation: " << temp_allocation->ptr() VLOG(10) << "Move temporary allocation: " << temp_allocation->ptr()
<< " to delete queue: " << temp_allocation->size() << "; " << " to delete queue: " << temp_allocation->size() << "; "
<< "wait_delete_mem: " << wait_delete_mem; << "wait_delete_mem: " << wait_delete_mem_;
} }
if (FLAGS_limit_of_temporary_allocation > 0 &&
if (FLAGS_limit_of_tmp_allocation > 0 && wait_delete_mem > FLAGS_limit_of_temporary_allocation) {
wait_delete_mem > static_cast<size_t>(FLAGS_limit_of_tmp_allocation)) {
PADDLE_ENFORCE(callback_ != nullptr, "The callback is non-initialized.");
Release(callback_); Release(callback_);
} }
return; return;
} }
VLOG(10) << "Delete temporary allocation " << temp_allocation->ptr()
<< " size: " << temp_allocation->size();
delete temp_allocation; delete temp_allocation;
} }
size_t TemporaryAllocator::TemporaryAllocationQueueSize() { size_t TemporaryAllocator::TemporaryAllocationQueueSize() {
std::unique_lock<std::mutex> lock(mtx_); std::unique_lock<std::mutex> lock(mtx_);
return temp_mem_map_ ? temp_mem_map_->size() : 0; return temp_mem_queue_ ? temp_mem_queue_->size() : 0;
} }
void TemporaryAllocator::SetCallback(const std::function<void()> &callback) { void TemporaryAllocator::SetCallback(const std::function<void()> &callback) {
...@@ -98,27 +84,6 @@ void TemporaryAllocator::SetCallback(const std::function<void()> &callback) { ...@@ -98,27 +84,6 @@ void TemporaryAllocator::SetCallback(const std::function<void()> &callback) {
alloc::Allocation *TemporaryAllocator::AllocateImpl( alloc::Allocation *TemporaryAllocator::AllocateImpl(
size_t size, alloc::Allocator::Attr attr) { size_t size, alloc::Allocator::Attr attr) {
{
// Find available allocation in temp_mem_map.
std::unique_lock<std::mutex> lock(mtx_);
if (temp_mem_map_->size()) {
auto it = temp_mem_map_->lower_bound(size);
// FIXME(zcd): Not sure the best value of excess fraction.
if (it != temp_mem_map_->end() &&
it->first <
static_cast<size_t>(
size * FLAGS_times_excess_than_required_tmp_allocation)) {
auto tmp_ptr = it->second;
temp_mem_map_->erase(it);
wait_delete_mem_ -= tmp_ptr->size();
VLOG(10) << "Reuse temporary allocation: " << tmp_ptr->ptr() << ": "
<< tmp_ptr->size();
return tmp_ptr;
}
}
}
// If not find the the available allocation, get allocation from
// AllocatorFacadeInstance.
auto raw_allocation = auto raw_allocation =
alloc::AllocatorFacade::Instance().Alloc(place_, size, attr); alloc::AllocatorFacade::Instance().Alloc(place_, size, attr);
auto temp_mem = new TemporaryAllocation(std::move(raw_allocation)); auto temp_mem = new TemporaryAllocation(std::move(raw_allocation));
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#pragma once #pragma once
#include <condition_variable> // NOLINT #include <condition_variable> // NOLINT
#include <deque> #include <deque>
#include <map>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/lock_guard_ptr.h" #include "paddle/fluid/platform/lock_guard_ptr.h"
...@@ -40,7 +39,7 @@ class TemporaryAllocation : public memory::allocation::Allocation { ...@@ -40,7 +39,7 @@ class TemporaryAllocation : public memory::allocation::Allocation {
* *
* There is one opportunity to free the allocations of temp_allocation_queue: * There is one opportunity to free the allocations of temp_allocation_queue:
* - when the allocation size of opportunities exceeds a certain threshold * - when the allocation size of opportunities exceeds a certain threshold
* (defined by FLAGS_limit_of_tmp_allocation). * (defined by FLAGS_limit_of_temporary_allocation).
* *
* */ * */
class TemporaryAllocator : public memory::allocation::Allocator { class TemporaryAllocator : public memory::allocation::Allocator {
...@@ -63,10 +62,11 @@ class TemporaryAllocator : public memory::allocation::Allocator { ...@@ -63,10 +62,11 @@ class TemporaryAllocator : public memory::allocation::Allocator {
private: private:
platform::Place place_; platform::Place place_;
// When the allocation is not held by any variable, it should be placed // When the allocation is not held by any variable, it should be placed
// to temp_mem_map immediately. // to temp_mem_queue immediately.
std::unique_ptr<std::multimap<size_t, TemporaryAllocation *>> temp_mem_map_{ std::shared_ptr<std::deque<TemporaryAllocation *>> temp_mem_queue_{nullptr};
nullptr};
std::mutex mtx_; std::mutex mtx_;
size_t wait_delete_mem_{0}; size_t wait_delete_mem_{0};
std::function<void()> callback_; std::function<void()> callback_;
......
...@@ -18,8 +18,7 @@ ...@@ -18,8 +18,7 @@
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
DECLARE_int64(limit_of_tmp_allocation); DECLARE_double(limit_of_temporary_allocation);
DECLARE_double(times_excess_than_required_tmp_allocation);
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -36,7 +35,7 @@ class DummyOp : public framework::OperatorBase { ...@@ -36,7 +35,7 @@ class DummyOp : public framework::OperatorBase {
const platform::Place& place) const override {} const platform::Place& place) const override {}
}; };
TEST(temporary_allocator, test_base_function) { TEST(temporary_allocator, temporary_allocator) {
platform::CPUPlace cpu_place; platform::CPUPlace cpu_place;
TemporaryAllocator alloc(cpu_place); TemporaryAllocator alloc(cpu_place);
alloc.Allocate(100); alloc.Allocate(100);
...@@ -60,10 +59,10 @@ TEST(temporary_allocator, test_base_function) { ...@@ -60,10 +59,10 @@ TEST(temporary_allocator, test_base_function) {
#endif #endif
} }
TEST(temporary_allocator, test_flags_function) { TEST(temporary_allocator, add_callback) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
const int64_t limit = FLAGS_limit_of_tmp_allocation; const double limit = FLAGS_limit_of_temporary_allocation;
FLAGS_limit_of_tmp_allocation = 10; FLAGS_limit_of_temporary_allocation = 10;
platform::CUDAPlace gpu_place(0); platform::CUDAPlace gpu_place(0);
TemporaryAllocator gpu_alloc(gpu_place); TemporaryAllocator gpu_alloc(gpu_place);
...@@ -79,52 +78,7 @@ TEST(temporary_allocator, test_flags_function) { ...@@ -79,52 +78,7 @@ TEST(temporary_allocator, test_flags_function) {
}); });
{ gpu_alloc.Allocate(100); } { gpu_alloc.Allocate(100); }
PADDLE_ENFORCE(deleted); PADDLE_ENFORCE(deleted);
FLAGS_limit_of_tmp_allocation = limit; FLAGS_limit_of_temporary_allocation = limit;
#endif
}
TEST(temporary_allocator, test_reuse_tmp_allocation) {
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace gpu_place(0);
TemporaryAllocator gpu_alloc(gpu_place);
gpu_alloc.SetCallback([]() {});
void* tmp_allocation_ptr1 = nullptr;
{
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
auto tmp_allocation1 = gpu_alloc.Allocate(100);
tmp_allocation_ptr1 = tmp_allocation1->ptr();
}
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1);
auto tmp_allocation2 = gpu_alloc.Allocate(100);
void* tmp_allocation_ptr2 = tmp_allocation2->ptr();
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr2);
auto tmp_allocation3 = gpu_alloc.Allocate(100);
void* tmp_allocation_ptr3 = tmp_allocation2->ptr();
PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr3);
#endif
}
TEST(temporary_allocator, test_times_excess_than_required_tmp_allocation) {
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace gpu_place(0);
TemporaryAllocator gpu_alloc(gpu_place);
gpu_alloc.SetCallback([]() {});
double excess_fraction = FLAGS_times_excess_than_required_tmp_allocation;
void* tmp_allocation_ptr1 = nullptr;
{
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
auto tmp_allocation1 =
gpu_alloc.Allocate(static_cast<size_t>(100 * excess_fraction - 1));
tmp_allocation_ptr1 = tmp_allocation1->ptr();
}
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1);
auto tmp_allocation2 = gpu_alloc.Allocate(100);
void* tmp_allocation_ptr2 = tmp_allocation2->ptr();
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr2);
#endif #endif
} }
......
...@@ -155,8 +155,7 @@ def __bootstrap__(): ...@@ -155,8 +155,7 @@ def __bootstrap__():
'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic',
'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit',
'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus', 'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus',
'sync_nccl_allreduce', 'limit_of_tmp_allocation', 'sync_nccl_allreduce'
'times_excess_than_required_tmp_allocation'
] ]
core.init_gflags([sys.argv[0]] + core.init_gflags([sys.argv[0]] +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册