未验证 提交 41483383 编写于 作者: R RichardWooSJTU 提交者: GitHub

delete unnecessary shape and slice op (#48112)

上级 55f6fb3d
......@@ -62,34 +62,7 @@ MultiTransformerLayerPattern::operator()(bool enable_int8,
fused_multi_transformer_name, "Out");
if (is_decoder) {
auto shape_repr =
PDNodeName(name_scope_, repr_, id_, "shape_" + std::to_string(i));
node_reprs["shape_" + std::to_string(i)] = shape_repr;
auto* shape = pattern->NewNode(shape_repr)->assert_is_op("shape");
auto shape_out_repr =
PDNodeName(name_scope_, repr_, id_, "shape_out_" + std::to_string(i));
node_reprs["shape_out_" + std::to_string(i)] = shape_out_repr;
auto* shape_out =
pattern->NewNode(shape_out_repr)->assert_is_op_output("shape", "Out");
shape->LinksFrom({src_mask}).LinksTo({shape_out});
auto slice_repr =
PDNodeName(name_scope_, repr_, id_, "slice_" + std::to_string(i));
node_reprs["slice_" + std::to_string(i)] = slice_repr;
auto* slice = pattern->NewNode(slice_repr)->assert_is_op("slice");
auto slice_out_repr =
PDNodeName(name_scope_, repr_, id_, "slice_out_" + std::to_string(i));
node_reprs["slice_out_" + std::to_string(i)] = slice_out_repr;
auto* slice_out =
pattern->NewNode(slice_out_repr)->assert_is_op_output("slice", "Out");
slice->LinksFrom({shape_out}).LinksTo({slice_out});
fused_multi_transformer->LinksFrom({x0, src_mask, slice_out})
.LinksTo({out});
fused_multi_transformer->LinksFrom({x0, src_mask}).LinksTo({out});
} else {
auto cache_kv_repr =
PDNodeName(name_scope_, repr_, id_, "cache_kv_" + std::to_string(i));
......@@ -187,10 +160,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
std::vector<Node*> fuse_op_nodes;
std::vector<Node*> out_nodes;
std::vector<std::string> unused_node_prefixes = {
"shape_", "shape_out_", "slice_", "slice_out_"};
std::vector<Node*> unused_nodes;
std::vector<OpDesc*> fuse_op_descs;
std::vector<VariableNameMap> fuse_op_input_var_name_maps;
std::vector<VariableNameMap> fuse_op_output_var_name_maps;
......@@ -219,14 +188,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
fill_op_node->Op()->SetInput("Input", {x0->Name()});
IR_NODE_UNLINK(out_nodes[i - 1], fill_op_node);
IR_NODE_LINK_TO(x0, fill_op_node);
} else if (is_decoder && i != 0) {
for (const auto& unused_node_prefix : unused_node_prefixes) {
PDNode* unused_pdnode =
multi_layer_pattern.PatternBase::pattern->RetrieveNode(
node_reprs[unused_node_prefix + std::to_string(i)]);
Node* unused_node = subgraph.at(unused_pdnode);
unused_nodes.push_back(unused_node);
}
}
}
......@@ -293,10 +254,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
std::unordered_set<const Node*> marked_fuse_op_nodes(
fuse_op_nodes.begin() + 1, fuse_op_nodes.end());
if (is_decoder) {
marked_fuse_op_nodes.insert(unused_nodes.begin(), unused_nodes.end());
}
GraphSafeRemoveNodes(graph, marked_fuse_op_nodes);
++fusion_count;
};
......
......@@ -1146,35 +1146,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
auto cache_kv_name = "cache_kv" + std::to_string(layer_idx);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name});
VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx));
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);
OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);
VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx));
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);
OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);
fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()});
fused_multi_transformer_op_desc.SetInput("TimeStep", {"slice_out.0"});
// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
......@@ -1219,12 +1191,42 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);
// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
if (layer_idx == 0) {
VarDesc shape_out_desc("shape_out.0");
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);
OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);
VarDesc slice_out_desc("slice_out.0");
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);
OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);
// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
}
IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer);
......@@ -1789,35 +1791,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
auto cache_kv_name = "cache_kv" + std::to_string(layer_idx);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name});
VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx));
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);
OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);
VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx));
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);
OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);
fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()});
fused_multi_transformer_op_desc.SetInput("TimeStep", {"slice_out.0"});
// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
......@@ -1862,12 +1836,42 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);
// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
if (layer_idx == 0) {
VarDesc shape_out_desc("shape_out.0");
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);
OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);
VarDesc slice_out_desc("slice_out.0");
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);
OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);
// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
}
IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer);
......@@ -2405,35 +2409,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
auto cache_kv_name = "cache_kv" + std::to_string(layer_idx);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name});
VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx));
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);
OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);
VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx));
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);
OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);
fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()});
fused_multi_transformer_op_desc.SetInput("TimeStep", {"slice_out.0"});
// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
......@@ -2483,12 +2459,42 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);
// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
if (layer_idx == 0) {
VarDesc shape_out_desc("shape_out.0");
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);
OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);
VarDesc slice_out_desc("slice_out.0");
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);
OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);
// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
}
IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer);
......
......@@ -177,6 +177,7 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"fused_multi_transformer_decoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass",
"fuse_multi_transformer_layer_pass",
"gpu_cpu_map_matmul_v2_to_mul_pass",
"gpu_cpu_map_matmul_v2_to_matmul_pass",
"fc_fuse_pass",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册