未验证 提交 be1b3fc3 编写于 作者: S sprouteer 提交者: GitHub

[XPU][BUG] Fix link_xpu_op_max_pass bug (#53258)

上级 503f422e
......@@ -110,7 +110,8 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern,
->assert_is_op_input(conv_type_, "Filter")
->AsInput();
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output(conv_type_, "Output");
->assert_is_op_output(conv_type_, "Output")
->assert_has_n_outputs(1);
conv->LinksFrom({input, conv_filter}).LinksTo({conv_out});
// ew_bias_add op
PDNode* ew_bias_add = nullptr;
......@@ -190,12 +191,12 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern,
// ew_branch_add op
if (with_branch_) {
if (with_branch_x_) {
bn_out->assert_is_op_input("elementwise_add", "Y")->AsIntermediate();
bn_out->assert_is_op_input("elementwise_add", "Y");
ew_branch_add_in = pattern->NewNode(ew_branch_add_in_repr())
->assert_is_op_input("elementwise_add", "X")
->AsInput();
} else if (with_branch_y_) {
bn_out->assert_is_op_input("elementwise_add", "X")->AsIntermediate();
bn_out->assert_is_op_input("elementwise_add", "X");
ew_branch_add_in = pattern->NewNode(ew_branch_add_in_repr())
->assert_is_op_input("elementwise_add", "Y")
->AsInput();
......@@ -221,13 +222,15 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern,
}
// act op
if (!act_type_.empty()) {
ew_branch_add_out->assert_is_op_input(act_type_, "X")->AsIntermediate();
ew_branch_add_out->assert_is_op_input(act_type_, "X");
act = pattern->NewNode(act_repr())->assert_is_op(act_type_);
act_out = pattern->NewNode(act_out_repr())
->assert_is_op_output(act_type_, "Out")
->assert_var_not_persistable();
act_out =
pattern->NewNode(act_out_repr())->assert_is_op_output(act_type_, "Out");
act->LinksFrom({ew_branch_add_out}).LinksTo({act_out});
} else {
act_out = ew_branch_add_out;
}
act_out->AsOutput();
}
} // namespace patterns
......
......@@ -41,30 +41,39 @@ namespace patterns {
struct FusionXPUOpPattern : public PatternBase {
FusionXPUOpPattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& op_type);
const std::string& op_type,
bool with_branch);
// declare operator node's name
PATTERN_DECL_NODE(fusion_op);
// declare variable node's name
PATTERN_DECL_NODE(out);
PATTERN_DECL_NODE(out_max);
PATTERN_DECL_NODE(input);
PATTERN_DECL_NODE(branch);
private:
std::string op_type_;
bool with_branch_{false};
};
FusionXPUOpPattern::FusionXPUOpPattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& op_type)
: PatternBase(pattern, name_scope, name_scope), op_type_(op_type) {
const std::string& op_type,
bool with_branch)
: PatternBase(pattern, name_scope, name_scope),
op_type_(op_type),
with_branch_(with_branch) {
auto* fusion_op = pattern->NewNode(fusion_op_repr())->assert_is_op(op_type_);
auto* out = pattern->NewNode(out_repr())
->assert_is_op_output(op_type_, "out")
->assert_var_not_persistable();
auto* out_max = pattern->NewNode(out_max_repr())
->assert_is_op_output(op_type_, "out_max")
->assert_var_not_persistable();
fusion_op->LinksTo({out, out_max});
auto* input =
pattern->NewNode(input_repr())->assert_is_op_input(op_type_, "x");
PDNode* branch = nullptr;
if (with_branch_) {
branch =
pattern->NewNode(branch_repr())->assert_is_op_input(op_type_, "branch");
fusion_op->LinksFrom({input, branch});
} else {
fusion_op->LinksFrom({input});
}
}
} // namespace patterns
......@@ -74,7 +83,9 @@ class LinkXPUOpMaxPass : public FusePassBase {
void ApplyImpl(ir::Graph* graph) const override;
private:
void ApplyImpl(ir::Graph* graph, const std::string& op_type) const;
void ApplyImpl(ir::Graph* graph,
const std::string& op_type,
bool with_branch) const;
const std::string name_scope_{"multi_encoder_xpu_slice_fuse_pass"};
// ops with x_max/out_max
......@@ -89,8 +100,7 @@ Origin subgraph:
out0 out0_max
|
\
fusion_xpu_op1
fusion_op
Fused subgraph:
fusion_xpu_op0
/ \
......@@ -98,36 +108,77 @@ Fused subgraph:
out0 out0_max
| |
\ /
fusion_xpu_op1
fusion_op
Origin subgraph1:
fusion_xpu_op0 fusion_xpu_op1
/ \ / \
| | | |
out0 out0_max out1 out1_max
| |
(x) \ / (branch)
fusion_xpu_op2
Fused subgraph1:
fusion_xpu_op0 fusion_xpu_op1
/ \ / \
| | | |
out0 out0_max out1 out1_max
| | | |
(x) \ |(x_max) |(branch) /(branch_max)
\ | | /
\ | | /
\ | | /
fusion_xpu_op2
*/
void LinkXPUOpMaxPass::ApplyImpl(ir::Graph* graph) const {
Init(name_scope_, graph);
for (auto op_type : op_types_) {
ApplyImpl(graph, op_type);
for (auto with_branch : {true, false}) {
ApplyImpl(graph, op_type, with_branch);
}
}
}
void LinkXPUOpMaxPass::ApplyImpl(ir::Graph* graph,
const std::string& op_type) const {
const std::string& op_type,
bool with_branch) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
GraphPatternDetector gpd;
patterns::FusionXPUOpPattern pattern(
gpd.mutable_pattern(), name_scope_, op_type);
gpd.mutable_pattern(), name_scope_, op_type, with_branch);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle LinkXPUOpMaxPass fuse";
GET_IR_NODE(fusion_op);
GET_IR_NODE(out);
GET_IR_NODE(out_max);
for (auto next_op : out->outputs) {
auto* next_op_desc = next_op->Op();
if (op_types_.count(next_op_desc->Type()) == 0) continue;
next_op_desc->SetInput("x_max", {out_max->Name()});
IR_NODE_LINK_TO(out_max, next_op);
found_subgraph_count++;
GET_IR_NODE(input);
GET_IR_NODE(branch);
auto* fusion_op_desc = fusion_op->Op();
if (input->inputs[0]->Op()->HasOutput("out_max")) {
auto input_max_name = input->inputs[0]->Op()->Output("out_max");
for (auto max_node : input->inputs[0]->outputs) {
if (input_max_name[0] == max_node->Name()) {
fusion_op_desc->SetInput("x_max", {max_node->Name()});
IR_NODE_LINK_TO(max_node, fusion_op);
found_subgraph_count++;
}
}
}
if (with_branch) {
if (branch->inputs[0]->Op()->HasOutput("out_max")) {
auto branch_max_name = branch->inputs[0]->Op()->Output("out_max");
for (auto max_node : branch->inputs[0]->outputs) {
if (branch_max_name[0] == max_node->Name()) {
fusion_op_desc->SetInput("branch_max", {max_node->Name()});
IR_NODE_LINK_TO(max_node, fusion_op);
found_subgraph_count++;
}
}
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册