diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index f7a6b5ba84ca1762bd903790aa3c0346b22ed035..6f878541e6de1deec1829145b1b325ecd176a034 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -45,11 +45,10 @@ class Tensor { friend struct EigenVector; public: - Tensor() : offset_(0), is_pinned_(false) {} + Tensor() : offset_(0) {} /*! Constructor with place should only be used in pybind. */ - explicit Tensor(const platform::Place& place) - : offset_(0), is_pinned_(false) { + explicit Tensor(const platform::Place& place) : offset_(0) { holder_->set_place(place); } @@ -70,12 +69,11 @@ class Tensor { * @note If not exist, then allocation. */ template - inline T* mutable_data(platform::Place place, bool is_pinned = false); + inline T* mutable_data(platform::Place place); - inline void* mutable_data(platform::Place place, std::type_index type, - bool is_pinned = false); + inline void* mutable_data(platform::Place place, std::type_index type); - inline void* mutable_data(platform::Place place, bool is_pinned = false); + inline void* mutable_data(platform::Place place); /** * @brief Return a pointer to mutable memory block. @@ -86,8 +84,7 @@ class Tensor { * @note If not exist, then allocation. */ template - inline T* mutable_data(DDim dims, platform::Place place, - bool is_pinned = false); + inline T* mutable_data(DDim dims, platform::Place place); /*! Return the dimensions of the memory block. */ inline const DDim& dims() const; @@ -95,9 +92,6 @@ class Tensor { /*! Return the numel of the memory block. */ inline int64_t numel() const; - /*! Return the numel of the memory block. */ - inline bool isPinned() const; - /*! Resize the dimensions of the memory block. */ inline Tensor& Resize(const DDim& dims); @@ -152,14 +146,12 @@ class Tensor { template struct PlaceholderImpl : public Placeholder { - PlaceholderImpl(Place place, size_t size, std::type_index type, - bool is_pinned = false) - : ptr_(static_cast(memory::Alloc(place, size, is_pinned)), - memory::PODDeleter(place, is_pinned)), + PlaceholderImpl(Place place, size_t size, std::type_index type) + : ptr_(static_cast(memory::Alloc(place, size)), + memory::PODDeleter(place)), place_(place), size_(size), - type_(type), - is_pinned_(is_pinned) { + type_(type) { PADDLE_ENFORCE_NOT_NULL(ptr_, "Insufficient %s memory to allocation.", (is_cpu_place(place_) ? "CPU" : "GPU")); } @@ -182,9 +174,6 @@ class Tensor { /* the current type of memory */ std::type_index type_; - - /*! use pinned memory or not. */ - bool is_pinned_; }; /*! holds the memory block if allocated. */ @@ -219,7 +208,6 @@ class Tensor { * PlaceHolder::ptr_ and where the tensor data really begins. */ size_t offset_; - bool is_pinned_; }; inline void Tensor::switch_place(platform::Place new_place) { diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index 113814971e115fa88bd0ded34017fa26a9dd5803..7a4839044008338dda43f75b5ee6def500b78270 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -101,21 +101,19 @@ inline T* Tensor::data() { } template -inline T* Tensor::mutable_data(DDim dims, platform::Place place, - bool is_pinned) { +inline T* Tensor::mutable_data(DDim dims, platform::Place place) { static_assert(std::is_pod::value, "T must be POD"); Resize(dims); - return mutable_data(place, is_pinned); + return mutable_data(place); } template -inline T* Tensor::mutable_data(platform::Place place, bool is_pinned) { +inline T* Tensor::mutable_data(platform::Place place) { static_assert(std::is_pod::value, "T must be POD"); - return reinterpret_cast(mutable_data(place, typeid(T), is_pinned)); + return reinterpret_cast(mutable_data(place, typeid(T))); } -inline void* Tensor::mutable_data(platform::Place place, std::type_index type, - bool is_pinned) { +inline void* Tensor::mutable_data(platform::Place place, std::type_index type) { if (holder_ != nullptr) { holder_->set_type(type); } @@ -129,27 +127,26 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type, holder_->size() < size + offset_) { if (platform::is_cpu_place(place)) { holder_.reset(new PlaceholderImpl( - boost::get(place), size, type, is_pinned)); + boost::get(place), size, type)); } else if (platform::is_gpu_place(place)) { #ifndef PADDLE_WITH_CUDA PADDLE_THROW("'CUDAPlace' is not supported in CPU only device."); } #else holder_.reset(new PlaceholderImpl( - boost::get(place), size, type, is_pinned)); + boost::get(place), size, type)); } #endif offset_ = 0; - is_pinned_ = is_pinned; } return reinterpret_cast(reinterpret_cast(holder_->ptr()) + offset_); } -inline void* Tensor::mutable_data(platform::Place place, bool is_pinned) { +inline void* Tensor::mutable_data(platform::Place place) { PADDLE_ENFORCE(this->holder_ != nullptr, "Cannot invoke mutable data if current hold nothing"); - return mutable_data(place, holder_->type(), is_pinned); + return mutable_data(place, holder_->type()); } inline Tensor& Tensor::ShareDataWith(const Tensor& src) { @@ -191,8 +188,6 @@ inline const DDim& Tensor::dims() const { return dims_; } inline int64_t Tensor::numel() const { return product(dims_); } -inline bool Tensor::isPinned() const { return is_pinned_; } - inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { Tensor res; res.ShareDataWith(src); diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 8b7533ce712b0a01060842b6f71449ed6bd23e2c..1d864af011bced9df188147ec436b8de12947ba9 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -148,6 +148,11 @@ struct AnyVisitor : public boost::static_visitor { const platform::CPUPlace& cpu) const { return *out.data(); } + + bool GetResult(const framework::Tensor& out, + const platform::CUDAPinnedPlace& cpu) const { + return *out.data(); + } }; template diff --git a/paddle/fluid/memory/CMakeLists.txt b/paddle/fluid/memory/CMakeLists.txt index 1a61c484823b292234d4758cdc1959d7a21510e6..8b3043af7a18787a08583d47b76da679ccb63740 100644 --- a/paddle/fluid/memory/CMakeLists.txt +++ b/paddle/fluid/memory/CMakeLists.txt @@ -4,13 +4,17 @@ cc_library(memory SRCS memory.cc DEPS place enforce) cc_library(memcpy SRCS memcpy.cc DEPS place) cc_library(paddle_memory - DEPS - memory - memcpy - meta_data - meta_cache - memory_block - buddy_allocator - system_allocator) + DEPS + memory + memcpy + meta_data + meta_cache + memory_block + buddy_allocator + system_allocator) cc_test(memory_test SRCS memory_test.cc DEPS place paddle_memory) + +#if (WITH_GPU) +# nv_test(pinned_memory_test SRCS pinned_memory_test.cu DEPS place paddle_memory) +#endif() diff --git a/paddle/fluid/memory/detail/system_allocator.cc b/paddle/fluid/memory/detail/system_allocator.cc index 22f6f506748735d1a0fe75375aeea22bd92b8b7e..a45f8c33ee5956f3409ee1b7c43628aa0acafb98 100644 --- a/paddle/fluid/memory/detail/system_allocator.cc +++ b/paddle/fluid/memory/detail/system_allocator.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/memory/detail/system_allocator.h" #include "paddle/fluid/platform/assert.h" +#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/gpu_info.h" @@ -134,21 +135,31 @@ bool GPUAllocator::UseGpu() const { return true; } // memory. It’s locked to a physical address. void* CUDAPinnedAllocator::Alloc(size_t& index, size_t size) { if (size <= 0) return nullptr; - void* p; - // NOTE: here, we use GpuMaxAllocSize() as the maximum memory size + + // NOTE: here, we use CUDAPinnedMaxAllocSize as the maximum memory size // of host pinned allocation. Allocates too much would reduce // the amount of memory available to the underlying system for paging. + size_t usable = + paddle::platform::CUDAPinnedMaxAllocSize() - cuda_pinnd_alloc_size_; - size_t usable = paddle::platform::GpuMaxAllocSize() - fallback_alloc_size_; - - if (size > usable) return nullptr; + if (size > usable) { + LOG(WARNING) << "Cannot malloc " << size / 1024.0 / 1024.0 + << " MB pinned memory." + << ", available " << usable / 1024.0 / 1024.0 << " MB"; + return nullptr; + } + void* p; // PINNED memory is visible to all CUDA contexts. cudaError_t result = cudaMallocHost(&p, size); + if (result == cudaSuccess) { - index = 1; - fallback_alloc_size_ += size; + index = 1; // PINNED memory + cuda_pinnd_alloc_size_ += size; return p; + } else { + LOG(WARNING) << "cudaMallocHost failed."; + return nullptr; } return nullptr; @@ -158,8 +169,8 @@ void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) { cudaError_t err; PADDLE_ASSERT(index == 1); - PADDLE_ASSERT(fallback_alloc_size_ >= size); - fallback_alloc_size_ -= size; + PADDLE_ASSERT(cuda_pinnd_alloc_size_ >= size); + cuda_pinnd_alloc_size_ -= size; err = cudaFreeHost(p); // Purposefully allow cudaErrorCudartUnloading, because @@ -172,7 +183,7 @@ void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) { } } -bool CUDAPinnedAllocator::UseGpu() const { return true; } +bool CUDAPinnedAllocator::UseGpu() const { return false; } #endif diff --git a/paddle/fluid/memory/detail/system_allocator.h b/paddle/fluid/memory/detail/system_allocator.h index e8479e73f433f1d741b2933da4843c0ba80276d5..e3c50ef6483c61e2016bbd967a4100057c87dca3 100644 --- a/paddle/fluid/memory/detail/system_allocator.h +++ b/paddle/fluid/memory/detail/system_allocator.h @@ -21,8 +21,9 @@ namespace memory { namespace detail { /** - * \brief SystemAllocator is the parent class of CPUAllocator and GPUAllocator. - * A BuddyAllocator object uses a SystemAllocator* pointing to the + * \brief SystemAllocator is the parent class of CPUAllocator, + * CUDAPinnedAllocator and GPUAllocator. A BuddyAllocator + * object uses a SystemAllocator* pointing to the * underlying system allocator. */ class SystemAllocator { @@ -62,9 +63,7 @@ class CUDAPinnedAllocator : public SystemAllocator { virtual bool UseGpu() const; private: - size_t gpu_alloc_size_ = - 0; // TODO(zcd): how to define the upper limit of CUDAPinnedMemory? - size_t fallback_alloc_size_ = 0; + size_t cuda_pinnd_alloc_size_ = 0; }; #endif diff --git a/paddle/fluid/memory/memcpy.cc b/paddle/fluid/memory/memcpy.cc index b991360d0442ec2d258443a931a9dcf10b332f1e..eddcaab8befda84dd14ed46c31ac025dfbcc7ca9 100644 --- a/paddle/fluid/memory/memcpy.cc +++ b/paddle/fluid/memory/memcpy.cc @@ -56,6 +56,45 @@ void Copy( } } +template <> +void Copy( + platform::CPUPlace dst_place, void* dst, + platform::CUDAPinnedPlace src_place, const void* src, size_t num) { + std::memcpy(dst, src, num); +} + +template <> +void Copy( + platform::CUDAPinnedPlace dst_place, void* dst, + platform::CPUPlace src_place, const void* src, size_t num) { + std::memcpy(dst, src, num); +} + +template <> +void Copy( + platform::CUDAPinnedPlace dst_place, void* dst, + platform::CUDAPinnedPlace src_place, const void* src, size_t num) { + std::memcpy(dst, src, num); +} + +template <> +void Copy( + platform::CUDAPinnedPlace dst_place, void* dst, + platform::CUDAPlace src_place, const void* src, size_t num, + cudaStream_t stream) { + platform::SetDeviceId(src_place.device); + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); +} + +template <> +void Copy( + platform::CUDAPlace dst_place, void* dst, + platform::CUDAPinnedPlace src_place, const void* src, size_t num, + cudaStream_t stream) { + platform::SetDeviceId(dst_place.device); + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); +} + #endif } // namespace memory diff --git a/paddle/fluid/memory/memory.cc b/paddle/fluid/memory/memory.cc index 56593653a622bce323306d86156d140c46f58d18..09f82166beab369416e351dbb8ecd09f759bfbda 100644 --- a/paddle/fluid/memory/memory.cc +++ b/paddle/fluid/memory/memory.cc @@ -38,8 +38,7 @@ BuddyAllocator* GetCPUBuddyAllocator() { } template <> -void* Alloc(platform::CPUPlace place, size_t size, - bool is_pinned) { +void* Alloc(platform::CPUPlace place, size_t size) { VLOG(10) << "Allocate " << size << " bytes on " << platform::Place(place); void* p = GetCPUBuddyAllocator()->Alloc(size); VLOG(10) << " pointer=" << p; @@ -47,8 +46,7 @@ void* Alloc(platform::CPUPlace place, size_t size, } template <> -void Free(platform::CPUPlace place, void* p, - bool is_pinned) { +void Free(platform::CPUPlace place, void* p) { VLOG(10) << "Free pointer=" << p << " on " << platform::Place(place); GetCPUBuddyAllocator()->Free(p); } @@ -84,47 +82,15 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { return as[gpu_id]; } -BuddyAllocator* GetCUDAPinnedBuddyAllocator(int gpu_id) { - static BuddyAllocator** as = NULL; - if (as == NULL) { - int gpu_num = platform::GetCUDADeviceCount(); - as = new BuddyAllocator*[gpu_num]; - for (int gpu = 0; gpu < gpu_num; gpu++) { - as[gpu] = nullptr; - } - } - platform::SetDeviceId(gpu_id); - if (!as[gpu_id]) { - as[gpu_id] = new BuddyAllocator(new detail::CUDAPinnedAllocator, - platform::GpuMinChunkSize(), - platform::GpuMaxChunkSize()); - VLOG(10) << "\n\nNOTE: each GPU device use " - << FLAGS_fraction_of_gpu_memory_to_use * 100 - << "% of GPU memory.\n" - << "You can set GFlags environment variable '" - << "FLAGS_fraction_of_gpu_memory_to_use" - << "' to change the fraction of GPU usage.\n\n"; - } - return as[gpu_id]; -} - template <> size_t Used(platform::CUDAPlace place) { return GetGPUBuddyAllocator(place.device)->Used(); } template <> -void* Alloc(platform::CUDAPlace place, size_t size, - bool is_pinned) { - void* ptr; - if (is_pinned) { - auto* buddy_allocator = GetCUDAPinnedBuddyAllocator(place.device); - ptr = buddy_allocator->Alloc(size); - } else { - auto* buddy_allocator = GetGPUBuddyAllocator(place.device); - ptr = buddy_allocator->Alloc(size); - } - +void* Alloc(platform::CUDAPlace place, size_t size) { + auto* buddy_allocator = GetGPUBuddyAllocator(place.device); + auto* ptr = buddy_allocator->Alloc(size); if (ptr == nullptr) { int cur_dev = platform::GetCurrentDeviceId(); platform::SetDeviceId(place.device); @@ -142,15 +108,42 @@ void* Alloc(platform::CUDAPlace place, size_t size, } template <> -void Free(platform::CUDAPlace place, void* p, - bool is_pinned) { - if (is_pinned) { - GetCUDAPinnedBuddyAllocator(place.device)->Free(p); - } else { - GetGPUBuddyAllocator(place.device)->Free(p); +void Free(platform::CUDAPlace place, void* p) { + GetGPUBuddyAllocator(place.device)->Free(p); +} + +BuddyAllocator* GetCUDAPinnedBuddyAllocator() { + static BuddyAllocator* ba = NULL; + if (ba == NULL) { + ba = new BuddyAllocator(new detail::CUDAPinnedAllocator, + platform::CUDAPinnedMinChunkSize(), + platform::CUDAPinnedMaxChunkSize()); } + return ba; } +template <> +size_t Used(platform::CUDAPinnedPlace place) { + return GetCUDAPinnedBuddyAllocator()->Used(); +} + +template <> +void* Alloc(platform::CUDAPinnedPlace place, + size_t size) { + auto* buddy_allocator = GetCUDAPinnedBuddyAllocator(); + void* ptr = buddy_allocator->Alloc(size); + + if (ptr == nullptr) { + LOG(WARNING) << "cudaMallocHost Cannot allocate " << size + << " bytes in CUDAPinnedPlace"; + } + return ptr; +} + +template <> +void Free(platform::CUDAPinnedPlace place, void* p) { + GetCUDAPinnedBuddyAllocator()->Free(p); +} #endif size_t Usage::operator()(const platform::CPUPlace& cpu) const { @@ -165,6 +158,14 @@ size_t Usage::operator()(const platform::CUDAPlace& gpu) const { #endif } +size_t Usage::operator()(const platform::CUDAPinnedPlace& cuda_pinned) const { +#ifdef PADDLE_WITH_CUDA + return Used(cuda_pinned); +#else + PADDLE_THROW("'CUDAPinnedPlace' is not supported in CPU only device."); +#endif +} + size_t memory_usage(const platform::Place& p) { return boost::apply_visitor(Usage(), p); } diff --git a/paddle/fluid/memory/memory.h b/paddle/fluid/memory/memory.h index 062bfc880e78dc5d90c567ffe5c4e521704c9ca6..3e6bfddd69cb16edf323d040ea5369cd551f299e 100644 --- a/paddle/fluid/memory/memory.h +++ b/paddle/fluid/memory/memory.h @@ -33,7 +33,7 @@ namespace memory { * address is valid or not. */ template -void* Alloc(Place place, size_t size, bool is_pinned = false); +void* Alloc(Place place, size_t size); /** * \brief Free memory block in one place. @@ -43,7 +43,7 @@ void* Alloc(Place place, size_t size, bool is_pinned = false); * */ template -void Free(Place place, void* ptr, bool is_pinned = false); +void Free(Place place, void* ptr); /** * \brief Total size of used memory in one place. @@ -57,6 +57,7 @@ size_t Used(Place place); struct Usage : public boost::static_visitor { size_t operator()(const platform::CPUPlace& cpu) const; size_t operator()(const platform::CUDAPlace& gpu) const; + size_t operator()(const platform::CUDAPinnedPlace& cuda_pinned) const; }; size_t memory_usage(const platform::Place& p); @@ -74,13 +75,11 @@ class PODDeleter { static_assert(std::is_pod::value, "T must be POD"); public: - explicit PODDeleter(Place place, bool is_pinned = false) - : place_(place), is_pinned_(is_pinned) {} - void operator()(T* ptr) { Free(place_, static_cast(ptr), is_pinned_); } + explicit PODDeleter(Place place) : place_(place) {} + void operator()(T* ptr) { Free(place_, static_cast(ptr)); } private: Place place_; - bool is_pinned_; }; /** diff --git a/paddle/fluid/memory/memory_test.cc b/paddle/fluid/memory/memory_test.cc index eb27a52b254c1cda065197746eb179bbd1d7f2f1..03829702a0c5c3dc177381b4ad3d012fda8f537d 100644 --- a/paddle/fluid/memory/memory_test.cc +++ b/paddle/fluid/memory/memory_test.cc @@ -141,4 +141,59 @@ TEST(BuddyAllocator, GPUMultAlloc) { } } +size_t align(size_t size, paddle::platform::CUDAPinnedPlace place) { + size += sizeof(paddle::memory::detail::Metadata); + size_t alignment = paddle::platform::CUDAPinnedMinChunkSize(); + size_t remaining = size % alignment; + return remaining == 0 ? size : size + (alignment - remaining); +} + +TEST(BuddyAllocator, CUDAPinnedAllocator) { + void *p = nullptr; + + EXPECT_EQ(p, nullptr); + + paddle::platform::CUDAPinnedPlace cpu; + p = paddle::memory::Alloc(cpu, 4096); + + EXPECT_NE(p, nullptr); + + paddle::platform::Place place = cpu; + EXPECT_EQ(paddle::memory::Used(cpu), paddle::memory::memory_usage(place)); + + paddle::memory::Free(cpu, p); +} + +TEST(BuddyAllocator, CUDAPinnedMultAllocator) { + paddle::platform::CUDAPinnedPlace cpu; + + std::unordered_map ps; + + size_t total_size = paddle::memory::Used(cpu); + EXPECT_EQ(total_size, 0UL); + + for (auto size : + {0, 128, 256, 1024, 4096, 16384, 65536, 262144, 1048576, 4194304}) { + ps[paddle::memory::Alloc(cpu, size)] = size; + + // Buddy Allocator doesn't manage too large memory chunk + if (paddle::memory::Used(cpu) == total_size) continue; + + size_t aligned_size = align(size, cpu); + total_size += aligned_size; + EXPECT_EQ(total_size, paddle::memory::Used(cpu)); + } + + for (auto p : ps) { + EXPECT_EQ(is_aligned(p.first), true); + paddle::memory::Free(cpu, p.first); + + // Buddy Allocator doesn't manage too large memory chunk + if (paddle::memory::Used(cpu) == total_size) continue; + + size_t aligned_size = align(p.second, cpu); + total_size -= aligned_size; + EXPECT_EQ(total_size, paddle::memory::Used(cpu)); + } +} #endif diff --git a/paddle/fluid/memory/pinned_memory_test.cu b/paddle/fluid/memory/pinned_memory_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..a000001f41788fb16ac075426f06357cbe42d642 --- /dev/null +++ b/paddle/fluid/memory/pinned_memory_test.cu @@ -0,0 +1,147 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include +#include + +#include "paddle/fluid/memory/detail/memory_block.h" +#include "paddle/fluid/memory/detail/meta_data.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/memory/memory.h" + +#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/fluid/platform/gpu_info.h" +#include "paddle/fluid/platform/place.h" + +// This unit test is an example comparing the performance between using pinned +// memory and not. In general, using pinned memory will be faster. +template +__global__ void Kernel(T* output, int dim) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < dim) { + output[tid] = output[tid] * output[tid] / 100; + } +} + +template +float test_pinned_memory() { + Place cpu_place; + paddle::platform::CUDAPlace cuda_place; + + const int data_size = 4096; + const int iteration = 10; + + // create event start and end + cudaEvent_t start_e, stop_e, copying_e; + float elapsedTime = 0; + cudaEventCreate(&start_e); + cudaEventCreate(&stop_e); + cudaEventCreate(©ing_e); + + // create computation stream, data copying stream + cudaStream_t computation_stream, copying_stream; + cudaStreamCreate(&computation_stream); + cudaStreamCreate(©ing_stream); + + // create record event, pinned memory, gpu memory + std::vector record_event(iteration); + std::vector input_pinned_mem(iteration); + std::vector gpu_mem(iteration); + std::vector output_pinned_mem(iteration); + + // initial data + for (int j = 0; j < iteration; ++j) { + cudaEventCreateWithFlags(&record_event[j], cudaEventDisableTiming); + cudaEventCreate(&(record_event[j])); + input_pinned_mem[j] = static_cast( + paddle::memory::Alloc(cpu_place, data_size * sizeof(float))); + output_pinned_mem[j] = static_cast( + paddle::memory::Alloc(cpu_place, data_size * sizeof(float))); + gpu_mem[j] = static_cast( + paddle::memory::Alloc(cuda_place, data_size * sizeof(float))); + + for (int k = 0; k < data_size; ++k) { + input_pinned_mem[j][k] = k; + } + } + + cudaEventRecord(start_e, computation_stream); + + // computation + for (int m = 0; m < 30; ++m) { + for (int i = 0; i < iteration; ++i) { + // cpu -> GPU on computation stream. + // note: this operation is async for pinned memory. + paddle::memory::Copy(cuda_place, gpu_mem[i], cpu_place, + input_pinned_mem[i], data_size * sizeof(float), + computation_stream); + + // call kernel on computation stream. + Kernel<<<4, 1024, 0, computation_stream>>>(gpu_mem[i], data_size); + + // record event_computation on computation stream + cudaEventRecord(record_event[i], computation_stream); + + // wait event_computation on copy stream. + // note: this operation is async. + cudaStreamWaitEvent(copying_stream, record_event[i], 0); + + // copy data GPU->CPU, on copy stream. + // note: this operation is async for pinned memory. + paddle::memory::Copy(cpu_place, output_pinned_mem[i], cuda_place, + gpu_mem[i], data_size * sizeof(float), + copying_stream); + } + } + + cudaEventRecord(copying_e, copying_stream); + cudaStreamWaitEvent(computation_stream, copying_e, 0); + + cudaEventRecord(stop_e, computation_stream); + + cudaEventSynchronize(start_e); + cudaEventSynchronize(stop_e); + cudaEventElapsedTime(&elapsedTime, start_e, stop_e); + + // std::cout << cpu_place << " " + // << "time consume:" << elapsedTime / 30 << std::endl; + + for (int l = 0; l < iteration; ++l) { + for (int k = 0; k < data_size; ++k) { + float temp = input_pinned_mem[l][k]; + temp = temp * temp / 100; + EXPECT_FLOAT_EQ(temp, output_pinned_mem[l][k]); + } + } + + // destroy resource + cudaEventDestroy(copying_e); + cudaEventDestroy(start_e); + cudaEventDestroy(stop_e); + for (int j = 0; j < 10; ++j) { + cudaEventDestroy((record_event[j])); + paddle::memory::Free(cpu_place, input_pinned_mem[j]); + paddle::memory::Free(cpu_place, output_pinned_mem[j]); + paddle::memory::Free(cuda_place, gpu_mem[j]); + } + return elapsedTime / 30; +} + +TEST(CPUANDCUDAPinned, CPUAllocatorAndCUDAPinnedAllocator) { + // Generally speaking, operation on pinned_memory is faster than that on + // unpinned-memory, but if this unit test fails frequently, please close this + // test for the time being. + float time1 = test_pinned_memory(); + float time2 = test_pinned_memory(); + EXPECT_GT(time1, time2); +} diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index 299a0aed01dfe0448d896738d9fd33319b1b2887..44fd739fb1d161c6c7d6ab1cc611c59220280a4e 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -322,6 +322,14 @@ void set_constant_with_place( TensorSetConstantCPU(tensor, value)); } +template <> +void set_constant_with_place( + const platform::DeviceContext& context, framework::Tensor* tensor, + float value) { + framework::VisitDataType(framework::ToDataType(tensor->type()), + TensorSetConstantCPU(tensor, value)); +} + struct TensorSetConstantWithPlace : public boost::static_visitor { TensorSetConstantWithPlace(const platform::DeviceContext& context, framework::Tensor* tensor, float value) diff --git a/paddle/fluid/platform/cpu_info.cc b/paddle/fluid/platform/cpu_info.cc index 8db08edba805e41d33ec6a6a4b338cca0d4906ef..4fc9aae8e36e9b43d65fab0b92ec3a2549057128 100644 --- a/paddle/fluid/platform/cpu_info.cc +++ b/paddle/fluid/platform/cpu_info.cc @@ -27,6 +27,11 @@ DEFINE_double(fraction_of_cpu_memory_to_use, 1, "Default use 100% of CPU memory for PaddlePaddle," "reserve the rest for page tables, etc"); +DEFINE_double( + fraction_of_cuda_pinned_memory_to_use, 0.5, + "Default use 50% of CPU memory as the pinned_memory for PaddlePaddle," + "reserve the rest for page tables, etc"); + namespace paddle { namespace platform { @@ -62,5 +67,22 @@ size_t CpuMaxChunkSize() { return CpuMaxAllocSize() / 32; } +size_t CUDAPinnedMaxAllocSize() { + // For distributed systems, it requires configuring and limiting + // the fraction of memory to use. + return FLAGS_fraction_of_cuda_pinned_memory_to_use * CpuTotalPhysicalMemory(); +} + +size_t CUDAPinnedMinChunkSize() { + // Allow to allocate the minimum chunk size is 64 KB. + return 1 << 16; +} + +size_t CUDAPinnedMaxChunkSize() { + // Allow to allocate the maximum chunk size is roughly 1/256 of CUDA_PINNED + // memory. + return CUDAPinnedMaxAllocSize() / 256; +} + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/cpu_info.h b/paddle/fluid/platform/cpu_info.h index a930151bd15a33d5b8861c6239e7dd964822f0f6..f06c2b67fe4385f427322e9bb2f3080fdd3acc94 100644 --- a/paddle/fluid/platform/cpu_info.h +++ b/paddle/fluid/platform/cpu_info.h @@ -22,11 +22,20 @@ namespace platform { //! Get the maximum allocation size for a machine. size_t CpuMaxAllocSize(); +//! Get the maximum allocation size for a machine. +size_t CUDAPinnedMaxAllocSize(); + //! Get the minimum chunk size for buddy allocator. size_t CpuMinChunkSize(); //! Get the maximum chunk size for buddy allocator. size_t CpuMaxChunkSize(); +//! Get the minimum chunk size for buddy allocator. +size_t CUDAPinnedMinChunkSize(); + +//! Get the maximum chunk size for buddy allocator. +size_t CUDAPinnedMaxChunkSize(); + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 59b76a1edb5ec5900520fbccb6a6f8f6e7a70aa4..feb4f367008d76d86a93c561a8eec1f2485c99d6 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -53,6 +53,16 @@ DeviceContextPool::DeviceContextPool( PADDLE_THROW( "'CUDAPlace' is not supported, Please re-compile with WITH_GPU " "option"); +#endif + } else if (platform::is_cuda_pinned_place(p)) { +#ifdef PADDLE_WITH_CUDA + device_contexts_.emplace( + p, + PtrType(new CUDAPinnedDeviceContext(boost::get(p)))); +#else + PADDLE_THROW( + "'CUDAPlace' is not supported, Please re-compile with WITH_GPU " + "option"); #endif } } @@ -186,6 +196,20 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; } +CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() { + eigen_device_.reset(new Eigen::DefaultDevice()); +} + +CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place) + : place_(place) { + eigen_device_.reset(new Eigen::DefaultDevice()); +} + +Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const { + return eigen_device_.get(); +} + +Place CUDAPinnedDeviceContext::GetPlace() const { return place_; } #endif #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 202394c7be7e103a609dd0999fc883c794ef0edd..6b796d92d09cdde2db60c7651c03d3782ff013e6 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -118,6 +118,25 @@ struct DefaultDeviceContextType { using TYPE = CUDADeviceContext; }; +// Currently, CUDAPinnedDeviceContext is only used to data copying. +class CUDAPinnedDeviceContext : public DeviceContext { + public: + CUDAPinnedDeviceContext(); + explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place); + + Place GetPlace() const override; + + Eigen::DefaultDevice* eigen_device() const; + + private: + CUDAPinnedPlace place_; + std::unique_ptr eigen_device_; +}; + +template <> +struct DefaultDeviceContextType { + using TYPE = CUDAPinnedDeviceContext; +}; #endif #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/fluid/platform/place.cc b/paddle/fluid/platform/place.cc index de8f958eb012cb1ac563cbbbac8951e439bf8f33..655ce8485d4584aa0955315b045da6bf541f7fe2 100644 --- a/paddle/fluid/platform/place.cc +++ b/paddle/fluid/platform/place.cc @@ -26,6 +26,7 @@ class PlacePrinter : public boost::static_visitor<> { void operator()(const CUDAPlace &p) { os_ << "CUDAPlace(" << p.device << ")"; } + void operator()(const CUDAPinnedPlace &p) { os_ << "CUDAPinnedPlace"; } private: std::ostream &os_; @@ -40,12 +41,19 @@ const Place &get_place() { return the_default_place; } const CUDAPlace default_gpu() { return CUDAPlace(0); } const CPUPlace default_cpu() { return CPUPlace(); } +const CUDAPinnedPlace default_cuda_pinned() { return CUDAPinnedPlace(); } bool is_gpu_place(const Place &p) { return boost::apply_visitor(IsCUDAPlace(), p); } -bool is_cpu_place(const Place &p) { return !is_gpu_place(p); } +bool is_cpu_place(const Place &p) { + return boost::apply_visitor(IsCPUPlace(), p); +} + +bool is_cuda_pinned_place(const Place &p) { + return boost::apply_visitor(IsCUDAPinnedPlace(), p); +} bool places_are_same_class(const Place &p1, const Place &p2) { return p1.which() == p2.which(); @@ -53,7 +61,7 @@ bool places_are_same_class(const Place &p1, const Place &p2) { bool is_same_place(const Place &p1, const Place &p2) { if (places_are_same_class(p1, p2)) { - if (is_cpu_place(p1)) { + if (is_cpu_place(p1) || is_cuda_pinned_place(p1)) { return true; } else { return boost::get(p1) == boost::get(p2); diff --git a/paddle/fluid/platform/place.h b/paddle/fluid/platform/place.h index 4cc8b377b8b671eb5a446ecbae21ba9628fbd2c8..d0bdcb0da5177f9f8ad517787e612f1b98b3fbb4 100644 --- a/paddle/fluid/platform/place.h +++ b/paddle/fluid/platform/place.h @@ -45,12 +45,33 @@ struct CUDAPlace { int device; }; +struct CUDAPinnedPlace { + CUDAPinnedPlace() {} + + // needed for variant equality comparison + inline bool operator==(const CUDAPinnedPlace &) const { return true; } + inline bool operator!=(const CUDAPinnedPlace &) const { return false; } +}; + struct IsCUDAPlace : public boost::static_visitor { bool operator()(const CPUPlace &) const { return false; } bool operator()(const CUDAPlace &gpu) const { return true; } + bool operator()(const CUDAPinnedPlace &) const { return false; } }; -typedef boost::variant Place; +struct IsCPUPlace : public boost::static_visitor { + bool operator()(const CPUPlace &cpu) const { return true; } + bool operator()(const CUDAPlace &) const { return false; } + bool operator()(const CUDAPinnedPlace &) const { return false; } +}; + +struct IsCUDAPinnedPlace : public boost::static_visitor { + bool operator()(const CPUPlace &) const { return false; } + bool operator()(const CUDAPlace &) const { return false; } + bool operator()(const CUDAPinnedPlace &cuda_pinned) const { return true; } +}; + +typedef boost::variant Place; using PlaceList = std::vector; @@ -59,9 +80,11 @@ const Place &get_place(); const CUDAPlace default_gpu(); const CPUPlace default_cpu(); +const CUDAPinnedPlace default_cuda_pinned(); bool is_gpu_place(const Place &); bool is_cpu_place(const Place &); +bool is_cuda_pinned_place(const Place &); bool places_are_same_class(const Place &, const Place &); bool is_same_place(const Place &, const Place &); @@ -95,6 +118,16 @@ struct PlaceVisitorWrapper #else PADDLE_THROW("Paddle is not compiled with CUDA. Cannot visit cuda device"); return typename Visitor::result_type(); +#endif + } + + typename Visitor::result_type operator()( + const CUDAPinnedPlace &cuda_pinned) const { +#ifdef PADDLE_WITH_CUDA + return visitor_(cuda_pinned); +#else + PADDLE_THROW("Paddle is not compiled with CUDA. Cannot visit cuda_pinned"); + return typename Visitor::result_type(); #endif } };