diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index c277bd7cb69bba899296efe64107ee538c4aa847..128a5344fbb8c64c36ade24475bd0d99bdb3e0f5 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -21,6 +21,9 @@ #include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #endif +#include +#include + namespace paddle { namespace framework { namespace details { @@ -168,6 +171,11 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( */ PolishGraphToSupportDataHazards(&result); + /* + * Only variables should be the leaves of graph. + */ + AddOutputToLeafOps(&result); + if (VLOG_IS_ON(10)) { std::ostringstream sout; PrintGraphviz(*graph, sout); diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 361ba6d39721eed406a30fea325b3b4508ec45d0..0a4febd22f3feefdcac99cafc2cb58269380d192 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -136,6 +136,17 @@ void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) { sout << "}\n"; } + +void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { + for (auto &op : graph->ops_) { + if (!op->outputs_.empty()) { + continue; + } + auto *dummy_leaf = new DummyVarHandle(); + graph->dep_vars_.emplace(dummy_leaf); + op->AddOutput(dummy_leaf); + } +} } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index bf20e7164a100718c1dcfe3ef971cfff60bbbaa2..be1f0460e45402806b18835f054a7195df1374cc 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -14,13 +14,13 @@ #pragma once +#include +#include + #include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/place.h" -#include -#include - namespace paddle { namespace framework { namespace details { @@ -52,6 +52,8 @@ class SSAGraphBuilder { const std::string &each_var_name, const platform::Place &place, size_t place_offset); + static void AddOutputToLeafOps(SSAGraph *graph); + static void PrintGraphviz(const SSAGraph &graph, std::ostream &sout); }; } // namespace details diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 1f96b9dc6235a18f7566c98cca60baa964e6aa56..596e5731868630cebc3cf51b2e78d4deb39a9b33 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -87,7 +87,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // Step 2. Insert FetchOps std::vector> fetch_ops; - std::vector dummy_vars; FeedFetchList fetch_data(fetch_tensors.size()); std::unordered_map> fetched_vars; @@ -101,13 +100,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( } } + std::unordered_set> fetch_dependencies; for (size_t i = 0; i < fetch_tensors.size(); ++i) { auto &var_name = fetch_tensors[i]; auto &vars = fetched_vars.at(var_name); auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_); fetch_ops.emplace_back(op); - // FIXME: Use new device context for (auto &p : places_) { op->dev_ctxes_[p] = fetch_ctxs_.Get(p); } @@ -115,6 +114,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( for (auto *var : vars) { op->AddInput(var); } + + auto *fetch_dummy = new DummyVarHandle(); + op->AddOutput(fetch_dummy); + fetch_dependencies.emplace(fetch_dummy); + InsertPendingVar(*fetch_dummy); InsertPendingOp(*op); } diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index f7a6b5ba84ca1762bd903790aa3c0346b22ed035..6f878541e6de1deec1829145b1b325ecd176a034 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -45,11 +45,10 @@ class Tensor { friend struct EigenVector; public: - Tensor() : offset_(0), is_pinned_(false) {} + Tensor() : offset_(0) {} /*! Constructor with place should only be used in pybind. */ - explicit Tensor(const platform::Place& place) - : offset_(0), is_pinned_(false) { + explicit Tensor(const platform::Place& place) : offset_(0) { holder_->set_place(place); } @@ -70,12 +69,11 @@ class Tensor { * @note If not exist, then allocation. */ template - inline T* mutable_data(platform::Place place, bool is_pinned = false); + inline T* mutable_data(platform::Place place); - inline void* mutable_data(platform::Place place, std::type_index type, - bool is_pinned = false); + inline void* mutable_data(platform::Place place, std::type_index type); - inline void* mutable_data(platform::Place place, bool is_pinned = false); + inline void* mutable_data(platform::Place place); /** * @brief Return a pointer to mutable memory block. @@ -86,8 +84,7 @@ class Tensor { * @note If not exist, then allocation. */ template - inline T* mutable_data(DDim dims, platform::Place place, - bool is_pinned = false); + inline T* mutable_data(DDim dims, platform::Place place); /*! Return the dimensions of the memory block. */ inline const DDim& dims() const; @@ -95,9 +92,6 @@ class Tensor { /*! Return the numel of the memory block. */ inline int64_t numel() const; - /*! Return the numel of the memory block. */ - inline bool isPinned() const; - /*! Resize the dimensions of the memory block. */ inline Tensor& Resize(const DDim& dims); @@ -152,14 +146,12 @@ class Tensor { template struct PlaceholderImpl : public Placeholder { - PlaceholderImpl(Place place, size_t size, std::type_index type, - bool is_pinned = false) - : ptr_(static_cast(memory::Alloc(place, size, is_pinned)), - memory::PODDeleter(place, is_pinned)), + PlaceholderImpl(Place place, size_t size, std::type_index type) + : ptr_(static_cast(memory::Alloc(place, size)), + memory::PODDeleter(place)), place_(place), size_(size), - type_(type), - is_pinned_(is_pinned) { + type_(type) { PADDLE_ENFORCE_NOT_NULL(ptr_, "Insufficient %s memory to allocation.", (is_cpu_place(place_) ? "CPU" : "GPU")); } @@ -182,9 +174,6 @@ class Tensor { /* the current type of memory */ std::type_index type_; - - /*! use pinned memory or not. */ - bool is_pinned_; }; /*! holds the memory block if allocated. */ @@ -219,7 +208,6 @@ class Tensor { * PlaceHolder::ptr_ and where the tensor data really begins. */ size_t offset_; - bool is_pinned_; }; inline void Tensor::switch_place(platform::Place new_place) { diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index 113814971e115fa88bd0ded34017fa26a9dd5803..7a4839044008338dda43f75b5ee6def500b78270 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -101,21 +101,19 @@ inline T* Tensor::data() { } template -inline T* Tensor::mutable_data(DDim dims, platform::Place place, - bool is_pinned) { +inline T* Tensor::mutable_data(DDim dims, platform::Place place) { static_assert(std::is_pod::value, "T must be POD"); Resize(dims); - return mutable_data(place, is_pinned); + return mutable_data(place); } template -inline T* Tensor::mutable_data(platform::Place place, bool is_pinned) { +inline T* Tensor::mutable_data(platform::Place place) { static_assert(std::is_pod::value, "T must be POD"); - return reinterpret_cast(mutable_data(place, typeid(T), is_pinned)); + return reinterpret_cast(mutable_data(place, typeid(T))); } -inline void* Tensor::mutable_data(platform::Place place, std::type_index type, - bool is_pinned) { +inline void* Tensor::mutable_data(platform::Place place, std::type_index type) { if (holder_ != nullptr) { holder_->set_type(type); } @@ -129,27 +127,26 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type, holder_->size() < size + offset_) { if (platform::is_cpu_place(place)) { holder_.reset(new PlaceholderImpl( - boost::get(place), size, type, is_pinned)); + boost::get(place), size, type)); } else if (platform::is_gpu_place(place)) { #ifndef PADDLE_WITH_CUDA PADDLE_THROW("'CUDAPlace' is not supported in CPU only device."); } #else holder_.reset(new PlaceholderImpl( - boost::get(place), size, type, is_pinned)); + boost::get(place), size, type)); } #endif offset_ = 0; - is_pinned_ = is_pinned; } return reinterpret_cast(reinterpret_cast(holder_->ptr()) + offset_); } -inline void* Tensor::mutable_data(platform::Place place, bool is_pinned) { +inline void* Tensor::mutable_data(platform::Place place) { PADDLE_ENFORCE(this->holder_ != nullptr, "Cannot invoke mutable data if current hold nothing"); - return mutable_data(place, holder_->type(), is_pinned); + return mutable_data(place, holder_->type()); } inline Tensor& Tensor::ShareDataWith(const Tensor& src) { @@ -191,8 +188,6 @@ inline const DDim& Tensor::dims() const { return dims_; } inline int64_t Tensor::numel() const { return product(dims_); } -inline bool Tensor::isPinned() const { return is_pinned_; } - inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { Tensor res; res.ShareDataWith(src); diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 8b7533ce712b0a01060842b6f71449ed6bd23e2c..1d864af011bced9df188147ec436b8de12947ba9 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -148,6 +148,11 @@ struct AnyVisitor : public boost::static_visitor { const platform::CPUPlace& cpu) const { return *out.data(); } + + bool GetResult(const framework::Tensor& out, + const platform::CUDAPinnedPlace& cpu) const { + return *out.data(); + } }; template diff --git a/paddle/fluid/memory/CMakeLists.txt b/paddle/fluid/memory/CMakeLists.txt index 1a61c484823b292234d4758cdc1959d7a21510e6..8b3043af7a18787a08583d47b76da679ccb63740 100644 --- a/paddle/fluid/memory/CMakeLists.txt +++ b/paddle/fluid/memory/CMakeLists.txt @@ -4,13 +4,17 @@ cc_library(memory SRCS memory.cc DEPS place enforce) cc_library(memcpy SRCS memcpy.cc DEPS place) cc_library(paddle_memory - DEPS - memory - memcpy - meta_data - meta_cache - memory_block - buddy_allocator - system_allocator) + DEPS + memory + memcpy + meta_data + meta_cache + memory_block + buddy_allocator + system_allocator) cc_test(memory_test SRCS memory_test.cc DEPS place paddle_memory) + +#if (WITH_GPU) +# nv_test(pinned_memory_test SRCS pinned_memory_test.cu DEPS place paddle_memory) +#endif() diff --git a/paddle/fluid/memory/detail/system_allocator.cc b/paddle/fluid/memory/detail/system_allocator.cc index 22f6f506748735d1a0fe75375aeea22bd92b8b7e..a45f8c33ee5956f3409ee1b7c43628aa0acafb98 100644 --- a/paddle/fluid/memory/detail/system_allocator.cc +++ b/paddle/fluid/memory/detail/system_allocator.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/memory/detail/system_allocator.h" #include "paddle/fluid/platform/assert.h" +#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/gpu_info.h" @@ -134,21 +135,31 @@ bool GPUAllocator::UseGpu() const { return true; } // memory. It’s locked to a physical address. void* CUDAPinnedAllocator::Alloc(size_t& index, size_t size) { if (size <= 0) return nullptr; - void* p; - // NOTE: here, we use GpuMaxAllocSize() as the maximum memory size + + // NOTE: here, we use CUDAPinnedMaxAllocSize as the maximum memory size // of host pinned allocation. Allocates too much would reduce // the amount of memory available to the underlying system for paging. + size_t usable = + paddle::platform::CUDAPinnedMaxAllocSize() - cuda_pinnd_alloc_size_; - size_t usable = paddle::platform::GpuMaxAllocSize() - fallback_alloc_size_; - - if (size > usable) return nullptr; + if (size > usable) { + LOG(WARNING) << "Cannot malloc " << size / 1024.0 / 1024.0 + << " MB pinned memory." + << ", available " << usable / 1024.0 / 1024.0 << " MB"; + return nullptr; + } + void* p; // PINNED memory is visible to all CUDA contexts. cudaError_t result = cudaMallocHost(&p, size); + if (result == cudaSuccess) { - index = 1; - fallback_alloc_size_ += size; + index = 1; // PINNED memory + cuda_pinnd_alloc_size_ += size; return p; + } else { + LOG(WARNING) << "cudaMallocHost failed."; + return nullptr; } return nullptr; @@ -158,8 +169,8 @@ void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) { cudaError_t err; PADDLE_ASSERT(index == 1); - PADDLE_ASSERT(fallback_alloc_size_ >= size); - fallback_alloc_size_ -= size; + PADDLE_ASSERT(cuda_pinnd_alloc_size_ >= size); + cuda_pinnd_alloc_size_ -= size; err = cudaFreeHost(p); // Purposefully allow cudaErrorCudartUnloading, because @@ -172,7 +183,7 @@ void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) { } } -bool CUDAPinnedAllocator::UseGpu() const { return true; } +bool CUDAPinnedAllocator::UseGpu() const { return false; } #endif diff --git a/paddle/fluid/memory/detail/system_allocator.h b/paddle/fluid/memory/detail/system_allocator.h index e8479e73f433f1d741b2933da4843c0ba80276d5..e3c50ef6483c61e2016bbd967a4100057c87dca3 100644 --- a/paddle/fluid/memory/detail/system_allocator.h +++ b/paddle/fluid/memory/detail/system_allocator.h @@ -21,8 +21,9 @@ namespace memory { namespace detail { /** - * \brief SystemAllocator is the parent class of CPUAllocator and GPUAllocator. - * A BuddyAllocator object uses a SystemAllocator* pointing to the + * \brief SystemAllocator is the parent class of CPUAllocator, + * CUDAPinnedAllocator and GPUAllocator. A BuddyAllocator + * object uses a SystemAllocator* pointing to the * underlying system allocator. */ class SystemAllocator { @@ -62,9 +63,7 @@ class CUDAPinnedAllocator : public SystemAllocator { virtual bool UseGpu() const; private: - size_t gpu_alloc_size_ = - 0; // TODO(zcd): how to define the upper limit of CUDAPinnedMemory? - size_t fallback_alloc_size_ = 0; + size_t cuda_pinnd_alloc_size_ = 0; }; #endif diff --git a/paddle/fluid/memory/memcpy.cc b/paddle/fluid/memory/memcpy.cc index b991360d0442ec2d258443a931a9dcf10b332f1e..eddcaab8befda84dd14ed46c31ac025dfbcc7ca9 100644 --- a/paddle/fluid/memory/memcpy.cc +++ b/paddle/fluid/memory/memcpy.cc @@ -56,6 +56,45 @@ void Copy( } } +template <> +void Copy( + platform::CPUPlace dst_place, void* dst, + platform::CUDAPinnedPlace src_place, const void* src, size_t num) { + std::memcpy(dst, src, num); +} + +template <> +void Copy( + platform::CUDAPinnedPlace dst_place, void* dst, + platform::CPUPlace src_place, const void* src, size_t num) { + std::memcpy(dst, src, num); +} + +template <> +void Copy( + platform::CUDAPinnedPlace dst_place, void* dst, + platform::CUDAPinnedPlace src_place, const void* src, size_t num) { + std::memcpy(dst, src, num); +} + +template <> +void Copy( + platform::CUDAPinnedPlace dst_place, void* dst, + platform::CUDAPlace src_place, const void* src, size_t num, + cudaStream_t stream) { + platform::SetDeviceId(src_place.device); + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); +} + +template <> +void Copy( + platform::CUDAPlace dst_place, void* dst, + platform::CUDAPinnedPlace src_place, const void* src, size_t num, + cudaStream_t stream) { + platform::SetDeviceId(dst_place.device); + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); +} + #endif } // namespace memory diff --git a/paddle/fluid/memory/memory.cc b/paddle/fluid/memory/memory.cc index 56593653a622bce323306d86156d140c46f58d18..09f82166beab369416e351dbb8ecd09f759bfbda 100644 --- a/paddle/fluid/memory/memory.cc +++ b/paddle/fluid/memory/memory.cc @@ -38,8 +38,7 @@ BuddyAllocator* GetCPUBuddyAllocator() { } template <> -void* Alloc(platform::CPUPlace place, size_t size, - bool is_pinned) { +void* Alloc(platform::CPUPlace place, size_t size) { VLOG(10) << "Allocate " << size << " bytes on " << platform::Place(place); void* p = GetCPUBuddyAllocator()->Alloc(size); VLOG(10) << " pointer=" << p; @@ -47,8 +46,7 @@ void* Alloc(platform::CPUPlace place, size_t size, } template <> -void Free(platform::CPUPlace place, void* p, - bool is_pinned) { +void Free(platform::CPUPlace place, void* p) { VLOG(10) << "Free pointer=" << p << " on " << platform::Place(place); GetCPUBuddyAllocator()->Free(p); } @@ -84,47 +82,15 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { return as[gpu_id]; } -BuddyAllocator* GetCUDAPinnedBuddyAllocator(int gpu_id) { - static BuddyAllocator** as = NULL; - if (as == NULL) { - int gpu_num = platform::GetCUDADeviceCount(); - as = new BuddyAllocator*[gpu_num]; - for (int gpu = 0; gpu < gpu_num; gpu++) { - as[gpu] = nullptr; - } - } - platform::SetDeviceId(gpu_id); - if (!as[gpu_id]) { - as[gpu_id] = new BuddyAllocator(new detail::CUDAPinnedAllocator, - platform::GpuMinChunkSize(), - platform::GpuMaxChunkSize()); - VLOG(10) << "\n\nNOTE: each GPU device use " - << FLAGS_fraction_of_gpu_memory_to_use * 100 - << "% of GPU memory.\n" - << "You can set GFlags environment variable '" - << "FLAGS_fraction_of_gpu_memory_to_use" - << "' to change the fraction of GPU usage.\n\n"; - } - return as[gpu_id]; -} - template <> size_t Used(platform::CUDAPlace place) { return GetGPUBuddyAllocator(place.device)->Used(); } template <> -void* Alloc(platform::CUDAPlace place, size_t size, - bool is_pinned) { - void* ptr; - if (is_pinned) { - auto* buddy_allocator = GetCUDAPinnedBuddyAllocator(place.device); - ptr = buddy_allocator->Alloc(size); - } else { - auto* buddy_allocator = GetGPUBuddyAllocator(place.device); - ptr = buddy_allocator->Alloc(size); - } - +void* Alloc(platform::CUDAPlace place, size_t size) { + auto* buddy_allocator = GetGPUBuddyAllocator(place.device); + auto* ptr = buddy_allocator->Alloc(size); if (ptr == nullptr) { int cur_dev = platform::GetCurrentDeviceId(); platform::SetDeviceId(place.device); @@ -142,15 +108,42 @@ void* Alloc(platform::CUDAPlace place, size_t size, } template <> -void Free(platform::CUDAPlace place, void* p, - bool is_pinned) { - if (is_pinned) { - GetCUDAPinnedBuddyAllocator(place.device)->Free(p); - } else { - GetGPUBuddyAllocator(place.device)->Free(p); +void Free(platform::CUDAPlace place, void* p) { + GetGPUBuddyAllocator(place.device)->Free(p); +} + +BuddyAllocator* GetCUDAPinnedBuddyAllocator() { + static BuddyAllocator* ba = NULL; + if (ba == NULL) { + ba = new BuddyAllocator(new detail::CUDAPinnedAllocator, + platform::CUDAPinnedMinChunkSize(), + platform::CUDAPinnedMaxChunkSize()); } + return ba; } +template <> +size_t Used(platform::CUDAPinnedPlace place) { + return GetCUDAPinnedBuddyAllocator()->Used(); +} + +template <> +void* Alloc(platform::CUDAPinnedPlace place, + size_t size) { + auto* buddy_allocator = GetCUDAPinnedBuddyAllocator(); + void* ptr = buddy_allocator->Alloc(size); + + if (ptr == nullptr) { + LOG(WARNING) << "cudaMallocHost Cannot allocate " << size + << " bytes in CUDAPinnedPlace"; + } + return ptr; +} + +template <> +void Free(platform::CUDAPinnedPlace place, void* p) { + GetCUDAPinnedBuddyAllocator()->Free(p); +} #endif size_t Usage::operator()(const platform::CPUPlace& cpu) const { @@ -165,6 +158,14 @@ size_t Usage::operator()(const platform::CUDAPlace& gpu) const { #endif } +size_t Usage::operator()(const platform::CUDAPinnedPlace& cuda_pinned) const { +#ifdef PADDLE_WITH_CUDA + return Used(cuda_pinned); +#else + PADDLE_THROW("'CUDAPinnedPlace' is not supported in CPU only device."); +#endif +} + size_t memory_usage(const platform::Place& p) { return boost::apply_visitor(Usage(), p); } diff --git a/paddle/fluid/memory/memory.h b/paddle/fluid/memory/memory.h index 062bfc880e78dc5d90c567ffe5c4e521704c9ca6..3e6bfddd69cb16edf323d040ea5369cd551f299e 100644 --- a/paddle/fluid/memory/memory.h +++ b/paddle/fluid/memory/memory.h @@ -33,7 +33,7 @@ namespace memory { * address is valid or not. */ template -void* Alloc(Place place, size_t size, bool is_pinned = false); +void* Alloc(Place place, size_t size); /** * \brief Free memory block in one place. @@ -43,7 +43,7 @@ void* Alloc(Place place, size_t size, bool is_pinned = false); * */ template -void Free(Place place, void* ptr, bool is_pinned = false); +void Free(Place place, void* ptr); /** * \brief Total size of used memory in one place. @@ -57,6 +57,7 @@ size_t Used(Place place); struct Usage : public boost::static_visitor { size_t operator()(const platform::CPUPlace& cpu) const; size_t operator()(const platform::CUDAPlace& gpu) const; + size_t operator()(const platform::CUDAPinnedPlace& cuda_pinned) const; }; size_t memory_usage(const platform::Place& p); @@ -74,13 +75,11 @@ class PODDeleter { static_assert(std::is_pod::value, "T must be POD"); public: - explicit PODDeleter(Place place, bool is_pinned = false) - : place_(place), is_pinned_(is_pinned) {} - void operator()(T* ptr) { Free(place_, static_cast(ptr), is_pinned_); } + explicit PODDeleter(Place place) : place_(place) {} + void operator()(T* ptr) { Free(place_, static_cast(ptr)); } private: Place place_; - bool is_pinned_; }; /** diff --git a/paddle/fluid/memory/memory_test.cc b/paddle/fluid/memory/memory_test.cc index eb27a52b254c1cda065197746eb179bbd1d7f2f1..03829702a0c5c3dc177381b4ad3d012fda8f537d 100644 --- a/paddle/fluid/memory/memory_test.cc +++ b/paddle/fluid/memory/memory_test.cc @@ -141,4 +141,59 @@ TEST(BuddyAllocator, GPUMultAlloc) { } } +size_t align(size_t size, paddle::platform::CUDAPinnedPlace place) { + size += sizeof(paddle::memory::detail::Metadata); + size_t alignment = paddle::platform::CUDAPinnedMinChunkSize(); + size_t remaining = size % alignment; + return remaining == 0 ? size : size + (alignment - remaining); +} + +TEST(BuddyAllocator, CUDAPinnedAllocator) { + void *p = nullptr; + + EXPECT_EQ(p, nullptr); + + paddle::platform::CUDAPinnedPlace cpu; + p = paddle::memory::Alloc(cpu, 4096); + + EXPECT_NE(p, nullptr); + + paddle::platform::Place place = cpu; + EXPECT_EQ(paddle::memory::Used(cpu), paddle::memory::memory_usage(place)); + + paddle::memory::Free(cpu, p); +} + +TEST(BuddyAllocator, CUDAPinnedMultAllocator) { + paddle::platform::CUDAPinnedPlace cpu; + + std::unordered_map ps; + + size_t total_size = paddle::memory::Used(cpu); + EXPECT_EQ(total_size, 0UL); + + for (auto size : + {0, 128, 256, 1024, 4096, 16384, 65536, 262144, 1048576, 4194304}) { + ps[paddle::memory::Alloc(cpu, size)] = size; + + // Buddy Allocator doesn't manage too large memory chunk + if (paddle::memory::Used(cpu) == total_size) continue; + + size_t aligned_size = align(size, cpu); + total_size += aligned_size; + EXPECT_EQ(total_size, paddle::memory::Used(cpu)); + } + + for (auto p : ps) { + EXPECT_EQ(is_aligned(p.first), true); + paddle::memory::Free(cpu, p.first); + + // Buddy Allocator doesn't manage too large memory chunk + if (paddle::memory::Used(cpu) == total_size) continue; + + size_t aligned_size = align(p.second, cpu); + total_size -= aligned_size; + EXPECT_EQ(total_size, paddle::memory::Used(cpu)); + } +} #endif diff --git a/paddle/fluid/memory/pinned_memory_test.cu b/paddle/fluid/memory/pinned_memory_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..a000001f41788fb16ac075426f06357cbe42d642 --- /dev/null +++ b/paddle/fluid/memory/pinned_memory_test.cu @@ -0,0 +1,147 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include +#include + +#include "paddle/fluid/memory/detail/memory_block.h" +#include "paddle/fluid/memory/detail/meta_data.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/memory/memory.h" + +#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/fluid/platform/gpu_info.h" +#include "paddle/fluid/platform/place.h" + +// This unit test is an example comparing the performance between using pinned +// memory and not. In general, using pinned memory will be faster. +template +__global__ void Kernel(T* output, int dim) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < dim) { + output[tid] = output[tid] * output[tid] / 100; + } +} + +template +float test_pinned_memory() { + Place cpu_place; + paddle::platform::CUDAPlace cuda_place; + + const int data_size = 4096; + const int iteration = 10; + + // create event start and end + cudaEvent_t start_e, stop_e, copying_e; + float elapsedTime = 0; + cudaEventCreate(&start_e); + cudaEventCreate(&stop_e); + cudaEventCreate(©ing_e); + + // create computation stream, data copying stream + cudaStream_t computation_stream, copying_stream; + cudaStreamCreate(&computation_stream); + cudaStreamCreate(©ing_stream); + + // create record event, pinned memory, gpu memory + std::vector record_event(iteration); + std::vector input_pinned_mem(iteration); + std::vector gpu_mem(iteration); + std::vector output_pinned_mem(iteration); + + // initial data + for (int j = 0; j < iteration; ++j) { + cudaEventCreateWithFlags(&record_event[j], cudaEventDisableTiming); + cudaEventCreate(&(record_event[j])); + input_pinned_mem[j] = static_cast( + paddle::memory::Alloc(cpu_place, data_size * sizeof(float))); + output_pinned_mem[j] = static_cast( + paddle::memory::Alloc(cpu_place, data_size * sizeof(float))); + gpu_mem[j] = static_cast( + paddle::memory::Alloc(cuda_place, data_size * sizeof(float))); + + for (int k = 0; k < data_size; ++k) { + input_pinned_mem[j][k] = k; + } + } + + cudaEventRecord(start_e, computation_stream); + + // computation + for (int m = 0; m < 30; ++m) { + for (int i = 0; i < iteration; ++i) { + // cpu -> GPU on computation stream. + // note: this operation is async for pinned memory. + paddle::memory::Copy(cuda_place, gpu_mem[i], cpu_place, + input_pinned_mem[i], data_size * sizeof(float), + computation_stream); + + // call kernel on computation stream. + Kernel<<<4, 1024, 0, computation_stream>>>(gpu_mem[i], data_size); + + // record event_computation on computation stream + cudaEventRecord(record_event[i], computation_stream); + + // wait event_computation on copy stream. + // note: this operation is async. + cudaStreamWaitEvent(copying_stream, record_event[i], 0); + + // copy data GPU->CPU, on copy stream. + // note: this operation is async for pinned memory. + paddle::memory::Copy(cpu_place, output_pinned_mem[i], cuda_place, + gpu_mem[i], data_size * sizeof(float), + copying_stream); + } + } + + cudaEventRecord(copying_e, copying_stream); + cudaStreamWaitEvent(computation_stream, copying_e, 0); + + cudaEventRecord(stop_e, computation_stream); + + cudaEventSynchronize(start_e); + cudaEventSynchronize(stop_e); + cudaEventElapsedTime(&elapsedTime, start_e, stop_e); + + // std::cout << cpu_place << " " + // << "time consume:" << elapsedTime / 30 << std::endl; + + for (int l = 0; l < iteration; ++l) { + for (int k = 0; k < data_size; ++k) { + float temp = input_pinned_mem[l][k]; + temp = temp * temp / 100; + EXPECT_FLOAT_EQ(temp, output_pinned_mem[l][k]); + } + } + + // destroy resource + cudaEventDestroy(copying_e); + cudaEventDestroy(start_e); + cudaEventDestroy(stop_e); + for (int j = 0; j < 10; ++j) { + cudaEventDestroy((record_event[j])); + paddle::memory::Free(cpu_place, input_pinned_mem[j]); + paddle::memory::Free(cpu_place, output_pinned_mem[j]); + paddle::memory::Free(cuda_place, gpu_mem[j]); + } + return elapsedTime / 30; +} + +TEST(CPUANDCUDAPinned, CPUAllocatorAndCUDAPinnedAllocator) { + // Generally speaking, operation on pinned_memory is faster than that on + // unpinned-memory, but if this unit test fails frequently, please close this + // test for the time being. + float time1 = test_pinned_memory(); + float time2 = test_pinned_memory(); + EXPECT_GT(time1, time2); +} diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 650bc92be22af9ea8afcacf590a11190109e8811..695db841a4ec666b2c8783dfc7df959711341d85 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/conv_op.h" + +#include +#include + #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" #endif diff --git a/paddle/fluid/operators/detail/bytebuffer_stream.cc b/paddle/fluid/operators/detail/bytebuffer_stream.cc index 741dd51de9e75feb608161579e56cb160b058ebb..a14171563edb0ac9a22b7ae493c965de3efb7823 100644 --- a/paddle/fluid/operators/detail/bytebuffer_stream.cc +++ b/paddle/fluid/operators/detail/bytebuffer_stream.cc @@ -17,7 +17,7 @@ limitations under the License. */ // file and did some modifications so that we can send gRPC // requests without too much copying of the tensor data. -#include "bytebuffer_stream.h" +#include "paddle/fluid/operators/detail/bytebuffer_stream.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/detail/bytebuffer_stream.h b/paddle/fluid/operators/detail/bytebuffer_stream.h index 1791a48aab1b66147f645c90757b35ef5f6e001b..054dd4ff294414cca55d7e033f2c5403bbb85526 100644 --- a/paddle/fluid/operators/detail/bytebuffer_stream.h +++ b/paddle/fluid/operators/detail/bytebuffer_stream.h @@ -19,9 +19,11 @@ limitations under the License. */ #pragma once -#include +#include + #include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/zero_copy_stream.h" +#include "grpc++/grpc++.h" namespace grpc { // A ZeroCopyInputStream that reads from grpc_byte_buffer @@ -56,7 +58,7 @@ class GrpcBufferReader final *data = GRPC_SLICE_START_PTR(slice_) + GRPC_SLICE_LENGTH(slice_) - backup_count_; GPR_CODEGEN_ASSERT(backup_count_ <= INT_MAX); - *size = (int)backup_count_; + *size = static_cast(backup_count_); backup_count_ = 0; return true; } @@ -68,7 +70,7 @@ class GrpcBufferReader final *data = GRPC_SLICE_START_PTR(slice_); // On win x64, int is only 32bit GPR_CODEGEN_ASSERT(GRPC_SLICE_LENGTH(slice_) <= INT_MAX); - byte_count_ += * size = (int)GRPC_SLICE_LENGTH(slice_); + byte_count_ += * size = static_cast(GRPC_SLICE_LENGTH(slice_)); return true; } diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index d79ba6d291950e1f089eb11713bd1c3e4d154b27..ef987d07f08525bff5267cdc2076ae767417e4f1 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -14,6 +14,8 @@ limitations under the License. */ #include "paddle/fluid/operators/detail/grpc_client.h" +#include + #include #include "paddle/fluid/framework/threadpool.h" @@ -54,7 +56,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, auto call = s->stub_g_.PrepareUnaryCall( s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_); call->StartCall(); - call->Finish(&s->reply_, &s->status_, static_cast(s)); + call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); }); req_count_++; @@ -66,7 +68,7 @@ void ProcGetResponse(const VarHandle& var_h, // const sendrecv::VariableMessage& ret_msg) { const ::grpc::ByteBuffer& ret_msg) { framework::Variable* outvar = NULL; - DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, outvar); + DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar); } template @@ -110,7 +112,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, auto call = s->stub_g_.PrepareUnaryCall( s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_); call->StartCall(); - call->Finish(&s->reply_, &s->status_, static_cast(s)); + call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); }); req_count_++; @@ -170,7 +172,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { sendrecv::VariableMessage req; req.set_varname(BATCH_BARRIER_MESSAGE); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, static_cast(s)); + rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); req_count_++; } @@ -182,7 +184,7 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { sendrecv::VariableMessage req; req.set_varname(FETCH_BARRIER_MESSAGE); auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, static_cast(s)); + rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); req_count_++; } diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index fe237e54ef61fb5b6e9bfa46fbe6b3df3dd40265..4425b19328f503eb7f9022916ed6452cdfea4eeb 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -14,10 +14,9 @@ limitations under the License. */ #pragma once -#include -#include #include -#include + +#include // NOLINT #include #include #include @@ -25,11 +24,11 @@ limitations under the License. */ #include #include -#include -#include -#include -#include - +#include "grpc++/generic/generic_stub.h" +#include "grpc++/grpc++.h" +#include "grpc++/support/byte_buffer.h" +#include "grpc++/support/slice.h" +#include "grpc/support/log.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 7c978b28b6873d05afb435de4caf7f4ce5d33193..19bba46e3bd49a689fbe1d0c93efe01806fb0228 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -273,7 +273,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { // FIXME(typhoonzero): change cq_name to enum. void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, - std::string cq_name, + const std::string& cq_name, std::function TryToRegisterNewOne) { TryToRegisterNewOne(); diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index b0596d3cd1e108f28e8f1485d6b5c989c55be7e9..5b5033018c6aefdc165886fc4e9086ff0a7e9201 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -14,10 +14,11 @@ limitations under the License. */ #pragma once -#include #include +#include // NOLINT #include +#include "grpc++/grpc++.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/program_desc.h" @@ -71,7 +72,8 @@ class AsyncGRPCServer final { void ShutDown(); protected: - void HandleRequest(::grpc::ServerCompletionQueue *cq, std::string cq_name, + void HandleRequest(::grpc::ServerCompletionQueue *cq, + const std::string &cq_name, std::function TryToRegisterNewOne); void TryToRegisterNewSendOne(); void TryToRegisterNewGetOne(); diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index 1ad62863a1a98c28cb08f47dfa8a5bfae463ba91..b89aed0157de8e95564015b3e7f42316a39537f5 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include #include -#include +#include // NOLINT #include "gtest/gtest.h" #include "paddle/fluid/operators/detail/grpc_client.h" diff --git a/paddle/fluid/operators/detail/proto_encoder_helper.h b/paddle/fluid/operators/detail/proto_encoder_helper.h index 4a7bfb8bd586fe84c9243bc64117d146c4386674..d91d054b2507f32d1e948dde33da06a70cabe775 100644 --- a/paddle/fluid/operators/detail/proto_encoder_helper.h +++ b/paddle/fluid/operators/detail/proto_encoder_helper.h @@ -19,7 +19,9 @@ limitations under the License. */ #pragma once -#include +#include + +#include "grpc++/grpc++.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -142,6 +144,6 @@ class ProtoEncodeHelper { char* limit_; // Just for CHECKs }; -} // detail -} // operators -} // paddle +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index 7e3f015dabdb3fd6190d1ca2f422aa526e8889cd..f8576d01b10f4c0fda4d12d371b2966739acfc21 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/detail/sendrecvop_utils.h" + #include -#include +#include // NOLINT + #include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/zero_copy_stream.h" #include "paddle/fluid/framework/data_type.h" @@ -42,7 +44,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void* buf = malloc(1024); void* payload = nullptr; size_t payload_size; - ProtoEncodeHelper e((char*)buf, 1024); + ProtoEncodeHelper e(static_cast(buf), 1024); e.WriteString(VarMsg::kVarnameFieldNumber, name); if (var->IsType()) { e.WriteUint64(VarMsg::kTypeFieldNumber, 0); @@ -152,7 +154,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, framework::proto::VarType_Type_SELECTED_ROWS) { auto* slr = var->GetMutable(); - ProtoEncodeHelper e2((char*)buf, 128); + ProtoEncodeHelper e2(static_cast(buf), 128); // NOTE: rows is of type int64_t size_t rows_memory_size = slr->rows().size() * framework::SizeOfType(typeid(int64_t)); @@ -181,10 +183,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const platform::DeviceContext& ctx, const framework::Scope* scope, - framework::Variable*& var) { + framework::Variable** var) { operators::detail::VariableResponse resp(scope, &ctx); PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); - var = resp.GetVar(); + *var = resp.GetVar(); } } // namespace detail diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.h b/paddle/fluid/operators/detail/sendrecvop_utils.h index b3b2b8469c8f19313038f2551ab04708a05656d5..d7954440846b8db9a9add0110fb9a546a762774d 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.h +++ b/paddle/fluid/operators/detail/sendrecvop_utils.h @@ -51,7 +51,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const platform::DeviceContext& ctx, const framework::Scope* scope, - framework::Variable*& var); + framework::Variable** var); inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { switch (type) { diff --git a/paddle/fluid/operators/detail/serde_test.cc b/paddle/fluid/operators/detail/serde_test.cc index ea1670e56f3c2fedc2617db1425472e52c6519f5..f8cae6b26acf9d37ca286487065d70ede4c03120 100644 --- a/paddle/fluid/operators/detail/serde_test.cc +++ b/paddle/fluid/operators/detail/serde_test.cc @@ -14,9 +14,9 @@ limitations under the License. */ #include #include -#include +#include // NOLINT -#include +#include "google/protobuf/text_format.h" #include "gtest/gtest.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/tensor_util.h" @@ -107,7 +107,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { for (int i = 0; i < tensor_numel; ++i) { EXPECT_FLOAT_EQ(tensor_data2[i], 32.7); } - for (int i = 0; i < rows2->size(); ++i) { + for (int64_t i = 0; i < rows2->size(); ++i) { EXPECT_EQ(rows_data2[i], i); } EXPECT_EQ(slr2->height(), 1000); diff --git a/paddle/fluid/operators/detail/simple_block_queue.h b/paddle/fluid/operators/detail/simple_block_queue.h index 36b58b0c6700b5af7eaea92d2b0c32adaba35bb8..69773e05df7ed76f31c26f4304693fec2e9aac9c 100644 --- a/paddle/fluid/operators/detail/simple_block_queue.h +++ b/paddle/fluid/operators/detail/simple_block_queue.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once -#include +#include // NOLINT #include -#include +#include // NOLINT namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index 01eb8acc558231d443d4617578cc56d4e895c2f2..78e1d274a92241b5f2093beb63acdc8c497dfb83 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -112,7 +112,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, bool VariableResponse::CopyLodTensorData( ::google::protobuf::io::CodedInputStream* input, - const platform::DeviceContext& ctx, framework::DDim& dims, int length) { + const platform::DeviceContext& ctx, const framework::DDim& dims, + int length) { auto var = scope_->FindVar(meta_.varname()); auto* tensor = var->GetMutable(); tensor->Resize(dims); @@ -148,7 +149,8 @@ inline framework::DDim GetDims( bool VariableResponse::CopySelectRowsTensorData( ::google::protobuf::io::CodedInputStream* input, - const platform::DeviceContext& ctx, framework::DDim& dims, int length) { + const platform::DeviceContext& ctx, const framework::DDim& dims, + int length) { auto var = scope_->FindVar(meta_.varname()); auto* slr = var->GetMutable(); slr->set_height(meta_.slr_height()); diff --git a/paddle/fluid/operators/detail/variable_response.h b/paddle/fluid/operators/detail/variable_response.h index e121ed7bce966d7dea94f71087f2187dcaa17cec..050b6b84010b4f3e95bc88e5bb738ff18b7fe423 100644 --- a/paddle/fluid/operators/detail/variable_response.h +++ b/paddle/fluid/operators/detail/variable_response.h @@ -14,6 +14,8 @@ #pragma once +#include + #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" @@ -60,14 +62,14 @@ class VariableResponse { private: bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, - framework::DDim& dims, int length); + const framework::DDim& dims, int length); bool CopySelectRowsData(::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, int length); bool CopyLodTensorData(::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, - framework::DDim& dims, int length); + const framework::DDim& dims, int length); private: const framework::Scope* scope_; diff --git a/paddle/fluid/operators/fc_mkldnn_op.cc b/paddle/fluid/operators/fc_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9c704a2949f7100e0812eafe1e58ef04bf71f840 --- /dev/null +++ b/paddle/fluid/operators/fc_mkldnn_op.cc @@ -0,0 +1,303 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/fc_op.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; +using paddle::platform::MKLDNNDeviceContext; + +template +class MKLDNNMD { + public: + explicit MKLDNNMD(const T* in, const T* w, bool bias) + : in{paddle::framework::vectorize2int(in->dims())}, + w{paddle::framework::vectorize2int(w->dims())} { + with_bias_ = bias; + } + + mkldnn::memory::desc dst() const { + return platform::MKLDNNMemDesc({in[0], w[1]}, + mkldnn::memory::data_type::f32, + mkldnn::memory::format::nc); + } + + mkldnn::memory::desc src() const { + return is_spatial() + ? platform::MKLDNNMemDesc({in[0], in[1], in[2], in[3]}, + mkldnn::memory::data_type::f32, + mkldnn::memory::format::nchw) + : platform::MKLDNNMemDesc({in[0], in[1]}, + mkldnn::memory::data_type::f32, + mkldnn::memory::format::nc); + } + + mkldnn::memory::desc weights() const { + return is_spatial() + ? platform::MKLDNNMemDesc({w[1], in[1], in[2], in[3]}, + mkldnn::memory::data_type::f32, + mkldnn::memory::format::oihw) + : platform::MKLDNNMemDesc({w[1], in[1]}, + mkldnn::memory::data_type::f32, + mkldnn::memory::format::oi); + } + + mkldnn::memory::desc bias() const { + return with_bias_ + ? platform::MKLDNNMemDesc({w[1]}, mkldnn::memory::data_type::f32, + mkldnn::memory::format::format_undef) + : platform::MKLDNNMemDesc({}, mkldnn::memory::data_type::f32, + mkldnn::memory::format::format_undef); + } + + private: + bool is_spatial() const { return in.size() > 1 && w.size() > 1; } + + std::vector in; + std::vector w; + bool with_bias_; + bool is_spatial_; +}; + +class MKLDNNMemory { + public: + MKLDNNMemory(MKLDNNMD* t, const mkldnn::engine& e) + : md_{t}, engine_{e} {} + virtual ~MKLDNNMemory() = default; + + template + mkldnn::memory dst(const Output* out) { + return mkldnn::memory({md_->dst(), engine_}, + static_cast(const_cast(out))); + } + + template + mkldnn::memory dst(Output* out) { + return mkldnn::memory({md_->dst(), engine_}, out); + } + + template + mkldnn::memory src(const Input* in) { + return mkldnn::memory({md_->src(), engine_}, + static_cast(const_cast(in))); + } + + template + mkldnn::memory weights(const Weight* w) { + return mkldnn::memory({md_->weights(), engine_}, + static_cast(const_cast(w))); + } + + mkldnn::memory bias() { + return mkldnn::memory(mkldnn::memory::primitive_desc(md_->bias(), engine_)); + } + + private: + MKLDNNMD* md_; + const mkldnn::engine& engine_; +}; + +template +class FCMKLDNNOpKernel : public paddle::framework::OpKernel { + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + auto input = ctx.Input("Input"); + auto w = ctx.Input("W"); + + PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4, + "Input must be with 2 or 4 dimensions, i.e. NCHW"); + PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4, + "Weights must be with 2 or 4 dimensions, i.e. OI or OIHW"); + + bool with_bias = ctx.Attr("bias_attr"); + MKLDNNMD md(input, w, with_bias); + + std::shared_ptr pd = + FcFwdPrimitiveDesc(md.src(), md.weights(), md.dst(), md.bias(), + with_bias, mkldnn_engine); + + const std::string key = ctx.op().Output("Out"); + const std::string key_fc_pd = key + "@fc_pd"; + + dev_ctx.SetBlob(key_fc_pd, pd); + + MKLDNNMemory mem(&md, mkldnn_engine); + + const T* input_data = input->data(); + const T* w_data = w->data(); + + auto output = ctx.Output("Out"); + T* output_data = output->mutable_data(ctx.GetPlace()); + + auto dst_memory = mem.dst(output_data); + auto src_memory = mem.src(input_data); + auto weights_memory = mem.weights(w_data); + auto bias_memory = mem.bias(); + + auto forward = with_bias ? mkldnn::inner_product_forward( + *pd, src_memory, weights_memory, bias_memory, + dst_memory) + : mkldnn::inner_product_forward( + *pd, src_memory, weights_memory, dst_memory); + + std::vector pipeline = {forward}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); + } + + private: + std::unique_ptr + FcFwdPrimitiveDesc(const mkldnn::memory::desc& src, + const mkldnn::memory::desc& weights, + const mkldnn::memory::desc& dst, + const mkldnn::memory::desc& bias, const bool with_bias, + const mkldnn::engine& engine) const { + auto desc = with_bias + ? mkldnn::inner_product_forward::desc( + mkldnn::prop_kind::forward, src, weights, bias, dst) + : mkldnn::inner_product_forward::desc( + mkldnn::prop_kind::forward, src, weights, dst); + + auto pd = new mkldnn::inner_product_forward::primitive_desc(desc, engine); + return std::unique_ptr(pd); + } +}; + +template +class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + T* input_grad_data = nullptr; + T* w_grad_data = nullptr; + + Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); + Tensor* w_grad = ctx.Output(framework::GradVarName("W")); + + if (input_grad) { + input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + } + if (w_grad) { + w_grad_data = w_grad->mutable_data(ctx.GetPlace()); + } + + const Tensor* input = ctx.Input("Input"); + const T* input_data = input->data(); + + const Tensor* w = ctx.Input("W"); + const T* w_data = w->data(); + + const Tensor* out_grad = ctx.Input(framework::GradVarName("Out")); + const T* out_grad_data = out_grad->data(); + + bool with_bias = ctx.Attr("bias_attr"); + + MKLDNNMD md(input, w, with_bias); + MKLDNNMemory mem(&md, mkldnn_engine); + + auto dst_memory = mem.dst(out_grad_data); + auto src_memory = mem.src(input_data); + auto weights_memory = mem.weights(w_data); + auto bias_memory = mem.bias(); + + const std::string key = ctx.op().Input("Out"); + const std::string key_fc_pd = key + "@fc_pd"; + + auto pd = + std::static_pointer_cast( + dev_ctx.GetBlob(key_fc_pd)); + + PADDLE_ENFORCE(pd != nullptr, "Fail to find key_fc_pd in device context"); + + if (w_grad) { + auto weights_grad_memory = mem.weights(w_grad_data); + + mkldnn::inner_product_backward_weights::primitive_desc bwd_weight_pd = + FcBwdWeightsPrimitiveDesc(md.src(), md.weights(), md.dst(), md.bias(), + with_bias, *pd, mkldnn_engine); + + auto bwd_weights_prim = mkldnn::inner_product_backward_weights( + bwd_weight_pd, src_memory, dst_memory, weights_grad_memory, + bias_memory); + + std::vector pipeline{bwd_weights_prim}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); + } + + if (input_grad) { + auto src_grad_memory = mem.src(input_grad_data); + + mkldnn::inner_product_backward_data::primitive_desc bwd_data_pd = + FcBwdDataPrimitiveDesc(md.src(), md.weights(), md.dst(), *pd, + mkldnn_engine); + + auto bwd_data_prim = mkldnn::inner_product_backward_data( + bwd_data_pd, dst_memory, weights_memory, src_grad_memory); + + std::vector pipeline{bwd_data_prim}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); + } + } + + private: + mkldnn::inner_product_backward_weights::primitive_desc + FcBwdWeightsPrimitiveDesc( + const mkldnn::memory::desc& src, const mkldnn::memory::desc& diff_weights, + const mkldnn::memory::desc& diff_dst, const mkldnn::memory::desc& bias, + const bool with_bias, + const mkldnn::inner_product_forward::primitive_desc& pd, + const mkldnn::engine& engine) const { + auto bwd_weight_desc = with_bias + ? mkldnn::inner_product_backward_weights::desc( + src, diff_weights, bias, diff_dst) + : mkldnn::inner_product_backward_weights::desc( + src, diff_weights, bias, diff_dst); + + return mkldnn::inner_product_backward_weights::primitive_desc( + bwd_weight_desc, engine, pd); + } + + mkldnn::inner_product_backward_data::primitive_desc FcBwdDataPrimitiveDesc( + const mkldnn::memory::desc& diff_src, const mkldnn::memory::desc& weights, + const mkldnn::memory::desc& diff_dst, + const mkldnn::inner_product_forward::primitive_desc& pd, + const mkldnn::engine& engine) const { + auto bwd_data_desc = + mkldnn::inner_product_backward_data::desc(diff_src, weights, diff_dst); + return mkldnn::inner_product_backward_data::primitive_desc(bwd_data_desc, + engine, pd); + } +}; +} // namespace operators +} // namespace paddle + +REGISTER_OP_KERNEL(fc, MKLDNN, ::paddle::platform::CPUPlace, + paddle::operators::FCMKLDNNOpKernel); + +REGISTER_OP_KERNEL(fc_grad, MKLDNN, ::paddle::platform::CPUPlace, + paddle::operators::FCMKLDNNGradOpKernel); diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..381771f157d78fb04e54f0a07c40e4df2c91441a --- /dev/null +++ b/paddle/fluid/operators/fc_op.cc @@ -0,0 +1,102 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fc_op.h" +#include + +namespace paddle { +namespace operators { + +void FCOp::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "X(Input) of Fully Connected should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Out(Output) of Fully Connected should not be null."); + PADDLE_ENFORCE(ctx->HasInput("W"), + "W(Input) of Fully Connected should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + auto w_dims = ctx->GetInputDim("W"); + std::vector output_shape({in_dims[0], w_dims[1]}); + + PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4, + "Fully Connected input should be 2-D or 4-D tensor."); + + PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4, + "Fully Connected input should be 2-D or 4-D tensor."); + + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + ctx->ShareLoD("Input", "Out"); +} + +framework::OpKernelType FCOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + framework::LibraryType library{framework::LibraryType::kMKLDNN}; + framework::DataLayout layout{framework::DataLayout::kAnyLayout}; + + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), + layout, library); +} + +void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { + auto in_dims = ctx->GetInputDim("Input"); + auto w_dims = ctx->GetInputDim("W"); + + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("W"))) { + ctx->SetOutputDim(framework::GradVarName("W"), w_dims); + } +} + +framework::OpKernelType FCOpGrad::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + framework::LibraryType library{framework::LibraryType::kMKLDNN}; + framework::DataLayout layout{framework::DataLayout::kAnyLayout}; + + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), + layout, library); +} + +FCOpMaker::FCOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Input", "(Tensor) The input tensor of fully connected operator. "); + AddInput("W", "(Tensor), The second input tensor of fc op."); + AddOutput("Out", "(Tensor) The output tensor of fully connected operator. "); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr("bias_attr", "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddComment(R"DOC( + Fully Connected Operator. + + The fully connected operation calculates the output based on the input, weights and bias attribute. + The size of each dimension of the parameters checked in the infer-shape. + The matrix of bias is generated by the mkldnn framework, when the bias_attr is True. + Additional parametrs are use_mkldnn and bias_attr. + The input(X) size and output(Out) size may be diffrent. + + The fully connected layer only supports MKLDNN version +)DOC"); +} + +} // namespace operators +} // namespace paddle + +REGISTER_OP(fc, paddle::operators::FCOp, paddle::operators::FCOpMaker, fc_grad, + paddle::operators::FCOpGrad); diff --git a/paddle/fluid/operators/fc_op.h b/paddle/fluid/operators/fc_op.h new file mode 100644 index 0000000000000000000000000000000000000000..70fa96440d344397a7427c1338afee85bde923d4 --- /dev/null +++ b/paddle/fluid/operators/fc_op.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class FCOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class FCOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class FCOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FCOpMaker(OpProto* proto, OpAttrChecker* op_checker); +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index 299a0aed01dfe0448d896738d9fd33319b1b2887..44fd739fb1d161c6c7d6ab1cc611c59220280a4e 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -322,6 +322,14 @@ void set_constant_with_place( TensorSetConstantCPU(tensor, value)); } +template <> +void set_constant_with_place( + const platform::DeviceContext& context, framework::Tensor* tensor, + float value) { + framework::VisitDataType(framework::ToDataType(tensor->type()), + TensorSetConstantCPU(tensor, value)); +} + struct TensorSetConstantWithPlace : public boost::static_visitor { TensorSetConstantWithPlace(const platform::DeviceContext& context, framework::Tensor* tensor, float value) diff --git a/paddle/fluid/operators/nccl_op_test.cu.cc b/paddle/fluid/operators/nccl_op_test.cu.cc index a31d64e899df33f16f707e96d7ff7b85eca8d6ea..20b8a5c98ab16ac8121cb2fd01deb8ecc1966d44 100644 --- a/paddle/fluid/operators/nccl_op_test.cu.cc +++ b/paddle/fluid/operators/nccl_op_test.cu.cc @@ -15,8 +15,8 @@ limitations under the License. */ #include #include #include -#include -#include +#include // NOLINT +#include // NOLINT #include #include "paddle/fluid/framework/init.h" @@ -43,7 +43,7 @@ const f::DDim kDims = {20, 20}; // nccl op common tester, init communicator. class NCCLTester : public ::testing::Test { public: - virtual void SetUp() override { + void SetUp() override { int count = p::GetCUDADeviceCount(); if (count <= 1) { LOG(WARNING) @@ -64,7 +64,7 @@ class NCCLTester : public ::testing::Test { NCCLInitOp(); } - virtual void TearDown() override { + void TearDown() override { for (auto &device_context : dev_ctxs_) { delete device_context; } diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index f9a8058f2a32b6736d6513b017b761a31ddc2e37..96c0c1cbe6d588364416925a7ab1bc8f90ac6fd7 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include // NOLINT + #include "paddle/fluid/framework/channel.h" #include "paddle/fluid/operators/reader/reader_op_registry.h" diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index b6ac7b21d56f7760b3f4814581c90b0ff2cc4a6a..eacedeea8835d27b712b287824b9d30b03ebebbf 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -21,6 +21,22 @@ namespace reader { class MultipleReader : public framework::ReaderBase { public: + class ThreadBufferMap { + public: + std::vector& operator[]( + const std::thread::id& thread_id) { + std::lock_guard lock(mutex_); + return buffer_[thread_id]; + } + + void Clear() { buffer_.clear(); } + + private: + std::mutex mutex_; + std::unordered_map> + buffer_; + }; + MultipleReader(const std::vector& file_names, const std::vector& dims, size_t thread_num) : file_names_(file_names), dims_(dims) { @@ -47,28 +63,27 @@ class MultipleReader : public framework::ReaderBase { framework::Channel* waiting_file_idx_; framework::Channel* available_thread_idx_; framework::Channel>* buffer_; - mutable std::vector local_buffer_; + mutable ThreadBufferMap thread_buffer_map_; }; void MultipleReader::ReadNext(std::vector* out) { if (!HasNext()) { PADDLE_THROW("There is no next data!"); } - - if (local_buffer_.empty()) { - buffer_->Receive(&local_buffer_); - } - *out = local_buffer_; - local_buffer_.clear(); + auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()]; + *out = thread_local_buffer; + thread_local_buffer.clear(); } bool MultipleReader::HasNext() const { - return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true; + auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()]; + return thread_local_buffer.empty() ? buffer_->Receive(&thread_local_buffer) + : true; } void MultipleReader::ReInit() { EndScheduler(); - local_buffer_.clear(); + thread_buffer_map_.Clear(); StartNewScheduler(); } @@ -176,7 +191,7 @@ class OpenFilesOp : public framework::OperatorBase { const auto& ranks = Attr>("ranks"); PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), - int(shape_concat.size()), + static_cast(shape_concat.size()), "The accumulate of all ranks should be equal to the " "shape concat's length."); const auto& file_names = Attr>("file_names"); diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index b87b8e6b26cdeb017e700870998a53c1b295988c..93f9c74b809770136d3d3300e0e0700b1bc0459e 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -14,6 +14,9 @@ limitations under the License. */ #include "paddle/fluid/operators/reshape_op.h" +#include +#include + namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h index 871b4d38d56f10f3c0c178caa566508ab75f316c..807e5ad951b893a4c027a96d743f0606b70cf160 100644 --- a/paddle/fluid/operators/reshape_op.h +++ b/paddle/fluid/operators/reshape_op.h @@ -14,6 +14,9 @@ limitations under the License. */ #pragma once +#include +#include + #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" diff --git a/paddle/fluid/platform/cpu_info.cc b/paddle/fluid/platform/cpu_info.cc index 8db08edba805e41d33ec6a6a4b338cca0d4906ef..4fc9aae8e36e9b43d65fab0b92ec3a2549057128 100644 --- a/paddle/fluid/platform/cpu_info.cc +++ b/paddle/fluid/platform/cpu_info.cc @@ -27,6 +27,11 @@ DEFINE_double(fraction_of_cpu_memory_to_use, 1, "Default use 100% of CPU memory for PaddlePaddle," "reserve the rest for page tables, etc"); +DEFINE_double( + fraction_of_cuda_pinned_memory_to_use, 0.5, + "Default use 50% of CPU memory as the pinned_memory for PaddlePaddle," + "reserve the rest for page tables, etc"); + namespace paddle { namespace platform { @@ -62,5 +67,22 @@ size_t CpuMaxChunkSize() { return CpuMaxAllocSize() / 32; } +size_t CUDAPinnedMaxAllocSize() { + // For distributed systems, it requires configuring and limiting + // the fraction of memory to use. + return FLAGS_fraction_of_cuda_pinned_memory_to_use * CpuTotalPhysicalMemory(); +} + +size_t CUDAPinnedMinChunkSize() { + // Allow to allocate the minimum chunk size is 64 KB. + return 1 << 16; +} + +size_t CUDAPinnedMaxChunkSize() { + // Allow to allocate the maximum chunk size is roughly 1/256 of CUDA_PINNED + // memory. + return CUDAPinnedMaxAllocSize() / 256; +} + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/cpu_info.h b/paddle/fluid/platform/cpu_info.h index a930151bd15a33d5b8861c6239e7dd964822f0f6..f06c2b67fe4385f427322e9bb2f3080fdd3acc94 100644 --- a/paddle/fluid/platform/cpu_info.h +++ b/paddle/fluid/platform/cpu_info.h @@ -22,11 +22,20 @@ namespace platform { //! Get the maximum allocation size for a machine. size_t CpuMaxAllocSize(); +//! Get the maximum allocation size for a machine. +size_t CUDAPinnedMaxAllocSize(); + //! Get the minimum chunk size for buddy allocator. size_t CpuMinChunkSize(); //! Get the maximum chunk size for buddy allocator. size_t CpuMaxChunkSize(); +//! Get the minimum chunk size for buddy allocator. +size_t CUDAPinnedMinChunkSize(); + +//! Get the maximum chunk size for buddy allocator. +size_t CUDAPinnedMaxChunkSize(); + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 59b76a1edb5ec5900520fbccb6a6f8f6e7a70aa4..feb4f367008d76d86a93c561a8eec1f2485c99d6 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -53,6 +53,16 @@ DeviceContextPool::DeviceContextPool( PADDLE_THROW( "'CUDAPlace' is not supported, Please re-compile with WITH_GPU " "option"); +#endif + } else if (platform::is_cuda_pinned_place(p)) { +#ifdef PADDLE_WITH_CUDA + device_contexts_.emplace( + p, + PtrType(new CUDAPinnedDeviceContext(boost::get(p)))); +#else + PADDLE_THROW( + "'CUDAPlace' is not supported, Please re-compile with WITH_GPU " + "option"); #endif } } @@ -186,6 +196,20 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; } +CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() { + eigen_device_.reset(new Eigen::DefaultDevice()); +} + +CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place) + : place_(place) { + eigen_device_.reset(new Eigen::DefaultDevice()); +} + +Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const { + return eigen_device_.get(); +} + +Place CUDAPinnedDeviceContext::GetPlace() const { return place_; } #endif #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 202394c7be7e103a609dd0999fc883c794ef0edd..6b796d92d09cdde2db60c7651c03d3782ff013e6 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -118,6 +118,25 @@ struct DefaultDeviceContextType { using TYPE = CUDADeviceContext; }; +// Currently, CUDAPinnedDeviceContext is only used to data copying. +class CUDAPinnedDeviceContext : public DeviceContext { + public: + CUDAPinnedDeviceContext(); + explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place); + + Place GetPlace() const override; + + Eigen::DefaultDevice* eigen_device() const; + + private: + CUDAPinnedPlace place_; + std::unique_ptr eigen_device_; +}; + +template <> +struct DefaultDeviceContextType { + using TYPE = CUDAPinnedDeviceContext; +}; #endif #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/fluid/platform/place.cc b/paddle/fluid/platform/place.cc index de8f958eb012cb1ac563cbbbac8951e439bf8f33..655ce8485d4584aa0955315b045da6bf541f7fe2 100644 --- a/paddle/fluid/platform/place.cc +++ b/paddle/fluid/platform/place.cc @@ -26,6 +26,7 @@ class PlacePrinter : public boost::static_visitor<> { void operator()(const CUDAPlace &p) { os_ << "CUDAPlace(" << p.device << ")"; } + void operator()(const CUDAPinnedPlace &p) { os_ << "CUDAPinnedPlace"; } private: std::ostream &os_; @@ -40,12 +41,19 @@ const Place &get_place() { return the_default_place; } const CUDAPlace default_gpu() { return CUDAPlace(0); } const CPUPlace default_cpu() { return CPUPlace(); } +const CUDAPinnedPlace default_cuda_pinned() { return CUDAPinnedPlace(); } bool is_gpu_place(const Place &p) { return boost::apply_visitor(IsCUDAPlace(), p); } -bool is_cpu_place(const Place &p) { return !is_gpu_place(p); } +bool is_cpu_place(const Place &p) { + return boost::apply_visitor(IsCPUPlace(), p); +} + +bool is_cuda_pinned_place(const Place &p) { + return boost::apply_visitor(IsCUDAPinnedPlace(), p); +} bool places_are_same_class(const Place &p1, const Place &p2) { return p1.which() == p2.which(); @@ -53,7 +61,7 @@ bool places_are_same_class(const Place &p1, const Place &p2) { bool is_same_place(const Place &p1, const Place &p2) { if (places_are_same_class(p1, p2)) { - if (is_cpu_place(p1)) { + if (is_cpu_place(p1) || is_cuda_pinned_place(p1)) { return true; } else { return boost::get(p1) == boost::get(p2); diff --git a/paddle/fluid/platform/place.h b/paddle/fluid/platform/place.h index 4cc8b377b8b671eb5a446ecbae21ba9628fbd2c8..d0bdcb0da5177f9f8ad517787e612f1b98b3fbb4 100644 --- a/paddle/fluid/platform/place.h +++ b/paddle/fluid/platform/place.h @@ -45,12 +45,33 @@ struct CUDAPlace { int device; }; +struct CUDAPinnedPlace { + CUDAPinnedPlace() {} + + // needed for variant equality comparison + inline bool operator==(const CUDAPinnedPlace &) const { return true; } + inline bool operator!=(const CUDAPinnedPlace &) const { return false; } +}; + struct IsCUDAPlace : public boost::static_visitor { bool operator()(const CPUPlace &) const { return false; } bool operator()(const CUDAPlace &gpu) const { return true; } + bool operator()(const CUDAPinnedPlace &) const { return false; } }; -typedef boost::variant Place; +struct IsCPUPlace : public boost::static_visitor { + bool operator()(const CPUPlace &cpu) const { return true; } + bool operator()(const CUDAPlace &) const { return false; } + bool operator()(const CUDAPinnedPlace &) const { return false; } +}; + +struct IsCUDAPinnedPlace : public boost::static_visitor { + bool operator()(const CPUPlace &) const { return false; } + bool operator()(const CUDAPlace &) const { return false; } + bool operator()(const CUDAPinnedPlace &cuda_pinned) const { return true; } +}; + +typedef boost::variant Place; using PlaceList = std::vector; @@ -59,9 +80,11 @@ const Place &get_place(); const CUDAPlace default_gpu(); const CPUPlace default_cpu(); +const CUDAPinnedPlace default_cuda_pinned(); bool is_gpu_place(const Place &); bool is_cpu_place(const Place &); +bool is_cuda_pinned_place(const Place &); bool places_are_same_class(const Place &, const Place &); bool is_same_place(const Place &, const Place &); @@ -95,6 +118,16 @@ struct PlaceVisitorWrapper #else PADDLE_THROW("Paddle is not compiled with CUDA. Cannot visit cuda device"); return typename Visitor::result_type(); +#endif + } + + typename Visitor::result_type operator()( + const CUDAPinnedPlace &cuda_pinned) const { +#ifdef PADDLE_WITH_CUDA + return visitor_(cuda_pinned); +#else + PADDLE_THROW("Paddle is not compiled with CUDA. Cannot visit cuda_pinned"); + return typename Visitor::result_type(); #endif } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 3d13133bf25aa3f538f6f574bd2ae682e1bc7e39..d2e7d58524bfb11627b6acb36ef873c41b348f0f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -133,6 +133,8 @@ def fc(input, bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias of this layer. If it is set to None, no bias will be added to the output units. act (str, default None): Activation to be applied to the output of this layer. + use_mkldnn(bool): Use mkldnn kernel or not, it is valid only when the mkldnn + library is installed. Default: False name (str, default None): The name of this layer. Returns: @@ -153,38 +155,64 @@ def fc(input, dtype = helper.input_dtype() mul_results = [] - for input_var, param_attr in helper.iter_inputs_and_params(): - input_shape = input_var.shape + if use_mkldnn: + tmp = helper.create_tmp_variable(dtype) + input_shape = input.shape param_shape = [ reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1) ] + [size] w = helper.create_parameter( - attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False) - tmp = helper.create_tmp_variable(dtype) + attr=helper.param_attr, + shape=param_shape, + dtype=dtype, + is_bias=False) + if bias_attr is None or bias_attr is False: + bias_attr = False + else: + bias_attr = True helper.append_op( - type="mul", - inputs={"X": input_var, - "Y": w}, + type="fc", + inputs={"Input": input, + "W": w}, outputs={"Out": tmp}, - attrs={ - "x_num_col_dims": num_flatten_dims, - "y_num_col_dims": 1, - 'use_mkldnn': use_mkldnn - }) - mul_results.append(tmp) - - # sum - if len(mul_results) == 1: - pre_bias = mul_results[0] + attrs={"use_mkldnn": use_mkldnn, + "bias_attr": bias_attr}) + return helper.append_activation(tmp) else: - pre_bias = helper.create_tmp_variable(dtype) - helper.append_op( - type="sum", inputs={"X": mul_results}, outputs={"Out": pre_bias}) - # add bias - pre_activation = helper.append_bias_op(pre_bias, dim_start=num_flatten_dims) - # add activation - return helper.append_activation(pre_activation) + for input_var, param_attr in helper.iter_inputs_and_params(): + input_shape = input_var.shape + param_shape = [ + reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1) + ] + [size] + + w = helper.create_parameter( + attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False) + tmp = helper.create_tmp_variable(dtype) + helper.append_op( + type="mul", + inputs={"X": input_var, + "Y": w}, + outputs={"Out": tmp}, + attrs={ + "x_num_col_dims": num_flatten_dims, + "y_num_col_dims": 1, + }) + mul_results.append(tmp) + + if len(mul_results) == 1: + pre_bias = mul_results[0] + else: + pre_bias = helper.create_tmp_variable(dtype) + helper.append_op( + type="sum", + inputs={"X": mul_results}, + outputs={"Out": pre_bias}) + # add bias + pre_activation = helper.append_bias_op( + pre_bias, dim_start=num_flatten_dims) + # add activation + return helper.append_activation(pre_activation) def embedding(input, diff --git a/python/paddle/fluid/tests/unittests/test_fc_op.py b/python/paddle/fluid/tests/unittests/test_fc_op.py new file mode 100644 index 0000000000000000000000000000000000000000..3f547f3c484bf034a87823a75d946ef130a5cb70 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fc_op.py @@ -0,0 +1,99 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +def fully_connected_naive(input, weights, bias_data=None): + in_n, in_c, in_h, in_w = input.shape + w_h, w_c = weights.shape + + x_data = np.reshape(input, [in_n, in_c * in_h * in_w]) + w_data = np.transpose(np.reshape(weights, (w_c, in_c * in_h * in_w))) + result = None + + if not bias_data: + result = np.dot(x_data, w_data) + else: + result = np.dot(x_data, w_data) + bias_data + + return result + + +class MatrixGenerate: + def __init__(self, mb, ic, oc, h, w): + self.input = np.random.random((mb, ic, h, w)).astype("float32") + self.weights = np.random.random((ic * h * w, oc)).astype("float32") + + +class TestFCMKLDNNOp(OpTest): + def setUp(self): + self.op_type = "fc" + self.use_mkldnn = True + self.with_bias = True + self.matrix = MatrixGenerate(1, 10, 15, 3, 3) + + self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights} + + self.attrs = { + 'use_mkldnn': self.use_mkldnn, + 'with_bias': self.with_bias + } + + self.outputs = { + 'Out': fully_connected_naive(self.matrix.input, self.matrix.weights) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(set(['Input', 'W']), 'Out', max_relative_error=0.9) + + def test_check_grad_no_weight(self): + self.check_grad( + ['Input'], 'Out', max_relative_error=0.5, no_grad_set=set('W')) + + +class TestFCMKLDNNOp1(TestFCMKLDNNOp): + def init_op_type(self): + self.matrix = MatrixGenerate(2, 15, 48, 2, 2) + + +class TestFCMKLDNNOp2(TestFCMKLDNNOp): + def init_op_type(self): + self.matrix = MatrixGenerate(2, 32, 40, 1, 1) + + +class TestFCMKLDNNOp3(TestFCMKLDNNOp): + def init_op_type(self): + self.matrix = MatrixGenerate(2, 2, 4, 1, 1) + + +class TestFCMKLDNNOp4(TestFCMKLDNNOp): + def init_op_type(self): + self.with_bias = False + self.matrix = MatrixGenerate(2, 32, 48, 2, 2) + + +class TestFCMKLDNNOp4(TestFCMKLDNNOp): + def init_op_type(self): + self.with_bias = False + self.matrix = MatrixGenerate(2, 32, 1000, 6, 6) + + +if __name__ == "__main__": + unittest.main()