提交 784826a4 编写于 作者: L luotao1

enhance cache runtime_context for different scope

test=develop
上级 2fb38c10
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
...@@ -917,13 +918,18 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig( ...@@ -917,13 +918,18 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig(
void OperatorWithKernel::RunImpl(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
if (!runtime_ctx_) { const Scope* cur_scope = &scope;
if (!runtime_ctx_ || pre_scope_ != cur_scope ||
scope.FindVar(details::kLocalExecScopeName)) {
// RuntimeContext is used to relate input/output names of Operator with // RuntimeContext is used to relate input/output names of Operator with
// the corresponding variables in Scope. // the corresponding variables in Scope.
// Since the input/output names of Operator do not change in the execution, // In a same Scope, since the input/output names of Operator do not change
// RuntimeContext could be created only at the first iteration of // in the execution, RuntimeContext could be created only at the first
// the execution to save the elapsed time. // iteration of the execution to save the elapsed time.
runtime_ctx_ = new RuntimeContext(Inputs(), Outputs(), scope); // Note that the Scope should not be the local scope, since local scope
// would be cleaned regularly.
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope));
pre_scope_ = cur_scope;
} }
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
...@@ -963,8 +969,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -963,8 +969,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// do data transformScope &transfer_scope; // do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars; std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope = PrepareData(scope, expected_kernel_key, auto* transfer_scope = PrepareData(
&transfered_inplace_vars, runtime_ctx_); scope, expected_kernel_key, &transfered_inplace_vars, runtime_ctx_.get());
// exec scope is the scope that kernel actually executed on. // exec scope is the scope that kernel actually executed on.
const Scope& exec_scope = const Scope& exec_scope =
......
...@@ -543,7 +543,8 @@ class OperatorWithKernel : public OperatorBase { ...@@ -543,7 +543,8 @@ class OperatorWithKernel : public OperatorBase {
protected: protected:
mutable OpKernelConfigsMap kernel_configs_map_; mutable OpKernelConfigsMap kernel_configs_map_;
mutable RuntimeContext* runtime_ctx_ = nullptr; mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
mutable const Scope* pre_scope_ = nullptr;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册