提交 7a6fcc7d 编写于 作者: Q qijun

move EigenDeviceConverter to device_context.h

上级 184768e0
......@@ -22,14 +22,14 @@ namespace framework {
template <>
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return *device_context_.get_eigen_device<Eigen::DefaultDevice>();
return *device_context_.GetEigenDevice<platform::CPUPlace>();
}
#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return *device_context_.get_eigen_device<Eigen::GpuDevice>();
return *device_context_.GetEigenDevice<platform::GPUPlace>();
}
#endif
......
......@@ -296,21 +296,6 @@ template <>
std::vector<Tensor*> InferShapeContext::MultiOutput<Tensor>(
const std::string& name) const;
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif
class ExecutionContext : public InferShapeContext {
public:
ExecutionContext(const OperatorBase& op, const Scope& scope,
......@@ -318,8 +303,8 @@ class ExecutionContext : public InferShapeContext {
: InferShapeContext(op, scope), device_context_(device_context) {}
template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
typename DeviceType = typename platform::EigenDeviceConverter<
PlaceType>::EigenDeviceType>
DeviceType& GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); }
......
......@@ -16,8 +16,8 @@ namespace paddle {
namespace platform {
template <>
Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>()
const {
Eigen::DefaultDevice* DeviceContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device();
}
......@@ -37,6 +37,12 @@ Place CPUDeviceContext::GetPlace() const { return CPUPlace(); }
#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice*
DeviceContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
}
class EigenCudaStreamDevice : public Eigen::StreamInterface {
public:
EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
......@@ -90,11 +96,6 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
mutable unsigned int* semaphore_;
};
template <>
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
}
CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
......
......@@ -27,13 +27,23 @@ limitations under the License. */
namespace paddle {
namespace platform {
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
class DeviceContext {
public:
virtual ~DeviceContext() {}
virtual Place GetPlace() const = 0;
template <typename DeviceType>
DeviceType* get_eigen_device() const;
template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* GetEigenDevice() const;
virtual void Wait() const {}
};
......@@ -52,6 +62,11 @@ class CPUDeviceContext : public DeviceContext {
};
#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
class EigenCudaStreamDevice;
class CUDADeviceContext : public DeviceContext {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册