提交 65381f13 编写于 作者: H hong19860320 提交者: GitHub

[LITE][ALL] Refine NPU and XPU passes, fix the pass matching based on the...

[LITE][ALL] Refine NPU and XPU passes, fix the pass matching based on the bound targets and excluded targets (#2477)

上级 b3a5fc1a
......@@ -146,15 +146,22 @@ std::vector<std::string> Predictor::GetOutputNames() { return output_names_; }
// append the names of inputs and outputs into input_names_ and output_names_
void Predictor::PrepareFeedFetch() {
std::vector<const cpp::OpDesc *> feeds;
std::vector<const cpp::OpDesc *> fetchs;
#if defined(LITE_WITH_NPU) || defined(LITE_WITH_XPU)
// The shape of input tensors must be determined before generating NPU and XPU
// program.
auto current_block = program_desc_.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < current_block->OpsSize(); i++) {
auto op = current_block->GetOp<cpp::OpDesc>(i);
#else
if (!program_) {
GenRuntimeProgram();
}
std::vector<const cpp::OpDesc *> feeds;
std::vector<const cpp::OpDesc *> fetchs;
const auto &insts = program_->instructions();
for (size_t i = 0; i < program_->num_instructions(); i++) {
const auto &op = insts[i].op()->op_info();
#endif
if (op->Type() == "feed") {
feeds.push_back(op);
} else if (op->Type() == "fetch") {
......
......@@ -20,7 +20,12 @@ USE_MIR_PASS(static_kernel_pick_pass);
USE_MIR_PASS(variable_place_inference_pass);
USE_MIR_PASS(type_target_cast_pass);
USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS(subgraph_program_pass);
#ifdef LITE_WITH_NPU
USE_MIR_PASS(generate_npu_program_pass);
#endif
#ifdef LITE_WITH_XPU
USE_MIR_PASS(generate_xpu_program_pass);
#endif
USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass);
......
......@@ -255,4 +255,5 @@ void MemoryOptimizePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle
REGISTER_MIR_PASS(memory_optimize_pass, paddle::lite::mir::MemoryOptimizePass)
.BindTargets({TARGET(kARM)});
.BindTargets({TARGET(kARM)})
.ExcludeTargets({TARGET(kOpenCL), TARGET(kNPU), TARGET(kXPU)});
......@@ -52,34 +52,44 @@ class Pass {
// Bind targets. At runtime, there must be one device in the bound targets.
void BindTargets(const std::set<TargetType>& targets) {
std::set<TargetType> res;
for (const auto& target : targets) {
const std::set<TargetType>& universe = ExpandValidTargets(target);
std::set_union(bound_targets_.begin(),
bound_targets_.end(),
universe.begin(),
universe.end(),
std::inserter(res, res.begin()));
std::inserter(bound_targets_, bound_targets_.begin()));
}
bound_targets_ = res;
}
// Exclude targets. At runtime, there must be one device in the bound targets.
// Disable the pass if one of the valid devices is in the excluded targets.
void ExcludeTargets(const std::set<TargetType>& targets) {
std::set<TargetType> res;
for (const auto& target : targets) {
const std::set<TargetType>& universe = ExpandValidTargets(target);
std::set_difference(bound_targets_.begin(),
bound_targets_.end(),
universe.begin(),
universe.end(),
std::inserter(res, res.begin()));
std::set<TargetType> updated_bound_targets;
std::set_difference(
bound_targets_.begin(),
bound_targets_.end(),
universe.begin(),
universe.end(),
std::inserter(updated_bound_targets, updated_bound_targets.begin()));
bound_targets_ = updated_bound_targets;
std::set_union(
excluded_targets_.begin(),
excluded_targets_.end(),
universe.begin(),
universe.end(),
std::inserter(excluded_targets_, excluded_targets_.begin()));
}
bound_targets_ = res;
}
// Get all bound targets.
const std::set<TargetType>& Targets() const { return bound_targets_; }
const std::set<TargetType>& BoundTargets() const { return bound_targets_; }
// Get all excluded targets.
const std::set<TargetType>& ExcludedTargets() const {
return excluded_targets_;
}
// Some passes are only available on qualified kernels and need to be
// explicitly declared.
......@@ -116,6 +126,7 @@ class Pass {
std::string name_;
std::string doc_;
std::set<TargetType> bound_targets_;
std::set<TargetType> excluded_targets_;
std::unordered_map<std::string, std::set<lite_api::Place>> bound_kernels_;
};
......
......@@ -47,10 +47,34 @@ bool KernelRegistered(const std::string name, const Place& place) {
return false;
}
bool PassMatchesTarget(const mir::Pass& pass, TargetType target) {
const auto& targets = pass.Targets();
if (targets.find(TARGET(kAny)) != targets.end()) return true;
return (targets.find(target) != targets.end());
bool PassMatchesTarget(const mir::Pass& pass,
const std::set<TargetType>& targets) {
// Whether the pass is suitable for targets ? The condition is the
// intersection of targets and pass's bound targets is not empty, besides the
// intersection of targets and pass's excluded targets is empty. The formula
// is as follows: matched = !empty(targets ^ pass.bound_targets) &&
// empty(targets ^ pass.excluded_targets), where ^ is intersection operation.
const auto& bound_targets = pass.BoundTargets();
bool matched = bound_targets.find(TARGET(kAny)) != bound_targets.end();
std::set<TargetType> inter_bound_targets;
std::set_intersection(
bound_targets.begin(),
bound_targets.end(),
targets.begin(),
targets.end(),
std::inserter(inter_bound_targets, inter_bound_targets.begin()));
matched |= !inter_bound_targets.empty();
const auto& excluded_targets = pass.ExcludedTargets();
matched &= excluded_targets.find(TARGET(kAny)) == excluded_targets.end();
std::set<TargetType> inter_excluded_targets;
std::set_intersection(
excluded_targets.begin(),
excluded_targets.end(),
targets.begin(),
targets.end(),
std::inserter(inter_excluded_targets, inter_excluded_targets.begin()));
matched &= inter_excluded_targets.empty();
return matched;
}
bool PassMatchesKernels(const mir::Pass& pass) {
......
......@@ -14,6 +14,7 @@
#pragma once
#include <set>
#include <string>
#include "lite/core/mir/pass.h"
......@@ -24,7 +25,8 @@ namespace lite {
bool KernelRegistered(const std::string name, const Place& place);
// Check if the pass hits the hardware target.
bool PassMatchesTarget(const mir::Pass& pass, TargetType target);
bool PassMatchesTarget(const mir::Pass& pass,
const std::set<TargetType>& targets);
// Check if the pass hits all necessary operators.
bool PassMatchesKernels(const mir::Pass& pass);
......
......@@ -128,10 +128,10 @@ std::string GenerateNPUProgramPass::BuildNPUGraph(
// persistable=true, Sothat the model parser can recognize it and save it to
// param files
if (!lite::npu::BuildModel(inputs, outputs, weight)) {
LOG(WARNING) << "[NPU] Build NPU graph failed (subgraph=" << sub_id << ")";
throw std::runtime_error("Build NPU graph failed.");
LOG(FATAL) << "[NPU] Build NPU graph failed (subgraph=" << sub_id << ")";
} else {
LOG(INFO) << "[NPU] Build NPU graph success (subgraph=" << sub_id << ")";
}
LOG(INFO) << "[NPU] Build NPU graph success (subgraph=" << sub_id << ")";
return weight_var_name;
}
......@@ -175,40 +175,19 @@ void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
supported_op_types.push_back(i.first);
}
try {
int num_subgraph = FuseSubgraph(graph, supported_op_types);
InferOnce(graph);
auto op_nodes_all = ClassifySubgraph(graph);
CHECK_EQ(op_nodes_all.size(), num_subgraph);
int id = 1;
for (auto& op_nodes : op_nodes_all) {
LOG(INFO) << "[NPU] Converting Subgraph " << id;
GenNPUSubgraph(graph, op_nodes.second, id);
LOG(INFO) << "[NPU] After NPU Pass Subgraph " << id << "\n"
<< Visualize(graph.get());
id++;
}
} catch (...) {
LOG(WARNING) << "[NPU] Build NPU graph failed.";
throw std::runtime_error("[NPU] Build NPU graph failed.");
}
for (auto& item : graph->StmtTopologicalOrder()) {
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
LOG(INFO) << stmt;
insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front()));
}
int num_subgraph = FuseSubgraph(graph, supported_op_types);
InferOnce(graph);
auto op_nodes_all = ClassifySubgraph(graph);
CHECK_EQ(op_nodes_all.size(), num_subgraph);
int id = 1;
for (auto& op_nodes : op_nodes_all) {
LOG(INFO) << "[NPU] Converting Subgraph " << id;
GenNPUSubgraph(graph, op_nodes.second, id);
LOG(INFO) << "[NPU] After NPU Pass Subgraph " << id << "\n"
<< Visualize(graph.get());
id++;
}
}
std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() {
LOG(INFO) << "[NPU] program insts.size " << insts_.size();
std::unique_ptr<RuntimeProgram> program(
new RuntimeProgram(std::move(insts_)));
return program;
}
} // namespace subgraph
} // namespace mir
} // namespace lite
......
......@@ -35,7 +35,6 @@ class GenerateNPUProgramPass : public SubgraphProgramPass {
using key2nodes_t = std::map<std::string, Node*>;
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
std::unique_ptr<RuntimeProgram> GenProgram();
protected:
// nodes2cvt: op nodes to convert
......@@ -54,9 +53,6 @@ class GenerateNPUProgramPass : public SubgraphProgramPass {
void GenNPUSubgraph(const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes,
int sub_id);
private:
std::vector<Instruction> insts_;
};
} // namespace subgraph
......
......@@ -115,10 +115,10 @@ std::string GenerateXPUProgramPass::BuildXPUGraph(
graph_ctx.params,
&ordered_cvted_var_nodes,
weight)) {
LOG(WARNING) << "[XPU] Build XPU graph failed (subgraph=" << sub_id << ")";
throw std::runtime_error("[XPU] Build XPU graph failed.");
LOG(FATAL) << "[XPU] Build XPU graph failed (subgraph=" << sub_id << ")";
} else {
LOG(INFO) << "[XPU] Build XPU graph success (subgraph=" << sub_id << ")";
}
LOG(INFO) << "[XPU] Build XPU graph success (subgraph=" << sub_id << ")";
return weight_var_name;
}
......@@ -162,40 +162,19 @@ void GenerateXPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
supported_op_types.push_back(i.first);
}
try {
int num_subgraph = FuseSubgraph(graph, supported_op_types);
InferOnce(graph);
auto op_nodes_all = ClassifySubgraph(graph);
CHECK_EQ(op_nodes_all.size(), num_subgraph);
int id = 1;
for (auto& op_nodes : op_nodes_all) {
LOG(INFO) << "[XPU] Converting Subgraph " << id;
GenXPUSubgraph(graph, op_nodes.second, id);
LOG(INFO) << "[XPU] After XPU Pass Subgraph " << id << "\n"
<< Visualize(graph.get());
id++;
}
} catch (...) {
LOG(WARNING) << "[XPU] Build XPU graph failed.";
throw std::runtime_error("[XPU] Build XPU graph failed.");
}
for (auto& item : graph->StmtTopologicalOrder()) {
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
LOG(INFO) << stmt;
insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front()));
}
int num_subgraph = FuseSubgraph(graph, supported_op_types);
InferOnce(graph);
auto op_nodes_all = ClassifySubgraph(graph);
CHECK_EQ(op_nodes_all.size(), num_subgraph);
int id = 1;
for (auto& op_nodes : op_nodes_all) {
LOG(INFO) << "[XPU] Converting Subgraph " << id;
GenXPUSubgraph(graph, op_nodes.second, id);
LOG(INFO) << "[XPU] After XPU Pass Subgraph " << id << "\n"
<< Visualize(graph.get());
id++;
}
}
std::unique_ptr<RuntimeProgram> GenerateXPUProgramPass::GenProgram() {
LOG(INFO) << "[XPU] program insts.size=" << insts_.size();
std::unique_ptr<RuntimeProgram> program(
new RuntimeProgram(std::move(insts_)));
return program;
}
} // namespace subgraph
} // namespace mir
} // namespace lite
......
......@@ -35,7 +35,6 @@ class GenerateXPUProgramPass : public SubgraphProgramPass {
using key2nodes_t = std::map<std::string, Node*>;
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
std::unique_ptr<RuntimeProgram> GenProgram();
protected:
// nodes2cvt: op nodes to convert
......@@ -58,9 +57,6 @@ class GenerateXPUProgramPass : public SubgraphProgramPass {
void GenXPUSubgraph(const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes,
int sub_id);
private:
std::vector<Instruction> insts_;
};
} // namespace subgraph
......
......@@ -15,6 +15,7 @@
#pragma once
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "lite/core/mir/generate_program_pass.h"
......@@ -50,21 +51,6 @@ class Optimizer {
valid_places_ = valid_places;
CHECK(!valid_places.empty()) << "At least one valid_place should be set";
CHECK(!graph_) << "duplicate optimize found";
auto valid_places_has_target = [&](TargetType t) -> bool {
for (auto& p : valid_places) {
if (p.target == t) {
return true;
}
}
return false;
};
std::map<std::string, bool> lite_with_targets{
{"kOpenCL", valid_places_has_target(TARGET(kOpenCL))},
{"kNPU", valid_places_has_target(TARGET(kNPU))},
{"kXPU", valid_places_has_target(TARGET(kXPU))}};
VLOG(4) << "lite_with_targets['kOpenCL']:" << lite_with_targets["kOpenCL"];
VLOG(4) << "lite_with_targets['kNPU']:" << lite_with_targets["kNPU"];
VLOG(4) << "lite_with_targets['kXPU']:" << lite_with_targets["kXPU"];
graph_.reset(new mir::SSAGraph);
graph_->Build(program, valid_places);
......@@ -122,13 +108,8 @@ class Optimizer {
"argument_type_display_pass",
"runtime_context_assign_pass",
"argument_type_display_pass"}};
if ((!lite_with_targets["kOpenCL"]) && (!lite_with_targets["kNPU"]) &&
(!lite_with_targets["kXPU"])) {
// TODO(ysh329): cause CL_INVALID_MEM_OBJECT when setArg in OpenCL
// kernel
passes_local.emplace_back("memory_optimize_pass");
}
"argument_type_display_pass",
"memory_optimize_pass"}};
RunPasses(passes_local);
} else {
RunPasses(passes);
......@@ -140,40 +121,13 @@ class Optimizer {
// Generate a new program based on the mir graph.
std::unique_ptr<RuntimeProgram> GenRuntimeProgram() {
#if defined(LITE_WITH_NPU) || defined(LITE_WITH_XPU)
auto target_place = Place{
#ifdef LITE_WITH_NPU
TARGET(kNPU),
#endif
#ifdef LITE_WITH_XPU
TARGET(kXPU),
#endif
PRECISION(kFloat)};
if (std::find(valid_places_.begin(), valid_places_.end(), target_place) !=
valid_places_.end()) {
#ifdef LITE_WITH_NPU
auto pass = mir::PassManager::Global()
.LookUp<mir::subgraph::GenerateNPUProgramPass>(
"generate_npu_program_pass");
#endif
// Extra passes are applied for NPU and XPU, they depends on the shapes
// of input tensors. so GenRuntimeProgram() must be called after the shapes
// of input tensors are determined.
std::vector<std::string> subgraph_passes{"generate_npu_program_pass",
"generate_xpu_program_pass"};
RunPasses(subgraph_passes);
#ifdef LITE_WITH_XPU
auto pass = mir::PassManager::Global()
.LookUp<mir::subgraph::GenerateXPUProgramPass>(
"generate_xpu_program_pass");
#endif
try {
pass->Apply(graph_);
auto program = pass->GenProgram();
CHECK(exec_scope_);
program->set_exec_scope(exec_scope_);
return program;
} catch (...) {
LOG(WARNING) << "Build " << TargetToStr(target_place.target)
<< " program failed!";
}
}
#endif
auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
"generate_program_pass");
pass->Apply(graph_);
......@@ -215,14 +169,16 @@ class Optimizer {
for (auto& x : passes) {
LOG(INFO) << "== Running pass: " << x;
mir::Pass* pass = mir::PassManager::Global().LookUp(x);
CHECK(pass) << "Can not find pass: " << x;
bool matched = false;
if (!pass) {
LOG(INFO) << " - Skip " << x << " because the pass isn't found.";
continue;
}
std::set<TargetType> targets;
for (const auto& place : valid_places_) {
if (PassMatchesTarget(*pass, place.target)) {
matched = true;
}
targets.insert(place.target);
}
matched = matched && PassMatchesKernels(*pass);
bool matched =
PassMatchesTarget(*pass, targets) && PassMatchesKernels(*pass);
if (!matched) {
LOG(INFO) << " - Skip " << x
<< " because the target or kernel does not match.";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册