提交 d94fd972 编写于 作者: L luotao1

add runtime_context_cache_pass

test=develop
上级 b561ad1e
...@@ -66,6 +66,7 @@ pass_library(conv_elementwise_add_fuse_pass inference) ...@@ -66,6 +66,7 @@ pass_library(conv_elementwise_add_fuse_pass inference)
pass_library(conv_affine_channel_fuse_pass inference) pass_library(conv_affine_channel_fuse_pass inference)
pass_library(transpose_flatten_concat_fuse_pass inference) pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(identity_scale_op_clean_pass base) pass_library(identity_scale_op_clean_pass base)
pass_library(runtime_context_cache_pass base)
# There may be many transpose-flatten structures in a model, and the output of # There may be many transpose-flatten structures in a model, and the output of
# these structures will be used as inputs to the concat Op. This pattern will # these structures will be used as inputs to the concat Op. This pattern will
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/runtime_context_cache_pass.h"
#include <memory>
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> RuntimeContextCachePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Applies Runtime Context Cache strategy.";
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
n->Op()->SetAttr(kEnableRuntimeContext, true);
}
}
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(runtime_context_cache_pass,
paddle::framework::ir::RuntimeContextCachePass);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class RuntimeContextCachePass : public Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -20,7 +20,6 @@ limitations under the License. */ ...@@ -20,7 +20,6 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
...@@ -877,19 +876,14 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig( ...@@ -877,19 +876,14 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig(
void OperatorWithKernel::RunImpl(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
const Scope* cur_scope = &scope; if (!HasAttr(kEnableRuntimeContext)) {
// RuntimeContext is used to relate input/output names of Operator with
// the corresponding variables in Scope.
// In a same Scope, since the input/output names of Operator do not change
// in the execution, RuntimeContext could be created only at the first
// iteration of the execution to save the elapsed time.
// Note that the Scope should not be the local scope, since local scope
// would be cleaned regularly.
if (scope.FindVar(details::kLocalExecScopeName)) {
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope)); runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope));
} else if (!runtime_ctx_ || pre_scope_ != cur_scope) { } else {
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope)); const Scope* cur_scope = &scope;
pre_scope_ = cur_scope; if (!runtime_ctx_ || pre_scope_ != cur_scope) {
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope));
pre_scope_ = cur_scope;
}
} }
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
......
...@@ -62,6 +62,14 @@ constexpr char kZeroVarSuffix[] = "@ZERO"; ...@@ -62,6 +62,14 @@ constexpr char kZeroVarSuffix[] = "@ZERO";
/// Variables with this suffix are the new Gradient. /// Variables with this suffix are the new Gradient.
constexpr char kNewGradSuffix[] = "@NEWGRAD@"; constexpr char kNewGradSuffix[] = "@NEWGRAD@";
/// RuntimeContext is used to relate input/output names of Operator with
/// the corresponding variables in Scope.
/// If an Op has attribute kEnableRuntimeContext, it means that in a same Scope,
/// since the input/output names of this Op do not change in the execution,
/// RuntimeContext could be created only at the first iteration of this Op's
/// execution to save the elapsed time.
constexpr char kEnableRuntimeContext[] = "@ENABLE_RUNTIME_CONTEXT@";
// define some kernel priority // define some kernel priority
/* Define multiple kernel type fallback order*/ /* Define multiple kernel type fallback order*/
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority; extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;
......
...@@ -107,10 +107,6 @@ const Scope* Scope::FindScope(const Variable* var) const { ...@@ -107,10 +107,6 @@ const Scope* Scope::FindScope(const Variable* var) const {
return FindScopeInternal(var); return FindScopeInternal(var);
} }
bool Scope::HasLocalVar(const std::string& name) const {
return vars_.find(name) != vars_.end();
}
void Scope::DropKids() { void Scope::DropKids() {
SCOPE_KIDS_WRITER_LOCK SCOPE_KIDS_WRITER_LOCK
for (Scope* s : kids_) delete s; for (Scope* s : kids_) delete s;
......
...@@ -75,10 +75,6 @@ class Scope { ...@@ -75,10 +75,6 @@ class Scope {
/// Caller doesn't own the returned Variable. /// Caller doesn't own the returned Variable.
Variable* FindLocalVar(const std::string& name) const; Variable* FindLocalVar(const std::string& name) const;
/// Find whether a variable in the current scope.
/// Return false if cannot find.
bool HasLocalVar(const std::string& name) const;
const Scope* parent() const { return parent_; } const Scope* parent() const { return parent_; }
/// Find the scope or an ancestor scope that contains the given variable. /// Find the scope or an ancestor scope that contains the given variable.
......
...@@ -118,6 +118,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -118,6 +118,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(serialized_info_cache_); CP_MEMBER(serialized_info_cache_);
// framework related.
CP_MEMBER(enable_runtime_context_cache_);
if (use_gpu_) { if (use_gpu_) {
pass_builder_.reset(new GpuPassStrategy( pass_builder_.reset(new GpuPassStrategy(
*static_cast<GpuPassStrategy *>(other.pass_builder()))); *static_cast<GpuPassStrategy *>(other.pass_builder())));
...@@ -225,6 +228,10 @@ void AnalysisConfig::Update() { ...@@ -225,6 +228,10 @@ void AnalysisConfig::Update() {
if (ir_debug_) { if (ir_debug_) {
pass_builder()->TurnOnDebug(); pass_builder()->TurnOnDebug();
} }
if (enable_runtime_context_cache_) {
pass_builder()->AppendPass("runtime_context_cache_pass");
}
} }
std::string AnalysisConfig::SerializeInfoCache() { std::string AnalysisConfig::SerializeInfoCache() {
...@@ -258,6 +265,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ...@@ -258,6 +265,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << specify_input_name_; ss << specify_input_name_;
ss << cpu_math_library_num_threads_; ss << cpu_math_library_num_threads_;
ss << enable_runtime_context_cache_;
return ss.str(); return ss.str();
} }
......
...@@ -194,6 +194,23 @@ struct AnalysisConfig { ...@@ -194,6 +194,23 @@ struct AnalysisConfig {
/** Tell whether the memory optimization is activated. */ /** Tell whether the memory optimization is activated. */
bool enable_memory_optim() const; bool enable_memory_optim() const;
// framework related
/** \brief Control whether to perform runtime context cache optimization.
*
* If turned off, in Op's every execution, RuntimeContext would be called to
* relate input/output names of this Op with the corresponding variables in
* Scope.
*/
void SwitchRuntimeContextCache(int x = true) {
enable_runtime_context_cache_ = x;
}
/** A boolean state tell whether the runtime context cache optimization is
* actived.
*/
bool runtime_context_cache_enabled() const {
return enable_runtime_context_cache_;
}
friend class ::paddle::AnalysisPredictor; friend class ::paddle::AnalysisPredictor;
/** NOTE just for developer, not an official API, easily to be broken. /** NOTE just for developer, not an official API, easily to be broken.
...@@ -254,6 +271,15 @@ struct AnalysisConfig { ...@@ -254,6 +271,15 @@ struct AnalysisConfig {
int cpu_math_library_num_threads_{1}; int cpu_math_library_num_threads_{1};
// framework related
// RuntimeContext is used to relate input/output names of Operator with
// the corresponding variables in Scope.
// If enable_runtime_context_cache_ is true, it means that in a same Scope,
// since the input/output names of this Op do not change in the execution,
// RuntimeContext could be created only at the first iteration of this Op's
// execution to save the elapsed time.
bool enable_runtime_context_cache_{true};
// A runtime cache, shouldn't be transferred to others. // A runtime cache, shouldn't be transferred to others.
std::string serialized_info_cache_; std::string serialized_info_cache_;
......
...@@ -72,7 +72,8 @@ std::ostream &operator<<(std::ostream &os, const AnalysisConfig &config) { ...@@ -72,7 +72,8 @@ std::ostream &operator<<(std::ostream &os, const AnalysisConfig &config) {
} }
os << GenSpaces(num_spaces) << "enable_ir_optim: " << config.ir_optim() os << GenSpaces(num_spaces) << "enable_ir_optim: " << config.ir_optim()
<< "\n"; << "\n";
os << GenSpaces(num_spaces) << "enable_ir_optim: " << config.ir_optim() os << GenSpaces(num_spaces)
<< "use_runtime_context_cache: " << config.runtime_context_cache_enabled()
<< "\n"; << "\n";
os << GenSpaces(num_spaces) os << GenSpaces(num_spaces)
<< "use_feed_fetch_ops: " << config.use_feed_fetch_ops_enabled() << "\n"; << "use_feed_fetch_ops: " << config.use_feed_fetch_ops_enabled() << "\n";
......
...@@ -242,6 +242,10 @@ void BindAnalysisConfig(py::module *m) { ...@@ -242,6 +242,10 @@ void BindAnalysisConfig(py::module *m) {
.def("set_mkldnn_op", &AnalysisConfig::SetMKLDNNOp) .def("set_mkldnn_op", &AnalysisConfig::SetMKLDNNOp)
.def("set_model_buffer", &AnalysisConfig::SetModelBuffer) .def("set_model_buffer", &AnalysisConfig::SetModelBuffer)
.def("model_from_memory", &AnalysisConfig::model_from_memory) .def("model_from_memory", &AnalysisConfig::model_from_memory)
.def("runtime_context_cache_enabled",
&AnalysisConfig::runtime_context_cache_enabled)
.def("switch_runtime_context_cache",
&AnalysisConfig::SwitchRuntimeContextCache, py::arg("x") = true)
.def("pass_builder", &AnalysisConfig::pass_builder, .def("pass_builder", &AnalysisConfig::pass_builder,
py::return_value_policy::reference); py::return_value_policy::reference);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册