You need to sign in or sign up before continuing.
提交 784826a4 编写于 作者: L luotao1

enhance cache runtime_context for different scope

test=develop
上级 2fb38c10
......@@ -20,6 +20,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_proto_maker.h"
......@@ -917,13 +918,18 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig(
void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const {
if (!runtime_ctx_) {
const Scope* cur_scope = &scope;
if (!runtime_ctx_ || pre_scope_ != cur_scope ||
scope.FindVar(details::kLocalExecScopeName)) {
// RuntimeContext is used to relate input/output names of Operator with
// the corresponding variables in Scope.
// Since the input/output names of Operator do not change in the execution,
// RuntimeContext could be created only at the first iteration of
// the execution to save the elapsed time.
runtime_ctx_ = new RuntimeContext(Inputs(), Outputs(), scope);
// In a same Scope, since the input/output names of Operator do not change
// in the execution, RuntimeContext could be created only at the first
// iteration of the execution to save the elapsed time.
// Note that the Scope should not be the local scope, since local scope
// would be cleaned regularly.
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope));
pre_scope_ = cur_scope;
}
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
......@@ -963,8 +969,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope = PrepareData(scope, expected_kernel_key,
&transfered_inplace_vars, runtime_ctx_);
auto* transfer_scope = PrepareData(
scope, expected_kernel_key, &transfered_inplace_vars, runtime_ctx_.get());
// exec scope is the scope that kernel actually executed on.
const Scope& exec_scope =
......
......@@ -543,7 +543,8 @@ class OperatorWithKernel : public OperatorBase {
protected:
mutable OpKernelConfigsMap kernel_configs_map_;
mutable RuntimeContext* runtime_ctx_ = nullptr;
mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
mutable const Scope* pre_scope_ = nullptr;
};
extern bool OpSupportGPU(const std::string& op_type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册