提交 46ee6bb1 编写于 作者: L luotao1

fix distributed unit-tests

test=develop
上级 1b59bed9
...@@ -874,17 +874,23 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig( ...@@ -874,17 +874,23 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig(
return kernel_configs; return kernel_configs;
} }
void OperatorWithKernel::RunImpl(const Scope& scope, RuntimeContext* OperatorWithKernel::GetRuntimeContext(
const platform::Place& place) const { const Scope& scope) const {
if (!HasAttr(kEnableRuntimeContext)) { if (!HasAttr(kEnableRuntimeContext)) {
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope)); return new RuntimeContext(Inputs(), Outputs(), scope);
} else { } else {
const Scope* cur_scope = &scope; const Scope* cur_scope = &scope;
if (!runtime_ctx_ || pre_scope_ != cur_scope) { if (!runtime_ctx_ || pre_scope_ != cur_scope) {
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope)); runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope));
pre_scope_ = cur_scope; pre_scope_ = cur_scope;
} }
return runtime_ctx_.get();
} }
}
void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const {
auto runtime_ctx = GetRuntimeContext(scope);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
...@@ -899,7 +905,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -899,7 +905,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = this->GetExpectedKernelType( auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx_, nullptr)); ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx, nullptr));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
...@@ -923,8 +929,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -923,8 +929,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// do data transformScope &transfer_scope; // do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars; std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope = PrepareData( auto* transfer_scope = PrepareData(scope, expected_kernel_key,
scope, expected_kernel_key, &transfered_inplace_vars, runtime_ctx_.get()); &transfered_inplace_vars, runtime_ctx);
// exec scope is the scope that kernel actually executed on. // exec scope is the scope that kernel actually executed on.
const Scope& exec_scope = const Scope& exec_scope =
...@@ -935,13 +941,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -935,13 +941,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} }
if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) { if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) {
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx_); RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
} }
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
// not Scope. Imperative mode only pass inputs and get outputs. // not Scope. Imperative mode only pass inputs and get outputs.
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx, kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx,
*runtime_ctx_, kernel_configs)); *runtime_ctx, kernel_configs));
if (!transfered_inplace_vars.empty()) { if (!transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered. // there is inplace variable has been transfered.
......
...@@ -464,6 +464,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -464,6 +464,7 @@ class OperatorWithKernel : public OperatorBase {
// same. // same.
proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
void RunImpl(const Scope& scope, const platform::Place& place) const final; void RunImpl(const Scope& scope, const platform::Place& place) const final;
RuntimeContext* GetRuntimeContext(const Scope& scope) const;
/** /**
* Transfer data from scope to a transfered scope. If there is no data need to * Transfer data from scope to a transfered scope. If there is no data need to
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册