未验证 提交 79bd6dfa 编写于 作者: C chengduo 提交者: GitHub

[Feature] Add Temporary Allocator (#14875)

* Add Temporal Allocator

* add Temporay Allocator to DeviceContext
test=develop

* code refine
test=develop

* fix mean_iou
test=develop

* Add DeviceTemporaryAllocator
test=develop

* fix conv_op bug
test=develop

* small fix
test=develop

* code refine
test=develop

* log refine
test=develop

* fix unit test
test=develop

* move double check

* refine concat_and_split
test=develop

* add limit_of_temporary_allocation
test=develop

* fix name
test=develop
上级 484c24b7
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/var_type.h"
namespace paddle {
namespace framework {
......@@ -27,6 +28,9 @@ void Tensor::check_memory_size() const {
"or maybe the required data-type mismatches the data already stored.");
}
Tensor::Tensor(std::type_index type)
: type_(framework::ToDataType(type)), offset_(0) {}
size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_;
}
......@@ -101,5 +105,12 @@ const DDim& Tensor::dims() const { return dims_; }
int64_t Tensor::numel() const { return product(dims_); }
void Tensor::ResetHolder(std::shared_ptr<memory::Allocation> holder) {
if (holder_) {
PADDLE_ENFORCE_EQ(numel() * SizeOfType(type()), holder->size());
}
holder_ = holder;
}
} // namespace framework
} // namespace paddle
......@@ -69,6 +69,8 @@ class Tensor {
public:
Tensor() : type_(proto::VarType::FP32), offset_(0) {}
explicit Tensor(std::type_index type);
/*! Return a pointer to mutable memory block. */
template <typename T>
T* data();
......@@ -162,6 +164,8 @@ class Tensor {
return std::move(holder_);
}
void ResetHolder(std::shared_ptr<memory::Allocation> holder);
private:
/*! holds the memory block if allocated. */
std::shared_ptr<memory::Allocation> holder_;
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/fluid/platform/create_tensor_with_allocationptr.h"
namespace paddle {
namespace operators {
......@@ -123,6 +124,8 @@ class GemmConvKernel : public framework::OpKernel<T> {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
auto& dev_ctx = context.template device_context<DeviceContext>();
const int batch_size = static_cast<int>(input->dims()[0]);
// filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
......@@ -155,13 +158,19 @@ class GemmConvKernel : public framework::OpKernel<T> {
// to call the matrix multiplication interface.
Tensor col_matrix;
if (is_expand) {
col.mutable_data<T>(col_shape, context.GetPlace());
auto tmp_allocation_ptr =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate(
framework::product(col_shape) * sizeof(T));
Tensor tep_tensor =
platform::GetTensor<T>(std::move(tmp_allocation_ptr), col_shape);
col.ShareDataWith(tep_tensor);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim input_shape =
framework::slice_ddim(input->dims(), 1, input->dims().size());
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
......@@ -178,7 +187,6 @@ class GemmConvKernel : public framework::OpKernel<T> {
math::Vol2ColFunctor<DeviceContext, T> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
......@@ -237,6 +245,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]);
auto& dev_ctx = context.template device_context<DeviceContext>();
// filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
// output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
......@@ -262,8 +272,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim input_shape =
framework::slice_ddim(input->dims(), 1, input->dims().size());
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
......@@ -286,13 +296,18 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
// to call the matrix multiplication interface.
Tensor col_matrix;
if (is_expand) {
col.mutable_data<T>(col_shape, context.GetPlace());
auto tmp_allocation_ptr =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate(
framework::product(col_shape) * sizeof(T));
Tensor tep_tensor =
platform::GetTensor<T>(std::move(tmp_allocation_ptr), col_shape);
col.ShareDataWith(tep_tensor);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (input_grad) {
......
......@@ -131,9 +131,8 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
int in_col = input[0].numel() / in_row;
int out_row = in_row, out_col = 0;
framework::Vector<int16_t> inputs_data(in_num * sizeof(T*) / 2);
framework::Vector<int> inputs_col(in_num + 1);
T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data());
std::vector<T*> inputs_data(in_num);
std::vector<int> inputs_col(in_num + 1);
inputs_col[0] = 0;
bool sameShape = true;
......@@ -144,12 +143,9 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
}
out_col += t_cols;
inputs_col[i + 1] = out_col;
inputs_ptr[i] = const_cast<T*>(input[i].data<T>());
inputs_data[i] = const_cast<T*>(input[i].data<T>());
}
T** dev_ins_data =
reinterpret_cast<T**>(inputs_data.CUDAMutableData(context.GetPlace()));
// computation
// set the thread block and grid according to CurrentDeviceId
const int kThreadsPerBlock = 1024;
......@@ -169,18 +165,32 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
auto tmp_dev_ins_data =
platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
inputs_data.size() * sizeof(T*));
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
tmp_dev_ins_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_data.data()),
inputs_data.size() * sizeof(T*), context.stream());
T** dev_ins_data = reinterpret_cast<T**>(tmp_dev_ins_data->ptr());
if (sameShape) {
ConcatKernel<<<grid_size, block_size, 0, context.stream()>>>(
dev_ins_data, in_col, out_row, out_col, output->data<T>());
} else {
const int* dev_ins_col_data = inputs_col.CUDAData(context.GetPlace());
auto tmp_dev_ins_col_data =
platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
inputs_col.size() * sizeof(int));
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_col.data()),
inputs_col.size() * sizeof(int), context.stream());
int* dev_ins_col_data = static_cast<int*>(tmp_dev_ins_col_data->ptr());
ConcatKernel<<<grid_size, block_size, 0, context.stream()>>>(
dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
out_row, out_col, output->data<T>());
}
// Wait() must be called because `inputs_data` may be destructed before
// kernel ends
context.Wait();
}
};
......@@ -207,9 +217,8 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
int in_col = 0, in_row = out_row;
bool sameShape = true;
framework::Vector<int16_t> outputs_data(o_num * sizeof(T*) / 2);
framework::Vector<int> outputs_cols(o_num + 1);
T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data());
std::vector<T*> outputs_data(o_num);
std::vector<int> outputs_cols(o_num + 1);
outputs_cols[0] = 0;
for (int i = 0; i < o_num; ++i) {
......@@ -220,15 +229,12 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
in_col += t_col;
outputs_cols[i + 1] = in_col;
if (outputs->at(i) != nullptr) {
outputs_ptr[i] = outputs->at(i)->data<T>();
outputs_data[i] = outputs->at(i)->data<T>();
} else {
outputs_ptr[i] = nullptr;
outputs_data[i] = nullptr;
}
}
T** dev_out_gpu_data =
reinterpret_cast<T**>(outputs_data.CUDAMutableData(context.GetPlace()));
// computation
const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock;
......@@ -247,18 +253,33 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
auto tmp_dev_outs_data =
platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
outputs_data.size() * sizeof(T*));
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
tmp_dev_outs_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_data.data()),
outputs_data.size() * sizeof(T*), context.stream());
T** dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
if (sameShape) {
SplitKernel<<<grid_size, block_size, 0, context.stream()>>>(
input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
} else {
const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace());
auto tmp_dev_ins_col_data =
platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
outputs_cols.size() * sizeof(int));
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_cols.data()),
outputs_cols.size() * sizeof(int), context.stream());
int* dev_outs_col_data =
reinterpret_cast<int*>(tmp_dev_ins_col_data->ptr());
SplitKernel<<<grid_size, block_size, 0, context.stream()>>>(
input.data<T>(), in_row, in_col, dev_outs_col_data,
static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
}
// Wait() must be called because `outputs_data` may be destructed before
// kernel ends
context.Wait();
}
};
......
......@@ -92,8 +92,8 @@ template <typename T>
class MeanIoUCUDAOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto& place = *dev_ctx.eigen_device();
// get input and output tensor
auto* predictions = ctx.Input<Tensor>("Predictions");
auto* labels = ctx.Input<Tensor>("Labels");
......@@ -115,11 +115,11 @@ class MeanIoUCUDAOpKernel : public framework::OpKernel<T> {
auto out_wrong_t = EigenTensor<int, 1>::From(*out_wrong);
auto out_correct_t = EigenTensor<int, 1>::From(*out_correct);
// Temporary tensor
Tensor ious;
float* ious_data = ious.mutable_data<float>(
{static_cast<int64_t>(num_classes)}, ctx.GetPlace());
auto ious_t = EigenTensor<float, 1>::From(ious);
// Temporary memory
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
auto tmp_ious_data = allocator.Allocate(num_classes * sizeof(float));
float* ious_data = static_cast<float*>(tmp_ious_data->ptr());
// Init out_wrong, out_correct and out_mean_iou
out_wrong_t.device(place) = out_wrong_t.constant(0);
......@@ -148,7 +148,7 @@ class MeanIoUCUDAOpKernel : public framework::OpKernel<T> {
CountCUDAKernel<T><<<grid, block, cache_size, stream>>>(
num_classes, predictions->numel(), predictions_data, labels_data,
out_wrong_data, out_correct_data);
ctx.device_context().Wait();
ComputeIoUCUDAKernel<<<1, block, 0, stream>>>(num_classes, out_wrong_data,
out_correct_data, ious_data,
out_mean_iou_data);
......
......@@ -56,6 +56,8 @@ ELSE()
set(MKLDNN_CTX_DEPS)
ENDIF()
cc_library(temp_allocator SRCS temporary_allocator.cc DEPS allocator_facade)
nv_library(stream_callback_manager SRCS stream_callback_manager.cc DEPS simple_threadpool enforce)
IF(WITH_GPU)
set(STREAM_CALLBACK_DEPS stream_callback_manager)
......@@ -66,7 +68,8 @@ ENDIF()
# memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc ${STREAM_CALLBACK_DEPS}
place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS})
place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} temp_allocator)
if(WIN32)
if(WITH_GPU AND NOT WITH_DSO)
get_property(cuda_modules GLOBAL PROPERTY CUDA_MODULES)
......@@ -92,3 +95,9 @@ IF(WITH_GPU)
nv_test(cuda_helper_test SRCS cuda_helper_test.cu)
ENDIF()
nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info)
if(WITH_GPU)
nv_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor)
else()
cc_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor)
endif()
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/temporary_allocator.h"
namespace paddle {
namespace platform {
template <typename T>
paddle::framework::Tensor GetTensor(
memory::allocation::AllocationPtr temp_allocation_ptr,
const framework::DDim &dim) {
auto &deleter = temp_allocation_ptr.get_deleter();
auto *allocation_ptr = temp_allocation_ptr.release();
auto shared_allocation =
std::shared_ptr<memory::allocation::Allocation>(allocation_ptr, deleter);
PADDLE_ENFORCE(dynamic_cast<TemporaryAllocation *>(allocation_ptr) != nullptr,
"The AllocationPtr must be TemporaryAllocation.");
PADDLE_ENFORCE_EQ(allocation_ptr->size(),
framework::product(dim) * sizeof(T));
paddle::framework::Tensor temp_tensor(std::type_index(typeid(T)));
temp_tensor.Resize(dim);
temp_tensor.ResetHolder(std::move(shared_allocation));
return temp_tensor;
}
} // namespace platform
} // namespace paddle
......@@ -85,6 +85,49 @@ DeviceContextPool::DeviceContextPool(
}
}
DeviceTemporaryAllocator* DeviceTemporaryAllocator::allocators = nullptr;
#ifdef PADDLE_WITH_CUDA
platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
const platform::Place& place, const cudaStream_t& stream) {
PADDLE_ENFORCE(platform::is_gpu_place(place));
auto place_stream = std::make_pair(place, stream);
{
std::unique_lock<std::mutex> lock(mtx_);
if (!device_allocator_.count(place_stream)) {
device_allocator_[place_stream].reset(new TemporaryAllocator(place));
device_allocator_[place_stream]->SetCallback([stream]() {
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
PADDLE_ENFORCE(cudaGetLastError());
});
}
}
return *device_allocator_.at(place_stream);
}
template <>
platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
const platform::CUDADeviceContext& dev_ctx) {
auto place_stream = std::make_pair(dev_ctx.GetPlace(), dev_ctx.stream());
if (device_allocator_.count(place_stream)) {
return *device_allocator_.at(place_stream);
}
return Get(dev_ctx.GetPlace(), dev_ctx.stream());
}
#endif
template <>
platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
const platform::CPUDeviceContext& dev_ctx) {
return cpu_allocator_;
}
platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
const platform::Place& place) {
PADDLE_ENFORCE(platform::is_cpu_place(place), "You should pass CPUPlace");
return cpu_allocator_;
}
CPUDeviceContext::CPUDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}
......@@ -271,8 +314,12 @@ CUDADeviceContext::~CUDADeviceContext() {
Place CUDADeviceContext::GetPlace() const { return place_; }
void CUDADeviceContext::Wait() const {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaGetLastError());
auto& allocator =
DeviceTemporaryAllocator::Instance().Get<CUDADeviceContext>(*this);
allocator.Release([=]() {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaGetLastError());
});
}
int CUDADeviceContext::GetComputeCapability() const {
......
......@@ -15,8 +15,10 @@ limitations under the License. */
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/temporary_allocator.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
......@@ -39,6 +41,50 @@ limitations under the License. */
namespace paddle {
namespace platform {
/*! \brief device temporary allocator singleton */
class DeviceTemporaryAllocator {
public:
static DeviceTemporaryAllocator& Instance() {
PADDLE_ENFORCE_NOT_NULL(allocators,
"Need to Create DeviceTemporaryAllocator first!");
return *allocators;
}
static DeviceTemporaryAllocator& Init() {
if (allocators == nullptr) {
allocators = new DeviceTemporaryAllocator();
}
return *allocators;
}
/*! \brief Return handle of single temporary allocator. */
#ifdef PADDLE_WITH_CUDA
platform::TemporaryAllocator& Get(const platform::Place& place,
const cudaStream_t& stream);
#endif
template <typename DeviceContext>
platform::TemporaryAllocator& Get(const DeviceContext& dev_ctx);
platform::TemporaryAllocator& Get(const platform::Place& place);
private:
DeviceTemporaryAllocator() : cpu_allocator_(platform::CPUPlace()) {}
static DeviceTemporaryAllocator* allocators;
platform::TemporaryAllocator cpu_allocator_;
#ifdef PADDLE_WITH_CUDA
std::map<std::pair<platform::Place, cudaStream_t>,
std::unique_ptr<platform::TemporaryAllocator>>
device_allocator_;
#endif
std::mutex mtx_;
DISABLE_COPY_AND_ASSIGN(DeviceTemporaryAllocator);
};
class DeviceContext {
public:
virtual ~DeviceContext() {}
......
......@@ -110,7 +110,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
}
places.emplace_back(platform::CPUPlace());
platform::DeviceContextPool::Init(places);
platform::DeviceTemporaryAllocator::Init();
#ifndef PADDLE_WITH_MKLDNN
platform::SetNumThreads(FLAGS_paddle_num_threads);
#endif
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/temporary_allocator.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
DEFINE_double(limit_of_temporary_allocation, -1,
"The up limit of temporary_allocation size.");
namespace paddle {
namespace platform {
namespace alloc = memory::allocation;
TemporaryAllocation::TemporaryAllocation(
alloc::AllocationPtr &&underlying_allocation)
: Allocation(underlying_allocation->ptr(), underlying_allocation->size(),
underlying_allocation->place()),
underlying_allocation_(std::move(underlying_allocation)) {}
TemporaryAllocator::TemporaryAllocator(platform::Place place) : place_(place) {
temp_mem_queue_.reset(new std::deque<TemporaryAllocation *>());
}
bool TemporaryAllocator::IsAllocThreadSafe() const { return true; }
void TemporaryAllocator::Release(const std::function<void()> &callback) {
std::shared_ptr<std::deque<TemporaryAllocation *>> t_allocations;
{
std::unique_lock<std::mutex> lock(mtx_);
callback();
t_allocations = temp_mem_queue_;
temp_mem_queue_.reset(new std::deque<TemporaryAllocation *>());
wait_delete_mem_ = 0;
}
for (auto tmp : *t_allocations) {
VLOG(10) << "Delete temporary allocation " << tmp->ptr()
<< " size: " << tmp->size();
delete tmp;
}
}
void TemporaryAllocator::Free(alloc::Allocation *allocation) {
auto *temp_allocation = dynamic_cast<TemporaryAllocation *>(allocation);
PADDLE_ENFORCE_NOT_NULL(temp_allocation);
if (platform::is_gpu_place(temp_allocation->place())) {
size_t wait_delete_mem = 0;
{
std::unique_lock<std::mutex> lock(mtx_);
temp_mem_queue_->emplace_back(temp_allocation);
wait_delete_mem_ += temp_allocation->size();
wait_delete_mem = wait_delete_mem_;
VLOG(10) << "Move temporary allocation: " << temp_allocation->ptr()
<< " to delete queue: " << temp_allocation->size() << "; "
<< "wait_delete_mem: " << wait_delete_mem_;
}
if (FLAGS_limit_of_temporary_allocation > 0 &&
wait_delete_mem > FLAGS_limit_of_temporary_allocation) {
Release(callback_);
}
return;
}
delete temp_allocation;
}
size_t TemporaryAllocator::TemporaryAllocationQueueSize() {
std::unique_lock<std::mutex> lock(mtx_);
return temp_mem_queue_ ? temp_mem_queue_->size() : 0;
}
void TemporaryAllocator::SetCallback(const std::function<void()> &callback) {
callback_ = callback;
}
alloc::Allocation *TemporaryAllocator::AllocateImpl(
size_t size, alloc::Allocator::Attr attr) {
auto raw_allocation =
alloc::AllocatorFacade::Instance().Alloc(place_, size, attr);
auto temp_mem = new TemporaryAllocation(std::move(raw_allocation));
VLOG(10) << "Alloc temporary allocation: " << temp_mem->ptr() << ": " << size;
return temp_mem;
}
} // namespace platform
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable> // NOLINT
#include <deque>
#include <mutex> // NOLINT
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/lock_guard_ptr.h"
namespace paddle {
namespace platform {
class TemporaryAllocation : public memory::allocation::Allocation {
public:
explicit TemporaryAllocation(
memory::allocation::AllocationPtr &&underlying_allocation);
memory::allocation::AllocationPtr underlying_allocation_;
};
class TemporaryAllocator : public memory::allocation::Allocator {
public:
explicit TemporaryAllocator(platform::Place place);
void Release(const std::function<void()> &callback);
size_t TemporaryAllocationQueueSize();
bool IsAllocThreadSafe() const override;
void SetCallback(const std::function<void()> &callback);
protected:
void Free(memory::allocation::Allocation *allocation) override;
memory::allocation::Allocation *AllocateImpl(
size_t size, memory::allocation::Allocator::Attr attr) override;
private:
platform::Place place_;
// When the allocation is not held by any variable, it should be placed
// to temp_mem_queue immediately.
std::shared_ptr<std::deque<TemporaryAllocation *>> temp_mem_queue_{nullptr};
std::mutex mtx_;
size_t wait_delete_mem_{0};
std::function<void()> callback_;
};
} // namespace platform
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/temporary_allocator.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/create_tensor_with_allocationptr.h"
DECLARE_double(limit_of_temporary_allocation);
namespace paddle {
namespace platform {
TEST(temporary_allocator, temporary_allocator) {
platform::CPUPlace cpu_place;
TemporaryAllocator alloc(cpu_place);
alloc.Allocate(100);
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace gpu_place(0);
TemporaryAllocator gpu_alloc(gpu_place);
auto allocation = gpu_alloc.Allocate(101);
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
gpu_alloc.Release([]() {});
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
{
auto allocation = gpu_alloc.Allocate(102);
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
}
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1);
gpu_alloc.Release([]() {});
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
#endif
}
TEST(temporary_allocator, add_callback) {
#ifdef PADDLE_WITH_CUDA
FLAGS_limit_of_temporary_allocation = 10;
platform::CUDAPlace gpu_place(0);
TemporaryAllocator gpu_alloc(gpu_place);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx =
static_cast<platform::CUDADeviceContext*>(pool.Get(gpu_place));
auto stream = dev_ctx->stream();
bool deleted = false;
gpu_alloc.SetCallback([stream, &deleted]() {
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
PADDLE_ENFORCE(cudaGetLastError());
deleted = true;
});
{ gpu_alloc.Allocate(100); }
PADDLE_ENFORCE(deleted);
FLAGS_limit_of_temporary_allocation = -1;
#endif
}
TEST(temporary_allocator, create_tensor_with_allocationptr) {
platform::CPUPlace cpu_place;
TemporaryAllocator cpu_alloc(cpu_place);
{
size_t memory_size = 200;
auto allocation = cpu_alloc.Allocate(memory_size);
void* address = allocation->ptr();
int numel = memory_size / sizeof(float);
framework::Tensor tensor =
GetTensor<float>(std::move(allocation), framework::make_ddim({numel}));
PADDLE_ENFORCE_EQ(address, tensor.data<float>());
PADDLE_ENFORCE_EQ(tensor.numel(), numel);
}
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace gpu_place(0);
TemporaryAllocator gpu_alloc(gpu_place);
{
size_t memory_size = 300;
auto allocation = gpu_alloc.Allocate(memory_size);
void* address = allocation->ptr();
int numel = memory_size / sizeof(float);
framework::Tensor tensor =
GetTensor<float>(std::move(allocation), framework::make_ddim({numel}));
PADDLE_ENFORCE_EQ(address, tensor.data<float>());
PADDLE_ENFORCE_EQ(tensor.numel(), numel);
}
// The allocation is not holded now, it should be placed to
// TemporaryAllocationQueue.
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1);
gpu_alloc.Release([]() {});
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
#endif
}
TEST(temporary_allocator, create_tensor_with_allocationptr2) {
platform::CPUPlace cpu_place;
TemporaryAllocator cpu_alloc(cpu_place);
{
size_t memory_size = 400;
int numel = memory_size / sizeof(float);
framework::Tensor out_side_tensor;
void* address;
{
auto allocation = cpu_alloc.Allocate(memory_size);
address = allocation->ptr();
framework::Tensor tensor = GetTensor<float>(
std::move(allocation), framework::make_ddim({numel}));
PADDLE_ENFORCE_EQ(address, tensor.data<float>());
PADDLE_ENFORCE_EQ(tensor.numel(), numel);
out_side_tensor.ShareDataWith(tensor);
}
PADDLE_ENFORCE_EQ(address, out_side_tensor.data<float>());
PADDLE_ENFORCE_EQ(out_side_tensor.numel(), numel);
}
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace gpu_place(0);
TemporaryAllocator gpu_alloc(gpu_place);
{
void* address;
size_t memory_size = 500;
int numel = memory_size / sizeof(float);
framework::Tensor out_side_tensor;
{
auto allocation = gpu_alloc.Allocate(memory_size);
address = allocation->ptr();
framework::Tensor tensor = GetTensor<float>(
std::move(allocation), framework::make_ddim({numel}));
PADDLE_ENFORCE_EQ(address, tensor.data<float>());
PADDLE_ENFORCE_EQ(tensor.numel(), numel);
out_side_tensor.ShareDataWith(tensor);
}
PADDLE_ENFORCE_EQ(address, out_side_tensor.data<float>());
PADDLE_ENFORCE_EQ(out_side_tensor.numel(), numel);
// The allocation is holded by out_side_tensor.
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
gpu_alloc.Release([]() {});
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
}
// The allocation is not holded now, it should be placed to
// TemporaryAllocationQueue.
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1);
gpu_alloc.Release([]() {});
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
#endif
}
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册