diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 5709eb1a7d4d26e6cc358651c1521ebf9a279801..4d29564aeed74558b7f0ec580568f70dad0b40cc 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -377,6 +377,30 @@ class ExecutionContext { return op_.Outputs(name); } + template + Tensor AllocateTmpTensor(const framework::DDim& dim, + const DevContext& dev_ctx) const { + auto tmp_allocation_ptr = platform::DeviceTemporaryAllocator::Instance() + .Get(dev_ctx) + .Allocate(product(dim) * sizeof(T)); + auto& deleter = tmp_allocation_ptr.get_deleter(); + auto* allocation_ptr = tmp_allocation_ptr.release(); + auto shared_allocation = std::shared_ptr( + allocation_ptr, deleter); + + PADDLE_ENFORCE( + dynamic_cast(allocation_ptr) != nullptr, + "The AllocationPtr must be TemporaryAllocation."); + PADDLE_ENFORCE_EQ(allocation_ptr->size(), + framework::product(dim) * sizeof(T)); + + paddle::framework::Tensor temp_tensor( + framework::ToDataType(std::type_index(typeid(T)))); + temp_tensor.Resize(dim); + temp_tensor.ResetHolder(std::move(shared_allocation)); + return temp_tensor; + } + private: const OperatorBase& op_; const Scope& scope_; diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index 871c7bd2a77d1cc5057177619b5cd7b2083ff308..1ffd357e62b4bdc72dbec627c463730aa2c8f720 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -151,27 +151,5 @@ void TensorToVector(const Tensor& src, std::vector* dst) { memory::Copy(dst_place, dst_ptr, boost::get(src.place()), src_ptr, size); } - -template -paddle::framework::Tensor GetTensor( - memory::allocation::AllocationPtr temp_allocation_ptr, - const framework::DDim& dim) { - auto& deleter = temp_allocation_ptr.get_deleter(); - auto* allocation_ptr = temp_allocation_ptr.release(); - auto shared_allocation = - std::shared_ptr(allocation_ptr, deleter); - - PADDLE_ENFORCE( - dynamic_cast(allocation_ptr) != nullptr, - "The AllocationPtr must be TemporaryAllocation."); - PADDLE_ENFORCE_EQ(allocation_ptr->size(), - framework::product(dim) * sizeof(T)); - - paddle::framework::Tensor temp_tensor( - framework::ToDataType(std::type_index(typeid(T)))); - temp_tensor.Resize(dim); - temp_tensor.ResetHolder(std::move(shared_allocation)); - return temp_tensor; -} } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index 2519f5e7acdb7828743c6e114adfe5e530058406..24b8e238799d22584fa68ccd5d1b2305a736c6c3 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -18,7 +18,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/im2col.h" @@ -158,10 +157,7 @@ class GemmConvKernel : public framework::OpKernel { // to call the matrix multiplication interface. Tensor col_matrix; if (is_expand) { - auto tmp_allocation_ptr = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate( - framework::product(col_shape) * sizeof(T)); - col = framework::GetTensor(std::move(tmp_allocation_ptr), col_shape); + col = context.AllocateTmpTensor(col_shape, dev_ctx); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } @@ -293,10 +289,7 @@ class GemmConvGradKernel : public framework::OpKernel { // to call the matrix multiplication interface. Tensor col_matrix; if (is_expand) { - auto tmp_allocation_ptr = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate( - framework::product(col_shape) * sizeof(T)); - col = framework::GetTensor(std::move(tmp_allocation_ptr), col_shape); + col = context.AllocateTmpTensor(col_shape, dev_ctx); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 05a0f14440732e5aef2ff665fbd3a5c1c7094581..1f51b5bab3068cc89bffa85de28a9438359659f3 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -100,7 +100,7 @@ ENDIF() nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info) if(WITH_GPU) - nv_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor) + nv_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor operator) else() - cc_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor) + cc_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor operator) endif() diff --git a/paddle/fluid/platform/temporary_allocator_test.cc b/paddle/fluid/platform/temporary_allocator_test.cc index e4e5be5b89f4cbecd6b5e9deec9cc5bffa6a4917..35d1d929819c41b213bc51ec24ac725021a76c88 100644 --- a/paddle/fluid/platform/temporary_allocator_test.cc +++ b/paddle/fluid/platform/temporary_allocator_test.cc @@ -14,12 +14,27 @@ #include "paddle/fluid/platform/temporary_allocator.h" #include +#include +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor_util.h" + DECLARE_double(limit_of_temporary_allocation); namespace paddle { namespace platform { +class DummyOp : public framework::OperatorBase { + public: + DummyOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + protected: + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override {} +}; + TEST(temporary_allocator, temporary_allocator) { platform::CPUPlace cpu_place; TemporaryAllocator alloc(cpu_place); @@ -68,96 +83,92 @@ TEST(temporary_allocator, add_callback) { } TEST(temporary_allocator, create_tensor_with_allocationptr) { - platform::CPUPlace cpu_place; - TemporaryAllocator cpu_alloc(cpu_place); + framework::VariableNameMap dummy_vars; + framework::AttributeMap dummy_attrs; + DummyOp op("dummy", dummy_vars, dummy_vars, dummy_attrs); + framework::Scope scope; + framework::VariableValueMap vars; + framework::RuntimeContext run_ctx(vars, vars); + size_t memory_size = 300; { - size_t memory_size = 200; - auto allocation = cpu_alloc.Allocate(memory_size); - void* address = allocation->ptr(); + platform::CPUPlace cpu_place; + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = + static_cast(pool.Get(cpu_place)); + framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx); + int numel = memory_size / sizeof(float); - framework::Tensor tensor = framework::GetTensor( - std::move(allocation), framework::make_ddim({numel})); - PADDLE_ENFORCE_EQ(address, tensor.data()); + framework::Tensor tensor = + ctx.AllocateTmpTensor( + framework::make_ddim({numel}), *dev_ctx); PADDLE_ENFORCE_EQ(tensor.numel(), numel); } #ifdef PADDLE_WITH_CUDA - platform::CUDAPlace gpu_place(0); - TemporaryAllocator gpu_alloc(gpu_place); - { - size_t memory_size = 300; - auto allocation = gpu_alloc.Allocate(memory_size); - void* address = allocation->ptr(); + platform::CUDAPlace gpu_place(0); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = + static_cast(pool.Get(gpu_place)); + framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx); int numel = memory_size / sizeof(float); - framework::Tensor tensor = framework::GetTensor( - std::move(allocation), framework::make_ddim({numel})); - PADDLE_ENFORCE_EQ(address, tensor.data()); + framework::Tensor tensor = + ctx.AllocateTmpTensor( + framework::make_ddim({numel}), *dev_ctx); PADDLE_ENFORCE_EQ(tensor.numel(), numel); } - - // The allocation is not holded now, it should be placed to - // TemporaryAllocationQueue. - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1); - gpu_alloc.Release([]() {}); - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); #endif } TEST(temporary_allocator, create_tensor_with_allocationptr2) { - platform::CPUPlace cpu_place; - TemporaryAllocator cpu_alloc(cpu_place); + framework::VariableNameMap dummy_vars; + framework::AttributeMap dummy_attrs; + DummyOp op("dummy", dummy_vars, dummy_vars, dummy_attrs); + framework::Scope scope; + framework::VariableValueMap vars; + framework::RuntimeContext run_ctx(vars, vars); + size_t memory_size = 400; { - size_t memory_size = 400; + platform::CPUPlace cpu_place; + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = + static_cast(pool.Get(cpu_place)); + framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx); int numel = memory_size / sizeof(float); framework::Tensor out_side_tensor; - void* address; { - auto allocation = cpu_alloc.Allocate(memory_size); - address = allocation->ptr(); - framework::Tensor tensor = framework::GetTensor( - std::move(allocation), framework::make_ddim({numel})); - PADDLE_ENFORCE_EQ(address, tensor.data()); + framework::Tensor tensor = + ctx.AllocateTmpTensor( + framework::make_ddim({numel}), *dev_ctx); PADDLE_ENFORCE_EQ(tensor.numel(), numel); out_side_tensor.ShareDataWith(tensor); } - PADDLE_ENFORCE_EQ(address, out_side_tensor.data()); PADDLE_ENFORCE_EQ(out_side_tensor.numel(), numel); } #ifdef PADDLE_WITH_CUDA - platform::CUDAPlace gpu_place(0); - TemporaryAllocator gpu_alloc(gpu_place); { - void* address; + platform::CUDAPlace gpu_place(0); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = + static_cast(pool.Get(gpu_place)); + framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx); + size_t memory_size = 500; int numel = memory_size / sizeof(float); framework::Tensor out_side_tensor; { - auto allocation = gpu_alloc.Allocate(memory_size); - address = allocation->ptr(); - framework::Tensor tensor = framework::GetTensor( - std::move(allocation), framework::make_ddim({numel})); - PADDLE_ENFORCE_EQ(address, tensor.data()); + framework::Tensor tensor = + ctx.AllocateTmpTensor( + framework::make_ddim({numel}), *dev_ctx); PADDLE_ENFORCE_EQ(tensor.numel(), numel); out_side_tensor.ShareDataWith(tensor); } - PADDLE_ENFORCE_EQ(address, out_side_tensor.data()); PADDLE_ENFORCE_EQ(out_side_tensor.numel(), numel); - // The allocation is holded by out_side_tensor. - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); - gpu_alloc.Release([]() {}); - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); } - - // The allocation is not holded now, it should be placed to - // TemporaryAllocationQueue. - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1); - gpu_alloc.Release([]() {}); - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); #endif }