From 46ee6bb1aa9b187a260b5c0080f28be16ab453a3 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Fri, 15 Mar 2019 13:27:51 +0800 Subject: [PATCH] fix distributed unit-tests test=develop --- paddle/fluid/framework/operator.cc | 22 ++++++++++++++-------- paddle/fluid/framework/operator.h | 1 + 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index a9a53b0d74f..ac1ad2b05e1 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -874,17 +874,23 @@ std::vector* OperatorWithKernel::GetKernelConfig( return kernel_configs; } -void OperatorWithKernel::RunImpl(const Scope& scope, - const platform::Place& place) const { +RuntimeContext* OperatorWithKernel::GetRuntimeContext( + const Scope& scope) const { if (!HasAttr(kEnableRuntimeContext)) { - runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope)); + return new RuntimeContext(Inputs(), Outputs(), scope); } else { const Scope* cur_scope = &scope; if (!runtime_ctx_ || pre_scope_ != cur_scope) { runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope)); pre_scope_ = cur_scope; } + return runtime_ctx_.get(); } +} + +void OperatorWithKernel::RunImpl(const Scope& scope, + const platform::Place& place) const { + auto runtime_ctx = GetRuntimeContext(scope); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); @@ -899,7 +905,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, OpKernelMap& kernels = kernels_iter->second; auto expected_kernel_key = this->GetExpectedKernelType( - ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx_, nullptr)); + ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx, nullptr)); VLOG(3) << "expected_kernel_key:" << expected_kernel_key; auto kernel_iter = kernels.find(expected_kernel_key); @@ -923,8 +929,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, // do data transformScope &transfer_scope; std::vector transfered_inplace_vars; - auto* transfer_scope = PrepareData( - scope, expected_kernel_key, &transfered_inplace_vars, runtime_ctx_.get()); + auto* transfer_scope = PrepareData(scope, expected_kernel_key, + &transfered_inplace_vars, runtime_ctx); // exec scope is the scope that kernel actually executed on. const Scope& exec_scope = @@ -935,13 +941,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) { - RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx_); + RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx); this->InferShape(&infer_shape_ctx); } // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext // not Scope. Imperative mode only pass inputs and get outputs. kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx, - *runtime_ctx_, kernel_configs)); + *runtime_ctx, kernel_configs)); if (!transfered_inplace_vars.empty()) { // there is inplace variable has been transfered. diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 323aa5a7f58..f0592f4f5fc 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -464,6 +464,7 @@ class OperatorWithKernel : public OperatorBase { // same. proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; void RunImpl(const Scope& scope, const platform::Place& place) const final; + RuntimeContext* GetRuntimeContext(const Scope& scope) const; /** * Transfer data from scope to a transfered scope. If there is no data need to -- GitLab