未验证 提交 84f481c0 编写于 作者: Z Zhong Hui 提交者: GitHub

add gpt2 model for the paddlenlp

add gpt2 model for the paddlenlp
上级 b1fcba33
# GPT2
## 模型介绍
[GPT2](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)(Language Models are Unsupervised Multitask Learners) 以[Transformer](https://arxiv.org/abs/1706.03762) 解码器为网络基本组件,使用自回归的方式在大规模无标注文本语料上进行预训练(pre-train),得到的语言生成模型。
本项目是语言模型 GPT2 的 PaddlePaddle 实现, 包含模型训练,预测等内容。下是本例的简要目录结构及说明:
```text
.
├── data.py # 数据处理
├── decompress.sh # 数据集解压脚本
├── generate_sample.py # inference demo
├── lr.py # 学习率控制
├── process_data.py # 数据预处理脚本
├── README.md # 文档
├── run_pretrain.py # 预训练入口
└── scripts # 训练脚本
```
## 快速开始
### 安装说明
1. paddle安装
本项目依赖于 PaddlePaddle 2.0rc1及以上版本或适当的develop版本,请参考 [安装指南](https://www.paddlepaddle.org.cn/install/quick) 进行安装
2. 下载代码
克隆代码库到本地
3. 环境依赖
该模型使用PaddlePaddle,关于环境依赖部分,请先参考PaddlePaddle[安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/index_cn.html)关于环境依赖部分的内容。
### 数据准备
#### 原始数据获取
[OpenWebTextCorpus](https://skylion007.github.io/OpenWebTextCorpus/)是一个开源的英文网页文本数据集,数据来源于Reddit,经过去重、清洗、提取,最终包含800多万个文档。
下载以后通过以下命令解压:
```shell
xz -d openwebtext.tar.xz
tar xf openwebtext.tar
mkdir raw_data
bash decompress.sh
```
解压以后得到的raw_data目录大小约为54GB。
#### 数据预处理
为了提升训练速度,我们在训练前将文本数据转成相应的id,并保存为npz格式:
```shell
python process_data.py --input_path raw_data \
--model_name gpt2-medium-en \
--append_eod \
--workers 8
```
运行命令后,产出`raw_data_ids.npz`文件。为了方便用户运行测试本模型,本项目提供了处理好的300M的训练样本:
```shell
wget https://paddlenlp.bj.bcebos.com/models/transformers/gpt2/train.data.json_ids.npz
```
将所有预处理得到的npz文件统一放入一个文件夹中,以备训练使用:
```
mkdir data
mv train.data.json_ids.npz data
```
#### 单卡训练
```shell
CUDA_VISIBLE_DEVICES=0 python run_pretrain.py --model_name_or_path gpt2-small-en \
--input_dir "./data"\
--output_dir "output"\
--weight_decay 0.01\
--grad_clip 1.0\
--max_steps 500000\
--save_steps 100000\
--warmup_rate 0.01\
--batch_size 8\
--device gpu
```
其中参数释义如下:
- `model_name_or_path` 要训练的模型或者之前训练的checkpoint。
- `input_dir` 指定输入文件,可以使用目录,指定目录时将包括目录中的所有文件。
- `output_dir` 指定输出文件。
- `weight_decay` 权重衰减参数。
- `grad_clip` 梯度裁剪范围。
- `max_steps` 最大训练步数
- `save_steps` 保存模型间隔
- `batch_size` 训练的batch大小
- `device` 训练设备
用户也可以使用提供的shell脚本直接训练`sh scripts/run.sh`.
### 单机多卡
同样,可以执行如下命令实现八卡训练:
```shell
unset CUDA_VISIBLE_DEVICES
python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_pretrain.py --model_name_or_path gpt2-small-en \
--input_dir "./data"\
--output_dir "output"\
--weight_decay 0.01\
--grad_clip 1.0\
--max_steps 500000\
--save_steps 100000\
--warmup_rate 0.01\
--batch_size 8\
--device gpu
```
用户也可以使用提供的shell脚本直接训练`sh scripts/run_multi.sh`.
#### 文本生成
本项目提供了简单的文本生成的demo,供用户测试文本生成效果。
```shell
python generate_sample.py
```
生成效果展示:
```text
问题:中国的首都是哪里?答案:北京。
问题:百度的厂长是谁? 答案:
李彦宏。
默写古诗: 大漠孤烟直,长河落日圆。
举杯邀明月,
对影成三人。
```
## 参考文献
- [Language Models are Unsupervised Multitask Learners](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)
- [CPM: A Large-scale Generative Chinese Pre-trained Language Model](https://arxiv.org/abs/2012.00413)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import os
import numpy as np
import paddle
def construct_samples_and_shuffle_data(name, data_prefix, documents, sizes,
num_samples, seq_length, seed,
worker_index):
# Number of tokens in each epoch and number of required epochs.
tokens_per_epoch = _num_tokens(sizes)
num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
# rng state
np_rng = np.random.RandomState(seed=seed)
# Filename of the index mappings.
_filename = data_prefix
_filename += '_{}_indexmap'.format(name)
_filename += '_{}ns'.format(num_samples)
_filename += '_{}sl'.format(seq_length)
doc_idx_filename = _filename + '_doc_idx.npy'
sample_idx_filename = _filename + '_sample_idx.npy'
shuffle_idx_filename = _filename + '_shuffle_idx.npy'
# Build the indexed mapping if not exist.
if worker_index == 0:
if (not os.path.isfile(doc_idx_filename)) or \
(not os.path.isfile(sample_idx_filename)) or \
(not os.path.isfile(shuffle_idx_filename)):
if num_epochs == 1:
separate_last_epoch = False
else:
num_samples_from_epochs_minus_one = (
(num_epochs - 1) * tokens_per_epoch - 1) // seq_length
last_epoch_num_samples = num_samples - \
num_samples_from_epochs_minus_one
assert last_epoch_num_samples >= 0, \
'last epoch number of samples should be non-negative.'
num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length
assert last_epoch_num_samples < (num_samples_per_epoch + 1), \
'last epoch number of samples exceeded max value.'
separate_last_epoch = (
last_epoch_num_samples < int(0.80 * num_samples_per_epoch))
doc_idx = _build_doc_idx(documents, num_epochs, np_rng,
separate_last_epoch)
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
# sample-idx.
assert doc_idx.dtype == np.int32
sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch)
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
if separate_last_epoch:
num_samples_ = num_samples_from_epochs_minus_one
else:
num_samples_ = sample_idx.shape[0] - 1
shuffle_idx = _build_shuffle_idx(num_samples_,
sample_idx.shape[0] - 1, np_rng)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
else:
while True:
if (not os.path.isfile(doc_idx_filename)) or \
(not os.path.isfile(sample_idx_filename)) or \
(not os.path.isfile(shuffle_idx_filename)):
time.sleep(3)
else:
break
# Load mappings.
doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r')
sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r')
shuffle_idx = np.load(
shuffle_idx_filename, allow_pickle=True, mmap_mode='r')
return doc_idx, sample_idx, shuffle_idx
def _num_tokens(lens):
"""Total number of tokens in the dataset."""
return np.sum(lens)
def _num_epochs(tokens_per_epoch, seq_length, num_samples):
"""Based on number of samples and sequence lenght, calculate how many
epochs will be needed."""
num_epochs = 0
total_tokens = 0
while True:
num_epochs += 1
total_tokens += tokens_per_epoch
if ((total_tokens - 1) // seq_length) >= num_samples:
return num_epochs
def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch):
"""Build an array with length = number-of-epochs * number-of-dcuments.
Each index is mapped to a corresponding document."""
if not separate_last_epoch or num_epochs == 1:
doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1]
doc_idx[:] = documents
doc_idx = doc_idx.reshape(-1)
doc_idx = doc_idx.astype(np.int32)
# np_rng.shuffle(doc_idx)
return doc_idx
doc_idx_first = _build_doc_idx(documents, num_epochs - 1, np_rng, False)
doc_idx_last = _build_doc_idx(documents, 1, np_rng, False)
return np.concatenate((doc_idx_first, doc_idx_last))
def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch):
num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length
sample_idx = np.zeros([int(num_samples) + 1, 2], dtype=np.int32)
sample_index = 0
doc_idx_index = 0
doc_offset = 0
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
while sample_index <= num_samples:
remaining_seq_length = seq_length + 1
while remaining_seq_length != 0:
doc_id = doc_idx[doc_idx_index]
doc_length = sizes[doc_id] - doc_offset
remaining_seq_length -= doc_length
if remaining_seq_length <= 0:
doc_offset += (remaining_seq_length + doc_length - 1)
remaining_seq_length = 0
else:
doc_idx_index += 1
doc_offset = 0
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
return sample_idx
def _build_shuffle_idx(num_samples, total_size, np_rng):
dtype_ = np.uint32
if total_size >= (np.iinfo(np.uint32).max - 1):
dtype_ = np.int64
shuffle_idx_first = np.arange(
start=0, stop=num_samples, step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx_first)
if num_samples == total_size:
return shuffle_idx_first
shuffle_idx_last = np.arange(
start=num_samples, stop=total_size, step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx_last)
return np.concatenate((shuffle_idx_first, shuffle_idx_last))
class GPT2Dataset(paddle.io.Dataset):
def __init__(self,
file_path,
worker_index,
num_samples,
eod_id,
name="gpt2",
max_seq_len=1024,
mode="train",
seed=1234):
self.file_path = file_path
self.max_seq_len = max_seq_len
self.name = name
process_datas = np.load(
self.file_path, mmap_mode="r+", allow_pickle=True)
self.sample_ids = process_datas["ids"]
self.sample_lens = process_datas["lens"]
document_ids = np.arange(0, self.sample_lens.shape[0])
self.eod_id = eod_id
self.doc_idx, self.sample_idx, self.shuffle_idx = \
construct_samples_and_shuffle_data(self.name, self.file_path, document_ids,\
self.sample_lens, num_samples, max_seq_len, seed, worker_index)
self.start_pos = [0] + np.cumsum(self.sample_lens).tolist()
def _construct_sample(self, tokens):
tokens = np.array(tokens).astype("int64").tolist()
labels = tokens[1:]
tokens = tokens[:-1]
seq_length = len(tokens)
# attention mask for the attention calulate
attention_mask = np.tri(seq_length, seq_length).reshape(
(1, seq_length, seq_length))
# the pad and eod tokens do not contribute the loss
loss_mask = np.ones(seq_length, dtype="float32")
loss_mask[np.where(np.array(tokens) == self.eod_id)] = 0.0
position_ids = np.arange(0, seq_length, dtype="int64")
# -INF mask value as default
attention_mask = (attention_mask - 1.0) * 1e9
# Bool mask of attention
# attention_mask = attention_mask.astype("float32")
return [tokens, loss_mask, attention_mask, position_ids, labels]
def _get_single_sample_from_idx(self, doc_index_f, doc_index_l, offset_f,
offset_l):
if doc_index_f == doc_index_l:
current_start_pos = self.start_pos[doc_index_f]
return self.sample_ids[current_start_pos+offset_f:\
current_start_pos+offset_l+1].tolist()
elif doc_index_f < doc_index_l:
current_start_pos = self.start_pos[doc_index_f]
next_start_pos = self.start_pos[doc_index_f + 1]
tokens = self.sample_ids[current_start_pos + offset_f:
next_start_pos].tolist()
for i in range(doc_index_f + 1, doc_index_l):
current_start_pos = self.start_pos[i]
next_start_pos = self.start_pos[i + 1]
tokens.extend(self.sample_ids[current_start_pos:next_start_pos]
.tolist())
last_start_pos = self.start_pos[doc_index_l]
tokens.extend(self.sample_ids[last_start_pos:last_start_pos +
offset_l + 1].tolist())
else:
current_start_pos = self.start_pos[doc_index_f]
next_start_pos = self.start_pos[-1]
tokens = self.sample_ids[current_start_pos + offset_f:
next_start_pos].tolist()
for i in range(0, doc_index_l):
current_start_pos = self.start_pos[i]
next_start_pos = self.start_pos[i + 1]
tokens.extend(self.sample_ids[current_start_pos:next_start_pos]
.tolist())
last_start_pos = self.start_pos[doc_index_l]
tokens.extend(self.sample_ids[last_start_pos:last_start_pos +
offset_l + 1].tolist())
return tokens
def __getitem__(self, index):
idx = self.shuffle_idx[index]
# Start and end documents and offsets.
doc_index_f_raw = self.sample_idx[idx][0]
doc_index_l_raw = self.sample_idx[idx + 1][0]
doc_index_f = self.doc_idx[self.sample_idx[idx][0]]
doc_index_l = self.doc_idx[self.sample_idx[idx + 1][0]]
offset_f = self.sample_idx[idx][1]
offset_l = self.sample_idx[idx + 1][1]
tokens = self._get_single_sample_from_idx(doc_index_f, doc_index_l,
offset_f, offset_l)
token_arr = np.array(tokens, dtype="int64")
return self._construct_sample(tokens)
def __len__(self):
return self.sample_idx.shape[0] - 1
#!/bin/bash
n=0
maxjobs=2 # 最大进程数
m=0
maxfiles=12800 # 每个目录中的最大文件数
for i in $(ls openwebtext); do
echo $i;
if ((n % $maxfiles == 0)); then
((m=n))
mkdir -p raw_data/data_$m
fi
if ((++n % $maxjobs == 0)) ; then
wait
fi
tar xJf openwebtext/$i --warning=no-timestamp -C raw_data/data_$m/ &
done
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Many thanks for following projects.
# https://github.com/TsinghuaAI/CPM-Generate
# https://github.com/jm12138/CPM-Generate-Paddle
import argparse
import numpy as np
import paddle
from paddlenlp.utils.tools import loadz
from paddlenlp.transformers import GPT2Model, GPT2ForPretraining
from paddlenlp.transformers import GPT2ChineseTokenizer, GPT2Tokenizer
from paddlenlp.utils.log import logger
MODEL_CLASSES = {
"gpt2-base-cn": (GPT2ForPretraining, GPT2ChineseTokenizer),
"gpt2-medium-en": (GPT2ForPretraining, GPT2Tokenizer),
}
class Demo:
def __init__(self, model_name_or_path="gpt2-base-cn"):
model_class, tokenizer_class = MODEL_CLASSES[model_name_or_path]
self.tokenizer = tokenizer_class.from_pretrained(model_name_or_path)
logger.info('Loading the model parameters, please wait...')
self.model = model_class.from_pretrained(model_name_or_path)
self.model.eval()
logger.info('Model loaded.')
# prediction function
def predict(self, text, max_len=10):
ids = self.tokenizer.encode(text)
input_id = paddle.to_tensor(
np.array(ids).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(input_id, use_cache=True, cache=None)
nid = int(np.argmax(output[0, -1].numpy()))
ids.append(nid)
out = [nid]
for i in range(max_len):
input_id = paddle.to_tensor(
np.array([nid]).reshape(1, -1).astype('int64'))
output, cached_kvs = self.model(
input_id, use_cache=True, cache=cached_kvs)
nid = int(np.argmax(output[0, -1].numpy()))
ids.append(nid)
# if nid is '\n', the predicion is over.
if nid == 3:
break
out.append(nid)
logger.info(text)
logger.info(self.tokenizer.decode(out))
# One shot example
def ask_question(self, question, max_len=10):
self.predict("问题:中国的首都是哪里?答案:北京。\n问题:%s 答案:" % question, max_len)
# dictation poetry
def dictation_poetry(self, front, max_len=10):
self.predict('''默写古诗: 大漠孤烟直,长河落日圆。\n%s''' % front, max_len)
if __name__ == "__main__":
demo = Demo("gpt2-base-cn")
demo.ask_question("百度的厂长是谁?")
demo.dictation_poetry("举杯邀明月,")
del demo
# demo = Demo("gpt2-medium-en")
# demo.predict("Question: Where is the capital of China? Answer: Beijing. \nQuestion: Who is the CEO of Apple? Answer:", 20)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy
import warnings
from paddle import Tensor
from paddle.optimizer.lr import LRScheduler
class CosineAnnealingWithWarmupDecay(LRScheduler):
def __init__(self,
max_lr,
min_lr,
warmup_step,
decay_step,
last_epoch=-1,
verbose=False):
self.decay_step = decay_step
self.warmup_step = warmup_step
self.max_lr = max_lr
self.min_lr = min_lr
super(CosineAnnealingWithWarmupDecay, self).__init__(max_lr, last_epoch,
verbose)
def get_lr(self):
if self.warmup_step > 0 and self.last_epoch <= self.warmup_step:
return float(self.max_lr) * (self.last_epoch) / self.warmup_step
if self.last_epoch > self.decay_step:
return self.min_lr
num_step_ = self.last_epoch - self.warmup_step
decay_step_ = self.decay_step - self.warmup_step
decay_ratio = float(num_step_) / float(decay_step_)
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
return self.min_lr + coeff * (self.max_lr - self.min_lr)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import argparse
import json
import multiprocessing
import numpy as np
from paddlenlp.transformers import GPT2Tokenizer
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--input_path', type=str, required=True, help='Path to input JSON')
parser.add_argument(
'--model_name', type=str, required=True, help='What model to use.')
parser.add_argument(
'--append_eod',
action='store_true',
help='Append an <eod> token to the end of a document.')
parser.add_argument(
'--workers',
type=int,
default=1,
help='Number of worker processes to launch')
args = parser.parse_args()
return args
class Converter(object):
def __init__(self, model_name, append_eod):
self.append_eod = append_eod
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
Converter.tokenizer = tokenizer
self.eod_id = tokenizer.command_name_map["eod"].Id
self.vocab_size = len(tokenizer)
def encode(self, text):
tokens = self.tokenizer.encode(text)
if self.append_eod:
tokens.append(self.eod_id)
return tokens, len(tokens)
def main():
args = get_args()
file_paths = []
if os.path.isfile(args.input_path):
file_paths.append(args.input_path)
else:
for root, _, fs in os.walk(args.input_path):
for f in fs:
file_paths.append(os.path.join(root, f))
all_doc_ids = []
lens = []
convert = Converter(args.model_name, args.append_eod)
pool = multiprocessing.Pool(args.workers)
if convert.vocab_size < 65500:
save_dtype = np.uint16
else:
save_dtype = np.int32
for file_path in tqdm(file_paths):
text = open(file_path, 'r', encoding='utf-8').read()
text = re.sub('[\n]+', '\n', text)
text = re.sub('[ ]+', ' ', text)
encoded_docs = pool.imap(convert.encode, [text], 25)
for tokens, sizes in encoded_docs:
all_doc_ids.extend(tokens)
lens.append(sizes)
all_doc_ids = np.array(all_doc_ids, dtype=save_dtype)
lens = np.array(lens, dtype=save_dtype)
np.savez(args.input_path + "_ids.npz", ids=all_doc_ids, lens=lens)
if __name__ == "__main__":
main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import math
import os
import random
import time
import numpy as np
import paddle
from paddle.io import DataLoader, Dataset
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.transformers import GPT2Model, GPT2ForPretraining, GPT2PretrainingCriterion
from paddlenlp.transformers import GPT2Tokenizer
from paddlenlp.utils.log import logger
from data import GPT2Dataset
import lr
MODEL_CLASSES = {
"gpt2-small-en": (GPT2ForPretraining, GPT2Tokenizer),
"gpt2-medium-en": (GPT2ForPretraining, GPT2Tokenizer),
"gpt2-large-en": (GPT2ForPretraining, GPT2Tokenizer),
}
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: "
+ ", ".join(
sum([
list(classes[-1].pretrained_init_configuration.keys())
for classes in MODEL_CLASSES.values()
], [])), )
parser.add_argument(
"--input_dir",
default=None,
type=str,
required=True,
help="The input directory where the data will be read from.", )
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--batch_size",
default=8,
type=int,
help="Batch size per GPU/CPU for training.", )
parser.add_argument(
"--weight_decay",
default=0.0,
type=float,
help="Weight decay if we apply some.")
parser.add_argument(
"--grad_clip",
default=0.0,
type=float,
help="Grad clip for the parameter.")
parser.add_argument(
"--adam_epsilon",
default=1e-8,
type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument(
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--num_train_epochs",
default=1,
type=int,
help="Total number of training epochs to perform.", )
parser.add_argument(
"--max_steps",
default=520000,
type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument(
"--decay_steps",
default=360000,
type=int,
help="The steps use to control the learing rate. If the step > decay_steps, will use the min_lr.",
)
parser.add_argument(
"--max_lr",
default=1e-5,
type=float,
help="The initial max learning rate for Adam.")
parser.add_argument(
"--min_lr",
default=5e-5,
type=float,
help="The initial min learning rate for Adam.")
parser.add_argument(
"--warmup_rate",
default=0.01,
type=float,
help="Linear warmup over warmup_steps.")
parser.add_argument(
"--logging_steps",
type=int,
default=1,
help="Log every X updates steps.")
parser.add_argument(
"--save_steps",
type=int,
default=500,
help="Save checkpoint every X updates steps.")
parser.add_argument(
"--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument(
"--device",
type=str,
default="gpu",
help="select cpu, gpu, xpu devices.")
args = parser.parse_args()
return args
class WorkerInitObj(object):
def __init__(self, seed):
self.seed = seed
def __call__(self, id):
np.random.seed(seed=self.seed + id)
random.seed(self.seed + id)
def create_pretrained_dataset(args, input_path, worker_init, worker_index,
eod_id):
train_data = GPT2Dataset(
file_path=input_path,
worker_index=worker_index,
num_samples=args.batch_size * args.max_steps,
eod_id=eod_id,
seed=args.seed + worker_index)
train_batch_sampler = paddle.io.DistributedBatchSampler(
train_data, batch_size=args.batch_size, shuffle=True, drop_last=True)
train_data_loader = DataLoader(
dataset=train_data,
batch_sampler=train_batch_sampler,
num_workers=0,
worker_init_fn=worker_init,
collate_fn=Tuple(Stack(), Stack(), Stack(), Stack(), Stack()))
return train_data_loader
def set_seed(args):
if args.device == "cpu":
idx = 0
else:
idx = paddle.distributed.get_rank()
random.seed(args.seed + idx)
np.random.seed(args.seed + idx)
paddle.seed(args.seed + idx)
def do_train(args):
assert args.device in [
"cpu", "gpu", "xpu"
], "Invalid device! Available device should be cpu, gpu, or xpu."
paddle.set_device(args.device)
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
worker_index = paddle.distributed.get_rank()
worker_num = paddle.distributed.get_world_size()
set_seed(args)
worker_init = WorkerInitObj(args.seed + paddle.distributed.get_rank())
model_class, tokenizer_class = MODEL_CLASSES[args.model_name_or_path]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
eod_id = tokenizer.command_name_map["eod"].Id
model = GPT2ForPretraining(
GPT2Model(**model_class.pretrained_init_configuration[
args.model_name_or_path]))
# creat the critrion for the gpt model
criterion = GPT2PretrainingCriterion()
if args.decay_steps is None:
args.decay_steps = args.max_steps
warmup_step = args.warmup_rate * args.decay_steps
lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
max_lr=args.max_lr,
min_lr=args.min_lr,
warmup_step=warmup_step,
decay_step=args.decay_steps)
clip = None
if args.grad_clip > 0:
clip = paddle.nn.ClipGradByNorm(clip_norm=args.grad_clip)
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
epsilon=args.adam_epsilon,
parameters=model.parameters(),
weight_decay=args.weight_decay,
grad_clip=clip,
apply_decay_param_fun=lambda x: x in [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
])
global_step = 0
tic_train = time.time()
for epoch in range(args.num_train_epochs):
files = [
os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
if (os.path.isfile(os.path.join(args.input_dir, f)) and "npz_"
not in str(f))
]
files.sort()
num_files = len(files)
for f_id in range(num_files):
data_file = files[f_id]
train_data_loader = create_pretrained_dataset(
args, data_file, worker_init, worker_index, eod_id=eod_id)
for step, batch in enumerate(train_data_loader):
global_step += 1
tokens, loss_mask, attention_mask, position_ids, labels = batch
loss_mask.stop_gradient = True
attention_mask.stop_gradient = True
preds = model(tokens, position_ids, attention_mask)
loss = criterion(preds, labels, loss_mask)
if global_step % args.logging_steps == 0:
if worker_index == 0:
logger.info(
"global step %d, epoch: %d, lr: %.10f, batch: %d, loss: %f, speed: %.2f step/s"
% (global_step, epoch, optimizer.get_lr(), step,
loss,
args.logging_steps / (time.time() - tic_train)))
tic_train = time.time()
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_gradients()
if global_step % args.save_steps == 0:
if worker_index == 0:
output_dir = os.path.join(args.output_dir,
"model_%d" % global_step)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# need better way to get inner model of DataParallel
model_to_save = model._layers if isinstance(
model, paddle.DataParallel) else model
model_to_save.save_pretrained(output_dir)
if global_step >= args.max_steps:
del train_data_loader
return
del train_data_loader
if __name__ == "__main__":
args = parse_args()
do_train(args)
export CUDA_VISIBLE_DEVICES=0
python run_pretrain.py --model_name_or_path gpt2-small-en --input_dir "./data"\
--output_dir "output"\
--max_lr 0.00015\
--min_lr 0.00001\
--weight_decay 0.01\
--grad_clip 1.0\
--max_steps 500000\
--save_steps 100000\
--decay_steps 320000\
--warmup_rate 0.01\
--batch_size 8\
--device gpu
unset CUDA_VISIBLE_DEVICES
python -m paddle.distributed.launch --gpus "0,1" run_pretrain.py --model_name_or_path gpt2-small-en --input_dir "./data"\
--output_dir "output"\
--max_lr 0.00015\
--min_lr 0.00001\
--weight_decay 0.01\
--grad_clip 1.0\
--max_steps 500000\
--save_steps 100000\
--decay_steps 320000\
--warmup_rate 0.01\
--batch_size 8\
--device gpu
......@@ -19,6 +19,8 @@ from .bert.modeling import *
from .bert.tokenizer import *
from .ernie.modeling import *
from .ernie.tokenizer import *
from .gpt2.modeling import *
from .gpt2.tokenizer import *
from .roberta.modeling import *
from .roberta.tokenizer import *
from .electra.modeling import *
......
from .modeling import *
from .tokenizer import GPT2ChineseTokenizer
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import regex as re
import unicodedata
import json
import sentencepiece
import jieba
from functools import lru_cache
from collections import namedtuple
from .. import PretrainedTokenizer
from ..tokenizer_utils import convert_to_unicode, whitespace_tokenize,\
_is_whitespace, _is_control, _is_punctuation
__all__ = [
'GPT2Tokenizer',
'GPT2ChineseTokenizer',
]
COMMAND_TUPLE = namedtuple('CommandToken', ('name', 'token', 'Id'))
TYPE_TUPLE = namedtuple('TypeToken', ('name', 'token', 'Id'))
class CommandToken(object):
def __init__(self, name, token, Id):
self.name = name
self.token = token
self.Id = Id
def __str__(self):
return str(COMMAND_TUPLE(self.name, self.token, self.Id))
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr = chr
bs = list(range(ord("!"), ord("~") + 1)) + list(
range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class GPT2ChineseTokenizer(PretrainedTokenizer):
"""
Constructs a GPT2 Chinese tokenizer. It uses a basic tokenizer to do punctuation
splitting, lower casing and so on, and follows a WordPiece tokenizer to
tokenize as subwords.
"""
resource_files_names = {
"vocab_file": "vocab.json",
"model_file": "sentencepiece.model"
} # for save_pretrained
pretrained_resource_files_map = {
"vocab_file": {
"gpt2-base-cn":
"https://paddlenlp.bj.bcebos.com/models/transformers/gpt2/gpt2-base-cn-vocab.json",
},
"model_file": {
"gpt2-base-cn":
"https://paddlenlp.bj.bcebos.com/models/transformers/gpt2/gpt2-base-cn-sentencepiece.model"
}
}
pretrained_init_configuration = {"gpt2-base-cn": {"do_lower_case": True}, }
def __init__(self,
vocab_file,
model_file,
do_lower_case=True,
max_len=512,
bod_id="<bod>",
eod_id="<eod>",
max_length=None):
if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the "
"vocabulary from a pretrained model please use "
"`tokenizer = GPT2Tokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
.format(vocab_file))
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v: k for k, v in self.encoder.items()}
self.sp = sentencepiece.SentencePieceProcessor(model_file=model_file)
self.translator = str.maketrans(" \n", "\u2582\u2583")
def tokenize(self, text):
""" Tokenize a string. """
seg_list = [
x.translate(self.translator) for x in jieba.cut(text, cut_all=False)
]
new_seg = " ".join(seg_list)
return self.sp.encode(new_seg)
def encode(self, text):
return self.convert_tokens_to_ids(text)
def decode(self, tokens):
return self.convert_ids_to_tokens(tokens)
def convert_tokens_to_ids(self, text):
res = self.tokenize(text)
return res
def convert_ids_to_tokens(self, tokens):
text = self.sp.decode(tokens)
text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583',
'\n')
return text
class GPT2Tokenizer(PretrainedTokenizer):
resource_files_names = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt"
} # for save_pretrained
pretrained_resource_files_map = {
"vocab_file": {
"gpt2-large-en":
"http://paddlenlp.bj.bcebos.com/models/transformers/gpt2/gpt2-large-en-vocab.json",
"gpt2-medium-en":
"http://paddlenlp.bj.bcebos.com/models/transformers/gpt2/gpt2-medium-en-vocab.json",
"gpt2-small-en":
"http://paddlenlp.bj.bcebos.com/models/transformers/gpt2/gpt2-small-en-vocab.json",
},
"merges_file": {
"gpt2-large-en":
"http://paddlenlp.bj.bcebos.com/models/transformers/gpt2/gpt2-large-en-merges.txt",
"gpt2-medium-en":
"http://paddlenlp.bj.bcebos.com/models/transformers/gpt2/gpt2-medium-en-merges.txt",
"gpt2-small-en":
"http://paddlenlp.bj.bcebos.com/models/transformers/gpt2/gpt2-small-en-merges.txt",
}
}
pretrained_init_configuration = {
"gpt2-large-en": {
"do_lower_case": True
},
"gpt2-medium-en": {
"do_lower_case": True
},
"gpt2-small-en": {
"do_lower_case": True
},
}
def __init__(self,
vocab_file,
merges_file,
errors='replace',
special_tokens=None,
max_len=None,
do_lower_case=True):
self.max_len = int(1e12)
self.num_command_tokens = 2
self.num_type_tokens = 2
self.encoder = json.load(open(vocab_file))
self.decoder = {v: k for k, v in self.encoder.items()}
# construct the command tokens
self._command_tokens = [
CommandToken('pad', '<|endoftext|>', self.encoder['<|endoftext|>']),
CommandToken('eod', '<|endoftext|>', self.encoder['<|endoftext|>']),
]
self.command_name_map = {tok.name: tok for tok in self._command_tokens}
self.command_token_map = {
tok.token: tok
for tok in self._command_tokens
}
self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
self.num_tokens = len(self.encoder)
self.num_text_tokens = self.num_tokens - 1
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i)
for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {
v: k
for k, v in self.special_tokens.items()
}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i +
1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def tokenize(self, text):
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(
bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def convert_tokens_to_ids(self, tokens):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".
format(len(ids), self.max_len))
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text, fn=None):
processed_text = text
if fn is not None:
processed_text = fn(text)
ids = self.convert_tokens_to_ids(self.tokenize(processed_text))
return ids
def decode(self, tokens):
# TODO
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode(
'utf-8', errors=self.errors)
return text
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册