diff --git a/paddle/fluid/framework/tensor.cc b/paddle/fluid/framework/tensor.cc index 48d300eba957a3469876eab1f5dabb56376e81b5..41566800e5781d576120ccf5dfbb3024bf4bea24 100644 --- a/paddle/fluid/framework/tensor.cc +++ b/paddle/fluid/framework/tensor.cc @@ -32,6 +32,7 @@ size_t Tensor::memory_size() const { } void* Tensor::mutable_data(platform::Place place, std::type_index type, + memory::Allocator::Attr attr, size_t requested_size) { type_ = type; PADDLE_ENFORCE_GE(numel(), 0, @@ -46,17 +47,18 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type, /* some versions of boost::variant don't have operator!= */ if (holder_ == nullptr || !(holder_->place() == place) || holder_->size() < size + offset_) { - holder_ = memory::AllocShared(place, size); + holder_ = memory::AllocShared(place, size, attr); offset_ = 0; } return reinterpret_cast(reinterpret_cast(holder_->ptr()) + offset_); } -void* Tensor::mutable_data(platform::Place place, size_t requested_size) { +void* Tensor::mutable_data(platform::Place place, memory::Allocator::Attr attr, + size_t requested_size) { PADDLE_ENFORCE(this->holder_ != nullptr, "Cannot invoke mutable data if current hold nothing."); - return mutable_data(place, type_, requested_size); + return mutable_data(place, type_, attr, requested_size); } Tensor& Tensor::ShareDataWith(const Tensor& src) { diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 232b5a67a0a290027dc1e61ccfaf0432b5727fd8..0a4aebefacd21b8c8f6386b0c6f30fd51ceca38c 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -84,12 +84,17 @@ class Tensor { * @note If not exist, then allocation. */ template - T* mutable_data(platform::Place place, size_t requested_size = 0); + T* mutable_data(platform::Place place, + memory::Allocator::Attr attr = memory::Allocator::kDefault, + size_t requested_size = 0); void* mutable_data(platform::Place place, std::type_index type, + memory::Allocator::Attr attr = memory::Allocator::kDefault, size_t requested_size = 0); - void* mutable_data(platform::Place place, size_t requested_size = 0); + void* mutable_data(platform::Place place, + memory::Allocator::Attr attr = memory::Allocator::kDefault, + size_t requested_size = 0); /** * @brief Return a pointer to mutable memory block. @@ -101,7 +106,9 @@ class Tensor { * @note If not exist, then allocation. */ template - T* mutable_data(DDim dims, platform::Place place, size_t requested_size = 0); + T* mutable_data(DDim dims, platform::Place place, + memory::Allocator::Attr attr = memory::Allocator::kDefault, + size_t requested_size = 0); /*! Return the dimensions of the memory block. */ const DDim& dims() const; diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index dfa251c02da03c1e1d1545a538e1982c6bccb168..0c9c0d782fc73bd8278b82bebf7fd84a4f297b94 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -47,16 +47,20 @@ inline T* Tensor::data() { template inline T* Tensor::mutable_data(DDim dims, platform::Place place, + memory::Allocator::Attr attr, size_t requested_size) { static_assert(std::is_pod::value, "T must be POD"); Resize(dims); - return mutable_data(place, requested_size); + return mutable_data(place, attr, requested_size); } template -inline T* Tensor::mutable_data(platform::Place place, size_t requested_size) { +inline T* Tensor::mutable_data(platform::Place place, + memory::Allocator::Attr attr, + size_t requested_size) { static_assert(std::is_pod::value, "T must be POD"); - return reinterpret_cast(mutable_data(place, typeid(T), requested_size)); + return reinterpret_cast( + mutable_data(place, typeid(T), attr, requested_size)); } inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { diff --git a/paddle/fluid/memory/allocation/CMakeLists.txt b/paddle/fluid/memory/allocation/CMakeLists.txt index 937b26f807545cc8cf844787fc9e94c18499ba77..44a354cf223f1147aa016fedca7f0c0b7c6bf1f2 100644 --- a/paddle/fluid/memory/allocation/CMakeLists.txt +++ b/paddle/fluid/memory/allocation/CMakeLists.txt @@ -25,9 +25,9 @@ endif() cc_library(naive_managed_allocator SRCS naive_managed_allocator.cc DEPS allocator) cc_test(naive_managed_allocator_test SRCS naive_managed_allocator_test.cc DEPS naive_managed_allocator) - +nv_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator) if (WITH_GPU) - set(AllocatorFacadeDeps gpu_info cuda_allocator) + set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator) else () set(AllocatorFacadeDeps) endif() diff --git a/paddle/fluid/memory/allocation/allocator.h b/paddle/fluid/memory/allocation/allocator.h index 500fc28645bc0aa7003a0849a55273da8e19152a..1ee80a3b40e449615bcab10c1e05920215ebda38 100644 --- a/paddle/fluid/memory/allocation/allocator.h +++ b/paddle/fluid/memory/allocation/allocator.h @@ -60,7 +60,8 @@ class Allocator { kFixedHuge = 2, kFluxHuge = 3, kTmp = 4, - NumOfAttrs = 5 + kCommunication = 5, + NumOfAttrs = 6 }; virtual ~Allocator(); diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index bfd5f959faca56168cd0acfd31d23cdf7cbbb965..2a5fd608bcc95ab906544b2a782a4a45885539d7 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -21,6 +21,7 @@ #include "paddle/fluid/memory/allocation/cpu_allocator.h" #include "paddle/fluid/memory/allocation/locked_allocator.h" #include "paddle/fluid/memory/allocation/naive_managed_allocator.h" +#include "paddle/fluid/memory/allocation/pinned_allocator.h" #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/place.h" @@ -32,6 +33,35 @@ namespace paddle { namespace memory { namespace allocation { +class CPUManagedAllocator : public ManagedAllocator { + public: + CPUManagedAllocator() + : normal_allocator_(NaiveManagedAllocator::Create( + std::unique_ptr(new CPUAllocator()))), + communication_allocator_(NaiveManagedAllocator::Create( + std::unique_ptr(new CPUPinnedAllocator()))) {} + + std::unique_ptr Allocate(size_t size, Attr attr) override { + if (attr == kCommunication) { + return communication_allocator_->Allocate(size, attr); + } else { + return normal_allocator_->Allocate(size, attr); + } + } + + std::shared_ptr AllocateShared(size_t size, Attr attr) override { + if (attr == kCommunication) { + return communication_allocator_->AllocateShared(size, attr); + } else { + return normal_allocator_->AllocateShared(size, attr); + } + } + + private: + std::shared_ptr normal_allocator_; + std::shared_ptr communication_allocator_; +}; + class AllocatorFacadePrivate { public: std::map> allocators_; @@ -52,10 +82,7 @@ class AllocatorFacadePrivate { private: void InitCPUAllocator() { - auto all = NaiveManagedAllocator::Create( - std::unique_ptr(new CPUAllocator())); - - allocators_[platform::CPUPlace()] = all; + allocators_[platform::CPUPlace()] = std::make_shared(); } void InitCUDAAllocator() { diff --git a/paddle/fluid/memory/allocation/pinned_allocator.cc b/paddle/fluid/memory/allocation/pinned_allocator.cc new file mode 100644 index 0000000000000000000000000000000000000000..39f4b78421592d9916db192ffc0be1b2b59c7dfc --- /dev/null +++ b/paddle/fluid/memory/allocation/pinned_allocator.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/memory/allocation/pinned_allocator.h" +#include +#include + +namespace paddle { +namespace memory { +namespace allocation { + +std::unique_ptr CPUPinnedAllocator::Allocate(size_t size, + Allocator::Attr attr) { + PADDLE_ENFORCE_EQ( + attr, kCommunication, + "CPUPinnedAllocator should be used for Cross-Device Communication"); + + void* ptr; + PADDLE_ENFORCE(cudaMallocHost(&ptr, size)); + return std::unique_ptr( + new CPUPinnedAllocation(ptr, size)); +} + +void CPUPinnedAllocator::Free(Allocation* allocation) { + PADDLE_ENFORCE_NOT_NULL(dynamic_cast(allocation)); + PADDLE_ENFORCE(cudaFreeHost(allocation->ptr())); +} + +bool CPUPinnedAllocator::IsAllocThreadSafe() const { return true; } +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/allocation/pinned_allocator.h b/paddle/fluid/memory/allocation/pinned_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..eb249192dd016dcd1405f65083e9310d32172a57 --- /dev/null +++ b/paddle/fluid/memory/allocation/pinned_allocator.h @@ -0,0 +1,37 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/memory/allocation/allocator.h" + +namespace paddle { +namespace memory { +namespace allocation { + +class CPUPinnedAllocation : public Allocation { + public: + CPUPinnedAllocation(void* ptr, size_t size) + : Allocation(ptr, size, platform::CPUPlace()) {} +}; + +class CPUPinnedAllocator : public UnmanagedAllocator { + public: + std::unique_ptr Allocate(size_t size, Attr attr) override; + void Free(Allocation* allocation) override; + bool IsAllocThreadSafe() const override; +}; + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index eae65968285703f5882d910e29bc5d8e1511cba6..68faa1b2b64b65bb0dbec85b9dd9edd7cc9e6301 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -303,7 +303,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bool fuse_eltwise = ctx.Attr("fuse_eltwise"); int groups = ctx.Attr("groups"); - // TODO: add support for dilation + // TODO: add support for dilation // NOLINT PADDLE_ENFORCE( dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1, "dilation in convolution is not implemented yet"); @@ -386,8 +386,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto user_weights_memory_p = handler.AcquireWeightsMemory( user_weights_md, to_void_cast(filter_data)); - T* output_data = - output->mutable_data(ctx.GetPlace(), handler.GetDstMemorySize()); + T* output_data = output->mutable_data( + ctx.GetPlace(), paddle::memory::Allocator::kDefault, + handler.GetDstMemorySize()); // create reorder primitive if the input format is not the preferred one auto src_memory_p = handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); @@ -626,7 +627,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { user_diff_dst_memory_p, pipeline); const size_t size = handler.GetDiffWeightsMemorySize(); - filter_grad_data = filter_grad->mutable_data(ctx.GetPlace(), size); + filter_grad_data = filter_grad->mutable_data( + ctx.GetPlace(), paddle::memory::Allocator::kDefault, size); auto diff_weights_memory_p = handler.AcquireDiffWeightsMemoryFromWeightsPrimitive( @@ -651,7 +653,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { pipeline); const size_t size = handler.GetDiffSourceMemorySize(); - input_grad_data = input_grad->mutable_data(ctx.GetPlace(), size); + input_grad_data = input_grad->mutable_data( + ctx.GetPlace(), paddle::memory::Allocator::kDefault, size); auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive( reinterpret_cast(input_grad_data)); diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 51614a6a3dd2f7f830cf533fc365b56a99d3b918..7a5bf3230e0ca5ad9da0e127fc7f8f7a4eac97db 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -112,17 +112,16 @@ T TensorGetElement(const framework::Tensor &self, size_t offset) { } } -// TODO(dzhwinter) : fix the redundent Tensor allocate and free +// TODO(dzhwinter) : fix the redundant Tensor allocate and free template void TensorSetElement(framework::Tensor *self, size_t offset, T elem) { if (platform::is_gpu_place(self->place())) { - std::shared_ptr dst(new framework::Tensor); - framework::TensorCopySync(*self, platform::CPUPlace(), dst.get()); - dst->data()[offset] = elem; - framework::TensorCopySync(*dst.get(), self->place(), self); - + framework::Tensor dst; + framework::TensorCopySync(*self, platform::CPUPlace(), &dst); + dst.mutable_data(platform::CPUPlace())[offset] = elem; + framework::TensorCopySync(dst, self->place(), self); } else if (platform::is_cpu_place(self->place())) { - self->data()[offset] = elem; + self->mutable_data(self->place())[offset] = elem; } } diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index 6a2732e9399aa5a93f4c47eb73bfd23dba608c3d..6514fd29cb766f472f9f9ba035ba9cc344a107ae 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -113,7 +113,7 @@ class TestConv2dOp(OpTest): return place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace() self.check_grad_with_place( - place, set(['Input', 'Filter']), 'Output', max_relative_error=0.02) + place, {'Input', 'Filter'}, 'Output', max_relative_error=0.02) def test_check_grad_no_filter(self): if self.dtype == np.float16: