提交 b18e6141 编写于 作者: D dongzhihong

"change device context to pointer"

上级 d911b1b5
...@@ -22,14 +22,14 @@ namespace framework { ...@@ -22,14 +22,14 @@ namespace framework {
template <> template <>
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const { platform::CPUPlace, Eigen::DefaultDevice>() const {
return *device_context_.get_eigen_device<Eigen::DefaultDevice>(); return *device_context_->get_eigen_device<Eigen::DefaultDevice>();
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice& Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return *device_context_.get_eigen_device<Eigen::GpuDevice>(); return *device_context_->get_eigen_device<Eigen::GpuDevice>();
} }
#endif #endif
......
...@@ -252,7 +252,7 @@ struct EigenDeviceConverter<platform::GPUPlace> { ...@@ -252,7 +252,7 @@ struct EigenDeviceConverter<platform::GPUPlace> {
class ExecutionContext : public OperatorContext { class ExecutionContext : public OperatorContext {
public: public:
ExecutionContext(const OperatorBase* op, const Scope& scope, ExecutionContext(const OperatorBase* op, const Scope& scope,
const platform::DeviceContext& device_context) const platform::DeviceContext* device_context)
: OperatorContext(op, scope), device_context_(device_context) {} : OperatorContext(op, scope), device_context_(device_context) {}
template <typename PlaceType, template <typename PlaceType,
...@@ -260,9 +260,9 @@ class ExecutionContext : public OperatorContext { ...@@ -260,9 +260,9 @@ class ExecutionContext : public OperatorContext {
typename EigenDeviceConverter<PlaceType>::EigenDeviceType> typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType& GetEigenDevice() const; DeviceType& GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); } platform::Place GetPlace() const { return device_context_->GetPlace(); }
const platform::DeviceContext& device_context_; const platform::DeviceContext* device_context_;
}; };
class OpKernel { class OpKernel {
...@@ -311,7 +311,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -311,7 +311,7 @@ class OperatorWithKernel : public OperatorBase {
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const final { const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(ExecutionContext(this, scope, dev_ctx)); opKernel->Compute(ExecutionContext(this, scope, &dev_ctx));
} }
static std::unordered_map<std::string /* op_type */, OpKernelMap>& static std::unordered_map<std::string /* op_type */, OpKernelMap>&
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册