提交 7a6fcc7d 编写于 作者: Q qijun

move EigenDeviceConverter to device_context.h

上级 184768e0
...@@ -22,14 +22,14 @@ namespace framework { ...@@ -22,14 +22,14 @@ namespace framework {
template <> template <>
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const { platform::CPUPlace, Eigen::DefaultDevice>() const {
return *device_context_.get_eigen_device<Eigen::DefaultDevice>(); return *device_context_.GetEigenDevice<platform::CPUPlace>();
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice& Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return *device_context_.get_eigen_device<Eigen::GpuDevice>(); return *device_context_.GetEigenDevice<platform::GPUPlace>();
} }
#endif #endif
......
...@@ -296,21 +296,6 @@ template <> ...@@ -296,21 +296,6 @@ template <>
std::vector<Tensor*> InferShapeContext::MultiOutput<Tensor>( std::vector<Tensor*> InferShapeContext::MultiOutput<Tensor>(
const std::string& name) const; const std::string& name) const;
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif
class ExecutionContext : public InferShapeContext { class ExecutionContext : public InferShapeContext {
public: public:
ExecutionContext(const OperatorBase& op, const Scope& scope, ExecutionContext(const OperatorBase& op, const Scope& scope,
...@@ -318,8 +303,8 @@ class ExecutionContext : public InferShapeContext { ...@@ -318,8 +303,8 @@ class ExecutionContext : public InferShapeContext {
: InferShapeContext(op, scope), device_context_(device_context) {} : InferShapeContext(op, scope), device_context_(device_context) {}
template <typename PlaceType, template <typename PlaceType,
typename DeviceType = typename DeviceType = typename platform::EigenDeviceConverter<
typename EigenDeviceConverter<PlaceType>::EigenDeviceType> PlaceType>::EigenDeviceType>
DeviceType& GetEigenDevice() const; DeviceType& GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); } platform::Place GetPlace() const { return device_context_.GetPlace(); }
......
...@@ -16,8 +16,8 @@ namespace paddle { ...@@ -16,8 +16,8 @@ namespace paddle {
namespace platform { namespace platform {
template <> template <>
Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>() Eigen::DefaultDevice* DeviceContext::GetEigenDevice<
const { platform::CPUPlace, Eigen::DefaultDevice>() const {
return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device(); return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device();
} }
...@@ -37,6 +37,12 @@ Place CPUDeviceContext::GetPlace() const { return CPUPlace(); } ...@@ -37,6 +37,12 @@ Place CPUDeviceContext::GetPlace() const { return CPUPlace(); }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice*
DeviceContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
}
class EigenCudaStreamDevice : public Eigen::StreamInterface { class EigenCudaStreamDevice : public Eigen::StreamInterface {
public: public:
EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) { EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
...@@ -90,11 +96,6 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { ...@@ -90,11 +96,6 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
mutable unsigned int* semaphore_; mutable unsigned int* semaphore_;
}; };
template <>
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
}
CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
SetDeviceId(place_.device); SetDeviceId(place_.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_)); PADDLE_ENFORCE(cudaStreamCreate(&stream_));
......
...@@ -27,13 +27,23 @@ limitations under the License. */ ...@@ -27,13 +27,23 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
class DeviceContext { class DeviceContext {
public: public:
virtual ~DeviceContext() {} virtual ~DeviceContext() {}
virtual Place GetPlace() const = 0; virtual Place GetPlace() const = 0;
template <typename DeviceType> template <typename PlaceType,
DeviceType* get_eigen_device() const; typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* GetEigenDevice() const;
virtual void Wait() const {} virtual void Wait() const {}
}; };
...@@ -52,6 +62,11 @@ class CPUDeviceContext : public DeviceContext { ...@@ -52,6 +62,11 @@ class CPUDeviceContext : public DeviceContext {
}; };
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
class EigenCudaStreamDevice; class EigenCudaStreamDevice;
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册