Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
41483383
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
41483383
编写于
11月 21, 2022
作者:
R
RichardWooSJTU
提交者:
GitHub
11月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
delete unnecessary shape and slice op (#48112)
上级
55f6fb3d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
113 addition
and
149 deletion
+113
-149
paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc
...e/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc
+1
-44
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
...luid/framework/ir/fused_multi_transformer_decoder_pass.cc
+111
-105
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
未找到文件。
paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc
浏览文件 @
41483383
...
@@ -62,34 +62,7 @@ MultiTransformerLayerPattern::operator()(bool enable_int8,
...
@@ -62,34 +62,7 @@ MultiTransformerLayerPattern::operator()(bool enable_int8,
fused_multi_transformer_name
,
"Out"
);
fused_multi_transformer_name
,
"Out"
);
if
(
is_decoder
)
{
if
(
is_decoder
)
{
auto
shape_repr
=
fused_multi_transformer
->
LinksFrom
({
x0
,
src_mask
}).
LinksTo
({
out
});
PDNodeName
(
name_scope_
,
repr_
,
id_
,
"shape_"
+
std
::
to_string
(
i
));
node_reprs
[
"shape_"
+
std
::
to_string
(
i
)]
=
shape_repr
;
auto
*
shape
=
pattern
->
NewNode
(
shape_repr
)
->
assert_is_op
(
"shape"
);
auto
shape_out_repr
=
PDNodeName
(
name_scope_
,
repr_
,
id_
,
"shape_out_"
+
std
::
to_string
(
i
));
node_reprs
[
"shape_out_"
+
std
::
to_string
(
i
)]
=
shape_out_repr
;
auto
*
shape_out
=
pattern
->
NewNode
(
shape_out_repr
)
->
assert_is_op_output
(
"shape"
,
"Out"
);
shape
->
LinksFrom
({
src_mask
}).
LinksTo
({
shape_out
});
auto
slice_repr
=
PDNodeName
(
name_scope_
,
repr_
,
id_
,
"slice_"
+
std
::
to_string
(
i
));
node_reprs
[
"slice_"
+
std
::
to_string
(
i
)]
=
slice_repr
;
auto
*
slice
=
pattern
->
NewNode
(
slice_repr
)
->
assert_is_op
(
"slice"
);
auto
slice_out_repr
=
PDNodeName
(
name_scope_
,
repr_
,
id_
,
"slice_out_"
+
std
::
to_string
(
i
));
node_reprs
[
"slice_out_"
+
std
::
to_string
(
i
)]
=
slice_out_repr
;
auto
*
slice_out
=
pattern
->
NewNode
(
slice_out_repr
)
->
assert_is_op_output
(
"slice"
,
"Out"
);
slice
->
LinksFrom
({
shape_out
}).
LinksTo
({
slice_out
});
fused_multi_transformer
->
LinksFrom
({
x0
,
src_mask
,
slice_out
})
.
LinksTo
({
out
});
}
else
{
}
else
{
auto
cache_kv_repr
=
auto
cache_kv_repr
=
PDNodeName
(
name_scope_
,
repr_
,
id_
,
"cache_kv_"
+
std
::
to_string
(
i
));
PDNodeName
(
name_scope_
,
repr_
,
id_
,
"cache_kv_"
+
std
::
to_string
(
i
));
...
@@ -187,10 +160,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
...
@@ -187,10 +160,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
std
::
vector
<
Node
*>
fuse_op_nodes
;
std
::
vector
<
Node
*>
fuse_op_nodes
;
std
::
vector
<
Node
*>
out_nodes
;
std
::
vector
<
Node
*>
out_nodes
;
std
::
vector
<
std
::
string
>
unused_node_prefixes
=
{
"shape_"
,
"shape_out_"
,
"slice_"
,
"slice_out_"
};
std
::
vector
<
Node
*>
unused_nodes
;
std
::
vector
<
OpDesc
*>
fuse_op_descs
;
std
::
vector
<
OpDesc
*>
fuse_op_descs
;
std
::
vector
<
VariableNameMap
>
fuse_op_input_var_name_maps
;
std
::
vector
<
VariableNameMap
>
fuse_op_input_var_name_maps
;
std
::
vector
<
VariableNameMap
>
fuse_op_output_var_name_maps
;
std
::
vector
<
VariableNameMap
>
fuse_op_output_var_name_maps
;
...
@@ -219,14 +188,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
...
@@ -219,14 +188,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
fill_op_node
->
Op
()
->
SetInput
(
"Input"
,
{
x0
->
Name
()});
fill_op_node
->
Op
()
->
SetInput
(
"Input"
,
{
x0
->
Name
()});
IR_NODE_UNLINK
(
out_nodes
[
i
-
1
],
fill_op_node
);
IR_NODE_UNLINK
(
out_nodes
[
i
-
1
],
fill_op_node
);
IR_NODE_LINK_TO
(
x0
,
fill_op_node
);
IR_NODE_LINK_TO
(
x0
,
fill_op_node
);
}
else
if
(
is_decoder
&&
i
!=
0
)
{
for
(
const
auto
&
unused_node_prefix
:
unused_node_prefixes
)
{
PDNode
*
unused_pdnode
=
multi_layer_pattern
.
PatternBase
::
pattern
->
RetrieveNode
(
node_reprs
[
unused_node_prefix
+
std
::
to_string
(
i
)]);
Node
*
unused_node
=
subgraph
.
at
(
unused_pdnode
);
unused_nodes
.
push_back
(
unused_node
);
}
}
}
}
}
...
@@ -293,10 +254,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
...
@@ -293,10 +254,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
std
::
unordered_set
<
const
Node
*>
marked_fuse_op_nodes
(
std
::
unordered_set
<
const
Node
*>
marked_fuse_op_nodes
(
fuse_op_nodes
.
begin
()
+
1
,
fuse_op_nodes
.
end
());
fuse_op_nodes
.
begin
()
+
1
,
fuse_op_nodes
.
end
());
if
(
is_decoder
)
{
marked_fuse_op_nodes
.
insert
(
unused_nodes
.
begin
(),
unused_nodes
.
end
());
}
GraphSafeRemoveNodes
(
graph
,
marked_fuse_op_nodes
);
GraphSafeRemoveNodes
(
graph
,
marked_fuse_op_nodes
);
++
fusion_count
;
++
fusion_count
;
};
};
...
...
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
浏览文件 @
41483383
...
@@ -1146,35 +1146,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1146,35 +1146,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
auto
cache_kv_name
=
"cache_kv"
+
std
::
to_string
(
layer_idx
);
auto
cache_kv_name
=
"cache_kv"
+
std
::
to_string
(
layer_idx
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv_name
});
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv_name
});
VarDesc
shape_out_desc
(
"shape_out."
+
std
::
to_string
(
layer_idx
));
fused_multi_transformer_op_desc
.
SetInput
(
"TimeStep"
,
{
"slice_out.0"
});
shape_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
shape_out_desc
.
SetPersistable
(
false
);
auto
*
shape_out
=
graph
->
CreateVarNode
(
&
shape_out_desc
);
OpDesc
shape_op_desc
(
layer_norm
->
Op
()
->
Block
());
shape_op_desc
.
SetType
(
"shape"
);
shape_op_desc
.
SetInput
(
"Input"
,
{
eltadd_qk_b
->
Name
()});
shape_op_desc
.
SetOutput
(
"Out"
,
{
shape_out
->
Name
()});
auto
*
shape_op
=
graph
->
CreateOpNode
(
&
shape_op_desc
);
VarDesc
slice_out_desc
(
"slice_out."
+
std
::
to_string
(
layer_idx
));
slice_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
slice_out_desc
.
SetPersistable
(
false
);
auto
*
slice_out
=
graph
->
CreateVarNode
(
&
slice_out_desc
);
OpDesc
slice_op_desc
(
layer_norm
->
Op
()
->
Block
());
slice_op_desc
.
SetType
(
"slice"
);
slice_op_desc
.
SetInput
(
"Input"
,
{
shape_out
->
Name
()});
slice_op_desc
.
SetOutput
(
"Out"
,
{
slice_out
->
Name
()});
std
::
vector
<
int
>
axes
=
{
0
};
std
::
vector
<
int
>
starts
=
{
3
};
std
::
vector
<
int
>
ends
=
{
4
};
slice_op_desc
.
SetAttr
(
"axes"
,
axes
);
slice_op_desc
.
SetAttr
(
"starts"
,
starts
);
slice_op_desc
.
SetAttr
(
"ends"
,
ends
);
auto
*
slice_op
=
graph
->
CreateOpNode
(
&
slice_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"TimeStep"
,
{
slice_out
->
Name
()});
// Out Linear input
// Out Linear input
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearW"
,
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearW"
,
...
@@ -1219,12 +1191,42 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
...
@@ -1219,12 +1191,42 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
IR_NODE_LINK_TO
(
eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_qk_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_qk_b
,
fused_multi_transformer
);
// TimeStep link
if
(
layer_idx
==
0
)
{
IR_NODE_LINK_TO
(
eltadd_qk_b
,
shape_op
);
VarDesc
shape_out_desc
(
"shape_out.0"
);
IR_NODE_LINK_TO
(
shape_op
,
shape_out
);
shape_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
IR_NODE_LINK_TO
(
shape_out
,
slice_op
);
shape_out_desc
.
SetPersistable
(
false
);
IR_NODE_LINK_TO
(
slice_op
,
slice_out
);
auto
*
shape_out
=
graph
->
CreateVarNode
(
&
shape_out_desc
);
IR_NODE_LINK_TO
(
slice_out
,
fused_multi_transformer
)
OpDesc
shape_op_desc
(
layer_norm
->
Op
()
->
Block
());
shape_op_desc
.
SetType
(
"shape"
);
shape_op_desc
.
SetInput
(
"Input"
,
{
eltadd_qk_b
->
Name
()});
shape_op_desc
.
SetOutput
(
"Out"
,
{
shape_out
->
Name
()});
auto
*
shape_op
=
graph
->
CreateOpNode
(
&
shape_op_desc
);
VarDesc
slice_out_desc
(
"slice_out.0"
);
slice_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
slice_out_desc
.
SetPersistable
(
false
);
auto
*
slice_out
=
graph
->
CreateVarNode
(
&
slice_out_desc
);
OpDesc
slice_op_desc
(
layer_norm
->
Op
()
->
Block
());
slice_op_desc
.
SetType
(
"slice"
);
slice_op_desc
.
SetInput
(
"Input"
,
{
shape_out
->
Name
()});
slice_op_desc
.
SetOutput
(
"Out"
,
{
slice_out
->
Name
()});
std
::
vector
<
int
>
axes
=
{
0
};
std
::
vector
<
int
>
starts
=
{
3
};
std
::
vector
<
int
>
ends
=
{
4
};
slice_op_desc
.
SetAttr
(
"axes"
,
axes
);
slice_op_desc
.
SetAttr
(
"starts"
,
starts
);
slice_op_desc
.
SetAttr
(
"ends"
,
ends
);
auto
*
slice_op
=
graph
->
CreateOpNode
(
&
slice_op_desc
);
// TimeStep link
IR_NODE_LINK_TO
(
eltadd_qk_b
,
shape_op
);
IR_NODE_LINK_TO
(
shape_op
,
shape_out
);
IR_NODE_LINK_TO
(
shape_out
,
slice_op
);
IR_NODE_LINK_TO
(
slice_op
,
slice_out
);
IR_NODE_LINK_TO
(
slice_out
,
fused_multi_transformer
)
}
IR_NODE_LINK_TO
(
matmul_linear_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
matmul_linear_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_linear_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_linear_b
,
fused_multi_transformer
);
...
@@ -1789,35 +1791,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -1789,35 +1791,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
auto
cache_kv_name
=
"cache_kv"
+
std
::
to_string
(
layer_idx
);
auto
cache_kv_name
=
"cache_kv"
+
std
::
to_string
(
layer_idx
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv_name
});
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv_name
});
VarDesc
shape_out_desc
(
"shape_out."
+
std
::
to_string
(
layer_idx
));
fused_multi_transformer_op_desc
.
SetInput
(
"TimeStep"
,
{
"slice_out.0"
});
shape_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
shape_out_desc
.
SetPersistable
(
false
);
auto
*
shape_out
=
graph
->
CreateVarNode
(
&
shape_out_desc
);
OpDesc
shape_op_desc
(
layer_norm
->
Op
()
->
Block
());
shape_op_desc
.
SetType
(
"shape"
);
shape_op_desc
.
SetInput
(
"Input"
,
{
eltadd_qk_b
->
Name
()});
shape_op_desc
.
SetOutput
(
"Out"
,
{
shape_out
->
Name
()});
auto
*
shape_op
=
graph
->
CreateOpNode
(
&
shape_op_desc
);
VarDesc
slice_out_desc
(
"slice_out."
+
std
::
to_string
(
layer_idx
));
slice_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
slice_out_desc
.
SetPersistable
(
false
);
auto
*
slice_out
=
graph
->
CreateVarNode
(
&
slice_out_desc
);
OpDesc
slice_op_desc
(
layer_norm
->
Op
()
->
Block
());
slice_op_desc
.
SetType
(
"slice"
);
slice_op_desc
.
SetInput
(
"Input"
,
{
shape_out
->
Name
()});
slice_op_desc
.
SetOutput
(
"Out"
,
{
slice_out
->
Name
()});
std
::
vector
<
int
>
axes
=
{
0
};
std
::
vector
<
int
>
starts
=
{
3
};
std
::
vector
<
int
>
ends
=
{
4
};
slice_op_desc
.
SetAttr
(
"axes"
,
axes
);
slice_op_desc
.
SetAttr
(
"starts"
,
starts
);
slice_op_desc
.
SetAttr
(
"ends"
,
ends
);
auto
*
slice_op
=
graph
->
CreateOpNode
(
&
slice_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"TimeStep"
,
{
slice_out
->
Name
()});
// Out Linear input
// Out Linear input
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearW"
,
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearW"
,
...
@@ -1862,12 +1836,42 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -1862,12 +1836,42 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
IR_NODE_LINK_TO
(
eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_qk_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_qk_b
,
fused_multi_transformer
);
// TimeStep link
if
(
layer_idx
==
0
)
{
IR_NODE_LINK_TO
(
eltadd_qk_b
,
shape_op
);
VarDesc
shape_out_desc
(
"shape_out.0"
);
IR_NODE_LINK_TO
(
shape_op
,
shape_out
);
shape_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
IR_NODE_LINK_TO
(
shape_out
,
slice_op
);
shape_out_desc
.
SetPersistable
(
false
);
IR_NODE_LINK_TO
(
slice_op
,
slice_out
);
auto
*
shape_out
=
graph
->
CreateVarNode
(
&
shape_out_desc
);
IR_NODE_LINK_TO
(
slice_out
,
fused_multi_transformer
)
OpDesc
shape_op_desc
(
layer_norm
->
Op
()
->
Block
());
shape_op_desc
.
SetType
(
"shape"
);
shape_op_desc
.
SetInput
(
"Input"
,
{
eltadd_qk_b
->
Name
()});
shape_op_desc
.
SetOutput
(
"Out"
,
{
shape_out
->
Name
()});
auto
*
shape_op
=
graph
->
CreateOpNode
(
&
shape_op_desc
);
VarDesc
slice_out_desc
(
"slice_out.0"
);
slice_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
slice_out_desc
.
SetPersistable
(
false
);
auto
*
slice_out
=
graph
->
CreateVarNode
(
&
slice_out_desc
);
OpDesc
slice_op_desc
(
layer_norm
->
Op
()
->
Block
());
slice_op_desc
.
SetType
(
"slice"
);
slice_op_desc
.
SetInput
(
"Input"
,
{
shape_out
->
Name
()});
slice_op_desc
.
SetOutput
(
"Out"
,
{
slice_out
->
Name
()});
std
::
vector
<
int
>
axes
=
{
0
};
std
::
vector
<
int
>
starts
=
{
3
};
std
::
vector
<
int
>
ends
=
{
4
};
slice_op_desc
.
SetAttr
(
"axes"
,
axes
);
slice_op_desc
.
SetAttr
(
"starts"
,
starts
);
slice_op_desc
.
SetAttr
(
"ends"
,
ends
);
auto
*
slice_op
=
graph
->
CreateOpNode
(
&
slice_op_desc
);
// TimeStep link
IR_NODE_LINK_TO
(
eltadd_qk_b
,
shape_op
);
IR_NODE_LINK_TO
(
shape_op
,
shape_out
);
IR_NODE_LINK_TO
(
shape_out
,
slice_op
);
IR_NODE_LINK_TO
(
slice_op
,
slice_out
);
IR_NODE_LINK_TO
(
slice_out
,
fused_multi_transformer
)
}
IR_NODE_LINK_TO
(
matmul_linear_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
matmul_linear_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_linear_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_linear_b
,
fused_multi_transformer
);
...
@@ -2405,35 +2409,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2405,35 +2409,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
auto
cache_kv_name
=
"cache_kv"
+
std
::
to_string
(
layer_idx
);
auto
cache_kv_name
=
"cache_kv"
+
std
::
to_string
(
layer_idx
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv_name
});
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv_name
});
VarDesc
shape_out_desc
(
"shape_out."
+
std
::
to_string
(
layer_idx
));
fused_multi_transformer_op_desc
.
SetInput
(
"TimeStep"
,
{
"slice_out.0"
});
shape_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
shape_out_desc
.
SetPersistable
(
false
);
auto
*
shape_out
=
graph
->
CreateVarNode
(
&
shape_out_desc
);
OpDesc
shape_op_desc
(
layer_norm
->
Op
()
->
Block
());
shape_op_desc
.
SetType
(
"shape"
);
shape_op_desc
.
SetInput
(
"Input"
,
{
eltadd_qk_b
->
Name
()});
shape_op_desc
.
SetOutput
(
"Out"
,
{
shape_out
->
Name
()});
auto
*
shape_op
=
graph
->
CreateOpNode
(
&
shape_op_desc
);
VarDesc
slice_out_desc
(
"slice_out."
+
std
::
to_string
(
layer_idx
));
slice_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
slice_out_desc
.
SetPersistable
(
false
);
auto
*
slice_out
=
graph
->
CreateVarNode
(
&
slice_out_desc
);
OpDesc
slice_op_desc
(
layer_norm
->
Op
()
->
Block
());
slice_op_desc
.
SetType
(
"slice"
);
slice_op_desc
.
SetInput
(
"Input"
,
{
shape_out
->
Name
()});
slice_op_desc
.
SetOutput
(
"Out"
,
{
slice_out
->
Name
()});
std
::
vector
<
int
>
axes
=
{
0
};
std
::
vector
<
int
>
starts
=
{
3
};
std
::
vector
<
int
>
ends
=
{
4
};
slice_op_desc
.
SetAttr
(
"axes"
,
axes
);
slice_op_desc
.
SetAttr
(
"starts"
,
starts
);
slice_op_desc
.
SetAttr
(
"ends"
,
ends
);
auto
*
slice_op
=
graph
->
CreateOpNode
(
&
slice_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"TimeStep"
,
{
slice_out
->
Name
()});
// Out Linear input
// Out Linear input
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearW"
,
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearW"
,
...
@@ -2483,12 +2459,42 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
...
@@ -2483,12 +2459,42 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
IR_NODE_LINK_TO
(
eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_qk_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_qk_b
,
fused_multi_transformer
);
// TimeStep link
if
(
layer_idx
==
0
)
{
IR_NODE_LINK_TO
(
eltadd_qk_b
,
shape_op
);
VarDesc
shape_out_desc
(
"shape_out.0"
);
IR_NODE_LINK_TO
(
shape_op
,
shape_out
);
shape_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
IR_NODE_LINK_TO
(
shape_out
,
slice_op
);
shape_out_desc
.
SetPersistable
(
false
);
IR_NODE_LINK_TO
(
slice_op
,
slice_out
);
auto
*
shape_out
=
graph
->
CreateVarNode
(
&
shape_out_desc
);
IR_NODE_LINK_TO
(
slice_out
,
fused_multi_transformer
)
OpDesc
shape_op_desc
(
layer_norm
->
Op
()
->
Block
());
shape_op_desc
.
SetType
(
"shape"
);
shape_op_desc
.
SetInput
(
"Input"
,
{
eltadd_qk_b
->
Name
()});
shape_op_desc
.
SetOutput
(
"Out"
,
{
shape_out
->
Name
()});
auto
*
shape_op
=
graph
->
CreateOpNode
(
&
shape_op_desc
);
VarDesc
slice_out_desc
(
"slice_out.0"
);
slice_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
slice_out_desc
.
SetPersistable
(
false
);
auto
*
slice_out
=
graph
->
CreateVarNode
(
&
slice_out_desc
);
OpDesc
slice_op_desc
(
layer_norm
->
Op
()
->
Block
());
slice_op_desc
.
SetType
(
"slice"
);
slice_op_desc
.
SetInput
(
"Input"
,
{
shape_out
->
Name
()});
slice_op_desc
.
SetOutput
(
"Out"
,
{
slice_out
->
Name
()});
std
::
vector
<
int
>
axes
=
{
0
};
std
::
vector
<
int
>
starts
=
{
3
};
std
::
vector
<
int
>
ends
=
{
4
};
slice_op_desc
.
SetAttr
(
"axes"
,
axes
);
slice_op_desc
.
SetAttr
(
"starts"
,
starts
);
slice_op_desc
.
SetAttr
(
"ends"
,
ends
);
auto
*
slice_op
=
graph
->
CreateOpNode
(
&
slice_op_desc
);
// TimeStep link
IR_NODE_LINK_TO
(
eltadd_qk_b
,
shape_op
);
IR_NODE_LINK_TO
(
shape_op
,
shape_out
);
IR_NODE_LINK_TO
(
shape_out
,
slice_op
);
IR_NODE_LINK_TO
(
slice_op
,
slice_out
);
IR_NODE_LINK_TO
(
slice_out
,
fused_multi_transformer
)
}
IR_NODE_LINK_TO
(
matmul_linear_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
matmul_linear_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_linear_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_linear_b
,
fused_multi_transformer
);
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
41483383
...
@@ -177,6 +177,7 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
...
@@ -177,6 +177,7 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"fused_multi_transformer_decoder_fuse_qkv_pass"
,
"fused_multi_transformer_decoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
,
"fuse_multi_transformer_layer_pass"
,
"gpu_cpu_map_matmul_v2_to_mul_pass"
,
"gpu_cpu_map_matmul_v2_to_mul_pass"
,
"gpu_cpu_map_matmul_v2_to_matmul_pass"
,
"gpu_cpu_map_matmul_v2_to_matmul_pass"
,
"fc_fuse_pass"
,
"fc_fuse_pass"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录