未验证 提交 be1b3fc3 编写于 作者: S sprouteer 提交者: GitHub

[XPU][BUG] Fix link_xpu_op_max_pass bug (#53258)

上级 503f422e
...@@ -110,7 +110,8 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern, ...@@ -110,7 +110,8 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern,
->assert_is_op_input(conv_type_, "Filter") ->assert_is_op_input(conv_type_, "Filter")
->AsInput(); ->AsInput();
auto conv_out = pattern->NewNode(conv_out_repr()) auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output(conv_type_, "Output"); ->assert_is_op_output(conv_type_, "Output")
->assert_has_n_outputs(1);
conv->LinksFrom({input, conv_filter}).LinksTo({conv_out}); conv->LinksFrom({input, conv_filter}).LinksTo({conv_out});
// ew_bias_add op // ew_bias_add op
PDNode* ew_bias_add = nullptr; PDNode* ew_bias_add = nullptr;
...@@ -190,12 +191,12 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern, ...@@ -190,12 +191,12 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern,
// ew_branch_add op // ew_branch_add op
if (with_branch_) { if (with_branch_) {
if (with_branch_x_) { if (with_branch_x_) {
bn_out->assert_is_op_input("elementwise_add", "Y")->AsIntermediate(); bn_out->assert_is_op_input("elementwise_add", "Y");
ew_branch_add_in = pattern->NewNode(ew_branch_add_in_repr()) ew_branch_add_in = pattern->NewNode(ew_branch_add_in_repr())
->assert_is_op_input("elementwise_add", "X") ->assert_is_op_input("elementwise_add", "X")
->AsInput(); ->AsInput();
} else if (with_branch_y_) { } else if (with_branch_y_) {
bn_out->assert_is_op_input("elementwise_add", "X")->AsIntermediate(); bn_out->assert_is_op_input("elementwise_add", "X");
ew_branch_add_in = pattern->NewNode(ew_branch_add_in_repr()) ew_branch_add_in = pattern->NewNode(ew_branch_add_in_repr())
->assert_is_op_input("elementwise_add", "Y") ->assert_is_op_input("elementwise_add", "Y")
->AsInput(); ->AsInput();
...@@ -221,13 +222,15 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern, ...@@ -221,13 +222,15 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern,
} }
// act op // act op
if (!act_type_.empty()) { if (!act_type_.empty()) {
ew_branch_add_out->assert_is_op_input(act_type_, "X")->AsIntermediate(); ew_branch_add_out->assert_is_op_input(act_type_, "X");
act = pattern->NewNode(act_repr())->assert_is_op(act_type_); act = pattern->NewNode(act_repr())->assert_is_op(act_type_);
act_out = pattern->NewNode(act_out_repr()) act_out =
->assert_is_op_output(act_type_, "Out") pattern->NewNode(act_out_repr())->assert_is_op_output(act_type_, "Out");
->assert_var_not_persistable();
act->LinksFrom({ew_branch_add_out}).LinksTo({act_out}); act->LinksFrom({ew_branch_add_out}).LinksTo({act_out});
} else {
act_out = ew_branch_add_out;
} }
act_out->AsOutput();
} }
} // namespace patterns } // namespace patterns
......
...@@ -41,30 +41,39 @@ namespace patterns { ...@@ -41,30 +41,39 @@ namespace patterns {
struct FusionXPUOpPattern : public PatternBase { struct FusionXPUOpPattern : public PatternBase {
FusionXPUOpPattern(PDPattern* pattern, FusionXPUOpPattern(PDPattern* pattern,
const std::string& name_scope, const std::string& name_scope,
const std::string& op_type); const std::string& op_type,
bool with_branch);
// declare operator node's name // declare operator node's name
PATTERN_DECL_NODE(fusion_op); PATTERN_DECL_NODE(fusion_op);
// declare variable node's name // declare variable node's name
PATTERN_DECL_NODE(out); PATTERN_DECL_NODE(input);
PATTERN_DECL_NODE(out_max); PATTERN_DECL_NODE(branch);
private: private:
std::string op_type_; std::string op_type_;
bool with_branch_{false};
}; };
FusionXPUOpPattern::FusionXPUOpPattern(PDPattern* pattern, FusionXPUOpPattern::FusionXPUOpPattern(PDPattern* pattern,
const std::string& name_scope, const std::string& name_scope,
const std::string& op_type) const std::string& op_type,
: PatternBase(pattern, name_scope, name_scope), op_type_(op_type) { bool with_branch)
: PatternBase(pattern, name_scope, name_scope),
op_type_(op_type),
with_branch_(with_branch) {
auto* fusion_op = pattern->NewNode(fusion_op_repr())->assert_is_op(op_type_); auto* fusion_op = pattern->NewNode(fusion_op_repr())->assert_is_op(op_type_);
auto* out = pattern->NewNode(out_repr()) auto* input =
->assert_is_op_output(op_type_, "out") pattern->NewNode(input_repr())->assert_is_op_input(op_type_, "x");
->assert_var_not_persistable();
auto* out_max = pattern->NewNode(out_max_repr()) PDNode* branch = nullptr;
->assert_is_op_output(op_type_, "out_max") if (with_branch_) {
->assert_var_not_persistable(); branch =
fusion_op->LinksTo({out, out_max}); pattern->NewNode(branch_repr())->assert_is_op_input(op_type_, "branch");
fusion_op->LinksFrom({input, branch});
} else {
fusion_op->LinksFrom({input});
}
} }
} // namespace patterns } // namespace patterns
...@@ -74,7 +83,9 @@ class LinkXPUOpMaxPass : public FusePassBase { ...@@ -74,7 +83,9 @@ class LinkXPUOpMaxPass : public FusePassBase {
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
private: private:
void ApplyImpl(ir::Graph* graph, const std::string& op_type) const; void ApplyImpl(ir::Graph* graph,
const std::string& op_type,
bool with_branch) const;
const std::string name_scope_{"multi_encoder_xpu_slice_fuse_pass"}; const std::string name_scope_{"multi_encoder_xpu_slice_fuse_pass"};
// ops with x_max/out_max // ops with x_max/out_max
...@@ -89,8 +100,7 @@ Origin subgraph: ...@@ -89,8 +100,7 @@ Origin subgraph:
out0 out0_max out0 out0_max
| |
\ \
fusion_xpu_op1 fusion_op
Fused subgraph: Fused subgraph:
fusion_xpu_op0 fusion_xpu_op0
/ \ / \
...@@ -98,36 +108,77 @@ Fused subgraph: ...@@ -98,36 +108,77 @@ Fused subgraph:
out0 out0_max out0 out0_max
| | | |
\ / \ /
fusion_xpu_op1 fusion_op
Origin subgraph1:
fusion_xpu_op0 fusion_xpu_op1
/ \ / \
| | | |
out0 out0_max out1 out1_max
| |
(x) \ / (branch)
fusion_xpu_op2
Fused subgraph1:
fusion_xpu_op0 fusion_xpu_op1
/ \ / \
| | | |
out0 out0_max out1 out1_max
| | | |
(x) \ |(x_max) |(branch) /(branch_max)
\ | | /
\ | | /
\ | | /
fusion_xpu_op2
*/ */
void LinkXPUOpMaxPass::ApplyImpl(ir::Graph* graph) const { void LinkXPUOpMaxPass::ApplyImpl(ir::Graph* graph) const {
Init(name_scope_, graph); Init(name_scope_, graph);
for (auto op_type : op_types_) { for (auto op_type : op_types_) {
ApplyImpl(graph, op_type); for (auto with_branch : {true, false}) {
ApplyImpl(graph, op_type, with_branch);
}
} }
} }
void LinkXPUOpMaxPass::ApplyImpl(ir::Graph* graph, void LinkXPUOpMaxPass::ApplyImpl(ir::Graph* graph,
const std::string& op_type) const { const std::string& op_type,
bool with_branch) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null.")); graph, platform::errors::PreconditionNotMet("graph should not be null."));
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::FusionXPUOpPattern pattern( patterns::FusionXPUOpPattern pattern(
gpd.mutable_pattern(), name_scope_, op_type); gpd.mutable_pattern(), name_scope_, op_type, with_branch);
int found_subgraph_count = 0; int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) { Graph* graph) {
VLOG(4) << "handle LinkXPUOpMaxPass fuse"; VLOG(4) << "handle LinkXPUOpMaxPass fuse";
GET_IR_NODE(fusion_op); GET_IR_NODE(fusion_op);
GET_IR_NODE(out); GET_IR_NODE(input);
GET_IR_NODE(out_max); GET_IR_NODE(branch);
for (auto next_op : out->outputs) {
auto* next_op_desc = next_op->Op(); auto* fusion_op_desc = fusion_op->Op();
if (op_types_.count(next_op_desc->Type()) == 0) continue; if (input->inputs[0]->Op()->HasOutput("out_max")) {
next_op_desc->SetInput("x_max", {out_max->Name()}); auto input_max_name = input->inputs[0]->Op()->Output("out_max");
IR_NODE_LINK_TO(out_max, next_op); for (auto max_node : input->inputs[0]->outputs) {
found_subgraph_count++; if (input_max_name[0] == max_node->Name()) {
fusion_op_desc->SetInput("x_max", {max_node->Name()});
IR_NODE_LINK_TO(max_node, fusion_op);
found_subgraph_count++;
}
}
}
if (with_branch) {
if (branch->inputs[0]->Op()->HasOutput("out_max")) {
auto branch_max_name = branch->inputs[0]->Op()->Output("out_max");
for (auto max_node : branch->inputs[0]->outputs) {
if (branch_max_name[0] == max_node->Name()) {
fusion_op_desc->SetInput("branch_max", {max_node->Name()});
IR_NODE_LINK_TO(max_node, fusion_op);
found_subgraph_count++;
}
}
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册