未验证 提交 c62fd634 编写于 作者: H hong19860320 提交者: GitHub

[LITE][ALL] Refine NPU and XPU passes, fix the pass matching based on the...

[LITE][ALL] Refine NPU and XPU passes, fix the pass matching based on the bound targets and excluded targets (#2477)

上级 820eb6d4
...@@ -146,15 +146,22 @@ std::vector<std::string> Predictor::GetOutputNames() { return output_names_; } ...@@ -146,15 +146,22 @@ std::vector<std::string> Predictor::GetOutputNames() { return output_names_; }
// append the names of inputs and outputs into input_names_ and output_names_ // append the names of inputs and outputs into input_names_ and output_names_
void Predictor::PrepareFeedFetch() { void Predictor::PrepareFeedFetch() {
std::vector<const cpp::OpDesc *> feeds;
std::vector<const cpp::OpDesc *> fetchs;
#if defined(LITE_WITH_NPU) || defined(LITE_WITH_XPU)
// The shape of input tensors must be determined before generating NPU and XPU
// program.
auto current_block = program_desc_.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < current_block->OpsSize(); i++) {
auto op = current_block->GetOp<cpp::OpDesc>(i);
#else
if (!program_) { if (!program_) {
GenRuntimeProgram(); GenRuntimeProgram();
} }
std::vector<const cpp::OpDesc *> feeds;
std::vector<const cpp::OpDesc *> fetchs;
const auto &insts = program_->instructions(); const auto &insts = program_->instructions();
for (size_t i = 0; i < program_->num_instructions(); i++) { for (size_t i = 0; i < program_->num_instructions(); i++) {
const auto &op = insts[i].op()->op_info(); const auto &op = insts[i].op()->op_info();
#endif
if (op->Type() == "feed") { if (op->Type() == "feed") {
feeds.push_back(op); feeds.push_back(op);
} else if (op->Type() == "fetch") { } else if (op->Type() == "fetch") {
......
...@@ -20,7 +20,12 @@ USE_MIR_PASS(static_kernel_pick_pass); ...@@ -20,7 +20,12 @@ USE_MIR_PASS(static_kernel_pick_pass);
USE_MIR_PASS(variable_place_inference_pass); USE_MIR_PASS(variable_place_inference_pass);
USE_MIR_PASS(type_target_cast_pass); USE_MIR_PASS(type_target_cast_pass);
USE_MIR_PASS(generate_program_pass); USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS(subgraph_program_pass); #ifdef LITE_WITH_NPU
USE_MIR_PASS(generate_npu_program_pass);
#endif
#ifdef LITE_WITH_XPU
USE_MIR_PASS(generate_xpu_program_pass);
#endif
USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass); USE_MIR_PASS(argument_type_display_pass);
......
...@@ -255,4 +255,5 @@ void MemoryOptimizePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -255,4 +255,5 @@ void MemoryOptimizePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(memory_optimize_pass, paddle::lite::mir::MemoryOptimizePass) REGISTER_MIR_PASS(memory_optimize_pass, paddle::lite::mir::MemoryOptimizePass)
.BindTargets({TARGET(kARM)}); .BindTargets({TARGET(kARM)})
.ExcludeTargets({TARGET(kOpenCL), TARGET(kNPU), TARGET(kXPU)});
...@@ -52,34 +52,44 @@ class Pass { ...@@ -52,34 +52,44 @@ class Pass {
// Bind targets. At runtime, there must be one device in the bound targets. // Bind targets. At runtime, there must be one device in the bound targets.
void BindTargets(const std::set<TargetType>& targets) { void BindTargets(const std::set<TargetType>& targets) {
std::set<TargetType> res;
for (const auto& target : targets) { for (const auto& target : targets) {
const std::set<TargetType>& universe = ExpandValidTargets(target); const std::set<TargetType>& universe = ExpandValidTargets(target);
std::set_union(bound_targets_.begin(), std::set_union(bound_targets_.begin(),
bound_targets_.end(), bound_targets_.end(),
universe.begin(), universe.begin(),
universe.end(), universe.end(),
std::inserter(res, res.begin())); std::inserter(bound_targets_, bound_targets_.begin()));
} }
bound_targets_ = res;
} }
// Exclude targets. At runtime, there must be one device in the bound targets. // Exclude targets. At runtime, there must be one device in the bound targets.
// Disable the pass if one of the valid devices is in the excluded targets.
void ExcludeTargets(const std::set<TargetType>& targets) { void ExcludeTargets(const std::set<TargetType>& targets) {
std::set<TargetType> res;
for (const auto& target : targets) { for (const auto& target : targets) {
const std::set<TargetType>& universe = ExpandValidTargets(target); const std::set<TargetType>& universe = ExpandValidTargets(target);
std::set_difference(bound_targets_.begin(), std::set<TargetType> updated_bound_targets;
bound_targets_.end(), std::set_difference(
universe.begin(), bound_targets_.begin(),
universe.end(), bound_targets_.end(),
std::inserter(res, res.begin())); universe.begin(),
universe.end(),
std::inserter(updated_bound_targets, updated_bound_targets.begin()));
bound_targets_ = updated_bound_targets;
std::set_union(
excluded_targets_.begin(),
excluded_targets_.end(),
universe.begin(),
universe.end(),
std::inserter(excluded_targets_, excluded_targets_.begin()));
} }
bound_targets_ = res;
} }
// Get all bound targets. // Get all bound targets.
const std::set<TargetType>& Targets() const { return bound_targets_; } const std::set<TargetType>& BoundTargets() const { return bound_targets_; }
// Get all excluded targets.
const std::set<TargetType>& ExcludedTargets() const {
return excluded_targets_;
}
// Some passes are only available on qualified kernels and need to be // Some passes are only available on qualified kernels and need to be
// explicitly declared. // explicitly declared.
...@@ -116,6 +126,7 @@ class Pass { ...@@ -116,6 +126,7 @@ class Pass {
std::string name_; std::string name_;
std::string doc_; std::string doc_;
std::set<TargetType> bound_targets_; std::set<TargetType> bound_targets_;
std::set<TargetType> excluded_targets_;
std::unordered_map<std::string, std::set<lite_api::Place>> bound_kernels_; std::unordered_map<std::string, std::set<lite_api::Place>> bound_kernels_;
}; };
......
...@@ -47,10 +47,34 @@ bool KernelRegistered(const std::string name, const Place& place) { ...@@ -47,10 +47,34 @@ bool KernelRegistered(const std::string name, const Place& place) {
return false; return false;
} }
bool PassMatchesTarget(const mir::Pass& pass, TargetType target) { bool PassMatchesTarget(const mir::Pass& pass,
const auto& targets = pass.Targets(); const std::set<TargetType>& targets) {
if (targets.find(TARGET(kAny)) != targets.end()) return true; // Whether the pass is suitable for targets ? The condition is the
return (targets.find(target) != targets.end()); // intersection of targets and pass's bound targets is not empty, besides the
// intersection of targets and pass's excluded targets is empty. The formula
// is as follows: matched = !empty(targets ^ pass.bound_targets) &&
// empty(targets ^ pass.excluded_targets), where ^ is intersection operation.
const auto& bound_targets = pass.BoundTargets();
bool matched = bound_targets.find(TARGET(kAny)) != bound_targets.end();
std::set<TargetType> inter_bound_targets;
std::set_intersection(
bound_targets.begin(),
bound_targets.end(),
targets.begin(),
targets.end(),
std::inserter(inter_bound_targets, inter_bound_targets.begin()));
matched |= !inter_bound_targets.empty();
const auto& excluded_targets = pass.ExcludedTargets();
matched &= excluded_targets.find(TARGET(kAny)) == excluded_targets.end();
std::set<TargetType> inter_excluded_targets;
std::set_intersection(
excluded_targets.begin(),
excluded_targets.end(),
targets.begin(),
targets.end(),
std::inserter(inter_excluded_targets, inter_excluded_targets.begin()));
matched &= inter_excluded_targets.empty();
return matched;
} }
bool PassMatchesKernels(const mir::Pass& pass) { bool PassMatchesKernels(const mir::Pass& pass) {
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <set>
#include <string> #include <string>
#include "lite/core/mir/pass.h" #include "lite/core/mir/pass.h"
...@@ -24,7 +25,8 @@ namespace lite { ...@@ -24,7 +25,8 @@ namespace lite {
bool KernelRegistered(const std::string name, const Place& place); bool KernelRegistered(const std::string name, const Place& place);
// Check if the pass hits the hardware target. // Check if the pass hits the hardware target.
bool PassMatchesTarget(const mir::Pass& pass, TargetType target); bool PassMatchesTarget(const mir::Pass& pass,
const std::set<TargetType>& targets);
// Check if the pass hits all necessary operators. // Check if the pass hits all necessary operators.
bool PassMatchesKernels(const mir::Pass& pass); bool PassMatchesKernels(const mir::Pass& pass);
......
...@@ -128,10 +128,10 @@ std::string GenerateNPUProgramPass::BuildNPUGraph( ...@@ -128,10 +128,10 @@ std::string GenerateNPUProgramPass::BuildNPUGraph(
// persistable=true, Sothat the model parser can recognize it and save it to // persistable=true, Sothat the model parser can recognize it and save it to
// param files // param files
if (!lite::npu::BuildModel(inputs, outputs, weight)) { if (!lite::npu::BuildModel(inputs, outputs, weight)) {
LOG(WARNING) << "[NPU] Build NPU graph failed (subgraph=" << sub_id << ")"; LOG(FATAL) << "[NPU] Build NPU graph failed (subgraph=" << sub_id << ")";
throw std::runtime_error("Build NPU graph failed."); } else {
LOG(INFO) << "[NPU] Build NPU graph success (subgraph=" << sub_id << ")";
} }
LOG(INFO) << "[NPU] Build NPU graph success (subgraph=" << sub_id << ")";
return weight_var_name; return weight_var_name;
} }
...@@ -175,40 +175,19 @@ void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -175,40 +175,19 @@ void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
supported_op_types.push_back(i.first); supported_op_types.push_back(i.first);
} }
try { int num_subgraph = FuseSubgraph(graph, supported_op_types);
int num_subgraph = FuseSubgraph(graph, supported_op_types); InferOnce(graph);
InferOnce(graph); auto op_nodes_all = ClassifySubgraph(graph);
auto op_nodes_all = ClassifySubgraph(graph); CHECK_EQ(op_nodes_all.size(), num_subgraph);
CHECK_EQ(op_nodes_all.size(), num_subgraph); int id = 1;
int id = 1; for (auto& op_nodes : op_nodes_all) {
for (auto& op_nodes : op_nodes_all) { LOG(INFO) << "[NPU] Converting Subgraph " << id;
LOG(INFO) << "[NPU] Converting Subgraph " << id; GenNPUSubgraph(graph, op_nodes.second, id);
GenNPUSubgraph(graph, op_nodes.second, id); LOG(INFO) << "[NPU] After NPU Pass Subgraph " << id << "\n"
LOG(INFO) << "[NPU] After NPU Pass Subgraph " << id << "\n" << Visualize(graph.get());
<< Visualize(graph.get()); id++;
id++;
}
} catch (...) {
LOG(WARNING) << "[NPU] Build NPU graph failed.";
throw std::runtime_error("[NPU] Build NPU graph failed.");
}
for (auto& item : graph->StmtTopologicalOrder()) {
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
LOG(INFO) << stmt;
insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front()));
}
} }
} }
std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() {
LOG(INFO) << "[NPU] program insts.size " << insts_.size();
std::unique_ptr<RuntimeProgram> program(
new RuntimeProgram(std::move(insts_)));
return program;
}
} // namespace subgraph } // namespace subgraph
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
......
...@@ -35,7 +35,6 @@ class GenerateNPUProgramPass : public SubgraphProgramPass { ...@@ -35,7 +35,6 @@ class GenerateNPUProgramPass : public SubgraphProgramPass {
using key2nodes_t = std::map<std::string, Node*>; using key2nodes_t = std::map<std::string, Node*>;
void Apply(const std::unique_ptr<SSAGraph>& graph) override; void Apply(const std::unique_ptr<SSAGraph>& graph) override;
std::unique_ptr<RuntimeProgram> GenProgram();
protected: protected:
// nodes2cvt: op nodes to convert // nodes2cvt: op nodes to convert
...@@ -54,9 +53,6 @@ class GenerateNPUProgramPass : public SubgraphProgramPass { ...@@ -54,9 +53,6 @@ class GenerateNPUProgramPass : public SubgraphProgramPass {
void GenNPUSubgraph(const std::unique_ptr<SSAGraph>& graph, void GenNPUSubgraph(const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes, const std::unordered_set<Node*>& op_nodes,
int sub_id); int sub_id);
private:
std::vector<Instruction> insts_;
}; };
} // namespace subgraph } // namespace subgraph
......
...@@ -115,10 +115,10 @@ std::string GenerateXPUProgramPass::BuildXPUGraph( ...@@ -115,10 +115,10 @@ std::string GenerateXPUProgramPass::BuildXPUGraph(
graph_ctx.params, graph_ctx.params,
&ordered_cvted_var_nodes, &ordered_cvted_var_nodes,
weight)) { weight)) {
LOG(WARNING) << "[XPU] Build XPU graph failed (subgraph=" << sub_id << ")"; LOG(FATAL) << "[XPU] Build XPU graph failed (subgraph=" << sub_id << ")";
throw std::runtime_error("[XPU] Build XPU graph failed."); } else {
LOG(INFO) << "[XPU] Build XPU graph success (subgraph=" << sub_id << ")";
} }
LOG(INFO) << "[XPU] Build XPU graph success (subgraph=" << sub_id << ")";
return weight_var_name; return weight_var_name;
} }
...@@ -162,40 +162,19 @@ void GenerateXPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -162,40 +162,19 @@ void GenerateXPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
supported_op_types.push_back(i.first); supported_op_types.push_back(i.first);
} }
try { int num_subgraph = FuseSubgraph(graph, supported_op_types);
int num_subgraph = FuseSubgraph(graph, supported_op_types); InferOnce(graph);
InferOnce(graph); auto op_nodes_all = ClassifySubgraph(graph);
auto op_nodes_all = ClassifySubgraph(graph); CHECK_EQ(op_nodes_all.size(), num_subgraph);
CHECK_EQ(op_nodes_all.size(), num_subgraph); int id = 1;
int id = 1; for (auto& op_nodes : op_nodes_all) {
for (auto& op_nodes : op_nodes_all) { LOG(INFO) << "[XPU] Converting Subgraph " << id;
LOG(INFO) << "[XPU] Converting Subgraph " << id; GenXPUSubgraph(graph, op_nodes.second, id);
GenXPUSubgraph(graph, op_nodes.second, id); LOG(INFO) << "[XPU] After XPU Pass Subgraph " << id << "\n"
LOG(INFO) << "[XPU] After XPU Pass Subgraph " << id << "\n" << Visualize(graph.get());
<< Visualize(graph.get()); id++;
id++;
}
} catch (...) {
LOG(WARNING) << "[XPU] Build XPU graph failed.";
throw std::runtime_error("[XPU] Build XPU graph failed.");
}
for (auto& item : graph->StmtTopologicalOrder()) {
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
LOG(INFO) << stmt;
insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front()));
}
} }
} }
std::unique_ptr<RuntimeProgram> GenerateXPUProgramPass::GenProgram() {
LOG(INFO) << "[XPU] program insts.size=" << insts_.size();
std::unique_ptr<RuntimeProgram> program(
new RuntimeProgram(std::move(insts_)));
return program;
}
} // namespace subgraph } // namespace subgraph
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
......
...@@ -35,7 +35,6 @@ class GenerateXPUProgramPass : public SubgraphProgramPass { ...@@ -35,7 +35,6 @@ class GenerateXPUProgramPass : public SubgraphProgramPass {
using key2nodes_t = std::map<std::string, Node*>; using key2nodes_t = std::map<std::string, Node*>;
void Apply(const std::unique_ptr<SSAGraph>& graph) override; void Apply(const std::unique_ptr<SSAGraph>& graph) override;
std::unique_ptr<RuntimeProgram> GenProgram();
protected: protected:
// nodes2cvt: op nodes to convert // nodes2cvt: op nodes to convert
...@@ -58,9 +57,6 @@ class GenerateXPUProgramPass : public SubgraphProgramPass { ...@@ -58,9 +57,6 @@ class GenerateXPUProgramPass : public SubgraphProgramPass {
void GenXPUSubgraph(const std::unique_ptr<SSAGraph>& graph, void GenXPUSubgraph(const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes, const std::unordered_set<Node*>& op_nodes,
int sub_id); int sub_id);
private:
std::vector<Instruction> insts_;
}; };
} // namespace subgraph } // namespace subgraph
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <map> #include <map>
#include <memory> #include <memory>
#include <set>
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/core/mir/generate_program_pass.h" #include "lite/core/mir/generate_program_pass.h"
...@@ -50,21 +51,6 @@ class Optimizer { ...@@ -50,21 +51,6 @@ class Optimizer {
valid_places_ = valid_places; valid_places_ = valid_places;
CHECK(!valid_places.empty()) << "At least one valid_place should be set"; CHECK(!valid_places.empty()) << "At least one valid_place should be set";
CHECK(!graph_) << "duplicate optimize found"; CHECK(!graph_) << "duplicate optimize found";
auto valid_places_has_target = [&](TargetType t) -> bool {
for (auto& p : valid_places) {
if (p.target == t) {
return true;
}
}
return false;
};
std::map<std::string, bool> lite_with_targets{
{"kOpenCL", valid_places_has_target(TARGET(kOpenCL))},
{"kNPU", valid_places_has_target(TARGET(kNPU))},
{"kXPU", valid_places_has_target(TARGET(kXPU))}};
VLOG(4) << "lite_with_targets['kOpenCL']:" << lite_with_targets["kOpenCL"];
VLOG(4) << "lite_with_targets['kNPU']:" << lite_with_targets["kNPU"];
VLOG(4) << "lite_with_targets['kXPU']:" << lite_with_targets["kXPU"];
graph_.reset(new mir::SSAGraph); graph_.reset(new mir::SSAGraph);
graph_->Build(program, valid_places); graph_->Build(program, valid_places);
...@@ -122,13 +108,8 @@ class Optimizer { ...@@ -122,13 +108,8 @@ class Optimizer {
"argument_type_display_pass", "argument_type_display_pass",
"runtime_context_assign_pass", "runtime_context_assign_pass",
"argument_type_display_pass"}}; "argument_type_display_pass",
if ((!lite_with_targets["kOpenCL"]) && (!lite_with_targets["kNPU"]) && "memory_optimize_pass"}};
(!lite_with_targets["kXPU"])) {
// TODO(ysh329): cause CL_INVALID_MEM_OBJECT when setArg in OpenCL
// kernel
passes_local.emplace_back("memory_optimize_pass");
}
RunPasses(passes_local); RunPasses(passes_local);
} else { } else {
RunPasses(passes); RunPasses(passes);
...@@ -140,40 +121,13 @@ class Optimizer { ...@@ -140,40 +121,13 @@ class Optimizer {
// Generate a new program based on the mir graph. // Generate a new program based on the mir graph.
std::unique_ptr<RuntimeProgram> GenRuntimeProgram() { std::unique_ptr<RuntimeProgram> GenRuntimeProgram() {
#if defined(LITE_WITH_NPU) || defined(LITE_WITH_XPU) // Extra passes are applied for NPU and XPU, they depends on the shapes
auto target_place = Place{ // of input tensors. so GenRuntimeProgram() must be called after the shapes
#ifdef LITE_WITH_NPU // of input tensors are determined.
TARGET(kNPU), std::vector<std::string> subgraph_passes{"generate_npu_program_pass",
#endif "generate_xpu_program_pass"};
#ifdef LITE_WITH_XPU RunPasses(subgraph_passes);
TARGET(kXPU),
#endif
PRECISION(kFloat)};
if (std::find(valid_places_.begin(), valid_places_.end(), target_place) !=
valid_places_.end()) {
#ifdef LITE_WITH_NPU
auto pass = mir::PassManager::Global()
.LookUp<mir::subgraph::GenerateNPUProgramPass>(
"generate_npu_program_pass");
#endif
#ifdef LITE_WITH_XPU
auto pass = mir::PassManager::Global()
.LookUp<mir::subgraph::GenerateXPUProgramPass>(
"generate_xpu_program_pass");
#endif
try {
pass->Apply(graph_);
auto program = pass->GenProgram();
CHECK(exec_scope_);
program->set_exec_scope(exec_scope_);
return program;
} catch (...) {
LOG(WARNING) << "Build " << TargetToStr(target_place.target)
<< " program failed!";
}
}
#endif
auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>( auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
"generate_program_pass"); "generate_program_pass");
pass->Apply(graph_); pass->Apply(graph_);
...@@ -215,14 +169,16 @@ class Optimizer { ...@@ -215,14 +169,16 @@ class Optimizer {
for (auto& x : passes) { for (auto& x : passes) {
LOG(INFO) << "== Running pass: " << x; LOG(INFO) << "== Running pass: " << x;
mir::Pass* pass = mir::PassManager::Global().LookUp(x); mir::Pass* pass = mir::PassManager::Global().LookUp(x);
CHECK(pass) << "Can not find pass: " << x; if (!pass) {
bool matched = false; LOG(INFO) << " - Skip " << x << " because the pass isn't found.";
continue;
}
std::set<TargetType> targets;
for (const auto& place : valid_places_) { for (const auto& place : valid_places_) {
if (PassMatchesTarget(*pass, place.target)) { targets.insert(place.target);
matched = true;
}
} }
matched = matched && PassMatchesKernels(*pass); bool matched =
PassMatchesTarget(*pass, targets) && PassMatchesKernels(*pass);
if (!matched) { if (!matched) {
LOG(INFO) << " - Skip " << x LOG(INFO) << " - Skip " << x
<< " because the target or kernel does not match."; << " because the target or kernel does not match.";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册