From ab601c19c37da50f391de62581e572dbe32d6b7f Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 26 Mar 2018 22:47:20 +0800 Subject: [PATCH] Add CUDAPinnedPlace --- paddle/fluid/framework/tensor_util.cc | 5 ++++ paddle/fluid/memory/memory.cc | 4 +++ paddle/fluid/operators/math/math_function.cc | 8 ++++++ paddle/fluid/platform/device_context.cc | 24 +++++++++++++++++ paddle/fluid/platform/device_context.h | 27 ++++++++++++-------- paddle/fluid/platform/place.h | 5 ++++ 6 files changed, 63 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 8b7533ce71..1d864af011 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -148,6 +148,11 @@ struct AnyVisitor : public boost::static_visitor { const platform::CPUPlace& cpu) const { return *out.data(); } + + bool GetResult(const framework::Tensor& out, + const platform::CUDAPinnedPlace& cpu) const { + return *out.data(); + } }; template diff --git a/paddle/fluid/memory/memory.cc b/paddle/fluid/memory/memory.cc index 94b43af147..f6cbdaaa97 100644 --- a/paddle/fluid/memory/memory.cc +++ b/paddle/fluid/memory/memory.cc @@ -160,7 +160,11 @@ size_t Usage::operator()(const platform::CUDAPlace& gpu) const { } size_t Usage::operator()(const platform::CUDAPinnedPlace& cuda_pinned) const { +#ifdef PADDLE_WITH_CUDA return Used(cuda_pinned); +#else + PADDLE_THROW("'CUDAPinnedPlace' is not supported in CPU only device."); +#endif } size_t memory_usage(const platform::Place& p) { diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index 299a0aed01..44fd739fb1 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -322,6 +322,14 @@ void set_constant_with_place( TensorSetConstantCPU(tensor, value)); } +template <> +void set_constant_with_place( + const platform::DeviceContext& context, framework::Tensor* tensor, + float value) { + framework::VisitDataType(framework::ToDataType(tensor->type()), + TensorSetConstantCPU(tensor, value)); +} + struct TensorSetConstantWithPlace : public boost::static_visitor { TensorSetConstantWithPlace(const platform::DeviceContext& context, framework::Tensor* tensor, float value) diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 59b76a1edb..feb4f36700 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -53,6 +53,16 @@ DeviceContextPool::DeviceContextPool( PADDLE_THROW( "'CUDAPlace' is not supported, Please re-compile with WITH_GPU " "option"); +#endif + } else if (platform::is_cuda_pinned_place(p)) { +#ifdef PADDLE_WITH_CUDA + device_contexts_.emplace( + p, + PtrType(new CUDAPinnedDeviceContext(boost::get(p)))); +#else + PADDLE_THROW( + "'CUDAPlace' is not supported, Please re-compile with WITH_GPU " + "option"); #endif } } @@ -186,6 +196,20 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; } +CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() { + eigen_device_.reset(new Eigen::DefaultDevice()); +} + +CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place) + : place_(place) { + eigen_device_.reset(new Eigen::DefaultDevice()); +} + +Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const { + return eigen_device_.get(); +} + +Place CUDAPinnedDeviceContext::GetPlace() const { return place_; } #endif #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index e25cfe60b1..6b796d92d0 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -119,17 +119,24 @@ struct DefaultDeviceContextType { }; // Currently, CUDAPinnedDeviceContext is only used to data copying. -// class CUDAPinnedDeviceContext : public DeviceContext { -// public: -// CUDAPinnedDeviceContext(); -// explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place); -// -// Place GetPlace() const override; -// -// private: -// CUDAPinnedPlace place_; -//}; +class CUDAPinnedDeviceContext : public DeviceContext { + public: + CUDAPinnedDeviceContext(); + explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place); + + Place GetPlace() const override; + Eigen::DefaultDevice* eigen_device() const; + + private: + CUDAPinnedPlace place_; + std::unique_ptr eigen_device_; +}; + +template <> +struct DefaultDeviceContextType { + using TYPE = CUDAPinnedDeviceContext; +}; #endif #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/fluid/platform/place.h b/paddle/fluid/platform/place.h index 8f3acd8df6..d0bdcb0da5 100644 --- a/paddle/fluid/platform/place.h +++ b/paddle/fluid/platform/place.h @@ -123,7 +123,12 @@ struct PlaceVisitorWrapper typename Visitor::result_type operator()( const CUDAPinnedPlace &cuda_pinned) const { +#ifdef PADDLE_WITH_CUDA return visitor_(cuda_pinned); +#else + PADDLE_THROW("Paddle is not compiled with CUDA. Cannot visit cuda_pinned"); + return typename Visitor::result_type(); +#endif } }; -- GitLab