diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index cb86e6be2be3624bf54ee28193ca5d4c7bafa0eb..beb6793289812cfaa6991d28379126ff29fa2547 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -22,14 +22,14 @@ namespace framework { template <> Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { - return *device_context_.get_eigen_device(); + return *device_context_->get_eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> Eigen::GpuDevice& ExecutionContext::GetEigenDevice() const { - return *device_context_.get_eigen_device(); + return *device_context_->get_eigen_device(); } #endif diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index d42e21c0a235791db42076555d0568ff8f4acbe2..b25362fef336fd84934e901108b6c8358463fe03 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -252,7 +252,7 @@ struct EigenDeviceConverter { class ExecutionContext : public OperatorContext { public: ExecutionContext(const OperatorBase* op, const Scope& scope, - const platform::DeviceContext& device_context) + const platform::DeviceContext* device_context) : OperatorContext(op, scope), device_context_(device_context) {} template ::EigenDeviceType> DeviceType& GetEigenDevice() const; - platform::Place GetPlace() const { return device_context_.GetPlace(); } + platform::Place GetPlace() const { return device_context_->GetPlace(); } - const platform::DeviceContext& device_context_; + const platform::DeviceContext* device_context_; }; class OpKernel { @@ -311,7 +311,7 @@ class OperatorWithKernel : public OperatorBase { void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const final { auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); - opKernel->Compute(ExecutionContext(this, scope, dev_ctx)); + opKernel->Compute(ExecutionContext(this, scope, &dev_ctx)); } static std::unordered_map&