From 7a6fcc7d30ab8dd8f452a9974e16798dbbe05dfe Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 28 Sep 2017 17:39:50 -0700 Subject: [PATCH] move EigenDeviceConverter to device_context.h --- paddle/framework/operator.cc | 4 ++-- paddle/framework/operator.h | 19 ++----------------- paddle/platform/device_context.cc | 15 ++++++++------- paddle/platform/device_context.h | 19 +++++++++++++++++-- 4 files changed, 29 insertions(+), 28 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index d7beff5bc1d..8b5560ffa12 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -22,14 +22,14 @@ namespace framework { template <> Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { - return *device_context_.get_eigen_device(); + return *device_context_.GetEigenDevice(); } #ifndef PADDLE_ONLY_CPU template <> Eigen::GpuDevice& ExecutionContext::GetEigenDevice() const { - return *device_context_.get_eigen_device(); + return *device_context_.GetEigenDevice(); } #endif diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index ba697a43e9e..310d68d7c1b 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -296,21 +296,6 @@ template <> std::vector InferShapeContext::MultiOutput( const std::string& name) const; -template -struct EigenDeviceConverter; - -template <> -struct EigenDeviceConverter { - using EigenDeviceType = Eigen::DefaultDevice; -}; - -#ifndef PADDLE_ONLY_CPU -template <> -struct EigenDeviceConverter { - using EigenDeviceType = Eigen::GpuDevice; -}; -#endif - class ExecutionContext : public InferShapeContext { public: ExecutionContext(const OperatorBase& op, const Scope& scope, @@ -318,8 +303,8 @@ class ExecutionContext : public InferShapeContext { : InferShapeContext(op, scope), device_context_(device_context) {} template ::EigenDeviceType> + typename DeviceType = typename platform::EigenDeviceConverter< + PlaceType>::EigenDeviceType> DeviceType& GetEigenDevice() const; platform::Place GetPlace() const { return device_context_.GetPlace(); } diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 93b472b41c8..36af1ac677f 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -16,8 +16,8 @@ namespace paddle { namespace platform { template <> -Eigen::DefaultDevice* DeviceContext::get_eigen_device() - const { +Eigen::DefaultDevice* DeviceContext::GetEigenDevice< + platform::CPUPlace, Eigen::DefaultDevice>() const { return reinterpret_cast(this)->eigen_device(); } @@ -37,6 +37,12 @@ Place CPUDeviceContext::GetPlace() const { return CPUPlace(); } #ifndef PADDLE_ONLY_CPU +template <> +Eigen::GpuDevice* +DeviceContext::GetEigenDevice() const { + return reinterpret_cast(this)->eigen_device(); +} + class EigenCudaStreamDevice : public Eigen::StreamInterface { public: EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) { @@ -90,11 +96,6 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { mutable unsigned int* semaphore_; }; -template <> -Eigen::GpuDevice* DeviceContext::get_eigen_device() const { - return reinterpret_cast(this)->eigen_device(); -} - CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { SetDeviceId(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index f6a39a8e26c..d805d2ab085 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -27,13 +27,23 @@ limitations under the License. */ namespace paddle { namespace platform { +template +struct EigenDeviceConverter; + +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::DefaultDevice; +}; + class DeviceContext { public: virtual ~DeviceContext() {} virtual Place GetPlace() const = 0; - template - DeviceType* get_eigen_device() const; + template ::EigenDeviceType> + DeviceType* GetEigenDevice() const; virtual void Wait() const {} }; @@ -52,6 +62,11 @@ class CPUDeviceContext : public DeviceContext { }; #ifndef PADDLE_ONLY_CPU +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::GpuDevice; +}; + class EigenCudaStreamDevice; class CUDADeviceContext : public DeviceContext { -- GitLab