From b18e6141639807406e5569a0e447cd0d1198bcf6 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Mon, 7 Aug 2017 09:43:57 +0800 Subject: [PATCH] "change device context to pointer" --- paddle/framework/operator.cc | 4 ++-- paddle/framework/operator.h | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index cb86e6be2be..beb67932898 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -22,14 +22,14 @@ namespace framework { template <> Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { - return *device_context_.get_eigen_device(); + return *device_context_->get_eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> Eigen::GpuDevice& ExecutionContext::GetEigenDevice() const { - return *device_context_.get_eigen_device(); + return *device_context_->get_eigen_device(); } #endif diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index d42e21c0a23..b25362fef33 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -252,7 +252,7 @@ struct EigenDeviceConverter { class ExecutionContext : public OperatorContext { public: ExecutionContext(const OperatorBase* op, const Scope& scope, - const platform::DeviceContext& device_context) + const platform::DeviceContext* device_context) : OperatorContext(op, scope), device_context_(device_context) {} template ::EigenDeviceType> DeviceType& GetEigenDevice() const; - platform::Place GetPlace() const { return device_context_.GetPlace(); } + platform::Place GetPlace() const { return device_context_->GetPlace(); } - const platform::DeviceContext& device_context_; + const platform::DeviceContext* device_context_; }; class OpKernel { @@ -311,7 +311,7 @@ class OperatorWithKernel : public OperatorBase { void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const final { auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); - opKernel->Compute(ExecutionContext(this, scope, dev_ctx)); + opKernel->Compute(ExecutionContext(this, scope, &dev_ctx)); } static std::unordered_map& -- GitLab