Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
76eb371e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
76eb371e
编写于
12月 17, 2021
作者:
Z
zhaoyingli
提交者:
GitHub
12月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] add gpt model for unittest (#38202)
* add gpt modeling * update file name
上级
9a8a4c77
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
989 addition
and
0 deletion
+989
-0
python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py
...n/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py
+989
-0
未找到文件。
python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py
0 → 100644
浏览文件 @
76eb371e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
collections
import
random
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
import
paddle.tensor
as
tensor
import
paddle.distributed.auto_parallel
as
auto
from
paddle
import
fluid
from
paddle.fluid
import
layers
from
paddle.distributed
import
fleet
from
paddle.nn.layer.transformer
import
_convert_param_attr_to_list
from
paddle.fluid.initializer
import
Normal
,
NumpyArrayInitializer
paddle
.
enable_static
()
def
init_global
():
global
_global_parallel_strategy
global
_global_process_mesh
global
PP_MESH_LIST
global
DPPP_MESH_LIST
global
MPPP_MESH_LIST
global
DPMPPP_MESH_LIST
class
MultiHeadAttention
(
nn
.
Layer
):
"""
Attention mapps queries and a set of key-value pairs to outputs, and
Multi-Head Attention performs multiple parallel attention to jointly attending
to information from different representation subspaces.
"""
Cache
=
collections
.
namedtuple
(
"Cache"
,
[
"k"
,
"v"
])
StaticCache
=
collections
.
namedtuple
(
"StaticCache"
,
[
"k"
,
"v"
])
def
__init__
(
self
,
embed_dim
,
num_heads
,
dropout
=
0.
,
kdim
=
None
,
vdim
=
None
,
need_weights
=
False
,
weight_attr
=
None
,
bias_attr
=
None
,
fuse
=
False
,
mesh_idx
=
None
):
super
(
MultiHeadAttention
,
self
).
__init__
()
self
.
embed_dim
=
embed_dim
self
.
kdim
=
kdim
if
kdim
is
not
None
else
embed_dim
self
.
vdim
=
vdim
if
vdim
is
not
None
else
embed_dim
self
.
num_heads
=
num_heads
self
.
dropout
=
dropout
self
.
need_weights
=
need_weights
self
.
fuse
=
fuse
self
.
mesh_idx
=
mesh_idx
self
.
head_dim
=
embed_dim
//
num_heads
assert
self
.
head_dim
*
num_heads
==
self
.
embed_dim
,
"embed_dim must be divisible by num_heads"
if
self
.
fuse
:
assert
self
.
kdim
==
embed_dim
assert
self
.
vdim
==
embed_dim
self
.
qkv_proj
=
nn
.
Linear
(
embed_dim
,
3
*
embed_dim
,
weight_attr
,
bias_attr
=
bias_attr
)
else
:
self
.
q_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
)
self
.
k_proj
=
nn
.
Linear
(
self
.
kdim
,
embed_dim
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
)
self
.
v_proj
=
nn
.
Linear
(
self
.
vdim
,
embed_dim
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
)
def
_fuse_prepare_qkv
(
self
,
query
):
mix_layer
=
self
.
qkv_proj
(
query
)
mix_layer
=
paddle
.
reshape_
(
mix_layer
,
[
0
,
0
,
self
.
num_heads
,
3
*
self
.
head_dim
])
mix_layer
=
paddle
.
transpose
(
mix_layer
,
[
0
,
2
,
1
,
3
])
q
,
k
,
v
=
paddle
.
split
(
mix_layer
,
num_or_sections
=
3
,
axis
=-
1
)
return
q
,
k
,
v
def
_prepare_qkv
(
self
,
query
,
key
,
value
,
use_cache
=
False
,
cache
=
None
):
"""
Prapares linear projected queries, keys and values for usage of subsequnt
multiple parallel attention. If `cache` is not None, using cached results
to reduce redundant calculations.
"""
q
=
self
.
q_proj
(
query
)
if
_global_parallel_strategy
==
"mp"
:
auto
.
shard_tensor
(
self
.
q_proj
.
weight
,
dist_attr
=
{
"process_mesh"
:
_global_process_mesh
,
"dims_mapping"
:
[
-
1
,
0
]
})
elif
_global_parallel_strategy
==
"dp_mp"
:
auto
.
shard_tensor
(
self
.
q_proj
.
weight
,
dist_attr
=
{
"process_mesh"
:
_global_process_mesh
,
"dims_mapping"
:
[
-
1
,
1
]
})
elif
_global_parallel_strategy
==
"mp_pp"
:
auto
.
shard_tensor
(
self
.
q_proj
.
weight
,
dist_attr
=
{
"process_mesh"
:
MPPP_MESH_LIST
[
self
.
mesh_idx
],
"dims_mapping"
:
[
-
1
,
0
]
})
elif
_global_parallel_strategy
==
"dp_mp_pp"
:
auto
.
shard_tensor
(
self
.
q_proj
.
weight
,
dist_attr
=
{
"process_mesh"
:
DPMPPP_MESH_LIST
[
self
.
mesh_idx
],
"dims_mapping"
:
[
-
1
,
1
]
})
q
=
tensor
.
reshape
(
x
=
q
,
shape
=
[
0
,
0
,
self
.
num_heads
,
self
.
head_dim
])
q
=
tensor
.
transpose
(
x
=
q
,
perm
=
[
0
,
2
,
1
,
3
])
if
isinstance
(
cache
,
self
.
StaticCache
):
# for encoder-decoder attention in inference and has cached
k
,
v
=
cache
.
k
,
cache
.
v
else
:
k
,
v
=
self
.
compute_kv
(
key
,
value
)
if
isinstance
(
cache
,
self
.
Cache
):
# for decoder self-attention in inference
k
=
tensor
.
concat
([
cache
.
k
,
k
],
axis
=
2
)
v
=
tensor
.
concat
([
cache
.
v
,
v
],
axis
=
2
)
if
use_cache
is
True
:
cache
=
self
.
Cache
(
k
,
v
)
return
(
q
,
k
,
v
)
if
use_cache
is
False
else
(
q
,
k
,
v
,
cache
)
def
compute_kv
(
self
,
key
,
value
):
"""
Applies linear projection on input keys and values, then splits heads
(reshape and transpose) to get keys and values from different representation
subspaces. The results are used as key-values pairs for subsequent multiple
parallel attention.
It is part of calculations in multi-head attention, and is provided as
a method to pre-compute and prefetch these results, thus we can use them
to construct cache for inference.
"""
k
=
self
.
k_proj
(
key
)
if
_global_parallel_strategy
==
"mp"
:
auto
.
shard_tensor
(
self
.
k_proj
.
weight
,
dist_attr
=
{
"process_mesh"
:
_global_process_mesh
,
"dims_mapping"
:
[
-
1
,
0
]
})
elif
_global_parallel_strategy
==
"dp_mp"
:
auto
.
shard_tensor
(
self
.
k_proj
.
weight
,
dist_attr
=
{
"process_mesh"
:
_global_process_mesh
,
"dims_mapping"
:
[
-
1
,
1
]
})
elif
_global_parallel_strategy
==
"mp_pp"
:
auto
.
shard_tensor
(
self
.
k_proj
.
weight
,
dist_attr
=
{
"process_mesh"
:
MPPP_MESH_LIST
[
self
.
mesh_idx
],
"dims_mapping"
:
[
-
1
,
0
]
})
elif
_global_parallel_strategy
==
"dp_mp_pp"
:
auto
.
shard_tensor
(
self
.
k_proj
.
weight
,
dist_attr
=
{
"process_mesh"
:
DPMPPP_MESH_LIST
[
self
.
mesh_idx
],
"dims_mapping"
:
[
-
1
,
1
]
})
v
=
self
.
v_proj
(
value
)
if
_global_parallel_strategy
==
"mp"
:
auto
.
shard_tensor
(
self
.
v_proj
.
weight
,
dist_attr
=
{
"process_mesh"
:
_global_process_mesh
,
"dims_mapping"
:
[
-
1
,
0
]
})
elif
_global_parallel_strategy
==
"dp_mp"
:
auto
.
shard_tensor
(
self
.
v_proj
.
weight
,
dist_attr
=
{
"process_mesh"
:
_global_process_mesh
,
"dims_mapping"
:
[
-
1
,
1
]
})
elif
_global_parallel_strategy
==
"mp_pp"
:
auto
.
shard_tensor
(
self
.
v_proj
.
weight
,
dist_attr
=
{
"process_mesh"
:
MPPP_MESH_LIST
[
self
.
mesh_idx
],
"dims_mapping"
:
[
-
1
,
0
]
})
elif
_global_parallel_strategy
==
"dp_mp_pp"
:
auto
.
shard_tensor
(
self
.
v_proj
.
weight
,
dist_attr
=
{
"process_mesh"
:
DPMPPP_MESH_LIST
[
self
.
mesh_idx
],
"dims_mapping"
:
[
-
1
,
1
]
})
k
=
tensor
.
reshape
(
x
=
k
,
shape
=
[
0
,
0
,
self
.
num_heads
,
self
.
head_dim
])
k
=
tensor
.
transpose
(
x
=
k
,
perm
=
[
0
,
2
,
1
,
3
])
v
=
tensor
.
reshape
(
x
=
v
,
shape
=
[
0
,
0
,
self
.
num_heads
,
self
.
head_dim
])
v
=
tensor
.
transpose
(
x
=
v
,
perm
=
[
0
,
2
,
1
,
3
])
return
k
,
v
def
gen_cache
(
self
,
key
,
value
=
None
,
type
=
Cache
):
"""
Generates cache for `forward` usage in inference accroding to arguments.
The generated cache is an instance of `MultiHeadAttention.Cache` or an
instance of `MultiHeadAttention.StaticCache`.
"""
if
type
==
MultiHeadAttention
.
StaticCache
:
# static_kv
k
,
v
=
self
.
compute_kv
(
key
,
value
)
return
self
.
StaticCache
(
k
,
v
)
elif
value
is
None
:
# incremental_state
k
=
layers
.
fill_constant_batch_size_like
(
input
=
key
,
shape
=
[
-
1
,
self
.
num_heads
,
0
,
self
.
head_dim
],
dtype
=
key
.
dtype
,
value
=
0
)
v
=
layers
.
fill_constant_batch_size_like
(
input
=
key
,
shape
=
[
-
1
,
self
.
num_heads
,
0
,
self
.
head_dim
],
dtype
=
key
.
dtype
,
value
=
0
)
return
self
.
Cache
(
k
,
v
)
else
:
# incremental_state with initial value, mainly for usage like UniLM
return
self
.
Cache
(
key
,
value
)
def
forward
(
self
,
query
,
key
,
value
,
attn_mask
=
None
,
use_cache
=
False
,
cache
=
None
):
"""
Applies multi-head attention to map queries and a set of key-value pairs
to outputs.
"""
key
=
query
if
key
is
None
else
key
value
=
query
if
value
is
None
else
value
# compute q ,k ,v
if
use_cache
is
False
:
if
self
.
fuse
:
q
,
k
,
v
=
self
.
_fuse_prepare_qkv
(
query
)
else
:
q
,
k
,
v
=
self
.
_prepare_qkv
(
query
,
key
,
value
,
use_cache
,
cache
)
else
:
q
,
k
,
v
,
cache
=
self
.
_prepare_qkv
(
query
,
key
,
value
,
use_cache
,
cache
)
product
=
layers
.
matmul
(
x
=
q
,
y
=
k
,
transpose_y
=
True
,
alpha
=
self
.
head_dim
**-
0.5
)
if
attn_mask
is
not
None
:
product
=
product
+
attn_mask
weights
=
F
.
softmax
(
product
)
if
self
.
dropout
:
weights
=
F
.
dropout
(
weights
,
self
.
dropout
,
training
=
self
.
training
,
mode
=
"upscale_in_train"
)
out
=
tensor
.
matmul
(
weights
,
v
)
# combine heads
out
=
tensor
.
transpose
(
out
,
perm
=
[
0
,
2
,
1
,
3
])
out
=
tensor
.
reshape
(
x
=
out
,
shape
=
[
0
,
0
,
out
.
shape
[
2
]
*
out
.
shape
[
3
]])
# project to output
out
=
self
.
out_proj
(
out
)
if
_global_parallel_strategy
==
"mp"
:
auto
.
shard_tensor
(
self
.
out_proj
.
weight
,
dist_attr
=
{
"process_mesh"
:
_global_process_mesh
,
"dims_mapping"
:
[
0
,
-
1
]
})
elif
_global_parallel_strategy
==
"dp_mp"
:
auto
.
shard_tensor
(
self
.
out_proj
.
weight
,
dist_attr
=
{
"process_mesh"
:
_global_process_mesh
,
"dims_mapping"
:
[
1
,
-
1
]
})
elif
_global_parallel_strategy
==
"mp_pp"
:
auto
.
shard_tensor
(
self
.
out_proj
.
weight
,
dist_attr
=
{
"process_mesh"
:
MPPP_MESH_LIST
[
self
.
mesh_idx
],
"dims_mapping"
:
[
0
,
-
1
]
})
elif
_global_parallel_strategy
==
"dp_mp_pp"
:
auto
.
shard_tensor
(
self
.
out_proj
.
weight
,
dist_attr
=
{
"process_mesh"
:
DPMPPP_MESH_LIST
[
self
.
mesh_idx
],
"dims_mapping"
:
[
1
,
-
1
]
})
outs
=
[
out
]
if
self
.
need_weights
:
outs
.
append
(
weights
)
if
use_cache
:
outs
.
append
(
cache
)
return
out
if
len
(
outs
)
==
1
else
tuple
(
outs
)
class
TransformerDecoder
(
nn
.
Layer
):
"""
TransformerDecoder is a stack of N decoder layers.
"""
def
__init__
(
self
,
decoder_layers
,
num_layers
,
norm
=
None
,
hidden_size
=
None
):
super
(
TransformerDecoder
,
self
).
__init__
()
self
.
num_layers
=
num_layers
self
.
layers
=
decoder_layers
self
.
norm
=
norm
if
norm
is
"LayerNorm"
:
self
.
norm
=
nn
.
LayerNorm
(
hidden_size
)
elif
norm
is
not
None
:
raise
ValueError
(
"Only support LayerNorm"
)
self
.
checkpoints
=
[]
def
forward
(
self
,
tgt
,
memory
,
tgt_mask
=
None
,
memory_mask
=
None
,
use_cache
=
False
,
cache
=
None
):
"""
Applies a stack of N Transformer decoder layers on inputs. If `norm` is
provided, also applies layer normalization on the output of last decoder
layer.
"""
output
=
tgt
new_caches
=
[]
self
.
checkpoints
=
[]
if
_global_parallel_strategy
==
"pp"
:
auto
.
shard_tensor
(
output
,
dist_attr
=
{
"process_mesh"
:
PP_MESH_LIST
[
0
],
"dims_mapping"
:
[
-
1
for
i
in
range
(
len
(
output
.
shape
))]
})
if
_global_parallel_strategy
==
"dp_pp"
:
auto
.
shard_tensor
(
output
,
dist_attr
=
{
"process_mesh"
:
DPPP_MESH_LIST
[
0
],
"dims_mapping"
:
[
0
]
+
[
-
1
for
i
in
range
(
len
(
output
.
shape
)
-
1
)]
})
if
_global_parallel_strategy
==
"mp_pp"
:
auto
.
shard_tensor
(
output
,
dist_attr
=
{
"process_mesh"
:
MPPP_MESH_LIST
[
0
],
"dims_mapping"
:
[
-
1
]
+
[
-
1
for
i
in
range
(
len
(
output
.
shape
)
-
1
)]
})
if
_global_parallel_strategy
==
"dp_mp_pp"
:
auto
.
shard_tensor
(
output
,
dist_attr
=
{
"process_mesh"
:
DPMPPP_MESH_LIST
[
0
],
"dims_mapping"
:
[
0
]
+
[
-
1
for
i
in
range
(
len
(
output
.
shape
)
-
1
)]
})
for
i
,
mod
in
enumerate
(
self
.
layers
):
if
cache
is
None
:
if
use_cache
:
if
_global_parallel_strategy
==
"pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
dist_attr
=
{
"process_mesh"
:
PP_MESH_LIST
[
mod
.
mesh_idx
]
})(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)[
0
]
auto
.
shard_tensor
(
output
,
dist_attr
=
{
"process_mesh"
:
PP_MESH_LIST
[
mod
.
mesh_idx
],
"dims_mapping"
:
[
-
1
for
i
in
range
(
len
(
output
.
shape
))]
})
elif
_global_parallel_strategy
==
"dp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
dist_attr
=
{
"process_mesh"
:
DPPP_MESH_LIST
[
mod
.
mesh_idx
]
})(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)[
0
]
auto
.
shard_tensor
(
output
,
dist_attr
=
{
"process_mesh"
:
DPPP_MESH_LIST
[
mod
.
mesh_idx
],
"dims_mapping"
:
[
0
]
+
[
-
1
for
i
in
range
(
len
(
output
.
shape
)
-
1
)]
})
elif
_global_parallel_strategy
==
"mp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
dist_attr
=
{
"process_mesh"
:
MPPP_MESH_LIST
[
mod
.
mesh_idx
]
})(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)[
0
]
auto
.
shard_tensor
(
output
,
dist_attr
=
{
"process_mesh"
:
MPPP_MESH_LIST
[
mod
.
mesh_idx
],
"dims_mapping"
:
[
-
1
]
+
[
-
1
for
i
in
range
(
len
(
output
.
shape
)
-
1
)]
})
elif
_global_parallel_strategy
==
"dp_mp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
dist_attr
=
{
"process_mesh"
:
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
]
})(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)[
0
]
auto
.
shard_tensor
(
output
,
dist_attr
=
{
"process_mesh"
:
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
],
"dims_mapping"
:
[
0
]
+
[
-
1
for
i
in
range
(
len
(
output
.
shape
)
-
1
)]
})
else
:
output
,
new_cache
=
mod
(
output
,
memory
,
tgt_mask
=
tgt_mask
,
use_cache
=
use_cache
,
cache
=
cache
)
new_caches
.
append
(
new_cache
)
else
:
if
_global_parallel_strategy
==
"pp"
:
output
=
auto
.
shard_op
(
mod
,
dist_attr
=
{
"process_mesh"
:
PP_MESH_LIST
[
mod
.
mesh_idx
]
})(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)[
0
]
auto
.
shard_tensor
(
output
,
dist_attr
=
{
"process_mesh"
:
PP_MESH_LIST
[
mod
.
mesh_idx
],
"dims_mapping"
:
[
-
1
for
i
in
range
(
len
(
output
.
shape
))]
})
elif
_global_parallel_strategy
==
"dp_pp"
:
output
=
auto
.
shard_op
(
mod
,
dist_attr
=
{
"process_mesh"
:
DPPP_MESH_LIST
[
mod
.
mesh_idx
]
})(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)[
0
]
auto
.
shard_tensor
(
output
,
dist_attr
=
{
"process_mesh"
:
DPPP_MESH_LIST
[
mod
.
mesh_idx
],
"dims_mapping"
:
[
0
]
+
[
-
1
for
i
in
range
(
len
(
output
.
shape
)
-
1
)]
})
elif
_global_parallel_strategy
==
"mp_pp"
:
output
=
auto
.
shard_op
(
mod
,
dist_attr
=
{
"process_mesh"
:
MPPP_MESH_LIST
[
mod
.
mesh_idx
]
})(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)[
0
]
auto
.
shard_tensor
(
output
,
dist_attr
=
{
"process_mesh"
:
MPPP_MESH_LIST
[
mod
.
mesh_idx
],
"dims_mapping"
:
[
-
1
]
+
[
-
1
for
i
in
range
(
len
(
output
.
shape
)
-
1
)]
})
elif
_global_parallel_strategy
==
"dp_mp_pp"
:
output
=
auto
.
shard_op
(
mod
,
dist_attr
=
{
"process_mesh"
:
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
]
})(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)[
0
]
auto
.
shard_tensor
(
output
,
dist_attr
=
{
"process_mesh"
:
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
],
"dims_mapping"
:
[
0
]
+
[
-
1
for
i
in
range
(
len
(
output
.
shape
)
-
1
)]
})
else
:
output
=
mod
(
output
,
memory
,
tgt_mask
=
tgt_mask
,
use_cache
=
use_cache
,
cache
=
cache
)
else
:
if
_global_parallel_strategy
==
"pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
dist_attr
=
{
"process_mesh"
:
PP_MESH_LIST
[
mod
.
mesh_idx
]})(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
dist_attr
=
{
"process_mesh"
:
PP_MESH_LIST
[
mod
.
mesh_idx
],
"dims_mapping"
:
[
-
1
for
i
in
range
(
len
(
output
.
shape
))]
})
elif
_global_parallel_strategy
==
"dp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
dist_attr
=
{
"process_mesh"
:
DPPP_MESH_LIST
[
mod
.
mesh_idx
]
})(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
dist_attr
=
{
"process_mesh"
:
DPPP_MESH_LIST
[
mod
.
mesh_idx
],
"dims_mapping"
:
[
0
]
+
[
-
1
for
i
in
range
(
len
(
output
.
shape
)
-
1
)]
})
elif
_global_parallel_strategy
==
"mp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
dist_attr
=
{
"process_mesh"
:
MPPP_MESH_LIST
[
mod
.
mesh_idx
]
})(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
dist_attr
=
{
"process_mesh"
:
MPPP_MESH_LIST
[
mod
.
mesh_idx
],
"dims_mapping"
:
[
-
1
]
+
[
-
1
for
i
in
range
(
len
(
output
.
shape
)
-
1
)]
})
elif
_global_parallel_strategy
==
"dp_mp_pp"
:
output
,
new_cache
=
auto
.
shard_op
(
mod
,
dist_attr
=
{
"process_mesh"
:
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
]
})(
output
,
memory
,
tgt_mask
,
use_cache
,
cache
)
auto
.
shard_tensor
(
output
,
dist_attr
=
{
"process_mesh"
:
DPMPPP_MESH_LIST
[
mod
.
mesh_idx
],
"dims_mapping"
:
[
0
]
+
[
-
1
for
i
in
range
(
len
(
output
.
shape
)
-
1
)]
})
else
:
output
,
new_cache
=
mod
(
output
,
memory
,
tgt_mask
=
tgt_mask
,
use_cache
=
use_cache
,
cache
=
cache
[
i
])
new_caches
.
append
(
new_cache
)
self
.
checkpoints
.
append
(
output
.
name
)
if
self
.
norm
is
not
None
:
output
=
self
.
norm
(
output
)
return
output
if
use_cache
is
False
else
(
output
,
new_caches
)
def
gen_cache
(
self
,
memory
,
do_zip
=
False
):
"""
Generates cache for `forward` usage. The generated cache is a list, and
each element in it is a tuple( :code:`(incremental_cache, static_cache)` )
produced by `TransformerDecoderLayer.gen_cache`. See `TransformerDecoderLayer.gen_cache`
for more details. If `do_zip` is True, apply `zip` on these tuples to get
a list with two elements.
"""
cache
=
[
layer
.
gen_cache
(
memory
)
for
layer
in
self
.
layers
]
if
do_zip
:
cache
=
list
(
zip
(
*
cache
))
return
cache
class
TransformerDecoderLayer
(
nn
.
Layer
):
"""
The transformer decoder layer.
It contains multiheadattention and some linear layers.
"""
def
__init__
(
self
,
d_model
,
nhead
,
dim_feedforward
,
dropout
=
0.1
,
activation
=
"gelu"
,
attn_dropout
=
None
,
act_dropout
=
None
,
normalize_before
=
True
,
weight_attr
=
None
,
bias_attr
=
None
,
mesh_idx
=
None
):
self
.
_config
=
locals
()
self
.
_config
.
pop
(
"self"
)
self
.
_config
.
pop
(
"__class__"
,
None
)
# py3
self
.
mesh_idx
=
mesh_idx
super
(
TransformerDecoderLayer
,
self
).
__init__
()
attn_dropout
=
dropout
if
attn_dropout
is
None
else
attn_dropout
act_dropout
=
dropout
if
act_dropout
is
None
else
act_dropout
self
.
normalize_before
=
normalize_before
weight_attrs
=
_convert_param_attr_to_list
(
weight_attr
,
3
)
bias_attrs
=
_convert_param_attr_to_list
(
bias_attr
,
3
)
self
.
self_attn
=
MultiHeadAttention
(
d_model
,
nhead
,
dropout
=
attn_dropout
,
weight_attr
=
weight_attrs
[
0
],
bias_attr
=
bias_attrs
[
0
],
mesh_idx
=
self
.
mesh_idx
)
self
.
linear1
=
nn
.
Linear
(
d_model
,
dim_feedforward
,
weight_attrs
[
2
],
bias_attr
=
bias_attrs
[
2
])
self
.
linear2
=
nn
.
Linear
(
dim_feedforward
,
d_model
,
weight_attrs
[
2
],
bias_attr
=
bias_attrs
[
2
])
self
.
norm1
=
nn
.
LayerNorm
(
d_model
,
epsilon
=
1e-5
)
self
.
norm2
=
nn
.
LayerNorm
(
d_model
,
epsilon
=
1e-5
)
self
.
dropout1
=
nn
.
Dropout
(
dropout
,
mode
=
"upscale_in_train"
)
self
.
dropout2
=
nn
.
Dropout
(
act_dropout
,
mode
=
"upscale_in_train"
)
self
.
activation
=
getattr
(
F
,
activation
)
def
forward
(
self
,
tgt
,
memory
,
tgt_mask
=
None
,
use_cache
=
False
,
cache
=
None
):
residual
=
tgt
if
self
.
normalize_before
:
tgt
=
self
.
norm1
(
tgt
)
if
use_cache
is
False
:
tgt
=
self
.
self_attn
(
tgt
,
tgt
,
tgt
,
tgt_mask
,
use_cache
,
cache
)
else
:
tgt
,
incremental_cache
=
self
.
self_attn
(
tgt
,
tgt
,
tgt
,
tgt_mask
,
use_cache
,
cache
)
tgt
=
residual
+
self
.
dropout1
(
tgt
)
if
not
self
.
normalize_before
:
tgt
=
self
.
norm1
(
tgt
)
residual
=
tgt
if
self
.
normalize_before
:
tgt
=
self
.
norm2
(
tgt
)
if
_global_parallel_strategy
==
"mp"
:
auto
.
shard_tensor
(
self
.
linear1
.
weight
,
dist_attr
=
{
"process_mesh"
:
_global_process_mesh
,
"dims_mapping"
:
[
-
1
,
0
]
})
elif
_global_parallel_strategy
==
"dp_mp"
:
auto
.
shard_tensor
(
self
.
linear1
.
weight
,
dist_attr
=
{
"process_mesh"
:
_global_process_mesh
,
"dims_mapping"
:
[
-
1
,
1
]
})
elif
_global_parallel_strategy
==
"mp_pp"
:
auto
.
shard_tensor
(
self
.
linear1
.
weight
,
dist_attr
=
{
"process_mesh"
:
MPPP_MESH_LIST
[
self
.
mesh_idx
],
"dims_mapping"
:
[
-
1
,
0
]
})
if
_global_parallel_strategy
==
"dp_mp_pp"
:
auto
.
shard_tensor
(
self
.
linear1
.
weight
,
dist_attr
=
{
"process_mesh"
:
DPMPPP_MESH_LIST
[
self
.
mesh_idx
],
"dims_mapping"
:
[
-
1
,
1
]
})
if
_global_parallel_strategy
==
"mp"
:
auto
.
shard_tensor
(
self
.
linear2
.
weight
,
dist_attr
=
{
"process_mesh"
:
_global_process_mesh
,
"dims_mapping"
:
[
0
,
-
1
]
})
elif
_global_parallel_strategy
==
"dp_mp"
:
auto
.
shard_tensor
(
self
.
linear2
.
weight
,
dist_attr
=
{
"process_mesh"
:
_global_process_mesh
,
"dims_mapping"
:
[
1
,
-
1
]
})
elif
_global_parallel_strategy
==
"mp_pp"
:
auto
.
shard_tensor
(
self
.
linear2
.
weight
,
dist_attr
=
{
"process_mesh"
:
MPPP_MESH_LIST
[
self
.
mesh_idx
],
"dims_mapping"
:
[
0
,
-
1
]
})
elif
_global_parallel_strategy
==
"dp_mp_pp"
:
auto
.
shard_tensor
(
self
.
linear2
.
weight
,
dist_attr
=
{
"process_mesh"
:
DPMPPP_MESH_LIST
[
self
.
mesh_idx
],
"dims_mapping"
:
[
1
,
-
1
]
})
tgt
=
self
.
dropout2
(
self
.
linear2
(
F
.
gelu
(
self
.
linear1
(
tgt
),
approximate
=
True
)))
tgt
=
residual
+
tgt
if
not
self
.
normalize_before
:
tgt
=
self
.
norm2
(
tgt
)
return
tgt
if
use_cache
is
False
else
(
tgt
,
incremental_cache
)
def
gen_cache
(
self
,
memory
):
incremental_cache
=
self
.
self_attn
.
gen_cache
(
memory
,
type
=
self
.
self_attn
.
Cache
)
return
incremental_cache
class
GPTEmbeddings
(
nn
.
Layer
):
"""
Include embeddings from word, position and token_type embeddings
"""
def
__init__
(
self
,
vocab_size
,
hidden_size
=
768
,
hidden_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
initializer_range
=
0.02
):
super
(
GPTEmbeddings
,
self
).
__init__
()
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
hidden_size
,
weight_attr
=
paddle
.
ParamAttr
(
name
=
"word_embeddings"
,
initializer
=
nn
.
initializer
.
Normal
(
mean
=
0.0
,
std
=
initializer_range
)))
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
hidden_size
,
weight_attr
=
paddle
.
ParamAttr
(
name
=
"pos_embeddings"
,
initializer
=
nn
.
initializer
.
Normal
(
mean
=
0.0
,
std
=
initializer_range
)))
self
.
dropout
=
nn
.
Dropout
(
hidden_dropout_prob
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
):
if
position_ids
is
None
:
ones
=
paddle
.
ones_like
(
input_ids
,
dtype
=
"int64"
)
seq_length
=
paddle
.
cumsum
(
ones
,
axis
=-
1
)
position_ids
=
seq_length
-
ones
input_embedings
=
self
.
word_embeddings
(
input_ids
)
if
_global_parallel_strategy
==
"mp"
:
auto
.
shard_tensor
(
self
.
word_embeddings
.
weight
,
dist_attr
=
{
"process_mesh"
:
_global_process_mesh
,
"dims_mapping"
:
[
0
,
-
1
]
})
elif
_global_parallel_strategy
==
"dp_mp"
:
auto
.
shard_tensor
(
self
.
word_embeddings
.
weight
,
dist_attr
=
{
"process_mesh"
:
_global_process_mesh
,
"dims_mapping"
:
[
1
,
-
1
]
})
elif
_global_parallel_strategy
==
"mp_pp"
:
auto
.
shard_tensor
(
self
.
word_embeddings
.
weight
,
dist_attr
=
{
"process_mesh"
:
MPPP_MESH_LIST
[
0
],
"dims_mapping"
:
[
0
,
-
1
]
})
elif
_global_parallel_strategy
==
"dp_mp_pp"
:
auto
.
shard_tensor
(
self
.
word_embeddings
.
weight
,
dist_attr
=
{
"process_mesh"
:
DPMPPP_MESH_LIST
[
0
],
"dims_mapping"
:
[
1
,
-
1
]
})
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
input_embedings
+
position_embeddings
embeddings
=
self
.
dropout
(
embeddings
)
return
embeddings
class
GPTModel
(
nn
.
Layer
):
"""
The base model of gpt.
"""
def
__init__
(
self
,
vocab_size
=
50304
,
hidden_size
=
1024
,
num_hidden_layers
=
24
,
num_attention_heads
=
16
,
intermediate_size
=
4096
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.
,
attention_probs_dropout_prob
=
0.
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
initializer_range
=
0.02
,
pad_token_id
=
0
,
eos_token_id
=
7
,
bos_token_id
=
0
,
eol_token_id
=
3
,
pp_degree
=
None
):
super
(
GPTModel
,
self
).
__init__
()
self
.
pad_token_id
=
pad_token_id
self
.
initializer_range
=
initializer_range
self
.
hidden_size
=
hidden_size
self
.
vocab_size
=
vocab_size
self
.
layer_per_stage
=
None
self
.
pipline_mode
=
(
pp_degree
is
not
None
and
pp_degree
>
1
)
if
self
.
pipline_mode
:
self
.
layer_per_stage
=
num_hidden_layers
//
pp_degree
self
.
embeddings
=
GPTEmbeddings
(
vocab_size
,
hidden_size
,
hidden_dropout_prob
,
max_position_embeddings
,
type_vocab_size
,
self
.
initializer_range
)
decoder_layers
=
nn
.
LayerList
()
for
i
in
range
(
num_hidden_layers
):
mesh_index
=
None
DecoderLayer
=
TransformerDecoderLayer
if
self
.
layer_per_stage
is
not
None
:
mesh_index
=
i
//
self
.
layer_per_stage
decoder_layers
.
append
(
DecoderLayer
(
d_model
=
hidden_size
,
nhead
=
num_attention_heads
,
dim_feedforward
=
intermediate_size
,
dropout
=
hidden_dropout_prob
,
activation
=
hidden_act
,
attn_dropout
=
attention_probs_dropout_prob
,
act_dropout
=
hidden_dropout_prob
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Normal
(
mean
=
0.0
,
std
=
self
.
initializer_range
)),
bias_attr
=
None
,
mesh_idx
=
mesh_index
))
Decoder
=
TransformerDecoder
self
.
decoder
=
Decoder
(
decoder_layers
,
num_hidden_layers
,
norm
=
"LayerNorm"
,
hidden_size
=
hidden_size
)
self
.
checkpoints
=
[]
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
attention_mask
=
None
,
use_cache
=
False
,
cache
=
None
):
self
.
checkpoints
=
[]
if
position_ids
is
None
:
past_length
=
0
if
cache
is
not
None
:
past_length
=
paddle
.
shape
(
cache
[
0
].
k
)[
-
2
]
position_ids
=
paddle
.
arange
(
past_length
,
paddle
.
shape
(
input_ids
)[
-
1
]
+
past_length
,
dtype
=
'int64'
)
position_ids
=
position_ids
.
unsqueeze
(
0
)
position_ids
=
paddle
.
fluid
.
layers
.
expand_as
(
position_ids
,
input_ids
)
embedding_output
=
self
.
embeddings
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
if
_global_parallel_strategy
==
"pp"
:
auto
.
shard_tensor
(
input_ids
,
dist_attr
=
{
"process_mesh"
:
PP_MESH_LIST
[
0
],
"dims_mapping"
:
[
-
1
for
i
in
range
(
len
(
input_ids
.
shape
))]
})
if
_global_parallel_strategy
==
"dp_pp"
:
auto
.
shard_tensor
(
input_ids
,
dist_attr
=
{
"process_mesh"
:
DPPP_MESH_LIST
[
0
],
"dims_mapping"
:
[
0
]
+
[
-
1
for
i
in
range
(
len
(
input_ids
.
shape
)
-
1
)]
})
if
_global_parallel_strategy
==
"dp_mp_pp"
:
auto
.
shard_tensor
(
input_ids
,
dist_attr
=
{
"process_mesh"
:
DPMPPP_MESH_LIST
[
0
],
"dims_mapping"
:
[
0
]
+
[
-
1
for
i
in
range
(
len
(
input_ids
.
shape
)
-
1
)]
})
attention_mask
.
stop_gradient
=
True
encoder_outputs
=
self
.
decoder
(
embedding_output
,
memory
=
None
,
tgt_mask
=
attention_mask
,
use_cache
=
use_cache
,
cache
=
cache
)
self
.
checkpoints
.
extend
(
self
.
decoder
.
checkpoints
)
return
encoder_outputs
class
GPTForPretraining
(
nn
.
Layer
):
"""
The pretraining model of GPT.
It returns some logits and cached_kvs.
"""
def
__init__
(
self
,
gpt
,
vocab_size
=
50304
,
hidden_size
=
768
,
initializer_range
=
0.02
,
):
super
(
GPTForPretraining
,
self
).
__init__
()
self
.
output_embeddings
=
nn
.
Embedding
(
vocab_size
,
hidden_size
,
weight_attr
=
paddle
.
ParamAttr
(
name
=
"output_embeddings"
,
initializer
=
nn
.
initializer
.
Normal
(
mean
=
0.0
,
std
=
initializer_range
)))
self
.
gpt
=
gpt
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
attention_mask
=
None
,
masked_positions
=
None
,
use_cache
=
False
,
cache
=
None
):
outputs
=
self
.
gpt
(
input_ids
,
position_ids
=
position_ids
,
attention_mask
=
attention_mask
,
use_cache
=
use_cache
,
cache
=
cache
)
if
use_cache
:
encoder_outputs
,
cached_kvs
=
outputs
[:
2
]
else
:
encoder_outputs
=
outputs
logits
=
paddle
.
matmul
(
encoder_outputs
,
self
.
output_embeddings
.
weight
,
transpose_y
=
True
)
if
use_cache
:
return
logits
,
cached_kvs
else
:
return
logits
class
GPTPretrainingCriterion
(
nn
.
Layer
):
"""
Criterion for GPT.
It calculates the final loss.
"""
def
__init__
(
self
):
super
(
GPTPretrainingCriterion
,
self
).
__init__
()
self
.
loss_func
=
paddle
.
nn
.
CrossEntropyLoss
(
reduction
=
"none"
)
def
forward
(
self
,
prediction_scores
,
masked_lm_labels
,
loss_mask
):
masked_lm_loss
=
self
.
loss_func
(
prediction_scores
,
masked_lm_labels
.
unsqueeze
(
2
))
loss_mask
=
loss_mask
.
reshape
([
-
1
])
masked_lm_loss
=
paddle
.
sum
(
masked_lm_loss
.
reshape
([
-
1
])
*
loss_mask
)
total_loss
=
masked_lm_loss
/
loss_mask
.
sum
()
pp_total_loss
=
None
loss
=
total_loss
if
"pp"
in
_global_parallel_strategy
:
total_loss
=
total_loss
masked_lm_loss
.
persistable
=
True
total_loss
.
persistable
=
True
total_loss
.
persistable
=
True
pp_total_loss
=
paddle
.
fluid
.
layers
.
fill_constant
([
1
,
],
"float32"
,
0.0
)
pp_total_loss
.
persistable
=
True
block
=
paddle
.
static
.
default_main_program
().
global_block
()
acc_steps
=
1
tmp
=
total_loss
/
acc_steps
block
.
append_op
(
type
=
"elementwise_add"
,
inputs
=
{
"X"
:
[
pp_total_loss
],
"Y"
:
[
tmp
]},
outputs
=
{
"Out"
:
[
pp_total_loss
]})
loss
=
pp_total_loss
return
loss
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录