From 2a03e3808d48257a71366f5802aeec052914e1cc Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 17 Jul 2017 16:45:42 +0800 Subject: [PATCH] set correct place for output tensor --- paddle/framework/operator.cc | 4 ++-- paddle/framework/operator.h | 4 +++- paddle/operators/add_op.h | 4 ++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 946bde573..1a7e33222 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -18,14 +18,14 @@ namespace paddle { namespace framework { template <> -Eigen::DefaultDevice* OpKernel::KernelContext::get_eigen_device< +Eigen::DefaultDevice* OpKernel::KernelContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { return device_context_.get_eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> -Eigen::GpuDevice* OpKernel::KernelContext::get_eigen_device< +Eigen::GpuDevice* OpKernel::KernelContext::GetEigenDevice< platform::GPUPlace, Eigen::GpuDevice>() const { return device_context_.get_eigen_device(); } diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index e6cae9c32..b8c5098e4 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -109,7 +109,9 @@ class OpKernel { template ::EigenDeviceType> - DeviceType* get_eigen_device() const; + DeviceType* GetEigenDevice() const; + + platform::Place GetPlace() const { return device_context_.GetPlace(); } const OperatorBase& op_; const ScopePtr& scope_; diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index e8c718669..e9a793d23 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -27,9 +27,9 @@ public: auto input1 = context.Input(1)->Get(); auto* output = context.Output(0)->GetMutable(); - output->mutable_data(Place()); + output->mutable_data(context.GetPlace()); - output->flat().device(*(context.get_eigen_device())) = + output->flat().device(*(context.GetEigenDevice())) = input0.flat() + input1.flat(); } }; -- GitLab