diff --git a/paddle/fluid/framework/tensor.cc b/paddle/fluid/framework/tensor.cc index 57335847a1931de6599560c6e9395a910282b0ee..5b09cad06c3f87ce29a8c986d30217099bd10d74 100644 --- a/paddle/fluid/framework/tensor.cc +++ b/paddle/fluid/framework/tensor.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/var_type.h" namespace paddle { namespace framework { @@ -27,6 +28,9 @@ void Tensor::check_memory_size() const { "or maybe the required data-type mismatches the data already stored."); } +Tensor::Tensor(std::type_index type) + : type_(framework::ToDataType(type)), offset_(0) {} + size_t Tensor::memory_size() const { return holder_ == nullptr ? 0UL : holder_->size() - offset_; } @@ -101,5 +105,12 @@ const DDim& Tensor::dims() const { return dims_; } int64_t Tensor::numel() const { return product(dims_); } +void Tensor::ResetHolder(std::shared_ptr holder) { + if (holder_) { + PADDLE_ENFORCE_EQ(numel() * SizeOfType(type()), holder->size()); + } + holder_ = holder; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 6a1cbe5cd567429c922156f8bce7ca710b15a0f5..2e110133a33ede5c58779f9f7c52abd8e74c2fa0 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -69,6 +69,8 @@ class Tensor { public: Tensor() : type_(proto::VarType::FP32), offset_(0) {} + explicit Tensor(std::type_index type); + /*! Return a pointer to mutable memory block. */ template T* data(); @@ -162,6 +164,8 @@ class Tensor { return std::move(holder_); } + void ResetHolder(std::shared_ptr holder); + private: /*! holds the memory block if allocated. */ std::shared_ptr holder_; diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index 249f308c13ff5636fbaa6747b28cab7886b7e736..4a7b31c7d491f0e4b73e2b574456d1567b7cc5dc 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/vol2col.h" +#include "paddle/fluid/platform/create_tensor_with_allocationptr.h" namespace paddle { namespace operators { @@ -123,6 +124,8 @@ class GemmConvKernel : public framework::OpKernel { std::vector paddings = context.Attr>("paddings"); std::vector dilations = context.Attr>("dilations"); + auto& dev_ctx = context.template device_context(); + const int batch_size = static_cast(input->dims()[0]); // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w} @@ -155,13 +158,19 @@ class GemmConvKernel : public framework::OpKernel { // to call the matrix multiplication interface. Tensor col_matrix; if (is_expand) { - col.mutable_data(col_shape, context.GetPlace()); + auto tmp_allocation_ptr = + platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate( + framework::product(col_shape) * sizeof(T)); + Tensor tep_tensor = + platform::GetTensor(std::move(tmp_allocation_ptr), col_shape); + + col.ShareDataWith(tep_tensor); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } - framework::DDim input_shape = framework::slice_ddim( - input->dims(), 1, static_cast(input->dims().size())); + framework::DDim input_shape = + framework::slice_ddim(input->dims(), 1, input->dims().size()); framework::DDim filter_matrix_shape = {filter.dims()[0], filter.numel() / filter.dims()[0]}; @@ -178,7 +187,6 @@ class GemmConvKernel : public framework::OpKernel { math::Vol2ColFunctor vol2col; math::Im2ColFunctor im2col; - auto& dev_ctx = context.template device_context(); auto blas = math::GetBlas(dev_ctx); for (int i = 0; i < batch_size; i++) { Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); @@ -237,6 +245,8 @@ class GemmConvGradKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); + auto& dev_ctx = context.template device_context(); + // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w} std::vector filter_shape_vec(framework::vectorize(filter.dims())); // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w} @@ -262,8 +272,8 @@ class GemmConvGradKernel : public framework::OpKernel { framework::DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1); - framework::DDim input_shape = framework::slice_ddim( - input->dims(), 1, static_cast(input->dims().size())); + framework::DDim input_shape = + framework::slice_ddim(input->dims(), 1, input->dims().size()); framework::DDim filter_matrix_shape = {filter.dims()[0], filter.numel() / filter.dims()[0]}; @@ -286,13 +296,18 @@ class GemmConvGradKernel : public framework::OpKernel { // to call the matrix multiplication interface. Tensor col_matrix; if (is_expand) { - col.mutable_data(col_shape, context.GetPlace()); + auto tmp_allocation_ptr = + platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate( + framework::product(col_shape) * sizeof(T)); + Tensor tep_tensor = + platform::GetTensor(std::move(tmp_allocation_ptr), col_shape); + + col.ShareDataWith(tep_tensor); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } math::SetConstant set_zero; - auto& dev_ctx = context.template device_context(); auto blas = math::GetBlas(dev_ctx); if (input_grad) { diff --git a/paddle/fluid/operators/math/concat_and_split.cu b/paddle/fluid/operators/math/concat_and_split.cu index 760a065c1081d1e55901774b258ba524471b856b..b10a19b658e383b8c7b4fbbe8f90da1fe0d4fd14 100644 --- a/paddle/fluid/operators/math/concat_and_split.cu +++ b/paddle/fluid/operators/math/concat_and_split.cu @@ -131,9 +131,8 @@ class ConcatFunctor { int in_col = input[0].numel() / in_row; int out_row = in_row, out_col = 0; - framework::Vector inputs_data(in_num * sizeof(T*) / 2); - framework::Vector inputs_col(in_num + 1); - T** inputs_ptr = reinterpret_cast(inputs_data.data()); + std::vector inputs_data(in_num); + std::vector inputs_col(in_num + 1); inputs_col[0] = 0; bool sameShape = true; @@ -144,12 +143,9 @@ class ConcatFunctor { } out_col += t_cols; inputs_col[i + 1] = out_col; - inputs_ptr[i] = const_cast(input[i].data()); + inputs_data[i] = const_cast(input[i].data()); } - T** dev_ins_data = - reinterpret_cast(inputs_data.CUDAMutableData(context.GetPlace())); - // computation // set the thread block and grid according to CurrentDeviceId const int kThreadsPerBlock = 1024; @@ -169,18 +165,32 @@ class ConcatFunctor { std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1)); dim3 grid_size = dim3(grid_cols, grid_rows, 1); + auto tmp_dev_ins_data = + platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate( + inputs_data.size() * sizeof(T*)); + memory::Copy(boost::get(context.GetPlace()), + tmp_dev_ins_data->ptr(), platform::CPUPlace(), + static_cast(inputs_data.data()), + inputs_data.size() * sizeof(T*), context.stream()); + T** dev_ins_data = reinterpret_cast(tmp_dev_ins_data->ptr()); + if (sameShape) { ConcatKernel<<>>( dev_ins_data, in_col, out_row, out_col, output->data()); } else { - const int* dev_ins_col_data = inputs_col.CUDAData(context.GetPlace()); + auto tmp_dev_ins_col_data = + platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate( + inputs_col.size() * sizeof(int)); + memory::Copy(boost::get(context.GetPlace()), + tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), + static_cast(inputs_col.data()), + inputs_col.size() * sizeof(int), context.stream()); + int* dev_ins_col_data = static_cast(tmp_dev_ins_col_data->ptr()); + ConcatKernel<<>>( dev_ins_data, dev_ins_col_data, static_cast(inputs_col.size()), out_row, out_col, output->data()); } - // Wait() must be called because `inputs_data` may be destructed before - // kernel ends - context.Wait(); } }; @@ -207,9 +217,8 @@ class SplitFunctor { int in_col = 0, in_row = out_row; bool sameShape = true; - framework::Vector outputs_data(o_num * sizeof(T*) / 2); - framework::Vector outputs_cols(o_num + 1); - T** outputs_ptr = reinterpret_cast(outputs_data.data()); + std::vector outputs_data(o_num); + std::vector outputs_cols(o_num + 1); outputs_cols[0] = 0; for (int i = 0; i < o_num; ++i) { @@ -220,15 +229,12 @@ class SplitFunctor { in_col += t_col; outputs_cols[i + 1] = in_col; if (outputs->at(i) != nullptr) { - outputs_ptr[i] = outputs->at(i)->data(); + outputs_data[i] = outputs->at(i)->data(); } else { - outputs_ptr[i] = nullptr; + outputs_data[i] = nullptr; } } - T** dev_out_gpu_data = - reinterpret_cast(outputs_data.CUDAMutableData(context.GetPlace())); - // computation const int kThreadsPerBlock = 1024; int block_cols = kThreadsPerBlock; @@ -247,18 +253,33 @@ class SplitFunctor { std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1)); dim3 grid_size = dim3(grid_cols, grid_rows, 1); + auto tmp_dev_outs_data = + platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate( + outputs_data.size() * sizeof(T*)); + memory::Copy(boost::get(context.GetPlace()), + tmp_dev_outs_data->ptr(), platform::CPUPlace(), + reinterpret_cast(outputs_data.data()), + outputs_data.size() * sizeof(T*), context.stream()); + T** dev_out_gpu_data = reinterpret_cast(tmp_dev_outs_data->ptr()); + if (sameShape) { SplitKernel<<>>( input.data(), in_row, in_col, out0_col, dev_out_gpu_data); } else { - const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace()); + auto tmp_dev_ins_col_data = + platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate( + outputs_cols.size() * sizeof(int)); + memory::Copy(boost::get(context.GetPlace()), + tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), + reinterpret_cast(outputs_cols.data()), + outputs_cols.size() * sizeof(int), context.stream()); + int* dev_outs_col_data = + reinterpret_cast(tmp_dev_ins_col_data->ptr()); + SplitKernel<<>>( input.data(), in_row, in_col, dev_outs_col_data, static_cast(outputs_cols.size()), dev_out_gpu_data); } - // Wait() must be called because `outputs_data` may be destructed before - // kernel ends - context.Wait(); } }; diff --git a/paddle/fluid/operators/mean_iou_op.cu b/paddle/fluid/operators/mean_iou_op.cu index 83bb4dde46fa241affad3788e3381b6ecd8aa098..08088eb8733f28f0dc8ecade2aa4b70342244b0a 100644 --- a/paddle/fluid/operators/mean_iou_op.cu +++ b/paddle/fluid/operators/mean_iou_op.cu @@ -92,8 +92,8 @@ template class MeanIoUCUDAOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto& place = *ctx.template device_context() - .eigen_device(); + auto& dev_ctx = ctx.template device_context(); + auto& place = *dev_ctx.eigen_device(); // get input and output tensor auto* predictions = ctx.Input("Predictions"); auto* labels = ctx.Input("Labels"); @@ -115,11 +115,11 @@ class MeanIoUCUDAOpKernel : public framework::OpKernel { auto out_wrong_t = EigenTensor::From(*out_wrong); auto out_correct_t = EigenTensor::From(*out_correct); - // Temporary tensor - Tensor ious; - float* ious_data = ious.mutable_data( - {static_cast(num_classes)}, ctx.GetPlace()); - auto ious_t = EigenTensor::From(ious); + // Temporary memory + auto& allocator = + platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); + auto tmp_ious_data = allocator.Allocate(num_classes * sizeof(float)); + float* ious_data = static_cast(tmp_ious_data->ptr()); // Init out_wrong, out_correct and out_mean_iou out_wrong_t.device(place) = out_wrong_t.constant(0); @@ -148,7 +148,7 @@ class MeanIoUCUDAOpKernel : public framework::OpKernel { CountCUDAKernel<<>>( num_classes, predictions->numel(), predictions_data, labels_data, out_wrong_data, out_correct_data); - ctx.device_context().Wait(); + ComputeIoUCUDAKernel<<<1, block, 0, stream>>>(num_classes, out_wrong_data, out_correct_data, ious_data, out_mean_iou_data); diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 2f205e1d5ca30d67a55e4df0f5e879ffef9a9c26..d1dff16ddd859e6bf19ec22420c28819a9f14d50 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -56,6 +56,8 @@ ELSE() set(MKLDNN_CTX_DEPS) ENDIF() +cc_library(temp_allocator SRCS temporary_allocator.cc DEPS allocator_facade) + nv_library(stream_callback_manager SRCS stream_callback_manager.cc DEPS simple_threadpool enforce) IF(WITH_GPU) set(STREAM_CALLBACK_DEPS stream_callback_manager) @@ -66,7 +68,8 @@ ENDIF() # memcpy depends on device_context, here add deps individually for # avoiding cycle dependencies cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc ${STREAM_CALLBACK_DEPS} - place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) + place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} temp_allocator) + if(WIN32) if(WITH_GPU AND NOT WITH_DSO) get_property(cuda_modules GLOBAL PROPERTY CUDA_MODULES) @@ -92,3 +95,9 @@ IF(WITH_GPU) nv_test(cuda_helper_test SRCS cuda_helper_test.cu) ENDIF() nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info) + +if(WITH_GPU) + nv_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor) +else() + cc_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor) +endif() diff --git a/paddle/fluid/platform/create_tensor_with_allocationptr.h b/paddle/fluid/platform/create_tensor_with_allocationptr.h new file mode 100644 index 0000000000000000000000000000000000000000..00fcc5f86209b2a827ac070773f4b0049b0457d8 --- /dev/null +++ b/paddle/fluid/platform/create_tensor_with_allocationptr.h @@ -0,0 +1,42 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/temporary_allocator.h" +namespace paddle { +namespace platform { + +template +paddle::framework::Tensor GetTensor( + memory::allocation::AllocationPtr temp_allocation_ptr, + const framework::DDim &dim) { + auto &deleter = temp_allocation_ptr.get_deleter(); + auto *allocation_ptr = temp_allocation_ptr.release(); + auto shared_allocation = + std::shared_ptr(allocation_ptr, deleter); + + PADDLE_ENFORCE(dynamic_cast(allocation_ptr) != nullptr, + "The AllocationPtr must be TemporaryAllocation."); + PADDLE_ENFORCE_EQ(allocation_ptr->size(), + framework::product(dim) * sizeof(T)); + + paddle::framework::Tensor temp_tensor(std::type_index(typeid(T))); + temp_tensor.Resize(dim); + temp_tensor.ResetHolder(std::move(shared_allocation)); + return temp_tensor; +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index d2e23d80f437e1df9216fa36e99a9be394dda074..81c443d758fcf22545af4bf8e452be8f0ecc0a89 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -85,6 +85,49 @@ DeviceContextPool::DeviceContextPool( } } +DeviceTemporaryAllocator* DeviceTemporaryAllocator::allocators = nullptr; + +#ifdef PADDLE_WITH_CUDA +platform::TemporaryAllocator& DeviceTemporaryAllocator::Get( + const platform::Place& place, const cudaStream_t& stream) { + PADDLE_ENFORCE(platform::is_gpu_place(place)); + auto place_stream = std::make_pair(place, stream); + { + std::unique_lock lock(mtx_); + if (!device_allocator_.count(place_stream)) { + device_allocator_[place_stream].reset(new TemporaryAllocator(place)); + device_allocator_[place_stream]->SetCallback([stream]() { + PADDLE_ENFORCE(cudaStreamSynchronize(stream)); + PADDLE_ENFORCE(cudaGetLastError()); + }); + } + } + return *device_allocator_.at(place_stream); +} + +template <> +platform::TemporaryAllocator& DeviceTemporaryAllocator::Get( + const platform::CUDADeviceContext& dev_ctx) { + auto place_stream = std::make_pair(dev_ctx.GetPlace(), dev_ctx.stream()); + if (device_allocator_.count(place_stream)) { + return *device_allocator_.at(place_stream); + } + return Get(dev_ctx.GetPlace(), dev_ctx.stream()); +} +#endif + +template <> +platform::TemporaryAllocator& DeviceTemporaryAllocator::Get( + const platform::CPUDeviceContext& dev_ctx) { + return cpu_allocator_; +} + +platform::TemporaryAllocator& DeviceTemporaryAllocator::Get( + const platform::Place& place) { + PADDLE_ENFORCE(platform::is_cpu_place(place), "You should pass CPUPlace"); + return cpu_allocator_; +} + CPUDeviceContext::CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } @@ -271,8 +314,12 @@ CUDADeviceContext::~CUDADeviceContext() { Place CUDADeviceContext::GetPlace() const { return place_; } void CUDADeviceContext::Wait() const { - PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); - PADDLE_ENFORCE(cudaGetLastError()); + auto& allocator = + DeviceTemporaryAllocator::Instance().Get(*this); + allocator.Release([=]() { + PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); + PADDLE_ENFORCE(cudaGetLastError()); + }); } int CUDADeviceContext::GetComputeCapability() const { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 812e56f1f966d03207cf83ad47cb88e9fa5d55bb..af9744dcb847f8af97e87cc18d2aee376f3f3d6c 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -15,8 +15,10 @@ limitations under the License. */ #include // NOLINT #include #include +#include #include #include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/platform/temporary_allocator.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cudnn.h" @@ -39,6 +41,50 @@ limitations under the License. */ namespace paddle { namespace platform { +/*! \brief device temporary allocator singleton */ +class DeviceTemporaryAllocator { + public: + static DeviceTemporaryAllocator& Instance() { + PADDLE_ENFORCE_NOT_NULL(allocators, + "Need to Create DeviceTemporaryAllocator first!"); + return *allocators; + } + + static DeviceTemporaryAllocator& Init() { + if (allocators == nullptr) { + allocators = new DeviceTemporaryAllocator(); + } + return *allocators; + } + +/*! \brief Return handle of single temporary allocator. */ +#ifdef PADDLE_WITH_CUDA + platform::TemporaryAllocator& Get(const platform::Place& place, + const cudaStream_t& stream); +#endif + template + platform::TemporaryAllocator& Get(const DeviceContext& dev_ctx); + + platform::TemporaryAllocator& Get(const platform::Place& place); + + private: + DeviceTemporaryAllocator() : cpu_allocator_(platform::CPUPlace()) {} + + static DeviceTemporaryAllocator* allocators; + + platform::TemporaryAllocator cpu_allocator_; + +#ifdef PADDLE_WITH_CUDA + std::map, + std::unique_ptr> + device_allocator_; +#endif + + std::mutex mtx_; + + DISABLE_COPY_AND_ASSIGN(DeviceTemporaryAllocator); +}; + class DeviceContext { public: virtual ~DeviceContext() {} diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 0d10d82d74a2011b1b2bc088fe88cbfdb49600b8..ac86b38a61c9d8e3e946d9fb3f46d8feba7c034d 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -110,7 +110,7 @@ void InitDevices(bool init_p2p, const std::vector devices) { } places.emplace_back(platform::CPUPlace()); platform::DeviceContextPool::Init(places); - + platform::DeviceTemporaryAllocator::Init(); #ifndef PADDLE_WITH_MKLDNN platform::SetNumThreads(FLAGS_paddle_num_threads); #endif diff --git a/paddle/fluid/platform/temporary_allocator.cc b/paddle/fluid/platform/temporary_allocator.cc new file mode 100644 index 0000000000000000000000000000000000000000..0be017f75bcc8aff5073ebb2c5179cf7250be8b9 --- /dev/null +++ b/paddle/fluid/platform/temporary_allocator.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/temporary_allocator.h" +#include "paddle/fluid/memory/allocation/allocator_facade.h" + +DEFINE_double(limit_of_temporary_allocation, -1, + "The up limit of temporary_allocation size."); + +namespace paddle { +namespace platform { +namespace alloc = memory::allocation; + +TemporaryAllocation::TemporaryAllocation( + alloc::AllocationPtr &&underlying_allocation) + : Allocation(underlying_allocation->ptr(), underlying_allocation->size(), + underlying_allocation->place()), + underlying_allocation_(std::move(underlying_allocation)) {} + +TemporaryAllocator::TemporaryAllocator(platform::Place place) : place_(place) { + temp_mem_queue_.reset(new std::deque()); +} + +bool TemporaryAllocator::IsAllocThreadSafe() const { return true; } + +void TemporaryAllocator::Release(const std::function &callback) { + std::shared_ptr> t_allocations; + { + std::unique_lock lock(mtx_); + callback(); + t_allocations = temp_mem_queue_; + temp_mem_queue_.reset(new std::deque()); + wait_delete_mem_ = 0; + } + for (auto tmp : *t_allocations) { + VLOG(10) << "Delete temporary allocation " << tmp->ptr() + << " size: " << tmp->size(); + delete tmp; + } +} + +void TemporaryAllocator::Free(alloc::Allocation *allocation) { + auto *temp_allocation = dynamic_cast(allocation); + PADDLE_ENFORCE_NOT_NULL(temp_allocation); + if (platform::is_gpu_place(temp_allocation->place())) { + size_t wait_delete_mem = 0; + { + std::unique_lock lock(mtx_); + temp_mem_queue_->emplace_back(temp_allocation); + wait_delete_mem_ += temp_allocation->size(); + wait_delete_mem = wait_delete_mem_; + VLOG(10) << "Move temporary allocation: " << temp_allocation->ptr() + << " to delete queue: " << temp_allocation->size() << "; " + << "wait_delete_mem: " << wait_delete_mem_; + } + if (FLAGS_limit_of_temporary_allocation > 0 && + wait_delete_mem > FLAGS_limit_of_temporary_allocation) { + Release(callback_); + } + return; + } + delete temp_allocation; +} + +size_t TemporaryAllocator::TemporaryAllocationQueueSize() { + std::unique_lock lock(mtx_); + return temp_mem_queue_ ? temp_mem_queue_->size() : 0; +} + +void TemporaryAllocator::SetCallback(const std::function &callback) { + callback_ = callback; +} + +alloc::Allocation *TemporaryAllocator::AllocateImpl( + size_t size, alloc::Allocator::Attr attr) { + auto raw_allocation = + alloc::AllocatorFacade::Instance().Alloc(place_, size, attr); + auto temp_mem = new TemporaryAllocation(std::move(raw_allocation)); + VLOG(10) << "Alloc temporary allocation: " << temp_mem->ptr() << ": " << size; + return temp_mem; +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/temporary_allocator.h b/paddle/fluid/platform/temporary_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..4e32d2d6959e69c94e869491ef8d11708870f7df --- /dev/null +++ b/paddle/fluid/platform/temporary_allocator.h @@ -0,0 +1,63 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include // NOLINT +#include +#include // NOLINT +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/platform/lock_guard_ptr.h" +namespace paddle { +namespace platform { + +class TemporaryAllocation : public memory::allocation::Allocation { + public: + explicit TemporaryAllocation( + memory::allocation::AllocationPtr &&underlying_allocation); + + memory::allocation::AllocationPtr underlying_allocation_; +}; + +class TemporaryAllocator : public memory::allocation::Allocator { + public: + explicit TemporaryAllocator(platform::Place place); + + void Release(const std::function &callback); + + size_t TemporaryAllocationQueueSize(); + + bool IsAllocThreadSafe() const override; + + void SetCallback(const std::function &callback); + + protected: + void Free(memory::allocation::Allocation *allocation) override; + + memory::allocation::Allocation *AllocateImpl( + size_t size, memory::allocation::Allocator::Attr attr) override; + + private: + platform::Place place_; + + // When the allocation is not held by any variable, it should be placed + // to temp_mem_queue immediately. + std::shared_ptr> temp_mem_queue_{nullptr}; + + std::mutex mtx_; + size_t wait_delete_mem_{0}; + std::function callback_; +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/temporary_allocator_test.cc b/paddle/fluid/platform/temporary_allocator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b940b0e8243c0ae1e0eeb3a2c13f3d16c228925 --- /dev/null +++ b/paddle/fluid/platform/temporary_allocator_test.cc @@ -0,0 +1,165 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/temporary_allocator.h" +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/create_tensor_with_allocationptr.h" +DECLARE_double(limit_of_temporary_allocation); + +namespace paddle { +namespace platform { + +TEST(temporary_allocator, temporary_allocator) { + platform::CPUPlace cpu_place; + TemporaryAllocator alloc(cpu_place); + alloc.Allocate(100); + +#ifdef PADDLE_WITH_CUDA + platform::CUDAPlace gpu_place(0); + TemporaryAllocator gpu_alloc(gpu_place); + + auto allocation = gpu_alloc.Allocate(101); + PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); + gpu_alloc.Release([]() {}); + PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); + + { + auto allocation = gpu_alloc.Allocate(102); + PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); + } + PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1); + gpu_alloc.Release([]() {}); + PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); +#endif +} + +TEST(temporary_allocator, add_callback) { +#ifdef PADDLE_WITH_CUDA + FLAGS_limit_of_temporary_allocation = 10; + platform::CUDAPlace gpu_place(0); + TemporaryAllocator gpu_alloc(gpu_place); + + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = + static_cast(pool.Get(gpu_place)); + auto stream = dev_ctx->stream(); + bool deleted = false; + gpu_alloc.SetCallback([stream, &deleted]() { + PADDLE_ENFORCE(cudaStreamSynchronize(stream)); + PADDLE_ENFORCE(cudaGetLastError()); + deleted = true; + }); + { gpu_alloc.Allocate(100); } + PADDLE_ENFORCE(deleted); + FLAGS_limit_of_temporary_allocation = -1; +#endif +} + +TEST(temporary_allocator, create_tensor_with_allocationptr) { + platform::CPUPlace cpu_place; + TemporaryAllocator cpu_alloc(cpu_place); + { + size_t memory_size = 200; + auto allocation = cpu_alloc.Allocate(memory_size); + void* address = allocation->ptr(); + int numel = memory_size / sizeof(float); + framework::Tensor tensor = + GetTensor(std::move(allocation), framework::make_ddim({numel})); + PADDLE_ENFORCE_EQ(address, tensor.data()); + PADDLE_ENFORCE_EQ(tensor.numel(), numel); + } + +#ifdef PADDLE_WITH_CUDA + platform::CUDAPlace gpu_place(0); + TemporaryAllocator gpu_alloc(gpu_place); + + { + size_t memory_size = 300; + auto allocation = gpu_alloc.Allocate(memory_size); + void* address = allocation->ptr(); + int numel = memory_size / sizeof(float); + framework::Tensor tensor = + GetTensor(std::move(allocation), framework::make_ddim({numel})); + PADDLE_ENFORCE_EQ(address, tensor.data()); + PADDLE_ENFORCE_EQ(tensor.numel(), numel); + } + + // The allocation is not holded now, it should be placed to + // TemporaryAllocationQueue. + PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1); + gpu_alloc.Release([]() {}); + PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); +#endif +} + +TEST(temporary_allocator, create_tensor_with_allocationptr2) { + platform::CPUPlace cpu_place; + TemporaryAllocator cpu_alloc(cpu_place); + { + size_t memory_size = 400; + int numel = memory_size / sizeof(float); + + framework::Tensor out_side_tensor; + void* address; + { + auto allocation = cpu_alloc.Allocate(memory_size); + address = allocation->ptr(); + framework::Tensor tensor = GetTensor( + std::move(allocation), framework::make_ddim({numel})); + PADDLE_ENFORCE_EQ(address, tensor.data()); + PADDLE_ENFORCE_EQ(tensor.numel(), numel); + + out_side_tensor.ShareDataWith(tensor); + } + PADDLE_ENFORCE_EQ(address, out_side_tensor.data()); + PADDLE_ENFORCE_EQ(out_side_tensor.numel(), numel); + } + +#ifdef PADDLE_WITH_CUDA + platform::CUDAPlace gpu_place(0); + TemporaryAllocator gpu_alloc(gpu_place); + { + void* address; + size_t memory_size = 500; + int numel = memory_size / sizeof(float); + framework::Tensor out_side_tensor; + { + auto allocation = gpu_alloc.Allocate(memory_size); + address = allocation->ptr(); + framework::Tensor tensor = GetTensor( + std::move(allocation), framework::make_ddim({numel})); + PADDLE_ENFORCE_EQ(address, tensor.data()); + PADDLE_ENFORCE_EQ(tensor.numel(), numel); + + out_side_tensor.ShareDataWith(tensor); + } + PADDLE_ENFORCE_EQ(address, out_side_tensor.data()); + PADDLE_ENFORCE_EQ(out_side_tensor.numel(), numel); + // The allocation is holded by out_side_tensor. + PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); + gpu_alloc.Release([]() {}); + PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); + } + + // The allocation is not holded now, it should be placed to + // TemporaryAllocationQueue. + PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1); + gpu_alloc.Release([]() {}); + PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); +#endif +} + +} // namespace platform +} // namespace paddle