提交 2a03e380 编写于 作者: Q qijun

set correct place for output tensor

上级 6dc567a5
...@@ -18,14 +18,14 @@ namespace paddle { ...@@ -18,14 +18,14 @@ namespace paddle {
namespace framework { namespace framework {
template <> template <>
Eigen::DefaultDevice* OpKernel::KernelContext::get_eigen_device< Eigen::DefaultDevice* OpKernel::KernelContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const { platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>(); return device_context_.get_eigen_device<Eigen::DefaultDevice>();
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice* OpKernel::KernelContext::get_eigen_device< Eigen::GpuDevice* OpKernel::KernelContext::GetEigenDevice<
platform::GPUPlace, Eigen::GpuDevice>() const { platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<Eigen::GpuDevice>(); return device_context_.get_eigen_device<Eigen::GpuDevice>();
} }
......
...@@ -109,7 +109,9 @@ class OpKernel { ...@@ -109,7 +109,9 @@ class OpKernel {
template <typename PlaceType, template <typename PlaceType,
typename DeviceType = typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType> typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* get_eigen_device() const; DeviceType* GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); }
const OperatorBase& op_; const OperatorBase& op_;
const ScopePtr& scope_; const ScopePtr& scope_;
......
...@@ -27,9 +27,9 @@ public: ...@@ -27,9 +27,9 @@ public:
auto input1 = context.Input(1)->Get<framework::Tensor>(); auto input1 = context.Input(1)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto* output = context.Output(0)->GetMutable<framework::Tensor>();
output->mutable_data<T>(Place()); output->mutable_data<T>(context.GetPlace());
output->flat<T>().device(*(context.get_eigen_device<Place>())) = output->flat<T>().device(*(context.GetEigenDevice<Place>())) =
input0.flat<T>() + input1.flat<T>(); input0.flat<T>() + input1.flat<T>();
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册