From 056599a738a2634c163dea044b188f7770501182 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Thu, 21 Mar 2019 11:09:43 +0800 Subject: [PATCH] add expected_kernel_cache_pass test=develop --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../ir/expected_kernel_cache_pass.cc | 39 +++++++++++++++++++ .../framework/ir/expected_kernel_cache_pass.h | 32 +++++++++++++++ paddle/fluid/framework/operator.cc | 2 +- paddle/fluid/framework/operator.h | 6 +++ paddle/fluid/inference/api/analysis_config.cc | 1 + .../inference/api/paddle_pass_builder.cc | 2 + 7 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/framework/ir/expected_kernel_cache_pass.cc create mode 100644 paddle/fluid/framework/ir/expected_kernel_cache_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index a79a53867..db9d2f6d8 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -71,6 +71,7 @@ pass_library(transpose_flatten_concat_fuse_pass inference) pass_library(identity_scale_op_clean_pass base) pass_library(sync_batch_norm_pass base) pass_library(runtime_context_cache_pass base) +pass_library(expected_kernel_cache_pass base) # There may be many transpose-flatten structures in a model, and the output of # these structures will be used as inputs to the concat Op. This pattern will diff --git a/paddle/fluid/framework/ir/expected_kernel_cache_pass.cc b/paddle/fluid/framework/ir/expected_kernel_cache_pass.cc new file mode 100644 index 000000000..28db7dc55 --- /dev/null +++ b/paddle/fluid/framework/ir/expected_kernel_cache_pass.cc @@ -0,0 +1,39 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/expected_kernel_cache_pass.h" +#include +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +namespace ir { + +std::unique_ptr ExpectedKernelCachePass::ApplyImpl( + std::unique_ptr graph) const { + VLOG(3) << "Applies Expected Kernel Cache strategy."; + for (const Node* n : graph->Nodes()) { + if (n->IsOp()) { + n->Op()->SetAttr(kEnableCacheExpectedKernel, true); + } + } + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(expected_kernel_cache_pass, + paddle::framework::ir::ExpectedKernelCachePass); diff --git a/paddle/fluid/framework/ir/expected_kernel_cache_pass.h b/paddle/fluid/framework/ir/expected_kernel_cache_pass.h new file mode 100644 index 000000000..738eb84fd --- /dev/null +++ b/paddle/fluid/framework/ir/expected_kernel_cache_pass.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class ExpectedKernelCachePass : public Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 1d3a38cc2..ac90f05b9 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -895,7 +895,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); - if (!kernel_type_) { + if (!HasAttr(kEnableCacheExpectedKernel) || !kernel_type_) { ChooseKernel(*runtime_ctx, scope, place); } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index fb7829a12..74f684449 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -70,6 +70,12 @@ constexpr char kNewGradSuffix[] = "@NEWGRAD@"; /// this Op's execution to save the elapsed time. constexpr char kEnableCacheRuntimeContext[] = "@ENABLE_CACHE_RUNTIME_CONTEXT@"; +/// If an Op has attribtue kEnableCacheExpectedKernel, it means that in a same +/// name scope and same place, since the expected kerenl of this Op does not +/// change in the execution, it could be recorded only at the first iteration of +/// this Op's execution to save the elapsed time. +constexpr char kEnableCacheExpectedKernel[] = "@ENABLE_CACHE_EXPECTED_KERNEL@"; + /// If an Op has this attribute, all its kernels should calculate output /// variable's shape in the corresponding Compute() function. And /// OperatorWithKernel::RunImpl() would skip call this Op's InferShape() diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 1be25de49..b4392cf88 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -203,6 +203,7 @@ void AnalysisConfig::Update() { pass_builder()->InsertPass(3, "tensorrt_subgraph_pass"); } pass_builder()->DeletePass("runtime_context_cache_pass"); + pass_builder()->DeletePass("expected_kernel_cache_pass"); } if (use_mkldnn_) { diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 22c527cfc..20acd51f3 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -81,6 +81,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "conv_elementwise_add2_act_fuse_pass", // "conv_elementwise_add_fuse_pass", // "runtime_context_cache_pass", // + "expected_kernel_cache_pass", // #endif }); @@ -117,6 +118,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { "is_test_pass", // "identity_scale_op_clean_pass", // "runtime_context_cache_pass", // + "expected_kernel_cache_pass", // }); use_gpu_ = false; } -- GitLab