diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 3744eae69689633c91c6543a25e00cb769ee6043..febad37b42b1ab4a995e0b62ca3440b72efc13de 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -488,6 +488,8 @@ void OperatorWithKernel::Run(const Scope& scope, } } + VLOG(3) << "expected_kernel_key:" << expected_kernel_key; + Scope& new_scope = scope.NewScope(); for (auto& var_name_item : this->Inputs()) { @@ -520,7 +522,8 @@ void OperatorWithKernel::Run(const Scope& scope, auto kernel_iter = kernels.find(expected_kernel_key); - kernel_iter->second->Compute(ExecutionContext(*this, new_scope, *dev_ctx)); + kernel_iter->second->Compute(ExecutionContext( + *this, new_scope, *pool.Get(expected_kernel_key.place_))); } proto::DataType OperatorWithKernel::IndicateDataType( diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc index 387d1e0a747f71d85826b52d140c2838112227f6..48c01f984f825208d911a06c6e48b802fa24aa0e 100644 --- a/paddle/operators/fetch_op.cc +++ b/paddle/operators/fetch_op.cc @@ -53,7 +53,7 @@ class FetchOp : public framework::OperatorBase { // FIXME(yuyang18): Should we assume the fetch operator always generate // CPU outputs? platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); + auto &dev_ctx = *pool.Get(src_item.place()); CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item); dev_ctx.Wait();