Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
52e1742f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
52e1742f
编写于
3月 20, 2023
作者:
M
mayang002
提交者:
GitHub
3月 20, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[xpu] fused_multi_transformer_xpu pass&kernel support (#51571)
上级
c36e3fd2
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
1196 addition
and
31 deletion
+1196
-31
cmake/external/xpu.cmake
cmake/external/xpu.cmake
+2
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+6
-0
paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc
.../framework/ir/fuse_multi_transformer_layer_pass_tester.cc
+4
-4
paddle/fluid/framework/ir/node.h
paddle/fluid/framework/ir/node.h
+9
-0
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+1
-0
paddle/fluid/framework/ir/pass_tester_helper.h
paddle/fluid/framework/ir/pass_tester_helper.h
+31
-27
paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc
...ramework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc
+546
-0
paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass_tester.cc
...k/ir/xpu/fused_multi_transformer_xpu_quant_pass_tester.cc
+170
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
paddle/phi/api/yaml/static_ops.yaml
paddle/phi/api/yaml/static_ops.yaml
+10
-0
paddle/phi/backends/xpu/xpu2_op_list.cc
paddle/phi/backends/xpu/xpu2_op_list.cc
+2
-0
paddle/phi/infermeta/fusion.cc
paddle/phi/infermeta/fusion.cc
+104
-0
paddle/phi/infermeta/fusion.h
paddle/phi/infermeta/fusion.h
+35
-0
paddle/phi/kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc
.../kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc
+275
-0
未找到文件。
cmake/external/xpu.cmake
浏览文件 @
52e1742f
...
...
@@ -142,6 +142,8 @@ if(WITH_XPU_XFT)
message
(
STATUS
"Compile with XPU XFT!"
)
add_definitions
(
-DPADDLE_WITH_XPU_XFT
)
set
(
XPU_XFT_INC_DIR
"
${
XPU_INC_DIR
}
/xft"
)
include_directories
(
${
XPU_XFT_INC_DIR
}
)
set
(
XPU_XFT_LIB
"
${
XPU_LIB_DIR
}
/
${
XPU_XFT_LIB_NAME
}
"
)
endif
()
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
52e1742f
...
...
@@ -235,6 +235,8 @@ if(WITH_XPU)
pass_library
(
link_xpu_op_max_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
delete_isolated_node_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
endif
()
cc_library
(
...
...
@@ -493,4 +495,8 @@ if(WITH_XPU)
test_delete_isolated_node_pass
SRCS xpu/delete_isolated_node_pass_test.cc
DEPS delete_isolated_node_pass
)
cc_test
(
test_fused_multi_transformer_xpu_quant_pass
SRCS xpu/fused_multi_transformer_xpu_quant_pass_tester.cc
DEPS fused_multi_transformer_xpu_quant_pass
)
endif
()
paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc
浏览文件 @
52e1742f
...
...
@@ -75,7 +75,7 @@ TEST(FuseMultiTransformerLayerPass, encoder_fp) {
1
,
{
2
,
-
1
,
16
,
1024
,
64
},
0
);
auto
*
out
=
layers
.
fused_multi_transformer
(
x
,
auto
outs
=
layers
.
fused_multi_transformer
(
x
,
cache_kv
,
src_mask
,
qkv_w
,
...
...
@@ -93,7 +93,7 @@ TEST(FuseMultiTransformerLayerPass, encoder_fp) {
0.1
,
1e-12
);
x
=
out
;
x
=
out
s
[
0
]
;
}
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
...
...
@@ -126,7 +126,7 @@ TEST(FuseMultiTransformerLayerPass, decoder_fp) {
for
(
int
i
=
0
;
i
<
num_layers
;
++
i
)
{
auto
*
shape_out
=
layers
.
shape
(
src_mask
);
auto
*
time_stamp
=
layers
.
slice
(
shape_out
,
{
0
},
{
3
},
{
4
});
auto
*
out
=
layers
.
fused_multi_transformer
(
x
,
auto
outs
=
layers
.
fused_multi_transformer
(
x
,
cache_kv
,
src_mask
,
qkv_w
,
...
...
@@ -145,7 +145,7 @@ TEST(FuseMultiTransformerLayerPass, decoder_fp) {
1e-12
,
time_stamp
);
x
=
out
;
x
=
out
s
[
0
]
;
}
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
auto
param_scope
=
CreateParamScope
();
...
...
paddle/fluid/framework/ir/node.h
浏览文件 @
52e1742f
...
...
@@ -151,6 +151,15 @@ class Node {
var_desc_
->
SetName
(
new_name
);
}
void
RenameOp
(
const
std
::
string
&
new_name
)
{
PADDLE_ENFORCE_EQ
(
type_
==
Type
::
kOperation
&&
op_desc_
,
true
,
platform
::
errors
::
InvalidArgument
(
"Node must be type of variable."
));
name_
=
new_name
;
op_desc_
->
SetType
(
new_name
);
}
int
DescOrder
()
const
{
return
desc_order_
;
}
int
GetVarNodeBlockId
()
const
{
...
...
paddle/fluid/framework/ir/pass.cc
浏览文件 @
52e1742f
...
...
@@ -49,6 +49,7 @@ static const std::vector<std::string> support_subgraph_passes = {
"fuse_multi_transformer_layer_pass"
,
"delete_quant_dequant_linear_op_pass"
,
"delete_weight_dequant_linear_op_pass"
,
"fused_multi_transformer_xpu_quant_pass"
,
"fc_xpu_fuse_pass"
,
"delete_op_device_pass"
};
...
...
paddle/fluid/framework/ir/pass_tester_helper.h
浏览文件 @
52e1742f
...
...
@@ -571,33 +571,35 @@ struct Layers {
return
out
;
}
VarDesc
*
fused_multi_transformer
(
VarDesc
*
x
,
VarDesc
*
cache_kv
,
VarDesc
*
src_mask
,
VarDesc
*
qkv_w
,
VarDesc
*
qkv_bias
,
VarDesc
*
out_linear_w
,
VarDesc
*
out_linear_bias
,
VarDesc
*
ffn1_w
,
VarDesc
*
ffn1_bias
,
VarDesc
*
ffn2_w
,
VarDesc
*
ffn2_bias
,
VarDesc
*
ln_scale
,
VarDesc
*
ln_bias
,
VarDesc
*
ffn_ln_scale
,
VarDesc
*
ffn_ln_bias
,
float
epsilon
,
float
dropout_rate
,
VarDesc
*
time_stamp
=
nullptr
,
VarDesc
*
qkv_out_scale
=
nullptr
,
VarDesc
*
out_linear_out_scale
=
nullptr
,
VarDesc
*
ffn1_out_scale
=
nullptr
,
VarDesc
*
ffn2_out_scale
=
nullptr
,
std
::
vector
<
float
>
qkv_in_scale
=
{},
std
::
vector
<
float
>
out_linear_in_scale
=
{},
std
::
vector
<
float
>
ffn1_in_scale
=
{},
std
::
vector
<
float
>
ffn2_in_scale
=
{})
{
std
::
vector
<
VarDesc
*>
fused_multi_transformer
(
VarDesc
*
x
,
VarDesc
*
cache_kv
,
VarDesc
*
src_mask
,
VarDesc
*
qkv_w
,
VarDesc
*
qkv_bias
,
VarDesc
*
out_linear_w
,
VarDesc
*
out_linear_bias
,
VarDesc
*
ffn1_w
,
VarDesc
*
ffn1_bias
,
VarDesc
*
ffn2_w
,
VarDesc
*
ffn2_bias
,
VarDesc
*
ln_scale
,
VarDesc
*
ln_bias
,
VarDesc
*
ffn_ln_scale
,
VarDesc
*
ffn_ln_bias
,
float
epsilon
,
float
dropout_rate
,
VarDesc
*
time_stamp
=
nullptr
,
VarDesc
*
qkv_out_scale
=
nullptr
,
VarDesc
*
out_linear_out_scale
=
nullptr
,
VarDesc
*
ffn1_out_scale
=
nullptr
,
VarDesc
*
ffn2_out_scale
=
nullptr
,
std
::
vector
<
float
>
qkv_in_scale
=
{},
std
::
vector
<
float
>
out_linear_in_scale
=
{},
std
::
vector
<
float
>
ffn1_in_scale
=
{},
std
::
vector
<
float
>
ffn2_in_scale
=
{})
{
VarDesc
*
out
=
lod_tensor
(
unique_name
());
VarDesc
*
cache_kv_out
=
lod_tensor
(
unique_name
());
OpDesc
*
op
=
program_
.
MutableBlock
(
0
)
->
AppendOp
();
std
::
string
op_type
=
qkv_out_scale
?
"fused_multi_transformer_int8"
:
"fused_multi_transformer"
;
...
...
@@ -623,6 +625,7 @@ struct Layers {
op
->
SetAttr
(
"dropout_rate"
,
dropout_rate
);
op
->
SetAttr
(
"epsilon"
,
epsilon
);
op
->
SetOutput
(
"Out"
,
{
out
->
Name
()});
op
->
SetOutput
(
"CacheKVOut"
,
{
cache_kv_out
->
Name
()});
if
(
time_stamp
)
{
op
->
SetInput
(
"TimeStep"
,
{
time_stamp
->
Name
()});
...
...
@@ -638,7 +641,8 @@ struct Layers {
op
->
SetAttr
(
"ffn1_in_scale"
,
ffn1_in_scale
);
op
->
SetAttr
(
"ffn2_in_scale"
,
ffn2_in_scale
);
}
return
out
;
std
::
vector
<
VarDesc
*>
outs
=
{
out
,
cache_kv_out
};
return
outs
;
}
VarDesc
*
dequantize_linear
(
VarDesc
*
x
,
...
...
paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc
0 → 100644
浏览文件 @
52e1742f
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace
phi
{
class
DenseTensor
;
}
// namespace phi
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
FusedMultiTransformerPattern
:
public
PatternBase
{
FusedMultiTransformerPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
bool
with_cache_kv
,
bool
with_pre_caches
,
bool
with_rotary_pos_emb
,
bool
with_time_step
,
bool
with_seq_lengths
,
bool
with_src_mask
);
// declare operator node's name
PATTERN_DECL_NODE
(
fused_mt
);
// declare variable node's name
PATTERN_DECL_NODE
(
x
);
PATTERN_DECL_NODE
(
ln_scale
);
PATTERN_DECL_NODE
(
ln_bias
);
PATTERN_DECL_NODE
(
qkv_w
);
PATTERN_DECL_NODE
(
qkv_bias
);
PATTERN_DECL_NODE
(
cache_kv
);
PATTERN_DECL_NODE
(
pre_caches
);
PATTERN_DECL_NODE
(
rotary_pos_emb
);
PATTERN_DECL_NODE
(
time_step
);
PATTERN_DECL_NODE
(
seq_lengths
);
PATTERN_DECL_NODE
(
src_mask
);
PATTERN_DECL_NODE
(
out_linear_w
);
PATTERN_DECL_NODE
(
out_linear_bias
);
PATTERN_DECL_NODE
(
ffn_ln_scale
);
PATTERN_DECL_NODE
(
ffn_ln_bias
);
PATTERN_DECL_NODE
(
ffn1_w
);
PATTERN_DECL_NODE
(
ffn1_bias
);
PATTERN_DECL_NODE
(
ffn2_w
);
PATTERN_DECL_NODE
(
ffn2_bias
);
PATTERN_DECL_NODE
(
cache_kv_out
);
PATTERN_DECL_NODE
(
out
);
private:
bool
with_cache_kv_
{
false
};
bool
with_pre_caches_
{
false
};
bool
with_rotary_pos_emb_
{
false
};
bool
with_time_step_
{
false
};
bool
with_seq_lengths_
{
false
};
bool
with_src_mask_
{
false
};
};
FusedMultiTransformerPattern
::
FusedMultiTransformerPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
bool
with_cache_kv
,
bool
with_pre_caches
,
bool
with_rotary_pos_emb
,
bool
with_time_step
,
bool
with_seq_lengths
,
bool
with_src_mask
)
:
PatternBase
(
pattern
,
name_scope
,
name_scope
),
with_cache_kv_
(
with_cache_kv
),
with_pre_caches_
(
with_pre_caches
),
with_rotary_pos_emb_
(
with_rotary_pos_emb
),
with_time_step_
(
with_time_step
),
with_seq_lengths_
(
with_seq_lengths
),
with_src_mask_
(
with_src_mask
)
{
std
::
string
op_type
=
"fused_multi_transformer"
;
auto
*
fused_mt
=
pattern
->
NewNode
(
fused_mt_repr
())
->
assert_is_op
(
op_type
);
// inputs and outputs
auto
*
x
=
pattern
->
NewNode
(
x_repr
())
->
assert_is_op_input
(
op_type
,
"X"
)
->
assert_var_not_persistable
();
auto
*
cache_kv_out
=
pattern
->
NewNode
(
cache_kv_out_repr
())
->
assert_is_op_output
(
op_type
,
"CacheKVOut"
)
->
assert_var_not_persistable
();
auto
*
out
=
pattern
->
NewNode
(
out_repr
())
->
assert_is_op_output
(
op_type
,
"Out"
)
->
assert_var_not_persistable
();
// weights and biases
auto
*
ln_scale
=
pattern
->
NewNode
(
ln_scale_repr
())
->
assert_is_op_input
(
op_type
,
"LnScale"
)
->
assert_is_persistable_var
()
->
assert_more
([](
Node
*
node
)
{
return
node
->
Var
()
->
GetShape
().
size
()
==
1
;
});
auto
*
ln_bias
=
pattern
->
NewNode
(
ln_bias_repr
())
->
assert_is_op_input
(
op_type
,
"LnBias"
)
->
assert_is_persistable_var
()
->
assert_more
([](
Node
*
node
)
{
return
node
->
Var
()
->
GetShape
().
size
()
==
1
;
});
auto
*
qkv_w
=
pattern
->
NewNode
(
qkv_w_repr
())
->
assert_is_op_input
(
op_type
,
"QKVW"
)
->
assert_is_persistable_var
()
->
assert_more
([](
Node
*
node
)
{
return
node
->
Var
()
->
GetShape
().
size
()
==
4
;
});
auto
*
qkv_bias
=
pattern
->
NewNode
(
qkv_bias_repr
())
->
assert_is_op_input
(
op_type
,
"QKVBias"
)
->
assert_is_persistable_var
()
->
assert_more
([](
Node
*
node
)
{
return
node
->
Var
()
->
GetShape
().
size
()
==
3
;
});
auto
*
out_linear_w
=
pattern
->
NewNode
(
out_linear_w_repr
())
->
assert_is_op_input
(
op_type
,
"OutLinearW"
)
->
assert_is_persistable_var
()
->
assert_more
([](
Node
*
node
)
{
return
node
->
Var
()
->
GetShape
().
size
()
==
2
;
});
auto
*
out_linear_bias
=
pattern
->
NewNode
(
out_linear_bias_repr
())
->
assert_is_op_input
(
op_type
,
"OutLinearBias"
)
->
assert_is_persistable_var
()
->
assert_more
([](
Node
*
node
)
{
return
node
->
Var
()
->
GetShape
().
size
()
==
1
;
});
auto
*
ffn_ln_scale
=
pattern
->
NewNode
(
ffn_ln_scale_repr
())
->
assert_is_op_input
(
op_type
,
"FFNLnScale"
)
->
assert_is_persistable_var
()
->
assert_more
([](
Node
*
node
)
{
return
node
->
Var
()
->
GetShape
().
size
()
==
1
;
});
auto
*
ffn_ln_bias
=
pattern
->
NewNode
(
ffn_ln_bias_repr
())
->
assert_is_op_input
(
op_type
,
"FFNLnBias"
)
->
assert_is_persistable_var
()
->
assert_more
([](
Node
*
node
)
{
return
node
->
Var
()
->
GetShape
().
size
()
==
1
;
});
auto
*
ffn1_w
=
pattern
->
NewNode
(
ffn1_w_repr
())
->
assert_is_op_input
(
op_type
,
"FFN1Weight"
)
->
assert_is_persistable_var
()
->
assert_more
([](
Node
*
node
)
{
return
node
->
Var
()
->
GetShape
().
size
()
==
2
;
});
auto
*
ffn1_bias
=
pattern
->
NewNode
(
ffn1_bias_repr
())
->
assert_is_op_input
(
op_type
,
"FFN1Bias"
)
->
assert_is_persistable_var
()
->
assert_more
([](
Node
*
node
)
{
return
node
->
Var
()
->
GetShape
().
size
()
==
1
;
});
auto
*
ffn2_w
=
pattern
->
NewNode
(
ffn2_w_repr
())
->
assert_is_op_input
(
op_type
,
"FFN2Weight"
)
->
assert_is_persistable_var
()
->
assert_more
([](
Node
*
node
)
{
return
node
->
Var
()
->
GetShape
().
size
()
==
2
;
});
auto
*
ffn2_bias
=
pattern
->
NewNode
(
ffn2_bias_repr
())
->
assert_is_op_input
(
op_type
,
"FFN2Bias"
)
->
assert_is_persistable_var
()
->
assert_more
([](
Node
*
node
)
{
return
node
->
Var
()
->
GetShape
().
size
()
==
1
;
});
std
::
vector
<
PDNode
*>
input_vars
{
x
,
ln_scale
,
ln_bias
,
qkv_w
,
qkv_bias
,
out_linear_w
,
out_linear_bias
,
ffn_ln_scale
,
ffn_ln_bias
,
ffn1_w
,
ffn1_bias
,
ffn2_w
,
ffn2_bias
};
std
::
vector
<
PDNode
*>
output_vars
{
cache_kv_out
,
out
};
// optional node
PDNode
*
cache_kv
=
nullptr
;
PDNode
*
pre_caches
=
nullptr
;
PDNode
*
rotary_pos_emb
=
nullptr
;
PDNode
*
time_step
=
nullptr
;
PDNode
*
seq_lengths
=
nullptr
;
PDNode
*
src_mask
=
nullptr
;
if
(
with_cache_kv_
)
{
cache_kv
=
pattern
->
NewNode
(
cache_kv_repr
())
->
assert_is_op_input
(
op_type
,
"CacheKV"
)
->
assert_var_not_persistable
();
input_vars
.
push_back
(
cache_kv
);
}
if
(
with_pre_caches_
)
{
pre_caches
=
pattern
->
NewNode
(
pre_caches_repr
())
->
assert_is_op_input
(
op_type
,
"PreCaches"
)
->
assert_var_not_persistable
();
input_vars
.
push_back
(
pre_caches
);
}
if
(
with_rotary_pos_emb_
)
{
rotary_pos_emb
=
pattern
->
NewNode
(
rotary_pos_emb_repr
())
->
assert_is_op_input
(
op_type
,
"RotaryPosEmb"
)
->
assert_var_not_persistable
();
input_vars
.
push_back
(
rotary_pos_emb
);
}
if
(
with_time_step_
)
{
time_step
=
pattern
->
NewNode
(
time_step_repr
())
->
assert_is_op_input
(
op_type
,
"TimeStep"
)
->
assert_var_not_persistable
();
input_vars
.
push_back
(
time_step
);
}
if
(
with_seq_lengths_
)
{
seq_lengths
=
pattern
->
NewNode
(
seq_lengths_repr
())
->
assert_is_op_input
(
op_type
,
"SeqLengths"
)
->
assert_var_not_persistable
();
input_vars
.
push_back
(
seq_lengths
);
}
if
(
with_src_mask_
)
{
src_mask
=
pattern
->
NewNode
(
src_mask_repr
())
->
assert_is_op_input
(
op_type
,
"SrcMask"
)
->
assert_var_not_persistable
();
input_vars
.
push_back
(
src_mask
);
}
fused_mt
->
LinksFrom
(
input_vars
).
LinksTo
(
output_vars
);
}
}
// namespace patterns
/*
1. transpose and quantify the weights of fused_multi_transformer op from fp32 to
int16
*/
class
FusedMultiTransformerXPUQuantPass
:
public
FusePassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
int
ApplyImpl
(
ir
::
Graph
*
graph
,
bool
with_cache_kv
,
bool
with_pre_caches
,
bool
with_rotary_pos_emb
,
bool
with_time_step
,
bool
with_seq_lengths
,
bool
with_src_mask
)
const
;
const
std
::
string
name_scope_
{
"fused_multi_transformer_xpu_quant_pass"
};
};
void
FusedMultiTransformerXPUQuantPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
Init
(
name_scope_
,
graph
);
VLOG
(
3
)
<<
"in FusedMultiTransformerXPUQuantPass::ApplyImpl"
;
int
found_subgraph_count
=
0
;
for
(
bool
with_time_step
:
{
true
,
false
})
{
found_subgraph_count
+=
ApplyImpl
(
graph
,
true
,
false
,
false
,
with_time_step
,
false
,
true
);
}
AddStatis
(
found_subgraph_count
);
}
int
FusedMultiTransformerXPUQuantPass
::
ApplyImpl
(
ir
::
Graph
*
graph
,
bool
with_cache_kv
,
bool
with_pre_caches
,
bool
with_rotary_pos_emb
,
bool
with_time_step
,
bool
with_seq_lengths
,
bool
with_src_mask
)
const
{
GraphPatternDetector
gpd
;
patterns
::
FusedMultiTransformerPattern
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
,
with_cache_kv
,
with_pre_caches
,
with_rotary_pos_emb
,
with_time_step
,
with_seq_lengths
,
with_src_mask
);
int
found_subgraph_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
4
)
<<
"handle FusedMultiTransformerXPUQuantPass fuse"
;
GET_IR_NODE
(
x
);
GET_IR_NODE
(
ln_scale
);
GET_IR_NODE
(
ln_bias
);
GET_IR_NODE
(
qkv_w
);
GET_IR_NODE
(
qkv_bias
);
GET_IR_NODE
(
cache_kv
);
GET_IR_NODE
(
pre_caches
);
GET_IR_NODE
(
rotary_pos_emb
);
GET_IR_NODE
(
time_step
);
GET_IR_NODE
(
seq_lengths
);
GET_IR_NODE
(
src_mask
);
GET_IR_NODE
(
out_linear_w
);
GET_IR_NODE
(
out_linear_bias
);
GET_IR_NODE
(
ffn_ln_scale
);
GET_IR_NODE
(
ffn_ln_bias
);
GET_IR_NODE
(
ffn1_w
);
GET_IR_NODE
(
ffn1_bias
);
GET_IR_NODE
(
ffn2_w
);
GET_IR_NODE
(
ffn2_bias
);
GET_IR_NODE
(
cache_kv_out
);
GET_IR_NODE
(
out
);
GET_IR_NODE
(
fused_mt
);
auto
*
block
=
fused_mt
->
Op
()
->
Block
();
auto
*
scope
=
param_scope
();
// quant weight nodes
// w_nodes_vec: [QKVW, OutLinearW, FFN1Weight, FFN2Weight]
std
::
vector
<
std
::
vector
<
Node
*>>
w_nodes_vec
(
4
);
std
::
vector
<
std
::
vector
<
Node
*>>
w_int16_nodes_vec
(
4
);
std
::
vector
<
std
::
vector
<
Node
*>>
w_max_nodes_vec
(
4
);
std
::
vector
<
std
::
vector
<
std
::
string
>>
w_int16_names_vec
(
4
);
std
::
vector
<
std
::
vector
<
std
::
string
>>
w_max_names_vec
(
4
);
auto
quant_func
=
[
&
](
const
std
::
string
&
input_name
,
std
::
vector
<
Node
*>*
w_nodes
,
std
::
vector
<
Node
*>*
w_int16_nodes
,
std
::
vector
<
Node
*>*
w_max_nodes
,
std
::
vector
<
std
::
string
>*
w_int16_names
,
std
::
vector
<
std
::
string
>*
w_max_names
,
bool
need_transpose
)
{
typedef
int16_t
TW
;
auto
w_names
=
fused_mt
->
Op
()
->
Input
(
input_name
);
for
(
auto
w_name
:
w_names
)
{
Node
*
w_node
=
FindNodeWithName
(
graph
,
w_name
);
Node
*
w_int16
=
nullptr
;
Node
*
w_max
=
nullptr
;
PADDLE_ENFORCE_NE
(
w_node
,
nullptr
,
platform
::
errors
::
Fatal
(
"w node should not be nullptr"
));
PrepareWeight
<
TW
>
(
graph
,
scope
,
block
,
w_node
,
&
w_int16
,
&
w_max
,
need_transpose
);
w_nodes
->
push_back
(
w_node
);
w_int16_nodes
->
push_back
(
w_int16
);
w_max_nodes
->
push_back
(
w_max
);
}
for
(
size_t
i
=
0
;
i
<
w_names
.
size
();
++
i
)
{
w_int16_names
->
push_back
(
w_int16_nodes
->
at
(
i
)
->
Name
());
w_max_names
->
push_back
(
w_max_nodes
->
at
(
i
)
->
Name
());
}
PADDLE_ENFORCE_EQ
(
w_names
.
size
(),
w_nodes
->
size
(),
platform
::
errors
::
Fatal
(
"The size of w_names(%d) should be equal to w_nodes(%d)"
,
static_cast
<
int
>
(
w_names
.
size
()),
static_cast
<
int
>
(
w_nodes
->
size
())));
PADDLE_ENFORCE_EQ
(
w_names
.
size
(),
w_int16_nodes
->
size
(),
platform
::
errors
::
Fatal
(
"The size of w_names(%d) should be equal to w_int16_nodes(%d)"
,
static_cast
<
int
>
(
w_names
.
size
()),
static_cast
<
int
>
(
w_int16_nodes
->
size
())));
PADDLE_ENFORCE_EQ
(
w_names
.
size
(),
w_max_nodes
->
size
(),
platform
::
errors
::
Fatal
(
"The size of w_names(%d) should be equal to w_max_nodes(%d)"
,
static_cast
<
int
>
(
w_names
.
size
()),
static_cast
<
int
>
(
w_max_nodes
->
size
())));
PADDLE_ENFORCE_EQ
(
w_names
.
size
(),
w_int16_names
->
size
(),
platform
::
errors
::
Fatal
(
"The size of w_names(%d) should be equal to w_int16_names(%d)"
,
static_cast
<
int
>
(
w_names
.
size
()),
static_cast
<
int
>
(
w_int16_names
->
size
())));
PADDLE_ENFORCE_EQ
(
w_names
.
size
(),
w_max_names
->
size
(),
platform
::
errors
::
Fatal
(
"The size of w_names(%d) should be equal to w_max_names(%d)"
,
static_cast
<
int
>
(
w_names
.
size
()),
static_cast
<
int
>
(
w_max_names
->
size
())));
};
quant_func
(
"QKVW"
,
&
(
w_nodes_vec
[
0
]),
&
(
w_int16_nodes_vec
[
0
]),
&
(
w_max_nodes_vec
[
0
]),
&
(
w_int16_names_vec
[
0
]),
&
(
w_max_names_vec
[
0
]),
false
);
quant_func
(
"OutLinearW"
,
&
(
w_nodes_vec
[
1
]),
&
(
w_int16_nodes_vec
[
1
]),
&
(
w_max_nodes_vec
[
1
]),
&
(
w_int16_names_vec
[
1
]),
&
(
w_max_names_vec
[
1
]),
true
);
quant_func
(
"FFN1Weight"
,
&
(
w_nodes_vec
[
2
]),
&
(
w_int16_nodes_vec
[
2
]),
&
(
w_max_nodes_vec
[
2
]),
&
(
w_int16_names_vec
[
2
]),
&
(
w_max_names_vec
[
2
]),
true
);
quant_func
(
"FFN2Weight"
,
&
(
w_nodes_vec
[
3
]),
&
(
w_int16_nodes_vec
[
3
]),
&
(
w_max_nodes_vec
[
3
]),
&
(
w_int16_names_vec
[
3
]),
&
(
w_max_names_vec
[
3
]),
true
);
// cast some nodes to fp32 nodes
std
::
vector
<
Node
*>
fp32_nodes
;
auto
cast_tofp32_func
=
[
&
](
const
std
::
string
&
input_name
)
{
auto
names
=
fused_mt
->
Op
()
->
Input
(
input_name
);
for
(
auto
name
:
names
)
{
auto
*
curr_tensor
=
scope
->
Var
(
name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
PADDLE_ENFORCE_NE
(
curr_tensor
,
nullptr
,
platform
::
errors
::
Fatal
(
"tensor node should not be nullptr"
));
CastToFp32
(
curr_tensor
);
Node
*
curr_node
=
FindNodeWithName
(
graph
,
name
);
fp32_nodes
.
push_back
(
curr_node
);
}
};
cast_tofp32_func
(
"LnScale"
);
cast_tofp32_func
(
"LnBias"
);
cast_tofp32_func
(
"QKVBias"
);
cast_tofp32_func
(
"OutLinearBias"
);
cast_tofp32_func
(
"FFNLnScale"
);
cast_tofp32_func
(
"FFNLnBias"
);
cast_tofp32_func
(
"FFN1Bias"
);
cast_tofp32_func
(
"FFN2Bias"
);
// Generate fused_multi_transformer_xpu op inplace
fused_mt
->
RenameOp
(
"fused_multi_transformer_xpu"
);
framework
::
OpDesc
*
fused_mt_xpu_op_desc
=
fused_mt
->
Op
();
fused_mt_xpu_op_desc
->
SetType
(
"fused_multi_transformer_xpu"
);
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
name_caches
;
for
(
auto
key
:
fused_mt_xpu_op_desc
->
InputNames
())
{
name_caches
.
insert
({
key
,
fused_mt_xpu_op_desc
->
Input
(
key
)});
}
for
(
auto
key
:
fused_mt_xpu_op_desc
->
OutputNames
())
{
name_caches
.
insert
({
key
,
fused_mt_xpu_op_desc
->
Output
(
key
)});
}
fused_mt_xpu_op_desc
->
MutableInputs
()
->
clear
();
fused_mt_xpu_op_desc
->
MutableOutputs
()
->
clear
();
fused_mt_xpu_op_desc
->
SetInput
(
"x"
,
name_caches
.
at
(
"X"
));
fused_mt_xpu_op_desc
->
SetInput
(
"ln_scale"
,
name_caches
.
at
(
"LnScale"
));
fused_mt_xpu_op_desc
->
SetInput
(
"ln_bias"
,
name_caches
.
at
(
"LnBias"
));
fused_mt_xpu_op_desc
->
SetInput
(
"qkv_bias"
,
name_caches
.
at
(
"QKVBias"
));
if
(
cache_kv
)
{
fused_mt_xpu_op_desc
->
SetInput
(
"cache_kv"
,
name_caches
.
at
(
"CacheKV"
));
}
if
(
pre_caches
)
{
fused_mt_xpu_op_desc
->
SetInput
(
"pre_caches"
,
name_caches
.
at
(
"PreCaches"
));
}
if
(
rotary_pos_emb
)
{
fused_mt_xpu_op_desc
->
SetInput
(
"rotary_pos_emb"
,
name_caches
.
at
(
"RotaryPosEmb"
));
}
if
(
time_step
)
{
fused_mt_xpu_op_desc
->
SetInput
(
"time_step"
,
name_caches
.
at
(
"TimeStep"
));
}
if
(
seq_lengths
)
{
fused_mt_xpu_op_desc
->
SetInput
(
"seq_lengths"
,
name_caches
.
at
(
"SeqLengths"
));
}
if
(
src_mask
)
{
fused_mt_xpu_op_desc
->
SetInput
(
"src_mask"
,
name_caches
.
at
(
"SrcMask"
));
}
fused_mt_xpu_op_desc
->
SetInput
(
"out_linear_bias"
,
name_caches
.
at
(
"OutLinearBias"
));
fused_mt_xpu_op_desc
->
SetInput
(
"ffn_ln_scale"
,
name_caches
.
at
(
"FFNLnScale"
));
fused_mt_xpu_op_desc
->
SetInput
(
"ffn_ln_bias"
,
name_caches
.
at
(
"FFNLnBias"
));
fused_mt_xpu_op_desc
->
SetInput
(
"ffn1_bias"
,
name_caches
.
at
(
"FFN1Bias"
));
fused_mt_xpu_op_desc
->
SetInput
(
"ffn2_bias"
,
name_caches
.
at
(
"FFN2Bias"
));
fused_mt_xpu_op_desc
->
SetOutput
(
"cache_kv_out"
,
name_caches
.
at
(
"CacheKVOut"
));
fused_mt_xpu_op_desc
->
SetOutput
(
"out"
,
name_caches
.
at
(
"Out"
));
fused_mt_xpu_op_desc
->
SetInput
(
"qkvw"
,
w_int16_names_vec
[
0
]);
fused_mt_xpu_op_desc
->
SetInput
(
"qkvw_max"
,
w_max_names_vec
[
0
]);
fused_mt_xpu_op_desc
->
SetInput
(
"out_linear_w"
,
w_int16_names_vec
[
1
]);
fused_mt_xpu_op_desc
->
SetInput
(
"out_linear_wmax"
,
w_max_names_vec
[
1
]);
fused_mt_xpu_op_desc
->
SetInput
(
"ffn1_weight"
,
w_int16_names_vec
[
2
]);
fused_mt_xpu_op_desc
->
SetInput
(
"ffn1_weight_max"
,
w_max_names_vec
[
2
]);
fused_mt_xpu_op_desc
->
SetInput
(
"ffn2_weight"
,
w_int16_names_vec
[
3
]);
fused_mt_xpu_op_desc
->
SetInput
(
"ffn2_weight_max"
,
w_max_names_vec
[
3
]);
if
(
!
fused_mt_xpu_op_desc
->
HasAttr
(
"rotary_emb_dims"
))
{
fused_mt_xpu_op_desc
->
SetAttr
(
"rotary_emb_dims"
,
0
);
}
// unlink QKVW/OutLinearW/FFN1Weight/FFN2Weight from fused_mt_xpu
for
(
auto
nodes
:
w_nodes_vec
)
{
for
(
auto
node
:
nodes
)
{
IR_NODE_UNLINK
(
node
,
fused_mt
);
}
}
// link int16 format of QKVW/OutLinearW/FFN1Weight/FFN2Weight to
// fused_mt_xpu
for
(
auto
nodes
:
w_int16_nodes_vec
)
{
for
(
auto
node
:
nodes
)
{
IR_NODE_LINK_TO
(
node
,
fused_mt
);
}
}
// link QKVWMax/OutLinearWMax/FFN1WeightMax/FFN2WeightMax to fused_mt_xpu
for
(
auto
nodes
:
w_max_nodes_vec
)
{
for
(
auto
node
:
nodes
)
{
IR_NODE_LINK_TO
(
node
,
fused_mt
);
}
}
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
return
found_subgraph_count
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
fused_multi_transformer_xpu_quant_pass
,
paddle
::
framework
::
ir
::
FusedMultiTransformerXPUQuantPass
);
paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass_tester.cc
0 → 100644
浏览文件 @
52e1742f
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#define DEF_INPUT_DATA \
Layers layers; \
auto* x = layers.data("x", {1, 128, 1024}); \
auto* src_mask = layers.data("src_mask", {1, 16, 128, 128}); \
auto* ln_scale = layers.data("ln_scale", {1024}, true); \
auto* ln_bias = layers.data("ln_bias", {1024}, true); \
auto* qkv_w = layers.data("qkv_w", {3, 16, 64, 1024}, true); \
auto* qkv_bias = layers.data("qkv_bias", {3, 16, 64}, true); \
auto* out_linear_w = layers.data("out_linear_w", {1024, 1024}, true); \
auto* out_linear_bias = layers.data("out_linear_bias", {1024}, true); \
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); \
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); \
auto* ffn1_w = layers.data("ffn1_w", {1024, 4096}, true); \
auto* ffn1_bias = layers.data("ffn1_bias", {4096}, true); \
auto* ffn2_w = layers.data("ffn2_w", {4096, 1024}, true); \
auto* ffn2_bias = layers.data("ffn2_bias", {1024}, true);
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
AddVarToScope
(
Scope
*
param_scope
,
const
std
::
string
&
name
,
const
DDim
&
dims
)
{
auto
*
tensor
=
param_scope
->
Var
(
name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
tensor
->
Resize
(
dims
);
tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
}
Scope
*
CreateParamScope
()
{
auto
param_scope
=
new
Scope
();
AddVarToScope
(
param_scope
,
"ln_scale"
,
{
1024
});
AddVarToScope
(
param_scope
,
"ln_bias"
,
{
1024
});
AddVarToScope
(
param_scope
,
"ffn_ln_scale"
,
{
1024
});
AddVarToScope
(
param_scope
,
"ffn_ln_bias"
,
{
1024
});
AddVarToScope
(
param_scope
,
"qkv_w"
,
{
3
,
16
,
64
,
1024
});
AddVarToScope
(
param_scope
,
"out_linear_w"
,
{
1024
,
1024
});
AddVarToScope
(
param_scope
,
"ffn1_w"
,
{
1024
,
4096
});
AddVarToScope
(
param_scope
,
"ffn2_w"
,
{
4096
,
1024
});
AddVarToScope
(
param_scope
,
"qkv_bias"
,
{
3072
});
AddVarToScope
(
param_scope
,
"out_linear_bias"
,
{
1024
});
AddVarToScope
(
param_scope
,
"ffn1_bias"
,
{
4096
});
AddVarToScope
(
param_scope
,
"ffn2_bias"
,
{
1024
});
return
param_scope
;
}
TEST
(
FusedMultiTransformerXPUQuantPass
,
context_stage
)
{
DEF_INPUT_DATA
auto
*
cache_kv
=
layers
.
fill_constant_batch_size_like
(
x
,
static_cast
<
int
>
(
proto
::
VarType
::
FP32
),
0
,
1
,
{
2
,
-
1
,
16
,
1024
,
64
},
0
);
layers
.
fused_multi_transformer
(
x
,
cache_kv
,
src_mask
,
qkv_w
,
qkv_bias
,
out_linear_w
,
out_linear_bias
,
ffn1_w
,
ffn1_bias
,
ffn2_w
,
ffn2_bias
,
ln_scale
,
ln_bias
,
ffn_ln_scale
,
ffn_ln_bias
,
0.1
,
1e-12
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"fused_multi_transformer_xpu_quant_pass"
);
if
(
pass
.
get
()
==
nullptr
)
{
LOG
(
INFO
)
<<
"get fused_multi_transformer_xpu_quant_pass failed"
;
}
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
num_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer_xpu"
);
VLOG
(
3
)
<<
DebugString
(
graph
);
PADDLE_ENFORCE_EQ
(
num_nodes_after
,
1
,
platform
::
errors
::
InvalidArgument
(
"After the fuse_multi_transformer_layer_pass, "
"The node num in graph should be 1, but the result is %d"
,
num_nodes_after
));
}
TEST
(
FusedMultiTransformerXPUQuantPass
,
decoder_stage
)
{
DEF_INPUT_DATA
auto
*
cache_kv
=
layers
.
fill_constant_batch_size_like
(
x
,
static_cast
<
int
>
(
proto
::
VarType
::
FP32
),
0
,
1
,
{
2
,
-
1
,
16
,
1024
,
64
},
0
);
auto
*
time_step
=
layers
.
data
(
"time_step"
,
{
1
});
layers
.
fused_multi_transformer
(
x
,
cache_kv
,
src_mask
,
qkv_w
,
qkv_bias
,
out_linear_w
,
out_linear_bias
,
ffn1_w
,
ffn1_bias
,
ffn2_w
,
ffn2_bias
,
ln_scale
,
ln_bias
,
ffn_ln_scale
,
ffn_ln_bias
,
0.1
,
1e-12
,
time_step
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"fused_multi_transformer_xpu_quant_pass"
);
if
(
pass
.
get
()
==
nullptr
)
{
LOG
(
INFO
)
<<
"get fused_multi_transformer_xpu_quant_pass failed"
;
}
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
num_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer_xpu"
);
VLOG
(
3
)
<<
DebugString
(
graph
);
PADDLE_ENFORCE_EQ
(
num_nodes_after
,
1
,
platform
::
errors
::
InvalidArgument
(
"After the fuse_multi_transformer_layer_pass, "
"The node num in graph should be 1, but the result is %d"
,
num_nodes_after
));
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
fused_multi_transformer_xpu_quant_pass
);
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
52e1742f
...
...
@@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"embedding_with_eltwise_add_xpu_fuse_pass"
,
"multi_encoder_xpu_fuse_pass"
,
"multi_encoder_xpu_slice_fuse_pass"
,
"fused_multi_transformer_xpu_quant_pass"
,
"fc_xpu_fuse_pass"
,
"link_xpu_op_max_pass"
,
"delete_op_device_pass"
,
...
...
paddle/phi/api/yaml/static_ops.yaml
浏览文件 @
52e1742f
...
...
@@ -47,6 +47,16 @@
param
:
[
x
,
axis
,
keepdim
,
reduce_all
]
backward
:
frobenius_norm_grad
-
op
:
fused_multi_transformer_xpu
args
:
(Tensor x, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] qkvw, Tensor[] qkvw_max, Tensor[] qkv_bias, Tensor[] out_linear_w, Tensor[] out_linear_wmax, Tensor[] out_linear_bias, Tensor[] ffn_ln_scale, Tensor[] ffn_ln_bias, Tensor[] ffn1_weight, Tensor[] ffn1_weight_max, Tensor[] ffn1_bias, Tensor[] ffn2_weight, Tensor[] ffn2_weight_max, Tensor[] ffn2_bias, Tensor[] cache_kv, Tensor[] pre_caches, Tensor rotary_pos_emb, Tensor time_step, Tensor seq_lengths, Tensor src_mask, bool pre_layer_norm, int rotary_emb_dims, float epsilon, float dropout_rate, bool is_test, str dropout_implementation, str act_method, bool trans_qkvw, int ring_id)
output
:
Tensor(out), Tensor[](cache_kv_out){out_linear_w.size()}
infer_meta
:
func
:
FusedMultiTransformerXpuInferMeta
kernel
:
func
:
fused_multi_transformer_xpu
data_type
:
x
optional
:
cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask
-
op
:
generate_sequence_xpu
args
:
(Tensor x, DataType dtype)
output
:
Tensor
...
...
paddle/phi/backends/xpu/xpu2_op_list.cc
浏览文件 @
52e1742f
...
...
@@ -331,6 +331,8 @@ XPUOpMap& get_kl2_ops() {
phi
::
DataType
::
INT32
,
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"fused_multi_transformer_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"unfold"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"unfold_grad"
,
...
...
paddle/phi/infermeta/fusion.cc
浏览文件 @
52e1742f
...
...
@@ -114,4 +114,108 @@ void MultiEncoderXPUInferMeta(
}
}
void
FusedMultiTransformerXpuInferMeta
(
const
MetaTensor
&
x
,
const
std
::
vector
<
const
MetaTensor
*>&
ln_scale
,
const
std
::
vector
<
const
MetaTensor
*>&
ln_bias
,
const
std
::
vector
<
const
MetaTensor
*>&
qkvw
,
const
std
::
vector
<
const
MetaTensor
*>&
qkvw_max
,
const
std
::
vector
<
const
MetaTensor
*>&
qkv_bias
,
const
std
::
vector
<
const
MetaTensor
*>&
out_linear_w
,
const
std
::
vector
<
const
MetaTensor
*>&
out_linear_wmax
,
const
std
::
vector
<
const
MetaTensor
*>&
out_linear_bias
,
const
std
::
vector
<
const
MetaTensor
*>&
ffn_ln_scale
,
const
std
::
vector
<
const
MetaTensor
*>&
ffn_ln_bias
,
const
std
::
vector
<
const
MetaTensor
*>&
ffn1_weight
,
const
std
::
vector
<
const
MetaTensor
*>&
ffn1_weight_max
,
const
std
::
vector
<
const
MetaTensor
*>&
ffn1_bias
,
const
std
::
vector
<
const
MetaTensor
*>&
ffn2_weight
,
const
std
::
vector
<
const
MetaTensor
*>&
ffn2_weight_max
,
const
std
::
vector
<
const
MetaTensor
*>&
ffn2_bias
,
const
std
::
vector
<
const
MetaTensor
*>&
cache_kv
,
const
std
::
vector
<
const
MetaTensor
*>&
pre_caches
,
const
std
::
vector
<
const
MetaTensor
*>&
rotary_pos_emb
,
const
std
::
vector
<
const
MetaTensor
*>&
time_step
,
const
std
::
vector
<
const
MetaTensor
*>&
seq_lengths
,
const
std
::
vector
<
const
MetaTensor
*>&
src_mask
,
bool
pre_layer_norm
,
int
rotary_emb_dims
,
float
epsilon
,
float
dropout_rate
,
bool
is_test
,
const
std
::
string
&
dropout_implementation
,
const
std
::
string
&
act_method
,
bool
trans_qkvw
,
int
ring_id
,
MetaTensor
*
out
,
std
::
vector
<
MetaTensor
*>
cache_kv_out
)
{
auto
x_dim
=
x
.
dims
();
auto
y_dim
=
qkvw
[
0
]
->
dims
();
PADDLE_ENFORCE_EQ
(
x_dim
.
size
(),
3
,
phi
::
errors
::
InvalidArgument
(
"The dimensions of x must be 3"
"(batch_size, seq_len, dim_embed),"
"but received dimensions of"
"Input is [%d]"
,
x_dim
.
size
()));
PADDLE_ENFORCE_EQ
(
y_dim
.
size
(),
4
,
phi
::
errors
::
InvalidArgument
(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"but received dimensions of"
"Input is [%d]"
,
y_dim
.
size
()));
PADDLE_ENFORCE_EQ
(
x_dim
[
2
],
trans_qkvw
?
y_dim
[
3
]
:
y_dim
[
0
],
phi
::
errors
::
InvalidArgument
(
"ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is "
"true) or y_dim[0](trans_qkvw is false)"
"must be equal. But received: the shape "
"of input x = [%s], and the shape of "
"input qkv_weight = [%s]"
,
x_dim
,
y_dim
));
if
(
cache_kv
.
size
()
>
0
)
{
const
auto
&
c_dim
=
cache_kv
[
0
]
->
dims
();
PADDLE_ENFORCE_EQ
(
c_dim
.
size
(),
5
,
phi
::
errors
::
InvalidArgument
(
"The CacheKV must be 5 dims, but got %d"
,
c_dim
.
size
()));
PADDLE_ENFORCE_EQ
(
c_dim
[
0
],
2
,
phi
::
errors
::
InvalidArgument
(
"The first dim of CacheKV must be 2, but got %d"
,
c_dim
[
0
]));
// 2
PADDLE_ENFORCE_EQ
(
c_dim
[
1
],
x_dim
[
0
],
phi
::
errors
::
InvalidArgument
(
"The second dim of CacheKV must be equal with "
"batch size %d, but got %d"
,
x_dim
[
0
],
c_dim
[
1
]));
// batch_size
PADDLE_ENFORCE_EQ
(
c_dim
[
2
],
trans_qkvw
?
y_dim
[
1
]
:
y_dim
[
2
],
phi
::
errors
::
InvalidArgument
(
"The third dim of CacheKV must be equal with num "
"head %d, but got %d"
,
trans_qkvw
?
y_dim
[
1
]
:
y_dim
[
2
],
c_dim
[
2
]));
// num_head
PADDLE_ENFORCE_EQ
(
c_dim
[
4
],
trans_qkvw
?
y_dim
[
2
]
:
y_dim
[
3
],
phi
::
errors
::
InvalidArgument
(
"The fifth dim of CacheKV must be equal with head "
"size %d, but got %d"
,
trans_qkvw
?
y_dim
[
2
]
:
y_dim
[
3
],
c_dim
[
4
]));
// head_size
}
out
->
set_dims
(
x_dim
);
out
->
set_dtype
(
x
.
dtype
());
out
->
set_layout
(
x
.
layout
());
}
}
// namespace phi
paddle/phi/infermeta/fusion.h
浏览文件 @
52e1742f
...
...
@@ -66,4 +66,39 @@ void MultiEncoderXPUInferMeta(
MetaTensor
*
x_fp16
,
MetaTensor
*
out_fp16
);
void
FusedMultiTransformerXpuInferMeta
(
const
MetaTensor
&
x
,
const
std
::
vector
<
const
MetaTensor
*>&
ln_scale
,
const
std
::
vector
<
const
MetaTensor
*>&
ln_bias
,
const
std
::
vector
<
const
MetaTensor
*>&
qkvw
,
const
std
::
vector
<
const
MetaTensor
*>&
qkvw_max
,
const
std
::
vector
<
const
MetaTensor
*>&
qkv_bias
,
const
std
::
vector
<
const
MetaTensor
*>&
out_linear_w
,
const
std
::
vector
<
const
MetaTensor
*>&
out_linear_wmax
,
const
std
::
vector
<
const
MetaTensor
*>&
out_linear_bias
,
const
std
::
vector
<
const
MetaTensor
*>&
ffn_ln_scale
,
const
std
::
vector
<
const
MetaTensor
*>&
ffn_ln_bias
,
const
std
::
vector
<
const
MetaTensor
*>&
ffn1_weight
,
const
std
::
vector
<
const
MetaTensor
*>&
ffn1_weight_max
,
const
std
::
vector
<
const
MetaTensor
*>&
ffn1_bias
,
const
std
::
vector
<
const
MetaTensor
*>&
ffn2_weight
,
const
std
::
vector
<
const
MetaTensor
*>&
ffn2_weight_max
,
const
std
::
vector
<
const
MetaTensor
*>&
ffn2_bias
,
const
std
::
vector
<
const
MetaTensor
*>&
cache_kv
,
const
std
::
vector
<
const
MetaTensor
*>&
pre_caches
,
const
std
::
vector
<
const
MetaTensor
*>&
rotary_pos_emb
,
const
std
::
vector
<
const
MetaTensor
*>&
time_step
,
const
std
::
vector
<
const
MetaTensor
*>&
seq_lengths
,
const
std
::
vector
<
const
MetaTensor
*>&
src_mask
,
bool
pre_layer_norm
,
int
rotary_emb_dims
,
float
epsilon
,
float
dropout_rate
,
bool
is_test
,
const
std
::
string
&
dropout_implementation
,
const
std
::
string
&
act_method
,
bool
trans_qkvw
,
int
ring_id
,
MetaTensor
*
out
,
std
::
vector
<
MetaTensor
*>
cache_kv_out
);
}
// namespace phi
paddle/phi/kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc
0 → 100644
浏览文件 @
52e1742f
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/memcpy_kernel.h"
#ifdef PADDLE_WITH_XPU_XFT
#include "models/fused_multi_transformer_op.h"
namespace
xft
=
baidu
::
xpu
::
xft
;
#endif
namespace
phi
{
namespace
fusion
{
template
<
typename
T
,
typename
Context
>
void
FusedMultiTransformerXpuKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
xx
,
const
std
::
vector
<
const
DenseTensor
*>&
ln_scale
,
const
std
::
vector
<
const
DenseTensor
*>&
ln_bias
,
const
std
::
vector
<
const
DenseTensor
*>&
qkvw
,
const
std
::
vector
<
const
DenseTensor
*>&
qkvw_max
,
const
std
::
vector
<
const
DenseTensor
*>&
qkv_bias
,
const
std
::
vector
<
const
DenseTensor
*>&
out_linear_w
,
const
std
::
vector
<
const
DenseTensor
*>&
out_linear_wmax
,
const
std
::
vector
<
const
DenseTensor
*>&
out_linear_bias
,
const
std
::
vector
<
const
DenseTensor
*>&
ffn_ln_scale
,
const
std
::
vector
<
const
DenseTensor
*>&
ffn_ln_bias
,
const
std
::
vector
<
const
DenseTensor
*>&
ffn1_weight
,
const
std
::
vector
<
const
DenseTensor
*>&
ffn1_weight_max
,
const
std
::
vector
<
const
DenseTensor
*>&
ffn1_bias
,
const
std
::
vector
<
const
DenseTensor
*>&
ffn2_weight
,
const
std
::
vector
<
const
DenseTensor
*>&
ffn2_weight_max
,
const
std
::
vector
<
const
DenseTensor
*>&
ffn2_bias
,
const
paddle
::
optional
<
std
::
vector
<
const
DenseTensor
*>>&
cache_kv
,
const
paddle
::
optional
<
std
::
vector
<
const
DenseTensor
*>>&
pre_caches
,
const
paddle
::
optional
<
DenseTensor
>&
rotary_pos_emb
,
const
paddle
::
optional
<
DenseTensor
>&
time_step
,
const
paddle
::
optional
<
DenseTensor
>&
seq_lengths
,
const
paddle
::
optional
<
DenseTensor
>&
src_mask
,
bool
pre_layer_norm
,
int
rotary_emb_dims
,
float
epsilon
,
float
dropout_rate
,
bool
is_test
,
const
std
::
string
&
dropout_implementation
,
const
std
::
string
&
act_method
,
bool
trans_qkvw
,
int
ring_id
,
DenseTensor
*
out
,
std
::
vector
<
DenseTensor
*>
cache_kv_out
)
{
#ifdef PADDLE_WITH_XPU_XFT
using
XPUTypeT
=
typename
XPUTypeTrait
<
T
>::
Type
;
PADDLE_ENFORCE_EQ
(
pre_layer_norm
,
true
,
phi
::
errors
::
PreconditionNotMet
(
"Only support pre_layer_norm = true at now."
));
PADDLE_ENFORCE_EQ
(
seq_lengths
.
get_ptr
(),
nullptr
,
phi
::
errors
::
PreconditionNotMet
(
"seq_lengths not support at now."
));
PADDLE_ENFORCE_EQ
(
rotary_pos_emb
.
get_ptr
(),
nullptr
,
phi
::
errors
::
PreconditionNotMet
(
"rotary_pos_emb not support at now."
));
PADDLE_ENFORCE_EQ
(
pre_caches
.
get_ptr
(),
nullptr
,
phi
::
errors
::
PreconditionNotMet
(
"pre_caches not support at now."
));
PADDLE_ENFORCE_NE
(
src_mask
.
get_ptr
(),
nullptr
,
phi
::
errors
::
PreconditionNotMet
(
"src_mask should not be nullptr."
));
PADDLE_ENFORCE_EQ
(
trans_qkvw
,
true
,
phi
::
errors
::
PreconditionNotMet
(
"Only support trans_qkvw == true at now."
));
const
auto
x_dims
=
xx
.
dims
();
int
seq_len
=
x_dims
[
1
];
const
auto
qkv_w_dims
=
qkvw
[
0
]
->
dims
();
int
num_head
=
trans_qkvw
?
qkv_w_dims
[
1
]
:
qkv_w_dims
[
2
];
int
dim_head
=
trans_qkvw
?
qkv_w_dims
[
2
]
:
qkv_w_dims
[
3
];
int
time_step_value
=
-
1
;
if
(
time_step
)
{
PADDLE_ENFORCE_EQ
(
time_step
.
get_ptr
()
->
place
(),
phi
::
CPUPlace
(),
phi
::
errors
::
PreconditionNotMet
(
"The place of input(time_step) must be CPUPlace."
));
// cache_seq_len
time_step_value
=
time_step
.
get_ptr
()
->
data
<
int
>
()[
0
];
PADDLE_ENFORCE_GT
(
time_step_value
,
0
,
phi
::
errors
::
PreconditionNotMet
(
"The value of time_step must > 0, but now is %d"
,
time_step_value
));
PADDLE_ENFORCE_EQ
(
seq_len
,
1
,
phi
::
errors
::
PreconditionNotMet
(
"In decode stage, the seq_len of input must be 1, but now is %d"
,
seq_len
));
}
XPUTypeT
*
x_data
=
reinterpret_cast
<
XPUTypeT
*>
(
const_cast
<
T
*>
(
xx
.
data
<
T
>
()));
XPUTypeT
*
src_mask_data
=
reinterpret_cast
<
XPUTypeT
*>
(
const_cast
<
T
*>
(
src_mask
.
get_ptr
()
->
data
<
T
>
()));
auto
*
out_data
=
reinterpret_cast
<
XPUTypeT
*>
(
ctx
.
template
Alloc
<
T
>(
out
));
auto
src_mask_dims
=
src_mask
.
get_ptr
()
->
dims
();
auto
out_dims
=
out
->
dims
();
auto
xft_x
=
xft
::
xftTensor
<
XPUTypeT
,
3
>
(
x_data
,
std
::
array
<
int64_t
,
3
>
{
x_dims
[
0
],
x_dims
[
1
],
x_dims
[
2
]});
// TODO(mayang02): xft support mask.dtype = float16
xpu
::
ctx_guard
RAII_GUARD
(
ctx
.
x_context
());
float
*
src_mask_fp32_data
=
RAII_GUARD
.
alloc
<
float
>
(
src_mask
.
get_ptr
()
->
numel
());
int
r
=
xpu
::
cast
<
XPUTypeT
,
float
>
(
ctx
.
x_context
(),
src_mask_data
,
src_mask_fp32_data
,
src_mask
.
get_ptr
()
->
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"xpu::cast"
);
auto
xft_src_mask
=
xft
::
xftTensor
<
float
,
4
>
(
src_mask_fp32_data
,
std
::
array
<
int64_t
,
4
>
{
src_mask_dims
[
0
],
src_mask_dims
[
1
],
src_mask_dims
[
2
],
src_mask_dims
[
3
]});
auto
xft_out
=
xft
::
xftTensor
<
XPUTypeT
,
3
>
(
out_data
,
std
::
array
<
int64_t
,
3
>
{
out_dims
[
0
],
out_dims
[
1
],
out_dims
[
2
]});
typedef
int16_t
TW
;
std
::
vector
<
xft
::
xftVec
<
float
>>
xft_ln_scale
;
std
::
vector
<
xft
::
xftVec
<
float
>>
xft_ln_bias
;
std
::
vector
<
xft
::
xftMat
<
TW
>>
xft_qkvw
;
std
::
vector
<
xft
::
xftVec
<
float
>>
xft_qkv_bias
;
std
::
vector
<
xft
::
xftMat
<
TW
>>
xft_out_linear_w
;
std
::
vector
<
xft
::
xftVec
<
float
>>
xft_out_linear_bias
;
std
::
vector
<
xft
::
xftVec
<
float
>>
xft_ffn_ln_scale
;
std
::
vector
<
xft
::
xftVec
<
float
>>
xft_ffn_ln_bias
;
std
::
vector
<
xft
::
xftMat
<
TW
>>
xft_ffn1_w
;
std
::
vector
<
xft
::
xftVec
<
float
>>
xft_ffn1_bias
;
std
::
vector
<
xft
::
xftMat
<
TW
>>
xft_ffn2_w
;
std
::
vector
<
xft
::
xftVec
<
float
>>
xft_ffn2_bias
;
std
::
vector
<
xft
::
xftTensor
<
XPUTypeT
,
5
>>
xft_cache_kv
;
std
::
vector
<
xft
::
xftTensor
<
XPUTypeT
,
5
>>
xft_cache_kv_out
;
int
layers
=
qkvw
.
size
();
for
(
int
i
=
0
;
i
<
layers
;
++
i
)
{
// step1. layer_norm
xft_ln_scale
.
emplace_back
(
const_cast
<
float
*>
(
ln_scale
[
i
]
->
data
<
float
>
()),
std
::
array
<
int64_t
,
1
>
{
ln_scale
[
i
]
->
dims
()[
0
]});
xft_ln_bias
.
emplace_back
(
const_cast
<
float
*>
(
ln_bias
[
i
]
->
data
<
float
>
()),
std
::
array
<
int64_t
,
1
>
{
ln_bias
[
i
]
->
dims
()[
0
]});
// step2. qkv
auto
qkvw_dims
=
qkvw
[
i
]
->
dims
();
xft_qkvw
.
emplace_back
(
const_cast
<
TW
*>
(
qkvw
[
i
]
->
data
<
TW
>
()),
const_cast
<
float
*>
(
qkvw_max
[
i
]
->
data
<
float
>
()),
std
::
array
<
int64_t
,
2
>
{
qkvw_dims
[
0
]
*
qkvw_dims
[
1
]
*
qkvw_dims
[
2
],
qkvw_dims
[
3
]});
auto
qkvb_dims
=
qkv_bias
[
i
]
->
dims
();
xft_qkv_bias
.
emplace_back
(
const_cast
<
float
*>
(
qkv_bias
[
i
]
->
data
<
float
>
()),
std
::
array
<
int64_t
,
1
>
{
qkvb_dims
[
0
]
*
qkvb_dims
[
1
]
*
qkvb_dims
[
2
]});
// attn out
auto
outw_dims
=
out_linear_w
[
i
]
->
dims
();
xft_out_linear_w
.
emplace_back
(
const_cast
<
TW
*>
(
out_linear_w
[
i
]
->
data
<
TW
>
()),
const_cast
<
float
*>
(
out_linear_wmax
[
i
]
->
data
<
float
>
()),
std
::
array
<
int64_t
,
2
>
{
outw_dims
[
0
],
outw_dims
[
1
]});
xft_out_linear_bias
.
emplace_back
(
const_cast
<
float
*>
(
out_linear_bias
[
i
]
->
data
<
float
>
()),
std
::
array
<
int64_t
,
1
>
{
out_linear_bias
[
i
]
->
dims
()[
0
]});
// ffn ln
xft_ffn_ln_scale
.
emplace_back
(
const_cast
<
float
*>
(
ffn_ln_scale
[
i
]
->
data
<
float
>
()),
std
::
array
<
int64_t
,
1
>
{
ffn_ln_scale
[
i
]
->
dims
()[
0
]});
xft_ffn_ln_bias
.
emplace_back
(
const_cast
<
float
*>
(
ffn_ln_bias
[
i
]
->
data
<
float
>
()),
std
::
array
<
int64_t
,
1
>
{
ffn_ln_bias
[
i
]
->
dims
()[
0
]});
// ffn1
auto
ffn1w_dims
=
ffn1_weight
[
i
]
->
dims
();
xft_ffn1_w
.
emplace_back
(
const_cast
<
TW
*>
(
ffn1_weight
[
i
]
->
data
<
TW
>
()),
const_cast
<
float
*>
(
ffn1_weight_max
[
i
]
->
data
<
float
>
()),
std
::
array
<
int64_t
,
2
>
{
ffn1w_dims
[
0
],
ffn1w_dims
[
1
]});
xft_ffn1_bias
.
emplace_back
(
const_cast
<
float
*>
(
ffn1_bias
[
i
]
->
data
<
float
>
()),
std
::
array
<
int64_t
,
1
>
{
ffn1_bias
[
i
]
->
dims
()[
0
]});
// ffn2
auto
ffn2w_dims
=
ffn2_weight
[
i
]
->
dims
();
xft_ffn2_w
.
emplace_back
(
const_cast
<
TW
*>
(
ffn2_weight
[
i
]
->
data
<
TW
>
()),
const_cast
<
float
*>
(
ffn2_weight_max
[
i
]
->
data
<
float
>
()),
std
::
array
<
int64_t
,
2
>
{
ffn2w_dims
[
0
],
ffn2w_dims
[
1
]});
xft_ffn2_bias
.
emplace_back
(
const_cast
<
float
*>
(
ffn2_bias
[
i
]
->
data
<
float
>
()),
std
::
array
<
int64_t
,
1
>
{
ffn2_bias
[
i
]
->
dims
()[
0
]});
// cache kv in
if
(
time_step_value
>
0
)
{
auto
cachekv_dims
=
cache_kv
.
get_ptr
()
->
at
(
i
)
->
dims
();
xft_cache_kv
.
emplace_back
(
reinterpret_cast
<
XPUTypeT
*>
(
const_cast
<
T
*>
(
cache_kv
.
get_ptr
()
->
at
(
i
)
->
data
<
T
>
())),
std
::
array
<
int64_t
,
5
>
{
cachekv_dims
[
0
],
cachekv_dims
[
1
],
cachekv_dims
[
2
],
cachekv_dims
[
3
],
cachekv_dims
[
4
]});
}
// cache kv out
auto
cachekv_out_dims
=
cache_kv_out
[
i
]
->
dims
();
xft_cache_kv_out
.
emplace_back
(
reinterpret_cast
<
XPUTypeT
*>
(
ctx
.
template
Alloc
<
T
>(
cache_kv_out
[
i
])),
std
::
array
<
int64_t
,
5
>
{
cachekv_out_dims
[
0
],
cachekv_out_dims
[
1
],
cachekv_out_dims
[
2
],
cachekv_out_dims
[
3
],
cachekv_out_dims
[
4
]});
}
xft
::
NlpParam
param
;
param
.
num_layer
=
layers
;
param
.
n_head
=
num_head
;
param
.
size_per_head
=
dim_head
;
param
.
hidden_act
=
act_method
;
param
.
is_fuse_qkv
=
true
;
r
=
xft
::
fused_multi_transformer
<
XPUTypeT
,
TW
,
int16_t
>
(
ctx
.
x_context
(),
xft_x
,
xft_cache_kv
,
xft_src_mask
,
xft_ln_scale
,
xft_ln_bias
,
xft_qkvw
,
xft_qkv_bias
,
xft_out_linear_w
,
xft_out_linear_bias
,
xft_ffn_ln_scale
,
xft_ffn_ln_bias
,
xft_ffn1_w
,
xft_ffn1_bias
,
xft_ffn2_w
,
xft_ffn2_bias
,
param
,
time_step_value
,
&
xft_out
,
xft_cache_kv_out
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"xft::fused_multi_transformer"
);
#else
LOG
(
FATAL
)
<<
"fused_multi_transformer_xpu is not supported since it's not "
"compiled with XPU_XFT"
;
#endif
}
}
// namespace fusion
}
// namespace phi
PD_REGISTER_KERNEL
(
fused_multi_transformer_xpu
,
XPU
,
ALL_LAYOUT
,
phi
::
fusion
::
FusedMultiTransformerXpuKernel
,
float
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
20
).
SetBackend
(
phi
::
Backend
::
CPU
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录