未验证 提交 a663ba3e 编写于 作者: C Cwndmiao 提交者: GitHub

[LITE][XPU] 1. Add CloneFrom() method in SSAGraph for recovery from failed...

[LITE][XPU] 1. Add CloneFrom() method in SSAGraph for recovery from failed pattern match; 2. Fix BindKernel in xpu fusion passes; (#4156)

* [LITE][XPU] 1. Add CloneFrom() method in SSAGraph for recovery from failed pattern match; 2. Fix BindKernel in xpu fusion passes;

* test=develop, test=xpu
上级 4d68af14
......@@ -163,4 +163,4 @@ class XPUEmbeddingWithEltwiseAddFusePass : public ProgramPass {
REGISTER_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass,
paddle::lite::mir::XPUEmbeddingWithEltwiseAddFusePass)
.BindTargets({TARGET(kXPU)})
.BindKernel("lookup_table");
.BindKernel("__xpu__embedding_with_eltwise_add");
......@@ -144,4 +144,4 @@ class XPUFcFusePass : public ProgramPass {
REGISTER_MIR_PASS(__xpu__fc_fuse_pass, paddle::lite::mir::XPUFcFusePass)
.BindTargets({TARGET(kXPU)})
.BindKernel("fc");
.BindKernel("__xpu__fc");
......@@ -672,4 +672,4 @@ class XPUMultiEncoderFusePass : public ProgramPass {
REGISTER_MIR_PASS(__xpu__multi_encoder_fuse_pass,
paddle::lite::mir::XPUMultiEncoderFusePass)
.BindTargets({TARGET(kXPU)})
.BindKernel("matmul");
.BindKernel("__xpu__multi_encoder");
......@@ -1368,14 +1368,24 @@ class XPUResNetCbamFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return;
bool changed = false;
SSAGraph backup;
backup.CloneFrom(*graph);
fusion::XPUResNetCbamBlock0Fuser block0_fuser;
block0_fuser(graph.get());
changed |= block0_fuser(graph.get());
fusion::XPUResNetCbamBlock1Fuser block1_fuser;
block1_fuser(graph.get());
changed |= block1_fuser(graph.get());
fusion::XPUResNetCbamBlock2Fuser block2_fuser;
block2_fuser(graph.get());
changed |= block2_fuser(graph.get());
fusion::XPUResNetCbamFuser resnet_fuser;
resnet_fuser(graph.get());
size_t n_matches = resnet_fuser(graph.get());
if (changed && !n_matches) {
// Restore graph from backuped one if no whole ResNetCbam graph was found
graph->CloneFrom(backup);
}
}
};
......
......@@ -932,12 +932,22 @@ class XPUResNet50FusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return;
bool changed = false;
SSAGraph backup;
backup.CloneFrom(*graph);
fusion::XPUResNetBlock0Fuser block0_fuser;
block0_fuser(graph.get());
changed |= block0_fuser(graph.get());
fusion::XPUResNetBlock1Fuser block1_fuser;
block1_fuser(graph.get());
changed |= block1_fuser(graph.get());
fusion::XPUResNet50Fuser resnet50_fuser;
resnet50_fuser(graph.get());
size_t n_matches = resnet50_fuser(graph.get());
if (changed && !n_matches) {
// Restore graph from backuped one if no whole ResNet50 graph was found
graph->CloneFrom(backup);
}
}
};
......@@ -948,4 +958,4 @@ class XPUResNet50FusePass : public ProgramPass {
REGISTER_MIR_PASS(__xpu__resnet_fuse_pass,
paddle::lite::mir::XPUResNet50FusePass)
.BindTargets({TARGET(kXPU)})
.BindKernel("conv2d");
.BindKernel("__xpu__resnet50");
......@@ -32,7 +32,8 @@ class FuseBase {
virtual ~FuseBase() = default;
void operator()(SSAGraph* graph) {
// Returns number of matched subgraphs
size_t operator()(SSAGraph* graph) {
BuildPattern();
PerformPatternMatcher(graph);
......@@ -41,6 +42,7 @@ class FuseBase {
}
DeleteInterNodes(graph);
return key2nodes_.size();
}
// Build a PMPattern using PMNode.
......
......@@ -226,6 +226,42 @@ void SSAGraph::RemoveNode(const mir::Node *node) {
node_storage_.erase(pos);
}
void SSAGraph::CloneFrom(const SSAGraph &from) {
node_storage_.clear();
arguments_.clear();
valid_places_ = from.valid_places_;
std::map<const mir::Node *, mir::Node *> clone_node_map;
for (const auto &node : from.node_storage_) {
if (node.IsArg()) {
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
new_node.AsArg() = *node.arg();
clone_node_map.emplace(&node, &new_node);
} else {
const auto *inst = node.stmt();
auto *new_node = GraphCreateInstructNode(inst->op(), valid_places_);
clone_node_map.emplace(&node, new_node);
}
}
// Rebuild node inlinks/outlinks
for (const auto &node : from.node_storage_) {
CHECK(clone_node_map.count(&node));
auto *new_node = clone_node_map.at(&node);
for (const auto *inlink : node.inlinks) {
CHECK(clone_node_map.count(inlink));
new_node->inlinks.emplace_back(clone_node_map.at(inlink));
}
for (const auto *outlink : node.outlinks) {
CHECK(clone_node_map.count(outlink));
new_node->outlinks.emplace_back(clone_node_map.at(outlink));
}
}
CheckValid();
}
mir::Node *SSAGraph::Argument(const std::string &name) {
auto it = arguments_.find(name);
CHECK(it != arguments_.end()) << "no argument called " << name;
......
......@@ -44,6 +44,9 @@ class SSAGraph : GraphBase {
int block_idx = kRootBlockIdx);
void RemoveNode(const mir::Node *node);
// Clone from another SSAGraph, all mir::Node(s) are duplicated.
void CloneFrom(const SSAGraph &from);
std::vector<mir::Node *> StmtTopologicalOrder();
std::vector<mir::Node *> NodeTopologicalOrder();
......
......@@ -31,7 +31,8 @@ class XPUFuseBase {
virtual ~XPUFuseBase() = default;
void operator()(SSAGraph* graph) {
// Returns number of matched subgraphs
size_t operator()(SSAGraph* graph) {
BuildPattern();
PerformPatternMatcher(graph);
......@@ -40,6 +41,7 @@ class XPUFuseBase {
}
DeleteInterNodes(graph);
return key2nodes_.size();
}
// Build a PMPattern using PMNode.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册