提交 2a03e380 编写于 作者: Q qijun

set correct place for output tensor

上级 6dc567a5
......@@ -18,14 +18,14 @@ namespace paddle {
namespace framework {
template <>
Eigen::DefaultDevice* OpKernel::KernelContext::get_eigen_device<
Eigen::DefaultDevice* OpKernel::KernelContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>();
}
#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice* OpKernel::KernelContext::get_eigen_device<
Eigen::GpuDevice* OpKernel::KernelContext::GetEigenDevice<
platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<Eigen::GpuDevice>();
}
......
......@@ -109,7 +109,9 @@ class OpKernel {
template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* get_eigen_device() const;
DeviceType* GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); }
const OperatorBase& op_;
const ScopePtr& scope_;
......
......@@ -27,9 +27,9 @@ public:
auto input1 = context.Input(1)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
output->mutable_data<T>(Place());
output->mutable_data<T>(context.GetPlace());
output->flat<T>().device(*(context.get_eigen_device<Place>())) =
output->flat<T>().device(*(context.GetEigenDevice<Place>())) =
input0.flat<T>() + input1.flat<T>();
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册