未验证 提交 e00a7f38 编写于 作者: Z Zhong Hui 提交者: GitHub

Add checkpoint support for gpt2 model (#5257)

* fix checkpoints problem.
上级 1c4b18c0
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
1. paddle安装 1. paddle安装
本项目依赖于 PaddlePaddle 2.0rc1及以上版本或适当的develop版本,请参考 [安装指南](https://www.paddlepaddle.org.cn/install/quick) 进行安装 本项目依赖于 PaddlePaddle 2.0及以上版本或适当的develop版本,请参考 [安装指南](https://www.paddlepaddle.org.cn/install/quick) 进行安装
2. 下载代码 2. 下载代码
......
...@@ -20,7 +20,6 @@ import argparse ...@@ -20,7 +20,6 @@ import argparse
import numpy as np import numpy as np
import paddle import paddle
from paddlenlp.utils.tools import loadz
from paddlenlp.transformers import GPT2Model, GPT2ForPretraining from paddlenlp.transformers import GPT2Model, GPT2ForPretraining
from paddlenlp.transformers import GPT2ChineseTokenizer, GPT2Tokenizer from paddlenlp.transformers import GPT2ChineseTokenizer, GPT2Tokenizer
from paddlenlp.utils.log import logger from paddlenlp.utils.log import logger
......
...@@ -30,15 +30,18 @@ from paddlenlp.utils.log import logger ...@@ -30,15 +30,18 @@ from paddlenlp.utils.log import logger
from data import GPT2Dataset from data import GPT2Dataset
import lr import lr
MODEL_CLASSES = { MODEL_CLASSES = {"gpt2": (GPT2ForPretraining, GPT2Tokenizer)}
"gpt2-small-en": (GPT2ForPretraining, GPT2Tokenizer),
"gpt2-medium-en": (GPT2ForPretraining, GPT2Tokenizer),
"gpt2-large-en": (GPT2ForPretraining, GPT2Tokenizer),
}
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " +
", ".join(MODEL_CLASSES.keys()), )
parser.add_argument( parser.add_argument(
"--model_name_or_path", "--model_name_or_path",
default=None, default=None,
...@@ -190,15 +193,18 @@ def do_train(args): ...@@ -190,15 +193,18 @@ def do_train(args):
worker_num = paddle.distributed.get_world_size() worker_num = paddle.distributed.get_world_size()
set_seed(args) set_seed(args)
worker_init = WorkerInitObj(args.seed + paddle.distributed.get_rank()) worker_init = WorkerInitObj(args.seed + paddle.distributed.get_rank())
model_class, tokenizer_class = MODEL_CLASSES[args.model_name_or_path] model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
eod_id = tokenizer.command_name_map["eod"].Id eod_id = tokenizer.command_name_map["eod"].Id
model = GPT2ForPretraining( pretrained_models_list = list(
GPT2Model(**model_class.pretrained_init_configuration[ model_class.pretrained_init_configuration.keys())
args.model_name_or_path])) if args.model_name_or_path in pretrained_models_list:
# creat the critrion for the gpt model model = GPT2ForPretraining(
criterion = GPT2PretrainingCriterion() GPT2Model(**model_class.pretrained_init_configuration[
args.model_name_or_path]))
else:
model = GPT2ForPretraining.from_pretrained(args.model_name_or_path)
if args.decay_steps is None: if args.decay_steps is None:
args.decay_steps = args.max_steps args.decay_steps = args.max_steps
...@@ -223,6 +229,13 @@ def do_train(args): ...@@ -223,6 +229,13 @@ def do_train(args):
p.name for n, p in model.named_parameters() p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"]) if not any(nd in n for nd in ["bias", "norm"])
]) ])
if args.model_name_or_path not in pretrained_models_list:
opt_dict = paddle.load(
os.path.join(args.model_name_or_path, "model_state.pdopt"))
optimizer.set_state_dict(opt_dict)
# creat the critrion for the gpt model
criterion = GPT2PretrainingCriterion()
global_step = 0 global_step = 0
tic_train = time.time() tic_train = time.time()
...@@ -259,7 +272,7 @@ def do_train(args): ...@@ -259,7 +272,7 @@ def do_train(args):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
optimizer.clear_gradients() optimizer.clear_grad()
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
if worker_index == 0: if worker_index == 0:
output_dir = os.path.join(args.output_dir, output_dir = os.path.join(args.output_dir,
...@@ -270,9 +283,14 @@ def do_train(args): ...@@ -270,9 +283,14 @@ def do_train(args):
model_to_save = model._layers if isinstance( model_to_save = model._layers if isinstance(
model, paddle.DataParallel) else model model, paddle.DataParallel) else model
model_to_save.save_pretrained(output_dir) model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(
optimizer.state_dict(),
os.path.join(output_dir, "model_state.pdopt"))
if global_step >= args.max_steps: if global_step >= args.max_steps:
del train_data_loader del train_data_loader
return return
del train_data_loader del train_data_loader
......
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python run_pretrain.py --model_name_or_path gpt2-small-en --input_dir "./data"\ python run_pretrain.py --model_type gpt2\
--model_name_or_path gpt2-small-en\
--input_dir "./data"\
--output_dir "output"\ --output_dir "output"\
--max_lr 0.00015\ --max_lr 0.00015\
--min_lr 0.00001\ --min_lr 0.00001\
......
unset CUDA_VISIBLE_DEVICES unset CUDA_VISIBLE_DEVICES
python -m paddle.distributed.launch --gpus "0,1" run_pretrain.py --model_name_or_path gpt2-small-en --input_dir "./data"\ python -m paddle.distributed.launch --gpus "0,1" run_pretrain.py \
--model_type gpt2\
--model_name_or_path gpt2-small-en\
--input_dir "./data"\
--output_dir "output"\ --output_dir "output"\
--max_lr 0.00015\ --max_lr 0.00015\
--min_lr 0.00001\ --min_lr 0.00001\
......
...@@ -18,6 +18,7 @@ from collections import namedtuple ...@@ -18,6 +18,7 @@ from collections import namedtuple
import json import json
import jieba import jieba
import shutil
from paddle.utils import try_import from paddle.utils import try_import
from .. import PretrainedTokenizer from .. import PretrainedTokenizer
...@@ -111,7 +112,8 @@ class GPT2ChineseTokenizer(PretrainedTokenizer): ...@@ -111,7 +112,8 @@ class GPT2ChineseTokenizer(PretrainedTokenizer):
bod_id="<bod>", bod_id="<bod>",
eod_id="<eod>", eod_id="<eod>",
max_length=None): max_length=None):
self._vocab_file = vocab_file
self._model_file = model_file
if not os.path.isfile(vocab_file): if not os.path.isfile(vocab_file):
raise ValueError( raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the " "Can't find a vocabulary file at path '{}'. To load the "
...@@ -149,6 +151,16 @@ class GPT2ChineseTokenizer(PretrainedTokenizer): ...@@ -149,6 +151,16 @@ class GPT2ChineseTokenizer(PretrainedTokenizer):
'\n') '\n')
return text return text
def save_resources(self, save_directory):
"""
Save tokenizer related resources to files under `save_directory`.
Args:
save_directory (str): Directory to save files into.
"""
for name, file_name in self.resource_files_names.items():
save_path = os.path.join(save_directory, file_name)
shutil.copyfile(getattr(self, "_%s" % name), save_path)
class GPT2Tokenizer(PretrainedTokenizer): class GPT2Tokenizer(PretrainedTokenizer):
resource_files_names = { resource_files_names = {
...@@ -192,6 +204,8 @@ class GPT2Tokenizer(PretrainedTokenizer): ...@@ -192,6 +204,8 @@ class GPT2Tokenizer(PretrainedTokenizer):
special_tokens=None, special_tokens=None,
max_len=None, max_len=None,
do_lower_case=True): do_lower_case=True):
self._vocab_file = vocab_file
self._merges_file = merges_file
self.max_len = int(1e12) self.max_len = int(1e12)
self.num_command_tokens = 2 self.num_command_tokens = 2
self.num_type_tokens = 2 self.num_type_tokens = 2
...@@ -346,3 +360,13 @@ class GPT2Tokenizer(PretrainedTokenizer): ...@@ -346,3 +360,13 @@ class GPT2Tokenizer(PretrainedTokenizer):
text = bytearray([self.byte_decoder[c] for c in text]).decode( text = bytearray([self.byte_decoder[c] for c in text]).decode(
'utf-8', errors=self.errors) 'utf-8', errors=self.errors)
return text return text
def save_resources(self, save_directory):
"""
Save tokenizer related resources to files under `save_directory`.
Args:
save_directory (str): Directory to save files into.
"""
for name, file_name in self.resource_files_names.items():
save_path = os.path.join(save_directory, file_name)
shutil.copyfile(getattr(self, "_%s" % name), save_path)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册