From d94fd972306e2fb237212c341adc4f90f7181c06 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Wed, 13 Mar 2019 17:22:39 +0800 Subject: [PATCH] add runtime_context_cache_pass test=develop --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../ir/runtime_context_cache_pass.cc | 39 +++++++++++++++++++ .../framework/ir/runtime_context_cache_pass.h | 32 +++++++++++++++ paddle/fluid/framework/operator.cc | 20 ++++------ paddle/fluid/framework/operator.h | 8 ++++ paddle/fluid/framework/scope.cc | 4 -- paddle/fluid/framework/scope.h | 4 -- paddle/fluid/inference/api/analysis_config.cc | 8 ++++ .../inference/api/paddle_analysis_config.h | 26 +++++++++++++ .../inference/tests/api/config_printer.h | 3 +- paddle/fluid/pybind/inference_api.cc | 4 ++ 11 files changed, 127 insertions(+), 22 deletions(-) create mode 100644 paddle/fluid/framework/ir/runtime_context_cache_pass.cc create mode 100644 paddle/fluid/framework/ir/runtime_context_cache_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index ca6b0229e..f7d82d5ea 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -66,6 +66,7 @@ pass_library(conv_elementwise_add_fuse_pass inference) pass_library(conv_affine_channel_fuse_pass inference) pass_library(transpose_flatten_concat_fuse_pass inference) pass_library(identity_scale_op_clean_pass base) +pass_library(runtime_context_cache_pass base) # There may be many transpose-flatten structures in a model, and the output of # these structures will be used as inputs to the concat Op. This pattern will diff --git a/paddle/fluid/framework/ir/runtime_context_cache_pass.cc b/paddle/fluid/framework/ir/runtime_context_cache_pass.cc new file mode 100644 index 000000000..75f379518 --- /dev/null +++ b/paddle/fluid/framework/ir/runtime_context_cache_pass.cc @@ -0,0 +1,39 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/runtime_context_cache_pass.h" +#include +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +namespace ir { + +std::unique_ptr RuntimeContextCachePass::ApplyImpl( + std::unique_ptr graph) const { + VLOG(3) << "Applies Runtime Context Cache strategy."; + for (const Node* n : graph->Nodes()) { + if (n->IsOp()) { + n->Op()->SetAttr(kEnableRuntimeContext, true); + } + } + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(runtime_context_cache_pass, + paddle::framework::ir::RuntimeContextCachePass); diff --git a/paddle/fluid/framework/ir/runtime_context_cache_pass.h b/paddle/fluid/framework/ir/runtime_context_cache_pass.h new file mode 100644 index 000000000..a6cf1a9ae --- /dev/null +++ b/paddle/fluid/framework/ir/runtime_context_cache_pass.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class RuntimeContextCachePass : public Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index eeced516e..980c5d858 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -20,7 +20,6 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/data_transform.h" -#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_proto_maker.h" @@ -877,19 +876,14 @@ std::vector* OperatorWithKernel::GetKernelConfig( void OperatorWithKernel::RunImpl(const Scope& scope, const platform::Place& place) const { - const Scope* cur_scope = &scope; - // RuntimeContext is used to relate input/output names of Operator with - // the corresponding variables in Scope. - // In a same Scope, since the input/output names of Operator do not change - // in the execution, RuntimeContext could be created only at the first - // iteration of the execution to save the elapsed time. - // Note that the Scope should not be the local scope, since local scope - // would be cleaned regularly. - if (scope.FindVar(details::kLocalExecScopeName)) { + if (!HasAttr(kEnableRuntimeContext)) { runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope)); - } else if (!runtime_ctx_ || pre_scope_ != cur_scope) { - runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope)); - pre_scope_ = cur_scope; + } else { + const Scope* cur_scope = &scope; + if (!runtime_ctx_ || pre_scope_ != cur_scope) { + runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope)); + pre_scope_ = cur_scope; + } } platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 6a2d4478a..29b9c45cc 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -62,6 +62,14 @@ constexpr char kZeroVarSuffix[] = "@ZERO"; /// Variables with this suffix are the new Gradient. constexpr char kNewGradSuffix[] = "@NEWGRAD@"; +/// RuntimeContext is used to relate input/output names of Operator with +/// the corresponding variables in Scope. +/// If an Op has attribute kEnableRuntimeContext, it means that in a same Scope, +/// since the input/output names of this Op do not change in the execution, +/// RuntimeContext could be created only at the first iteration of this Op's +/// execution to save the elapsed time. +constexpr char kEnableRuntimeContext[] = "@ENABLE_RUNTIME_CONTEXT@"; + // define some kernel priority /* Define multiple kernel type fallback order*/ extern std::vector> kKernelPriority; diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index e6de47717..87f0f307d 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -107,10 +107,6 @@ const Scope* Scope::FindScope(const Variable* var) const { return FindScopeInternal(var); } -bool Scope::HasLocalVar(const std::string& name) const { - return vars_.find(name) != vars_.end(); -} - void Scope::DropKids() { SCOPE_KIDS_WRITER_LOCK for (Scope* s : kids_) delete s; diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index 38d3b4d6c..f0915d2ee 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -75,10 +75,6 @@ class Scope { /// Caller doesn't own the returned Variable. Variable* FindLocalVar(const std::string& name) const; - /// Find whether a variable in the current scope. - /// Return false if cannot find. - bool HasLocalVar(const std::string& name) const; - const Scope* parent() const { return parent_; } /// Find the scope or an ancestor scope that contains the given variable. diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 774111122..a9e477f88 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -118,6 +118,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(serialized_info_cache_); + // framework related. + CP_MEMBER(enable_runtime_context_cache_); + if (use_gpu_) { pass_builder_.reset(new GpuPassStrategy( *static_cast(other.pass_builder()))); @@ -225,6 +228,10 @@ void AnalysisConfig::Update() { if (ir_debug_) { pass_builder()->TurnOnDebug(); } + + if (enable_runtime_context_cache_) { + pass_builder()->AppendPass("runtime_context_cache_pass"); + } } std::string AnalysisConfig::SerializeInfoCache() { @@ -258,6 +265,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << specify_input_name_; ss << cpu_math_library_num_threads_; + ss << enable_runtime_context_cache_; return ss.str(); } diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 9b05c3350..85639eebe 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -194,6 +194,23 @@ struct AnalysisConfig { /** Tell whether the memory optimization is activated. */ bool enable_memory_optim() const; + // framework related + /** \brief Control whether to perform runtime context cache optimization. + * + * If turned off, in Op's every execution, RuntimeContext would be called to + * relate input/output names of this Op with the corresponding variables in + * Scope. + */ + void SwitchRuntimeContextCache(int x = true) { + enable_runtime_context_cache_ = x; + } + /** A boolean state tell whether the runtime context cache optimization is + * actived. + */ + bool runtime_context_cache_enabled() const { + return enable_runtime_context_cache_; + } + friend class ::paddle::AnalysisPredictor; /** NOTE just for developer, not an official API, easily to be broken. @@ -254,6 +271,15 @@ struct AnalysisConfig { int cpu_math_library_num_threads_{1}; + // framework related + // RuntimeContext is used to relate input/output names of Operator with + // the corresponding variables in Scope. + // If enable_runtime_context_cache_ is true, it means that in a same Scope, + // since the input/output names of this Op do not change in the execution, + // RuntimeContext could be created only at the first iteration of this Op's + // execution to save the elapsed time. + bool enable_runtime_context_cache_{true}; + // A runtime cache, shouldn't be transferred to others. std::string serialized_info_cache_; diff --git a/paddle/fluid/inference/tests/api/config_printer.h b/paddle/fluid/inference/tests/api/config_printer.h index b0c23fbd5..b7b39d4dd 100644 --- a/paddle/fluid/inference/tests/api/config_printer.h +++ b/paddle/fluid/inference/tests/api/config_printer.h @@ -72,7 +72,8 @@ std::ostream &operator<<(std::ostream &os, const AnalysisConfig &config) { } os << GenSpaces(num_spaces) << "enable_ir_optim: " << config.ir_optim() << "\n"; - os << GenSpaces(num_spaces) << "enable_ir_optim: " << config.ir_optim() + os << GenSpaces(num_spaces) + << "use_runtime_context_cache: " << config.runtime_context_cache_enabled() << "\n"; os << GenSpaces(num_spaces) << "use_feed_fetch_ops: " << config.use_feed_fetch_ops_enabled() << "\n"; diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 236afc77f..11e9725ae 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -242,6 +242,10 @@ void BindAnalysisConfig(py::module *m) { .def("set_mkldnn_op", &AnalysisConfig::SetMKLDNNOp) .def("set_model_buffer", &AnalysisConfig::SetModelBuffer) .def("model_from_memory", &AnalysisConfig::model_from_memory) + .def("runtime_context_cache_enabled", + &AnalysisConfig::runtime_context_cache_enabled) + .def("switch_runtime_context_cache", + &AnalysisConfig::SwitchRuntimeContextCache, py::arg("x") = true) .def("pass_builder", &AnalysisConfig::pass_builder, py::return_value_policy::reference); } -- GitLab