提交 9773f38f 编写于 作者: L luotao1

cache runtime_context

test=develop
上级 187cffd0
...@@ -916,7 +916,14 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig( ...@@ -916,7 +916,14 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig(
void OperatorWithKernel::RunImpl(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeContext ctx(Inputs(), Outputs(), scope); if (!runtime_ctx_) {
// RuntimeContext is used to relate input/output names of Operator with
// the corresponding variables in Scope.
// Since the input/output names of Operator do not change in the execution,
// RuntimeContext could be created only at the first iteration of
// the execution to save the elapsed time.
runtime_ctx_ = new RuntimeContext(Inputs(), Outputs(), scope);
}
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
...@@ -931,7 +938,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -931,7 +938,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = this->GetExpectedKernelType( auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr)); ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx_, nullptr));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
...@@ -955,8 +962,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -955,8 +962,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// do data transformScope &transfer_scope; // do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars; std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope = auto* transfer_scope = PrepareData(scope, expected_kernel_key,
PrepareData(scope, expected_kernel_key, &transfered_inplace_vars, &ctx); &transfered_inplace_vars, runtime_ctx_);
// exec scope is the scope that kernel actually executed on. // exec scope is the scope that kernel actually executed on.
const Scope& exec_scope = const Scope& exec_scope =
...@@ -966,12 +973,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -966,12 +973,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
} }
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx); RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx_);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
// not Scope. Imperative mode only pass inputs and get outputs. // not Scope. Imperative mode only pass inputs and get outputs.
kernel_iter->second( kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx,
ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs)); *runtime_ctx_, kernel_configs));
if (!transfered_inplace_vars.empty()) { if (!transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered. // there is inplace variable has been transfered.
......
...@@ -541,6 +541,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -541,6 +541,7 @@ class OperatorWithKernel : public OperatorBase {
protected: protected:
mutable OpKernelConfigsMap kernel_configs_map_; mutable OpKernelConfigsMap kernel_configs_map_;
mutable RuntimeContext* runtime_ctx_ = nullptr;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册