From e27e1b08b0b1e59341f235d36cd8449c3102203a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E6=99=93=E4=BC=9F?= <39303645+Shixiaowei02@users.noreply.github.com> Date: Fri, 13 Sep 2019 12:06:23 +0800 Subject: [PATCH] checkout if passes match targets and kernels, test=develop (#2035) * checkout if passes match targets and kernels, test=develop * add pass_utils, test=develop * fix lite/core/mir/pass_registry.h, test=develop * improve code styles, test=develop * fix spell error, test=develop --- lite/core/mir/CMakeLists.txt | 2 +- lite/core/mir/argument_type_display_pass.cc | 2 +- lite/core/mir/demo_pass.cc | 3 +- .../identity_scale_eliminate_pass.cc | 2 +- .../mir/fusion/conv_activation_fuse_pass.cc | 2 +- lite/core/mir/fusion/conv_bn_fuse_pass.cc | 2 +- .../mir/fusion/conv_elementwise_fuse_pass.cc | 2 +- .../elementwise_add_activation_fuse_pass.cc | 2 +- lite/core/mir/fusion/fc_fuse_pass.cc | 2 +- lite/core/mir/fusion/interpolate_fuse_pass.cc | 2 +- .../mir/fusion/quant_dequant_fuse_pass.cc | 2 +- .../mir/fusion/shuffle_channel_fuse_pass.cc | 2 +- .../transpose_softmax_transpose_fuse_pass.cc | 2 +- lite/core/mir/generate_program_pass.cc | 2 +- lite/core/mir/graph_visualize_pass.cc | 2 +- lite/core/mir/io_copy_kernel_pick_pass.cc | 2 +- lite/core/mir/memory_optimize_pass.cc | 2 +- lite/core/mir/pass.h | 39 ++++++++++++++--- lite/core/mir/pass_registry.h | 9 +++- lite/core/mir/pass_utils.cc | 43 +++++++++++++++++++ lite/core/mir/pass_utils.h | 29 +++++++++++++ lite/core/mir/runtime_context_assign_pass.cc | 2 +- lite/core/mir/static_kernel_pick_pass.cc | 2 +- .../mir/subgraph/generate_npu_program_pass.cc | 2 +- .../mir/subgraph/subgraph_program_pass.cc | 2 +- lite/core/mir/type_layout_cast_pass.cc | 2 +- lite/core/mir/type_precision_cast_pass.cc | 2 +- lite/core/mir/type_target_cast_pass.cc | 2 +- .../core/mir/variable_place_inference_pass.cc | 2 +- lite/core/optimizer.h | 13 +++--- 30 files changed, 145 insertions(+), 39 deletions(-) create mode 100644 lite/core/mir/pass_utils.cc create mode 100644 lite/core/mir/pass_utils.h diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index 7d967b15c4..a44b834871 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -1,6 +1,6 @@ lite_cc_library(mir_node SRCS node.cc DEPS kernel) lite_cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node program) -lite_cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph) +lite_cc_library(mir_pass SRCS pass.cc pass_utils.cc DEPS mir_ssa_graph) lite_cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes) lite_cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) diff --git a/lite/core/mir/argument_type_display_pass.cc b/lite/core/mir/argument_type_display_pass.cc index ea44245225..2ed63b360c 100644 --- a/lite/core/mir/argument_type_display_pass.cc +++ b/lite/core/mir/argument_type_display_pass.cc @@ -43,4 +43,4 @@ class ArgumentTypeDisplayPass : public DebugPass { REGISTER_MIR_PASS(argument_type_display_pass, paddle::lite::mir::ArgumentTypeDisplayPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/demo_pass.cc b/lite/core/mir/demo_pass.cc index b92a2b0751..0e0858332c 100644 --- a/lite/core/mir/demo_pass.cc +++ b/lite/core/mir/demo_pass.cc @@ -34,4 +34,5 @@ bool RegisterDemoPass() { } // namespace lite } // namespace paddle -REGISTER_MIR_PASS(demo, paddle::lite::mir::DemoPass).SetTargets({TARGET(kAny)}); +REGISTER_MIR_PASS(demo, paddle::lite::mir::DemoPass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/elimination/identity_scale_eliminate_pass.cc b/lite/core/mir/elimination/identity_scale_eliminate_pass.cc index 00290937b2..acea48c742 100644 --- a/lite/core/mir/elimination/identity_scale_eliminate_pass.cc +++ b/lite/core/mir/elimination/identity_scale_eliminate_pass.cc @@ -70,4 +70,4 @@ class IdentityScaleEliminatePass : public ProgramPass { REGISTER_MIR_PASS(identity_scale_eliminate_pass, paddle::lite::mir::IdentityScaleEliminatePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/conv_activation_fuse_pass.cc b/lite/core/mir/fusion/conv_activation_fuse_pass.cc index c6939e1983..ceb3b0ea34 100644 --- a/lite/core/mir/fusion/conv_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_activation_fuse_pass.cc @@ -39,4 +39,4 @@ void ConvActivationFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_conv_activation_fuse_pass, paddle::lite::mir::ConvActivationFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass.cc b/lite/core/mir/fusion/conv_bn_fuse_pass.cc index 2e962017bc..8ac2dd252e 100644 --- a/lite/core/mir/fusion/conv_bn_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -35,4 +35,4 @@ void ConvBNFusePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc b/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc index 631c6b883e..2ff3631ba3 100644 --- a/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc @@ -39,4 +39,4 @@ void ConvElementwiseFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_conv_elementwise_fuse_pass, paddle::lite::mir::ConvElementwiseFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc index 71dc31d49a..67e9e56fcf 100644 --- a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc @@ -34,4 +34,4 @@ void ElementwiseAddActivationFusePass::Apply( REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass, paddle::lite::mir::ElementwiseAddActivationFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/fc_fuse_pass.cc b/lite/core/mir/fusion/fc_fuse_pass.cc index 3a68fd19bf..380f8f932d 100644 --- a/lite/core/mir/fusion/fc_fuse_pass.cc +++ b/lite/core/mir/fusion/fc_fuse_pass.cc @@ -32,4 +32,4 @@ void FcFusePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/interpolate_fuse_pass.cc b/lite/core/mir/fusion/interpolate_fuse_pass.cc index 5a0e1384a7..51c9868cf3 100644 --- a/lite/core/mir/fusion/interpolate_fuse_pass.cc +++ b/lite/core/mir/fusion/interpolate_fuse_pass.cc @@ -36,4 +36,4 @@ void InterpolateFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_interpolate_fuse_pass, paddle::lite::mir::InterpolateFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc index 9773caa3c1..15fdff5edf 100644 --- a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -44,4 +44,4 @@ void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass, paddle::lite::mir::QuantDequantFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc b/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc index 049be721e9..01b18a1842 100644 --- a/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc +++ b/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc @@ -36,4 +36,4 @@ void ShuffleChannelFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_shuffle_channel_fuse_pass, paddle::lite::mir::ShuffleChannelFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.cc b/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.cc index 47c866d87a..c233d64739 100644 --- a/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.cc +++ b/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.cc @@ -37,4 +37,4 @@ void TransposeSoftmaxTransposeFusePass::Apply( REGISTER_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass, paddle::lite::mir::TransposeSoftmaxTransposeFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/generate_program_pass.cc b/lite/core/mir/generate_program_pass.cc index 23f2de564e..76c97d2da6 100644 --- a/lite/core/mir/generate_program_pass.cc +++ b/lite/core/mir/generate_program_pass.cc @@ -39,4 +39,4 @@ void GenerateProgramPass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(generate_program_pass, paddle::lite::mir::GenerateProgramPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/graph_visualize_pass.cc b/lite/core/mir/graph_visualize_pass.cc index f97dbfc7cd..6e01d821df 100644 --- a/lite/core/mir/graph_visualize_pass.cc +++ b/lite/core/mir/graph_visualize_pass.cc @@ -99,4 +99,4 @@ std::string Visualize(mir::SSAGraph* graph) { } // namespace paddle REGISTER_MIR_PASS(graph_visualze, paddle::lite::mir::GraphVisualizePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/io_copy_kernel_pick_pass.cc b/lite/core/mir/io_copy_kernel_pick_pass.cc index b2ea823e0b..90cf3559e3 100644 --- a/lite/core/mir/io_copy_kernel_pick_pass.cc +++ b/lite/core/mir/io_copy_kernel_pick_pass.cc @@ -72,4 +72,4 @@ class IoCopyKernelPickPass : public StmtPass { REGISTER_MIR_PASS(io_copy_kernel_pick_pass, paddle::lite::mir::IoCopyKernelPickPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/memory_optimize_pass.cc b/lite/core/mir/memory_optimize_pass.cc index 24d00f4b74..4a4c83baef 100644 --- a/lite/core/mir/memory_optimize_pass.cc +++ b/lite/core/mir/memory_optimize_pass.cc @@ -261,4 +261,4 @@ void MemoryOptimizePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(memory_optimize_pass, paddle::lite::mir::MemoryOptimizePass) - .SetTargets({TARGET(kARM)}); + .BindTargets({TARGET(kARM)}); diff --git a/lite/core/mir/pass.h b/lite/core/mir/pass.h index cd7684ae32..8fd12fafa3 100644 --- a/lite/core/mir/pass.h +++ b/lite/core/mir/pass.h @@ -16,6 +16,7 @@ #include #include #include +#include #include "lite/core/mir/node.h" #include "lite/core/mir/ssa_graph.h" @@ -46,11 +47,36 @@ class Pass { void set_doc(const std::string& doc) { doc_ = doc; } const std::string& doc() const { return doc_; } - void set_targets(const std::set& targets) { targets_ = targets; } - const std::set& targets() const { return targets_; } - bool is_supported_target(TargetType target) const { - if (targets_.find(TARGET(kAny)) != targets_.end()) return true; - return (targets_.find(target) != targets_.end()); + // Some passes only apply to qualified targets, which need to be explicitly + // declared. + // Bind the target. At runtime, there must be one device in the bound targets. + void BindTargets(const std::set& targets) { + bound_targets_ = targets; + } + // Get all bound targets. + const std::set& Targets() const { return bound_targets_; } + + // Some passes are only available on qualified kernels and need to be + // explicitly declared. + // Bind kernels. All kernels bound at runtime must be registered. + void BindKernels( + const std::unordered_map>& + kernels) { + bound_kernels_ = kernels; + } + // Get all bound kernels. + const std::unordered_map>& + GetBoundKernels() const { + return bound_kernels_; + } + // Add one kernel to the bound kernels. + void BindKernel(const std::string& kernel_name, + const lite_api::Place& place) { + if (!bound_kernels_.count(kernel_name)) { + bound_kernels_.insert({kernel_name, {place}}); + } else { + bound_kernels_.at(kernel_name).insert(place); + } } Kind kind() const { return kind_; } @@ -64,7 +90,8 @@ class Pass { const Kind kind_; std::string name_; std::string doc_; - std::set targets_; + std::set bound_targets_; + std::unordered_map> bound_kernels_; }; // Different kinds. diff --git a/lite/core/mir/pass_registry.h b/lite/core/mir/pass_registry.h index cc5c119ecb..89a4b3efd6 100644 --- a/lite/core/mir/pass_registry.h +++ b/lite/core/mir/pass_registry.h @@ -30,8 +30,13 @@ class PassRegistry { : name_(name), pass_(pass) { PassManager::Global().AddNewPass(name_, pass_); } - PassRegistry& SetTargets(const std::set& targets) { - pass_->set_targets(targets); + PassRegistry& BindTargets(const std::set& targets) { + pass_->BindTargets(targets); + return *this; + } + PassRegistry& BindKernel(const std::string& name, + const lite_api::Place& place) { + pass_->BindKernel(name, place); return *this; } bool Touch() const { return true; } diff --git a/lite/core/mir/pass_utils.cc b/lite/core/mir/pass_utils.cc new file mode 100644 index 0000000000..f15a7d713c --- /dev/null +++ b/lite/core/mir/pass_utils.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass_utils.h" +#include +#include +#include + +namespace paddle { +namespace lite { + +bool PassMatchesTarget(const mir::Pass& pass, TargetType target) { + const auto& targets = pass.Targets(); + if (targets.find(TARGET(kAny)) != targets.end()) return true; + return (targets.find(target) != targets.end()); +} + +bool PassMatchesKernels(const mir::Pass& pass) { + const auto& kernels = pass.GetBoundKernels(); + for (const auto& kernel : kernels) { + for (const auto& place : kernel.second) { + if (KernelRegistry::Global() + .Create(kernel.first, place.target, place.precision, place.layout) + .empty()) + return false; + } + } + return true; +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/pass_utils.h b/lite/core/mir/pass_utils.h new file mode 100644 index 0000000000..445c91fe77 --- /dev/null +++ b/lite/core/mir/pass_utils.h @@ -0,0 +1,29 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { + +// Check if the pass hits the hardware target. +bool PassMatchesTarget(const mir::Pass& pass, TargetType target); + +// Check if the pass hits all necessary operators. +bool PassMatchesKernels(const mir::Pass& pass); + +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/runtime_context_assign_pass.cc b/lite/core/mir/runtime_context_assign_pass.cc index 652932c149..97c4819eaf 100644 --- a/lite/core/mir/runtime_context_assign_pass.cc +++ b/lite/core/mir/runtime_context_assign_pass.cc @@ -39,4 +39,4 @@ class RuntimeContextAssignPass : public StmtPass { REGISTER_MIR_PASS(runtime_context_assign_pass, paddle::lite::mir::RuntimeContextAssignPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/static_kernel_pick_pass.cc b/lite/core/mir/static_kernel_pick_pass.cc index 37bcb1e317..10e4f6c1b2 100644 --- a/lite/core/mir/static_kernel_pick_pass.cc +++ b/lite/core/mir/static_kernel_pick_pass.cc @@ -133,4 +133,4 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(static_kernel_pick_pass, paddle::lite::mir::StaticKernelPickPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/subgraph/generate_npu_program_pass.cc b/lite/core/mir/subgraph/generate_npu_program_pass.cc index 76e295c7af..8badd357c3 100644 --- a/lite/core/mir/subgraph/generate_npu_program_pass.cc +++ b/lite/core/mir/subgraph/generate_npu_program_pass.cc @@ -215,4 +215,4 @@ std::unique_ptr GenerateNPUProgramPass::GenProgram() { REGISTER_MIR_PASS(generate_npu_program_pass, paddle::lite::mir::subgraph::GenerateNPUProgramPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/subgraph/subgraph_program_pass.cc b/lite/core/mir/subgraph/subgraph_program_pass.cc index 2b6206f891..a3d95163ce 100644 --- a/lite/core/mir/subgraph/subgraph_program_pass.cc +++ b/lite/core/mir/subgraph/subgraph_program_pass.cc @@ -311,4 +311,4 @@ int SubgraphProgramPass::FuseSubgraph( REGISTER_MIR_PASS(subgraph_program_pass, paddle::lite::mir::subgraph::SubgraphProgramPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/type_layout_cast_pass.cc b/lite/core/mir/type_layout_cast_pass.cc index fbd3f9e1d2..11f4a21f24 100644 --- a/lite/core/mir/type_layout_cast_pass.cc +++ b/lite/core/mir/type_layout_cast_pass.cc @@ -174,4 +174,4 @@ void TypeLayoutTransformPass::SetValidPlaces( REGISTER_MIR_PASS(type_layout_cast_pass, paddle::lite::mir::TypeLayoutTransformPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/type_precision_cast_pass.cc b/lite/core/mir/type_precision_cast_pass.cc index 7cd22e25ac..5a99a67255 100644 --- a/lite/core/mir/type_precision_cast_pass.cc +++ b/lite/core/mir/type_precision_cast_pass.cc @@ -180,4 +180,4 @@ void PrecisionCastPass::SetValidPlaces(const std::vector& valid_places) { REGISTER_MIR_PASS(type_precision_cast_pass, paddle::lite::mir::PrecisionCastPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/type_target_cast_pass.cc b/lite/core/mir/type_target_cast_pass.cc index 5a07fdd9d9..0af7fa3cfd 100644 --- a/lite/core/mir/type_target_cast_pass.cc +++ b/lite/core/mir/type_target_cast_pass.cc @@ -180,4 +180,4 @@ void TypeTargetTransformPass::SetValidPlaces( REGISTER_MIR_PASS(type_target_cast_pass, paddle::lite::mir::TypeTargetTransformPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/variable_place_inference_pass.cc b/lite/core/mir/variable_place_inference_pass.cc index 1f8aea8172..f1b6381fc0 100644 --- a/lite/core/mir/variable_place_inference_pass.cc +++ b/lite/core/mir/variable_place_inference_pass.cc @@ -32,4 +32,4 @@ void VariablePlaceInferencePass::Apply(const std::unique_ptr &graph) { REGISTER_MIR_PASS(variable_place_inference_pass, paddle::lite::mir::VariablePlaceInferencePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 4a0e95e266..7361eed236 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -18,6 +18,7 @@ #include #include "lite/core/mir/generate_program_pass.h" #include "lite/core/mir/pass_manager.h" +#include "lite/core/mir/pass_utils.h" #include "lite/core/mir/ssa_graph.h" #include "lite/core/mir/static_kernel_pick_pass.h" #include "lite/core/mir/type_target_cast_pass.h" @@ -186,15 +187,15 @@ class Optimizer { LOG(INFO) << "== Running pass: " << x; mir::Pass* pass = mir::PassManager::Global().LookUp(x); CHECK(pass) << "Can not find pass: " << x; - bool supported = false; + bool matched = false; for (const auto& place : valid_places_) { - if (pass->is_supported_target(place.target)) { - supported = true; + if (PassMatchesTarget(*pass, place.target)) { + matched = true; } } - if (!supported) { - LOG(WARNING) << "Skip " << x - << " pass because the target does not match."; + matched = matched || PassMatchesKernels(*pass); + if (!matched) { + LOG(INFO) << "Skip " << x << " pass because the target does not match."; } else { pass->Apply(graph_); LOG(INFO) << "== Finished running: " << x; -- GitLab