diff --git a/cmake/generic.cmake b/cmake/generic.cmake index d846e08b3c390d674426058f1f98515bc2d9a815..f6749c2ab858d2daee55ede8cddb8a18d522f90e 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -389,7 +389,6 @@ function(cc_test_run TARGET_NAME) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) - set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true) # No unit test should exceed 10 minutes. set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600) @@ -472,7 +471,6 @@ function(nv_test TARGET_NAME) add_test(${TARGET_NAME} ${TARGET_NAME}) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) - set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true) endif() endfunction(nv_test) @@ -725,7 +723,7 @@ function(py_test TARGET_NAME) if(WITH_COVERAGE) add_test(NAME ${TARGET_NAME} COMMAND ${CMAKE_COMMAND} -E env FLAGS_init_allocated_mem=true FLAGS_cudnn_deterministic=true - FLAGS_cpu_deterministic=true FLAGS_limit_of_tmp_allocation=4294967296 # 4G + FLAGS_cpu_deterministic=true PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS} COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data ${PYTHON_EXECUTABLE} -m coverage run --branch -p ${py_test_SRCS} ${py_test_ARGS} @@ -733,7 +731,7 @@ function(py_test TARGET_NAME) else() add_test(NAME ${TARGET_NAME} COMMAND ${CMAKE_COMMAND} -E env FLAGS_init_allocated_mem=true FLAGS_cudnn_deterministic=true - FLAGS_cpu_deterministic=true FLAGS_limit_of_tmp_allocation=4294967296 # 4G + FLAGS_cpu_deterministic=true PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS} ${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS} WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 4d42220d2cd4dc3488bde95a45a115012ccb6336..275bb07a1f2f6f8e1cddce4b5ef0fcae67ddf801 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -123,8 +123,8 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context) cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place) -cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope - glog shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack) +cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog data_feed_proto + shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) diff --git a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc index 58538f918c61c2dfeaaaef2676f7775606ff278e..070a17a9de591a2a2130338d7f82bc5d534fa066 100644 --- a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc @@ -18,6 +18,7 @@ #include "paddle/fluid/framework/details/reduce_and_gather.h" #include "paddle/fluid/framework/details/variable_visitor.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/profiler.h" @@ -103,16 +104,15 @@ void SparseAllReduceOpHandle::RunImplEncoded() { int dev_id = boost::get(place).device; auto *nccl_ctxs = nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, false); auto &nccl_ctx = nccl_ctxs->at(dev_id); + auto *dev_ctx = nccl_ctxs->DevCtx(dev_id); auto stream = nccl_ctx.stream(); auto comm = nccl_ctx.comm_; - auto &allocator = - platform::DeviceTemporaryAllocator::Instance().Get(place, stream); int encode_size = 2 * k * sizeof(int); // dgc use ncclAllGather to get all the encoded data // so the buffer need nranks. int buf_size = nranks_ * encode_size; - auto tmp_ious_data = allocator.Allocate(buf_size); + auto tmp_ious_data = memory::Alloc(*dev_ctx, buf_size); void *gather_buff = reinterpret_cast(tmp_ious_data->ptr()); VLOG(10) << "in_numel:" << in_numel << ", out_numel:" << out_numel diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 03f083cdbd9084005052a0284152eb4f45f16915..359f58328a86c10896f5a852c3683e60841f1eab 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -35,6 +35,7 @@ limitations under the License. */ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/variant.h" @@ -360,9 +361,7 @@ class ExecutionContext { template Tensor AllocateTmpTensor(const framework::DDim& dim, const DevContext& dev_ctx) const { - auto tmp_allocation_ptr = platform::DeviceTemporaryAllocator::Instance() - .Get(dev_ctx) - .Allocate(product(dim) * sizeof(T)); + auto tmp_allocation_ptr = memory::Alloc(dev_ctx, product(dim) * sizeof(T)); auto& deleter = tmp_allocation_ptr.get_deleter(); auto* allocation_ptr = tmp_allocation_ptr.release(); auto shared_allocation = std::shared_ptr( diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index bb7bbc4cefbe38d15aafd7c211361a3d97f0f4b7..cab72e294f6c2b07da8d5db9bf38de8732c0e5d8 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -19,7 +19,6 @@ limitations under the License. */ #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/temporary_allocator.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/memory/CMakeLists.txt b/paddle/fluid/memory/CMakeLists.txt index 7eb663ea280e65f3c10304aa47c9970df099b901..8cf53b9739992760240284a172f108d6753a5608 100644 --- a/paddle/fluid/memory/CMakeLists.txt +++ b/paddle/fluid/memory/CMakeLists.txt @@ -1,12 +1,29 @@ add_subdirectory(detail) add_subdirectory(allocation) -cc_library(malloc SRCS malloc.cc DEPS place enforce allocator_facade profiler) + +if (WITH_MKLDNN) + set(MKLDNN_CTX_DEPS mkldnn) +else () + set(MKLDNN_CTX_DEPS) +endif() + +cc_library(malloc SRCS malloc.cc DEPS + place enforce allocator_facade profiler ${MKLDNN_CTX_DEPS}) cc_library(memcpy SRCS memcpy.cc DEPS place) cc_library(memory DEPS malloc memcpy) + +if (WITH_GPU) + add_dependencies(malloc cuda_device_context_allocator_pool) + target_link_libraries(malloc cuda_device_context_allocator_pool) + nv_test(malloc_test + SRCS malloc_test.cu + DEPS device_context malloc) +endif() + #if (WITH_GPU) # nv_test(pinned_memory_test SRCS pinned_memory_test.cu DEPS place memory) #endif() diff --git a/paddle/fluid/memory/allocation/CMakeLists.txt b/paddle/fluid/memory/allocation/CMakeLists.txt index 565951ed5e61a5b5cf5aa277c9422646972faece..f00dda0b54843e7e2e50b151e91b8ee0664c3618 100644 --- a/paddle/fluid/memory/allocation/CMakeLists.txt +++ b/paddle/fluid/memory/allocation/CMakeLists.txt @@ -6,8 +6,20 @@ cc_library(best_fit_allocator SRCS best_fit_allocator.cc DEPS allocator) cc_library(naive_best_fit_allocator SRCS naive_best_fit_allocator.cc DEPS allocator buddy_allocator profiler) cc_test(buffered_allocator_test SRCS buffered_allocator_test.cc DEPS locked_allocator buffered_allocator cpu_allocator best_fit_allocator) +if (WITH_MKLDNN) + set(MKLDNN_CTX_DEPS mkldnn) +else () + set(MKLDNN_CTX_DEPS) +endif() + if (WITH_GPU) nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard) + nv_library(cuda_device_context_allocation SRCS cuda_device_context_allocation.cc + DEPS allocator enforce place ${MKLDNN_CTX_DEPS}) + nv_library(cuda_device_context_allocator SRCS cuda_device_context_allocator.cc + DEPS allocator enforce place cuda_device_context_allocation ${MKLDNN_CTX_DEPS}) + nv_library(cuda_device_context_allocator_pool SRCS cuda_device_context_allocator_pool.cc + DEPS allocator enforce place cuda_device_context_allocation cuda_device_context_allocator ${MKLDNN_CTX_DEPS}) endif() cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator) diff --git a/paddle/fluid/memory/allocation/cuda_device_context_allocation.cc b/paddle/fluid/memory/allocation/cuda_device_context_allocation.cc new file mode 100644 index 0000000000000000000000000000000000000000..e361f71f4f75b14fa46e4bd5940ab100e7110cb2 --- /dev/null +++ b/paddle/fluid/memory/allocation/cuda_device_context_allocation.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/memory/allocation/cuda_device_context_allocation.h" +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace memory { +namespace allocation { + +CUDADeviceContextAllocation::CUDADeviceContextAllocation( + AllocationPtr allocation) + : Allocation(allocation->ptr(), allocation->size(), allocation->place()), + underlying_allocation_(std::move(allocation)) {} + +CUDADeviceContextAllocation::~CUDADeviceContextAllocation() { + PADDLE_ENFORCE_NOT_NULL( + dev_ctx_, "Didn't set device context for CUDADeviceContextAllocation"); + auto *p_allocation = underlying_allocation_.release(); + VLOG(4) << "Adding callback to delete CUDADeviceContextAllocation at " + << p_allocation; + dev_ctx_->AddStreamCallback([p_allocation] { + VLOG(4) << "Delete CUDADeviceContextAllocation at " << p_allocation; + AllocationDeleter()(p_allocation); + }); +} + +void CUDADeviceContextAllocation::SetCUDADeviceContext( + const platform::CUDADeviceContext *dev_ctx) { + dev_ctx_ = dev_ctx; +} + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/allocation/cuda_device_context_allocation.h b/paddle/fluid/memory/allocation/cuda_device_context_allocation.h new file mode 100644 index 0000000000000000000000000000000000000000..cf0d8792d0ab4cb7bd1e23344950d924aae71280 --- /dev/null +++ b/paddle/fluid/memory/allocation/cuda_device_context_allocation.h @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace memory { +namespace allocation { + +class CUDADeviceContextAllocation : public Allocation { + public: + explicit CUDADeviceContextAllocation(AllocationPtr allocation); + ~CUDADeviceContextAllocation(); + void SetCUDADeviceContext(const platform::CUDADeviceContext *dev_ctx); + + private: + AllocationPtr underlying_allocation_; + const platform::CUDADeviceContext *dev_ctx_{nullptr}; +}; + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/allocation/cuda_device_context_allocator.cc b/paddle/fluid/memory/allocation/cuda_device_context_allocator.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc9adc5caa261485fe383dc0ebd33f92beaebdff --- /dev/null +++ b/paddle/fluid/memory/allocation/cuda_device_context_allocator.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h" + +#include "paddle/fluid/memory/allocation/cuda_device_context_allocation.h" +#include "paddle/fluid/platform/cuda_device_guard.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace memory { +namespace allocation { + +CUDADeviceContextAllocator::CUDADeviceContextAllocator( + const platform::CUDAPlace place, cudaStream_t default_stream) + : place_(place), default_stream_(default_stream) { + platform::CUDADeviceGuard guard(place_.device); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaEventCreate(&event_, cudaEventDisableTiming), + "Create event failed in CUDADeviceContextAllocator"); +} + +CUDADeviceContextAllocator::~CUDADeviceContextAllocator() { + if (event_) { + platform::CUDADeviceGuard guard(place_.device); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaEventDestroy(event_), + "Destory event failed in CUDADeviceContextAllocator destroctor"); + } +} + +Allocation *CUDADeviceContextAllocator::AllocateImpl(size_t size) { + PADDLE_ENFORCE_NOT_NULL( + default_stream_, + "Didn't set default stream for CUDADeviceContextAllocator"); + platform::CUDADeviceGuard guard(place_.device); + auto allocation = + new CUDADeviceContextAllocation(memory::Alloc(place_, size)); + // Wait for the event on stream + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaEventRecord(event_, default_stream_), + "Failed to record event in CUDADeviceContextAllocator"); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamWaitEvent(default_stream_, event_, 0), + "Failed to wait event in CUDADeviceContextAllocator"); + return allocation; +} + +void CUDADeviceContextAllocator::FreeImpl(Allocation *allocation) { + delete allocation; +} + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/allocation/cuda_device_context_allocator.h b/paddle/fluid/memory/allocation/cuda_device_context_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..e27cb72af6e0961fb1aafcf9e7587b81f38c541a --- /dev/null +++ b/paddle/fluid/memory/allocation/cuda_device_context_allocator.h @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace memory { +namespace allocation { + +class CUDADeviceContextAllocator : public Allocator { + public: + explicit CUDADeviceContextAllocator(platform::CUDAPlace place, + cudaStream_t default_stream); + ~CUDADeviceContextAllocator(); + + protected: + Allocation *AllocateImpl(size_t size) override; + void FreeImpl(Allocation *allocation) override; + + private: + platform::CUDAPlace place_; + cudaEvent_t event_{nullptr}; + cudaStream_t default_stream_{nullptr}; +}; + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.cc b/paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.cc new file mode 100644 index 0000000000000000000000000000000000000000..e0b6825944cac826e9e0571c633e8a98250c7570 --- /dev/null +++ b/paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.h" + +#include +#include +#include "paddle/fluid/memory/allocation/cuda_device_context_allocation.h" +#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace memory { +namespace allocation { + +CUDADeviceContextAllocatorPool &CUDADeviceContextAllocatorPool::Instance() { + static CUDADeviceContextAllocatorPool pool; + return pool; +} + +AllocationPtr CUDADeviceContextAllocatorPool::Alloc( + const platform::CUDADeviceContext &dev_ctx, size_t size) { + auto iter = + allocators_.find(boost::get(dev_ctx.GetPlace())); + PADDLE_ENFORCE_EQ(iter != allocators_.end(), true, + "CUDADeviceContextAllocatorPool initialization error"); + auto &allocator = iter->second; + AllocationPtr allocation = allocator->Allocate(size); + static_cast(allocation.get()) + ->SetCUDADeviceContext(&dev_ctx); + return allocation; +} + +CUDADeviceContextAllocatorPool::CUDADeviceContextAllocatorPool() { + std::vector devices = platform::GetSelectedDevices(); + for (int i : devices) { + auto place = platform::CUDAPlace(i); + auto compute_stream = + platform::DeviceContextPool::Instance().GetByPlace(place)->stream(); + auto allocator = std::shared_ptr( + new CUDADeviceContextAllocator(place, compute_stream)); + allocators_.insert(make_pair(place, allocator)); + } +} + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.h b/paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..03b7c34f71e8ad5141fec6c8d50c2f4dbd781654 --- /dev/null +++ b/paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.h @@ -0,0 +1,41 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace memory { +namespace allocation { + +class CUDADeviceContextAllocatorPool { + public: + static CUDADeviceContextAllocatorPool &Instance(); + + AllocationPtr Alloc(const platform::CUDADeviceContext &dev_ctx, size_t size); + + private: + CUDADeviceContextAllocatorPool(); + std::map> + allocators_; +}; + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/malloc.cc b/paddle/fluid/memory/malloc.cc index 5884433aaff115c053b10848b32f8610fcb69747..f1a75f2add384910b706ac69c2d001ad1e659359 100644 --- a/paddle/fluid/memory/malloc.cc +++ b/paddle/fluid/memory/malloc.cc @@ -17,17 +17,44 @@ limitations under the License. */ #include #include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/memory/allocation/allocator_strategy.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.h" +#endif +#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" + namespace paddle { namespace memory { -std::shared_ptr AllocShared(const platform::Place& place, + +std::shared_ptr AllocShared(const platform::Place &place, size_t size) { return allocation::AllocatorFacade::Instance().AllocShared(place, size); } -AllocationPtr Alloc(const platform::Place& place, size_t size) { +AllocationPtr Alloc(const platform::Place &place, size_t size) { return allocation::AllocatorFacade::Instance().Alloc(place, size); } +AllocationPtr Alloc(const platform::DeviceContext &dev_ctx, size_t size) { + auto place = dev_ctx.GetPlace(); +#ifdef PADDLE_WITH_CUDA + if (size == 0 || !platform::is_gpu_place(place)) { + return Alloc(place, size); + } + auto *default_dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + auto &desired_dev_ctx = + static_cast(dev_ctx); + if (default_dev_ctx->stream() == desired_dev_ctx.stream()) { + return Alloc(place, size); + } else { + return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc( + desired_dev_ctx, size); + } +#else + return Alloc(place, size); +#endif +} + } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/malloc.h b/paddle/fluid/memory/malloc.h index 6731203fccb67fc5ded018bbe2ca51878da1a4c3..9ba572acaca9eba2b913847c52e5a54e19d79bdf 100644 --- a/paddle/fluid/memory/malloc.h +++ b/paddle/fluid/memory/malloc.h @@ -18,7 +18,13 @@ limitations under the License. */ #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/platform/place.h" namespace paddle { + +namespace platform { +class DeviceContext; +} // platform + namespace memory { + using allocation::Allocation; using allocation::Allocator; using allocation::AllocationPtr; @@ -28,5 +34,7 @@ extern std::shared_ptr AllocShared(const platform::Place& place, extern AllocationPtr Alloc(const platform::Place& place, size_t size); +extern AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size); + } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/malloc_test.cu b/paddle/fluid/memory/malloc_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..89853e159bde378ff1084ff656718c5f4316f051 --- /dev/null +++ b/paddle/fluid/memory/malloc_test.cu @@ -0,0 +1,137 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include // NOLINT +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace memory { + +const int NUM_STREAMS = 8; +const int N = 2; +const float DELTA = 1e-1; + +using CudaDevCtxVec = std::vector>; + +__global__ void kernel(float *x, int n) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (int i = tid; i < n; i += blockDim.x * gridDim.x) { + x[i] = 3.14159 * i; + } +} + +void CheckKernelOutput(float *x, int n) { + auto host_x = std::unique_ptr(new float[n]); + for (int i = 0; i < n; ++i) { + EXPECT_TRUE(cudaSuccess == cudaMemcpy(host_x.get(), x, n * sizeof(float), + cudaMemcpyDeviceToHost)); + EXPECT_GE(host_x[i] + DELTA, 3.14159f * i); + EXPECT_LE(host_x[i] - DELTA, 3.14159f * i); + } +} + +void MultiStreamCompute(float **data, float **second_data, + const platform::CUDADeviceContext &ctx) { + // multi-streams + AllocationPtr allocation_ptr = Alloc(ctx, N * sizeof(float)); + EXPECT_GE(allocation_ptr->size(), N * sizeof(float)); + *data = reinterpret_cast(allocation_ptr->ptr()); + kernel<<<1, 64, 0, ctx.stream()>>>(*data, N); + + // allocate and compute on same stream again + allocation_ptr = Alloc(ctx, N * sizeof(float)); + EXPECT_GE(allocation_ptr->size(), N * sizeof(float)); + *second_data = reinterpret_cast(allocation_ptr->ptr()); + kernel<<<1, 64, 0, ctx.stream()>>>(*second_data, N); +} + +TEST(Malloc, CUDADeviceContextMultiStream) { + auto place = platform::CUDAPlace(0); + EXPECT_TRUE(cudaSuccess == cudaSetDevice(0)); + + AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float)); + EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float)); + float *main_stream_data = + reinterpret_cast(main_stream_alloc_ptr->ptr()); + + float *data[NUM_STREAMS]; + float *second_data[NUM_STREAMS]; + CudaDevCtxVec dev_ctx; + + // default stream + kernel<<<1, 64>>>(main_stream_data, N); + main_stream_alloc_ptr.reset(); + + for (int i = 0; i < NUM_STREAMS; ++i) { + dev_ctx.push_back(std::unique_ptr( + new platform::CUDADeviceContext(place))); + MultiStreamCompute(&data[i], &second_data[i], *dev_ctx[i]); + } + + EXPECT_TRUE(cudaSuccess == cudaDeviceSynchronize()); + for (int i = 0; i < NUM_STREAMS; ++i) { + CheckKernelOutput(data[i], N); + CheckKernelOutput(second_data[i], N); + } +} + +TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) { + auto place = platform::CUDAPlace(0); + EXPECT_TRUE(cudaSuccess == cudaSetDevice(0)); + + AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float)); + EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float)); + float *main_stream_data = + reinterpret_cast(main_stream_alloc_ptr->ptr()); + + float *data[NUM_STREAMS]; + float *second_data[NUM_STREAMS]; + CudaDevCtxVec dev_ctx; + std::vector threads; + + // default stream + kernel<<<1, 64>>>(main_stream_data, N); + main_stream_alloc_ptr.reset(); + + for (int i = 0; i < NUM_STREAMS; ++i) { + dev_ctx.push_back(std::unique_ptr( + new platform::CUDADeviceContext(place))); + threads.push_back(std::thread(MultiStreamCompute, &data[i], &second_data[i], + std::cref(*dev_ctx[i]))); + } + + for (int i = 0; i < NUM_STREAMS; ++i) { + threads[i].join(); + } + + EXPECT_TRUE(cudaSuccess == cudaDeviceSynchronize()); + for (int i = 0; i < NUM_STREAMS; ++i) { + CheckKernelOutput(data[i], N); + CheckKernelOutput(second_data[i], N); + } +} + +TEST(Malloc, AllocZero) { + auto place = platform::CUDAPlace(0); + AllocationPtr allocation_ptr = Alloc(place, 0); + EXPECT_GE(allocation_ptr->size(), 0); +} +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/operators/deformable_psroi_pooling_op.cu b/paddle/fluid/operators/deformable_psroi_pooling_op.cu index c38e955385c79de82b28884bee81bfb57b993af6..4bf0416725b7f210345e7e09fb1951697d8575f7 100644 --- a/paddle/fluid/operators/deformable_psroi_pooling_op.cu +++ b/paddle/fluid/operators/deformable_psroi_pooling_op.cu @@ -28,6 +28,7 @@ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/deformable_psroi_pooling_op.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" @@ -231,10 +232,8 @@ class DeformablePSROIPoolCUDAKernel : public framework::OpKernel { } auto& dev_ctx = ctx.cuda_device_context(); - auto& allocator = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); int bytes = roi_batch_id_list.numel() * sizeof(int); - auto roi_ptr = allocator.Allocate(bytes); + auto roi_ptr = memory::Alloc(dev_ctx, bytes); int* roi_id_data = reinterpret_cast(roi_ptr->ptr()); const auto gplace = boost::get(ctx.GetPlace()); memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes, @@ -499,10 +498,8 @@ class DeformablePSROIPoolGradCUDAKernel : public framework::OpKernel { } } - auto& allocator = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); int bytes = roi_batch_id_list.numel() * sizeof(int); - auto roi_ptr = allocator.Allocate(bytes); + auto roi_ptr = memory::Alloc(dev_ctx, bytes); int* roi_id_data = reinterpret_cast(roi_ptr->ptr()); const auto gplace = boost::get(ctx.GetPlace()); memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes, diff --git a/paddle/fluid/operators/detection/box_coder_op.cu b/paddle/fluid/operators/detection/box_coder_op.cu index 19a5bb90fa828899ad6270c051090dd3662aeed8..b3dd142de77e2f8087ee4493378978f30b00fc58 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cu +++ b/paddle/fluid/operators/detection/box_coder_op.cu @@ -11,7 +11,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/detection/box_coder_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -174,10 +174,8 @@ class BoxCoderCUDAKernel : public framework::OpKernel { int grid = (row * col + block - 1) / block; auto& device_ctx = context.cuda_device_context(); - auto& allocator = - platform::DeviceTemporaryAllocator::Instance().Get(device_ctx); int bytes = var_size * sizeof(float); - auto dev_var = allocator.Allocate(bytes); + auto dev_var = memory::Alloc(device_ctx, bytes); float* dev_var_data = reinterpret_cast(dev_var->ptr()); auto cplace = platform::CPUPlace(); const auto gplace = boost::get(context.GetPlace()); diff --git a/paddle/fluid/operators/detection/yolo_box_op.cu b/paddle/fluid/operators/detection/yolo_box_op.cu index 5a882958e66a79507e053a96b15be8cbbcc83164..08ea62bc14e47f0ecad9a51215ae8a42590d0109 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cu +++ b/paddle/fluid/operators/detection/yolo_box_op.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/detection/yolo_box_op.h" #include "paddle/fluid/operators/math/math_function.h" @@ -84,10 +85,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { int input_size = downsample_ratio * h; auto& dev_ctx = ctx.cuda_device_context(); - auto& allocator = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); int bytes = sizeof(int) * anchors.size(); - auto anchors_ptr = allocator.Allocate(sizeof(int) * anchors.size()); + auto anchors_ptr = memory::Alloc(dev_ctx, sizeof(int) * anchors.size()); int* anchors_data = reinterpret_cast(anchors_ptr->ptr()); const auto gplace = boost::get(ctx.GetPlace()); const auto cplace = platform::CPUPlace(); diff --git a/paddle/fluid/operators/dgc_op.h b/paddle/fluid/operators/dgc_op.h index a1dcc2bcc13d233441bdd3e19d373b35b2e40a86..1285daae094ab28cd4ec059094d4baf603870d7d 100644 --- a/paddle/fluid/operators/dgc_op.h +++ b/paddle/fluid/operators/dgc_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "dgc/dgc.h" #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" namespace paddle { @@ -112,9 +113,7 @@ class DGCOpKernel : public framework::OpKernel { framework::DDim{2 * k}, ctx.GetPlace()); int buf_size = paddle::communication::dgc::get_buffer_size(k); - auto& allocator = platform::DeviceTemporaryAllocator::Instance().Get( - ctx.GetPlace(), dev_ctx.stream()); - auto tmp_ious_data = allocator.Allocate(buf_size); + auto tmp_ious_data = memory::Alloc(dev_ctx, buf_size); void* buf = reinterpret_cast(tmp_ious_data->ptr()); if (!paddle::communication::dgc::k_select( diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 422d99dd433055bdc91c4a25e5eab36259011df8..285947567e3603079737f1073ce5120f908238dd 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/math/blas.h" namespace paddle { @@ -184,9 +185,7 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel { // training auto* in_accum = context.Input("InAccum"); auto* in_state = context.Input("InState"); - auto& allocator = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); - auto cur_scale = allocator.Allocate(1 * sizeof(T)); + auto cur_scale = memory::Alloc(dev_ctx, sizeof(T)); T* cur_scale_data = static_cast(cur_scale->ptr()); FindAbsMaxFunctor()(dev_ctx, in->data(), in->numel(), @@ -251,9 +250,7 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel { // training auto* in_accum = context.Input("InAccum"); auto* in_state = context.Input("InState"); - auto& allocator = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); - auto cur_scale = allocator.Allocate(1 * sizeof(T)); + auto cur_scale = memory::Alloc(dev_ctx, sizeof(T)); T* cur_scale_data = static_cast(cur_scale->ptr()); FindAbsMaxFunctor()(dev_ctx, in->data(), in->numel(), diff --git a/paddle/fluid/operators/gather.cu.h b/paddle/fluid/operators/gather.cu.h index d0ab24a39e9e99c378caf60bc3f8474982538303..b3264ec0ad3fa984726244d911dab6f7bd8e95b8 100644 --- a/paddle/fluid/operators/gather.cu.h +++ b/paddle/fluid/operators/gather.cu.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/dim.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/place.h" @@ -142,9 +143,8 @@ void GPUGatherNd(const framework::ExecutionContext& context, } auto& dev_ctx = context.cuda_device_context(); - auto& allocator = platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); int bytes = input_dims_size * sizeof(int); - auto p_input_dims = allocator.Allocate(bytes); + auto p_input_dims = memory::Alloc(dev_ctx, bytes); int* g_input_dims = reinterpret_cast(p_input_dims->ptr()); memory::Copy(gplace, g_input_dims, cplace, v_input_dims.data(), bytes, ctx.stream()); diff --git a/paddle/fluid/operators/math/concat_and_split.cu b/paddle/fluid/operators/math/concat_and_split.cu index 153e6117227bf9fd273f83f8e64f859a54380053..5a7cd602c857b1345f4f48f3e799403130782a48 100644 --- a/paddle/fluid/operators/math/concat_and_split.cu +++ b/paddle/fluid/operators/math/concat_and_split.cu @@ -15,6 +15,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/mixed_vector.h" +#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/float16.h" @@ -264,8 +265,7 @@ class ConcatFunctor { const T** dev_ins_data = nullptr; if (!has_same_shape || in_num < 2 || in_num > 4) { tmp_dev_ins_data = - platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate( - inputs_data.size() * sizeof(T*)); + memory::Alloc(context, inputs_data.size() * sizeof(T*)); memory::Copy(boost::get(context.GetPlace()), tmp_dev_ins_data->ptr(), platform::CPUPlace(), static_cast(inputs_data.data()), @@ -292,8 +292,7 @@ class ConcatFunctor { } } else { auto tmp_dev_ins_col_data = - platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate( - inputs_col.size() * sizeof(int)); + memory::Alloc(context, inputs_col.size() * sizeof(int)); memory::Copy(boost::get(context.GetPlace()), tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), static_cast(inputs_col.data()), @@ -356,8 +355,7 @@ class SplitFunctor { T** dev_out_gpu_data = nullptr; if (!has_same_shape || o_num < 2 || o_num > 4) { tmp_dev_outs_data = - platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate( - outputs_data.size() * sizeof(T*)); + memory::Alloc(context, outputs_data.size() * sizeof(T*)); memory::Copy(boost::get(context.GetPlace()), tmp_dev_outs_data->ptr(), platform::CPUPlace(), reinterpret_cast(outputs_data.data()), @@ -384,8 +382,9 @@ class SplitFunctor { } } else { auto tmp_dev_ins_col_data = - platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate( - outputs_cols.size() * sizeof(int)); + memory::Alloc(context, + + outputs_cols.size() * sizeof(int)); memory::Copy(boost::get(context.GetPlace()), tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), reinterpret_cast(outputs_cols.data()), diff --git a/paddle/fluid/operators/mean_iou_op.cu b/paddle/fluid/operators/mean_iou_op.cu index 08088eb8733f28f0dc8ecade2aa4b70342244b0a..ada1892f43dcf33cf4db64215732189947f03579 100644 --- a/paddle/fluid/operators/mean_iou_op.cu +++ b/paddle/fluid/operators/mean_iou_op.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/mean_iou_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -116,9 +117,7 @@ class MeanIoUCUDAOpKernel : public framework::OpKernel { auto out_correct_t = EigenTensor::From(*out_correct); // Temporary memory - auto& allocator = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); - auto tmp_ious_data = allocator.Allocate(num_classes * sizeof(float)); + auto tmp_ious_data = memory::Alloc(dev_ctx, num_classes * sizeof(float)); float* ious_data = static_cast(tmp_ious_data->ptr()); // Init out_wrong, out_correct and out_mean_iou diff --git a/paddle/fluid/operators/roi_align_op.cu b/paddle/fluid/operators/roi_align_op.cu index 8d695fdedd04055215864ca4f0a7059ed7a5d6b0..943c5c81dc47a99f6e2489757b1b15a6ae41bde8 100644 --- a/paddle/fluid/operators/roi_align_op.cu +++ b/paddle/fluid/operators/roi_align_op.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/roi_align_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -272,10 +272,8 @@ class GPUROIAlignOpKernel : public framework::OpKernel { } } auto& dev_ctx = ctx.cuda_device_context(); - auto& allocator = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); int bytes = roi_batch_id_list.numel() * sizeof(int); - auto roi_ptr = allocator.Allocate(bytes); + auto roi_ptr = memory::Alloc(dev_ctx, bytes); int* roi_id_data = reinterpret_cast(roi_ptr->ptr()); const auto gplace = boost::get(ctx.GetPlace()); memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes, @@ -322,9 +320,8 @@ class GPUROIAlignGradOpKernel : public framework::OpKernel { } } auto& dev_ctx = ctx.cuda_device_context(); - auto& allocator = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); - auto roi_ptr = allocator.Allocate(roi_batch_id_list.numel() * sizeof(int)); + auto roi_ptr = + memory::Alloc(dev_ctx, roi_batch_id_list.numel() * sizeof(int)); int* roi_id_data = reinterpret_cast(roi_ptr->ptr()); int bytes = roi_batch_id_list.numel() * sizeof(int); const auto gplace = boost::get(ctx.GetPlace()); diff --git a/paddle/fluid/operators/roi_pool_op.cu b/paddle/fluid/operators/roi_pool_op.cu index ac3a4201e65256ae16c3376b385dd6000da60fe6..da8088d2ea70f589b6a5b8a443f16429cd0d1034 100644 --- a/paddle/fluid/operators/roi_pool_op.cu +++ b/paddle/fluid/operators/roi_pool_op.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/roi_pool_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -170,10 +170,8 @@ class GPUROIPoolOpKernel : public framework::OpKernel { } auto& dev_ctx = ctx.cuda_device_context(); - auto& allocator = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); int bytes = roi_batch_id_list.numel() * sizeof(int); - auto roi_ptr = allocator.Allocate(bytes); + auto roi_ptr = memory::Alloc(dev_ctx, bytes); int* roi_id_data = reinterpret_cast(roi_ptr->ptr()); const auto gplace = boost::get(ctx.GetPlace()); memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes, @@ -221,10 +219,8 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel { } auto& dev_ctx = ctx.cuda_device_context(); - auto& allocator = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); int bytes = roi_batch_id_list.numel() * sizeof(int); - auto roi_ptr = allocator.Allocate(bytes); + auto roi_ptr = memory::Alloc(dev_ctx, bytes); int* roi_id_data = reinterpret_cast(roi_ptr->ptr()); const auto gplace = boost::get(ctx.GetPlace()); memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes, diff --git a/paddle/fluid/operators/scatter.cu.h b/paddle/fluid/operators/scatter.cu.h index 8d28173c8edbb44ae8eca0f0ec269a8ee9ae123d..0e83219ded28d561b2bf7ef03154632503b75ea4 100644 --- a/paddle/fluid/operators/scatter.cu.h +++ b/paddle/fluid/operators/scatter.cu.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "math/math_function.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/place.h" @@ -170,9 +171,8 @@ void GPUScatterNdAdd(const framework::ExecutionContext& context, v_output_dims[i] = static_cast(output_dims[i]); } auto& dev_ctx = context.cuda_device_context(); - auto& allocator = platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); int bytes = output_dims_size * sizeof(int); - auto output_dims_ptr = allocator.Allocate(bytes); + auto output_dims_ptr = memory::Alloc(dev_ctx, bytes); int* g_output_dims = reinterpret_cast(output_dims_ptr->ptr()); memory::Copy(gplace, g_output_dims, cplace, v_output_dims.data(), bytes, ctx.stream()); diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu index aea69de6434a38aa834ff14f6d3d15ad5bbfc3e6..7c3a0ecba02a5d16dcb45025284680ba933ce9d5 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "cub/cub.cuh" +#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -116,9 +117,7 @@ class GPUSigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel { bool normalize = context.Attr("normalize"); // Temporary memory - auto &allocator = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); - auto cnt_ptr = allocator.Allocate(Labels->numel() * sizeof(T)); + auto cnt_ptr = memory::Alloc(dev_ctx, Labels->numel() * sizeof(T)); T *counts = reinterpret_cast(cnt_ptr->ptr()); int limit = Out->numel(); @@ -127,7 +126,7 @@ class GPUSigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel { GPUSigmoidForward<<>>( X->data(), Labels->data(), ignore_index, limit, out_data, counts); if (normalize) { - auto norm_ptr = allocator.Allocate(sizeof(T)); + auto norm_ptr = memory::Alloc(dev_ctx, sizeof(T)); T *norm = reinterpret_cast(norm_ptr->ptr()); Sum<<<1, kNumCUDAThreads, 0, dev_ctx.stream()>>>( counts, limit, static_cast(1e-5), norm); @@ -152,9 +151,7 @@ class GPUSigmoidCrossEntropyWithLogitsGradKernel auto &dev_ctx = context.cuda_device_context(); // Temporary memory - auto &allocator = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); - auto cnt_ptr = allocator.Allocate(X->numel() * sizeof(T)); + auto cnt_ptr = memory::Alloc(dev_ctx, X->numel() * sizeof(T)); T *counts = reinterpret_cast(cnt_ptr->ptr()); int limit = dX->numel(); @@ -165,7 +162,7 @@ class GPUSigmoidCrossEntropyWithLogitsGradKernel dx_data, counts); bool normalize = context.Attr("normalize"); if (normalize) { - auto norm_ptr = allocator.Allocate(sizeof(T)); + auto norm_ptr = memory::Alloc(dev_ctx, sizeof(T)); T *norm = reinterpret_cast(norm_ptr->ptr()); Sum<<<1, kNumCUDAThreads, 0, dev_ctx.stream()>>>( counts, limit, static_cast(1e-5), norm); diff --git a/paddle/fluid/operators/sum_op.cu b/paddle/fluid/operators/sum_op.cu index e3f31c0ae8ecd07b2f06ea2bfa13b32e4a8bdb37..3564ed0c4f0faf45461374ba1faa68c1c7992cb6 100644 --- a/paddle/fluid/operators/sum_op.cu +++ b/paddle/fluid/operators/sum_op.cu @@ -11,6 +11,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/sum_op.h" #include "paddle/fluid/platform/float16.h" @@ -197,8 +198,7 @@ void SumToLoDTensor(const framework::ExecutionContext &context) { } if (!sr_in_out_data.empty()) { auto tmp_sr_in_out_array = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate( - sr_in_out_data.size() * sizeof(T *)); + memory::Alloc(dev_ctx, sr_in_out_data.size() * sizeof(T *)); memory::Copy(boost::get(dev_ctx.GetPlace()), tmp_sr_in_out_array->ptr(), platform::CPUPlace(), @@ -216,9 +216,7 @@ void SumToLoDTensor(const framework::ExecutionContext &context) { } // if indata not null, merge into one kernel call. if (!in_data.empty()) { - auto tmp_in_array = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate( - in_data.size() * sizeof(T *)); + auto tmp_in_array = memory::Alloc(dev_ctx, in_data.size() * sizeof(T *)); memory::Copy(boost::get(dev_ctx.GetPlace()), tmp_in_array->ptr(), platform::CPUPlace(), diff --git a/paddle/fluid/operators/sync_batch_norm_op.cu b/paddle/fluid/operators/sync_batch_norm_op.cu index 8c57b0c9dd985394d6450ab80791f88c6d8b8f90..059effd22d851b14b73fd9ae974e362d563f3cbd 100644 --- a/paddle/fluid/operators/sync_batch_norm_op.cu +++ b/paddle/fluid/operators/sync_batch_norm_op.cu @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "cub/cub.cuh" #include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/batch_norm_op.h" #include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/float16.h" @@ -149,12 +150,10 @@ class SyncBatchNormKernel : public framework::OpKernel { mean_data = est_mean->data(); var_data = est_var->data(); } else { - auto &allocator = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); // x, x^2, 1, here 1 is used to calc device num // device num also can be got from platform::DeviceContextPool const int bytes = (C * 2 + 1) * sizeof(T); - alloc_ptr = allocator.Allocate(bytes); + alloc_ptr = memory::Alloc(dev_ctx, bytes); T *stats = reinterpret_cast(alloc_ptr->ptr()); const int threads = 256; @@ -373,10 +372,8 @@ class SyncBatchNormGradKernel : public framework::OpKernel { const T *saved_mean = ctx.Input("SavedMean")->data(); const T *saved_inv_var = ctx.Input("SavedVariance")->data(); - auto &allocator = - platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); const int bytes = (C * 2 + 1) * sizeof(T); - auto alloc_ptr = allocator.Allocate(bytes); + auto alloc_ptr = memory::Alloc(dev_ctx, bytes); T *stats = reinterpret_cast(alloc_ptr->ptr()); const int threads = 256; diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 0b3b96e82e4c33e691ed7a2de417fc9265cc61e2..3e0f2490ff469a28e6615bddeab5a7282f40c610 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -61,8 +61,6 @@ ELSE() set(MKLDNN_CTX_DEPS) ENDIF() -cc_library(temp_allocator SRCS temporary_allocator.cc DEPS allocator_facade) - nv_library(stream_callback_manager SRCS stream_callback_manager.cc DEPS simple_threadpool enforce) IF(WITH_GPU) set(STREAM_CALLBACK_DEPS stream_callback_manager) @@ -74,7 +72,7 @@ ENDIF() # avoiding cycle dependencies cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc ${STREAM_CALLBACK_DEPS} place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} - temp_allocator ${dgc_deps}) + ${dgc_deps}) if (WITH_DISTRIBUTE) cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce) @@ -117,12 +115,6 @@ cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor) nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info) -if(WITH_GPU) - nv_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor operator) -else() - cc_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor operator) -endif() - if(NOT APPLE AND NOT WIN32) cc_library(device_code SRCS device_code.cc DEPS device_context) if(WITH_GPU) diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index f8099c7e51526d28c0047d8206315f0251768bcb..cd5af6f3abc563ebc90360a1c0f29165505fc768 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -89,47 +89,6 @@ DeviceContextPool::DeviceContextPool( } } -DeviceTemporaryAllocator* DeviceTemporaryAllocator::allocators = nullptr; - -#ifdef PADDLE_WITH_CUDA -platform::TemporaryAllocator& DeviceTemporaryAllocator::Get( - const platform::Place& place, const cudaStream_t& stream) { - PADDLE_ENFORCE(platform::is_gpu_place(place)); - auto place_stream = std::make_pair(place, stream); - std::unique_lock lock(mtx_); - auto it = device_allocator_.find(place_stream); - if (it == device_allocator_.end()) { - auto tmp_allocator = new TemporaryAllocator(place); - tmp_allocator->SetCallback([stream]() { - PADDLE_ENFORCE(cudaStreamSynchronize(stream)); - PADDLE_ENFORCE(cudaGetLastError()); - }); - device_allocator_[place_stream].reset(tmp_allocator); - return *tmp_allocator; - } else { - return *it->second; - } -} - -template <> -platform::TemporaryAllocator& DeviceTemporaryAllocator::Get( - const platform::CUDADeviceContext& dev_ctx) { - return Get(dev_ctx.GetPlace(), dev_ctx.stream()); -} -#endif - -template <> -platform::TemporaryAllocator& DeviceTemporaryAllocator::Get( - const platform::CPUDeviceContext& dev_ctx) { - return cpu_allocator_; -} - -platform::TemporaryAllocator& DeviceTemporaryAllocator::Get( - const platform::Place& place) { - PADDLE_ENFORCE(platform::is_cpu_place(place), "You should pass CPUPlace"); - return cpu_allocator_; -} - CPUDeviceContext::CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } @@ -169,7 +128,9 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { if (UNLIKELY(num_bytes == 0)) { return nullptr; } - auto buf = paddle::memory::Alloc(place_, num_bytes); + auto buf = memory::Alloc(place_, num_bytes); + VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size() + << " requested " << num_bytes; void* retv = buf->ptr(); { std::lock_guard lock(mtx_); @@ -197,7 +158,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { char* scratch = static_cast(scratchpad()) + Eigen::kCudaScratchSize; semaphore_ = reinterpret_cast(scratch); - PADDLE_ENFORCE( + PADDLE_ENFORCE_CUDA_SUCCESS( cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_)); } return semaphore_; @@ -213,36 +174,12 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { mutable std::unordered_map allocations_; }; -CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place) - : workspace_(nullptr), stream_(stream), place_(place) { - PADDLE_ENFORCE(cudaSetDevice(place_.device)); - PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); - PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_)); -} - -CudnnHolder::~CudnnHolder() { - PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); -} - -void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) { - if (required_workspace_len <= WorkspaceSize()) { - return; - } - if (workspace_ != nullptr) { - // Maybe someone is using the current workspace - PADDLE_ENFORCE(cudaStreamSynchronize(*stream_)); - workspace_.reset(); - } - workspace_ = paddle::memory::Alloc(place_, required_workspace_len); -} - -CUDADeviceContext::CUDADeviceContext(CUDAPlace place) - : place_(place), cudnn_holder_(nullptr) { +CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { CUDADeviceGuard guard(place_.device); compute_capability_ = GetCUDAComputeCapability(place_.device); multi_process_ = GetCUDAMultiProcessors(place_.device); max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device); - PADDLE_ENFORCE(cudaStreamCreate(&stream_)); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_->Reinitialize(&stream_, place); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); @@ -302,6 +239,14 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) << "Please recompile or reinstall Paddle with compatible CUDNN " "version."; } + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnCreate(&cudnn_handle_), + "Failed to create Cudnn handle in DeviceContext"); + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnSetStream(cudnn_handle_, stream_), + "Failed to set stream for Cudnn handle in DeviceContext"); + } else { + cudnn_handle_ = nullptr; } } @@ -316,10 +261,14 @@ CUDADeviceContext::~CUDADeviceContext() { cublas_tensor_core_handle_.reset(); eigen_stream_.reset(); eigen_device_.reset(); - PADDLE_ENFORCE(cudaStreamDestroy(stream_)); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(stream_)); + if (cudnn_handle_) { + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroy(cudnn_handle_), + "Failed to destory Cudnn handle"); + } #if !defined(_WIN32) if (nccl_comm_) { - PADDLE_ENFORCE(dynload::ncclCommDestroy(nccl_comm_)); + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclCommDestroy(nccl_comm_)); } #endif } @@ -327,21 +276,17 @@ CUDADeviceContext::~CUDADeviceContext() { Place CUDADeviceContext::GetPlace() const { return place_; } void CUDADeviceContext::Wait() const { - auto& allocator = - DeviceTemporaryAllocator::Instance().Get(*this); - allocator.Release([this]() { - cudaError_t e_sync = cudaStreamSynchronize(stream_); - if (e_sync != 0) { - LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync) - << " errno:" << e_sync; - } + cudaError_t e_sync = cudaStreamSynchronize(stream_); + if (e_sync != 0) { + LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync) + << " errno: " << e_sync; + } - cudaError_t e_get = cudaGetLastError(); - if (e_get != 0) { - LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get) - << " errno:" << e_get; - } - }); + cudaError_t e_get = cudaGetLastError(); + if (e_get != 0) { + LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get) + << " errno: " << e_get; + } } int CUDADeviceContext::GetComputeCapability() const { @@ -360,21 +305,10 @@ bool CUDADeviceContext::tensor_core_available() const { return cublas_tensor_core_handle_ != nullptr; } -CudnnHolder* CUDADeviceContext::cudnn_holder() const { - std::call_once(init_cudnn_, [&]() { - if (dynload::HasCUDNN()) { - cudnn_holder_.reset(new CudnnHolder(&stream_, place_)); - } - }); - return cudnn_holder_.get(); -} - -cudnnHandle_t CUDADeviceContext::cudnn_handle() const { - return cudnn_holder()->cudnn_handle(); -} +cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { - return CudnnWorkspaceHandle(cudnn_holder()); + return CudnnWorkspaceHandle(*this); } cudaStream_t CUDADeviceContext::stream() const { return stream_; } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 8ad9f14a786f7acb054cd720a63efcdd4f58cd79..bdc2670dc20f07f0aed366686526aeba1d8066cf 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -18,7 +18,6 @@ limitations under the License. */ #include #include #include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/platform/temporary_allocator.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/dynload/cublas.h" @@ -45,71 +44,6 @@ limitations under the License. */ namespace paddle { namespace platform { -/*! \brief device temporary allocator singleton. - * - * Some operator needs temporary memory during computation, for example, - * conv_gemm, which needs use col to store the result of im2col. If we - * create a stack memory which is used by CUDA Kernel, before the - * Computation(...) returns, we should add ctx->Wait(), because the - * execution of CUDA is async, if there doesn't have ctx->Wait(), - * the temporary memory will be released before the CUDA Kernel uses - * it. - * - * DeviceTemporaryAllocator is a singleton, which contains a - * `TemporaryAllocator` for each . And the TemporaryAllocator - * contains a temp_allocation_queue which is used to store the temporary - * allocations. The allocation, which is allocated by TemporaryAllocator, - * is a unique_ptr, and when it is not held by any variable, it will be - * pushed into the temp_allocation_queue. There are two opportunities to free - * the allocations of temp_allocation_queue: - * - when the Stream calls cudaStreamSynchronize; - * - when the allocation size of opportunities exceeds a certain threshold - * (defined by FLAGS_limit_of_tmp_allocation). - * - * */ -class DeviceTemporaryAllocator { - public: - static DeviceTemporaryAllocator& Instance() { - PADDLE_ENFORCE_NOT_NULL(allocators, - "Need to Create DeviceTemporaryAllocator first!"); - return *allocators; - } - - static DeviceTemporaryAllocator& Init() { - if (allocators == nullptr) { - allocators = new DeviceTemporaryAllocator(); - } - return *allocators; - } - -/*! \brief Return handle of single temporary allocator. */ -#ifdef PADDLE_WITH_CUDA - platform::TemporaryAllocator& Get(const platform::Place& place, - const cudaStream_t& stream); -#endif - template - platform::TemporaryAllocator& Get(const DeviceContext& dev_ctx); - - platform::TemporaryAllocator& Get(const platform::Place& place); - - private: - DeviceTemporaryAllocator() : cpu_allocator_(platform::CPUPlace()) {} - - static DeviceTemporaryAllocator* allocators; - - platform::TemporaryAllocator cpu_allocator_; - -#ifdef PADDLE_WITH_CUDA - std::map, - std::unique_ptr> - device_allocator_; -#endif - - std::mutex mtx_; - - DISABLE_COPY_AND_ASSIGN(DeviceTemporaryAllocator); -}; - class DeviceContext { public: virtual ~DeviceContext() {} @@ -143,102 +77,7 @@ struct DefaultDeviceContextType { #ifdef PADDLE_WITH_CUDA class EigenCudaStreamDevice; -class CudnnHolder { - public: - CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place); - ~CudnnHolder(); - cudnnHandle_t cudnn_handle() const { return cudnn_handle_; } - - private: - friend class CudnnWorkspaceHandle; - void ReallocateWorkspace(size_t required_workspace_len); - - template - void RunFuncImpl(Callback&& cudnn_func, size_t required_workspace_len) { - if (required_workspace_len > WorkspaceSize()) { - ReallocateWorkspace(required_workspace_len); - } - VLOG(2) << "Cudnn workspace size: " - << static_cast(WorkspaceSize()) / (1 << 20) << " MB"; - cudnn_func(WorkspacePtr()); - } - - /*! \brief Reset workspace thus release the memory */ - inline void ResetWorkspace() { - if (workspace_) { - // Maybe someone is using the current workspace - PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(*stream_)); - workspace_ = nullptr; - } - } - - inline void* WorkspacePtr() { - if (workspace_) { - return workspace_->ptr(); - } else { - return nullptr; - } - } - - inline size_t WorkspaceSize() { - if (workspace_) { - return workspace_->size(); - } else { - return 0; - } - } - - std::mutex& Mutex() { return mtx_; } - - cudnnHandle_t cudnn_handle_; - memory::AllocationPtr workspace_; - - const cudaStream_t* stream_; // not owned; - const CUDAPlace place_; - - std::mutex mtx_; -}; - -class CudnnWorkspaceHandle { - public: - /*! \brief The lock would not be acquired when constructor calls. - * The lock would be acquired when RunFunc() is called first time. */ - inline explicit CudnnWorkspaceHandle(CudnnHolder* holder) : holder_(holder) {} - - /*! \brief Thread which call RunFunc() would acquire the lock first - * before invoking cudnn functions. */ - template - inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_len) { - if (!guard_) { - guard_.reset(new std::lock_guard(holder_->Mutex())); - } - holder_->RunFuncImpl(std::forward(cudnn_func), - required_workspace_len); - } - - /*! \brief Thread which call RunFuncSync() would acquire the lock first - * before invoking cudnn function and release gpu memory after running - * the function. Currently this function is only used when cudnn - * exhaustive searching and callers have to guarantee that the input function - * is host blocking */ - template - inline void RunFuncSync(Callback&& cudnn_func, - size_t required_workspace_len) { - if (!guard_) { - guard_.reset(new std::lock_guard(holder_->Mutex())); - } - holder_->RunFuncImpl(std::forward(cudnn_func), - required_workspace_len); - holder_->ResetWorkspace(); - } - - CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default; - CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete; - - private: - CudnnHolder* holder_; // not own - std::unique_ptr> guard_; -}; +class CudnnWorkspaceHandle; class CUDADeviceContext : public DeviceContext { public: @@ -323,9 +162,8 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; - mutable std::unique_ptr cudnn_holder_; cudaStream_t stream_; - + cudnnHandle_t cudnn_handle_; std::unique_ptr cublas_handle_; std::unique_ptr cublas_tensor_core_handle_; @@ -346,11 +184,60 @@ class CUDADeviceContext : public DeviceContext { // StreamCallbackManager is thread-safe std::unique_ptr callback_manager_; - CudnnHolder* cudnn_holder() const; DISABLE_COPY_AND_ASSIGN(CUDADeviceContext); }; +class CudnnWorkspaceHandle { + public: + inline explicit CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx) + : device_context_(dev_ctx) {} + + template + inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_bytes) { + if (required_workspace_bytes > WorkspaceSize()) { + ReallocWorkspace(required_workspace_bytes); + } + VLOG(2) << "Cudnn workspace size at RunFunc: " + << static_cast(WorkspaceSize()) / (1 << 20) << " MB"; + cudnn_func(allocation_ ? allocation_->ptr() : nullptr); + } + + /*! \brief Thread which call RunFuncSync() would release gpu memory after + * running the function. Currently this function is only used when cudnn + * exhaustive searching and callers have to guarantee that the input function + * is host blocking */ + template + inline void RunFuncSync(Callback&& cudnn_func, + size_t required_workspace_bytes) { + RunFunc(cudnn_func, required_workspace_bytes); + ResetWorkspace(); + } + + inline void ReallocWorkspace(size_t required_workspace_bytes) { + if (required_workspace_bytes <= WorkspaceSize()) { + return; + } + allocation_ = memory::Alloc(device_context_, required_workspace_bytes); + } + + inline void ResetWorkspace() { allocation_ = nullptr; } + + inline size_t WorkspaceSize() { + if (allocation_ == nullptr) { + return 0; + } + return allocation_->size(); + } + + CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default; + CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete; + + private: + memory::allocation::AllocationPtr allocation_; + const CUDADeviceContext& device_context_; +}; + template <> struct DefaultDeviceContextType { using TYPE = CUDADeviceContext; diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 0b9b61dbc5797c334837546ced588baafd1493b9..be6519b189011e0d2b09aa09bc2ddb173db2389e 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -153,7 +153,6 @@ void InitDevices(bool init_p2p, const std::vector devices) { } places.emplace_back(platform::CPUPlace()); platform::DeviceContextPool::Init(places); - platform::DeviceTemporaryAllocator::Init(); #ifndef PADDLE_WITH_MKLDNN platform::SetNumThreads(FLAGS_paddle_num_threads); diff --git a/paddle/fluid/platform/temporary_allocator.cc b/paddle/fluid/platform/temporary_allocator.cc deleted file mode 100644 index 6177b024f0ccbeeae14106868e2fc5ca7b8789eb..0000000000000000000000000000000000000000 --- a/paddle/fluid/platform/temporary_allocator.cc +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/platform/temporary_allocator.h" -#include -#include "paddle/fluid/memory/allocation/allocator_facade.h" - -DEFINE_int64(limit_of_tmp_allocation, -1, - "The up limit of temporary_allocation size."); -DEFINE_double(times_excess_than_required_tmp_allocation, 2, - "times_excess_than_required_tmp_allocation indicates the " - "max size the TemporaryAllocator can return. For example, " - "if the required memory size is N, and " - "times_excess_than_required_tmp_allocation is 2.0, " - "the TemporaryAllocator will return the available allocation " - "that the range of size is N ~ 2*N."); - -namespace paddle { -namespace platform { -namespace alloc = memory::allocation; - -TemporaryAllocator::TemporaryAllocator(platform::Place place) : place_(place) { - temp_mem_map_.reset(new std::multimap()); -} - -bool TemporaryAllocator::IsAllocThreadSafe() const { return true; } - -void TemporaryAllocator::Release(const std::function &callback) { - std::unique_ptr> t_allocations; - { - std::unique_lock lock(mtx_); - callback(); - t_allocations.swap(temp_mem_map_); - temp_mem_map_.reset(new std::multimap()); - wait_delete_mem_ = 0; - } - - alloc::AllocationDeleter deleter; - for (auto tmp : *t_allocations) { - VLOG(10) << "Delete temporary allocation " << tmp.second->ptr() - << " size: " << tmp.second->size(); - deleter(tmp.second); - } -} - -void TemporaryAllocator::FreeImpl(alloc::Allocation *temp_allocation) { - if (platform::is_gpu_place(temp_allocation->place())) { - PADDLE_ENFORCE(platform::is_same_place(temp_allocation->place(), place_), - "The place should be the same."); - size_t wait_delete_mem = 0; - { - std::unique_lock lock(mtx_); - temp_mem_map_->emplace(temp_allocation->size(), temp_allocation); - wait_delete_mem_ += temp_allocation->size(); - wait_delete_mem = wait_delete_mem_; - VLOG(10) << "Move temporary allocation: " << temp_allocation->ptr() - << " to delete queue: " << temp_allocation->size() << "; " - << "wait_delete_mem: " << wait_delete_mem; - } - - if (FLAGS_limit_of_tmp_allocation >= 0 && - wait_delete_mem >= static_cast(FLAGS_limit_of_tmp_allocation)) { - PADDLE_ENFORCE(callback_ != nullptr, "The callback is non-initialized."); - Release(callback_); - } - return; - } - VLOG(10) << "Delete temporary allocation " << temp_allocation->ptr() - << " size: " << temp_allocation->size(); - alloc::AllocationDeleter()(temp_allocation); -} - -size_t TemporaryAllocator::TemporaryAllocationQueueSize() { - std::unique_lock lock(mtx_); - return temp_mem_map_ ? temp_mem_map_->size() : 0; -} - -void TemporaryAllocator::SetCallback(const std::function &callback) { - callback_ = callback; -} - -alloc::Allocation *TemporaryAllocator::AllocateImpl(size_t size) { - { - // Find available allocation in temp_mem_map. - std::unique_lock lock(mtx_); - if (temp_mem_map_->size()) { - auto it = temp_mem_map_->lower_bound(size); - // FIXME(zcd): Not sure the best value of excess fraction. - if (it != temp_mem_map_->end() && - it->first < - static_cast( - size * FLAGS_times_excess_than_required_tmp_allocation)) { - auto tmp_ptr = it->second; - temp_mem_map_->erase(it); - wait_delete_mem_ -= tmp_ptr->size(); - VLOG(10) << "Reuse temporary allocation: " << tmp_ptr->ptr() << ": " - << tmp_ptr->size(); - return tmp_ptr; - } - } - } - // If not find the the available allocation, get allocation from - // AllocatorFacadeInstance. - auto temp_mem = alloc::AllocatorFacade::Instance().Alloc(place_, size); - VLOG(10) << "Alloc temporary allocation: " << temp_mem->ptr() << ": " << size; - return temp_mem.release(); -} - -} // namespace platform -} // namespace paddle diff --git a/paddle/fluid/platform/temporary_allocator.h b/paddle/fluid/platform/temporary_allocator.h deleted file mode 100644 index 41f0e4a80b735e6c4eabce864ac5a1dfe1d67ced..0000000000000000000000000000000000000000 --- a/paddle/fluid/platform/temporary_allocator.h +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include // NOLINT -#include -#include -#include -#include // NOLINT -#include "paddle/fluid/memory/allocation/allocator.h" -#include "paddle/fluid/platform/lock_guard_ptr.h" -namespace paddle { -namespace platform { - -/*! \brief the TemporaryAllocator is used to alloc the temporary allocation - * which used by CUDA's async operation. - * - * The TemporaryAllocator contains a temp_allocation_queue which - * is used to store the temporary allocations. The allocation, which is - * allocated by TemporaryAllocator, is a unique_ptr, and when it is not held - * by any variable, it will be pushed into the temp_allocation_queue. - * - * There is one opportunity to free the allocations of temp_allocation_queue: - * - when the allocation size of opportunities exceeds a certain threshold - * (defined by FLAGS_limit_of_tmp_allocation). - * - * */ -class TemporaryAllocator : public memory::allocation::Allocator { - public: - explicit TemporaryAllocator(platform::Place place); - - void Release(const std::function &callback); - - size_t TemporaryAllocationQueueSize(); - - bool IsAllocThreadSafe() const override; - - void SetCallback(const std::function &callback); - - protected: - void FreeImpl(memory::allocation::Allocation *allocation) override; - - memory::allocation::Allocation *AllocateImpl(size_t size) override; - - private: - platform::Place place_; - // When the allocation is not held by any variable, it should be placed - // to temp_mem_map immediately. - std::unique_ptr> - temp_mem_map_{nullptr}; - std::mutex mtx_; - size_t wait_delete_mem_{0}; - std::function callback_; -}; - -} // namespace platform -} // namespace paddle diff --git a/paddle/fluid/platform/temporary_allocator_test.cc b/paddle/fluid/platform/temporary_allocator_test.cc deleted file mode 100644 index a5068eff4943444f3ffe5de555e31888ca2986df..0000000000000000000000000000000000000000 --- a/paddle/fluid/platform/temporary_allocator_test.cc +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/platform/temporary_allocator.h" -#include -#include -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/tensor_util.h" - -DECLARE_int64(limit_of_tmp_allocation); -DECLARE_double(times_excess_than_required_tmp_allocation); - -namespace paddle { -namespace platform { - -class DummyOp : public framework::OperatorBase { - public: - DummyOp(const std::string& type, const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - protected: - void RunImpl(const framework::Scope& scope, - const platform::Place& place) const override {} -}; - -TEST(temporary_allocator, test_base_function) { - platform::CPUPlace cpu_place; - TemporaryAllocator alloc(cpu_place); - alloc.Allocate(100); - -#ifdef PADDLE_WITH_CUDA - platform::CUDAPlace gpu_place(0); - TemporaryAllocator gpu_alloc(gpu_place); - - auto allocation = gpu_alloc.Allocate(101); - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); - gpu_alloc.Release([]() {}); - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); - - { - auto allocation = gpu_alloc.Allocate(102); - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); - } - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1); - gpu_alloc.Release([]() {}); - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); -#endif -} - -TEST(temporary_allocator, test_flags_function) { -#ifdef PADDLE_WITH_CUDA - const int64_t limit = FLAGS_limit_of_tmp_allocation; - FLAGS_limit_of_tmp_allocation = 10; - platform::CUDAPlace gpu_place(0); - TemporaryAllocator gpu_alloc(gpu_place); - - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto* dev_ctx = - static_cast(pool.Get(gpu_place)); - auto stream = dev_ctx->stream(); - bool deleted = false; - gpu_alloc.SetCallback([stream, &deleted]() { - PADDLE_ENFORCE(cudaStreamSynchronize(stream)); - PADDLE_ENFORCE(cudaGetLastError()); - deleted = true; - }); - { gpu_alloc.Allocate(100); } - PADDLE_ENFORCE(deleted); - FLAGS_limit_of_tmp_allocation = limit; -#endif -} - -TEST(temporary_allocator, test_reuse_tmp_allocation) { -#ifdef PADDLE_WITH_CUDA - platform::CUDAPlace gpu_place(0); - TemporaryAllocator gpu_alloc(gpu_place); - gpu_alloc.SetCallback([]() {}); - - void* tmp_allocation_ptr1 = nullptr; - { - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); - auto tmp_allocation1 = gpu_alloc.Allocate(200); - tmp_allocation_ptr1 = tmp_allocation1->ptr(); - } - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1); - auto tmp_allocation2 = gpu_alloc.Allocate(200); - void* tmp_allocation_ptr2 = tmp_allocation2->ptr(); - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); - PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr2); - - auto tmp_allocation3 = gpu_alloc.Allocate(200); - void* tmp_allocation_ptr3 = tmp_allocation2->ptr(); - PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr3); -#endif -} - -TEST(temporary_allocator, test_times_excess_than_required_tmp_allocation) { -#ifdef PADDLE_WITH_CUDA - platform::CUDAPlace gpu_place(0); - TemporaryAllocator gpu_alloc(gpu_place); - gpu_alloc.SetCallback([]() {}); - double excess_fraction = FLAGS_times_excess_than_required_tmp_allocation; - void* tmp_allocation_ptr1 = nullptr; - { - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); - auto tmp_allocation1 = - gpu_alloc.Allocate(static_cast(200 * excess_fraction - 1)); - tmp_allocation_ptr1 = tmp_allocation1->ptr(); - } - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1); - auto tmp_allocation2 = gpu_alloc.Allocate(200 * excess_fraction - 10); - void* tmp_allocation_ptr2 = tmp_allocation2->ptr(); - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); - PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr2); -#endif -} - -TEST(temporary_allocator, create_tensor_with_allocationptr) { - framework::VariableNameMap dummy_vars; - framework::AttributeMap dummy_attrs; - DummyOp op("dummy", dummy_vars, dummy_vars, dummy_attrs); - framework::Scope scope; - framework::VariableValueMap vars; - framework::RuntimeContext run_ctx(vars, vars); - size_t memory_size = 300; - { - platform::CPUPlace cpu_place; - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto* dev_ctx = - static_cast(pool.Get(cpu_place)); - framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr); - - int numel = memory_size / sizeof(float); - framework::Tensor tensor = - ctx.AllocateTmpTensor( - framework::make_ddim({numel}), *dev_ctx); - PADDLE_ENFORCE_EQ(tensor.numel(), numel); - } - -#ifdef PADDLE_WITH_CUDA - { - platform::CUDAPlace gpu_place(0); - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto* dev_ctx = - static_cast(pool.Get(gpu_place)); - framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr); - int numel = memory_size / sizeof(float); - framework::Tensor tensor = - ctx.AllocateTmpTensor( - framework::make_ddim({numel}), *dev_ctx); - PADDLE_ENFORCE_EQ(tensor.numel(), numel); - } -#endif -} - -TEST(temporary_allocator, create_tensor_with_allocationptr2) { - framework::VariableNameMap dummy_vars; - framework::AttributeMap dummy_attrs; - DummyOp op("dummy", dummy_vars, dummy_vars, dummy_attrs); - framework::Scope scope; - framework::VariableValueMap vars; - framework::RuntimeContext run_ctx(vars, vars); - size_t memory_size = 400; - { - platform::CPUPlace cpu_place; - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto* dev_ctx = - static_cast(pool.Get(cpu_place)); - framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr); - int numel = memory_size / sizeof(float); - - framework::Tensor out_side_tensor; - { - framework::Tensor tensor = - ctx.AllocateTmpTensor( - framework::make_ddim({numel}), *dev_ctx); - PADDLE_ENFORCE_EQ(tensor.numel(), numel); - - out_side_tensor.ShareDataWith(tensor); - } - PADDLE_ENFORCE_EQ(out_side_tensor.numel(), numel); - } - -#ifdef PADDLE_WITH_CUDA - { - platform::CUDAPlace gpu_place(0); - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto* dev_ctx = - static_cast(pool.Get(gpu_place)); - framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr); - - size_t memory_size = 500; - int numel = memory_size / sizeof(float); - framework::Tensor out_side_tensor; - { - framework::Tensor tensor = - ctx.AllocateTmpTensor( - framework::make_ddim({numel}), *dev_ctx); - PADDLE_ENFORCE_EQ(tensor.numel(), numel); - - out_side_tensor.ShareDataWith(tensor); - } - PADDLE_ENFORCE_EQ(out_side_tensor.numel(), numel); - } -#endif -} - -} // namespace platform -} // namespace paddle diff --git a/paddle/testing/paddle_gtest_main.cc b/paddle/testing/paddle_gtest_main.cc index 6eb7a246b8588377850a5d77fc552913c7b0514a..d5acff56a9aa9136b84e216f6f8b0f28b528dbc5 100644 --- a/paddle/testing/paddle_gtest_main.cc +++ b/paddle/testing/paddle_gtest_main.cc @@ -57,6 +57,7 @@ int main(int argc, char** argv) { envs.push_back("initial_cpu_memory_in_mb"); envs.push_back("allocator_strategy"); + undefok.push_back("use_pinned_memory"); undefok.push_back("use_mkldnn"); undefok.push_back("initial_cpu_memory_in_mb"); #endif diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 2395def84ea80e0877cf78bf6c7a9a8145594184..8a62b9a4b227e2b56daf8423fc440e2ac85d57b6 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -202,8 +202,6 @@ def __bootstrap__(): 'reallocate_gpu_memory_in_mb', 'cudnn_deterministic', 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', 'cudnn_exhaustive_search', 'selected_gpus', 'sync_nccl_allreduce', - 'limit_of_tmp_allocation', - 'times_excess_than_required_tmp_allocation', 'cudnn_batchnorm_spatial_persistent', 'gpu_allocator_retry_time' ] core.init_gflags([sys.argv[0]] +