Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2039115c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2039115c
编写于
5月 05, 2023
作者:
S
shentanyue
提交者:
GitHub
5月 05, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU] Fusion of gather and assign operators to fused_mt op for reducing memory usage (#53262)
上级
d27f15ed
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
277 addition
and
83 deletion
+277
-83
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+4
-4
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+1
-1
paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass.cc
...luid/framework/ir/xpu/fused_multi_transformer_xpu_pass.cc
+118
-24
paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass_tester.cc
...amework/ir/xpu/fused_multi_transformer_xpu_pass_tester.cc
+62
-9
paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass.cc
paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass.cc
+0
-17
paddle/fluid/framework/ir/xpu/pass_utils.cc
paddle/fluid/framework/ir/xpu/pass_utils.cc
+17
-0
paddle/fluid/framework/ir/xpu/pass_utils.h
paddle/fluid/framework/ir/xpu/pass_utils.h
+3
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-1
paddle/phi/api/yaml/fused_ops.yaml
paddle/phi/api/yaml/fused_ops.yaml
+2
-2
paddle/phi/infermeta/fusion.cc
paddle/phi/infermeta/fusion.cc
+2
-7
paddle/phi/infermeta/fusion.h
paddle/phi/infermeta/fusion.h
+2
-0
paddle/phi/kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc
.../kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc
+65
-18
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
2039115c
...
...
@@ -242,7 +242,7 @@ if(WITH_XPU)
pass_library
(
one_beam_size_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
delete_isolated_node_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
fused_multi_transformer_xpu_
quant_
pass inference DIR xpu DEPS
pass_library
(
fused_multi_transformer_xpu_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
stack_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
fused_multi_transformer_cachekv_layout_trans_pass inference DIR
...
...
@@ -519,9 +519,9 @@ if(WITH_XPU)
SRCS xpu/delete_isolated_node_pass_test.cc
DEPS delete_isolated_node_pass
)
cc_test
(
test_fused_multi_transformer_xpu_
quant_
pass
SRCS xpu/fused_multi_transformer_xpu_
quant_
pass_tester.cc
DEPS fused_multi_transformer_xpu_
quant_
pass
)
test_fused_multi_transformer_xpu_pass
SRCS xpu/fused_multi_transformer_xpu_pass_tester.cc
DEPS fused_multi_transformer_xpu_pass
)
cc_test
(
test_one_beam_size_fuse_pass
SRCS xpu/one_beam_size_fuse_pass_test.cc
...
...
paddle/fluid/framework/ir/pass.cc
浏览文件 @
2039115c
...
...
@@ -65,7 +65,7 @@ static const std::vector<std::string> xpu_support_subgraph_passes = {
"fused_multi_transformer_cachekv_layout_trans_pass"
,
"one_beam_size_fuse_pass"
,
"stack_fuse_pass"
,
"fused_multi_transformer_xpu_
quant_
pass"
,
"fused_multi_transformer_xpu_pass"
,
"fc_xpu_fuse_pass"
,
"link_xpu_op_max_pass"
,
};
...
...
paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_
quant_
pass.cc
→
paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass.cc
浏览文件 @
2039115c
...
...
@@ -39,6 +39,32 @@ namespace framework {
namespace
ir
{
namespace
patterns
{
struct
FusedMultiTransformerAssignPattern
:
public
PatternBase
{
FusedMultiTransformerAssignPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
);
// declare operator node's name
PATTERN_DECL_NODE
(
assign
);
// declare variable node's name
PATTERN_DECL_NODE
(
assign_out
);
};
FusedMultiTransformerAssignPattern
::
FusedMultiTransformerAssignPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
name_scope
)
{
auto
*
assign
=
pattern
->
NewNode
(
assign_repr
())
->
assert_is_op
(
"assign"
)
->
assert_more
([
&
](
Node
*
node
)
{
auto
pre_op_nodes
=
node
->
inputs
[
0
]
->
inputs
;
return
pre_op_nodes
.
size
()
==
1
&&
pre_op_nodes
[
0
]
->
Op
()
->
Type
()
==
"fused_multi_transformer"
;
});
auto
*
assign_out
=
pattern
->
NewNode
(
assign_out_repr
())
->
assert_is_op_output
(
"assign"
,
"Out"
);
assign
->
LinksTo
({
assign_out
});
}
struct
FusedMultiTransformerPattern
:
public
PatternBase
{
FusedMultiTransformerPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
...
...
@@ -47,7 +73,6 @@ struct FusedMultiTransformerPattern : public PatternBase {
bool
with_time_step
,
bool
with_seq_lengths
,
bool
with_src_mask
);
// declare operator node's name
PATTERN_DECL_NODE
(
fused_mt
);
// declare variable node's name
...
...
@@ -234,39 +259,101 @@ FusedMultiTransformerPattern::FusedMultiTransformerPattern(
}
// namespace patterns
/*
1. transpose and quantify the weights of fused_multi_transformer op from fp32 to
1. Remove gather and assign op to reduce graphics memory consumption
2. transpose and quantify the weights of fused_multi_transformer op from fp32 to
int16
*/
class
FusedMultiTransformerXPU
Quant
Pass
:
public
FusePassBase
{
class
FusedMultiTransformerXPUPass
:
public
FusePassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
int
ApplyImpl
(
ir
::
Graph
*
graph
,
/*
Origin subgraph:
fused_multi_transformer
| | |
assign assign ...
| | |
gather gather ...
Fused subgraph:
fused_multi_transformer
*/
void
RemoveAssignGather
(
ir
::
Graph
*
graph
)
const
;
/*
Origin subgraph:
fused_multi_transformer
Fused subgraph:
fused_multi_transformer_xpu
*/
int
FusedMultiTransformerXPUQuant
(
ir
::
Graph
*
graph
,
bool
with_pre_caches
,
bool
with_rotary_pos_emb
,
bool
with_time_step
,
bool
with_seq_lengths
,
bool
with_src_mask
)
const
;
const
std
::
string
name_scope_
{
"fused_multi_transformer_xpu_
quant_
pass"
};
const
std
::
string
name_scope_
{
"fused_multi_transformer_xpu_pass"
};
};
void
FusedMultiTransformerXPU
Quant
Pass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
FusedMultiTransformerXPUPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
Init
(
name_scope_
,
graph
);
VLOG
(
3
)
<<
"in FusedMultiTransformerXPU
Quant
Pass::ApplyImpl"
;
VLOG
(
3
)
<<
"in FusedMultiTransformerXPUPass::ApplyImpl"
;
int
found_subgraph_count
=
0
;
RemoveAssignGather
(
graph
);
for
(
bool
with_time_step
:
{
true
,
false
})
{
found_subgraph_count
+=
ApplyImpl
(
graph
,
false
,
false
,
with_time_step
,
false
,
true
);
found_subgraph_count
+=
FusedMultiTransformerXPUQuant
(
graph
,
false
,
false
,
with_time_step
,
false
,
true
);
}
AddStatis
(
found_subgraph_count
);
}
int
FusedMultiTransformerXPUQuantPass
::
ApplyImpl
(
ir
::
Graph
*
graph
,
void
FusedMultiTransformerXPUPass
::
RemoveAssignGather
(
ir
::
Graph
*
graph
)
const
{
// detect assign + gather
GraphPatternDetector
gpd
;
patterns
::
FusedMultiTransformerAssignPattern
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
);
int
found_subgraph_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
1
)
<<
"handle RemoveAssignGather"
;
GET_IR_NODE
(
assign
);
GET_IR_NODE
(
assign_out
);
// Assign_out may not link to gather, so we find gather by input name.
auto
next_ops
=
FindOpNodeByInputName
(
graph
,
assign_out
->
Name
());
if
(
next_ops
.
size
()
!=
1
||
next_ops
[
0
]
->
Name
()
!=
"gather"
)
return
;
auto
*
gather
=
next_ops
[
0
];
// "assign_out" is used in multi blocks. "assign_out" should be reserved.
auto
*
gather_index
=
gather
->
inputs
[
0
];
auto
*
assign_in
=
assign
->
inputs
[
0
];
auto
*
fused_multi_transformer
=
assign_in
->
inputs
[
0
];
fused_multi_transformer
->
Op
()
->
Rename
(
assign_in
->
Name
(),
assign_out
->
Name
());
fused_multi_transformer
->
Op
()
->
SetInput
(
"gather_index"
,
gather
->
Op
()
->
Input
(
"Index"
));
fused_multi_transformer
->
Op
()
->
SetAttr
(
"gather_axis"
,
gather
->
Op
()
->
GetAttr
(
"axis"
));
IR_NODE_LINK_TO
(
gather_index
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
assign_out
);
std
::
unordered_set
<
const
Node
*>
delete_nodes
{
assign
,
assign_in
,
gather
};
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
}
int
FusedMultiTransformerXPUPass
::
FusedMultiTransformerXPUQuant
(
ir
::
Graph
*
graph
,
bool
with_pre_caches
,
bool
with_rotary_pos_emb
,
bool
with_time_step
,
...
...
@@ -286,7 +373,7 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
int
found_subgraph_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
4
)
<<
"handle FusedMultiTransformerXPUQuant
Pass fuse
"
;
VLOG
(
4
)
<<
"handle FusedMultiTransformerXPUQuant"
;
GET_IR_NODE
(
x
);
GET_IR_NODE
(
ln_scale
);
...
...
@@ -459,6 +546,13 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
if
(
name_caches
.
count
(
"CacheKV"
)
>
0
)
{
fused_mt_xpu_op_desc
->
SetInput
(
"cache_kv"
,
name_caches
.
at
(
"CacheKV"
));
}
if
(
name_caches
.
count
(
"gather_index"
)
>
0
)
{
fused_mt_xpu_op_desc
->
SetInput
(
"gather_index"
,
name_caches
.
at
(
"gather_index"
));
}
if
(
!
fused_mt_xpu_op_desc
->
HasAttr
(
"gather_axis"
))
{
fused_mt_xpu_op_desc
->
SetAttr
(
"gather_axis"
,
0
);
}
if
(
pre_caches
)
{
fused_mt_xpu_op_desc
->
SetInput
(
"pre_caches"
,
name_caches
.
at
(
"PreCaches"
));
}
...
...
@@ -529,5 +623,5 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
fused_multi_transformer_xpu_
quant_
pass
,
paddle
::
framework
::
ir
::
FusedMultiTransformerXPU
Quant
Pass
);
REGISTER_PASS
(
fused_multi_transformer_xpu_pass
,
paddle
::
framework
::
ir
::
FusedMultiTransformerXPUPass
);
paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_
quant_
pass_tester.cc
→
paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_pass_tester.cc
浏览文件 @
2039115c
...
...
@@ -64,7 +64,62 @@ Scope* CreateParamScope() {
return
param_scope
;
}
TEST
(
FusedMultiTransformerXPUQuantPass
,
context_stage
)
{
VarDesc
*
Data
(
paddle
::
framework
::
BlockDesc
*
block
,
std
::
string
name
,
std
::
vector
<
int64_t
>
shape
=
{},
bool
is_persistable
=
false
,
proto
::
VarType
::
Type
data_type
=
proto
::
VarType
::
FP32
)
{
auto
*
var
=
block
->
Var
(
name
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
var
->
SetDataType
(
data_type
);
var
->
SetShape
(
shape
);
var
->
SetPersistable
(
is_persistable
);
return
var
;
}
TEST
(
RemoveAssignGather
,
basic
)
{
paddle
::
framework
::
ProgramDesc
program
;
auto
*
block
=
program
.
MutableBlock
(
0
);
auto
*
x
=
Data
(
block
,
"fused_multi_transformer_x"
,
{
1
,
1
,
1536
});
auto
*
cache_kv
=
Data
(
block
,
"fused_multi_transformer_cache_kv"
,
{
2
,
1
,
24
,
512
,
64
});
OpDesc
*
fused_multi_transformer_op
=
block
->
AppendOp
();
fused_multi_transformer_op
->
SetType
(
"fused_multi_transformer"
);
fused_multi_transformer_op
->
SetInput
(
"X"
,
{
x
->
Name
()});
fused_multi_transformer_op
->
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
fused_multi_transformer_op
->
SetOutput
(
"CacheKVOut"
,
{
cache_kv
->
Name
()});
auto
*
assign_out
=
Data
(
block
,
"assign_out"
,
cache_kv
->
GetShape
());
OpDesc
*
assign_op
=
block
->
AppendOp
();
assign_op
->
SetType
(
"assign"
);
assign_op
->
SetInput
(
"X"
,
{
cache_kv
->
Name
()});
assign_op
->
SetOutput
(
"Out"
,
{
assign_out
->
Name
()});
OpDesc
*
gather_op
=
block
->
AppendOp
();
auto
gather_index
=
Data
(
block
,
"gather_index"
,
{
10
});
gather_op
->
SetType
(
"gather"
);
gather_op
->
SetInput
(
"X"
,
{
assign_out
->
Name
()});
gather_op
->
SetInput
(
"Index"
,
{
gather_index
->
Name
()});
gather_op
->
SetAttr
(
"axis"
,
{
1
});
gather_op
->
SetOutput
(
"Out"
,
{
cache_kv
->
Name
()});
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
program
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"fused_multi_transformer_xpu_pass"
);
pass
->
Apply
(
graph
.
get
());
auto
assign_num
=
GetNumOpNodes
(
graph
,
"assign"
);
auto
gather_num
=
GetNumOpNodes
(
graph
,
"gather"
);
PADDLE_ENFORCE_EQ
(
assign_num
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"assign op should be removed from the graph."
));
PADDLE_ENFORCE_EQ
(
gather_num
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"gather op should be removed from the graph."
));
}
TEST
(
FusedMultiTransformerXPUPass
,
context_stage
)
{
DEF_INPUT_DATA
auto
*
cache_kv
=
layers
.
fill_constant_batch_size_like
(
...
...
@@ -95,10 +150,9 @@ TEST(FusedMultiTransformerXPUQuantPass, context_stage) {
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"fused_multi_transformer_xpu_quant_pass"
);
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"fused_multi_transformer_xpu_pass"
);
if
(
pass
.
get
()
==
nullptr
)
{
LOG
(
INFO
)
<<
"get fused_multi_transformer_xpu_
quant_
pass failed"
;
LOG
(
INFO
)
<<
"get fused_multi_transformer_xpu_pass failed"
;
}
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
...
...
@@ -114,7 +168,7 @@ TEST(FusedMultiTransformerXPUQuantPass, context_stage) {
num_nodes_after
));
}
TEST
(
FusedMultiTransformerXPU
Quant
Pass
,
decoder_stage
)
{
TEST
(
FusedMultiTransformerXPUPass
,
decoder_stage
)
{
DEF_INPUT_DATA
auto
*
cache_kv
=
layers
.
fill_constant_batch_size_like
(
...
...
@@ -146,10 +200,9 @@ TEST(FusedMultiTransformerXPUQuantPass, decoder_stage) {
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"fused_multi_transformer_xpu_quant_pass"
);
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"fused_multi_transformer_xpu_pass"
);
if
(
pass
.
get
()
==
nullptr
)
{
LOG
(
INFO
)
<<
"get fused_multi_transformer_xpu_
quant_
pass failed"
;
LOG
(
INFO
)
<<
"get fused_multi_transformer_xpu_pass failed"
;
}
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
...
...
@@ -169,4 +222,4 @@ TEST(FusedMultiTransformerXPUQuantPass, decoder_stage) {
}
// namespace framework
}
// namespace paddle
USE_PASS
(
fused_multi_transformer_xpu_
quant_
pass
);
USE_PASS
(
fused_multi_transformer_xpu_pass
);
paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass.cc
浏览文件 @
2039115c
...
...
@@ -259,23 +259,6 @@ bool OnlyOneBeamSearchAndOneBeamSize(ir::Graph* graph) {
beam_search_nodes
[
0
]
->
Op
()
->
GetAttrIfExists
<
int
>
(
"beam_size"
)
==
1
;
}
std
::
vector
<
Node
*>
FindOpNodeByInputName
(
Graph
*
graph
,
const
std
::
string
&
var_name
)
{
std
::
vector
<
Node
*>
ret
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsOp
())
continue
;
auto
inputs
=
node
->
Op
()
->
Inputs
();
for
(
auto
input
:
inputs
)
{
auto
in_names
=
input
.
second
;
if
(
std
::
count
(
in_names
.
begin
(),
in_names
.
end
(),
var_name
)
>
0
)
{
ret
.
push_back
(
node
);
break
;
}
}
}
return
ret
;
}
void
OneBeamSizeFusePass
::
RemoveAssignGather
(
ir
::
Graph
*
graph
)
const
{
// detect assign + gather
GraphPatternDetector
gpd
;
...
...
paddle/fluid/framework/ir/xpu/pass_utils.cc
浏览文件 @
2039115c
...
...
@@ -71,6 +71,23 @@ Node* FindNodeWithName(Graph* graph, std::string name) {
return
nullptr
;
}
std
::
vector
<
Node
*>
FindOpNodeByInputName
(
Graph
*
graph
,
const
std
::
string
&
var_name
)
{
std
::
vector
<
Node
*>
ret
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsOp
())
continue
;
auto
inputs
=
node
->
Op
()
->
Inputs
();
for
(
auto
input
:
inputs
)
{
auto
in_names
=
input
.
second
;
if
(
std
::
count
(
in_names
.
begin
(),
in_names
.
end
(),
var_name
)
>
0
)
{
ret
.
push_back
(
node
);
break
;
}
}
}
return
ret
;
}
template
<
typename
T
>
std
::
string
IntTypeToString
()
{
LOG
(
FATAL
)
<<
"Not support type."
;
...
...
paddle/fluid/framework/ir/xpu/pass_utils.h
浏览文件 @
2039115c
...
...
@@ -51,6 +51,9 @@ int ConvertActivationType(std::string act_type);
Node
*
FindNodeWithName
(
Graph
*
graph
,
std
::
string
name
);
std
::
vector
<
Node
*>
FindOpNodeByInputName
(
Graph
*
graph
,
const
std
::
string
&
var_name
);
template
<
typename
T
>
size_t
HashTensor
(
const
phi
::
DenseTensor
&
in
);
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
2039115c
...
...
@@ -523,7 +523,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"one_beam_size_fuse_pass"
,
"delete_cast_op_pass"
,
"stack_fuse_pass"
,
"fused_multi_transformer_xpu_
quant_
pass"
,
"fused_multi_transformer_xpu_pass"
,
"fc_xpu_fuse_pass"
,
"conv2d_xpu_fuse_pass"
,
"link_xpu_op_max_pass"
,
...
...
paddle/phi/api/yaml/fused_ops.yaml
浏览文件 @
2039115c
...
...
@@ -58,14 +58,14 @@
support_dygraph_mode
:
true
-
op
:
fused_multi_transformer_xpu
args
:
(Tensor x, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] qkvw, Tensor[] qkvw_max, Tensor[] qkv_bias, Tensor[] out_linear_w, Tensor[] out_linear_wmax, Tensor[] out_linear_bias, Tensor[] ffn_ln_scale, Tensor[] ffn_ln_bias, Tensor[] ffn1_weight, Tensor[] ffn1_weight_max, Tensor[] ffn1_bias, Tensor[] ffn2_weight, Tensor[] ffn2_weight_max, Tensor[] ffn2_bias, Tensor[] cache_kv, Tensor[] pre_caches, Tensor rotary_pos_emb, Tensor time_step, Tensor seq_lengths, Tensor src_mask,
bool pre_layer_norm, int rotary_emb_dims, float epsilon, float dropout_rate, bool is_test, str dropout_implementation, str act_method, bool trans_qkvw, int ring_id
)
args
:
(Tensor x, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] qkvw, Tensor[] qkvw_max, Tensor[] qkv_bias, Tensor[] out_linear_w, Tensor[] out_linear_wmax, Tensor[] out_linear_bias, Tensor[] ffn_ln_scale, Tensor[] ffn_ln_bias, Tensor[] ffn1_weight, Tensor[] ffn1_weight_max, Tensor[] ffn1_bias, Tensor[] ffn2_weight, Tensor[] ffn2_weight_max, Tensor[] ffn2_bias, Tensor[] cache_kv, Tensor[] pre_caches, Tensor rotary_pos_emb, Tensor time_step, Tensor seq_lengths, Tensor src_mask,
Tensor gather_index, bool pre_layer_norm, int rotary_emb_dims, float epsilon, float dropout_rate, bool is_test, str dropout_implementation, str act_method, bool trans_qkvw, int ring_id, int gather_axis
)
output
:
Tensor(out), Tensor[](cache_kv_out){out_linear_w.size()}
infer_meta
:
func
:
FusedMultiTransformerXpuInferMeta
kernel
:
func
:
fused_multi_transformer_xpu
data_type
:
x
optional
:
cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask
optional
:
cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask
, gather_index
-
op
:
generate_sequence_xpu
args
:
(Tensor x, DataType dtype)
...
...
paddle/phi/infermeta/fusion.cc
浏览文件 @
2039115c
...
...
@@ -278,6 +278,7 @@ void FusedMultiTransformerXpuInferMeta(
const
std
::
vector
<
const
MetaTensor
*>&
time_step
,
const
std
::
vector
<
const
MetaTensor
*>&
seq_lengths
,
const
std
::
vector
<
const
MetaTensor
*>&
src_mask
,
const
std
::
vector
<
const
MetaTensor
*>&
gather_index
,
bool
pre_layer_norm
,
int
rotary_emb_dims
,
float
epsilon
,
...
...
@@ -287,6 +288,7 @@ void FusedMultiTransformerXpuInferMeta(
const
std
::
string
&
act_method
,
bool
trans_qkvw
,
int
ring_id
,
int
gather_axis
,
MetaTensor
*
out
,
std
::
vector
<
MetaTensor
*>
cache_kv_out
)
{
auto
x_dim
=
x
.
dims
();
...
...
@@ -325,13 +327,6 @@ void FusedMultiTransformerXpuInferMeta(
phi
::
errors
::
InvalidArgument
(
"The first dim of CacheKV must be 2, but got %d"
,
c_dim
[
0
]));
// 2
PADDLE_ENFORCE_EQ
(
c_dim
[
2
],
x_dim
[
0
],
phi
::
errors
::
InvalidArgument
(
"The third dim of CacheKV must be equal "
"with batch size %d, but got %d"
,
x_dim
[
0
],
c_dim
[
2
]));
// batch_size
PADDLE_ENFORCE_EQ
(
c_dim
[
3
],
trans_qkvw
?
y_dim
[
1
]
:
y_dim
[
2
],
...
...
paddle/phi/infermeta/fusion.h
浏览文件 @
2039115c
...
...
@@ -108,6 +108,7 @@ void FusedMultiTransformerXpuInferMeta(
const
std
::
vector
<
const
MetaTensor
*>&
time_step
,
const
std
::
vector
<
const
MetaTensor
*>&
seq_lengths
,
const
std
::
vector
<
const
MetaTensor
*>&
src_mask
,
const
std
::
vector
<
const
MetaTensor
*>&
gather_index
,
bool
pre_layer_norm
,
int
rotary_emb_dims
,
float
epsilon
,
...
...
@@ -117,6 +118,7 @@ void FusedMultiTransformerXpuInferMeta(
const
std
::
string
&
act_method
,
bool
trans_qkvw
,
int
ring_id
,
int
gather_axis
,
MetaTensor
*
out
,
std
::
vector
<
MetaTensor
*>
cache_kv_out
);
}
// namespace phi
paddle/phi/kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc
浏览文件 @
2039115c
...
...
@@ -17,6 +17,8 @@
#include "glog/logging.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/kernels/memcpy_kernel.h"
#ifdef PADDLE_WITH_XPU_XFT
#include "models/fused_multi_transformer_op.h"
...
...
@@ -52,6 +54,7 @@ void FusedMultiTransformerXpuKernel(
const
paddle
::
optional
<
DenseTensor
>&
time_step
,
const
paddle
::
optional
<
DenseTensor
>&
seq_lengths
,
const
paddle
::
optional
<
DenseTensor
>&
src_mask
,
const
paddle
::
optional
<
DenseTensor
>&
gather_index
,
bool
pre_layer_norm
,
int
rotary_emb_dims
,
float
epsilon
,
...
...
@@ -61,6 +64,7 @@ void FusedMultiTransformerXpuKernel(
const
std
::
string
&
act_method
,
bool
trans_qkvw
,
int
ring_id
,
int
gather_axis
,
DenseTensor
*
out
,
std
::
vector
<
DenseTensor
*>
cache_kv_out
)
{
#ifdef PADDLE_WITH_XPU_XFT
...
...
@@ -160,6 +164,21 @@ void FusedMultiTransformerXpuKernel(
std
::
vector
<
xft
::
xftTensor
<
XPUTypeT
,
5
>>
xft_cache_kv
;
std
::
vector
<
xft
::
xftTensor
<
XPUTypeT
,
5
>>
xft_cache_kv_out
;
// Create a temporary Tensor to store the gather output of cache_kv
auto
gather_index_t
=
gather_index
.
get_ptr
();
auto
cache_kv_dims
=
cache_kv
.
get_ptr
()
->
at
(
0
)
->
dims
();
auto
cache_kv_gather_dims
=
cache_kv_dims
;
phi
::
DenseTensor
cache_kv_gather_tensor
;
if
(
gather_index_t
)
{
MetaTensor
cache_kv_gather_meta
(
&
cache_kv_gather_tensor
);
phi
::
GatherInferMeta
(
*
cache_kv
.
get_ptr
()
->
at
(
0
),
*
gather_index_t
,
Scalar
(
gather_axis
),
&
cache_kv_gather_meta
);
cache_kv_gather_dims
=
cache_kv_gather_meta
.
dims
();
ctx
.
template
Alloc
<
T
>(
&
cache_kv_gather_tensor
);
}
int
layers
=
qkvw
.
size
();
for
(
int
i
=
0
;
i
<
layers
;
++
i
)
{
// step1. layer_norm
...
...
@@ -211,27 +230,55 @@ void FusedMultiTransformerXpuKernel(
xft_ffn2_bias
.
emplace_back
(
const_cast
<
float
*>
(
ffn2_bias
[
i
]
->
data
<
float
>
()),
std
::
array
<
int64_t
,
1
>
{
ffn2_bias
[
i
]
->
dims
()[
0
]});
// cache kv in
if
(
time_step_value
>
0
)
{
auto
cachekv_dims
=
cache_kv
.
get_ptr
()
->
at
(
i
)
->
dims
();
xft_cache_kv
.
emplace_back
(
reinterpret_cast
<
XPUTypeT
*>
(
const_cast
<
T
*>
(
cache_kv
.
get_ptr
()
->
at
(
i
)
->
data
<
T
>
())),
std
::
array
<
int64_t
,
5
>
{
cachekv_dims
[
0
],
cachekv_dims
[
1
],
cachekv_dims
[
2
],
cachekv_dims
[
3
],
cachekv_dims
[
4
]});
auto
cache_kv_data
=
reinterpret_cast
<
XPUTypeT
*>
(
const_cast
<
T
*>
(
cache_kv
.
get_ptr
()
->
at
(
i
)
->
data
<
T
>
()));
if
(
gather_index_t
)
{
const
auto
&
index_type
=
gather_index_t
->
dtype
();
if
(
index_type
==
DataType
::
INT32
)
{
r
=
xpu
::
gather
<
XPUTypeT
,
int32_t
>
(
ctx
.
x_context
(),
cache_kv_data
,
gather_index_t
->
data
<
int32_t
>
(),
reinterpret_cast
<
XPUTypeT
*>
(
cache_kv_gather_tensor
.
data
<
T
>
()),
phi
::
vectorize
<
int32_t
>
(
cache_kv_dims
),
gather_index_t
->
dims
().
size
()
==
0
?
1
:
gather_index_t
->
dims
()[
0
],
gather_axis
);
}
else
{
r
=
xpu
::
gather
<
XPUTypeT
,
int64_t
>
(
ctx
.
x_context
(),
cache_kv_data
,
gather_index_t
->
data
<
int64_t
>
(),
reinterpret_cast
<
XPUTypeT
*>
(
cache_kv_gather_tensor
.
data
<
T
>
()),
phi
::
vectorize
<
int32_t
>
(
cache_kv_dims
),
gather_index_t
->
dims
().
size
()
==
0
?
1
:
gather_index_t
->
dims
()[
0
],
gather_axis
);
}
// cache kv out
auto
cachekv_out_dims
=
cache_kv_out
[
i
]
->
dims
();
xft_cache_kv_out
.
emplace_back
(
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"xpu::gather"
);
cache_kv_out
[
i
]
->
ResizeAndAllocate
(
cache_kv_gather_dims
);
r
=
xpu
::
copy
<
XPUTypeT
>
(
ctx
.
x_context
(),
reinterpret_cast
<
XPUTypeT
*>
(
cache_kv_gather_tensor
.
data
<
T
>
()),
reinterpret_cast
<
XPUTypeT
*>
(
ctx
.
template
Alloc
<
T
>(
cache_kv_out
[
i
])),
std
::
array
<
int64_t
,
5
>
{
cachekv_out_dims
[
0
],
cachekv_out_dims
[
1
],
cachekv_out_dims
[
2
],
cachekv_out_dims
[
3
],
cachekv_out_dims
[
4
]});
cache_kv_out
[
i
]
->
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"xpu::copy"
);
}
cache_kv_data
=
reinterpret_cast
<
XPUTypeT
*>
(
const_cast
<
T
*>
(
cache_kv
.
get_ptr
()
->
at
(
i
)
->
data
<
T
>
()));
xft_cache_kv
.
emplace_back
(
cache_kv_data
,
std
::
array
<
int64_t
,
5
>
{
cache_kv_gather_dims
[
0
],
cache_kv_gather_dims
[
1
],
cache_kv_gather_dims
[
2
],
cache_kv_gather_dims
[
3
],
cache_kv_gather_dims
[
4
]});
// cache kv out direct use cache_kv_data
xft_cache_kv_out
.
emplace_back
(
cache_kv_data
,
std
::
array
<
int64_t
,
5
>
{
cache_kv_gather_dims
[
0
],
cache_kv_gather_dims
[
1
],
cache_kv_gather_dims
[
2
],
cache_kv_gather_dims
[
3
],
cache_kv_gather_dims
[
4
]});
}
xft
::
NlpParam
param
;
param
.
num_layer
=
layers
;
param
.
n_head
=
num_head
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录