提交 3c49e7b1 编写于 作者: Q qijun

move EigenDeviceConverter to device_context.h

上级 0f42e564
...@@ -22,14 +22,14 @@ namespace framework { ...@@ -22,14 +22,14 @@ namespace framework {
template <> template <>
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const { platform::CPUPlace, Eigen::DefaultDevice>() const {
return *device_context_->get_eigen_device<Eigen::DefaultDevice>(); return *device_context_->get_eigen_device<platform::CPUPlace>();
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice& Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return *device_context_->get_eigen_device<Eigen::GpuDevice>(); return *device_context_->get_eigen_device<platform::GPUPlace>();
} }
#endif #endif
......
...@@ -331,21 +331,6 @@ class InferShapeContext { ...@@ -331,21 +331,6 @@ class InferShapeContext {
const Scope& scope_; const Scope& scope_;
}; };
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif
class ExecutionContext : public InferShapeContext { class ExecutionContext : public InferShapeContext {
public: public:
ExecutionContext(const OperatorBase& op, const Scope& scope, ExecutionContext(const OperatorBase& op, const Scope& scope,
...@@ -353,8 +338,8 @@ class ExecutionContext : public InferShapeContext { ...@@ -353,8 +338,8 @@ class ExecutionContext : public InferShapeContext {
: InferShapeContext(op, scope), device_context_(device_context) {} : InferShapeContext(op, scope), device_context_(device_context) {}
template <typename PlaceType, template <typename PlaceType,
typename DeviceType = typename DeviceType = typename platform::EigenDeviceConverter<
typename EigenDeviceConverter<PlaceType>::EigenDeviceType> PlaceType>::EigenDeviceType>
DeviceType& GetEigenDevice() const; DeviceType& GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_->GetPlace(); } platform::Place GetPlace() const { return device_context_->GetPlace(); }
......
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
namespace math {
template <typename Place, typename T>
struct sigmoid {
void operator()(const platform::DeviceContext& deice_context,
const framework::Tensor& input, framework::Tensor* output) {
auto x = framework::EigenVector<T>::Flatten(*output);
auto y = framework::EigenVector<T>::Flatten(input);
auto* place = device_context.get_eigen_device<Place>();
y.device(*place) = 1. / (1. + (-x).exp());
}
};
}
}
}
...@@ -16,8 +16,8 @@ namespace paddle { ...@@ -16,8 +16,8 @@ namespace paddle {
namespace platform { namespace platform {
template <> template <>
Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>() Eigen::DefaultDevice*
const { DeviceContext::get_eigen_device<CPUPlace, Eigen::DefaultDevice>() const {
return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device(); return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device();
} }
...@@ -91,7 +91,8 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { ...@@ -91,7 +91,8 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
}; };
template <> template <>
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const { Eigen::GpuDevice* DeviceContext::get_eigen_device<GPUPlace, Eigen::GpuDevice>()
const {
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device(); return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
} }
......
...@@ -27,12 +27,29 @@ limitations under the License. */ ...@@ -27,12 +27,29 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif
class DeviceContext { class DeviceContext {
public: public:
virtual ~DeviceContext() {} virtual ~DeviceContext() {}
virtual Place GetPlace() const = 0; virtual Place GetPlace() const = 0;
template <typename DeviceType> template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* get_eigen_device() const; DeviceType* get_eigen_device() const;
}; };
......
...@@ -24,7 +24,7 @@ TEST(Device, Init) { ...@@ -24,7 +24,7 @@ TEST(Device, Init) {
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
DeviceContext* device_context = new CUDADeviceContext(GPUPlace(i)); DeviceContext* device_context = new CUDADeviceContext(GPUPlace(i));
Eigen::GpuDevice* gpu_device = Eigen::GpuDevice* gpu_device =
device_context->template get_eigen_device<Eigen::GpuDevice>(); device_context->template get_eigen_device<GPUPlace>();
ASSERT_NE(nullptr, gpu_device); ASSERT_NE(nullptr, gpu_device);
delete device_context; delete device_context;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册