diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 1c13d2902017ff5c6601aacde6d4bf04d7bc0717..e6a4c5e8e73ff5bbc44fa9c832e11e8ed7ceb159 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -241,6 +241,8 @@ if(WITH_XPU) pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(reshape_unstack_concat_fuse_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(multi_encoder_xpu_adaptive_seqlen_fuse_pass inference DIR xpu diff --git a/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc b/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc index 43be86b1f1bf822164eafde33141035aa326e25c..13393ec1b6895245348ccf4238220749e6cb2380 100644 --- a/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc +++ b/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc @@ -101,68 +101,86 @@ class DeleteRepeatedOpsPass : public FusePassBase { void ApplyImpl(ir::Graph* graph) const override; private: - int DeleteShapePass(ir::Graph* graph) const; - - int DeleteSlicePass(ir::Graph* graph) const; + void DeleteRepeatedOps( + ir::Graph* graph, + const std::string& op_type, + std::function gen_op_key_fn) const; const std::string name_scope_{"delete_repeated_ops_pass"}; }; -int DeleteRepeatedOpsPass::DeleteShapePass(ir::Graph* graph) const { +void DeleteRepeatedOpsPass::DeleteRepeatedOps( + ir::Graph* graph, + const std::string& op_type, + std::function gen_op_key_fn) const { GraphPatternDetector gpd; patterns::VarWithRepeatedOpsPattern pattern( - gpd.mutable_pattern(), name_scope_, "shape"); + gpd.mutable_pattern(), name_scope_, op_type); int delete_counts = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { - VLOG(4) << "handle DeleteShapePass"; + VLOG(4) << "handle DeleteRepeatedOps"; GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern); - std::vector invalid_shape_out_ops{"while", - "conditional_block"}; - std::vector shapes; + std::vector invalid_out_ops{ + "while", "conditional_block", "fetch"}; + std::map> ops_map; for (auto* next_op : in_var->outputs) { - if (next_op->Name() != "shape") continue; - bool shape_out_op_is_invalid = false; - for (auto* shape_out_op : next_op->outputs[0]->outputs) { - if (std::count(invalid_shape_out_ops.begin(), - invalid_shape_out_ops.end(), - shape_out_op->Name()) > 0 || - HasOutVarName(shape_out_op, next_op->outputs[0]->Name())) { - shape_out_op_is_invalid = true; + if (next_op->Name() != op_type) continue; + auto* op = next_op; + bool out_op_is_invalid = false; + for (auto* out_op : op->outputs[0]->outputs) { + if (std::count(invalid_out_ops.begin(), + invalid_out_ops.end(), + out_op->Name()) > 0 || + HasOutVarName(out_op, op->outputs[0]->Name())) { + out_op_is_invalid = true; break; } } - if (!shape_out_op_is_invalid) { - shapes.push_back(next_op); + if (out_op_is_invalid) continue; + auto attr_key = gen_op_key_fn(op->Op()); + ops_map[attr_key].push_back(op); + } + for (auto iter = ops_map.begin(); iter != ops_map.end();) { + if (iter->second.size() <= 1) { + iter = ops_map.erase(iter); + } else { + iter++; } } - if (shapes.size() <= 1) return; - auto* first_shape_out = shapes[0]->outputs[0]; - auto first_shape_out_name = first_shape_out->Name(); - std::unordered_set delete_nodes; - for (size_t i = 1; i < shapes.size(); i++) { - auto* cur_shape = shapes[i]; - auto* cur_shape_out = cur_shape->outputs[0]; - auto cur_shape_out_name = cur_shape_out->Name(); - for (auto* shape_out_op : cur_shape_out->outputs) { - shape_out_op->Op()->Rename(cur_shape_out_name, first_shape_out_name); - IR_NODE_LINK_TO(first_shape_out, shape_out_op); + for (auto iter : ops_map) { + auto ops = iter.second; + auto* first_op_out = ops[0]->outputs[0]; + auto first_op_out_name = first_op_out->Name(); + std::unordered_set delete_nodes; + for (size_t i = 1; i < ops.size(); i++) { + auto* cur_op = ops[i]; + auto* cur_op_out = cur_op->outputs[0]; + auto cur_op_out_name = cur_op_out->Name(); + for (auto* out_op : cur_op_out->outputs) { + out_op->Op()->RenameInput(cur_op_out_name, first_op_out_name); + IR_NODE_LINK_TO(first_op_out, out_op); + } + delete_nodes.insert(cur_op); + delete_nodes.insert(cur_op_out); + delete_counts++; } - delete_nodes.insert(cur_shape); - delete_nodes.insert(cur_shape_out); - delete_counts++; + GraphSafeRemoveNodes(graph, delete_nodes); } - - GraphSafeRemoveNodes(graph, delete_nodes); }; gpd(graph, handler); - return delete_counts; + if (delete_counts > 0) { + LOG(INFO) << "--- delete " << delete_counts << " repeated " << op_type + << " ops"; + } } +std::string GenShapeAttrKey(OpDesc* slice_op_desc) { return ""; } + std::string GenSliceAttrKey(OpDesc* slice_op_desc) { std::string attr_key; auto starts = slice_op_desc->GetAttrIfExists>("starts"); @@ -189,69 +207,27 @@ std::string GenSliceAttrKey(OpDesc* slice_op_desc) { return attr_key; } -int DeleteRepeatedOpsPass::DeleteSlicePass(ir::Graph* graph) const { - GraphPatternDetector gpd; - patterns::VarWithRepeatedOpsPattern pattern( - gpd.mutable_pattern(), name_scope_, "slice"); - - int delete_counts = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* graph) { - VLOG(4) << "handle DeleteSlicePass"; - GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern); - - std::vector invalid_slice_out_ops{"while", - "conditional_block"}; - std::map> slice_ops; - for (auto* next_op : in_var->outputs) { - if (next_op->Name() != "slice") continue; - auto* slice = next_op; - bool slice_out_op_is_invalid = false; - for (auto* slice_out_op : slice->outputs[0]->outputs) { - if (std::count(invalid_slice_out_ops.begin(), - invalid_slice_out_ops.end(), - slice_out_op->Name()) > 0 || - HasOutVarName(slice_out_op, slice->outputs[0]->Name())) { - slice_out_op_is_invalid = true; - break; - } - } - if (slice_out_op_is_invalid) continue; - auto attr_key = GenSliceAttrKey(slice->Op()); - slice_ops[attr_key].push_back(slice); - } - for (auto iter = slice_ops.begin(); iter != slice_ops.end();) { - if (iter->second.size() <= 1) { - iter = slice_ops.erase(iter); - } else { - iter++; - } - } +std::string GenCastAttrKey(OpDesc* cast_op_desc) { + auto in_dtype = cast_op_desc->GetAttrIfExists("in_dtype"); + auto out_dtype = cast_op_desc->GetAttrIfExists("out_dtype"); + return "in_dtype_" + std::to_string(in_dtype) + "_out_dtype_" + + std::to_string(out_dtype); +} - for (auto iter : slice_ops) { - auto slices = iter.second; - auto* first_slice_out = slices[0]->outputs[0]; - auto first_slice_out_name = first_slice_out->Name(); - std::unordered_set delete_nodes; - for (size_t i = 1; i < slices.size(); i++) { - auto* cur_slice = slices[i]; - auto* cur_slice_out = cur_slice->outputs[0]; - auto cur_slice_out_name = cur_slice_out->Name(); - for (auto* slice_out_op : cur_slice_out->outputs) { - slice_out_op->Op()->RenameInput(cur_slice_out_name, - first_slice_out_name); - IR_NODE_LINK_TO(first_slice_out, slice_out_op); - } - delete_nodes.insert(cur_slice); - delete_nodes.insert(cur_slice_out); - delete_counts++; - } - GraphSafeRemoveNodes(graph, delete_nodes); - } - }; +std::string GenAddAttrKey(OpDesc* add_op_desc) { + std::string x_name = add_op_desc->Input("X")[0]; + std::string y_name = add_op_desc->Input("Y")[0]; + auto axis = add_op_desc->GetAttrIfExists("axis"); + return x_name + "_" + y_name + "_axis_" + std::to_string(axis); +} - gpd(graph, handler); - return delete_counts; +std::string GenScaleAttrKey(OpDesc* scale_op_desc) { + auto scale = scale_op_desc->GetAttrIfExists("scale"); + auto bias = scale_op_desc->GetAttrIfExists("bias"); + auto bias_after_scale = + scale_op_desc->GetAttrIfExists("bias_after_scale"); + return "scale_" + std::to_string(scale) + "_bias_" + std::to_string(bias) + + "_bias_after_scale_" + std::to_string(bias_after_scale); } void DeleteRepeatedOpsPass::ApplyImpl(ir::Graph* graph) const { @@ -259,15 +235,12 @@ void DeleteRepeatedOpsPass::ApplyImpl(ir::Graph* graph) const { graph, platform::errors::PreconditionNotMet("graph should not be null.")); Init(name_scope_, graph); - int delete_counts = DeleteShapePass(graph); - if (delete_counts > 0) { - LOG(INFO) << "--- delete " << delete_counts << " repeated shape ops"; - } - - delete_counts = DeleteSlicePass(graph); - if (delete_counts > 0) { - LOG(INFO) << "--- delete " << delete_counts << " repeated slice ops"; - } + DeleteRepeatedOps(graph, "shape", GenShapeAttrKey); + DeleteRepeatedOps(graph, "slice", GenSliceAttrKey); + DeleteRepeatedOps(graph, "cast", GenCastAttrKey); + DeleteRepeatedOps(graph, "elementwise_add", GenAddAttrKey); + DeleteRepeatedOps(graph, "scale", GenScaleAttrKey); + DeleteRepeatedOps(graph, "cast", GenCastAttrKey); } } // namespace ir diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index c70a05829952c99685d540cb7b2d66c628c339da..59b43d87447a59a0e44d569eb8dc98bfee5694d9 100755 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -70,6 +70,7 @@ static const std::vector xpu_support_subgraph_passes = { "xpu_delete_cast_op_pass", "fc_xpu_fuse_pass", "link_xpu_op_max_pass", + "xpu_delete_cast_op_pass", }; Graph *Pass::Apply(Graph *graph) const { diff --git a/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc index 6c56113b0f4c4b358ceffe8047aa6d144a668e75..205740c0e24d82057bbc5f13cb2569e8a5e1cb72 100644 --- a/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc @@ -97,10 +97,10 @@ Reshape2MatmulPattern::Reshape2MatmulPattern(PDPattern* pattern, ->assert_more([](Node* node) { auto reshape2_in_x_shape = node->Var()->GetShape(); size_t reshape2_in_rank = reshape2_in_x_shape.size(); - bool nice_shape = - (reshape2_in_x_shape[2] == 1 && reshape2_in_x_shape[3] == 1) || - (reshape2_in_x_shape[1] == 1 && reshape2_in_x_shape[3] == 1); - return (reshape2_in_rank == 4 && nice_shape); + return reshape2_in_rank == 4 && ((reshape2_in_x_shape[2] == 1 && + reshape2_in_x_shape[3] == 1) || + (reshape2_in_x_shape[1] == 1 && + reshape2_in_x_shape[3] == 1)); }); auto* reshape2 = pattern->NewNode(reshape2_repr()) diff --git a/paddle/fluid/framework/ir/xpu/reshape_unstack_concat_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/reshape_unstack_concat_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..600b1862f1fd8eff1013a7249181ca860f4fdc59 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/reshape_unstack_concat_fuse_pass.cc @@ -0,0 +1,381 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct ReshapeUnstackConcatPattern : public PatternBase { + ReshapeUnstackConcatPattern(PDPattern* pattern, + const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(reshape); + PATTERN_DECL_NODE(unstack); + PATTERN_DECL_NODE(concat); + // declare variable node's name + PATTERN_DECL_NODE(reshape_in); + PATTERN_DECL_NODE(reshape_out); + PATTERN_DECL_NODE(unstack_out0); + PATTERN_DECL_NODE(concat_out); +}; + +ReshapeUnstackConcatPattern::ReshapeUnstackConcatPattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* reshape_in = + pattern->NewNode(reshape_in_repr())->assert_is_op_input("reshape2", "X"); + auto* reshape = + pattern->NewNode(reshape_repr()) + ->assert_is_op("reshape2") + ->assert_more([](Node* node) { + auto shape = node->Op()->GetAttrIfExists>("shape"); + return shape.size() == 6; + }); + auto* reshape_out = pattern->NewNode(reshape_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("unstack", "X"); + auto* unstack = pattern->NewNode(unstack_repr()) + ->assert_is_op("unstack") + ->assert_more([](Node* node) { + auto axis = node->Op()->GetAttrIfExists("axis"); + return axis == 0; + }); + auto* unstack_out0 = pattern->NewNode(unstack_out0_repr()) + ->assert_is_op_nth_output("unstack", "Y", 0) + ->assert_is_op_nth_input("concat", "X", 0); + auto* concat = pattern->NewNode(concat_repr()) + ->assert_is_op("concat") + ->assert_more([](Node* node) { + auto axis = node->Op()->GetAttrIfExists("axis"); + return axis == -2; + }); + auto* concat_out = pattern->NewNode(concat_out_repr()) + ->assert_is_op_output("concat", "Out") + ->assert_more([](Node* node) { + auto out_nodes = node->outputs; + if (out_nodes.size() <= 1) { + return false; + } + for (auto out_node : out_nodes) { + if (out_node->Name() != "slice") { + return false; + } + } + return true; + }); + reshape->LinksFrom({reshape_in}).LinksTo({reshape_out}); + unstack->LinksFrom({reshape_out}).LinksTo({unstack_out0}); + concat->LinksFrom({unstack_out0}).LinksTo({concat_out}); +} + +} // namespace patterns + +class ReshapeUnstackConcatFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + const std::string name_scope_{"reshape_unstack_concat_fuse_pass"}; +}; + +// clang-format off +/* +Origin subgraph: + reshape(4,-1,48,2,16,4096) + | + unstack + | + concat + | + ------------------------------------------------------------------ + | | | +slice(start/end/axes:0/1/1) slice(start/end/axes:1/2/1) ... slice(start/end/axes:n-1/n/1) + | | | +reshape(-1,2,64,4,1024) reshape(-1,2,64,4,1024) ... reshape(-1,2,64,4,1024) + | | | +slice(start/end/axes:0/1/3) slice(start/end/axes:0/1/3) ... slice(start/end/axes:0/1/3) + | | | +reshape(-1,2,64,16,64) reshape(-1,2,64,16,64) ... reshape(-1,2,64,16,64) + | | | +transpose(1,0,3,2,4) transpose(1,0,3,2,4) ... transpose(1,0,3,2,4) + +Optimized subgraph: + reshape(-1,4,1024) + | + slice(start/end/axes:0/1/2) + | + reshape(4,-1,48,2,16,1024) + | + unstack + | + concat + | + reshape(-1,n*2,64,16,64) + | + transpose(1,0,3,2,4) + | + split(num/axis:n/0) + | + ------------------------------------------------------------------ + | | | +*/ +// clang-format on +void ReshapeUnstackConcatFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + GraphPatternDetector gpd; + patterns::ReshapeUnstackConcatPattern pattern(gpd.mutable_pattern(), + name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ReshapeUnstackConcatFusePass fuse"; + GET_IR_NODE(reshape); + GET_IR_NODE(unstack); + GET_IR_NODE(concat); + GET_IR_NODE(reshape_in); + GET_IR_NODE(reshape_out); + GET_IR_NODE(unstack_out0); + GET_IR_NODE(concat_out); + auto* block = reshape->Op()->Block(); + + auto concat_out_ops = concat_out->outputs; + int ops_num = concat_out_ops.size(); + std::vector slice_0s(ops_num, nullptr); + std::vector reshape_0s(ops_num, nullptr); + std::vector slice_1s(ops_num, nullptr); + std::vector reshape_1s(ops_num, nullptr); + std::vector transposes(ops_num, nullptr); + for (int i = 0; i < ops_num; i++) { + auto slice_0 = concat_out_ops[i]; + if (slice_0->Name() != "slice") return; + auto slice_0_starts = + slice_0->Op()->GetAttrIfExists>("starts"); + auto slice_0_ends = + slice_0->Op()->GetAttrIfExists>("ends"); + auto slice_0_axes = + slice_0->Op()->GetAttrIfExists>("axes"); + if (slice_0_starts.size() != 1 || + (slice_0_ends[0] - slice_0_starts[0] != 1) || slice_0_axes[0] != 1) { + return; + } + int op_index = slice_0_starts[0]; + if (slice_0s[op_index] != nullptr) return; + slice_0s[op_index] = slice_0; + + auto reshape_0 = slice_0->outputs[0]->outputs[0]; + if (reshape_0->Name() != "reshape2") return; + auto reshape_0_shape = + reshape_0->Op()->GetAttrIfExists>("shape"); + if (reshape_0_shape.size() != 5) return; + reshape_0s[op_index] = reshape_0; + + Node* slice_1 = nullptr; + for (auto reshape_out : reshape_0->outputs) { + if (reshape_out->Name() == reshape_0->Op()->Output("Out")[0]) { + slice_1 = reshape_out->outputs[0]; + if (slice_1->Name() != "slice") return; + auto slice_1_axes = + slice_1->Op()->GetAttrIfExists>("axes"); + if (slice_1_axes.size() != 1 || slice_1_axes[0] != 3) { + return; + } + slice_1s[op_index] = slice_1; + } + } + + auto* reshape_1 = slice_1->outputs[0]->outputs[0]; + if (reshape_1->Name() != "reshape2") return; + auto reshape_1_shape = + reshape_1->Op()->GetAttrIfExists>("shape"); + if (reshape_1_shape.size() != 5) return; + reshape_1s[op_index] = reshape_1; + + Node* transpose = nullptr; + for (auto reshape_out : reshape_1->outputs) { + if (reshape_out->Name() == reshape_1->Op()->Output("Out")[0]) { + transpose = reshape_out->outputs[0]; + if (transpose->Name() != "transpose2") return; + auto transpose_axis = + transpose->Op()->GetAttrIfExists>("axis"); + if (transpose_axis != std::vector{1, 0, 3, 2, 4}) return; + transposes[op_index] = transpose; + } + } + } + + std::string new_reshape_0_out_name = reshape_in->Name() + "_reshape_out"; + VarDesc new_reshape_0_out_desc(new_reshape_0_out_name); + Node* new_reshape_0_out = graph->CreateVarNode(&new_reshape_0_out_desc); + + framework::OpDesc new_reshape_0_op_desc(block); + new_reshape_0_op_desc.SetType("reshape2"); + auto reshape_0_shape = + reshape_0s[0]->Op()->GetAttrIfExists>("shape"); + std::vector new_reshape_0_shape{ + -1, reshape_0_shape[3], reshape_0_shape[4]}; + new_reshape_0_op_desc.SetAttr("shape", new_reshape_0_shape); + new_reshape_0_op_desc.SetInput("X", {reshape_in->Name()}); + new_reshape_0_op_desc.SetOutput("Out", {new_reshape_0_out_name}); + auto* new_reshape_0 = graph->CreateOpNode(&new_reshape_0_op_desc); + + std::string new_slice_0_out_name = reshape_in->Name() + "_slice_out"; + VarDesc new_slice_0_out_desc(new_slice_0_out_name); + Node* new_slice_0_out = graph->CreateVarNode(&new_slice_0_out_desc); + + framework::OpDesc new_slice_0_op_desc(block); + new_slice_0_op_desc.SetType("slice"); + auto new_slice_0_start = + slice_1s[0]->Op()->GetAttrIfExists>("starts"); + auto new_slice_0_ends = + slice_1s[0]->Op()->GetAttrIfExists>("ends"); + new_slice_0_op_desc.SetAttr("starts", new_slice_0_start); + new_slice_0_op_desc.SetAttr("ends", new_slice_0_ends); + new_slice_0_op_desc.SetAttr("axes", std::vector{1}); + new_slice_0_op_desc.SetAttr("decrease_axis", std::vector{1}); + new_slice_0_op_desc.SetInput("Input", {new_reshape_0_out_name}); + new_slice_0_op_desc.SetOutput("Out", {new_slice_0_out_name}); + auto* new_slice_0 = graph->CreateOpNode(&new_slice_0_op_desc); + + reshape->Op()->SetInput("X", {new_slice_0_out_name}); + auto reshape_shape = + reshape->Op()->GetAttrIfExists>("shape"); + reshape_shape[5] /= reshape_0_shape[3]; + reshape->Op()->SetAttr("shape", reshape_shape); + IR_NODE_UNLINK(reshape_in, reshape); + IR_NODE_LINK_TO(reshape_in, new_reshape_0); + IR_NODE_LINK_TO(new_reshape_0, new_reshape_0_out); + IR_NODE_LINK_TO(new_reshape_0_out, new_slice_0); + IR_NODE_LINK_TO(new_slice_0, new_slice_0_out); + IR_NODE_LINK_TO(new_slice_0_out, reshape); + + std::string new_reshape_1_out_name = concat_out->Name() + "_reshape_out"; + VarDesc new_reshape_1_out_desc(new_reshape_1_out_name); + Node* new_reshape_1_out = graph->CreateVarNode(&new_reshape_1_out_desc); + + framework::OpDesc new_reshape_1_op_desc(block); + new_reshape_1_op_desc.SetType("reshape2"); + auto new_reshape_1_shape = + reshape_1s[0]->Op()->GetAttrIfExists>("shape"); + new_reshape_1_shape[1] *= ops_num; + new_reshape_1_op_desc.SetAttr("shape", new_reshape_1_shape); + new_reshape_1_op_desc.SetInput("X", {concat_out->Name()}); + new_reshape_1_op_desc.SetOutput("Out", {new_reshape_1_out_name}); + auto* new_reshape_1 = graph->CreateOpNode(&new_reshape_1_op_desc); + + std::string new_transpose_0_out_name = + concat_out->Name() + "_transpose_out"; + VarDesc new_transpose_0_out_desc(new_transpose_0_out_name); + Node* new_transpose_0_out = graph->CreateVarNode(&new_transpose_0_out_desc); + + framework::OpDesc new_transpose_0_op_desc(block); + new_transpose_0_op_desc.SetType("transpose2"); + auto transpose_axis = + transposes[0]->Op()->GetAttrIfExists>("axis"); + new_transpose_0_op_desc.SetAttr("axis", transpose_axis); + new_transpose_0_op_desc.SetInput("X", {new_reshape_1_out_name}); + new_transpose_0_op_desc.SetOutput("Out", {new_transpose_0_out_name}); + auto* new_transpose_0 = graph->CreateOpNode(&new_transpose_0_op_desc); + + std::vector new_split_0_out_names; + for (auto* transpose : transposes) { + new_split_0_out_names.push_back(transpose->Op()->Output("Out")[0]); + } + + framework::OpDesc new_split_0_op_desc(block); + new_split_0_op_desc.SetType("split"); + new_split_0_op_desc.SetAttr("num", ops_num); + new_split_0_op_desc.SetAttr("axis", 0); + new_split_0_op_desc.SetInput("X", {new_transpose_0_out_name}); + new_split_0_op_desc.SetOutput("Out", new_split_0_out_names); + auto* new_split_0 = graph->CreateOpNode(&new_split_0_op_desc); + + IR_NODE_LINK_TO(concat_out, new_reshape_1); + IR_NODE_LINK_TO(new_reshape_1, new_reshape_1_out); + IR_NODE_LINK_TO(new_reshape_1_out, new_transpose_0); + IR_NODE_LINK_TO(new_transpose_0, new_transpose_0_out); + IR_NODE_LINK_TO(new_transpose_0_out, new_split_0); + for (auto* transpose : transposes) { + for (auto* transpose_out : transpose->outputs) { + if (transpose_out->Name() == transpose->Op()->Output("Out")[0]) { + IR_NODE_LINK_TO(new_split_0, transpose_out); + } + } + } + + std::unordered_set delete_nodes; + delete_nodes.insert(slice_0s.begin(), slice_0s.end()); + for (auto* slice_0 : slice_0s) { + delete_nodes.emplace(slice_0->outputs[0]); + } + delete_nodes.insert(reshape_0s.begin(), reshape_0s.end()); + for (auto* reshape_0 : reshape_0s) { + auto reshape_0_outs = reshape_0->outputs; + delete_nodes.insert(reshape_0_outs.begin(), reshape_0_outs.end()); + } + delete_nodes.insert(slice_1s.begin(), slice_1s.end()); + for (auto* slice_1 : slice_1s) { + delete_nodes.emplace(slice_1->outputs[0]); + } + delete_nodes.insert(reshape_1s.begin(), reshape_1s.end()); + for (auto* reshape_1 : reshape_1s) { + auto reshape_1_outs = reshape_1->outputs; + delete_nodes.insert(reshape_1_outs.begin(), reshape_1_outs.end()); + } + delete_nodes.insert(transposes.begin(), transposes.end()); + GraphSafeRemoveNodes(graph, delete_nodes); + + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(reshape_unstack_concat_fuse_pass, + paddle::framework::ir::ReshapeUnstackConcatFusePass); + +REGISTER_PASS_CAPABILITY(reshape_unstack_concat_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "stack", 0)); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 9334214d683c36e0f42f479648e3b15abd669716..0edcaac4335a2588b28f9b09b28fc07aa93238d7 100755 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -512,6 +512,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "delete_concat_op_pass", "identity_op_clean_pass", "delete_repeated_ops_pass", + "reshape_unstack_concat_fuse_pass", "delete_op_device_pass", "constant_folding_pass", "delete_elementwise_mul_op_pass", @@ -525,6 +526,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "fold_interp_outsize_fuse_pass", "fold_two_squeeze2_fuse_pass", "delete_cast_op_pass", + "xpu_delete_cast_op_pass", "stack_fuse_pass", "fused_multi_transformer_xpu_pass", "sigmoid_elementmul_fuse_pass", @@ -539,7 +541,6 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "link_xpu_op_max_pass", "inplace_op_var_pass", "delete_isolated_node_pass", - "xpu_delete_cast_op_pass", }); use_xpu_ = true; } diff --git a/test/ir/inference/test_xpu_delete_repeated_ops_pass.py b/test/ir/inference/test_xpu_delete_repeated_ops_pass.py index d05e529cb31ae11d5a4f6efbbb60275dd696f0f6..5f7799aaee83d45fb9dffdafc8914889d94a106b 100644 --- a/test/ir/inference/test_xpu_delete_repeated_ops_pass.py +++ b/test/ir/inference/test_xpu_delete_repeated_ops_pass.py @@ -19,10 +19,10 @@ from auto_scan_test import PassAutoScanTest from program_config import OpConfig, ProgramConfig, TensorConfig -class TestDeleteRepeatedShapePass(PassAutoScanTest): +class TestDeleteRepeatedShapeCastPass(PassAutoScanTest): def sample_predictor_configs(self, program_config): config = self.create_inference_config(use_xpu=True) - yield config, ['shape', 'cast', 'cast', 'cast'], (1e-5, 1e-5) + yield config, ['shape', 'cast', 'relu', 'relu', 'relu'], (1e-5, 1e-5) def sample_program_config(self, draw): x_shape = draw( @@ -47,6 +47,13 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest): out_dtype=5, outputs={"Out": ["cast0_out"]}, ) + relu_op0 = OpConfig( + "relu", + inputs={ + "X": ["cast0_out"], + }, + outputs={"Out": ["relu0_out"]}, + ) shape_op1 = OpConfig( "shape", inputs={ @@ -63,6 +70,13 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest): out_dtype=5, outputs={"Out": ["cast1_out"]}, ) + relu_op1 = OpConfig( + "relu", + inputs={ + "X": ["cast1_out"], + }, + outputs={"Out": ["relu1_out"]}, + ) shape_op2 = OpConfig( "shape", inputs={ @@ -79,7 +93,24 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest): out_dtype=5, outputs={"Out": ["cast2_out"]}, ) - ops = [shape_op0, cast_op0, shape_op1, cast_op1, shape_op2, cast_op2] + relu_op2 = OpConfig( + "relu", + inputs={ + "X": ["cast2_out"], + }, + outputs={"Out": ["relu2_out"]}, + ) + ops = [ + shape_op0, + cast_op0, + relu_op0, + shape_op1, + cast_op1, + relu_op1, + shape_op2, + cast_op2, + relu_op2, + ] program_config = ProgramConfig( ops=ops, @@ -87,7 +118,7 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest): inputs={ "shape_x": TensorConfig(shape=x_shape), }, - outputs=["cast0_out", "cast1_out", "cast2_out"], + outputs=["relu0_out", "relu1_out", "relu2_out"], ) return program_config @@ -102,7 +133,7 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest): class TestDeleteRepeatedSlicePass(PassAutoScanTest): def sample_predictor_configs(self, program_config): config = self.create_inference_config(use_xpu=True) - yield config, ['slice'], (1e-5, 1e-5) + yield config, ['slice', 'relu', 'relu', 'relu'], (1e-5, 1e-5) def sample_program_config(self, draw): slice_x = draw( @@ -122,6 +153,13 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest): decrease_axis=[0], outputs={"Out": ["slice0_out"]}, ) + relu_op0 = OpConfig( + "relu", + inputs={ + "X": ["slice0_out"], + }, + outputs={"Out": ["relu0_out"]}, + ) slice_op1 = OpConfig( "slice", inputs={ @@ -133,6 +171,13 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest): decrease_axis=[0], outputs={"Out": ["slice1_out"]}, ) + relu_op1 = OpConfig( + "relu", + inputs={ + "X": ["slice1_out"], + }, + outputs={"Out": ["relu1_out"]}, + ) slice_op2 = OpConfig( "slice", inputs={ @@ -144,7 +189,14 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest): decrease_axis=[0], outputs={"Out": ["slice2_out"]}, ) - ops = [slice_op0, slice_op1, slice_op2] + relu_op2 = OpConfig( + "relu", + inputs={ + "X": ["slice2_out"], + }, + outputs={"Out": ["relu2_out"]}, + ) + ops = [slice_op0, relu_op0, slice_op1, relu_op1, slice_op2, relu_op2] program_config = ProgramConfig( ops=ops, @@ -152,7 +204,171 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest): inputs={ "slice_x": TensorConfig(shape=slice_x), }, - outputs=["slice0_out", "slice1_out", "slice2_out"], + outputs=["relu0_out", "relu1_out", "relu2_out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["delete_repeated_ops_pass"], + ) + + +class TestDeleteRepeatedAddPass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ['elementwise_add', 'relu', 'relu', 'relu'], (1e-5, 1e-5) + + def sample_program_config(self, draw): + add_x = draw( + st.lists( + st.integers(min_value=1, max_value=20), min_size=2, max_size=4 + ) + ) + + add_op0 = OpConfig( + "elementwise_add", + inputs={ + "X": ["add_x"], + "Y": ["add_y"], + }, + axis=-1, + outputs={"Out": ["add0_out"]}, + ) + relu_op0 = OpConfig( + "relu", + inputs={ + "X": ["add0_out"], + }, + outputs={"Out": ["relu0_out"]}, + ) + add_op1 = OpConfig( + "elementwise_add", + inputs={ + "X": ["add_x"], + "Y": ["add_y"], + }, + axis=-1, + outputs={"Out": ["add1_out"]}, + ) + relu_op1 = OpConfig( + "relu", + inputs={ + "X": ["add1_out"], + }, + outputs={"Out": ["relu1_out"]}, + ) + add_op2 = OpConfig( + "elementwise_add", + inputs={ + "X": ["add_x"], + "Y": ["add_y"], + }, + axis=-1, + outputs={"Out": ["add2_out"]}, + ) + relu_op2 = OpConfig( + "relu", + inputs={ + "X": ["add2_out"], + }, + outputs={"Out": ["relu2_out"]}, + ) + ops = [add_op0, relu_op0, add_op1, relu_op1, add_op2, relu_op2] + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "add_x": TensorConfig(shape=add_x), + "add_y": TensorConfig(shape=add_x), + }, + outputs=["relu0_out", "relu1_out", "relu2_out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["delete_repeated_ops_pass"], + ) + + +class TestDeleteRepeatedScalePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ['scale', 'relu', 'relu', 'relu'], (1e-5, 1e-5) + + def sample_program_config(self, draw): + scale_x = draw( + st.lists( + st.integers(min_value=1, max_value=20), min_size=2, max_size=4 + ) + ) + + scale_op0 = OpConfig( + "scale", + inputs={ + "X": ["scale_x"], + }, + scale=2.0, + bias=1.0, + bias_after_scale=True, + outputs={"Out": ["scale0_out"]}, + ) + relu_op0 = OpConfig( + "relu", + inputs={ + "X": ["scale0_out"], + }, + outputs={"Out": ["relu0_out"]}, + ) + scale_op1 = OpConfig( + "scale", + inputs={ + "X": ["scale_x"], + }, + scale=2.0, + bias=1.0, + bias_after_scale=True, + outputs={"Out": ["scale1_out"]}, + ) + relu_op1 = OpConfig( + "relu", + inputs={ + "X": ["scale1_out"], + }, + outputs={"Out": ["relu1_out"]}, + ) + scale_op2 = OpConfig( + "scale", + inputs={ + "X": ["scale_x"], + }, + scale=2.0, + bias=1.0, + bias_after_scale=True, + outputs={"Out": ["scale2_out"]}, + ) + relu_op2 = OpConfig( + "relu", + inputs={ + "X": ["scale2_out"], + }, + outputs={"Out": ["relu2_out"]}, + ) + ops = [scale_op0, relu_op0, scale_op1, relu_op1, scale_op2, relu_op2] + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "scale_x": TensorConfig(shape=scale_x), + }, + outputs=["relu0_out", "relu1_out", "relu2_out"], ) return program_config diff --git a/test/ir/inference/test_xpu_multi_encoder_xpu_fuse_pass.py b/test/ir/inference/test_xpu_multi_encoder_xpu_fuse_pass.py index a43fb2e3839c92027bd486a19f0cc7c8d890d114..47e367da7b52e0f305c98119689e6c8e7b1b3c48 100644 --- a/test/ir/inference/test_xpu_multi_encoder_xpu_fuse_pass.py +++ b/test/ir/inference/test_xpu_multi_encoder_xpu_fuse_pass.py @@ -294,9 +294,6 @@ class TestMultiEncoderXPUFusePass(PassAutoScanTest): qkv_add_3_bias_shape = [qkv_matmul_3_w_shape[1]] ln_1_bias_shape = [q_matmul_x_shape[2]] - # def generate_q_matmul_w(): - # return np.random.random(x_shape).astype(np.float32) - program_config = ProgramConfig( ops=ops, weights={ diff --git a/test/ir/inference/test_xpu_reshape_unstack_concat_fuse_pass.py b/test/ir/inference/test_xpu_reshape_unstack_concat_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..80d5a3eaf6457590c27ad12262fcf91837a8324f --- /dev/null +++ b/test/ir/inference/test_xpu_reshape_unstack_concat_fuse_pass.py @@ -0,0 +1,162 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestReshapeUnstackConcatFusePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, [ + "reshape2", + "slice", + "reshape2", + "unstack", + "concat", + "reshape2", + "transpose2", + "split", + ], (1e-3, 1e-3) + + def sample_program_config(self, draw): + reshape_x_shape = [4, 48, 2, 16, 4096] + + reshape_op = OpConfig( + "reshape2", + inputs={"X": ["reshape_x"]}, + outputs={"Out": ["reshape_out"], "XShape": ["reshape_xshape"]}, + shape=[4, -1, 48, 2, 16, 4096], + ) + unstack_op = OpConfig( + "unstack", + inputs={"X": ["reshape_out"]}, + outputs={ + "Y": [ + "unstakc_out0", + "unstakc_out1", + "unstakc_out2", + "unstakc_out3", + ] + }, + axis=0, + num=4, + ) + concat_op = OpConfig( + "concat", + inputs={ + "X": [ + "unstakc_out0", + "unstakc_out1", + "unstakc_out2", + "unstakc_out3", + ] + }, + outputs={"Out": ["concat_out"]}, + axis=-2, + ) + + slice_0s = [] + reshape_0s = [] + slice_1s = [] + reshape_1s = [] + transposes = [] + out_names = [] + for i in range(48): + slice_0_op = OpConfig( + "slice", + inputs={"Input": ["concat_out"]}, + outputs={"Out": ["slice_0_" + str(i) + "_out"]}, + starts=[i], + ends=[i + 1], + axes=[1], + decrease_axis=[], + ) + slice_0s.append(slice_0_op) + + reshape_0_op = OpConfig( + "reshape2", + inputs={"X": ["slice_0_" + str(i) + "_out"]}, + outputs={ + "Out": ["reshape_0_" + str(i) + "_out"], + "XShape": ["reshape_0_" + str(i) + "_xshape"], + }, + shape=[-1, 2, 64, 4, 1024], + ) + reshape_0s.append(reshape_0_op) + + slice_1_op = OpConfig( + "slice", + inputs={"Input": ["reshape_0_" + str(i) + "_out"]}, + outputs={"Out": ["slice_1_" + str(i) + "_out"]}, + starts=[1], + ends=[2], + axes=[3], + decrease_axis=[3], + ) + slice_1s.append(slice_1_op) + + reshape_1_op = OpConfig( + "reshape2", + inputs={"X": ["slice_1_" + str(i) + "_out"]}, + outputs={ + "Out": ["reshape_1_" + str(i) + "_out"], + "XShape": ["reshape_1_" + str(i) + "_xshape"], + }, + shape=[-1, 2, 64, 16, 64], + ) + reshape_1s.append(reshape_1_op) + + transpose_op = OpConfig( + "transpose2", + inputs={"X": ["reshape_1_" + str(i) + "_out"]}, + outputs={ + "Out": ["transpose_" + str(i) + "_out"], + "XShape": ["transpose_" + str(i) + "_xshape"], + }, + axis=[1, 0, 3, 2, 4], + ) + transposes.append(transpose_op) + out_names.append("transpose_" + str(i) + "_out") + + ops = [reshape_op, unstack_op, concat_op] + ops.extend(slice_0s) + ops.extend(reshape_0s) + ops.extend(slice_1s) + ops.extend(reshape_1s) + ops.extend(transposes) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "reshape_x": TensorConfig(shape=reshape_x_shape), + }, + outputs=out_names, + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=1, + min_success_num=1, + passes=["reshape_unstack_concat_fuse_pass"], + ) + + +if __name__ == "__main__": + unittest.main()