diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index 7d967b15c4eaa13d9a98f129addfcf316350b6b5..a44b8348716449519486d37f6784e31ecc39f554 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -1,6 +1,6 @@ lite_cc_library(mir_node SRCS node.cc DEPS kernel) lite_cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node program) -lite_cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph) +lite_cc_library(mir_pass SRCS pass.cc pass_utils.cc DEPS mir_ssa_graph) lite_cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes) lite_cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) diff --git a/lite/core/mir/argument_type_display_pass.cc b/lite/core/mir/argument_type_display_pass.cc index ea44245225fb38cd3f3ca427f513d19f3b21cbf6..2ed63b360c955b53eaa37af2f1e4832d0f88fd03 100644 --- a/lite/core/mir/argument_type_display_pass.cc +++ b/lite/core/mir/argument_type_display_pass.cc @@ -43,4 +43,4 @@ class ArgumentTypeDisplayPass : public DebugPass { REGISTER_MIR_PASS(argument_type_display_pass, paddle::lite::mir::ArgumentTypeDisplayPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/demo_pass.cc b/lite/core/mir/demo_pass.cc index b92a2b0751e7c28490d74de3811e50a619dce953..0e0858332c9d10382d71fe7b50b3b2beb6ac257b 100644 --- a/lite/core/mir/demo_pass.cc +++ b/lite/core/mir/demo_pass.cc @@ -34,4 +34,5 @@ bool RegisterDemoPass() { } // namespace lite } // namespace paddle -REGISTER_MIR_PASS(demo, paddle::lite::mir::DemoPass).SetTargets({TARGET(kAny)}); +REGISTER_MIR_PASS(demo, paddle::lite::mir::DemoPass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/elimination/identity_scale_eliminate_pass.cc b/lite/core/mir/elimination/identity_scale_eliminate_pass.cc index 00290937b2b657f69b8ae35d9785c0a456c1f6fb..acea48c742522d5b6b5f1f3b570fcbfe0c4be08d 100644 --- a/lite/core/mir/elimination/identity_scale_eliminate_pass.cc +++ b/lite/core/mir/elimination/identity_scale_eliminate_pass.cc @@ -70,4 +70,4 @@ class IdentityScaleEliminatePass : public ProgramPass { REGISTER_MIR_PASS(identity_scale_eliminate_pass, paddle::lite::mir::IdentityScaleEliminatePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/conv_activation_fuse_pass.cc b/lite/core/mir/fusion/conv_activation_fuse_pass.cc index c6939e1983728ebf6ae049a4019e088b1998c9c8..ceb3b0ea349232c7b248a3e414288be1688d0214 100644 --- a/lite/core/mir/fusion/conv_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_activation_fuse_pass.cc @@ -39,4 +39,4 @@ void ConvActivationFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_conv_activation_fuse_pass, paddle::lite::mir::ConvActivationFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass.cc b/lite/core/mir/fusion/conv_bn_fuse_pass.cc index 2e962017bc6e0856bd0b79ce009bb62720ae77da..8ac2dd252e713396033bdae379b103adf773f289 100644 --- a/lite/core/mir/fusion/conv_bn_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -35,4 +35,4 @@ void ConvBNFusePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc b/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc index 631c6b883e333a7b908d97fc7fb790e046f6260a..2ff3631ba31a807f215822fa25198c39776ea572 100644 --- a/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc @@ -39,4 +39,4 @@ void ConvElementwiseFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_conv_elementwise_fuse_pass, paddle::lite::mir::ConvElementwiseFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc index 71dc31d49a3dfd3bc715ea8f335c6b5342a9d596..67e9e56fcf35203200dbcd22b81d680f10d65e60 100644 --- a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc @@ -34,4 +34,4 @@ void ElementwiseAddActivationFusePass::Apply( REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass, paddle::lite::mir::ElementwiseAddActivationFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/fc_fuse_pass.cc b/lite/core/mir/fusion/fc_fuse_pass.cc index 3a68fd19bf748cfe0b4960e996b339b405427027..380f8f932dafb43f61aac1dddf1631dee7e78d7c 100644 --- a/lite/core/mir/fusion/fc_fuse_pass.cc +++ b/lite/core/mir/fusion/fc_fuse_pass.cc @@ -32,4 +32,4 @@ void FcFusePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/interpolate_fuse_pass.cc b/lite/core/mir/fusion/interpolate_fuse_pass.cc index 5a0e1384a7e442fd013b5f28dae51867eb8b52ba..51c9868cf3ed76ee6f02ac954f74c330e9f1a8e1 100644 --- a/lite/core/mir/fusion/interpolate_fuse_pass.cc +++ b/lite/core/mir/fusion/interpolate_fuse_pass.cc @@ -36,4 +36,4 @@ void InterpolateFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_interpolate_fuse_pass, paddle::lite::mir::InterpolateFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc index 9773caa3c1c6ae0fc0ad215d1b3f5027bbfce639..15fdff5edf3c8a480779ed780419227a5fe19306 100644 --- a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -44,4 +44,4 @@ void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass, paddle::lite::mir::QuantDequantFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc b/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc index 049be721e9b38bad3c3589f8bf35d838d58aa4a2..01b18a1842bcee832a8010581a67e90c0aa72683 100644 --- a/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc +++ b/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc @@ -36,4 +36,4 @@ void ShuffleChannelFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_shuffle_channel_fuse_pass, paddle::lite::mir::ShuffleChannelFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.cc b/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.cc index 47c866d87a91219fd14071f1185d5a826355b3f9..c233d6473959d2cb2c7e15fe6074844db0ba5850 100644 --- a/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.cc +++ b/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.cc @@ -37,4 +37,4 @@ void TransposeSoftmaxTransposeFusePass::Apply( REGISTER_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass, paddle::lite::mir::TransposeSoftmaxTransposeFusePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/generate_program_pass.cc b/lite/core/mir/generate_program_pass.cc index 23f2de564e4e3d9c06d48623db132e777eb60544..76c97d2da6ed9e7c6fc1f1889d80095278b68ec0 100644 --- a/lite/core/mir/generate_program_pass.cc +++ b/lite/core/mir/generate_program_pass.cc @@ -39,4 +39,4 @@ void GenerateProgramPass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(generate_program_pass, paddle::lite::mir::GenerateProgramPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/graph_visualize_pass.cc b/lite/core/mir/graph_visualize_pass.cc index f97dbfc7cd820ffb5a67d285df582a1b7b8ae67d..6e01d821dfe41feda5f9bb723054c240eea3efbd 100644 --- a/lite/core/mir/graph_visualize_pass.cc +++ b/lite/core/mir/graph_visualize_pass.cc @@ -99,4 +99,4 @@ std::string Visualize(mir::SSAGraph* graph) { } // namespace paddle REGISTER_MIR_PASS(graph_visualze, paddle::lite::mir::GraphVisualizePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/io_copy_kernel_pick_pass.cc b/lite/core/mir/io_copy_kernel_pick_pass.cc index b2ea823e0b4a6ee500410bb0a6973c4adaf5d399..90cf3559e3dd92998a2800e42b2c8abce8eb1355 100644 --- a/lite/core/mir/io_copy_kernel_pick_pass.cc +++ b/lite/core/mir/io_copy_kernel_pick_pass.cc @@ -72,4 +72,4 @@ class IoCopyKernelPickPass : public StmtPass { REGISTER_MIR_PASS(io_copy_kernel_pick_pass, paddle::lite::mir::IoCopyKernelPickPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/memory_optimize_pass.cc b/lite/core/mir/memory_optimize_pass.cc index 24d00f4b7457d294acfbd0e883984f30ce1b4cd9..4a4c83baef6c055320327409f2d8008a35f2f875 100644 --- a/lite/core/mir/memory_optimize_pass.cc +++ b/lite/core/mir/memory_optimize_pass.cc @@ -261,4 +261,4 @@ void MemoryOptimizePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(memory_optimize_pass, paddle::lite::mir::MemoryOptimizePass) - .SetTargets({TARGET(kARM)}); + .BindTargets({TARGET(kARM)}); diff --git a/lite/core/mir/pass.h b/lite/core/mir/pass.h index cd7684ae32111c6804c0d8b242f4cf47debfb4f4..8fd12fafa3fd6183eb3bba894be04d96075f1bc3 100644 --- a/lite/core/mir/pass.h +++ b/lite/core/mir/pass.h @@ -16,6 +16,7 @@ #include #include #include +#include #include "lite/core/mir/node.h" #include "lite/core/mir/ssa_graph.h" @@ -46,11 +47,36 @@ class Pass { void set_doc(const std::string& doc) { doc_ = doc; } const std::string& doc() const { return doc_; } - void set_targets(const std::set& targets) { targets_ = targets; } - const std::set& targets() const { return targets_; } - bool is_supported_target(TargetType target) const { - if (targets_.find(TARGET(kAny)) != targets_.end()) return true; - return (targets_.find(target) != targets_.end()); + // Some passes only apply to qualified targets, which need to be explicitly + // declared. + // Bind the target. At runtime, there must be one device in the bound targets. + void BindTargets(const std::set& targets) { + bound_targets_ = targets; + } + // Get all bound targets. + const std::set& Targets() const { return bound_targets_; } + + // Some passes are only available on qualified kernels and need to be + // explicitly declared. + // Bind kernels. All kernels bound at runtime must be registered. + void BindKernels( + const std::unordered_map>& + kernels) { + bound_kernels_ = kernels; + } + // Get all bound kernels. + const std::unordered_map>& + GetBoundKernels() const { + return bound_kernels_; + } + // Add one kernel to the bound kernels. + void BindKernel(const std::string& kernel_name, + const lite_api::Place& place) { + if (!bound_kernels_.count(kernel_name)) { + bound_kernels_.insert({kernel_name, {place}}); + } else { + bound_kernels_.at(kernel_name).insert(place); + } } Kind kind() const { return kind_; } @@ -64,7 +90,8 @@ class Pass { const Kind kind_; std::string name_; std::string doc_; - std::set targets_; + std::set bound_targets_; + std::unordered_map> bound_kernels_; }; // Different kinds. diff --git a/lite/core/mir/pass_registry.h b/lite/core/mir/pass_registry.h index cc5c119ecbc53baaf3f9664463a8e4bdd42147d2..89a4b3efd665d5e45b436ef677b99a3a79a5ae54 100644 --- a/lite/core/mir/pass_registry.h +++ b/lite/core/mir/pass_registry.h @@ -30,8 +30,13 @@ class PassRegistry { : name_(name), pass_(pass) { PassManager::Global().AddNewPass(name_, pass_); } - PassRegistry& SetTargets(const std::set& targets) { - pass_->set_targets(targets); + PassRegistry& BindTargets(const std::set& targets) { + pass_->BindTargets(targets); + return *this; + } + PassRegistry& BindKernel(const std::string& name, + const lite_api::Place& place) { + pass_->BindKernel(name, place); return *this; } bool Touch() const { return true; } diff --git a/lite/core/mir/pass_utils.cc b/lite/core/mir/pass_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..f15a7d713c00de720e2a0f99ac2bf974fbe26d95 --- /dev/null +++ b/lite/core/mir/pass_utils.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass_utils.h" +#include +#include +#include + +namespace paddle { +namespace lite { + +bool PassMatchesTarget(const mir::Pass& pass, TargetType target) { + const auto& targets = pass.Targets(); + if (targets.find(TARGET(kAny)) != targets.end()) return true; + return (targets.find(target) != targets.end()); +} + +bool PassMatchesKernels(const mir::Pass& pass) { + const auto& kernels = pass.GetBoundKernels(); + for (const auto& kernel : kernels) { + for (const auto& place : kernel.second) { + if (KernelRegistry::Global() + .Create(kernel.first, place.target, place.precision, place.layout) + .empty()) + return false; + } + } + return true; +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/pass_utils.h b/lite/core/mir/pass_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..445c91fe77ba650ca5dd5acf33428814e0d49d9e --- /dev/null +++ b/lite/core/mir/pass_utils.h @@ -0,0 +1,29 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { + +// Check if the pass hits the hardware target. +bool PassMatchesTarget(const mir::Pass& pass, TargetType target); + +// Check if the pass hits all necessary operators. +bool PassMatchesKernels(const mir::Pass& pass); + +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/runtime_context_assign_pass.cc b/lite/core/mir/runtime_context_assign_pass.cc index 652932c14926b2511edd658138776fc02cd4c3c2..97c4819eaf6734ba9b374444166d17cb15e8ae65 100644 --- a/lite/core/mir/runtime_context_assign_pass.cc +++ b/lite/core/mir/runtime_context_assign_pass.cc @@ -39,4 +39,4 @@ class RuntimeContextAssignPass : public StmtPass { REGISTER_MIR_PASS(runtime_context_assign_pass, paddle::lite::mir::RuntimeContextAssignPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/static_kernel_pick_pass.cc b/lite/core/mir/static_kernel_pick_pass.cc index 37bcb1e31798e66e4a688ad61429ea663e448151..10e4f6c1b2ec29f468db527699df0371036244f2 100644 --- a/lite/core/mir/static_kernel_pick_pass.cc +++ b/lite/core/mir/static_kernel_pick_pass.cc @@ -133,4 +133,4 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(static_kernel_pick_pass, paddle::lite::mir::StaticKernelPickPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/subgraph/generate_npu_program_pass.cc b/lite/core/mir/subgraph/generate_npu_program_pass.cc index 76e295c7af044a1981d3075c7390b1fdcee8b93a..8badd357c3e268cd7b0281b434081f55747caa7f 100644 --- a/lite/core/mir/subgraph/generate_npu_program_pass.cc +++ b/lite/core/mir/subgraph/generate_npu_program_pass.cc @@ -215,4 +215,4 @@ std::unique_ptr GenerateNPUProgramPass::GenProgram() { REGISTER_MIR_PASS(generate_npu_program_pass, paddle::lite::mir::subgraph::GenerateNPUProgramPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/subgraph/subgraph_program_pass.cc b/lite/core/mir/subgraph/subgraph_program_pass.cc index 2b6206f891856eff462ee3257be44316e9149a46..a3d95163ce5e8f130e32dc6425c526b9e405a0fb 100644 --- a/lite/core/mir/subgraph/subgraph_program_pass.cc +++ b/lite/core/mir/subgraph/subgraph_program_pass.cc @@ -311,4 +311,4 @@ int SubgraphProgramPass::FuseSubgraph( REGISTER_MIR_PASS(subgraph_program_pass, paddle::lite::mir::subgraph::SubgraphProgramPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/type_layout_cast_pass.cc b/lite/core/mir/type_layout_cast_pass.cc index fbd3f9e1d2a765ba20c6a220a87309623135f141..11f4a21f240f39cd3b15511231e748b33d7a1ae5 100644 --- a/lite/core/mir/type_layout_cast_pass.cc +++ b/lite/core/mir/type_layout_cast_pass.cc @@ -174,4 +174,4 @@ void TypeLayoutTransformPass::SetValidPlaces( REGISTER_MIR_PASS(type_layout_cast_pass, paddle::lite::mir::TypeLayoutTransformPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/type_precision_cast_pass.cc b/lite/core/mir/type_precision_cast_pass.cc index 7cd22e25aca0ae36115ef4104b719241f3645c1d..5a99a67255a00da566e181ab59dba9acda5d9647 100644 --- a/lite/core/mir/type_precision_cast_pass.cc +++ b/lite/core/mir/type_precision_cast_pass.cc @@ -180,4 +180,4 @@ void PrecisionCastPass::SetValidPlaces(const std::vector& valid_places) { REGISTER_MIR_PASS(type_precision_cast_pass, paddle::lite::mir::PrecisionCastPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/type_target_cast_pass.cc b/lite/core/mir/type_target_cast_pass.cc index 5a07fdd9d956b7c65370f2d9bfbd613be1332adf..0af7fa3cfd67d01f0075ed86fe881bcc99a7848e 100644 --- a/lite/core/mir/type_target_cast_pass.cc +++ b/lite/core/mir/type_target_cast_pass.cc @@ -180,4 +180,4 @@ void TypeTargetTransformPass::SetValidPlaces( REGISTER_MIR_PASS(type_target_cast_pass, paddle::lite::mir::TypeTargetTransformPass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/variable_place_inference_pass.cc b/lite/core/mir/variable_place_inference_pass.cc index 1f8aea81729575b13144406e93a431821f037b9f..f1b6381fc0010e08cffa4baee4dc7b33a678b387 100644 --- a/lite/core/mir/variable_place_inference_pass.cc +++ b/lite/core/mir/variable_place_inference_pass.cc @@ -32,4 +32,4 @@ void VariablePlaceInferencePass::Apply(const std::unique_ptr &graph) { REGISTER_MIR_PASS(variable_place_inference_pass, paddle::lite::mir::VariablePlaceInferencePass) - .SetTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 4a0e95e26654dda58bd88828042b05bedddbc684..7361eed23696de3fee98712590940bab2658d580 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -18,6 +18,7 @@ #include #include "lite/core/mir/generate_program_pass.h" #include "lite/core/mir/pass_manager.h" +#include "lite/core/mir/pass_utils.h" #include "lite/core/mir/ssa_graph.h" #include "lite/core/mir/static_kernel_pick_pass.h" #include "lite/core/mir/type_target_cast_pass.h" @@ -186,15 +187,15 @@ class Optimizer { LOG(INFO) << "== Running pass: " << x; mir::Pass* pass = mir::PassManager::Global().LookUp(x); CHECK(pass) << "Can not find pass: " << x; - bool supported = false; + bool matched = false; for (const auto& place : valid_places_) { - if (pass->is_supported_target(place.target)) { - supported = true; + if (PassMatchesTarget(*pass, place.target)) { + matched = true; } } - if (!supported) { - LOG(WARNING) << "Skip " << x - << " pass because the target does not match."; + matched = matched || PassMatchesKernels(*pass); + if (!matched) { + LOG(INFO) << "Skip " << x << " pass because the target does not match."; } else { pass->Apply(graph_); LOG(INFO) << "== Finished running: " << x;