Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
071708fa
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
071708fa
编写于
11月 17, 2022
作者:
T
taixiurong
提交者:
GitHub
11月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
xpu-paddlepaddle-41 [任务] ffn and attention test=kunlun (#46658)
上级
b4460eee
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
2723 addition
and
1 deletion
+2723
-1
paddle/fluid/operators/fused/CMakeLists.txt
paddle/fluid/operators/fused/CMakeLists.txt
+2
-0
paddle/fluid/operators/fused/fused_attention_op_xpu.cc
paddle/fluid/operators/fused/fused_attention_op_xpu.cc
+939
-0
paddle/fluid/operators/fused/fused_feedforward_op_xpu.cc
paddle/fluid/operators/fused/fused_feedforward_op_xpu.cc
+828
-0
paddle/fluid/operators/fused/xpu_fused_common_function.h
paddle/fluid/operators/fused/xpu_fused_common_function.h
+224
-0
paddle/fluid/platform/device/xpu/xpu2_op_list.h
paddle/fluid/platform/device/xpu/xpu2_op_list.h
+12
-0
paddle/phi/kernels/xpu/xpu_api_wrapper.h
paddle/phi/kernels/xpu/xpu_api_wrapper.h
+8
-1
python/paddle/fluid/tests/unittests/xpu/test_fused_attention_op_xpu.py
.../fluid/tests/unittests/xpu/test_fused_attention_op_xpu.py
+331
-0
python/paddle/fluid/tests/unittests/xpu/test_fused_feedforward_op_xpu.py
...luid/tests/unittests/xpu/test_fused_feedforward_op_xpu.py
+379
-0
未找到文件。
paddle/fluid/operators/fused/CMakeLists.txt
浏览文件 @
071708fa
...
...
@@ -38,6 +38,8 @@ if(WITH_XPU)
op_library
(
resnet_basic_block_op
)
op_library
(
resnet_unit_op
)
op_library
(
fused_gemm_epilogue_op
)
op_library
(
fused_attention_op
)
op_library
(
fused_feedforward_op
)
endif
()
if
(
WITH_GPU OR WITH_ROCM
)
...
...
paddle/fluid/operators/fused/fused_attention_op_xpu.cc
0 → 100644
浏览文件 @
071708fa
此差异已折叠。
点击以展开。
paddle/fluid/operators/fused/fused_feedforward_op_xpu.cc
0 → 100644
浏览文件 @
071708fa
此差异已折叠。
点击以展开。
paddle/fluid/operators/fused/xpu_fused_common_function.h
0 → 100644
浏览文件 @
071708fa
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
phi
::
DenseTensor
;
struct
XPUDropoutParam
{
float
dropout_prob
;
bool
is_upscale_in_train
;
bool
is_test
;
bool
fix_seed
;
const
Tensor
*
tensor_seed
;
int
seed_val
;
XPUDropoutParam
()
{
fix_seed
=
false
;
is_test
=
false
;
is_upscale_in_train
=
false
;
dropout_prob
=
0.5
;
tensor_seed
=
nullptr
;
seed_val
=
0
;
}
XPUDropoutParam
(
const
framework
::
ExecutionContext
&
context
,
const
int
dropout_index
)
{
std
::
string
pre_fix
=
"dropout"
;
std
::
string
str_index
=
std
::
to_string
(
dropout_index
);
if
(
dropout_index
>
0
)
{
pre_fix
=
pre_fix
+
str_index
+
"_"
;
}
else
{
pre_fix
=
pre_fix
+
"_"
;
}
dropout_prob
=
context
.
Attr
<
float
>
(
pre_fix
+
"rate"
);
auto
&
dropout_implementation
=
context
.
Attr
<
std
::
string
>
(
pre_fix
+
"implementation"
);
is_upscale_in_train
=
(
dropout_implementation
==
"upscale_in_train"
);
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
fix_seed
=
context
.
Attr
<
bool
>
(
pre_fix
+
"fix_seed"
);
std
::
string
str_seed
=
"Dropout"
;
if
(
dropout_index
>
0
)
{
str_seed
=
str_seed
+
str_index
+
"Seed"
;
}
else
{
str_seed
=
str_seed
+
"Seed"
;
}
tensor_seed
=
context
.
HasInput
(
str_seed
)
?
context
.
Input
<
Tensor
>
(
str_seed
)
:
nullptr
;
if
(
tensor_seed
)
{
seed_val
=
*
(
tensor_seed
->
data
<
int
>
());
}
else
{
seed_val
=
fix_seed
?
context
.
Attr
<
int
>
(
pre_fix
+
"seed"
)
:
0
;
}
}
void
initXPUDropoutParam
(
float
dropout_prob_
,
bool
is_upscale_in_train_
,
bool
is_test_
,
bool
fix_seed_
,
const
Tensor
*
tensor_seed
,
int
seed_val_
)
{
dropout_prob
=
dropout_prob_
;
is_upscale_in_train
=
is_upscale_in_train_
;
is_test
=
is_test_
;
fix_seed
=
fix_seed_
;
if
(
tensor_seed
)
{
seed_val
=
*
(
tensor_seed
->
data
<
int
>
());
}
else
{
seed_val
=
fix_seed
?
seed_val_
:
0
;
}
}
void
initXPUDropoutParam
(
const
framework
::
ExecutionContext
&
context
,
int
dropout_index
)
{
std
::
string
pre_fix
=
"dropout"
;
std
::
string
str_index
=
std
::
to_string
(
dropout_index
);
if
(
dropout_index
>
0
)
{
pre_fix
=
pre_fix
+
str_index
+
"_"
;
}
else
{
pre_fix
=
pre_fix
+
"_"
;
}
dropout_prob
=
context
.
Attr
<
float
>
(
pre_fix
+
"rate"
);
auto
&
dropout_implementation
=
context
.
Attr
<
std
::
string
>
(
pre_fix
+
"implementation"
);
is_upscale_in_train
=
(
dropout_implementation
==
"upscale_in_train"
);
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
fix_seed
=
context
.
Attr
<
bool
>
(
pre_fix
+
"fix_seed"
);
std
::
string
str_seed
=
"Dropout"
;
if
(
dropout_index
>
0
)
{
str_seed
=
str_seed
+
str_index
+
"Seed"
;
}
else
{
str_seed
=
str_seed
+
"Seed"
;
}
tensor_seed
=
context
.
HasInput
(
str_seed
)
?
context
.
Input
<
Tensor
>
(
str_seed
)
:
nullptr
;
if
(
tensor_seed
)
{
seed_val
=
*
(
tensor_seed
->
data
<
int
>
());
}
else
{
seed_val
=
fix_seed
?
context
.
Attr
<
int
>
(
pre_fix
+
"seed"
)
:
0
;
}
}
};
/******************
* check is l3
*******************/
static
bool
is_in_l3
(
const
void
*
addr
)
{
int64_t
addr_int
=
(
int64_t
)
addr
;
int
addr_int_high
=
addr_int
>>
32
;
return
(
addr_int_high
==
0
);
}
/*************************
* dropout
*************************/
template
<
typename
T
>
void
Dropout
(
xpu
::
Context
*
xpu_ctx
,
const
T
*
x
,
T
*
mask
,
T
*
y
,
const
XPUDropoutParam
&
param
,
int
len
)
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
int
r
=
XPU_SUCCESS
;
if
(
param
.
dropout_prob
==
0.0
f
)
{
r
=
xpu
::
copy
(
xpu_ctx
,
reinterpret_cast
<
const
XPUType
*>
(
x
),
reinterpret_cast
<
XPUType
*>
(
y
),
len
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"copy"
);
return
;
}
if
(
!
param
.
is_test
)
{
if
(
param
.
dropout_prob
==
1.0
f
)
{
r
=
xpu
::
constant
(
xpu_ctx
,
reinterpret_cast
<
XPUType
*>
(
y
),
len
,
XPUType
(
0
));
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"constant"
);
r
=
xpu
::
constant
(
xpu_ctx
,
reinterpret_cast
<
XPUType
*>
(
mask
),
len
,
XPUType
(
0
));
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"constant"
);
}
else
{
r
=
xpu
::
dropout
(
xpu_ctx
,
reinterpret_cast
<
const
XPUType
*>
(
x
),
reinterpret_cast
<
XPUType
*>
(
y
),
reinterpret_cast
<
XPUType
*>
(
mask
),
param
.
seed_val
,
len
,
param
.
is_upscale_in_train
,
param
.
dropout_prob
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"dropout"
);
}
}
else
{
float
scale
=
(
param
.
is_upscale_in_train
)
?
(
1.0
)
:
(
static_cast
<
float
>
(
1.0
f
-
param
.
dropout_prob
));
r
=
xpu
::
scale
(
xpu_ctx
,
reinterpret_cast
<
const
XPUType
*>
(
x
),
reinterpret_cast
<
XPUType
*>
(
y
),
len
,
false
,
scale
,
0.0
f
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"scale"
);
}
}
template
<
typename
T
>
void
DropoutGrad
(
xpu
::
Context
*
xpu_ctx
,
const
T
*
dy
,
const
T
*
mask
,
T
*
dx
,
const
XPUDropoutParam
&
param
,
int
len
)
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
if
(
param
.
dropout_prob
==
0.0
f
)
{
int
r
=
xpu
::
copy
(
xpu_ctx
,
reinterpret_cast
<
const
XPUType
*>
(
dy
),
reinterpret_cast
<
XPUType
*>
(
dx
),
len
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"copy"
);
return
;
}
if
(
!
param
.
is_upscale_in_train
)
{
int
r
=
xpu
::
mul
(
xpu_ctx
,
reinterpret_cast
<
const
XPUType
*>
(
dy
),
reinterpret_cast
<
const
XPUType
*>
(
mask
),
reinterpret_cast
<
XPUType
*>
(
dx
),
len
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"mul"
);
}
else
{
int
r
=
xpu
::
dropout_grad
(
xpu_ctx
,
reinterpret_cast
<
const
XPUType
*>
(
mask
),
reinterpret_cast
<
const
XPUType
*>
(
dy
),
reinterpret_cast
<
XPUType
*>
(
dx
),
param
.
dropout_prob
,
len
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"dropout_grad"
);
}
}
}
// namespace operators
}
// namespace paddle
#endif
paddle/fluid/platform/device/xpu/xpu2_op_list.h
浏览文件 @
071708fa
...
...
@@ -704,6 +704,18 @@ XPUOpMap& get_kl2_ops() {
{
"fused_gemm_epilogue_grad"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
()),
pOpKernelType
(
vartype
::
FP16
,
XPUPlace
())})},
{
"fused_attention"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
()),
pOpKernelType
(
vartype
::
FP16
,
XPUPlace
())})},
{
"fused_attention_grad"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
()),
pOpKernelType
(
vartype
::
FP16
,
XPUPlace
())})},
{
"fused_feedforward"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
()),
pOpKernelType
(
vartype
::
FP16
,
XPUPlace
())})},
{
"fused_feedforward_grad"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
()),
pOpKernelType
(
vartype
::
FP16
,
XPUPlace
())})},
};
return
s_xpu2_kernels
;
...
...
paddle/phi/kernels/xpu/xpu_api_wrapper.h
浏览文件 @
071708fa
...
...
@@ -382,7 +382,8 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx,
const
T
*
y
,
T
*
out
,
const
XpuFcInfo
&
fcinfo
,
float
alpha
)
{
float
alpha
,
bool
is_grad
=
false
)
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
int
fccal_type
=
FCCalcType
<
XPUType
>
();
...
...
@@ -398,6 +399,12 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx,
};
auto
fc_api
=
fc_api_list
[
fccal_type
];
if
(
std
::
getenv
(
"XPU_PADDLE_FC_GRAD_LOCAL"
)
!=
nullptr
)
{
if
(
is_grad
)
{
fc_api
=
fc_api_list
[
2
];
}
}
auto
fc_batch_api
=
fc_batch_api_list
[
fccal_type
];
int
m
=
fcinfo
.
m
;
...
...
python/paddle/fluid/tests/unittests/xpu/test_fused_attention_op_xpu.py
0 → 100644
浏览文件 @
071708fa
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
sys
sys
.
path
.
append
(
".."
)
import
paddle
import
paddle.nn.functional
as
F
import
paddle.incubate.nn.functional
as
incubate_f
from
paddle.nn.layer.norm
import
LayerNorm
from
paddle.nn.layer.common
import
Linear
,
Dropout
from
paddle.nn.layer.transformer
import
_convert_attention_mask
from
paddle
import
tensor
from
paddle.fluid
import
layers
import
unittest
from
op_test_xpu
import
XPUOpTest
from
paddle.fluid.framework
import
default_main_program
from
xpu.get_test_cover_info
import
(
create_test_class
,
get_xpu_op_support_types
,
XPUOpTestWrapper
,
)
default_main_program
().
random_seed
=
42
class
XPUTestFusedAttentionOp
(
XPUOpTestWrapper
):
def
__init__
(
self
):
self
.
op_name
=
'fused_attention'
self
.
use_dynamic_create_class
=
False
class
TestFusedAttentionOp
(
XPUOpTest
):
def
setUp
(
self
):
self
.
config
()
self
.
generate_input_data
()
self
.
rtol
=
1e-5
self
.
atol
=
1e-3
if
self
.
x_type
==
np
.
float16
or
str
(
self
.
x_type
)
==
"float16"
:
self
.
atol
=
1e-1
paddle
.
set_default_dtype
(
self
.
x_type
)
self
.
__class__
.
op_type
=
"fused_attention"
# use autograd to check grad in this unittest.
self
.
__class__
.
no_need_check_grad
=
True
self
.
q_proj
=
Linear
(
self
.
embed_dim
,
self
.
embed_dim
,
self
.
weight_attr
,
bias_attr
=
self
.
bias_attr
,
)
self
.
k_proj
=
Linear
(
self
.
kdim
,
self
.
embed_dim
,
self
.
weight_attr
,
bias_attr
=
self
.
bias_attr
,
)
self
.
v_proj
=
Linear
(
self
.
vdim
,
self
.
embed_dim
,
self
.
weight_attr
,
bias_attr
=
self
.
bias_attr
,
)
self
.
out_proj
=
Linear
(
self
.
embed_dim
,
self
.
embed_dim
,
self
.
weight_attr
,
bias_attr
=
self
.
bias_attr
,
)
paddle
.
set_default_dtype
(
np
.
float32
)
self
.
norm1
=
LayerNorm
(
self
.
embed_dim
)
self
.
norm2
=
LayerNorm
(
self
.
embed_dim
)
paddle
.
set_default_dtype
(
self
.
x_type
)
self
.
dropout
=
Dropout
(
self
.
dropout_prob
,
mode
=
"upscale_in_train"
)
def
config
(
self
):
self
.
x_type
=
self
.
in_type
self
.
attn_mask_type
=
np
.
float32
self
.
pre_layer_norm
=
True
self
.
has_attn_mask
=
False
self
.
training
=
True
self
.
batch_size
=
8
self
.
query_length
=
128
self
.
cache_length
=
128
self
.
head_dim
=
64
self
.
num_heads
=
16
self
.
embed_dim
=
self
.
head_dim
*
self
.
num_heads
self
.
dropout_prob
=
0.0
self
.
attn_dropout_prob
=
0.0
self
.
weight_attr
=
None
self
.
bias_attr
=
None
self
.
kdim
,
self
.
vdim
=
self
.
embed_dim
,
self
.
embed_dim
self
.
key_length
,
self
.
value_length
=
(
self
.
query_length
,
self
.
query_length
,
)
def
generate_input_data
(
self
):
self
.
query
=
np
.
random
.
rand
(
self
.
batch_size
,
self
.
query_length
,
self
.
embed_dim
).
astype
(
self
.
x_type
)
out_seq_len
=
self
.
key_length
if
self
.
has_attn_mask
:
# [B, n_head, seq_len, out_seq_len]
self
.
attn_mask
=
np
.
ones
(
(
self
.
batch_size
,
self
.
num_heads
,
self
.
query_length
,
out_seq_len
,
),
dtype
=
self
.
attn_mask_type
,
)
else
:
self
.
attn_mask
=
None
self
.
key
,
self
.
value
=
self
.
query
,
self
.
query
self
.
dout
=
np
.
random
.
random
(
(
self
.
batch_size
,
self
.
query_length
,
self
.
embed_dim
)
).
astype
(
self
.
x_type
)
def
GetBaselineOut
(
self
):
paddle
.
disable_static
()
tensor_query
=
paddle
.
to_tensor
(
self
.
query
,
stop_gradient
=
False
)
if
self
.
has_attn_mask
:
attn_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
False
)
else
:
attn_mask
=
None
residual
=
tensor_query
ln1_out
=
tensor_query
if
self
.
pre_layer_norm
:
ln1_out
=
self
.
norm1
(
tensor_query
)
q
=
self
.
q_proj
(
ln1_out
)
q
=
tensor
.
reshape
(
x
=
q
,
shape
=
[
0
,
0
,
self
.
num_heads
,
self
.
head_dim
])
q_out
=
tensor
.
transpose
(
x
=
q
,
perm
=
[
0
,
2
,
1
,
3
])
k
=
self
.
k_proj
(
ln1_out
)
v
=
self
.
v_proj
(
ln1_out
)
k
=
tensor
.
reshape
(
x
=
k
,
shape
=
[
0
,
0
,
self
.
num_heads
,
self
.
head_dim
])
k_out
=
tensor
.
transpose
(
x
=
k
,
perm
=
[
0
,
2
,
1
,
3
])
v
=
tensor
.
reshape
(
x
=
v
,
shape
=
[
0
,
0
,
self
.
num_heads
,
self
.
head_dim
])
v_out
=
tensor
.
transpose
(
x
=
v
,
perm
=
[
0
,
2
,
1
,
3
])
# [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, out_seq_len]
qk_out
=
layers
.
matmul
(
x
=
q_out
*
self
.
head_dim
**-
0.5
,
y
=
k_out
,
transpose_y
=
True
)
if
attn_mask
is
not
None
:
attn_mask
=
_convert_attention_mask
(
attn_mask
,
qk_out
.
dtype
)
attn_mask_out
=
qk_out
+
attn_mask
softmax_out
=
F
.
softmax
(
attn_mask_out
)
else
:
softmax_out
=
F
.
softmax
(
qk_out
)
if
self
.
dropout_prob
:
dropout_out
=
F
.
dropout
(
softmax_out
,
self
.
dropout_prob
,
training
=
self
.
training
,
mode
=
"upscale_in_train"
,
)
# [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, head_dim]
qktv_out
=
tensor
.
matmul
(
dropout_out
,
v_out
)
else
:
qktv_out
=
tensor
.
matmul
(
softmax_out
,
v_out
)
fmha_out
=
tensor
.
transpose
(
qktv_out
,
perm
=
[
0
,
2
,
1
,
3
])
out_linear_in
=
tensor
.
reshape
(
x
=
fmha_out
,
shape
=
[
0
,
0
,
fmha_out
.
shape
[
2
]
*
fmha_out
.
shape
[
3
]]
)
out
=
self
.
out_proj
(
out_linear_in
)
residual_out
=
residual
+
self
.
dropout
(
out
)
if
not
self
.
pre_layer_norm
:
final_out
=
self
.
norm1
(
residual_out
)
else
:
final_out
=
residual_out
paddle
.
autograd
.
backward
(
[
final_out
],
[
paddle
.
to_tensor
(
self
.
dout
)],
retain_graph
=
True
)
return
final_out
,
tensor_query
.
grad
def
GetFusedAttentionOut
(
self
):
paddle
.
disable_static
()
q_proj_weight
=
paddle
.
to_tensor
(
self
.
q_proj
.
weight
,
stop_gradient
=
False
)
k_proj_weight
=
paddle
.
to_tensor
(
self
.
k_proj
.
weight
,
stop_gradient
=
False
)
v_proj_weight
=
paddle
.
to_tensor
(
self
.
v_proj
.
weight
,
stop_gradient
=
False
)
out_linear_weight
=
paddle
.
to_tensor
(
self
.
out_proj
.
weight
,
stop_gradient
=
False
)
if
self
.
bias_attr
is
False
:
qkv_bias_tensor
=
None
out_linear_bias
=
None
else
:
q_proj_bias
=
paddle
.
to_tensor
(
self
.
q_proj
.
bias
,
stop_gradient
=
False
)
k_proj_bias
=
paddle
.
to_tensor
(
self
.
k_proj
.
bias
,
stop_gradient
=
False
)
v_proj_bias
=
paddle
.
to_tensor
(
self
.
v_proj
.
bias
,
stop_gradient
=
False
)
qkv_bias
=
np
.
concatenate
(
(
q_proj_bias
.
numpy
(),
k_proj_bias
.
numpy
(),
v_proj_bias
.
numpy
(),
)
)
qkv_bias
=
qkv_bias
.
reshape
((
3
,
self
.
num_heads
,
self
.
head_dim
))
qkv_bias_tensor
=
paddle
.
to_tensor
(
qkv_bias
,
stop_gradient
=
False
)
out_linear_bias
=
paddle
.
to_tensor
(
self
.
out_proj
.
bias
,
stop_gradient
=
False
)
ln1_scale
=
paddle
.
to_tensor
(
self
.
norm1
.
weight
,
stop_gradient
=
False
)
ln1_bias
=
paddle
.
to_tensor
(
self
.
norm1
.
bias
,
stop_gradient
=
False
)
ln2_scale
=
paddle
.
to_tensor
(
self
.
norm2
.
weight
,
stop_gradient
=
False
)
ln2_bias
=
paddle
.
to_tensor
(
self
.
norm2
.
bias
,
stop_gradient
=
False
)
q_proj_weight
=
q_proj_weight
.
numpy
().
transpose
((
1
,
0
))
k_proj_weight
=
k_proj_weight
.
numpy
().
transpose
((
1
,
0
))
v_proj_weight
=
v_proj_weight
.
numpy
().
transpose
((
1
,
0
))
qkv_weight
=
np
.
concatenate
(
(
q_proj_weight
,
k_proj_weight
,
v_proj_weight
)
)
qkv_weight
=
qkv_weight
.
reshape
(
(
3
,
self
.
num_heads
,
self
.
head_dim
,
self
.
embed_dim
)
)
x
=
paddle
.
to_tensor
(
self
.
query
,
stop_gradient
=
False
)
cache_kv
=
None
if
self
.
has_attn_mask
:
attn_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
False
)
else
:
attn_mask
=
None
qkv_weight_tensor
=
paddle
.
to_tensor
(
qkv_weight
,
stop_gradient
=
False
)
epsilon
=
1e-05
ln2_epsilon
=
1e-05
if
attn_mask
is
not
None
:
attn_mask
=
_convert_attention_mask
(
attn_mask
,
x
.
dtype
)
final_out
=
incubate_f
.
fused_multi_head_attention
(
x
,
qkv_weight_tensor
,
out_linear_weight
,
self
.
pre_layer_norm
,
ln1_scale
,
ln1_bias
,
ln2_scale
,
ln2_bias
,
epsilon
,
qkv_bias_tensor
,
out_linear_bias
,
cache_kv
,
attn_mask
,
self
.
dropout_prob
,
self
.
attn_dropout_prob
,
ln2_epsilon
,
)
paddle
.
autograd
.
backward
(
[
final_out
],
[
paddle
.
to_tensor
(
self
.
dout
)],
retain_graph
=
True
)
return
final_out
,
x
.
grad
def
test_fused_attention_op
(
self
):
final_out_ref
,
x_grad_ref
=
self
.
GetBaselineOut
()
final_out
,
x_grad
=
self
.
GetFusedAttentionOut
()
np
.
testing
.
assert_allclose
(
final_out_ref
,
final_out
.
numpy
(),
rtol
=
self
.
rtol
,
atol
=
self
.
atol
)
np
.
testing
.
assert_allclose
(
x_grad_ref
,
x_grad
.
numpy
(),
rtol
=
self
.
rtol
,
atol
=
self
.
atol
)
class
TestFusedAttentionOpPreLn
(
TestFusedAttentionOp
):
def
config
(
self
):
super
().
config
()
self
.
pre_layer_norm
=
True
class
TestFusedAttentionOpNoneAttnMask
(
TestFusedAttentionOp
):
def
config
(
self
):
super
().
config
()
self
.
pre_layer_norm
=
True
self
.
has_attn_mask
=
False
support_types
=
get_xpu_op_support_types
(
'fused_attention'
)
for
stype
in
support_types
:
create_test_class
(
globals
(),
XPUTestFusedAttentionOp
,
stype
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/xpu/test_fused_feedforward_op_xpu.py
0 → 100644
浏览文件 @
071708fa
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
sys
sys
.
path
.
append
(
".."
)
import
paddle
from
paddle.nn.layer
import
transformer
import
paddle.nn.functional
as
F
import
paddle.incubate.nn.functional
as
incubate_f
from
paddle.nn.layer.norm
import
LayerNorm
from
paddle.nn.layer.common
import
Linear
,
Dropout
import
unittest
from
op_test_xpu
import
XPUOpTest
from
paddle.fluid.framework
import
default_main_program
from
xpu.get_test_cover_info
import
(
create_test_class
,
XPUOpTestWrapper
,
)
class
XPUTestFusedFFNOp
(
XPUOpTestWrapper
):
def
__init__
(
self
):
self
.
op_name
=
'fused_feedforward'
self
.
use_dynamic_create_class
=
False
class
TestFusedFFNOp
(
XPUOpTest
):
def
getDtype
(
self
):
self
.
dtype
=
self
.
in_type
self
.
layer_norm_dtype
=
"float32"
def
getShape
(
self
):
self
.
batch_size
=
np
.
random
.
randint
(
1
,
32
)
self
.
query_length
=
np
.
random
.
randint
(
32
,
128
)
self
.
d_model
=
np
.
random
.
randint
(
32
,
512
)
self
.
dim_feedforward
=
np
.
random
.
randint
(
32
,
512
)
def
getDiff
(
self
):
self
.
rtol
=
1e-2
self
.
atol
=
1e-3
if
self
.
dtype
==
np
.
float16
or
self
.
dtype
==
"float16"
:
self
.
atol
=
1e-1
def
getActivation
(
self
):
self
.
act_method
=
"gelu"
def
getNormalizeBefore
(
self
):
self
.
pre_layer_norm
=
False
def
setUp
(
self
):
paddle
.
disable_static
()
self
.
__class__
.
op_type
=
"fused_feedforward"
# check grad in test_out_and_grad()
self
.
__class__
.
no_need_check_grad
=
True
self
.
getDtype
()
self
.
getShape
()
self
.
getDiff
()
self
.
getActivation
()
self
.
getNormalizeBefore
()
paddle
.
set_default_dtype
(
self
.
dtype
)
self
.
weight_attr
=
None
self
.
bias_attr
=
None
self
.
weight_attrs
=
transformer
.
_convert_param_attr_to_list
(
self
.
weight_attr
,
2
)
self
.
bias_attrs
=
transformer
.
_convert_param_attr_to_list
(
self
.
bias_attr
,
2
)
self
.
linear1
=
Linear
(
self
.
d_model
,
self
.
dim_feedforward
,
self
.
weight_attrs
[
1
],
bias_attr
=
self
.
bias_attrs
[
1
],
)
self
.
linear2
=
Linear
(
self
.
dim_feedforward
,
self
.
d_model
,
self
.
weight_attrs
[
1
],
bias_attr
=
self
.
bias_attrs
[
1
],
)
paddle
.
set_default_dtype
(
self
.
layer_norm_dtype
)
self
.
norm1
=
LayerNorm
(
self
.
d_model
)
self
.
norm2
=
LayerNorm
(
self
.
d_model
)
paddle
.
set_default_dtype
(
self
.
dtype
)
self
.
dropout1
=
Dropout
(
0.0
,
mode
=
"upscale_in_train"
)
self
.
dropout2
=
Dropout
(
0.0
,
mode
=
"upscale_in_train"
)
self
.
activation
=
getattr
(
F
,
self
.
act_method
)
self
.
src
=
np
.
random
.
random
(
(
self
.
batch_size
,
self
.
query_length
,
self
.
d_model
)
).
astype
(
self
.
dtype
)
self
.
dout
=
np
.
random
.
random
(
(
self
.
batch_size
,
self
.
query_length
,
self
.
d_model
)
).
astype
(
self
.
dtype
)
def
Base
(
self
):
paddle
.
disable_static
()
tensor_src
=
paddle
.
to_tensor
(
self
.
src
,
stop_gradient
=
False
)
residual
=
tensor_src
if
self
.
pre_layer_norm
:
ln1_out
=
self
.
norm1
(
tensor_src
)
linear2_out
=
self
.
linear2
(
self
.
dropout1
(
self
.
activation
(
self
.
linear1
(
ln1_out
)))
)
dropout2_out
=
residual
+
self
.
dropout2
(
linear2_out
)
paddle
.
autograd
.
backward
(
[
dropout2_out
],
[
paddle
.
to_tensor
(
self
.
dout
)],
True
)
return
dropout2_out
,
tensor_src
.
grad
else
:
linear2_out
=
self
.
linear2
(
self
.
dropout1
(
self
.
activation
(
self
.
linear1
(
tensor_src
)))
)
dropout2_out
=
residual
+
self
.
dropout2
(
linear2_out
)
dropout2_out
=
self
.
norm2
(
dropout2_out
)
paddle
.
autograd
.
backward
(
[
dropout2_out
],
[
paddle
.
to_tensor
(
self
.
dout
)],
True
)
return
dropout2_out
,
tensor_src
.
grad
def
FusedFFN
(
self
):
paddle
.
disable_static
()
linear1_weight
=
paddle
.
to_tensor
(
self
.
linear1
.
weight
,
stop_gradient
=
False
)
linear1_bias
=
paddle
.
to_tensor
(
self
.
linear1
.
bias
,
stop_gradient
=
False
)
linear2_weight
=
paddle
.
to_tensor
(
self
.
linear2
.
weight
,
stop_gradient
=
False
)
linear2_bias
=
paddle
.
to_tensor
(
self
.
linear2
.
bias
,
stop_gradient
=
False
)
ln1_scale
=
paddle
.
to_tensor
(
self
.
norm1
.
weight
,
stop_gradient
=
False
)
ln1_bias
=
paddle
.
to_tensor
(
self
.
norm1
.
bias
,
stop_gradient
=
False
)
ln2_scale
=
paddle
.
to_tensor
(
self
.
norm2
.
weight
,
stop_gradient
=
False
)
ln2_bias
=
paddle
.
to_tensor
(
self
.
norm2
.
bias
,
stop_gradient
=
False
)
x
=
paddle
.
to_tensor
(
self
.
src
,
stop_gradient
=
False
)
out
=
incubate_f
.
fused_feedforward
(
x
,
linear1_weight
,
linear2_weight
,
linear1_bias
,
linear2_bias
,
ln1_scale
,
ln1_bias
,
ln2_scale
,
ln2_bias
,
0.0
,
0.0
,
activation
=
self
.
act_method
,
pre_layer_norm
=
self
.
pre_layer_norm
,
)
paddle
.
autograd
.
backward
([
out
],
[
paddle
.
to_tensor
(
self
.
dout
)])
return
out
,
x
.
grad
def
test_out_and_grad
(
self
):
default_main_program
().
random_seed
=
42
base_out
,
base_grad
=
self
.
Base
()
fused_out
,
fused_grad
=
self
.
FusedFFN
()
np
.
testing
.
assert_allclose
(
base_out
.
numpy
(),
fused_out
.
numpy
(),
rtol
=
self
.
rtol
,
atol
=
self
.
atol
,
)
np
.
testing
.
assert_allclose
(
base_grad
.
numpy
(),
fused_grad
.
numpy
(),
rtol
=
self
.
rtol
,
atol
=
self
.
atol
,
)
class
TestFusedFFNOpActivation
(
TestFusedFFNOp
):
def
getActivation
(
self
):
self
.
act_method
=
"relu"
class
TestFusedFFNOpNormalizeBefore
(
TestFusedFFNOp
):
def
getNormalizeBefore
(
self
):
self
.
pre_layer_norm
=
True
def
getShape
(
self
):
self
.
batch_size
=
1
self
.
query_length
=
1
self
.
d_model
=
8
self
.
dim_feedforward
=
8
class
APITestStaticFusedFFN
(
unittest
.
TestCase
):
def
test_static
(
self
):
paddle
.
enable_static
()
default_main_program
().
random_seed
=
42
dtype
=
"float32"
layer_norm_dtype
=
"float32"
batch_size
=
1
d_model
=
8
dim_feedforward
=
8
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
batch_size
,
d_model
,
dim_feedforward
],
dtype
=
dtype
)
linear1_weight
=
paddle
.
static
.
data
(
name
=
'linear1_weight'
,
shape
=
[
d_model
,
dim_feedforward
],
dtype
=
dtype
)
linear1_bias
=
paddle
.
static
.
data
(
name
=
'linear1_bias'
,
shape
=
[
dim_feedforward
],
dtype
=
dtype
)
linear2_weight
=
paddle
.
static
.
data
(
name
=
'linear2_weight'
,
shape
=
[
dim_feedforward
,
d_model
],
dtype
=
dtype
)
linear2_bias
=
paddle
.
static
.
data
(
name
=
'linear2_bias'
,
shape
=
[
d_model
])
ln1_scale
=
paddle
.
static
.
data
(
name
=
'ln1_scale'
,
shape
=
[
d_model
])
ln1_bias
=
paddle
.
static
.
data
(
name
=
'ln1_scale'
,
shape
=
[
d_model
])
ln2_scale
=
paddle
.
static
.
data
(
name
=
'ln2_scale'
,
shape
=
[
d_model
])
ln2_bias
=
paddle
.
static
.
data
(
name
=
'ln2_scale'
,
shape
=
[
d_model
])
fused_out
=
incubate_f
.
fused_feedforward
(
x
,
linear1_weight
,
linear2_weight
,
linear1_bias
,
linear2_bias
,
ln1_scale
,
ln1_bias
,
ln2_scale
,
ln2_bias
,
0.0
,
0.0
,
activation
=
"relu"
,
pre_layer_norm
=
False
,
)
linear1_out
=
F
.
linear
(
x
,
linear1_weight
,
linear1_bias
)
act_out
=
F
.
relu
(
linear1_out
)
dropout1_out
=
F
.
dropout
(
x
=
act_out
,
p
=
0.0
,
training
=
False
)
linear2_out
=
F
.
linear
(
dropout1_out
,
linear2_weight
,
linear2_bias
)
dropout2_out
=
x
+
F
.
dropout
(
x
=
linear2_out
,
p
=
0.0
,
training
=
False
)
ln_out
=
F
.
layer_norm
(
dropout2_out
,
normalized_shape
=
list
([
d_model
]),
weight
=
ln2_scale
,
bias
=
ln2_bias
,
)
exe
=
paddle
.
static
.
Executor
(
paddle
.
XPUPlace
(
0
))
x_data
=
np
.
random
.
random
(
(
batch_size
,
d_model
,
dim_feedforward
)
).
astype
(
dtype
)
linear1_weight_data
=
np
.
random
.
random
(
(
d_model
,
dim_feedforward
)
).
astype
(
dtype
)
linear1_bias_data
=
np
.
zeros
((
dim_feedforward
)).
astype
(
dtype
)
linear2_weight_data
=
np
.
random
.
random
(
(
dim_feedforward
,
d_model
)
).
astype
(
dtype
)
linear2_bias_data
=
np
.
zeros
((
d_model
)).
astype
(
dtype
)
ln1_scale_data
=
np
.
ones
((
d_model
)).
astype
(
layer_norm_dtype
)
ln1_bias_data
=
np
.
zeros
((
d_model
)).
astype
(
layer_norm_dtype
)
ln2_scale_data
=
np
.
ones
((
d_model
)).
astype
(
layer_norm_dtype
)
ln2_bias_data
=
np
.
zeros
((
d_model
)).
astype
(
layer_norm_dtype
)
res_list
=
[
fused_out
,
ln_out
]
real_res
=
[]
for
res
in
res_list
:
fetch
=
exe
.
run
(
feed
=
{
'x'
:
x_data
,
'linear1_weight'
:
linear1_weight_data
,
'linear1_bias'
:
linear1_bias_data
,
'linear2_weight'
:
linear2_weight_data
,
'linear2_bias'
:
linear2_bias_data
,
'ln1_scale'
:
ln1_scale_data
,
'ln1_bias'
:
ln1_bias_data
,
'ln2_scale'
:
ln2_scale_data
,
'ln2_bias'
:
ln2_bias_data
,
},
fetch_list
=
[
res
],
)
real_res
.
append
(
fetch
)
np
.
testing
.
assert_allclose
(
real_res
[
0
],
real_res
[
1
],
rtol
=
1e-05
,
atol
=
0.001
)
class
TestFusedFFNOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
paddle
.
enable_static
()
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
):
def
test_dtype
():
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
1
,
10
,
10
],
dtype
=
"int32"
)
linear1_weight
=
paddle
.
static
.
data
(
name
=
'linear1_weight'
,
shape
=
[
1
,
10
,
10
],
dtype
=
"float32"
)
linear2_weight
=
paddle
.
static
.
data
(
name
=
'linear2_weight'
,
shape
=
[
1
,
10
,
10
],
dtype
=
"float32"
)
incubate_f
.
fused_feedforward
(
x
,
linear1_weight
,
linear2_weight
)
self
.
assertRaises
(
TypeError
,
test_dtype
)
def
test_dropout_rate_type
():
x
=
paddle
.
static
.
data
(
name
=
'x1'
,
shape
=
[
1
,
10
,
10
],
dtype
=
"float32"
)
linear1_weight
=
paddle
.
static
.
data
(
name
=
'linear1_weight1'
,
shape
=
[
10
,
10
],
dtype
=
"float32"
)
linear2_weight
=
paddle
.
static
.
data
(
name
=
'linear2_weight1'
,
shape
=
[
10
,
10
],
dtype
=
"float32"
)
incubate_f
.
fused_feedforward
(
x
,
linear1_weight
,
linear2_weight
,
dropout1_rate
=
"a"
)
self
.
assertRaises
(
TypeError
,
test_dropout_rate_type
)
def
test_dropout_rate_value
():
x
=
paddle
.
static
.
data
(
name
=
'x2'
,
shape
=
[
1
,
10
,
10
],
dtype
=
"float32"
)
linear1_weight
=
paddle
.
static
.
data
(
name
=
'linear1_weight2'
,
shape
=
[
10
,
10
],
dtype
=
"float32"
)
linear2_weight
=
paddle
.
static
.
data
(
name
=
'linear2_weight2'
,
shape
=
[
10
,
10
],
dtype
=
"float32"
)
incubate_f
.
fused_feedforward
(
x
,
linear1_weight
,
linear2_weight
,
dropout2_rate
=-
1
)
self
.
assertRaises
(
ValueError
,
test_dropout_rate_value
)
def
test_dropout_mode
():
x
=
paddle
.
static
.
data
(
name
=
'x3'
,
shape
=
[
1
,
10
,
10
],
dtype
=
"float32"
)
linear1_weight
=
paddle
.
static
.
data
(
name
=
'linear1_weight3'
,
shape
=
[
10
,
10
],
dtype
=
"float32"
)
linear2_weight
=
paddle
.
static
.
data
(
name
=
'linear2_weight3'
,
shape
=
[
10
,
10
],
dtype
=
"float32"
)
incubate_f
.
fused_feedforward
(
x
,
linear1_weight
,
linear2_weight
,
mode
=
'test'
)
self
.
assertRaises
(
ValueError
,
test_dropout_mode
)
support_types
=
{
"float32"
}
# get_xpu_op_support_types('fused_feedforward')
for
stype
in
support_types
:
create_test_class
(
globals
(),
XPUTestFusedFFNOp
,
stype
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录