未验证 提交 a663ba3e 编写于 作者: C Cwndmiao 提交者: GitHub

[LITE][XPU] 1. Add CloneFrom() method in SSAGraph for recovery from failed...

[LITE][XPU] 1. Add CloneFrom() method in SSAGraph for recovery from failed pattern match; 2. Fix BindKernel in xpu fusion passes; (#4156)

* [LITE][XPU] 1. Add CloneFrom() method in SSAGraph for recovery from failed pattern match; 2. Fix BindKernel in xpu fusion passes;

* test=develop, test=xpu
上级 4d68af14
...@@ -163,4 +163,4 @@ class XPUEmbeddingWithEltwiseAddFusePass : public ProgramPass { ...@@ -163,4 +163,4 @@ class XPUEmbeddingWithEltwiseAddFusePass : public ProgramPass {
REGISTER_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass, REGISTER_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass,
paddle::lite::mir::XPUEmbeddingWithEltwiseAddFusePass) paddle::lite::mir::XPUEmbeddingWithEltwiseAddFusePass)
.BindTargets({TARGET(kXPU)}) .BindTargets({TARGET(kXPU)})
.BindKernel("lookup_table"); .BindKernel("__xpu__embedding_with_eltwise_add");
...@@ -144,4 +144,4 @@ class XPUFcFusePass : public ProgramPass { ...@@ -144,4 +144,4 @@ class XPUFcFusePass : public ProgramPass {
REGISTER_MIR_PASS(__xpu__fc_fuse_pass, paddle::lite::mir::XPUFcFusePass) REGISTER_MIR_PASS(__xpu__fc_fuse_pass, paddle::lite::mir::XPUFcFusePass)
.BindTargets({TARGET(kXPU)}) .BindTargets({TARGET(kXPU)})
.BindKernel("fc"); .BindKernel("__xpu__fc");
...@@ -672,4 +672,4 @@ class XPUMultiEncoderFusePass : public ProgramPass { ...@@ -672,4 +672,4 @@ class XPUMultiEncoderFusePass : public ProgramPass {
REGISTER_MIR_PASS(__xpu__multi_encoder_fuse_pass, REGISTER_MIR_PASS(__xpu__multi_encoder_fuse_pass,
paddle::lite::mir::XPUMultiEncoderFusePass) paddle::lite::mir::XPUMultiEncoderFusePass)
.BindTargets({TARGET(kXPU)}) .BindTargets({TARGET(kXPU)})
.BindKernel("matmul"); .BindKernel("__xpu__multi_encoder");
...@@ -1368,14 +1368,24 @@ class XPUResNetCbamFusePass : public ProgramPass { ...@@ -1368,14 +1368,24 @@ class XPUResNetCbamFusePass : public ProgramPass {
public: public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override { void Apply(const std::unique_ptr<SSAGraph>& graph) override {
if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return; if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return;
bool changed = false;
SSAGraph backup;
backup.CloneFrom(*graph);
fusion::XPUResNetCbamBlock0Fuser block0_fuser; fusion::XPUResNetCbamBlock0Fuser block0_fuser;
block0_fuser(graph.get()); changed |= block0_fuser(graph.get());
fusion::XPUResNetCbamBlock1Fuser block1_fuser; fusion::XPUResNetCbamBlock1Fuser block1_fuser;
block1_fuser(graph.get()); changed |= block1_fuser(graph.get());
fusion::XPUResNetCbamBlock2Fuser block2_fuser; fusion::XPUResNetCbamBlock2Fuser block2_fuser;
block2_fuser(graph.get()); changed |= block2_fuser(graph.get());
fusion::XPUResNetCbamFuser resnet_fuser; fusion::XPUResNetCbamFuser resnet_fuser;
resnet_fuser(graph.get()); size_t n_matches = resnet_fuser(graph.get());
if (changed && !n_matches) {
// Restore graph from backuped one if no whole ResNetCbam graph was found
graph->CloneFrom(backup);
}
} }
}; };
......
...@@ -932,12 +932,22 @@ class XPUResNet50FusePass : public ProgramPass { ...@@ -932,12 +932,22 @@ class XPUResNet50FusePass : public ProgramPass {
public: public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override { void Apply(const std::unique_ptr<SSAGraph>& graph) override {
if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return; if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return;
bool changed = false;
SSAGraph backup;
backup.CloneFrom(*graph);
fusion::XPUResNetBlock0Fuser block0_fuser; fusion::XPUResNetBlock0Fuser block0_fuser;
block0_fuser(graph.get()); changed |= block0_fuser(graph.get());
fusion::XPUResNetBlock1Fuser block1_fuser; fusion::XPUResNetBlock1Fuser block1_fuser;
block1_fuser(graph.get()); changed |= block1_fuser(graph.get());
fusion::XPUResNet50Fuser resnet50_fuser; fusion::XPUResNet50Fuser resnet50_fuser;
resnet50_fuser(graph.get()); size_t n_matches = resnet50_fuser(graph.get());
if (changed && !n_matches) {
// Restore graph from backuped one if no whole ResNet50 graph was found
graph->CloneFrom(backup);
}
} }
}; };
...@@ -948,4 +958,4 @@ class XPUResNet50FusePass : public ProgramPass { ...@@ -948,4 +958,4 @@ class XPUResNet50FusePass : public ProgramPass {
REGISTER_MIR_PASS(__xpu__resnet_fuse_pass, REGISTER_MIR_PASS(__xpu__resnet_fuse_pass,
paddle::lite::mir::XPUResNet50FusePass) paddle::lite::mir::XPUResNet50FusePass)
.BindTargets({TARGET(kXPU)}) .BindTargets({TARGET(kXPU)})
.BindKernel("conv2d"); .BindKernel("__xpu__resnet50");
...@@ -32,7 +32,8 @@ class FuseBase { ...@@ -32,7 +32,8 @@ class FuseBase {
virtual ~FuseBase() = default; virtual ~FuseBase() = default;
void operator()(SSAGraph* graph) { // Returns number of matched subgraphs
size_t operator()(SSAGraph* graph) {
BuildPattern(); BuildPattern();
PerformPatternMatcher(graph); PerformPatternMatcher(graph);
...@@ -41,6 +42,7 @@ class FuseBase { ...@@ -41,6 +42,7 @@ class FuseBase {
} }
DeleteInterNodes(graph); DeleteInterNodes(graph);
return key2nodes_.size();
} }
// Build a PMPattern using PMNode. // Build a PMPattern using PMNode.
......
...@@ -226,6 +226,42 @@ void SSAGraph::RemoveNode(const mir::Node *node) { ...@@ -226,6 +226,42 @@ void SSAGraph::RemoveNode(const mir::Node *node) {
node_storage_.erase(pos); node_storage_.erase(pos);
} }
void SSAGraph::CloneFrom(const SSAGraph &from) {
node_storage_.clear();
arguments_.clear();
valid_places_ = from.valid_places_;
std::map<const mir::Node *, mir::Node *> clone_node_map;
for (const auto &node : from.node_storage_) {
if (node.IsArg()) {
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
new_node.AsArg() = *node.arg();
clone_node_map.emplace(&node, &new_node);
} else {
const auto *inst = node.stmt();
auto *new_node = GraphCreateInstructNode(inst->op(), valid_places_);
clone_node_map.emplace(&node, new_node);
}
}
// Rebuild node inlinks/outlinks
for (const auto &node : from.node_storage_) {
CHECK(clone_node_map.count(&node));
auto *new_node = clone_node_map.at(&node);
for (const auto *inlink : node.inlinks) {
CHECK(clone_node_map.count(inlink));
new_node->inlinks.emplace_back(clone_node_map.at(inlink));
}
for (const auto *outlink : node.outlinks) {
CHECK(clone_node_map.count(outlink));
new_node->outlinks.emplace_back(clone_node_map.at(outlink));
}
}
CheckValid();
}
mir::Node *SSAGraph::Argument(const std::string &name) { mir::Node *SSAGraph::Argument(const std::string &name) {
auto it = arguments_.find(name); auto it = arguments_.find(name);
CHECK(it != arguments_.end()) << "no argument called " << name; CHECK(it != arguments_.end()) << "no argument called " << name;
......
...@@ -44,6 +44,9 @@ class SSAGraph : GraphBase { ...@@ -44,6 +44,9 @@ class SSAGraph : GraphBase {
int block_idx = kRootBlockIdx); int block_idx = kRootBlockIdx);
void RemoveNode(const mir::Node *node); void RemoveNode(const mir::Node *node);
// Clone from another SSAGraph, all mir::Node(s) are duplicated.
void CloneFrom(const SSAGraph &from);
std::vector<mir::Node *> StmtTopologicalOrder(); std::vector<mir::Node *> StmtTopologicalOrder();
std::vector<mir::Node *> NodeTopologicalOrder(); std::vector<mir::Node *> NodeTopologicalOrder();
......
...@@ -31,7 +31,8 @@ class XPUFuseBase { ...@@ -31,7 +31,8 @@ class XPUFuseBase {
virtual ~XPUFuseBase() = default; virtual ~XPUFuseBase() = default;
void operator()(SSAGraph* graph) { // Returns number of matched subgraphs
size_t operator()(SSAGraph* graph) {
BuildPattern(); BuildPattern();
PerformPatternMatcher(graph); PerformPatternMatcher(graph);
...@@ -40,6 +41,7 @@ class XPUFuseBase { ...@@ -40,6 +41,7 @@ class XPUFuseBase {
} }
DeleteInterNodes(graph); DeleteInterNodes(graph);
return key2nodes_.size();
} }
// Build a PMPattern using PMNode. // Build a PMPattern using PMNode.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册