未验证 提交 dbb92ee4 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #16002 from luotao1/runtime_context

cache runtime_context
...@@ -69,6 +69,7 @@ pass_library(conv_affine_channel_fuse_pass inference) ...@@ -69,6 +69,7 @@ pass_library(conv_affine_channel_fuse_pass inference)
pass_library(transpose_flatten_concat_fuse_pass inference) pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(identity_scale_op_clean_pass base) pass_library(identity_scale_op_clean_pass base)
pass_library(sync_batch_norm_pass base) pass_library(sync_batch_norm_pass base)
pass_library(runtime_context_cache_pass base)
# There may be many transpose-flatten structures in a model, and the output of # There may be many transpose-flatten structures in a model, and the output of
# these structures will be used as inputs to the concat Op. This pattern will # these structures will be used as inputs to the concat Op. This pattern will
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/runtime_context_cache_pass.h"
#include <memory>
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> RuntimeContextCachePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Applies Runtime Context Cache strategy.";
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
n->Op()->SetAttr(kEnableCacheRuntimeContext, true);
}
}
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(runtime_context_cache_pass,
paddle::framework::ir::RuntimeContextCachePass);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class RuntimeContextCachePass : public Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -874,9 +874,23 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig( ...@@ -874,9 +874,23 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig(
return kernel_configs; return kernel_configs;
} }
RuntimeContext* OperatorWithKernel::GetRuntimeContext(
const Scope& scope) const {
if (!HasAttr(kEnableCacheRuntimeContext)) {
return new RuntimeContext(Inputs(), Outputs(), scope);
} else {
const Scope* cur_scope = &scope;
if (!runtime_ctx_ || pre_scope_ != cur_scope) {
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope));
pre_scope_ = cur_scope;
}
return runtime_ctx_.get();
}
}
void OperatorWithKernel::RunImpl(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeContext ctx(Inputs(), Outputs(), scope); auto runtime_ctx = GetRuntimeContext(scope);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
...@@ -891,7 +905,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -891,7 +905,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = this->GetExpectedKernelType( auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr)); ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx, nullptr));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
...@@ -915,8 +929,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -915,8 +929,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// do data transformScope &transfer_scope; // do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars; std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope = auto* transfer_scope = PrepareData(scope, expected_kernel_key,
PrepareData(scope, expected_kernel_key, &transfered_inplace_vars, &ctx); &transfered_inplace_vars, runtime_ctx);
// exec scope is the scope that kernel actually executed on. // exec scope is the scope that kernel actually executed on.
const Scope& exec_scope = const Scope& exec_scope =
...@@ -927,13 +941,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -927,13 +941,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} }
if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) { if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) {
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx); RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
} }
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
// not Scope. Imperative mode only pass inputs and get outputs. // not Scope. Imperative mode only pass inputs and get outputs.
kernel_iter->second( kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx,
ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs)); *runtime_ctx, kernel_configs));
if (!transfered_inplace_vars.empty()) { if (!transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered. // there is inplace variable has been transfered.
......
...@@ -62,6 +62,14 @@ constexpr char kZeroVarSuffix[] = "@ZERO"; ...@@ -62,6 +62,14 @@ constexpr char kZeroVarSuffix[] = "@ZERO";
/// Variables with this suffix are the new Gradient. /// Variables with this suffix are the new Gradient.
constexpr char kNewGradSuffix[] = "@NEWGRAD@"; constexpr char kNewGradSuffix[] = "@NEWGRAD@";
/// RuntimeContext is used to relate input/output names of Operator with
/// the corresponding variables in name scope.
/// If an Op has attribute kEnableCacheRuntimeContext, it means that in a same
/// name scope, since the input/output names of this Op do not change in the
/// execution, RuntimeContext could be created only at the first iteration of
/// this Op's execution to save the elapsed time.
constexpr char kEnableCacheRuntimeContext[] = "@ENABLE_CACHE_RUNTIME_CONTEXT@";
/// If an Op has this attribute, all its kernels should calculate output /// If an Op has this attribute, all its kernels should calculate output
/// variable's shape in the corresponding Compute() function. And /// variable's shape in the corresponding Compute() function. And
/// OperatorWithKernel::RunImpl() would skip call this Op's InferShape() /// OperatorWithKernel::RunImpl() would skip call this Op's InferShape()
...@@ -456,6 +464,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -456,6 +464,7 @@ class OperatorWithKernel : public OperatorBase {
// same. // same.
proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
void RunImpl(const Scope& scope, const platform::Place& place) const final; void RunImpl(const Scope& scope, const platform::Place& place) const final;
RuntimeContext* GetRuntimeContext(const Scope& scope) const;
/** /**
* Transfer data from scope to a transfered scope. If there is no data need to * Transfer data from scope to a transfered scope. If there is no data need to
...@@ -474,6 +483,8 @@ class OperatorWithKernel : public OperatorBase { ...@@ -474,6 +483,8 @@ class OperatorWithKernel : public OperatorBase {
protected: protected:
mutable OpKernelConfigsMap kernel_configs_map_; mutable OpKernelConfigsMap kernel_configs_map_;
mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
mutable const Scope* pre_scope_ = nullptr;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
...@@ -118,6 +118,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -118,6 +118,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(serialized_info_cache_); CP_MEMBER(serialized_info_cache_);
// framework related.
CP_MEMBER(enable_runtime_context_cache_);
if (use_gpu_) { if (use_gpu_) {
pass_builder_.reset(new GpuPassStrategy( pass_builder_.reset(new GpuPassStrategy(
*static_cast<GpuPassStrategy *>(other.pass_builder()))); *static_cast<GpuPassStrategy *>(other.pass_builder())));
...@@ -232,6 +235,10 @@ void AnalysisConfig::Update() { ...@@ -232,6 +235,10 @@ void AnalysisConfig::Update() {
if (ir_debug_) { if (ir_debug_) {
pass_builder()->TurnOnDebug(); pass_builder()->TurnOnDebug();
} }
if (enable_runtime_context_cache_) {
pass_builder()->AppendPass("runtime_context_cache_pass");
}
} }
std::string AnalysisConfig::SerializeInfoCache() { std::string AnalysisConfig::SerializeInfoCache() {
...@@ -265,6 +272,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ...@@ -265,6 +272,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << specify_input_name_; ss << specify_input_name_;
ss << cpu_math_library_num_threads_; ss << cpu_math_library_num_threads_;
ss << enable_runtime_context_cache_;
return ss.str(); return ss.str();
} }
......
...@@ -194,6 +194,23 @@ struct AnalysisConfig { ...@@ -194,6 +194,23 @@ struct AnalysisConfig {
/** Tell whether the memory optimization is activated. */ /** Tell whether the memory optimization is activated. */
bool enable_memory_optim() const; bool enable_memory_optim() const;
// framework related
/** \brief Control whether to perform runtime context cache optimization.
*
* If turned off, in Op's every execution, RuntimeContext would be called to
* relate input/output names of this Op with the corresponding variables in
* Scope.
*/
void SwitchRuntimeContextCache(int x = true) {
enable_runtime_context_cache_ = x;
}
/** A boolean state tell whether the runtime context cache optimization is
* actived.
*/
bool runtime_context_cache_enabled() const {
return enable_runtime_context_cache_;
}
friend class ::paddle::AnalysisPredictor; friend class ::paddle::AnalysisPredictor;
/** NOTE just for developer, not an official API, easily to be broken. /** NOTE just for developer, not an official API, easily to be broken.
...@@ -254,6 +271,15 @@ struct AnalysisConfig { ...@@ -254,6 +271,15 @@ struct AnalysisConfig {
int cpu_math_library_num_threads_{1}; int cpu_math_library_num_threads_{1};
// framework related
// RuntimeContext is used to relate input/output names of Operator with
// the corresponding variables in Scope.
// If enable_runtime_context_cache_ is true, it means that in a same Scope,
// since the input/output names of this Op do not change in the execution,
// RuntimeContext could be created only at the first iteration of this Op's
// execution to save the elapsed time.
bool enable_runtime_context_cache_{false};
// A runtime cache, shouldn't be transferred to others. // A runtime cache, shouldn't be transferred to others.
std::string serialized_info_cache_; std::string serialized_info_cache_;
......
...@@ -107,6 +107,7 @@ void SetConfig(AnalysisConfig *cfg) { ...@@ -107,6 +107,7 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->DisableGpu(); cfg->DisableGpu();
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
cfg->SwitchIrOptim(); cfg->SwitchIrOptim();
cfg->SwitchRuntimeContextCache();
if (FLAGS_zero_copy) { if (FLAGS_zero_copy) {
cfg->SwitchUseFeedFetchOps(false); cfg->SwitchUseFeedFetchOps(false);
} }
......
...@@ -72,7 +72,8 @@ std::ostream &operator<<(std::ostream &os, const AnalysisConfig &config) { ...@@ -72,7 +72,8 @@ std::ostream &operator<<(std::ostream &os, const AnalysisConfig &config) {
} }
os << GenSpaces(num_spaces) << "enable_ir_optim: " << config.ir_optim() os << GenSpaces(num_spaces) << "enable_ir_optim: " << config.ir_optim()
<< "\n"; << "\n";
os << GenSpaces(num_spaces) << "enable_ir_optim: " << config.ir_optim() os << GenSpaces(num_spaces)
<< "use_runtime_context_cache: " << config.runtime_context_cache_enabled()
<< "\n"; << "\n";
os << GenSpaces(num_spaces) os << GenSpaces(num_spaces)
<< "use_feed_fetch_ops: " << config.use_feed_fetch_ops_enabled() << "\n"; << "use_feed_fetch_ops: " << config.use_feed_fetch_ops_enabled() << "\n";
......
...@@ -242,6 +242,10 @@ void BindAnalysisConfig(py::module *m) { ...@@ -242,6 +242,10 @@ void BindAnalysisConfig(py::module *m) {
.def("set_mkldnn_op", &AnalysisConfig::SetMKLDNNOp) .def("set_mkldnn_op", &AnalysisConfig::SetMKLDNNOp)
.def("set_model_buffer", &AnalysisConfig::SetModelBuffer) .def("set_model_buffer", &AnalysisConfig::SetModelBuffer)
.def("model_from_memory", &AnalysisConfig::model_from_memory) .def("model_from_memory", &AnalysisConfig::model_from_memory)
.def("runtime_context_cache_enabled",
&AnalysisConfig::runtime_context_cache_enabled)
.def("switch_runtime_context_cache",
&AnalysisConfig::SwitchRuntimeContextCache, py::arg("x") = true)
.def("pass_builder", &AnalysisConfig::pass_builder, .def("pass_builder", &AnalysisConfig::pass_builder,
py::return_value_policy::reference); py::return_value_policy::reference);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册