未验证 提交 dc4b48f6 编写于 作者: W wz1qqx 提交者: GitHub

eliminate small pattern (#55843)

上级 c4694c15
......@@ -240,8 +240,8 @@ if(WITH_XPU)
pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(conv1d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(redundant_onnx_ops_elimination_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(redundant_unsqueeze_squeeze_elimination_pass inference DIR xpu
DEPS ${XPU_PASS_DEPS})
pass_library(redundant_squeeze_unsqueeze_elimination_pass inference DIR xpu
DEPS ${XPU_PASS_DEPS})
pass_library(conv2d_transpose_xpu_fuse_pass inference DIR xpu DEPS
......
......@@ -43,17 +43,17 @@ namespace patterns {
fuse ele_add + activation block in to xpu_ele_fusion op
For example:
graph:
ele_x
add_x
|
elementwise_add -----ele_y
elementwise_add -----add_y
|
layernorm
|
output
------------------------------------------------------
After the pass is applied:
ele_x
| ele_y
add_x
| add_y
| /
| /
scale---- add_layernorm_fusion ---- bias
......@@ -68,8 +68,8 @@ struct AddLayernormXPUPattern : public PatternBase {
PATTERN_DECL_NODE(ele_add);
PATTERN_DECL_NODE(l_norm);
// declare variable node's name
PATTERN_DECL_NODE(ele_x);
PATTERN_DECL_NODE(ele_y);
PATTERN_DECL_NODE(add_x);
PATTERN_DECL_NODE(add_y);
PATTERN_DECL_NODE(ele_out);
PATTERN_DECL_NODE(norm_bias);
PATTERN_DECL_NODE(norm_scale);
......@@ -83,17 +83,16 @@ AddLayernormXPUPattern::AddLayernormXPUPattern(PDPattern* pattern,
: PatternBase(pattern, name_scope, name_scope) {
auto ele_add =
pattern->NewNode(ele_add_repr())->assert_is_op("elementwise_add");
auto ele_x = pattern->NewNode(ele_x_repr())
auto add_x = pattern->NewNode(add_x_repr())
->assert_is_op_input("elementwise_add", "X")
->AsInput();
auto ele_y = pattern->NewNode(ele_y_repr())
auto add_y = pattern->NewNode(add_y_repr())
->assert_is_op_input("elementwise_add", "Y")
->AsInput();
auto ele_out = pattern->NewNode(ele_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("layer_norm", "X")
->assert_has_n_outputs(1);
ele_add->LinksFrom({ele_x, ele_y}).LinksTo({ele_out});
->assert_is_op_input("layer_norm", "X");
ele_add->LinksFrom({add_x, add_y}).LinksTo({ele_out});
auto l_norm = pattern->NewNode(l_norm_repr())->assert_is_op("layer_norm");
auto norm_bias = pattern->NewNode(norm_bias_repr())
->AsInput()
......@@ -169,8 +168,8 @@ void AddLayernormXPUFusePass::FuseAddLayernorm(ir::Graph* graph) const {
GET_IR_NODE(ele_add);
GET_IR_NODE(l_norm);
// declare variable node's name
GET_IR_NODE(ele_x);
GET_IR_NODE(ele_y);
GET_IR_NODE(add_x);
GET_IR_NODE(add_y);
GET_IR_NODE(ele_out);
GET_IR_NODE(norm_bias);
GET_IR_NODE(norm_scale);
......@@ -178,21 +177,21 @@ void AddLayernormXPUFusePass::FuseAddLayernorm(ir::Graph* graph) const {
GET_IR_NODE(norm_variance);
GET_IR_NODE(norm_out);
auto* block = ele_add->Op()->Block();
auto* block = l_norm->Op()->Block();
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
auto x_shape = add_x->Var()->GetShape();
auto x_rank = x_shape.size();
auto y_shape = add_y->Var()->GetShape();
auto y_rank = y_shape.size();
if (x_rank != y_rank) return;
// delete useless node
std::unordered_set<const Node*> delete_nodes;
float eps = PADDLE_GET_CONST(float, l_norm->Op()->GetAttr("epsilon"));
int begin_norm_axis =
PADDLE_GET_CONST(int, l_norm->Op()->GetAttr("begin_norm_axis"));
auto layer_norm_x_dims = ele_out->Var()->GetShape();
auto layer_norm_x_mat_dims =
phi::flatten_to_2d(phi::make_ddim(layer_norm_x_dims), begin_norm_axis);
int64_t m = layer_norm_x_mat_dims[0];
int64_t n = layer_norm_x_mat_dims[1];
std::string fused_op_out_name;
fused_op_out_name = norm_out->Name();
......@@ -200,28 +199,26 @@ void AddLayernormXPUFusePass::FuseAddLayernorm(ir::Graph* graph) const {
framework::OpDesc fused_op_desc(block);
fused_op_desc.SetType("add_layernorm_xpu");
// set attrs for fused op
fused_op_desc.SetInput("x", {ele_x->Name()});
fused_op_desc.SetInput("y", {ele_y->Name()});
fused_op_desc.SetInput("x", {add_x->Name()});
fused_op_desc.SetInput("y", {add_y->Name()});
fused_op_desc.SetInput("scale", {norm_scale->Name()});
fused_op_desc.SetInput("bias", {norm_bias->Name()});
fused_op_desc.SetAttr("m", m);
fused_op_desc.SetAttr("n", n);
fused_op_desc.SetAttr("epsilon", eps);
fused_op_desc.SetAttr("begin_norm_axis", begin_norm_axis);
fused_op_desc.SetOutput("out", {fused_op_out_name});
setIntermediateOut(&fused_op_desc, "mean", name_scope_);
setIntermediateOut(&fused_op_desc, "variance", name_scope_);
setIntermediateOut(&fused_op_desc, "z_add", name_scope_);
// relink fused op
auto* fused_op = graph->CreateOpNode(&fused_op_desc);
IR_NODE_LINK_TO(ele_x, fused_op);
IR_NODE_LINK_TO(ele_y, fused_op);
IR_NODE_LINK_TO(add_x, fused_op);
IR_NODE_LINK_TO(add_y, fused_op);
IR_NODE_LINK_TO(norm_scale, fused_op);
IR_NODE_LINK_TO(norm_bias, fused_op);
IR_NODE_LINK_TO(fused_op, norm_out);
addIntermediateOut(fused_op, "mean", name_scope_, graph);
addIntermediateOut(fused_op, "variance", name_scope_, graph);
addIntermediateOut(fused_op, "z_add", name_scope_, graph);
delete_nodes.insert({ele_add, l_norm, ele_out, norm_mean, norm_variance});
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
......
......@@ -88,7 +88,7 @@ ReduceMaxFusePattern::ReduceMaxFusePattern(PDPattern* pattern,
auto* op_desc = node->Op();
auto input_var = node->inputs[0]->Var();
auto pool2d_x_shape = input_var->GetShape();
std::vector<int> HW = {static_cast<int>(pool2d_x_shape[2]),
std::vector<int> hw = {static_cast<int>(pool2d_x_shape[2]),
static_cast<int>(pool2d_x_shape[3])};
auto pool_type =
op_desc->GetAttrIfExists<std::string>("pooling_type");
......@@ -98,8 +98,8 @@ ReduceMaxFusePattern::ReduceMaxFusePattern(PDPattern* pattern,
op_desc->GetAttrIfExists<std::vector<int>>("strides");
auto paddings_array =
op_desc->GetAttrIfExists<std::vector<int>>("paddings");
return pool_type == "max" && ksize_array == HW &&
strides_array == HW &&
return pool_type == "max" && ksize_array == hw &&
strides_array == hw &&
paddings_array == std::vector<int>{0, 0};
});
auto* pool2d_out = pattern->NewNode(pool2d_out_repr())
......@@ -181,7 +181,7 @@ ReduceMeanFusePattern::ReduceMeanFusePattern(PDPattern* pattern,
auto* op_desc = node->Op();
auto input_var = node->inputs[0]->Var();
auto pool2d_x_shape = input_var->GetShape();
std::vector<int> HW = {static_cast<int>(pool2d_x_shape[2]),
std::vector<int> hw = {static_cast<int>(pool2d_x_shape[2]),
static_cast<int>(pool2d_x_shape[3])};
auto pool_type =
op_desc->GetAttrIfExists<std::string>("pooling_type");
......@@ -191,8 +191,8 @@ ReduceMeanFusePattern::ReduceMeanFusePattern(PDPattern* pattern,
op_desc->GetAttrIfExists<std::vector<int>>("strides");
auto paddings_array =
op_desc->GetAttrIfExists<std::vector<int>>("paddings");
return pool_type == "avg" && ksize_array == HW &&
strides_array == HW &&
return pool_type == "avg" && ksize_array == hw &&
strides_array == hw &&
paddings_array == std::vector<int>{0, 0};
});
auto* pool2d_out = pattern->NewNode(pool2d_out_repr())
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/redundant_onnx_ops_elimination_pass.h"
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct FoldConv1dSqueeze2Pattern : public PatternBase {
FoldConv1dSqueeze2Pattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& act_type);
// declare operator node's name
PATTERN_DECL_NODE(squeeze2);
PATTERN_DECL_NODE(bn);
PATTERN_DECL_NODE(act);
PATTERN_DECL_NODE(unsqueeze2);
// declare variable node's name
PATTERN_DECL_NODE(x);
PATTERN_DECL_NODE(squeeze2_out);
PATTERN_DECL_NODE(bn_bias);
PATTERN_DECL_NODE(bn_mean);
PATTERN_DECL_NODE(bn_scale);
PATTERN_DECL_NODE(bn_var);
PATTERN_DECL_NODE(bn_out);
PATTERN_DECL_NODE(bn_mean_out);
PATTERN_DECL_NODE(bn_saved_mean);
PATTERN_DECL_NODE(bn_saved_var);
PATTERN_DECL_NODE(bn_var_out);
PATTERN_DECL_NODE(act_out);
PATTERN_DECL_NODE(unsqueeze2_out);
private:
std::string act_type_;
};
FoldConv1dSqueeze2Pattern::FoldConv1dSqueeze2Pattern(
PDPattern* pattern,
const std::string& name_scope,
const std::string& act_type)
: PatternBase(pattern, name_scope, name_scope), act_type_(act_type) {
auto* x = pattern->NewNode(x_repr())
->assert_is_op_input("squeeze2", "X")
->assert_more([](Node* node) {
auto x_shape = node->Var()->GetShape();
size_t x_rank = x_shape.size();
return x_rank == 4 && x_shape[2] == 1;
});
auto* squeeze2 = pattern->NewNode(squeeze2_repr())
->assert_is_op("squeeze2")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto axes_array =
op_desc->GetAttrIfExists<std::vector<int>>("axes");
return axes_array == std::vector<int>{-2};
});
auto* squeeze2_out = pattern->NewNode(squeeze2_out_repr())
->assert_is_op_output("squeeze2", "Out")
->assert_is_op_input("batch_norm", "X");
squeeze2->LinksFrom({x}).LinksTo({squeeze2_out});
auto* bn_bias = pattern->NewNode(bn_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Bias")
->assert_has_n_outputs(1);
auto* bn_mean = pattern->NewNode(bn_mean_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Mean")
->assert_has_n_outputs(1);
auto* bn_scale = pattern->NewNode(bn_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Scale")
->assert_has_n_outputs(1);
auto* bn_var = pattern->NewNode(bn_var_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Variance")
->assert_has_n_outputs(1);
auto* bn = pattern->NewNode(bn_repr())->assert_is_op("batch_norm");
auto* bn_out = pattern->NewNode(bn_out_repr())
->assert_is_op_output("batch_norm", "Y")
->assert_is_op_input(act_type_, "X");
auto* bn_mean_out = pattern->NewNode(bn_mean_out_repr())
->assert_is_op_output("batch_norm", "MeanOut");
auto* bn_saved_mean = pattern->NewNode(bn_saved_mean_repr())
->assert_is_op_output("batch_norm", "SavedMean");
auto* bn_var_out = pattern->NewNode(bn_var_out_repr())
->assert_is_op_output("batch_norm", "VarianceOut");
auto* bn_saved_var = pattern->NewNode(bn_saved_var_repr())
->assert_is_op_output("batch_norm", "SavedVariance");
bn->LinksFrom({squeeze2_out, bn_bias, bn_mean, bn_scale, bn_var})
.LinksTo({bn_out, bn_mean_out, bn_var_out, bn_saved_mean, bn_saved_var});
auto act = pattern->NewNode(act_repr())->assert_is_op(act_type_);
auto act_out = pattern->NewNode(act_out_repr())
->assert_is_op_output(act_type_, "Out")
->assert_is_op_input("unsqueeze2", "X");
act->LinksFrom({bn_out}).LinksTo({act_out});
auto* unsqueeze2 =
pattern->NewNode(unsqueeze2_repr())
->assert_is_op("unsqueeze2")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto axes_array =
op_desc->GetAttrIfExists<std::vector<int>>("axes");
return axes_array == std::vector<int>{-2} ||
axes_array == std::vector<int>{2};
});
auto* unsqueeze2_out = pattern->NewNode(unsqueeze2_out_repr())
->assert_is_op_output("unsqueeze2", "Out");
unsqueeze2->LinksFrom({act_out}).LinksTo({unsqueeze2_out});
}
} // namespace patterns
void RedundantOnnxOpsEliminationPass::FoldConv1dSqueeze2Ops(
ir::Graph* graph, const std::string& act_type) const {
GraphPatternDetector gpd;
patterns::FoldConv1dSqueeze2Pattern pattern(
gpd.mutable_pattern(), name_scope_, act_type);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle FoldConv1dSqueeze2Ops";
// declare operator node's name
GET_IR_NODE(squeeze2);
GET_IR_NODE(bn);
GET_IR_NODE(act);
GET_IR_NODE(unsqueeze2);
// declare variable node's name
GET_IR_NODE(x);
GET_IR_NODE(squeeze2_out);
GET_IR_NODE(bn_out);
GET_IR_NODE(act_out);
GET_IR_NODE(unsqueeze2_out);
auto bn_op_desc = bn->Op();
bn_op_desc->RenameInput(squeeze2_out->Var()->Name(), x->Var()->Name());
bn_out->Var()->SetShape(x->Var()->GetShape());
act_out->Var()->SetShape(x->Var()->GetShape());
bn_op_desc->Flush();
IR_NODE_LINK_TO(x, bn);
// behind unsqueeze op node
auto unsqueeze_out_link_nodes = unsqueeze2_out->outputs;
for (auto out_link_node : unsqueeze_out_link_nodes) {
auto op_desc = out_link_node->Op();
op_desc->RenameInput(unsqueeze2_out->Var()->Name(),
act_out->Var()->Name());
op_desc->Flush();
IR_NODE_LINK_TO(act_out, out_link_node);
}
// delete useless node
std::unordered_set<const Node*> delete_nodes = {
squeeze2, squeeze2_out, unsqueeze2, unsqueeze2_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void RedundantOnnxOpsEliminationPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
for (auto act_type : {"leaky_relu", "elu"}) {
FoldConv1dSqueeze2Ops(graph, act_type);
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(redundant_onnx_ops_elimination_pass,
paddle::framework::ir::RedundantOnnxOpsEliminationPass);
REGISTER_PASS_CAPABILITY(redundant_onnx_ops_elimination_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"conv2d", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.h"
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct FoldTranspose2OpsPattern : public PatternBase {
FoldTranspose2OpsPattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& act_type);
// declare operator node's name
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(unsqueeze2);
PATTERN_DECL_NODE(reduce_sum);
PATTERN_DECL_NODE(act);
PATTERN_DECL_NODE(transpose2_2);
// declare variable node's name
PATTERN_DECL_NODE(x);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(unsqueeze2_out);
PATTERN_DECL_NODE(sum_out);
PATTERN_DECL_NODE(act_out);
PATTERN_DECL_NODE(transpose2_2_out);
private:
std::string act_type_;
};
FoldTranspose2OpsPattern::FoldTranspose2OpsPattern(
PDPattern* pattern,
const std::string& name_scope,
const std::string& act_type)
: PatternBase(pattern, name_scope, name_scope), act_type_(act_type) {
auto* x = pattern->NewNode(x_repr())
->assert_is_op_input("transpose2", "X")
->assert_more([](Node* node) {
auto x_shape = node->Var()->GetShape();
size_t x_rank = x_shape.size();
return x_rank == 3;
});
auto* transpose2_1 =
pattern->NewNode(transpose2_1_repr())
->assert_is_op("transpose2")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto axis_array =
op_desc->GetAttrIfExists<std::vector<int>>("axis");
return axis_array == std::vector<int>{0, 2, 1};
});
auto* transpose2_1_out = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input("unsqueeze2", "X");
transpose2_1->LinksFrom({x}).LinksTo({transpose2_1_out});
auto* unsqueeze2 =
pattern->NewNode(unsqueeze2_repr())
->assert_is_op("unsqueeze2")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto axes_array =
op_desc->GetAttrIfExists<std::vector<int>>("axes");
return axes_array == std::vector<int>{-2};
});
auto* unsqueeze2_out = pattern->NewNode(unsqueeze2_out_repr())
->assert_is_op_output("unsqueeze2", "Out")
->assert_is_op_input("reduce_sum", "X");
unsqueeze2->LinksFrom({transpose2_1_out}).LinksTo({unsqueeze2_out});
auto* reduce_sum =
pattern->NewNode(reduce_sum_repr())
->assert_is_op("reduce_sum")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto keep_dim = op_desc->GetAttrIfExists<bool>("keep_dim");
auto dim_array = op_desc->GetAttrIfExists<std::vector<int>>("dim");
return dim_array == std::vector<int>{-2} && !keep_dim;
});
auto* sum_out = pattern->NewNode(sum_out_repr())
->assert_is_op_output("reduce_sum", "Out")
->assert_is_op_input(act_type_, "X");
reduce_sum->LinksFrom({unsqueeze2_out}).LinksTo({sum_out});
auto* act = pattern->NewNode(act_repr())->assert_is_op(act_type_);
auto* act_out = pattern->NewNode(act_out_repr())
->assert_is_op_output(act_type_, "Out")
->assert_is_op_input("transpose2", "X");
act->LinksFrom({sum_out}).LinksTo({act_out});
auto* transpose2_2 =
pattern->NewNode(transpose2_2_repr())
->assert_is_op("transpose2")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto axis_array =
op_desc->GetAttrIfExists<std::vector<int>>("axis");
return axis_array == std::vector<int>{0, 2, 1};
});
auto* transpose2_2_out = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2", "Out");
transpose2_2->LinksFrom({act_out}).LinksTo({transpose2_2_out});
}
struct FoldGatherSqueeze2Pattern : public PatternBase {
FoldGatherSqueeze2Pattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(unsqueeze2_op);
PATTERN_DECL_NODE(gather_op);
PATTERN_DECL_NODE(squeeze2_op);
// declare variable node's name
PATTERN_DECL_NODE(x);
PATTERN_DECL_NODE(unsqueeze2_op_out);
PATTERN_DECL_NODE(gather_i);
PATTERN_DECL_NODE(gather_op_out);
PATTERN_DECL_NODE(squeeze2_op_out);
};
FoldGatherSqueeze2Pattern::FoldGatherSqueeze2Pattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* x = pattern->NewNode(x_repr())->assert_is_op_input("unsqueeze2", "X");
auto* unsqueeze2_op =
pattern->NewNode(unsqueeze2_op_repr())
->assert_is_op("unsqueeze2")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto axes_array =
op_desc->GetAttrIfExists<std::vector<int>>("axes");
return axes_array.size() == 1;
});
auto* unsqueeze2_op_out = pattern->NewNode(unsqueeze2_op_out_repr())
->assert_is_op_output("unsqueeze2", "Out")
->assert_is_op_input("gather", "X");
unsqueeze2_op->LinksFrom({x}).LinksTo({unsqueeze2_op_out});
auto* gather_op = pattern->NewNode(gather_op_repr())->assert_is_op("gather");
auto* gather_i = pattern->NewNode(gather_i_repr())
->assert_is_op_input("gather", "Index")
->assert_is_persistable_var()
->assert_more([](Node* node) {
auto i_shape = node->Var()->GetShape();
size_t i_rank = i_shape.size();
return i_rank == 1;
});
auto* gather_op_out = pattern->NewNode(gather_op_out_repr())
->assert_is_op_output("gather", "Out")
->assert_is_op_input("squeeze2", "X");
gather_op->LinksFrom({unsqueeze2_op_out, gather_i}).LinksTo({gather_op_out});
auto* squeeze2_op =
pattern->NewNode(squeeze2_op_repr())
->assert_is_op("squeeze2")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto axes_array =
op_desc->GetAttrIfExists<std::vector<int>>("axes");
return axes_array.size() == 1;
});
auto* squeeze2_op_out = pattern->NewNode(squeeze2_op_out_repr())
->assert_is_op_output("squeeze2", "Out");
squeeze2_op->LinksFrom({gather_op_out}).LinksTo({squeeze2_op_out});
}
} // namespace patterns
void RedundantUnsqueeze2EliminationPass::FoldTranspose2Ops(
ir::Graph* graph, const std::string& act_type) const {
GraphPatternDetector gpd;
patterns::FoldTranspose2OpsPattern pattern(
gpd.mutable_pattern(), name_scope_, act_type);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle FoldTranspose2Ops";
// declare operator node's name
GET_IR_NODE(transpose2_1);
GET_IR_NODE(unsqueeze2);
GET_IR_NODE(reduce_sum);
GET_IR_NODE(act);
GET_IR_NODE(transpose2_2);
// declare variable node's name
GET_IR_NODE(x);
GET_IR_NODE(transpose2_1_out);
GET_IR_NODE(unsqueeze2_out);
GET_IR_NODE(sum_out);
GET_IR_NODE(act_out);
GET_IR_NODE(transpose2_2_out);
auto act_op_desc = act->Op();
act_op_desc->RenameInput(sum_out->Var()->Name(), x->Var()->Name());
act_out->Var()->SetShape(x->Var()->GetShape());
act_op_desc->Flush();
IR_NODE_LINK_TO(x, act);
// behind unsqueeze op node
auto final_out_link_nodes = transpose2_2_out->outputs;
for (auto out_link_node : final_out_link_nodes) {
auto op_desc = out_link_node->Op();
op_desc->RenameInput(transpose2_2_out->Var()->Name(),
act_out->Var()->Name());
op_desc->Flush();
IR_NODE_LINK_TO(act_out, out_link_node);
}
// delete useless node
std::unordered_set<const Node*> delete_nodes = {transpose2_1,
transpose2_1_out,
unsqueeze2,
unsqueeze2_out,
reduce_sum,
sum_out,
transpose2_2,
transpose2_2_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void RedundantUnsqueeze2EliminationPass::FoldGatherSqueeze2Ops(
ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::FoldGatherSqueeze2Pattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle FoldGatherSqueeze2Ops";
// declare operator node's name
GET_IR_NODE(unsqueeze2_op);
GET_IR_NODE(gather_op);
GET_IR_NODE(squeeze2_op);
// declare variable node's name
GET_IR_NODE(x);
GET_IR_NODE(unsqueeze2_op_out);
GET_IR_NODE(gather_i);
GET_IR_NODE(gather_op_out);
GET_IR_NODE(squeeze2_op_out);
bool flag = true;
auto x_shape = x->Var()->GetShape();
auto x_rank = static_cast<int>(x_shape.size());
std::vector<int> unsqueeze_axes_attr = PADDLE_GET_CONST(
std::vector<int>, unsqueeze2_op->Op()->GetAttr("axes"));
auto unsqueeze_axes = unsqueeze_axes_attr.front();
unsqueeze_axes =
unsqueeze_axes < 0 ? unsqueeze_axes + x_rank : unsqueeze_axes;
auto gather_axis = PADDLE_GET_CONST(int, gather_op->Op()->GetAttr("axis"));
gather_axis = gather_axis < 0 ? gather_axis + x_rank + 1 : gather_axis;
std::vector<int> squeeze_axes_attr =
PADDLE_GET_CONST(std::vector<int>, squeeze2_op->Op()->GetAttr("axes"));
auto squeeze_axes = squeeze_axes_attr.front();
squeeze_axes = squeeze_axes < 0 ? squeeze_axes + x_rank + 1 : squeeze_axes;
flag &= (unsqueeze_axes >= 0 && unsqueeze_axes < x_rank);
flag &=
((gather_axis == unsqueeze_axes + 1) && (squeeze_axes == gather_axis));
if (!flag) return;
// x->gather->squeeze2_op_out
auto gather_op_desc = gather_op->Op();
gather_op_desc->RenameInput(unsqueeze2_op_out->Var()->Name(),
x->Var()->Name());
gather_op_desc->SetAttr("axis", gather_axis - 1);
gather_op_out->Var()->SetShape(squeeze2_op_out->Var()->GetShape());
gather_op_desc->Flush();
IR_NODE_LINK_TO(x, gather_op);
// behind squeeze op node
auto squeeze_out_link_nodes = squeeze2_op_out->outputs;
for (auto out_link_node : squeeze_out_link_nodes) {
auto op_desc = out_link_node->Op();
op_desc->RenameInput(squeeze2_op_out->Var()->Name(),
gather_op_out->Var()->Name());
op_desc->Flush();
IR_NODE_LINK_TO(gather_op_out, out_link_node);
}
std::unordered_set<const Node*> delete_nodes{
squeeze2_op, squeeze2_op_out, unsqueeze2_op, unsqueeze2_op_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void RedundantUnsqueeze2EliminationPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
for (auto act_type : {"relu"}) {
FoldTranspose2Ops(graph, act_type);
}
FoldGatherSqueeze2Ops(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(redundant_unsqueeze_squeeze_elimination_pass,
paddle::framework::ir::RedundantUnsqueeze2EliminationPass);
REGISTER_PASS_CAPABILITY(redundant_unsqueeze_squeeze_elimination_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"conv2d", 0));
......@@ -31,51 +31,51 @@ namespace paddle {
namespace framework {
namespace ir {
class RedundantOnnxOpsEliminationPass : public FusePassBase {
class RedundantUnsqueeze2EliminationPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
/*
Origin subgraph:
x filter
| |
unsqueeze2(axes={-2}) unsqueeze2(axes={-2})
\ /
\ /
conv2d(conv1d)
x
|
elementwise_add
transpose2
|
squeeze2(axes={-2})
unsqueeze2(axes={-2})
|
batch_norm
reduce_sum
|
act
|
unsqueeze2
transpose2
|
conv2d(conv1d)
Fused subgraph:
x filter
| |
unsqueeze2(axes={-2}) unsqueeze2(axes={-2})
\ /
\ /
conv2d(conv1d)
x
|
elementwise_add
act
|
batch_norm
*/
void FoldTranspose2Ops(ir::Graph* graph, const std::string& act_type) const;
/*
Origin subgraph:
x
|
act
unsqueeze2(axes={-2})
|
gather
|
squeeze2
|
Fused subgraph:
x
|
gather
|
conv2d(conv1d)
*/
void FoldConv1dSqueeze2Ops(ir::Graph* graph,
const std::string& act_type) const;
void FoldGatherSqueeze2Ops(ir::Graph* graph) const;
const std::string name_scope_{"redundant_onnx_ops_elimination_pass"};
const std::string name_scope_{"redundant_unsqueeze_squeeze_elimination_pass"};
};
} // namespace ir
......
......@@ -527,7 +527,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"fold_interp_outsize_fuse_pass",
"fold_two_squeeze2_fuse_pass",
"conv1d_xpu_fuse_pass",
"redundant_onnx_ops_elimination_pass",
"redundant_unsqueeze_squeeze_elimination_pass",
"reduce_ops_fuse_pass",
"delete_cast_op_pass",
"xpu_delete_cast_op_pass",
......
......@@ -15,7 +15,7 @@
optional : x_max, y_max
- op : add_layernorm_xpu
args : (Tensor x, Tensor y, Tensor scale, Tensor bias, int64_t m, int64_t n, float epsilon)
args : (Tensor x, Tensor y, Tensor scale, Tensor bias, int begin_norm_axis, float epsilon)
output : Tensor(out), Tensor(mean), Tensor(variance), Tensor(z_add)
infer_meta :
func : AddLayernormXPUInferMeta
......
......@@ -24,8 +24,7 @@ XPUOpMap& get_kl2_ops() {
static XPUOpMap s_xpu2_kernels{
{"add_act_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"add_layernorm_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"add_layernorm_xpu", XPUKernelSet({phi::DataType::FLOAT32})},
{"abs", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"abs_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......
......@@ -96,8 +96,7 @@ void AddLayernormXPUInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& scale,
const MetaTensor& bias,
int64_t m,
int64_t n,
int begin_norm_axis,
float epsilon,
MetaTensor* out,
MetaTensor* mean,
......@@ -106,12 +105,16 @@ void AddLayernormXPUInferMeta(const MetaTensor& x,
int axis = -1;
auto x_dims = x.dims();
auto y_dims = y.dims();
auto out_dims = x_dims;
if (x_dims != y_dims) {
auto out_dims = BroadCastInferShape(x_dims, y_dims, axis);
out_dims = BroadCastInferShape(x_dims, y_dims, axis);
out->set_dims(out_dims);
} else {
out->set_dims(x_dims);
out->set_dims(out_dims);
}
auto layer_norm_x_mat_dims = phi::flatten_to_2d(out_dims, begin_norm_axis);
int64_t m = layer_norm_x_mat_dims[0];
int64_t n = layer_norm_x_mat_dims[1];
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);
......
......@@ -34,8 +34,7 @@ void AddLayernormXPUInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& scale,
const MetaTensor& bias,
int64_t m,
int64_t n,
int begin_norm_axis,
float epsilon,
MetaTensor* out,
MetaTensor* mean,
......
......@@ -13,19 +13,65 @@
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "glog/logging.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
namespace phi {
namespace fusion {
static phi::DDim BroadCastInferShape(const DDim x_dims,
const DDim y_dims,
int axis) {
std::vector<int> out_dims_array(x_dims.size(), -1);
if (x_dims != y_dims) {
int max_dim = std::max(x_dims.size(), y_dims.size());
if (x_dims.size() == y_dims.size()) {
PADDLE_ENFORCE_EQ((axis == -1) || (axis == 0),
true,
phi::errors::InvalidArgument(
"axis should be -1 or 0 while the dimension of "
"tensor X (%s) is equal to the dimension of "
"tensor Y (%s), but received axis: %s",
x_dims.size(),
y_dims.size(),
axis));
}
PADDLE_ENFORCE_EQ((axis >= (-1 * max_dim)) && (axis < max_dim),
true,
phi::errors::InvalidArgument(
"The axis range must be [%s, %s), but axis is %s. "
"Please set the axis again.",
-1 * max_dim,
max_dim,
axis));
axis = (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1)
: axis);
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
out_dims_array.resize(max_dim);
phi::funcs::GetBroadcastDimsArrays(x_dims,
y_dims,
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
return phi::make_ddim(out_dims_array);
}
return x_dims;
}
template <typename T, typename Context>
void AddLayernormXPUKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& scale,
const DenseTensor& bias,
int64_t m,
int64_t n,
int begin_norm_axis,
float epsilon,
DenseTensor* out,
DenseTensor* mean,
......@@ -37,12 +83,19 @@ void AddLayernormXPUKernel(const Context& ctx,
auto* y_data = reinterpret_cast<const XPUType*>(y.data<T>());
const float* scale_data = scale.data<float>();
const float* bias_data = bias.data<float>();
auto* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(out));
float* mean_data = ctx.template Alloc<float>(mean);
float* variance_data = ctx.template Alloc<float>(variance);
auto* z_add_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(z_add));
auto x_dims = x.dims();
auto y_dims = y.dims();
auto out_dims = BroadCastInferShape(x_dims, y_dims, -1);
auto layer_norm_x_mat_dims = phi::flatten_to_2d(out_dims, begin_norm_axis);
int64_t m = layer_norm_x_mat_dims[0];
int64_t n = layer_norm_x_mat_dims[1];
auto* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(out));
int r = xpu::add_layer_norm_fusion<XPUType>( // T
/* baidu::xpu::api::Context* ctx */ ctx.x_context(),
/* const T* x */ x_data,
......@@ -66,5 +119,4 @@ PD_REGISTER_KERNEL(add_layernorm_xpu,
XPU,
ALL_LAYOUT,
phi::fusion::AddLayernormXPUKernel,
float,
phi::dtype::float16) {}
float) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册