提交 2f3d7fd6 编写于 作者: 石晓伟 提交者: Yan Chunwei

checkout if passes match targets and kernels, test=develop (#2035)

* checkout if passes match targets and kernels, test=develop

* add pass_utils, test=develop

* fix lite/core/mir/pass_registry.h, test=develop

* improve code styles, test=develop

* fix spell error, test=develop
上级 129b689b
lite_cc_library(mir_node SRCS node.cc DEPS kernel)
lite_cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node program)
lite_cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph)
lite_cc_library(mir_pass SRCS pass.cc pass_utils.cc DEPS mir_ssa_graph)
lite_cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes)
lite_cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
......
......@@ -43,4 +43,4 @@ class ArgumentTypeDisplayPass : public DebugPass {
REGISTER_MIR_PASS(argument_type_display_pass,
paddle::lite::mir::ArgumentTypeDisplayPass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -34,4 +34,5 @@ bool RegisterDemoPass() {
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(demo, paddle::lite::mir::DemoPass).SetTargets({TARGET(kAny)});
REGISTER_MIR_PASS(demo, paddle::lite::mir::DemoPass)
.BindTargets({TARGET(kAny)});
......@@ -70,4 +70,4 @@ class IdentityScaleEliminatePass : public ProgramPass {
REGISTER_MIR_PASS(identity_scale_eliminate_pass,
paddle::lite::mir::IdentityScaleEliminatePass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -39,4 +39,4 @@ void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_conv_activation_fuse_pass,
paddle::lite::mir::ConvActivationFusePass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -35,4 +35,4 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -39,4 +39,4 @@ void ConvElementwiseFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_conv_elementwise_fuse_pass,
paddle::lite::mir::ConvElementwiseFusePass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -34,4 +34,4 @@ void ElementwiseAddActivationFusePass::Apply(
REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass,
paddle::lite::mir::ElementwiseAddActivationFusePass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -32,4 +32,4 @@ void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle
REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -36,4 +36,4 @@ void InterpolateFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_interpolate_fuse_pass,
paddle::lite::mir::InterpolateFusePass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -44,4 +44,4 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass,
paddle::lite::mir::QuantDequantFusePass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -36,4 +36,4 @@ void ShuffleChannelFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_shuffle_channel_fuse_pass,
paddle::lite::mir::ShuffleChannelFusePass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -37,4 +37,4 @@ void TransposeSoftmaxTransposeFusePass::Apply(
REGISTER_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass,
paddle::lite::mir::TransposeSoftmaxTransposeFusePass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -39,4 +39,4 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle
REGISTER_MIR_PASS(generate_program_pass, paddle::lite::mir::GenerateProgramPass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -99,4 +99,4 @@ std::string Visualize(mir::SSAGraph* graph) {
} // namespace paddle
REGISTER_MIR_PASS(graph_visualze, paddle::lite::mir::GraphVisualizePass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -72,4 +72,4 @@ class IoCopyKernelPickPass : public StmtPass {
REGISTER_MIR_PASS(io_copy_kernel_pick_pass,
paddle::lite::mir::IoCopyKernelPickPass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -261,4 +261,4 @@ void MemoryOptimizePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle
REGISTER_MIR_PASS(memory_optimize_pass, paddle::lite::mir::MemoryOptimizePass)
.SetTargets({TARGET(kARM)});
.BindTargets({TARGET(kARM)});
......@@ -16,6 +16,7 @@
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include "lite/core/mir/node.h"
#include "lite/core/mir/ssa_graph.h"
......@@ -46,11 +47,36 @@ class Pass {
void set_doc(const std::string& doc) { doc_ = doc; }
const std::string& doc() const { return doc_; }
void set_targets(const std::set<TargetType>& targets) { targets_ = targets; }
const std::set<TargetType>& targets() const { return targets_; }
bool is_supported_target(TargetType target) const {
if (targets_.find(TARGET(kAny)) != targets_.end()) return true;
return (targets_.find(target) != targets_.end());
// Some passes only apply to qualified targets, which need to be explicitly
// declared.
// Bind the target. At runtime, there must be one device in the bound targets.
void BindTargets(const std::set<TargetType>& targets) {
bound_targets_ = targets;
}
// Get all bound targets.
const std::set<TargetType>& Targets() const { return bound_targets_; }
// Some passes are only available on qualified kernels and need to be
// explicitly declared.
// Bind kernels. All kernels bound at runtime must be registered.
void BindKernels(
const std::unordered_map<std::string, std::set<lite_api::Place>>&
kernels) {
bound_kernels_ = kernels;
}
// Get all bound kernels.
const std::unordered_map<std::string, std::set<lite_api::Place>>&
GetBoundKernels() const {
return bound_kernels_;
}
// Add one kernel to the bound kernels.
void BindKernel(const std::string& kernel_name,
const lite_api::Place& place) {
if (!bound_kernels_.count(kernel_name)) {
bound_kernels_.insert({kernel_name, {place}});
} else {
bound_kernels_.at(kernel_name).insert(place);
}
}
Kind kind() const { return kind_; }
......@@ -64,7 +90,8 @@ class Pass {
const Kind kind_;
std::string name_;
std::string doc_;
std::set<TargetType> targets_;
std::set<TargetType> bound_targets_;
std::unordered_map<std::string, std::set<lite_api::Place>> bound_kernels_;
};
// Different kinds.
......
......@@ -30,8 +30,13 @@ class PassRegistry {
: name_(name), pass_(pass) {
PassManager::Global().AddNewPass(name_, pass_);
}
PassRegistry& SetTargets(const std::set<TargetType>& targets) {
pass_->set_targets(targets);
PassRegistry& BindTargets(const std::set<TargetType>& targets) {
pass_->BindTargets(targets);
return *this;
}
PassRegistry& BindKernel(const std::string& name,
const lite_api::Place& place) {
pass_->BindKernel(name, place);
return *this;
}
bool Touch() const { return true; }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/pass_utils.h"
#include <set>
#include <string>
#include <unordered_map>
namespace paddle {
namespace lite {
bool PassMatchesTarget(const mir::Pass& pass, TargetType target) {
const auto& targets = pass.Targets();
if (targets.find(TARGET(kAny)) != targets.end()) return true;
return (targets.find(target) != targets.end());
}
bool PassMatchesKernels(const mir::Pass& pass) {
const auto& kernels = pass.GetBoundKernels();
for (const auto& kernel : kernels) {
for (const auto& place : kernel.second) {
if (KernelRegistry::Global()
.Create(kernel.first, place.target, place.precision, place.layout)
.empty())
return false;
}
}
return true;
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
// Check if the pass hits the hardware target.
bool PassMatchesTarget(const mir::Pass& pass, TargetType target);
// Check if the pass hits all necessary operators.
bool PassMatchesKernels(const mir::Pass& pass);
} // namespace lite
} // namespace paddle
......@@ -39,4 +39,4 @@ class RuntimeContextAssignPass : public StmtPass {
REGISTER_MIR_PASS(runtime_context_assign_pass,
paddle::lite::mir::RuntimeContextAssignPass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -133,4 +133,4 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(static_kernel_pick_pass,
paddle::lite::mir::StaticKernelPickPass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -215,4 +215,4 @@ std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() {
REGISTER_MIR_PASS(generate_npu_program_pass,
paddle::lite::mir::subgraph::GenerateNPUProgramPass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -311,4 +311,4 @@ int SubgraphProgramPass::FuseSubgraph(
REGISTER_MIR_PASS(subgraph_program_pass,
paddle::lite::mir::subgraph::SubgraphProgramPass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -174,4 +174,4 @@ void TypeLayoutTransformPass::SetValidPlaces(
REGISTER_MIR_PASS(type_layout_cast_pass,
paddle::lite::mir::TypeLayoutTransformPass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -180,4 +180,4 @@ void PrecisionCastPass::SetValidPlaces(const std::vector<Place>& valid_places) {
REGISTER_MIR_PASS(type_precision_cast_pass,
paddle::lite::mir::PrecisionCastPass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -180,4 +180,4 @@ void TypeTargetTransformPass::SetValidPlaces(
REGISTER_MIR_PASS(type_target_cast_pass,
paddle::lite::mir::TypeTargetTransformPass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -32,4 +32,4 @@ void VariablePlaceInferencePass::Apply(const std::unique_ptr<SSAGraph> &graph) {
REGISTER_MIR_PASS(variable_place_inference_pass,
paddle::lite::mir::VariablePlaceInferencePass)
.SetTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)});
......@@ -18,6 +18,7 @@
#include <vector>
#include "lite/core/mir/generate_program_pass.h"
#include "lite/core/mir/pass_manager.h"
#include "lite/core/mir/pass_utils.h"
#include "lite/core/mir/ssa_graph.h"
#include "lite/core/mir/static_kernel_pick_pass.h"
#include "lite/core/mir/type_target_cast_pass.h"
......@@ -186,15 +187,15 @@ class Optimizer {
LOG(INFO) << "== Running pass: " << x;
mir::Pass* pass = mir::PassManager::Global().LookUp(x);
CHECK(pass) << "Can not find pass: " << x;
bool supported = false;
bool matched = false;
for (const auto& place : valid_places_) {
if (pass->is_supported_target(place.target)) {
supported = true;
if (PassMatchesTarget(*pass, place.target)) {
matched = true;
}
}
if (!supported) {
LOG(WARNING) << "Skip " << x
<< " pass because the target does not match.";
matched = matched || PassMatchesKernels(*pass);
if (!matched) {
LOG(INFO) << "Skip " << x << " pass because the target does not match.";
} else {
pass->Apply(graph_);
LOG(INFO) << "== Finished running: " << x;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册