diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index d7beff5bc1df1def6bf35381e103cf87eeb68fd0..8b5560ffa1234145fb4291f5730f89fd7375ee15 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -22,14 +22,14 @@ namespace framework { template <> Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { - return *device_context_.get_eigen_device(); + return *device_context_.GetEigenDevice(); } #ifndef PADDLE_ONLY_CPU template <> Eigen::GpuDevice& ExecutionContext::GetEigenDevice() const { - return *device_context_.get_eigen_device(); + return *device_context_.GetEigenDevice(); } #endif diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index ba697a43e9ebdd1837720098d74b95e2dbad77d3..310d68d7c1baac231a2f1709af28bfb58ae1a436 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -296,21 +296,6 @@ template <> std::vector InferShapeContext::MultiOutput( const std::string& name) const; -template -struct EigenDeviceConverter; - -template <> -struct EigenDeviceConverter { - using EigenDeviceType = Eigen::DefaultDevice; -}; - -#ifndef PADDLE_ONLY_CPU -template <> -struct EigenDeviceConverter { - using EigenDeviceType = Eigen::GpuDevice; -}; -#endif - class ExecutionContext : public InferShapeContext { public: ExecutionContext(const OperatorBase& op, const Scope& scope, @@ -318,8 +303,8 @@ class ExecutionContext : public InferShapeContext { : InferShapeContext(op, scope), device_context_(device_context) {} template ::EigenDeviceType> + typename DeviceType = typename platform::EigenDeviceConverter< + PlaceType>::EigenDeviceType> DeviceType& GetEigenDevice() const; platform::Place GetPlace() const { return device_context_.GetPlace(); } diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 93b472b41c8a4c3a2bfada9d4fbf0e9e1b0cc736..36af1ac677f6bb3e5b6392ff0de678afe7e47950 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -16,8 +16,8 @@ namespace paddle { namespace platform { template <> -Eigen::DefaultDevice* DeviceContext::get_eigen_device() - const { +Eigen::DefaultDevice* DeviceContext::GetEigenDevice< + platform::CPUPlace, Eigen::DefaultDevice>() const { return reinterpret_cast(this)->eigen_device(); } @@ -37,6 +37,12 @@ Place CPUDeviceContext::GetPlace() const { return CPUPlace(); } #ifndef PADDLE_ONLY_CPU +template <> +Eigen::GpuDevice* +DeviceContext::GetEigenDevice() const { + return reinterpret_cast(this)->eigen_device(); +} + class EigenCudaStreamDevice : public Eigen::StreamInterface { public: EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) { @@ -90,11 +96,6 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { mutable unsigned int* semaphore_; }; -template <> -Eigen::GpuDevice* DeviceContext::get_eigen_device() const { - return reinterpret_cast(this)->eigen_device(); -} - CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { SetDeviceId(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index f6a39a8e26c301296aac0af7f4e8b2c6c97ece24..d805d2ab085f76e119edf1c6f2acb9715883d755 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -27,13 +27,23 @@ limitations under the License. */ namespace paddle { namespace platform { +template +struct EigenDeviceConverter; + +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::DefaultDevice; +}; + class DeviceContext { public: virtual ~DeviceContext() {} virtual Place GetPlace() const = 0; - template - DeviceType* get_eigen_device() const; + template ::EigenDeviceType> + DeviceType* GetEigenDevice() const; virtual void Wait() const {} }; @@ -52,6 +62,11 @@ class CPUDeviceContext : public DeviceContext { }; #ifndef PADDLE_ONLY_CPU +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::GpuDevice; +}; + class EigenCudaStreamDevice; class CUDADeviceContext : public DeviceContext {