提交 2f3d7fd6 编写于 作者: 石晓伟 提交者: Yan Chunwei

checkout if passes match targets and kernels, test=develop (#2035)

* checkout if passes match targets and kernels, test=develop

* add pass_utils, test=develop

* fix lite/core/mir/pass_registry.h, test=develop

* improve code styles, test=develop

* fix spell error, test=develop
上级 129b689b
lite_cc_library(mir_node SRCS node.cc DEPS kernel) lite_cc_library(mir_node SRCS node.cc DEPS kernel)
lite_cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node program) lite_cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node program)
lite_cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph) lite_cc_library(mir_pass SRCS pass.cc pass_utils.cc DEPS mir_ssa_graph)
lite_cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes) lite_cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes)
lite_cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) lite_cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
......
...@@ -43,4 +43,4 @@ class ArgumentTypeDisplayPass : public DebugPass { ...@@ -43,4 +43,4 @@ class ArgumentTypeDisplayPass : public DebugPass {
REGISTER_MIR_PASS(argument_type_display_pass, REGISTER_MIR_PASS(argument_type_display_pass,
paddle::lite::mir::ArgumentTypeDisplayPass) paddle::lite::mir::ArgumentTypeDisplayPass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -34,4 +34,5 @@ bool RegisterDemoPass() { ...@@ -34,4 +34,5 @@ bool RegisterDemoPass() {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(demo, paddle::lite::mir::DemoPass).SetTargets({TARGET(kAny)}); REGISTER_MIR_PASS(demo, paddle::lite::mir::DemoPass)
.BindTargets({TARGET(kAny)});
...@@ -70,4 +70,4 @@ class IdentityScaleEliminatePass : public ProgramPass { ...@@ -70,4 +70,4 @@ class IdentityScaleEliminatePass : public ProgramPass {
REGISTER_MIR_PASS(identity_scale_eliminate_pass, REGISTER_MIR_PASS(identity_scale_eliminate_pass,
paddle::lite::mir::IdentityScaleEliminatePass) paddle::lite::mir::IdentityScaleEliminatePass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -39,4 +39,4 @@ void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -39,4 +39,4 @@ void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_conv_activation_fuse_pass, REGISTER_MIR_PASS(lite_conv_activation_fuse_pass,
paddle::lite::mir::ConvActivationFusePass) paddle::lite::mir::ConvActivationFusePass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -35,4 +35,4 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -35,4 +35,4 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass) REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -39,4 +39,4 @@ void ConvElementwiseFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -39,4 +39,4 @@ void ConvElementwiseFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_conv_elementwise_fuse_pass, REGISTER_MIR_PASS(lite_conv_elementwise_fuse_pass,
paddle::lite::mir::ConvElementwiseFusePass) paddle::lite::mir::ConvElementwiseFusePass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -34,4 +34,4 @@ void ElementwiseAddActivationFusePass::Apply( ...@@ -34,4 +34,4 @@ void ElementwiseAddActivationFusePass::Apply(
REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass, REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass,
paddle::lite::mir::ElementwiseAddActivationFusePass) paddle::lite::mir::ElementwiseAddActivationFusePass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -32,4 +32,4 @@ void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -32,4 +32,4 @@ void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass) REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -36,4 +36,4 @@ void InterpolateFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -36,4 +36,4 @@ void InterpolateFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_interpolate_fuse_pass, REGISTER_MIR_PASS(lite_interpolate_fuse_pass,
paddle::lite::mir::InterpolateFusePass) paddle::lite::mir::InterpolateFusePass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -44,4 +44,4 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -44,4 +44,4 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass, REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass,
paddle::lite::mir::QuantDequantFusePass) paddle::lite::mir::QuantDequantFusePass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -36,4 +36,4 @@ void ShuffleChannelFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -36,4 +36,4 @@ void ShuffleChannelFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_shuffle_channel_fuse_pass, REGISTER_MIR_PASS(lite_shuffle_channel_fuse_pass,
paddle::lite::mir::ShuffleChannelFusePass) paddle::lite::mir::ShuffleChannelFusePass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -37,4 +37,4 @@ void TransposeSoftmaxTransposeFusePass::Apply( ...@@ -37,4 +37,4 @@ void TransposeSoftmaxTransposeFusePass::Apply(
REGISTER_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass, REGISTER_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass,
paddle::lite::mir::TransposeSoftmaxTransposeFusePass) paddle::lite::mir::TransposeSoftmaxTransposeFusePass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -39,4 +39,4 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -39,4 +39,4 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(generate_program_pass, paddle::lite::mir::GenerateProgramPass) REGISTER_MIR_PASS(generate_program_pass, paddle::lite::mir::GenerateProgramPass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -99,4 +99,4 @@ std::string Visualize(mir::SSAGraph* graph) { ...@@ -99,4 +99,4 @@ std::string Visualize(mir::SSAGraph* graph) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(graph_visualze, paddle::lite::mir::GraphVisualizePass) REGISTER_MIR_PASS(graph_visualze, paddle::lite::mir::GraphVisualizePass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -72,4 +72,4 @@ class IoCopyKernelPickPass : public StmtPass { ...@@ -72,4 +72,4 @@ class IoCopyKernelPickPass : public StmtPass {
REGISTER_MIR_PASS(io_copy_kernel_pick_pass, REGISTER_MIR_PASS(io_copy_kernel_pick_pass,
paddle::lite::mir::IoCopyKernelPickPass) paddle::lite::mir::IoCopyKernelPickPass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -261,4 +261,4 @@ void MemoryOptimizePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -261,4 +261,4 @@ void MemoryOptimizePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(memory_optimize_pass, paddle::lite::mir::MemoryOptimizePass) REGISTER_MIR_PASS(memory_optimize_pass, paddle::lite::mir::MemoryOptimizePass)
.SetTargets({TARGET(kARM)}); .BindTargets({TARGET(kARM)});
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map>
#include "lite/core/mir/node.h" #include "lite/core/mir/node.h"
#include "lite/core/mir/ssa_graph.h" #include "lite/core/mir/ssa_graph.h"
...@@ -46,11 +47,36 @@ class Pass { ...@@ -46,11 +47,36 @@ class Pass {
void set_doc(const std::string& doc) { doc_ = doc; } void set_doc(const std::string& doc) { doc_ = doc; }
const std::string& doc() const { return doc_; } const std::string& doc() const { return doc_; }
void set_targets(const std::set<TargetType>& targets) { targets_ = targets; } // Some passes only apply to qualified targets, which need to be explicitly
const std::set<TargetType>& targets() const { return targets_; } // declared.
bool is_supported_target(TargetType target) const { // Bind the target. At runtime, there must be one device in the bound targets.
if (targets_.find(TARGET(kAny)) != targets_.end()) return true; void BindTargets(const std::set<TargetType>& targets) {
return (targets_.find(target) != targets_.end()); bound_targets_ = targets;
}
// Get all bound targets.
const std::set<TargetType>& Targets() const { return bound_targets_; }
// Some passes are only available on qualified kernels and need to be
// explicitly declared.
// Bind kernels. All kernels bound at runtime must be registered.
void BindKernels(
const std::unordered_map<std::string, std::set<lite_api::Place>>&
kernels) {
bound_kernels_ = kernels;
}
// Get all bound kernels.
const std::unordered_map<std::string, std::set<lite_api::Place>>&
GetBoundKernels() const {
return bound_kernels_;
}
// Add one kernel to the bound kernels.
void BindKernel(const std::string& kernel_name,
const lite_api::Place& place) {
if (!bound_kernels_.count(kernel_name)) {
bound_kernels_.insert({kernel_name, {place}});
} else {
bound_kernels_.at(kernel_name).insert(place);
}
} }
Kind kind() const { return kind_; } Kind kind() const { return kind_; }
...@@ -64,7 +90,8 @@ class Pass { ...@@ -64,7 +90,8 @@ class Pass {
const Kind kind_; const Kind kind_;
std::string name_; std::string name_;
std::string doc_; std::string doc_;
std::set<TargetType> targets_; std::set<TargetType> bound_targets_;
std::unordered_map<std::string, std::set<lite_api::Place>> bound_kernels_;
}; };
// Different kinds. // Different kinds.
......
...@@ -30,8 +30,13 @@ class PassRegistry { ...@@ -30,8 +30,13 @@ class PassRegistry {
: name_(name), pass_(pass) { : name_(name), pass_(pass) {
PassManager::Global().AddNewPass(name_, pass_); PassManager::Global().AddNewPass(name_, pass_);
} }
PassRegistry& SetTargets(const std::set<TargetType>& targets) { PassRegistry& BindTargets(const std::set<TargetType>& targets) {
pass_->set_targets(targets); pass_->BindTargets(targets);
return *this;
}
PassRegistry& BindKernel(const std::string& name,
const lite_api::Place& place) {
pass_->BindKernel(name, place);
return *this; return *this;
} }
bool Touch() const { return true; } bool Touch() const { return true; }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/pass_utils.h"
#include <set>
#include <string>
#include <unordered_map>
namespace paddle {
namespace lite {
bool PassMatchesTarget(const mir::Pass& pass, TargetType target) {
const auto& targets = pass.Targets();
if (targets.find(TARGET(kAny)) != targets.end()) return true;
return (targets.find(target) != targets.end());
}
bool PassMatchesKernels(const mir::Pass& pass) {
const auto& kernels = pass.GetBoundKernels();
for (const auto& kernel : kernels) {
for (const auto& place : kernel.second) {
if (KernelRegistry::Global()
.Create(kernel.first, place.target, place.precision, place.layout)
.empty())
return false;
}
}
return true;
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
// Check if the pass hits the hardware target.
bool PassMatchesTarget(const mir::Pass& pass, TargetType target);
// Check if the pass hits all necessary operators.
bool PassMatchesKernels(const mir::Pass& pass);
} // namespace lite
} // namespace paddle
...@@ -39,4 +39,4 @@ class RuntimeContextAssignPass : public StmtPass { ...@@ -39,4 +39,4 @@ class RuntimeContextAssignPass : public StmtPass {
REGISTER_MIR_PASS(runtime_context_assign_pass, REGISTER_MIR_PASS(runtime_context_assign_pass,
paddle::lite::mir::RuntimeContextAssignPass) paddle::lite::mir::RuntimeContextAssignPass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -133,4 +133,4 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -133,4 +133,4 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(static_kernel_pick_pass, REGISTER_MIR_PASS(static_kernel_pick_pass,
paddle::lite::mir::StaticKernelPickPass) paddle::lite::mir::StaticKernelPickPass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -215,4 +215,4 @@ std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() { ...@@ -215,4 +215,4 @@ std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() {
REGISTER_MIR_PASS(generate_npu_program_pass, REGISTER_MIR_PASS(generate_npu_program_pass,
paddle::lite::mir::subgraph::GenerateNPUProgramPass) paddle::lite::mir::subgraph::GenerateNPUProgramPass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -311,4 +311,4 @@ int SubgraphProgramPass::FuseSubgraph( ...@@ -311,4 +311,4 @@ int SubgraphProgramPass::FuseSubgraph(
REGISTER_MIR_PASS(subgraph_program_pass, REGISTER_MIR_PASS(subgraph_program_pass,
paddle::lite::mir::subgraph::SubgraphProgramPass) paddle::lite::mir::subgraph::SubgraphProgramPass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -174,4 +174,4 @@ void TypeLayoutTransformPass::SetValidPlaces( ...@@ -174,4 +174,4 @@ void TypeLayoutTransformPass::SetValidPlaces(
REGISTER_MIR_PASS(type_layout_cast_pass, REGISTER_MIR_PASS(type_layout_cast_pass,
paddle::lite::mir::TypeLayoutTransformPass) paddle::lite::mir::TypeLayoutTransformPass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -180,4 +180,4 @@ void PrecisionCastPass::SetValidPlaces(const std::vector<Place>& valid_places) { ...@@ -180,4 +180,4 @@ void PrecisionCastPass::SetValidPlaces(const std::vector<Place>& valid_places) {
REGISTER_MIR_PASS(type_precision_cast_pass, REGISTER_MIR_PASS(type_precision_cast_pass,
paddle::lite::mir::PrecisionCastPass) paddle::lite::mir::PrecisionCastPass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -180,4 +180,4 @@ void TypeTargetTransformPass::SetValidPlaces( ...@@ -180,4 +180,4 @@ void TypeTargetTransformPass::SetValidPlaces(
REGISTER_MIR_PASS(type_target_cast_pass, REGISTER_MIR_PASS(type_target_cast_pass,
paddle::lite::mir::TypeTargetTransformPass) paddle::lite::mir::TypeTargetTransformPass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -32,4 +32,4 @@ void VariablePlaceInferencePass::Apply(const std::unique_ptr<SSAGraph> &graph) { ...@@ -32,4 +32,4 @@ void VariablePlaceInferencePass::Apply(const std::unique_ptr<SSAGraph> &graph) {
REGISTER_MIR_PASS(variable_place_inference_pass, REGISTER_MIR_PASS(variable_place_inference_pass,
paddle::lite::mir::VariablePlaceInferencePass) paddle::lite::mir::VariablePlaceInferencePass)
.SetTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <vector> #include <vector>
#include "lite/core/mir/generate_program_pass.h" #include "lite/core/mir/generate_program_pass.h"
#include "lite/core/mir/pass_manager.h" #include "lite/core/mir/pass_manager.h"
#include "lite/core/mir/pass_utils.h"
#include "lite/core/mir/ssa_graph.h" #include "lite/core/mir/ssa_graph.h"
#include "lite/core/mir/static_kernel_pick_pass.h" #include "lite/core/mir/static_kernel_pick_pass.h"
#include "lite/core/mir/type_target_cast_pass.h" #include "lite/core/mir/type_target_cast_pass.h"
...@@ -186,15 +187,15 @@ class Optimizer { ...@@ -186,15 +187,15 @@ class Optimizer {
LOG(INFO) << "== Running pass: " << x; LOG(INFO) << "== Running pass: " << x;
mir::Pass* pass = mir::PassManager::Global().LookUp(x); mir::Pass* pass = mir::PassManager::Global().LookUp(x);
CHECK(pass) << "Can not find pass: " << x; CHECK(pass) << "Can not find pass: " << x;
bool supported = false; bool matched = false;
for (const auto& place : valid_places_) { for (const auto& place : valid_places_) {
if (pass->is_supported_target(place.target)) { if (PassMatchesTarget(*pass, place.target)) {
supported = true; matched = true;
} }
} }
if (!supported) { matched = matched || PassMatchesKernels(*pass);
LOG(WARNING) << "Skip " << x if (!matched) {
<< " pass because the target does not match."; LOG(INFO) << "Skip " << x << " pass because the target does not match.";
} else { } else {
pass->Apply(graph_); pass->Apply(graph_);
LOG(INFO) << "== Finished running: " << x; LOG(INFO) << "== Finished running: " << x;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册