未验证 提交 05bd4a89 编写于 作者: Z zhupengyang 提交者: GitHub

delete_repeated_ops_pass and reshape_unstack_concat_fuse_pass (#54846)

上级 e50266fe
...@@ -241,6 +241,8 @@ if(WITH_XPU) ...@@ -241,6 +241,8 @@ if(WITH_XPU)
pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS}) ${XPU_PASS_DEPS})
pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(reshape_unstack_concat_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS}) ${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_adaptive_seqlen_fuse_pass inference DIR xpu pass_library(multi_encoder_xpu_adaptive_seqlen_fuse_pass inference DIR xpu
......
...@@ -101,68 +101,86 @@ class DeleteRepeatedOpsPass : public FusePassBase { ...@@ -101,68 +101,86 @@ class DeleteRepeatedOpsPass : public FusePassBase {
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
private: private:
int DeleteShapePass(ir::Graph* graph) const; void DeleteRepeatedOps(
ir::Graph* graph,
int DeleteSlicePass(ir::Graph* graph) const; const std::string& op_type,
std::function<std::string(OpDesc*)> gen_op_key_fn) const;
const std::string name_scope_{"delete_repeated_ops_pass"}; const std::string name_scope_{"delete_repeated_ops_pass"};
}; };
int DeleteRepeatedOpsPass::DeleteShapePass(ir::Graph* graph) const { void DeleteRepeatedOpsPass::DeleteRepeatedOps(
ir::Graph* graph,
const std::string& op_type,
std::function<std::string(OpDesc*)> gen_op_key_fn) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::VarWithRepeatedOpsPattern pattern( patterns::VarWithRepeatedOpsPattern pattern(
gpd.mutable_pattern(), name_scope_, "shape"); gpd.mutable_pattern(), name_scope_, op_type);
int delete_counts = 0; int delete_counts = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) { Graph* graph) {
VLOG(4) << "handle DeleteShapePass"; VLOG(4) << "handle DeleteRepeatedOps";
GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern); GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern);
std::vector<std::string> invalid_shape_out_ops{"while", std::vector<std::string> invalid_out_ops{
"conditional_block"}; "while", "conditional_block", "fetch"};
std::vector<Node*> shapes; std::map<std::string, std::vector<Node*>> ops_map;
for (auto* next_op : in_var->outputs) { for (auto* next_op : in_var->outputs) {
if (next_op->Name() != "shape") continue; if (next_op->Name() != op_type) continue;
bool shape_out_op_is_invalid = false; auto* op = next_op;
for (auto* shape_out_op : next_op->outputs[0]->outputs) { bool out_op_is_invalid = false;
if (std::count(invalid_shape_out_ops.begin(), for (auto* out_op : op->outputs[0]->outputs) {
invalid_shape_out_ops.end(), if (std::count(invalid_out_ops.begin(),
shape_out_op->Name()) > 0 || invalid_out_ops.end(),
HasOutVarName(shape_out_op, next_op->outputs[0]->Name())) { out_op->Name()) > 0 ||
shape_out_op_is_invalid = true; HasOutVarName(out_op, op->outputs[0]->Name())) {
out_op_is_invalid = true;
break; break;
} }
} }
if (!shape_out_op_is_invalid) { if (out_op_is_invalid) continue;
shapes.push_back(next_op); auto attr_key = gen_op_key_fn(op->Op());
ops_map[attr_key].push_back(op);
}
for (auto iter = ops_map.begin(); iter != ops_map.end();) {
if (iter->second.size() <= 1) {
iter = ops_map.erase(iter);
} else {
iter++;
} }
} }
if (shapes.size() <= 1) return;
auto* first_shape_out = shapes[0]->outputs[0]; for (auto iter : ops_map) {
auto first_shape_out_name = first_shape_out->Name(); auto ops = iter.second;
std::unordered_set<const Node*> delete_nodes; auto* first_op_out = ops[0]->outputs[0];
for (size_t i = 1; i < shapes.size(); i++) { auto first_op_out_name = first_op_out->Name();
auto* cur_shape = shapes[i]; std::unordered_set<const Node*> delete_nodes;
auto* cur_shape_out = cur_shape->outputs[0]; for (size_t i = 1; i < ops.size(); i++) {
auto cur_shape_out_name = cur_shape_out->Name(); auto* cur_op = ops[i];
for (auto* shape_out_op : cur_shape_out->outputs) { auto* cur_op_out = cur_op->outputs[0];
shape_out_op->Op()->Rename(cur_shape_out_name, first_shape_out_name); auto cur_op_out_name = cur_op_out->Name();
IR_NODE_LINK_TO(first_shape_out, shape_out_op); for (auto* out_op : cur_op_out->outputs) {
out_op->Op()->RenameInput(cur_op_out_name, first_op_out_name);
IR_NODE_LINK_TO(first_op_out, out_op);
}
delete_nodes.insert(cur_op);
delete_nodes.insert(cur_op_out);
delete_counts++;
} }
delete_nodes.insert(cur_shape); GraphSafeRemoveNodes(graph, delete_nodes);
delete_nodes.insert(cur_shape_out);
delete_counts++;
} }
GraphSafeRemoveNodes(graph, delete_nodes);
}; };
gpd(graph, handler); gpd(graph, handler);
return delete_counts; if (delete_counts > 0) {
LOG(INFO) << "--- delete " << delete_counts << " repeated " << op_type
<< " ops";
}
} }
std::string GenShapeAttrKey(OpDesc* slice_op_desc) { return ""; }
std::string GenSliceAttrKey(OpDesc* slice_op_desc) { std::string GenSliceAttrKey(OpDesc* slice_op_desc) {
std::string attr_key; std::string attr_key;
auto starts = slice_op_desc->GetAttrIfExists<std::vector<int>>("starts"); auto starts = slice_op_desc->GetAttrIfExists<std::vector<int>>("starts");
...@@ -189,69 +207,27 @@ std::string GenSliceAttrKey(OpDesc* slice_op_desc) { ...@@ -189,69 +207,27 @@ std::string GenSliceAttrKey(OpDesc* slice_op_desc) {
return attr_key; return attr_key;
} }
int DeleteRepeatedOpsPass::DeleteSlicePass(ir::Graph* graph) const { std::string GenCastAttrKey(OpDesc* cast_op_desc) {
GraphPatternDetector gpd; auto in_dtype = cast_op_desc->GetAttrIfExists<int>("in_dtype");
patterns::VarWithRepeatedOpsPattern pattern( auto out_dtype = cast_op_desc->GetAttrIfExists<int>("out_dtype");
gpd.mutable_pattern(), name_scope_, "slice"); return "in_dtype_" + std::to_string(in_dtype) + "_out_dtype_" +
std::to_string(out_dtype);
int delete_counts = 0; }
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle DeleteSlicePass";
GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern);
std::vector<std::string> invalid_slice_out_ops{"while",
"conditional_block"};
std::map<std::string, std::vector<Node*>> slice_ops;
for (auto* next_op : in_var->outputs) {
if (next_op->Name() != "slice") continue;
auto* slice = next_op;
bool slice_out_op_is_invalid = false;
for (auto* slice_out_op : slice->outputs[0]->outputs) {
if (std::count(invalid_slice_out_ops.begin(),
invalid_slice_out_ops.end(),
slice_out_op->Name()) > 0 ||
HasOutVarName(slice_out_op, slice->outputs[0]->Name())) {
slice_out_op_is_invalid = true;
break;
}
}
if (slice_out_op_is_invalid) continue;
auto attr_key = GenSliceAttrKey(slice->Op());
slice_ops[attr_key].push_back(slice);
}
for (auto iter = slice_ops.begin(); iter != slice_ops.end();) {
if (iter->second.size() <= 1) {
iter = slice_ops.erase(iter);
} else {
iter++;
}
}
for (auto iter : slice_ops) { std::string GenAddAttrKey(OpDesc* add_op_desc) {
auto slices = iter.second; std::string x_name = add_op_desc->Input("X")[0];
auto* first_slice_out = slices[0]->outputs[0]; std::string y_name = add_op_desc->Input("Y")[0];
auto first_slice_out_name = first_slice_out->Name(); auto axis = add_op_desc->GetAttrIfExists<int>("axis");
std::unordered_set<const Node*> delete_nodes; return x_name + "_" + y_name + "_axis_" + std::to_string(axis);
for (size_t i = 1; i < slices.size(); i++) { }
auto* cur_slice = slices[i];
auto* cur_slice_out = cur_slice->outputs[0];
auto cur_slice_out_name = cur_slice_out->Name();
for (auto* slice_out_op : cur_slice_out->outputs) {
slice_out_op->Op()->RenameInput(cur_slice_out_name,
first_slice_out_name);
IR_NODE_LINK_TO(first_slice_out, slice_out_op);
}
delete_nodes.insert(cur_slice);
delete_nodes.insert(cur_slice_out);
delete_counts++;
}
GraphSafeRemoveNodes(graph, delete_nodes);
}
};
gpd(graph, handler); std::string GenScaleAttrKey(OpDesc* scale_op_desc) {
return delete_counts; auto scale = scale_op_desc->GetAttrIfExists<float>("scale");
auto bias = scale_op_desc->GetAttrIfExists<float>("bias");
auto bias_after_scale =
scale_op_desc->GetAttrIfExists<bool>("bias_after_scale");
return "scale_" + std::to_string(scale) + "_bias_" + std::to_string(bias) +
"_bias_after_scale_" + std::to_string(bias_after_scale);
} }
void DeleteRepeatedOpsPass::ApplyImpl(ir::Graph* graph) const { void DeleteRepeatedOpsPass::ApplyImpl(ir::Graph* graph) const {
...@@ -259,15 +235,12 @@ void DeleteRepeatedOpsPass::ApplyImpl(ir::Graph* graph) const { ...@@ -259,15 +235,12 @@ void DeleteRepeatedOpsPass::ApplyImpl(ir::Graph* graph) const {
graph, platform::errors::PreconditionNotMet("graph should not be null.")); graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph); Init(name_scope_, graph);
int delete_counts = DeleteShapePass(graph); DeleteRepeatedOps(graph, "shape", GenShapeAttrKey);
if (delete_counts > 0) { DeleteRepeatedOps(graph, "slice", GenSliceAttrKey);
LOG(INFO) << "--- delete " << delete_counts << " repeated shape ops"; DeleteRepeatedOps(graph, "cast", GenCastAttrKey);
} DeleteRepeatedOps(graph, "elementwise_add", GenAddAttrKey);
DeleteRepeatedOps(graph, "scale", GenScaleAttrKey);
delete_counts = DeleteSlicePass(graph); DeleteRepeatedOps(graph, "cast", GenCastAttrKey);
if (delete_counts > 0) {
LOG(INFO) << "--- delete " << delete_counts << " repeated slice ops";
}
} }
} // namespace ir } // namespace ir
......
...@@ -70,6 +70,7 @@ static const std::vector<std::string> xpu_support_subgraph_passes = { ...@@ -70,6 +70,7 @@ static const std::vector<std::string> xpu_support_subgraph_passes = {
"xpu_delete_cast_op_pass", "xpu_delete_cast_op_pass",
"fc_xpu_fuse_pass", "fc_xpu_fuse_pass",
"link_xpu_op_max_pass", "link_xpu_op_max_pass",
"xpu_delete_cast_op_pass",
}; };
Graph *Pass::Apply(Graph *graph) const { Graph *Pass::Apply(Graph *graph) const {
......
...@@ -97,10 +97,10 @@ Reshape2MatmulPattern::Reshape2MatmulPattern(PDPattern* pattern, ...@@ -97,10 +97,10 @@ Reshape2MatmulPattern::Reshape2MatmulPattern(PDPattern* pattern,
->assert_more([](Node* node) { ->assert_more([](Node* node) {
auto reshape2_in_x_shape = node->Var()->GetShape(); auto reshape2_in_x_shape = node->Var()->GetShape();
size_t reshape2_in_rank = reshape2_in_x_shape.size(); size_t reshape2_in_rank = reshape2_in_x_shape.size();
bool nice_shape = return reshape2_in_rank == 4 && ((reshape2_in_x_shape[2] == 1 &&
(reshape2_in_x_shape[2] == 1 && reshape2_in_x_shape[3] == 1) || reshape2_in_x_shape[3] == 1) ||
(reshape2_in_x_shape[1] == 1 && reshape2_in_x_shape[3] == 1); (reshape2_in_x_shape[1] == 1 &&
return (reshape2_in_rank == 4 && nice_shape); reshape2_in_x_shape[3] == 1));
}); });
auto* reshape2 = auto* reshape2 =
pattern->NewNode(reshape2_repr()) pattern->NewNode(reshape2_repr())
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct ReshapeUnstackConcatPattern : public PatternBase {
ReshapeUnstackConcatPattern(PDPattern* pattern,
const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(reshape);
PATTERN_DECL_NODE(unstack);
PATTERN_DECL_NODE(concat);
// declare variable node's name
PATTERN_DECL_NODE(reshape_in);
PATTERN_DECL_NODE(reshape_out);
PATTERN_DECL_NODE(unstack_out0);
PATTERN_DECL_NODE(concat_out);
};
ReshapeUnstackConcatPattern::ReshapeUnstackConcatPattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* reshape_in =
pattern->NewNode(reshape_in_repr())->assert_is_op_input("reshape2", "X");
auto* reshape =
pattern->NewNode(reshape_repr())
->assert_is_op("reshape2")
->assert_more([](Node* node) {
auto shape = node->Op()->GetAttrIfExists<std::vector<int>>("shape");
return shape.size() == 6;
});
auto* reshape_out = pattern->NewNode(reshape_out_repr())
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("unstack", "X");
auto* unstack = pattern->NewNode(unstack_repr())
->assert_is_op("unstack")
->assert_more([](Node* node) {
auto axis = node->Op()->GetAttrIfExists<int>("axis");
return axis == 0;
});
auto* unstack_out0 = pattern->NewNode(unstack_out0_repr())
->assert_is_op_nth_output("unstack", "Y", 0)
->assert_is_op_nth_input("concat", "X", 0);
auto* concat = pattern->NewNode(concat_repr())
->assert_is_op("concat")
->assert_more([](Node* node) {
auto axis = node->Op()->GetAttrIfExists<int>("axis");
return axis == -2;
});
auto* concat_out = pattern->NewNode(concat_out_repr())
->assert_is_op_output("concat", "Out")
->assert_more([](Node* node) {
auto out_nodes = node->outputs;
if (out_nodes.size() <= 1) {
return false;
}
for (auto out_node : out_nodes) {
if (out_node->Name() != "slice") {
return false;
}
}
return true;
});
reshape->LinksFrom({reshape_in}).LinksTo({reshape_out});
unstack->LinksFrom({reshape_out}).LinksTo({unstack_out0});
concat->LinksFrom({unstack_out0}).LinksTo({concat_out});
}
} // namespace patterns
class ReshapeUnstackConcatFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
const std::string name_scope_{"reshape_unstack_concat_fuse_pass"};
};
// clang-format off
/*
Origin subgraph:
reshape(4,-1,48,2,16,4096)
|
unstack
|
concat
|
------------------------------------------------------------------
| | |
slice(start/end/axes:0/1/1) slice(start/end/axes:1/2/1) ... slice(start/end/axes:n-1/n/1)
| | |
reshape(-1,2,64,4,1024) reshape(-1,2,64,4,1024) ... reshape(-1,2,64,4,1024)
| | |
slice(start/end/axes:0/1/3) slice(start/end/axes:0/1/3) ... slice(start/end/axes:0/1/3)
| | |
reshape(-1,2,64,16,64) reshape(-1,2,64,16,64) ... reshape(-1,2,64,16,64)
| | |
transpose(1,0,3,2,4) transpose(1,0,3,2,4) ... transpose(1,0,3,2,4)
Optimized subgraph:
reshape(-1,4,1024)
|
slice(start/end/axes:0/1/2)
|
reshape(4,-1,48,2,16,1024)
|
unstack
|
concat
|
reshape(-1,n*2,64,16,64)
|
transpose(1,0,3,2,4)
|
split(num/axis:n/0)
|
------------------------------------------------------------------
| | |
*/
// clang-format on
void ReshapeUnstackConcatFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
GraphPatternDetector gpd;
patterns::ReshapeUnstackConcatPattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ReshapeUnstackConcatFusePass fuse";
GET_IR_NODE(reshape);
GET_IR_NODE(unstack);
GET_IR_NODE(concat);
GET_IR_NODE(reshape_in);
GET_IR_NODE(reshape_out);
GET_IR_NODE(unstack_out0);
GET_IR_NODE(concat_out);
auto* block = reshape->Op()->Block();
auto concat_out_ops = concat_out->outputs;
int ops_num = concat_out_ops.size();
std::vector<Node*> slice_0s(ops_num, nullptr);
std::vector<Node*> reshape_0s(ops_num, nullptr);
std::vector<Node*> slice_1s(ops_num, nullptr);
std::vector<Node*> reshape_1s(ops_num, nullptr);
std::vector<Node*> transposes(ops_num, nullptr);
for (int i = 0; i < ops_num; i++) {
auto slice_0 = concat_out_ops[i];
if (slice_0->Name() != "slice") return;
auto slice_0_starts =
slice_0->Op()->GetAttrIfExists<std::vector<int>>("starts");
auto slice_0_ends =
slice_0->Op()->GetAttrIfExists<std::vector<int>>("ends");
auto slice_0_axes =
slice_0->Op()->GetAttrIfExists<std::vector<int>>("axes");
if (slice_0_starts.size() != 1 ||
(slice_0_ends[0] - slice_0_starts[0] != 1) || slice_0_axes[0] != 1) {
return;
}
int op_index = slice_0_starts[0];
if (slice_0s[op_index] != nullptr) return;
slice_0s[op_index] = slice_0;
auto reshape_0 = slice_0->outputs[0]->outputs[0];
if (reshape_0->Name() != "reshape2") return;
auto reshape_0_shape =
reshape_0->Op()->GetAttrIfExists<std::vector<int>>("shape");
if (reshape_0_shape.size() != 5) return;
reshape_0s[op_index] = reshape_0;
Node* slice_1 = nullptr;
for (auto reshape_out : reshape_0->outputs) {
if (reshape_out->Name() == reshape_0->Op()->Output("Out")[0]) {
slice_1 = reshape_out->outputs[0];
if (slice_1->Name() != "slice") return;
auto slice_1_axes =
slice_1->Op()->GetAttrIfExists<std::vector<int>>("axes");
if (slice_1_axes.size() != 1 || slice_1_axes[0] != 3) {
return;
}
slice_1s[op_index] = slice_1;
}
}
auto* reshape_1 = slice_1->outputs[0]->outputs[0];
if (reshape_1->Name() != "reshape2") return;
auto reshape_1_shape =
reshape_1->Op()->GetAttrIfExists<std::vector<int>>("shape");
if (reshape_1_shape.size() != 5) return;
reshape_1s[op_index] = reshape_1;
Node* transpose = nullptr;
for (auto reshape_out : reshape_1->outputs) {
if (reshape_out->Name() == reshape_1->Op()->Output("Out")[0]) {
transpose = reshape_out->outputs[0];
if (transpose->Name() != "transpose2") return;
auto transpose_axis =
transpose->Op()->GetAttrIfExists<std::vector<int>>("axis");
if (transpose_axis != std::vector<int>{1, 0, 3, 2, 4}) return;
transposes[op_index] = transpose;
}
}
}
std::string new_reshape_0_out_name = reshape_in->Name() + "_reshape_out";
VarDesc new_reshape_0_out_desc(new_reshape_0_out_name);
Node* new_reshape_0_out = graph->CreateVarNode(&new_reshape_0_out_desc);
framework::OpDesc new_reshape_0_op_desc(block);
new_reshape_0_op_desc.SetType("reshape2");
auto reshape_0_shape =
reshape_0s[0]->Op()->GetAttrIfExists<std::vector<int>>("shape");
std::vector<int> new_reshape_0_shape{
-1, reshape_0_shape[3], reshape_0_shape[4]};
new_reshape_0_op_desc.SetAttr("shape", new_reshape_0_shape);
new_reshape_0_op_desc.SetInput("X", {reshape_in->Name()});
new_reshape_0_op_desc.SetOutput("Out", {new_reshape_0_out_name});
auto* new_reshape_0 = graph->CreateOpNode(&new_reshape_0_op_desc);
std::string new_slice_0_out_name = reshape_in->Name() + "_slice_out";
VarDesc new_slice_0_out_desc(new_slice_0_out_name);
Node* new_slice_0_out = graph->CreateVarNode(&new_slice_0_out_desc);
framework::OpDesc new_slice_0_op_desc(block);
new_slice_0_op_desc.SetType("slice");
auto new_slice_0_start =
slice_1s[0]->Op()->GetAttrIfExists<std::vector<int>>("starts");
auto new_slice_0_ends =
slice_1s[0]->Op()->GetAttrIfExists<std::vector<int>>("ends");
new_slice_0_op_desc.SetAttr("starts", new_slice_0_start);
new_slice_0_op_desc.SetAttr("ends", new_slice_0_ends);
new_slice_0_op_desc.SetAttr("axes", std::vector<int>{1});
new_slice_0_op_desc.SetAttr("decrease_axis", std::vector<int>{1});
new_slice_0_op_desc.SetInput("Input", {new_reshape_0_out_name});
new_slice_0_op_desc.SetOutput("Out", {new_slice_0_out_name});
auto* new_slice_0 = graph->CreateOpNode(&new_slice_0_op_desc);
reshape->Op()->SetInput("X", {new_slice_0_out_name});
auto reshape_shape =
reshape->Op()->GetAttrIfExists<std::vector<int>>("shape");
reshape_shape[5] /= reshape_0_shape[3];
reshape->Op()->SetAttr("shape", reshape_shape);
IR_NODE_UNLINK(reshape_in, reshape);
IR_NODE_LINK_TO(reshape_in, new_reshape_0);
IR_NODE_LINK_TO(new_reshape_0, new_reshape_0_out);
IR_NODE_LINK_TO(new_reshape_0_out, new_slice_0);
IR_NODE_LINK_TO(new_slice_0, new_slice_0_out);
IR_NODE_LINK_TO(new_slice_0_out, reshape);
std::string new_reshape_1_out_name = concat_out->Name() + "_reshape_out";
VarDesc new_reshape_1_out_desc(new_reshape_1_out_name);
Node* new_reshape_1_out = graph->CreateVarNode(&new_reshape_1_out_desc);
framework::OpDesc new_reshape_1_op_desc(block);
new_reshape_1_op_desc.SetType("reshape2");
auto new_reshape_1_shape =
reshape_1s[0]->Op()->GetAttrIfExists<std::vector<int>>("shape");
new_reshape_1_shape[1] *= ops_num;
new_reshape_1_op_desc.SetAttr("shape", new_reshape_1_shape);
new_reshape_1_op_desc.SetInput("X", {concat_out->Name()});
new_reshape_1_op_desc.SetOutput("Out", {new_reshape_1_out_name});
auto* new_reshape_1 = graph->CreateOpNode(&new_reshape_1_op_desc);
std::string new_transpose_0_out_name =
concat_out->Name() + "_transpose_out";
VarDesc new_transpose_0_out_desc(new_transpose_0_out_name);
Node* new_transpose_0_out = graph->CreateVarNode(&new_transpose_0_out_desc);
framework::OpDesc new_transpose_0_op_desc(block);
new_transpose_0_op_desc.SetType("transpose2");
auto transpose_axis =
transposes[0]->Op()->GetAttrIfExists<std::vector<int>>("axis");
new_transpose_0_op_desc.SetAttr("axis", transpose_axis);
new_transpose_0_op_desc.SetInput("X", {new_reshape_1_out_name});
new_transpose_0_op_desc.SetOutput("Out", {new_transpose_0_out_name});
auto* new_transpose_0 = graph->CreateOpNode(&new_transpose_0_op_desc);
std::vector<std::string> new_split_0_out_names;
for (auto* transpose : transposes) {
new_split_0_out_names.push_back(transpose->Op()->Output("Out")[0]);
}
framework::OpDesc new_split_0_op_desc(block);
new_split_0_op_desc.SetType("split");
new_split_0_op_desc.SetAttr("num", ops_num);
new_split_0_op_desc.SetAttr("axis", 0);
new_split_0_op_desc.SetInput("X", {new_transpose_0_out_name});
new_split_0_op_desc.SetOutput("Out", new_split_0_out_names);
auto* new_split_0 = graph->CreateOpNode(&new_split_0_op_desc);
IR_NODE_LINK_TO(concat_out, new_reshape_1);
IR_NODE_LINK_TO(new_reshape_1, new_reshape_1_out);
IR_NODE_LINK_TO(new_reshape_1_out, new_transpose_0);
IR_NODE_LINK_TO(new_transpose_0, new_transpose_0_out);
IR_NODE_LINK_TO(new_transpose_0_out, new_split_0);
for (auto* transpose : transposes) {
for (auto* transpose_out : transpose->outputs) {
if (transpose_out->Name() == transpose->Op()->Output("Out")[0]) {
IR_NODE_LINK_TO(new_split_0, transpose_out);
}
}
}
std::unordered_set<const Node*> delete_nodes;
delete_nodes.insert(slice_0s.begin(), slice_0s.end());
for (auto* slice_0 : slice_0s) {
delete_nodes.emplace(slice_0->outputs[0]);
}
delete_nodes.insert(reshape_0s.begin(), reshape_0s.end());
for (auto* reshape_0 : reshape_0s) {
auto reshape_0_outs = reshape_0->outputs;
delete_nodes.insert(reshape_0_outs.begin(), reshape_0_outs.end());
}
delete_nodes.insert(slice_1s.begin(), slice_1s.end());
for (auto* slice_1 : slice_1s) {
delete_nodes.emplace(slice_1->outputs[0]);
}
delete_nodes.insert(reshape_1s.begin(), reshape_1s.end());
for (auto* reshape_1 : reshape_1s) {
auto reshape_1_outs = reshape_1->outputs;
delete_nodes.insert(reshape_1_outs.begin(), reshape_1_outs.end());
}
delete_nodes.insert(transposes.begin(), transposes.end());
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(reshape_unstack_concat_fuse_pass,
paddle::framework::ir::ReshapeUnstackConcatFusePass);
REGISTER_PASS_CAPABILITY(reshape_unstack_concat_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"stack", 0));
...@@ -512,6 +512,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { ...@@ -512,6 +512,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"delete_concat_op_pass", "delete_concat_op_pass",
"identity_op_clean_pass", "identity_op_clean_pass",
"delete_repeated_ops_pass", "delete_repeated_ops_pass",
"reshape_unstack_concat_fuse_pass",
"delete_op_device_pass", "delete_op_device_pass",
"constant_folding_pass", "constant_folding_pass",
"delete_elementwise_mul_op_pass", "delete_elementwise_mul_op_pass",
...@@ -525,6 +526,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { ...@@ -525,6 +526,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"fold_interp_outsize_fuse_pass", "fold_interp_outsize_fuse_pass",
"fold_two_squeeze2_fuse_pass", "fold_two_squeeze2_fuse_pass",
"delete_cast_op_pass", "delete_cast_op_pass",
"xpu_delete_cast_op_pass",
"stack_fuse_pass", "stack_fuse_pass",
"fused_multi_transformer_xpu_pass", "fused_multi_transformer_xpu_pass",
"sigmoid_elementmul_fuse_pass", "sigmoid_elementmul_fuse_pass",
...@@ -539,7 +541,6 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { ...@@ -539,7 +541,6 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"link_xpu_op_max_pass", "link_xpu_op_max_pass",
"inplace_op_var_pass", "inplace_op_var_pass",
"delete_isolated_node_pass", "delete_isolated_node_pass",
"xpu_delete_cast_op_pass",
}); });
use_xpu_ = true; use_xpu_ = true;
} }
......
...@@ -19,10 +19,10 @@ from auto_scan_test import PassAutoScanTest ...@@ -19,10 +19,10 @@ from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig from program_config import OpConfig, ProgramConfig, TensorConfig
class TestDeleteRepeatedShapePass(PassAutoScanTest): class TestDeleteRepeatedShapeCastPass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True) config = self.create_inference_config(use_xpu=True)
yield config, ['shape', 'cast', 'cast', 'cast'], (1e-5, 1e-5) yield config, ['shape', 'cast', 'relu', 'relu', 'relu'], (1e-5, 1e-5)
def sample_program_config(self, draw): def sample_program_config(self, draw):
x_shape = draw( x_shape = draw(
...@@ -47,6 +47,13 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest): ...@@ -47,6 +47,13 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest):
out_dtype=5, out_dtype=5,
outputs={"Out": ["cast0_out"]}, outputs={"Out": ["cast0_out"]},
) )
relu_op0 = OpConfig(
"relu",
inputs={
"X": ["cast0_out"],
},
outputs={"Out": ["relu0_out"]},
)
shape_op1 = OpConfig( shape_op1 = OpConfig(
"shape", "shape",
inputs={ inputs={
...@@ -63,6 +70,13 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest): ...@@ -63,6 +70,13 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest):
out_dtype=5, out_dtype=5,
outputs={"Out": ["cast1_out"]}, outputs={"Out": ["cast1_out"]},
) )
relu_op1 = OpConfig(
"relu",
inputs={
"X": ["cast1_out"],
},
outputs={"Out": ["relu1_out"]},
)
shape_op2 = OpConfig( shape_op2 = OpConfig(
"shape", "shape",
inputs={ inputs={
...@@ -79,7 +93,24 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest): ...@@ -79,7 +93,24 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest):
out_dtype=5, out_dtype=5,
outputs={"Out": ["cast2_out"]}, outputs={"Out": ["cast2_out"]},
) )
ops = [shape_op0, cast_op0, shape_op1, cast_op1, shape_op2, cast_op2] relu_op2 = OpConfig(
"relu",
inputs={
"X": ["cast2_out"],
},
outputs={"Out": ["relu2_out"]},
)
ops = [
shape_op0,
cast_op0,
relu_op0,
shape_op1,
cast_op1,
relu_op1,
shape_op2,
cast_op2,
relu_op2,
]
program_config = ProgramConfig( program_config = ProgramConfig(
ops=ops, ops=ops,
...@@ -87,7 +118,7 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest): ...@@ -87,7 +118,7 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest):
inputs={ inputs={
"shape_x": TensorConfig(shape=x_shape), "shape_x": TensorConfig(shape=x_shape),
}, },
outputs=["cast0_out", "cast1_out", "cast2_out"], outputs=["relu0_out", "relu1_out", "relu2_out"],
) )
return program_config return program_config
...@@ -102,7 +133,7 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest): ...@@ -102,7 +133,7 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest):
class TestDeleteRepeatedSlicePass(PassAutoScanTest): class TestDeleteRepeatedSlicePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True) config = self.create_inference_config(use_xpu=True)
yield config, ['slice'], (1e-5, 1e-5) yield config, ['slice', 'relu', 'relu', 'relu'], (1e-5, 1e-5)
def sample_program_config(self, draw): def sample_program_config(self, draw):
slice_x = draw( slice_x = draw(
...@@ -122,6 +153,13 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest): ...@@ -122,6 +153,13 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest):
decrease_axis=[0], decrease_axis=[0],
outputs={"Out": ["slice0_out"]}, outputs={"Out": ["slice0_out"]},
) )
relu_op0 = OpConfig(
"relu",
inputs={
"X": ["slice0_out"],
},
outputs={"Out": ["relu0_out"]},
)
slice_op1 = OpConfig( slice_op1 = OpConfig(
"slice", "slice",
inputs={ inputs={
...@@ -133,6 +171,13 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest): ...@@ -133,6 +171,13 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest):
decrease_axis=[0], decrease_axis=[0],
outputs={"Out": ["slice1_out"]}, outputs={"Out": ["slice1_out"]},
) )
relu_op1 = OpConfig(
"relu",
inputs={
"X": ["slice1_out"],
},
outputs={"Out": ["relu1_out"]},
)
slice_op2 = OpConfig( slice_op2 = OpConfig(
"slice", "slice",
inputs={ inputs={
...@@ -144,7 +189,14 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest): ...@@ -144,7 +189,14 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest):
decrease_axis=[0], decrease_axis=[0],
outputs={"Out": ["slice2_out"]}, outputs={"Out": ["slice2_out"]},
) )
ops = [slice_op0, slice_op1, slice_op2] relu_op2 = OpConfig(
"relu",
inputs={
"X": ["slice2_out"],
},
outputs={"Out": ["relu2_out"]},
)
ops = [slice_op0, relu_op0, slice_op1, relu_op1, slice_op2, relu_op2]
program_config = ProgramConfig( program_config = ProgramConfig(
ops=ops, ops=ops,
...@@ -152,7 +204,171 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest): ...@@ -152,7 +204,171 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest):
inputs={ inputs={
"slice_x": TensorConfig(shape=slice_x), "slice_x": TensorConfig(shape=slice_x),
}, },
outputs=["slice0_out", "slice1_out", "slice2_out"], outputs=["relu0_out", "relu1_out", "relu2_out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["delete_repeated_ops_pass"],
)
class TestDeleteRepeatedAddPass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ['elementwise_add', 'relu', 'relu', 'relu'], (1e-5, 1e-5)
def sample_program_config(self, draw):
add_x = draw(
st.lists(
st.integers(min_value=1, max_value=20), min_size=2, max_size=4
)
)
add_op0 = OpConfig(
"elementwise_add",
inputs={
"X": ["add_x"],
"Y": ["add_y"],
},
axis=-1,
outputs={"Out": ["add0_out"]},
)
relu_op0 = OpConfig(
"relu",
inputs={
"X": ["add0_out"],
},
outputs={"Out": ["relu0_out"]},
)
add_op1 = OpConfig(
"elementwise_add",
inputs={
"X": ["add_x"],
"Y": ["add_y"],
},
axis=-1,
outputs={"Out": ["add1_out"]},
)
relu_op1 = OpConfig(
"relu",
inputs={
"X": ["add1_out"],
},
outputs={"Out": ["relu1_out"]},
)
add_op2 = OpConfig(
"elementwise_add",
inputs={
"X": ["add_x"],
"Y": ["add_y"],
},
axis=-1,
outputs={"Out": ["add2_out"]},
)
relu_op2 = OpConfig(
"relu",
inputs={
"X": ["add2_out"],
},
outputs={"Out": ["relu2_out"]},
)
ops = [add_op0, relu_op0, add_op1, relu_op1, add_op2, relu_op2]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"add_x": TensorConfig(shape=add_x),
"add_y": TensorConfig(shape=add_x),
},
outputs=["relu0_out", "relu1_out", "relu2_out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["delete_repeated_ops_pass"],
)
class TestDeleteRepeatedScalePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ['scale', 'relu', 'relu', 'relu'], (1e-5, 1e-5)
def sample_program_config(self, draw):
scale_x = draw(
st.lists(
st.integers(min_value=1, max_value=20), min_size=2, max_size=4
)
)
scale_op0 = OpConfig(
"scale",
inputs={
"X": ["scale_x"],
},
scale=2.0,
bias=1.0,
bias_after_scale=True,
outputs={"Out": ["scale0_out"]},
)
relu_op0 = OpConfig(
"relu",
inputs={
"X": ["scale0_out"],
},
outputs={"Out": ["relu0_out"]},
)
scale_op1 = OpConfig(
"scale",
inputs={
"X": ["scale_x"],
},
scale=2.0,
bias=1.0,
bias_after_scale=True,
outputs={"Out": ["scale1_out"]},
)
relu_op1 = OpConfig(
"relu",
inputs={
"X": ["scale1_out"],
},
outputs={"Out": ["relu1_out"]},
)
scale_op2 = OpConfig(
"scale",
inputs={
"X": ["scale_x"],
},
scale=2.0,
bias=1.0,
bias_after_scale=True,
outputs={"Out": ["scale2_out"]},
)
relu_op2 = OpConfig(
"relu",
inputs={
"X": ["scale2_out"],
},
outputs={"Out": ["relu2_out"]},
)
ops = [scale_op0, relu_op0, scale_op1, relu_op1, scale_op2, relu_op2]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"scale_x": TensorConfig(shape=scale_x),
},
outputs=["relu0_out", "relu1_out", "relu2_out"],
) )
return program_config return program_config
......
...@@ -294,9 +294,6 @@ class TestMultiEncoderXPUFusePass(PassAutoScanTest): ...@@ -294,9 +294,6 @@ class TestMultiEncoderXPUFusePass(PassAutoScanTest):
qkv_add_3_bias_shape = [qkv_matmul_3_w_shape[1]] qkv_add_3_bias_shape = [qkv_matmul_3_w_shape[1]]
ln_1_bias_shape = [q_matmul_x_shape[2]] ln_1_bias_shape = [q_matmul_x_shape[2]]
# def generate_q_matmul_w():
# return np.random.random(x_shape).astype(np.float32)
program_config = ProgramConfig( program_config = ProgramConfig(
ops=ops, ops=ops,
weights={ weights={
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestReshapeUnstackConcatFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, [
"reshape2",
"slice",
"reshape2",
"unstack",
"concat",
"reshape2",
"transpose2",
"split",
], (1e-3, 1e-3)
def sample_program_config(self, draw):
reshape_x_shape = [4, 48, 2, 16, 4096]
reshape_op = OpConfig(
"reshape2",
inputs={"X": ["reshape_x"]},
outputs={"Out": ["reshape_out"], "XShape": ["reshape_xshape"]},
shape=[4, -1, 48, 2, 16, 4096],
)
unstack_op = OpConfig(
"unstack",
inputs={"X": ["reshape_out"]},
outputs={
"Y": [
"unstakc_out0",
"unstakc_out1",
"unstakc_out2",
"unstakc_out3",
]
},
axis=0,
num=4,
)
concat_op = OpConfig(
"concat",
inputs={
"X": [
"unstakc_out0",
"unstakc_out1",
"unstakc_out2",
"unstakc_out3",
]
},
outputs={"Out": ["concat_out"]},
axis=-2,
)
slice_0s = []
reshape_0s = []
slice_1s = []
reshape_1s = []
transposes = []
out_names = []
for i in range(48):
slice_0_op = OpConfig(
"slice",
inputs={"Input": ["concat_out"]},
outputs={"Out": ["slice_0_" + str(i) + "_out"]},
starts=[i],
ends=[i + 1],
axes=[1],
decrease_axis=[],
)
slice_0s.append(slice_0_op)
reshape_0_op = OpConfig(
"reshape2",
inputs={"X": ["slice_0_" + str(i) + "_out"]},
outputs={
"Out": ["reshape_0_" + str(i) + "_out"],
"XShape": ["reshape_0_" + str(i) + "_xshape"],
},
shape=[-1, 2, 64, 4, 1024],
)
reshape_0s.append(reshape_0_op)
slice_1_op = OpConfig(
"slice",
inputs={"Input": ["reshape_0_" + str(i) + "_out"]},
outputs={"Out": ["slice_1_" + str(i) + "_out"]},
starts=[1],
ends=[2],
axes=[3],
decrease_axis=[3],
)
slice_1s.append(slice_1_op)
reshape_1_op = OpConfig(
"reshape2",
inputs={"X": ["slice_1_" + str(i) + "_out"]},
outputs={
"Out": ["reshape_1_" + str(i) + "_out"],
"XShape": ["reshape_1_" + str(i) + "_xshape"],
},
shape=[-1, 2, 64, 16, 64],
)
reshape_1s.append(reshape_1_op)
transpose_op = OpConfig(
"transpose2",
inputs={"X": ["reshape_1_" + str(i) + "_out"]},
outputs={
"Out": ["transpose_" + str(i) + "_out"],
"XShape": ["transpose_" + str(i) + "_xshape"],
},
axis=[1, 0, 3, 2, 4],
)
transposes.append(transpose_op)
out_names.append("transpose_" + str(i) + "_out")
ops = [reshape_op, unstack_op, concat_op]
ops.extend(slice_0s)
ops.extend(reshape_0s)
ops.extend(slice_1s)
ops.extend(reshape_1s)
ops.extend(transposes)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"reshape_x": TensorConfig(shape=reshape_x_shape),
},
outputs=out_names,
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=1,
min_success_num=1,
passes=["reshape_unstack_concat_fuse_pass"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册