From 82af8031d9f75d9ad16f844324f9f2634c9b5003 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Tue, 19 Mar 2019 22:59:59 +0800 Subject: [PATCH] add runtime_context_cache_pass test=develop --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../ir/runtime_context_cache_pass.cc | 39 +++++++++++++++++++ .../framework/ir/runtime_context_cache_pass.h | 32 +++++++++++++++ paddle/fluid/framework/operator.cc | 29 ++++++++++---- paddle/fluid/framework/operator.h | 12 ++++++ paddle/fluid/inference/api/analysis_config.cc | 1 + .../inference/api/paddle_pass_builder.cc | 2 + 7 files changed, 109 insertions(+), 7 deletions(-) create mode 100644 paddle/fluid/framework/ir/runtime_context_cache_pass.cc create mode 100644 paddle/fluid/framework/ir/runtime_context_cache_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 8d2cc5ade..a79a53867 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -70,6 +70,7 @@ pass_library(conv_affine_channel_fuse_pass inference) pass_library(transpose_flatten_concat_fuse_pass inference) pass_library(identity_scale_op_clean_pass base) pass_library(sync_batch_norm_pass base) +pass_library(runtime_context_cache_pass base) # There may be many transpose-flatten structures in a model, and the output of # these structures will be used as inputs to the concat Op. This pattern will diff --git a/paddle/fluid/framework/ir/runtime_context_cache_pass.cc b/paddle/fluid/framework/ir/runtime_context_cache_pass.cc new file mode 100644 index 000000000..67b29512c --- /dev/null +++ b/paddle/fluid/framework/ir/runtime_context_cache_pass.cc @@ -0,0 +1,39 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/runtime_context_cache_pass.h" +#include +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +namespace ir { + +std::unique_ptr RuntimeContextCachePass::ApplyImpl( + std::unique_ptr graph) const { + VLOG(3) << "Applies Runtime Context Cache strategy."; + for (const Node* n : graph->Nodes()) { + if (n->IsOp()) { + n->Op()->SetAttr(kEnableCacheRuntimeContext, true); + } + } + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(runtime_context_cache_pass, + paddle::framework::ir::RuntimeContextCachePass); diff --git a/paddle/fluid/framework/ir/runtime_context_cache_pass.h b/paddle/fluid/framework/ir/runtime_context_cache_pass.h new file mode 100644 index 000000000..a6cf1a9ae --- /dev/null +++ b/paddle/fluid/framework/ir/runtime_context_cache_pass.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class RuntimeContextCachePass : public Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 44821aadf..1ba2bed88 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -876,7 +876,22 @@ std::vector* OperatorWithKernel::GetKernelConfig( void OperatorWithKernel::RunImpl(const Scope& scope, const platform::Place& place) const { - RuntimeContext ctx(Inputs(), Outputs(), scope); + if (!HasAttr(kEnableCacheRuntimeContext)) { + RuntimeContext ctx(Inputs(), Outputs(), scope); + RunImpl(scope, place, &ctx); + } else { + const Scope* cur_scope = &scope; + if (!runtime_ctx_ || pre_scope_ != cur_scope) { + runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope)); + pre_scope_ = cur_scope; + } + RunImpl(scope, place, runtime_ctx_.get()); + } +} + +void OperatorWithKernel::RunImpl(const Scope& scope, + const platform::Place& place, + RuntimeContext* runtime_ctx) const { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); @@ -891,7 +906,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, OpKernelMap& kernels = kernels_iter->second; auto expected_kernel_key = this->GetExpectedKernelType( - ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr)); + ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx, nullptr)); VLOG(3) << "expected_kernel_key:" << expected_kernel_key; auto kernel_iter = kernels.find(expected_kernel_key); @@ -915,8 +930,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, // do data transformScope &transfer_scope; std::vector transfered_inplace_vars; - auto* transfer_scope = - PrepareData(scope, expected_kernel_key, &transfered_inplace_vars, &ctx); + auto* transfer_scope = PrepareData(scope, expected_kernel_key, + &transfered_inplace_vars, runtime_ctx); // exec scope is the scope that kernel actually executed on. const Scope& exec_scope = @@ -927,13 +942,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) { - RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx); + RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx); this->InferShape(&infer_shape_ctx); } // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext // not Scope. Imperative mode only pass inputs and get outputs. - kernel_iter->second( - ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs)); + kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx, + *runtime_ctx, kernel_configs)); if (!transfered_inplace_vars.empty()) { // there is inplace variable has been transfered. diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 822bf5c9c..684960c23 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -62,6 +62,14 @@ constexpr char kZeroVarSuffix[] = "@ZERO"; /// Variables with this suffix are the new Gradient. constexpr char kNewGradSuffix[] = "@NEWGRAD@"; +/// RuntimeContext is used to relate input/output names of Operator with +/// the corresponding variables in name scope. +/// If an Op has attribute kEnableCacheRuntimeContext, it means that in a same +/// name scope, since the input/output names of this Op do not change in the +/// execution, RuntimeContext could be created only at the first iteration of +/// this Op's execution to save the elapsed time. +constexpr char kEnableCacheRuntimeContext[] = "@ENABLE_CACHE_RUNTIME_CONTEXT@"; + /// If an Op has this attribute, all its kernels should calculate output /// variable's shape in the corresponding Compute() function. And /// OperatorWithKernel::RunImpl() would skip call this Op's InferShape() @@ -456,6 +464,8 @@ class OperatorWithKernel : public OperatorBase { // same. proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; void RunImpl(const Scope& scope, const platform::Place& place) const final; + void RunImpl(const Scope& scope, const platform::Place& place, + RuntimeContext* runtime_ctx) const; /** * Transfer data from scope to a transfered scope. If there is no data need to @@ -474,6 +484,8 @@ class OperatorWithKernel : public OperatorBase { protected: mutable OpKernelConfigsMap kernel_configs_map_; + mutable std::unique_ptr runtime_ctx_; + mutable const Scope* pre_scope_ = nullptr; }; extern bool OpSupportGPU(const std::string& op_type); diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 92526f4e7..1be25de49 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -202,6 +202,7 @@ void AnalysisConfig::Update() { // Append after the Affine_channel_conv_fuse pass. pass_builder()->InsertPass(3, "tensorrt_subgraph_pass"); } + pass_builder()->DeletePass("runtime_context_cache_pass"); } if (use_mkldnn_) { diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 92c24647e..22c527cfc 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -80,6 +80,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "conv_elementwise_add_act_fuse_pass", // "conv_elementwise_add2_act_fuse_pass", // "conv_elementwise_add_fuse_pass", // + "runtime_context_cache_pass", // #endif }); @@ -115,6 +116,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { "conv_eltwiseadd_bn_fuse_pass", // "is_test_pass", // "identity_scale_op_clean_pass", // + "runtime_context_cache_pass", // }); use_gpu_ = false; } -- GitLab