Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
e00a7f38
M
models
项目概览
PaddlePaddle
/
models
1 年多 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e00a7f38
编写于
2月 03, 2021
作者:
Z
Zhong Hui
提交者:
GitHub
2月 03, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add checkpoint support for gpt2 model (#5257)
* fix checkpoints problem.
上级
1c4b18c0
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
63 addition
and
17 deletion
+63
-17
PaddleNLP/examples/language_model/gpt2/README.md
PaddleNLP/examples/language_model/gpt2/README.md
+1
-1
PaddleNLP/examples/language_model/gpt2/generate_sample.py
PaddleNLP/examples/language_model/gpt2/generate_sample.py
+0
-1
PaddleNLP/examples/language_model/gpt2/run_pretrain.py
PaddleNLP/examples/language_model/gpt2/run_pretrain.py
+30
-12
PaddleNLP/examples/language_model/gpt2/scripts/run.sh
PaddleNLP/examples/language_model/gpt2/scripts/run.sh
+3
-1
PaddleNLP/examples/language_model/gpt2/scripts/run_multi.sh
PaddleNLP/examples/language_model/gpt2/scripts/run_multi.sh
+4
-1
PaddleNLP/paddlenlp/transformers/gpt2/tokenizer.py
PaddleNLP/paddlenlp/transformers/gpt2/tokenizer.py
+25
-1
未找到文件。
PaddleNLP/examples/language_model/gpt2/README.md
浏览文件 @
e00a7f38
...
...
@@ -23,7 +23,7 @@
1.
paddle安装
本项目依赖于 PaddlePaddle 2.0
rc1
及以上版本或适当的develop版本,请参考 [安装指南](https://www.paddlepaddle.org.cn/install/quick) 进行安装
本项目依赖于 PaddlePaddle 2.0及以上版本或适当的develop版本,请参考 [安装指南](https://www.paddlepaddle.org.cn/install/quick) 进行安装
2.
下载代码
...
...
PaddleNLP/examples/language_model/gpt2/generate_sample.py
浏览文件 @
e00a7f38
...
...
@@ -20,7 +20,6 @@ import argparse
import
numpy
as
np
import
paddle
from
paddlenlp.utils.tools
import
loadz
from
paddlenlp.transformers
import
GPT2Model
,
GPT2ForPretraining
from
paddlenlp.transformers
import
GPT2ChineseTokenizer
,
GPT2Tokenizer
from
paddlenlp.utils.log
import
logger
...
...
PaddleNLP/examples/language_model/gpt2/run_pretrain.py
浏览文件 @
e00a7f38
...
...
@@ -30,15 +30,18 @@ from paddlenlp.utils.log import logger
from
data
import
GPT2Dataset
import
lr
MODEL_CLASSES
=
{
"gpt2-small-en"
:
(
GPT2ForPretraining
,
GPT2Tokenizer
),
"gpt2-medium-en"
:
(
GPT2ForPretraining
,
GPT2Tokenizer
),
"gpt2-large-en"
:
(
GPT2ForPretraining
,
GPT2Tokenizer
),
}
MODEL_CLASSES
=
{
"gpt2"
:
(
GPT2ForPretraining
,
GPT2Tokenizer
)}
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_type"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Model type selected in the list: "
+
", "
.
join
(
MODEL_CLASSES
.
keys
()),
)
parser
.
add_argument
(
"--model_name_or_path"
,
default
=
None
,
...
...
@@ -190,15 +193,18 @@ def do_train(args):
worker_num
=
paddle
.
distributed
.
get_world_size
()
set_seed
(
args
)
worker_init
=
WorkerInitObj
(
args
.
seed
+
paddle
.
distributed
.
get_rank
())
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_
name_or_path
]
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_
type
]
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
eod_id
=
tokenizer
.
command_name_map
[
"eod"
].
Id
pretrained_models_list
=
list
(
model_class
.
pretrained_init_configuration
.
keys
())
if
args
.
model_name_or_path
in
pretrained_models_list
:
model
=
GPT2ForPretraining
(
GPT2Model
(
**
model_class
.
pretrained_init_configuration
[
args
.
model_name_or_path
]))
# creat the critrion for the gpt model
criterion
=
GPT2PretrainingCriterion
(
)
else
:
model
=
GPT2ForPretraining
.
from_pretrained
(
args
.
model_name_or_path
)
if
args
.
decay_steps
is
None
:
args
.
decay_steps
=
args
.
max_steps
...
...
@@ -223,6 +229,13 @@ def do_train(args):
p
.
name
for
n
,
p
in
model
.
named_parameters
()
if
not
any
(
nd
in
n
for
nd
in
[
"bias"
,
"norm"
])
])
if
args
.
model_name_or_path
not
in
pretrained_models_list
:
opt_dict
=
paddle
.
load
(
os
.
path
.
join
(
args
.
model_name_or_path
,
"model_state.pdopt"
))
optimizer
.
set_state_dict
(
opt_dict
)
# creat the critrion for the gpt model
criterion
=
GPT2PretrainingCriterion
()
global_step
=
0
tic_train
=
time
.
time
()
...
...
@@ -259,7 +272,7 @@ def do_train(args):
loss
.
backward
()
optimizer
.
step
()
lr_scheduler
.
step
()
optimizer
.
clear_grad
ients
()
optimizer
.
clear_grad
()
if
global_step
%
args
.
save_steps
==
0
:
if
worker_index
==
0
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
...
...
@@ -270,9 +283,14 @@ def do_train(args):
model_to_save
=
model
.
_layers
if
isinstance
(
model
,
paddle
.
DataParallel
)
else
model
model_to_save
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"model_state.pdopt"
))
if
global_step
>=
args
.
max_steps
:
del
train_data_loader
return
del
train_data_loader
...
...
PaddleNLP/examples/language_model/gpt2/scripts/run.sh
浏览文件 @
e00a7f38
export
CUDA_VISIBLE_DEVICES
=
0
python run_pretrain.py
--model_name_or_path
gpt2-small-en
--input_dir
"./data"
\
python run_pretrain.py
--model_type
gpt2
\
--model_name_or_path
gpt2-small-en
\
--input_dir
"./data"
\
--output_dir
"output"
\
--max_lr
0.00015
\
--min_lr
0.00001
\
...
...
PaddleNLP/examples/language_model/gpt2/scripts/run_multi.sh
浏览文件 @
e00a7f38
unset
CUDA_VISIBLE_DEVICES
python
-m
paddle.distributed.launch
--gpus
"0,1"
run_pretrain.py
--model_name_or_path
gpt2-small-en
--input_dir
"./data"
\
python
-m
paddle.distributed.launch
--gpus
"0,1"
run_pretrain.py
\
--model_type
gpt2
\
--model_name_or_path
gpt2-small-en
\
--input_dir
"./data"
\
--output_dir
"output"
\
--max_lr
0.00015
\
--min_lr
0.00001
\
...
...
PaddleNLP/paddlenlp/transformers/gpt2/tokenizer.py
浏览文件 @
e00a7f38
...
...
@@ -18,6 +18,7 @@ from collections import namedtuple
import
json
import
jieba
import
shutil
from
paddle.utils
import
try_import
from
..
import
PretrainedTokenizer
...
...
@@ -111,7 +112,8 @@ class GPT2ChineseTokenizer(PretrainedTokenizer):
bod_id
=
"<bod>"
,
eod_id
=
"<eod>"
,
max_length
=
None
):
self
.
_vocab_file
=
vocab_file
self
.
_model_file
=
model_file
if
not
os
.
path
.
isfile
(
vocab_file
):
raise
ValueError
(
"Can't find a vocabulary file at path '{}'. To load the "
...
...
@@ -149,6 +151,16 @@ class GPT2ChineseTokenizer(PretrainedTokenizer):
'
\n
'
)
return
text
def
save_resources
(
self
,
save_directory
):
"""
Save tokenizer related resources to files under `save_directory`.
Args:
save_directory (str): Directory to save files into.
"""
for
name
,
file_name
in
self
.
resource_files_names
.
items
():
save_path
=
os
.
path
.
join
(
save_directory
,
file_name
)
shutil
.
copyfile
(
getattr
(
self
,
"_%s"
%
name
),
save_path
)
class
GPT2Tokenizer
(
PretrainedTokenizer
):
resource_files_names
=
{
...
...
@@ -192,6 +204,8 @@ class GPT2Tokenizer(PretrainedTokenizer):
special_tokens
=
None
,
max_len
=
None
,
do_lower_case
=
True
):
self
.
_vocab_file
=
vocab_file
self
.
_merges_file
=
merges_file
self
.
max_len
=
int
(
1e12
)
self
.
num_command_tokens
=
2
self
.
num_type_tokens
=
2
...
...
@@ -346,3 +360,13 @@ class GPT2Tokenizer(PretrainedTokenizer):
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
'utf-8'
,
errors
=
self
.
errors
)
return
text
def
save_resources
(
self
,
save_directory
):
"""
Save tokenizer related resources to files under `save_directory`.
Args:
save_directory (str): Directory to save files into.
"""
for
name
,
file_name
in
self
.
resource_files_names
.
items
():
save_path
=
os
.
path
.
join
(
save_directory
,
file_name
)
shutil
.
copyfile
(
getattr
(
self
,
"_%s"
%
name
),
save_path
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录