diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index a79a53867d85e91250ac4810caa5806c25f35fee..db9d2f6d8072dea8a8bf49a53cc3373798617473 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -71,6 +71,7 @@ pass_library(transpose_flatten_concat_fuse_pass inference) pass_library(identity_scale_op_clean_pass base) pass_library(sync_batch_norm_pass base) pass_library(runtime_context_cache_pass base) +pass_library(expected_kernel_cache_pass base) # There may be many transpose-flatten structures in a model, and the output of # these structures will be used as inputs to the concat Op. This pattern will diff --git a/paddle/fluid/framework/ir/expected_kernel_cache_pass.cc b/paddle/fluid/framework/ir/expected_kernel_cache_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..28db7dc55992b8f05c5a0a404fa677d8023312f9 --- /dev/null +++ b/paddle/fluid/framework/ir/expected_kernel_cache_pass.cc @@ -0,0 +1,39 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/expected_kernel_cache_pass.h" +#include +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +namespace ir { + +std::unique_ptr ExpectedKernelCachePass::ApplyImpl( + std::unique_ptr graph) const { + VLOG(3) << "Applies Expected Kernel Cache strategy."; + for (const Node* n : graph->Nodes()) { + if (n->IsOp()) { + n->Op()->SetAttr(kEnableCacheExpectedKernel, true); + } + } + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(expected_kernel_cache_pass, + paddle::framework::ir::ExpectedKernelCachePass); diff --git a/paddle/fluid/framework/ir/expected_kernel_cache_pass.h b/paddle/fluid/framework/ir/expected_kernel_cache_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..738eb84fd91afaaa5b841b43d58323db90a6f304 --- /dev/null +++ b/paddle/fluid/framework/ir/expected_kernel_cache_pass.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class ExpectedKernelCachePass : public Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 1d3a38cc286fdbc26d62a02e9d9086feb2826a07..ac90f05b9fa5cd79c89c2aae912a5d2072f62cc5 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -895,7 +895,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); - if (!kernel_type_) { + if (!HasAttr(kEnableCacheExpectedKernel) || !kernel_type_) { ChooseKernel(*runtime_ctx, scope, place); } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index fb7829a12cd1efcab115b5136025c6f324505f2f..74f6844494cd092a74ce8503391b97b00b5e398b 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -70,6 +70,12 @@ constexpr char kNewGradSuffix[] = "@NEWGRAD@"; /// this Op's execution to save the elapsed time. constexpr char kEnableCacheRuntimeContext[] = "@ENABLE_CACHE_RUNTIME_CONTEXT@"; +/// If an Op has attribtue kEnableCacheExpectedKernel, it means that in a same +/// name scope and same place, since the expected kerenl of this Op does not +/// change in the execution, it could be recorded only at the first iteration of +/// this Op's execution to save the elapsed time. +constexpr char kEnableCacheExpectedKernel[] = "@ENABLE_CACHE_EXPECTED_KERNEL@"; + /// If an Op has this attribute, all its kernels should calculate output /// variable's shape in the corresponding Compute() function. And /// OperatorWithKernel::RunImpl() would skip call this Op's InferShape() diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 1be25de497346913f24eec147a2db58b0f7065f4..b4392cf886f8025249fcaa9da96bd7a576f4f13a 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -203,6 +203,7 @@ void AnalysisConfig::Update() { pass_builder()->InsertPass(3, "tensorrt_subgraph_pass"); } pass_builder()->DeletePass("runtime_context_cache_pass"); + pass_builder()->DeletePass("expected_kernel_cache_pass"); } if (use_mkldnn_) { diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 22c527cfc117a5e6ababf264744745e41e0bf71a..20acd51f31c7a54557b098cda511d42e6125e78d 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -81,6 +81,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "conv_elementwise_add2_act_fuse_pass", // "conv_elementwise_add_fuse_pass", // "runtime_context_cache_pass", // + "expected_kernel_cache_pass", // #endif }); @@ -117,6 +118,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { "is_test_pass", // "identity_scale_op_clean_pass", // "runtime_context_cache_pass", // + "expected_kernel_cache_pass", // }); use_gpu_ = false; }