提交 3c49e7b1 编写于 作者: Q qijun

move EigenDeviceConverter to device_context.h

上级 0f42e564
......@@ -22,14 +22,14 @@ namespace framework {
template <>
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return *device_context_->get_eigen_device<Eigen::DefaultDevice>();
return *device_context_->get_eigen_device<platform::CPUPlace>();
}
#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return *device_context_->get_eigen_device<Eigen::GpuDevice>();
return *device_context_->get_eigen_device<platform::GPUPlace>();
}
#endif
......
......@@ -331,21 +331,6 @@ class InferShapeContext {
const Scope& scope_;
};
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif
class ExecutionContext : public InferShapeContext {
public:
ExecutionContext(const OperatorBase& op, const Scope& scope,
......@@ -353,8 +338,8 @@ class ExecutionContext : public InferShapeContext {
: InferShapeContext(op, scope), device_context_(device_context) {}
template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
typename DeviceType = typename platform::EigenDeviceConverter<
PlaceType>::EigenDeviceType>
DeviceType& GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_->GetPlace(); }
......
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
namespace math {
template <typename Place, typename T>
struct sigmoid {
void operator()(const platform::DeviceContext& deice_context,
const framework::Tensor& input, framework::Tensor* output) {
auto x = framework::EigenVector<T>::Flatten(*output);
auto y = framework::EigenVector<T>::Flatten(input);
auto* place = device_context.get_eigen_device<Place>();
y.device(*place) = 1. / (1. + (-x).exp());
}
};
}
}
}
......@@ -16,8 +16,8 @@ namespace paddle {
namespace platform {
template <>
Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>()
const {
Eigen::DefaultDevice*
DeviceContext::get_eigen_device<CPUPlace, Eigen::DefaultDevice>() const {
return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device();
}
......@@ -91,7 +91,8 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
};
template <>
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
Eigen::GpuDevice* DeviceContext::get_eigen_device<GPUPlace, Eigen::GpuDevice>()
const {
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
}
......
......@@ -27,12 +27,29 @@ limitations under the License. */
namespace paddle {
namespace platform {
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif
class DeviceContext {
public:
virtual ~DeviceContext() {}
virtual Place GetPlace() const = 0;
template <typename DeviceType>
template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* get_eigen_device() const;
};
......
......@@ -24,7 +24,7 @@ TEST(Device, Init) {
for (int i = 0; i < count; i++) {
DeviceContext* device_context = new CUDADeviceContext(GPUPlace(i));
Eigen::GpuDevice* gpu_device =
device_context->template get_eigen_device<Eigen::GpuDevice>();
device_context->template get_eigen_device<GPUPlace>();
ASSERT_NE(nullptr, gpu_device);
delete device_context;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册