Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
d7009805
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d7009805
编写于
12月 18, 2020
作者:
L
liu zhengxi
提交者:
GitHub
12月 18, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add transformer-xl for language model (#4987)
* add transformer-xl for language model
上级
592b4cb2
变更
10
展开全部
隐藏空白更改
内联
并排
Showing
10 changed file
with
2321 addition
and
0 deletion
+2321
-0
PaddleNLP/examples/language_model/transformer-xl/README.md
PaddleNLP/examples/language_model/transformer-xl/README.md
+89
-0
PaddleNLP/examples/language_model/transformer-xl/configs/enwik8.yaml
...xamples/language_model/transformer-xl/configs/enwik8.yaml
+112
-0
PaddleNLP/examples/language_model/transformer-xl/configs/text8.yaml
...examples/language_model/transformer-xl/configs/text8.yaml
+112
-0
PaddleNLP/examples/language_model/transformer-xl/configs/wt103.yaml
...examples/language_model/transformer-xl/configs/wt103.yaml
+112
-0
PaddleNLP/examples/language_model/transformer-xl/eval.py
PaddleNLP/examples/language_model/transformer-xl/eval.py
+134
-0
PaddleNLP/examples/language_model/transformer-xl/gen_data.sh
PaddleNLP/examples/language_model/transformer-xl/gen_data.sh
+55
-0
PaddleNLP/examples/language_model/transformer-xl/mem_transformer.py
...examples/language_model/transformer-xl/mem_transformer.py
+1181
-0
PaddleNLP/examples/language_model/transformer-xl/reader.py
PaddleNLP/examples/language_model/transformer-xl/reader.py
+197
-0
PaddleNLP/examples/language_model/transformer-xl/train.py
PaddleNLP/examples/language_model/transformer-xl/train.py
+308
-0
PaddleNLP/examples/language_model/transformer-xl/utils/preprocess_text8.py
...s/language_model/transformer-xl/utils/preprocess_text8.py
+21
-0
未找到文件。
PaddleNLP/examples/language_model/transformer-xl/README.md
0 → 100644
浏览文件 @
d7009805
# Language Model
## Transformer-XL
以下是本例的简要目录结构及说明:
```
text
.
├── eval.py # 预测脚本
├── reader.py # 数据读取接口
├── README.md # 文档
├── train.py # 训练脚本
└── configs # 配置文件
```
## 模型简介
本项目是语言模型 Transformer-XL 的 PaddlePaddle 实现, 包含模型训练,预测等内容。
## 快速开始
### 安装说明
1.
paddle安装
本项目依赖于 PaddlePaddle 2.0rc及以上版本或适当的develop版本,请参考 [安装指南](https://www.paddlepaddle.org.cn/install/quick) 进行安装
2.
下载代码
克隆代码库到本地
3.
环境依赖
该模型使用PaddlePaddle,关于环境依赖部分,请先参考PaddlePaddle[安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/index_cn.html)关于环境依赖部分的内容。
此外,需要另外涉及:
* attrdict
* pyyaml
### 数据准备
公开数据集:enwik8、text8、wt103 多用于语言模型的 benchmark 测试。输出获取与处理方式如下:
```
shell
bash gen_data.sh
```
会在当前路径下的 ./gen_data/ 路径下生成我们需要的数据。
### 单机训练
### 单机单卡
以提供的 enwik8 数据为例,可以执行以下命令进行模型训练:
```
sh
# setting visible devices for training
export
CUDA_VISIBLE_DEVICES
=
0
python train.py
--config
./configs/enwik8.yaml
```
可以在 enwik8.yaml 文件中设置相应的参数,比如
`batch_size`
、
`epoch`
等。
### 单机多卡
同样,可以执行如下命令实现八卡训练:
```
sh
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
python
-m
paddle.distributed.launch
--gpus
"0,1,2,3,4,5,6,7"
train.py
--config
./configs/enwik8.yaml
```
### 模型推断
以 enwik8 数据为例,模型训练完成后可以执行以下命令可以进行预测:
```
sh
# setting visible devices for prediction
export
CUDA_VISIBLE_DEVICES
=
0
python eval.py
--config
./configs/enwik8.yaml
```
完成推断之后,会将显示在验证集和测试集上的结果。
## 参考文献
PaddleNLP/examples/language_model/transformer-xl/configs/enwik8.yaml
0 → 100644
浏览文件 @
d7009805
# The frequency to save trained models when training.
save_step
:
10000
# The frequency to fetch and print output when training.
print_step
:
100
# Path of the checkpoint, to resume the previous training
init_from_checkpoint
:
"
"
# Path of the pretrain model, to better solve the current task
init_from_pretrain_model
:
"
"
# Path of trained parameter, to make prediction
init_from_params
:
"
./trained_models/step_final/"
# The directory for saving model
save_model
:
"
trained_models"
# The directory for saving inference model.
inference_model_dir
:
"
infer_model"
# Set seed for CE or debug
random_seed
:
None
# The path to data files
data
:
"
./gen_data/enwik8/"
# The name of dataset
dataset
:
"
enwik8"
# Whether to use cuda
use_gpu
:
True
# Args for reader, see reader.py for details
token_delimiter
:
None
batch_size
:
16
eval_batch_size
:
2
# Hyparams for training:
# The number of epoches for training
epoch
:
30
# The hyper parameters for optimizer.
# Type of ptimizer.
optim
:
adam
# Learning rate schedule.
scheduler
:
cosine
# This static learning_rate will be applied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate
:
0.00025
# The hyper parameters for Adam optimizer.
beta1
:
0.9
beta2
:
0.997
eps
:
1e-9
# The hyper parameters for Momentum optimizer.
mom
:
0.0
# Global gradient clip.
clip
:
0.25
# The parameters for learning rate scheduling.
warmup_steps
:
0
# The parameters for CosineAnnealingDecay. Minimum learning rate.
eta_min
:
0.0
# The parameters for ReduceLROnPlateau.
# The Ratio that the learning rate will be reduced.
decay_rate
:
0.5
# When loss doesn’t improve for this number of epochs, learing rate will be reduced.
patience
:
0
# The lower bound of the learning rate after reduction.
min_lr
:
0.0
# Hyparams for model:
# Whe use adaptive softmax.
adaptive
:
False
# Size of dictionary. This can be obtained automatically.
ntokens
:
10000
# The dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model
:
512
# Dimension of heads.
d_head
:
64
# Size of the hidden layer in position-wise feed-forward networks.
d_inner_hid
:
2048
# Number of head used in multi-head attention.
n_head
:
8
# Number of sub-layers to be stacked in the encoder and decoder.
n_layer
:
12
# Dropout rates.
dropout
:
0.1
# Attention dropout
attn_dropout
:
0.0
# Attention type for decoder.
# 0 for relative partial MHA (in Transformer-XL).
# 1 for relative MHA (in Shaw et al).
attn_type
:
0
# Apply layer normalization before or after sublayers.
normalize_before
:
False
# Whether to tie weight or not.
tie_weight
:
True
# The length of the extended context.
ext_len
:
0
# The divident value for softmax and adapative input.
div_val
:
1
# Target length. The number of tokens to predict.
tgt_len
:
512
# Memory length. The length of the retained previous heads.
mem_len
:
512
# Use the same attention length for all tokens.
same_length
:
False
# Use the same positional encoding after clamp len.
clamp_len
:
-1
# The number of samples in sample softmax. -1 means do not use sampled softmax.
sample_softmax
:
-1
# Max step for training.
max_step
:
400000
# Target length for evaluation. That is, the number of tokens to predict for evaluation.
eval_tgt_len
:
128
# What kind of mode for evaluation. valid, test or both("all").
mode
:
"
all"
# Maximum evaluation step.
max_eval_steps
:
-1
PaddleNLP/examples/language_model/transformer-xl/configs/text8.yaml
0 → 100644
浏览文件 @
d7009805
# The frequency to save trained models when training.
save_step
:
10000
# The frequency to fetch and print output when training.
print_step
:
100
# Path of the checkpoint, to resume the previous training
init_from_checkpoint
:
"
"
# Path of the pretrain model, to better solve the current task
init_from_pretrain_model
:
"
"
# Path of trained parameter, to make prediction
init_from_params
:
"
./trained_models/step_final/"
# The directory for saving model
save_model
:
"
trained_models"
# The directory for saving inference model.
inference_model_dir
:
"
infer_model"
# Set seed for CE or debug
random_seed
:
None
# The path to data files
data
:
"
./gen_data/text8/"
# The name of dataset
dataset
:
"
text8"
# Whether to use cuda
use_gpu
:
True
# Args for reader, see reader.py for details
token_delimiter
:
None
batch_size
:
15
eval_batch_size
:
5
# Hyparams for training:
# The number of epoches for training
epoch
:
30
# The hyper parameters for optimizer.
# Type of ptimizer.
optim
:
adam
# Learning rate schedule.
scheduler
:
cosine
# This static learning_rate will be applied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate
:
0.00025
# The hyper parameters for Adam optimizer.
beta1
:
0.9
beta2
:
0.997
eps
:
1e-9
# The hyper parameters for Momentum optimizer.
mom
:
0.0
# Global gradient clip.
clip
:
0.25
# The parameters for learning rate scheduling.
warmup_steps
:
0
# The parameters for CosineAnnealingDecay. Minimum learning rate.
eta_min
:
0.0
# The parameters for ReduceLROnPlateau.
# The Ratio that the learning rate will be reduced.
decay_rate
:
0.5
# When loss doesn’t improve for this number of epochs, learing rate will be reduced.
patience
:
0
# The lower bound of the learning rate after reduction.
min_lr
:
0.0
# Hyparams for model:
# Whe use adaptive softmax.
adaptive
:
False
# Size of dictionary. This can be obtained automatically.
ntokens
:
10000
# The dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model
:
512
# Dimension of heads.
d_head
:
64
# Size of the hidden layer in position-wise feed-forward networks.
d_inner_hid
:
2048
# Number of head used in multi-head attention.
n_head
:
8
# Number of sub-layers to be stacked in the encoder and decoder.
n_layer
:
12
# Dropout rates.
dropout
:
0.1
# Attention dropout
attn_dropout
:
0.0
# Attention type for decoder.
# 0 for relative partial MHA (in Transformer-XL).
# 1 for relative MHA (in Shaw et al).
attn_type
:
0
# Apply layer normalization before or after sublayers.
normalize_before
:
False
# Whether to tie weight or not.
tie_weight
:
True
# The length of the extended context.
ext_len
:
0
# The divident value for softmax and adapative input.
div_val
:
1
# Target length. The number of tokens to predict.
tgt_len
:
512
# Memory length. The length of the retained previous heads.
mem_len
:
512
# Use the same attention length for all tokens.
same_length
:
False
# Use the same positional encoding after clamp len.
clamp_len
:
-1
# The number of samples in sample softmax. -1 means do not use sampled softmax.
sample_softmax
:
-1
# Max step for training.
max_step
:
400000
# Target length for evaluation. That is, the number of tokens to predict for evaluation.
eval_tgt_len
:
128
# What kind of mode for evaluation. valid, test or both("all").
mode
:
"
all"
# Maximum evaluation step.
max_eval_steps
:
-1
PaddleNLP/examples/language_model/transformer-xl/configs/wt103.yaml
0 → 100644
浏览文件 @
d7009805
# The frequency to save trained models when training.
save_step
:
10000
# The frequency to fetch and print output when training.
print_step
:
100
# Path of the checkpoint, to resume the previous training
init_from_checkpoint
:
"
"
# Path of the pretrain model, to better solve the current task
init_from_pretrain_model
:
"
"
# Path of trained parameter, to make prediction
init_from_params
:
"
./trained_models/step_final/"
# The directory for saving model
save_model
:
"
trained_models"
# The directory for saving inference model.
inference_model_dir
:
"
infer_model"
# Set seed for CE or debug
random_seed
:
None
# The path to data files
data
:
"
./gen_data/wikitext-103/"
# The name of dataset
dataset
:
"
wt103"
# Whether to use cuda
use_gpu
:
True
# Args for reader, see reader.py for details
token_delimiter
:
None
batch_size
:
32
eval_batch_size
:
5
# Hyparams for training:
# The number of epoches for training
epoch
:
30
# The hyper parameters for optimizer.
# Type of ptimizer.
optim
:
adam
# Learning rate schedule.
scheduler
:
cosine
# This static learning_rate will be applied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate
:
0.00025
# The hyper parameters for Adam optimizer.
beta1
:
0.9
beta2
:
0.997
eps
:
1e-9
# The hyper parameters for Momentum optimizer.
mom
:
0.0
# Global gradient clip.
clip
:
0.25
# The parameters for learning rate scheduling.
warmup_steps
:
0
# The parameters for CosineAnnealingDecay. Minimum learning rate.
eta_min
:
0.0
# The parameters for ReduceLROnPlateau.
# The Ratio that the learning rate will be reduced.
decay_rate
:
0.5
# When loss doesn’t improve for this number of epochs, learing rate will be reduced.
patience
:
0
# The lower bound of the learning rate after reduction.
min_lr
:
0.0
# Hyparams for model:
# Whe use adaptive softmax.
adaptive
:
True
# Size of dictionary. This can be obtained automatically.
ntokens
:
10000
# The dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model
:
410
# Dimension of heads.
d_head
:
41
# Size of the hidden layer in position-wise feed-forward networks.
d_inner_hid
:
2100
# Number of head used in multi-head attention.
n_head
:
10
# Number of sub-layers to be stacked in the encoder and decoder.
n_layer
:
16
# Dropout rates.
dropout
:
0.1
# Attention dropout
attn_dropout
:
0.0
# Attention type for decoder.
# 0 for relative partial MHA (in Transformer-XL).
# 1 for relative MHA (in Shaw et al).
attn_type
:
0
# Apply layer normalization before or after sublayers.
normalize_before
:
False
# Whether to tie weight or not.
tie_weight
:
True
# The length of the extended context.
ext_len
:
0
# The divident value for softmax and adapative input.
div_val
:
1
# Target length. The number of tokens to predict.
tgt_len
:
150
# Memory length. The length of the retained previous heads.
mem_len
:
150
# Target length for evaluation. That is, the number of tokens to predict for evaluation.
eval_tgt_len
:
150
# Use the same attention length for all tokens.
same_length
:
False
# Use the same positional encoding after clamp len.
clamp_len
:
-1
# The number of samples in sample softmax. -1 means do not use sampled softmax.
sample_softmax
:
-1
# Max step for training.
max_step
:
200000
# What kind of mode for evaluation. valid, test or both("all").
mode
:
"
all"
# Maximum evaluation step.
max_eval_steps
:
-1
PaddleNLP/examples/language_model/transformer-xl/eval.py
0 → 100644
浏览文件 @
d7009805
import
os
import
time
import
yaml
import
logging
import
argparse
import
numpy
as
np
from
pprint
import
pprint
from
attrdict
import
AttrDict
import
paddle
from
reader
import
get_lm_vocab
,
get_lm_data_loader
from
mem_transformer
import
MemTransformerLM
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
logger
=
logging
.
getLogger
(
__name__
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--config"
,
default
=
"./configs/enwik8.yaml"
,
type
=
str
,
help
=
"Path of the config file. "
)
args
=
parser
.
parse_args
()
return
args
def
do_eval
(
args
):
assert
args
.
ext_len
>=
0
,
'Extended context length must be no less than 0'
def
_evaluate
(
loader
):
total_len
,
total_loss
=
0
,
0.
eval_mems
=
tuple
()
for
i
,
(
src
,
target
,
seq_len
)
in
enumerate
(
loader
):
if
args
.
max_eval_steps
>
0
and
i
>=
args
.
max_eval_steps
:
break
ret
=
mem_transformer
(
src
,
target
,
*
eval_mems
)
loss
,
eval_mems
=
ret
[
0
],
ret
[
1
:]
seq_len
=
seq_len
.
numpy
()
eval_cur_loss
=
seq_len
*
loss
.
numpy
()
total_loss
+=
eval_cur_loss
total_len
+=
seq_len
return
total_loss
/
total_len
def
_logger
(
loss
):
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
logger_info
=
"loss: %f, bpc: %f"
%
\
(
loss
,
loss
/
np
.
log
(
2
))
else
:
logger_info
=
"loss: %f, ppl: %.2f"
%
\
(
loss
,
np
.
exp
(
loss
))
return
logger_info
vocab
=
get_lm_vocab
(
args
)
eval_loader
=
get_lm_data_loader
(
args
,
vocab
,
"valid"
)
test_loader
=
get_lm_data_loader
(
args
,
vocab
,
"test"
)
cutoffs
,
tie_projs
=
[],
[
False
]
if
args
.
adaptive
:
assert
args
.
dataset
in
[
'wt103'
,
'lm1b'
]
if
args
.
dataset
==
'wt103'
:
cutoffs
=
[
20000
,
40000
,
200000
]
tie_projs
+=
[
True
]
*
len
(
cutoffs
)
elif
args
.
dataset
==
'lm1b'
:
cutoffs
=
[
60000
,
100000
,
640000
]
tie_projs
+=
[
False
]
*
len
(
cutoffs
)
mem_transformer
=
MemTransformerLM
(
args
.
ntokens
,
args
.
n_layer
,
args
.
n_head
,
args
.
d_model
,
args
.
d_head
,
args
.
d_inner_hid
,
args
.
dropout
,
args
.
attn_dropout
,
tie_weight
=
args
.
tie_weight
,
d_embed
=
args
.
d_model
,
div_val
=
args
.
div_val
,
tie_projs
=
tie_projs
,
normalize_before
=
args
.
normalize_before
,
tgt_len
=
args
.
tgt_len
,
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
,
cutoffs
=
cutoffs
,
same_length
=
args
.
same_length
,
attn_type
=
args
.
attn_type
,
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
)
assert
args
.
init_from_params
,
(
"Please set init_from_params to load the infer model."
)
model_dict
=
paddle
.
load
(
os
.
path
.
join
(
args
.
init_from_params
,
"mem_transformer.pdparams"
))
mem_transformer
.
load_dict
(
model_dict
)
logger
.
info
(
"Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}"
.
format
(
args
.
batch_size
,
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
,
args
.
clamp_len
))
mem_transformer
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
test_loss
=
None
valid_loss
=
None
if
args
.
mode
==
'all'
:
test_loss
=
_evaluate
(
test_loader
)
valid_loss
=
_evaluate
(
eval_loader
)
elif
args
.
mode
==
'valid'
:
valid_loss
=
_evaluate
(
eval_loader
)
elif
args
.
mode
==
'test'
:
test_loss
=
_evaluate
(
test_loader
)
logger_info
=
''
if
valid_loss
is
not
None
:
logger_info
=
logger_info
+
_logger
(
valid_loss
)
if
test_loss
is
not
None
:
logger_info
=
logger_info
+
_logger
(
test_loss
)
logger
.
info
(
logger_info
)
if
__name__
==
"__main__"
:
ARGS
=
parse_args
()
yaml_file
=
ARGS
.
config
with
open
(
yaml_file
,
'rt'
)
as
f
:
args
=
AttrDict
(
yaml
.
safe_load
(
f
))
pprint
(
args
)
do_eval
(
args
)
PaddleNLP/examples/language_model/transformer-xl/gen_data.sh
0 → 100644
浏览文件 @
d7009805
echo
"Downloading dataset..."
CUR_DIR
=
$PWD
mkdir
-p
gen_data
cd
./gen_data/
if
[
!
-d
"wikitext-103"
]
;
then
echo
"Downloading wikitext-103..."
wget
-O
wikitext-103-v1.zip https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip
echo
"Unzip wikitext-103..."
unzip wikitext-103-v1.zip
cd
wikitext-103
# Rename
mv
wiki.train.tokens train.txt
mv
wiki.valid.tokens valid.txt
mv
wiki.test.tokens test.txt
cd
-
fi
if
[
!
-d
'enwik8'
]
;
then
mkdir
-p
enwik8
cd
enwik8
echo
"Downloading enwik8..."
wget
-O
enwik8.zip http://mattmahoney.net/dc/enwik8.zip
wget
-O
prep_enwik8.py https://raw.githubusercontent.com/salesforce/awd-lstm-lm/master/data/enwik8/prep_enwik8.py
python3 prep_enwik8.py
rm
-f
prep_enwik8.py
cd
-
fi
if
[
!
-d
'text8'
]
;
then
mkdir
-p
text8
cd
text8
echo
"Downloading text8..."
wget
-O
text8.zip http://mattmahoney.net/dc/text8.zip
python
${
CUR_DIR
}
/utils/preprocess_text8.py 5000000
cd
-
fi
if
[
!
-d
'one-billion-words'
]
;
then
mkdir
-p
one-billion-words
cd
one-billion-words
echo
"Downloading one-billion-words..."
wget
-O
1-billion-word-language-modeling-benchmark-r13output.tar.gz http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz
tar
xzf 1-billion-word-language-modeling-benchmark-r13output.tar.gz
dir
=
"./1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/"
cat
${
dir
}
/news.en.heldout-00000-of-00050
>
valid.txt
cat
${
dir
}
/news.en.heldout-00000-of-00050
>
test.txt
wget
-O
1b_word_vocab.txt https://github.com/rafaljozefowicz/lm/raw/master/1b_word_vocab.txt
cd
-
fi
echo
"All done. "
PaddleNLP/examples/language_model/transformer-xl/mem_transformer.py
0 → 100644
浏览文件 @
d7009805
此差异已折叠。
点击以展开。
PaddleNLP/examples/language_model/transformer-xl/reader.py
0 → 100644
浏览文件 @
d7009805
import
os
import
numpy
as
np
from
paddlenlp.data
import
Vocab
import
paddle
from
paddle.io
import
IterableDataset
,
DataLoader
import
paddle.distributed
as
dist
class
LMDataset
(
IterableDataset
):
def
__init__
(
self
,
mode
,
vocab
,
path
,
dataset_name
,
batch_size
,
bptt
,
ext_len
,
nranks
,
rank
):
assert
(
mode
in
[
"train"
,
"valid"
,
"test"
]
),
"Parameter mode must be one of [train, valid, test]."
super
(
LMDataset
,
self
).
__init__
()
self
.
vocab
=
vocab
self
.
dataset_name
=
dataset_name
if
self
.
dataset_name
in
[
"wt103"
]:
self
.
data
=
self
.
read_raw_data
(
filename
=
os
.
path
.
join
(
path
,
mode
+
".txt"
),
ordered
=
True
)
elif
self
.
dataset_name
in
[
"enwik8"
,
"text8"
]:
self
.
data
=
self
.
read_raw_data
(
filename
=
os
.
path
.
join
(
path
,
mode
+
".txt"
),
ordered
=
True
,
add_eos
=
False
)
else
:
raise
ValueError
(
"Not supported dataset yet. "
)
self
.
rank
=
rank
self
.
batch_size
=
batch_size
batch_size
*=
nranks
self
.
bptt
=
bptt
self
.
ext_len
=
ext_len
if
ext_len
is
not
None
else
0
self
.
num_step
=
len
(
self
.
data
)
//
batch_size
data
=
self
.
data
[:
self
.
num_step
*
batch_size
]
self
.
data
=
data
.
reshape
([
batch_size
,
-
1
])
# Number of samples
self
.
num_samples
=
(
self
.
num_step
+
self
.
bptt
-
1
)
//
self
.
bptt
def
__len__
(
self
):
return
self
.
num_samples
def
__iter__
(
self
):
for
i
in
range
(
0
,
self
.
data
.
shape
[
1
]
-
1
,
self
.
bptt
):
seq_len
=
min
(
self
.
bptt
,
self
.
data
.
shape
[
1
]
-
1
-
i
)
end_idx
=
i
+
seq_len
beg_idx
=
max
(
0
,
i
-
self
.
ext_len
)
src
=
self
.
data
[:,
beg_idx
:
end_idx
]
target
=
self
.
data
[:,
i
+
1
:
i
+
1
+
seq_len
]
# NOTE: `seq_len` will be transfered to numpy immediately
# after returned by DataLoader. Hence, `seq_len` can be
# yield as `int`. And the returned tensor `seq_len`'s shape
# will be empty [].
# However, if it's necessary to use `seq_len` as input for some
# PaddlePaddle op, then it must be returned by `[seq_len]` whose
# shape is [1], cause some op cannot use shape [] as input.
yield
[
src
[
self
.
rank
*
self
.
batch_size
:(
self
.
rank
+
1
)
*
self
.
batch_size
],
target
[
self
.
rank
*
self
.
batch_size
:(
self
.
rank
+
1
)
*
self
.
batch_size
],
seq_len
]
def
read_raw_data
(
self
,
filename
,
ordered
=
False
,
lower_case
=
True
,
delimiter
=
None
,
add_eos
=
True
,
add_double_eos
=
False
):
assert
os
.
path
.
exists
(
filename
),
"%s is not exist. "
%
filename
data
=
[]
with
open
(
filename
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
line
in
f
:
tokens
=
LMDataset
.
tokenize
(
line
=
line
,
delimiter
=
delimiter
,
lower_case
=
lower_case
)
if
add_double_eos
:
# for lm1b
tokens
=
[
self
.
vocab
.
_identifiers_to_tokens
[
'bos_token'
]
]
+
tokens
+
[
self
.
vocab
.
_identifiers_to_tokens
[
'bos_token'
]
]
elif
add_eos
:
tokens
=
tokens
+
[
self
.
vocab
.
_identifiers_to_tokens
[
'eos_token'
]
]
data
.
append
(
np
.
asarray
(
self
.
get_indices
(
tokens
)).
astype
(
"int64"
))
if
ordered
:
data
=
np
.
concatenate
(
data
)
return
data
def
get_indices
(
self
,
tokens
):
return
self
.
vocab
.
to_indices
(
tokens
)
@
classmethod
def
get_vocab
(
cls
,
files
,
max_size
=
None
,
min_freq
=
0
,
lower_case
=
True
,
delimiter
=
None
,
unk_token
=
None
,
pad_token
=
None
,
bos_token
=
None
,
eos_token
=
None
,
**
kwargs
):
return
Vocab
.
build_vocab
(
cls
.
data_iterator
(
files
=
files
,
delimiter
=
delimiter
,
lower_case
=
lower_case
),
max_size
=
max_size
,
min_freq
=
min_freq
,
unk_token
=
unk_token
,
pad_token
=
pad_token
,
bos_token
=
bos_token
,
eos_token
=
eos_token
)
@
classmethod
def
tokenize
(
cls
,
line
,
delimiter
=
None
,
lower_case
=
True
):
line
=
line
.
strip
()
if
lower_case
:
line
=
line
.
lower
()
tokens
=
list
(
line
)
if
delimiter
==
""
else
line
.
split
(
delimiter
)
return
tokens
@
classmethod
def
data_iterator
(
cls
,
files
,
delimiter
=
None
,
lower_case
=
True
):
if
isinstance
(
files
,
str
):
files
=
[
files
]
elif
not
isinstance
(
files
,
(
list
,
tuple
)):
raise
ValueError
(
"The parameter files must be a str or a list/tuple."
)
for
fl
in
files
:
assert
os
.
path
.
exists
(
fl
),
"%s is not exist. "
%
fl
with
open
(
fl
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
line
in
f
:
tokens
=
cls
.
tokenize
(
line
=
line
,
delimiter
=
delimiter
,
lower_case
=
lower_case
)
yield
tokens
def
get_lm_data_loader
(
args
,
vocab
,
mode
=
"train"
):
lm_dataset
=
LMDataset
(
mode
=
mode
,
vocab
=
vocab
,
path
=
args
.
data
,
dataset_name
=
args
.
dataset
,
batch_size
=
args
.
batch_size
if
mode
==
"train"
else
args
.
eval_batch_size
,
bptt
=
args
.
tgt_len
,
ext_len
=
args
.
ext_len
,
nranks
=
dist
.
get_world_size
()
if
mode
==
"train"
else
1
,
rank
=
dist
.
get_rank
()
if
mode
==
"train"
else
0
)
data_loader
=
DataLoader
(
dataset
=
lm_dataset
,
batch_size
=
None
,
num_workers
=
0
,
return_list
=
True
)
return
data_loader
def
get_lm_vocab
(
args
):
kwargs
=
{
"unk_token"
:
"<unk>"
}
if
args
.
token_delimiter
==
"None"
:
kwargs
[
"delimiter"
]
=
None
else
:
kwargs
[
"delimiter"
]
=
args
.
token_delimiter
if
args
.
dataset
==
"wt103"
:
kwargs
[
"eos_token"
]
=
"<eos>"
kwargs
[
"lower_case"
]
=
False
if
args
.
dataset
in
[
"enwik8"
,
"text8"
]:
files
=
[
os
.
path
.
join
(
args
.
data
,
"train.txt"
),
os
.
path
.
join
(
args
.
data
,
"valid.txt"
),
os
.
path
.
join
(
args
.
data
,
"test.txt"
)
]
elif
args
.
dataset
==
"wt103"
:
files
=
[
os
.
path
.
join
(
args
.
data
,
"train.txt"
)]
else
:
raise
ValueError
(
"Not supported dataset yet. "
)
vocab
=
LMDataset
.
get_vocab
(
files
,
**
kwargs
)
args
.
ntokens
=
len
(
vocab
)
print
(
"Finish processing vocabulary, and the size of vocabulary is {}"
.
format
(
args
.
ntokens
))
return
vocab
PaddleNLP/examples/language_model/transformer-xl/train.py
0 → 100644
浏览文件 @
d7009805
import
os
import
time
import
yaml
import
logging
import
argparse
import
numpy
as
np
from
pprint
import
pprint
from
attrdict
import
AttrDict
import
paddle
import
paddle.nn
as
nn
import
paddle.distributed
as
dist
from
mem_transformer
import
MemTransformerLM
from
reader
import
get_lm_vocab
,
get_lm_data_loader
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
logger
=
logging
.
getLogger
(
__name__
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--config"
,
default
=
"./configs/enwik8.yaml"
,
type
=
str
,
help
=
"Path of the config file. "
)
args
=
parser
.
parse_args
()
return
args
def
do_train
(
args
):
if
args
.
use_gpu
:
rank
=
dist
.
get_rank
()
trainer_count
=
dist
.
get_world_size
()
else
:
rank
=
0
trainer_count
=
1
if
trainer_count
>
1
:
dist
.
init_parallel_env
()
random_seed
=
eval
(
str
(
args
.
random_seed
))
if
random_seed
is
not
None
:
paddle
.
seed
(
random_seed
)
vocab
=
get_lm_vocab
(
args
)
train_loader
=
get_lm_data_loader
(
args
,
vocab
,
"train"
)
eval_loader
=
get_lm_data_loader
(
args
,
vocab
,
"valid"
)
cutoffs
,
tie_projs
=
[],
[
False
]
if
args
.
adaptive
:
assert
args
.
dataset
in
[
'wt103'
,
'lm1b'
]
if
args
.
dataset
==
'wt103'
:
cutoffs
=
[
20000
,
40000
,
200000
]
tie_projs
+=
[
True
]
*
len
(
cutoffs
)
elif
args
.
dataset
==
'lm1b'
:
cutoffs
=
[
60000
,
100000
,
640000
]
tie_projs
+=
[
False
]
*
len
(
cutoffs
)
mem_transformer
=
MemTransformerLM
(
args
.
ntokens
,
args
.
n_layer
,
args
.
n_head
,
args
.
d_model
,
args
.
d_head
,
args
.
d_inner_hid
,
args
.
dropout
,
args
.
attn_dropout
,
tie_weight
=
args
.
tie_weight
,
d_embed
=
args
.
d_model
,
div_val
=
args
.
div_val
,
tie_projs
=
tie_projs
,
normalize_before
=
args
.
normalize_before
,
tgt_len
=
args
.
tgt_len
,
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
,
cutoffs
=
cutoffs
,
same_length
=
args
.
same_length
,
attn_type
=
args
.
attn_type
,
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
)
if
args
.
scheduler
==
'cosine'
:
scheduler
=
paddle
.
optimizer
.
lr
.
CosineAnnealingDecay
(
learning_rate
=
args
.
learning_rate
,
T_max
=
args
.
max_step
,
eta_min
=
args
.
eta_min
)
elif
args
.
scheduler
==
'noam'
:
scheduler
=
paddle
.
optimizer
.
lr
.
NoamDecay
(
d_model
=
args
.
d_model
,
warmup_steps
=
args
.
warmup_steps
,
learning_rate
=
args
.
learning_rate
)
elif
args
.
scheduler
==
'dev_perf'
:
# fluid api
scheduler
=
paddle
.
fluid
.
dygraph
.
ReduceLROnPlateau
(
learning_rate
=
args
.
learning_rate
,
decay_rate
=
args
.
decay_rate
,
patience
=
args
.
patience
,
min_lr
=
args
.
lr_min
)
elif
args
.
scheduler
==
'constant'
:
scheduler
=
args
.
learning_rate
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
args
.
clip
)
if
args
.
optim
.
lower
()
==
'momentum'
:
optimizer
=
paddle
.
optimizer
.
Momentum
(
learning_rate
=
scheduler
,
parameters
=
mem_transformer
.
parameters
(),
momentum
=
args
.
mom
,
grad_clip
=
clip
)
elif
args
.
optim
.
lower
()
==
'adam'
:
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
scheduler
,
parameters
=
mem_transformer
.
parameters
(),
beta1
=
args
.
beta1
,
beta2
=
args
.
beta2
,
epsilon
=
eval
(
args
.
eps
),
grad_clip
=
clip
)
elif
args
.
optim
.
lower
()
==
'adagrad'
:
optimizer
=
paddle
.
optimizer
.
Adagrad
(
learning_rate
=
scheduler
,
parameters
=
mem_transformer
.
parameters
(),
grad_clip
=
clip
)
# Init from some checkpoint, to resume the previous training
if
args
.
init_from_checkpoint
:
model_dict
=
paddle
.
load
(
os
.
path
.
join
(
args
.
init_from_checkpoint
,
"mem_transformer.pdparams"
))
opt_dict
=
paddle
.
load
(
os
.
path
.
join
(
args
.
init_from_checkpoint
,
"mem_transformer.pdopt"
))
mem_transformer
.
set_state_dict
(
model_dict
)
optimizer
.
set_state_dict
(
opt_dict
)
print
(
"loaded from checkpoint."
)
# Init from some pretrain models, to better solve the current task
if
args
.
init_from_pretrain_model
:
model_dict
=
paddle
.
load
(
os
.
path
.
join
(
args
.
init_from_pretrain_model
,
"mem_transformer.pdparams"
))
mem_transformer
.
set_state_dict
(
model_dict
)
print
(
"loaded from pre-trained model."
)
if
trainer_count
>
1
:
mem_transformer
=
paddle
.
DataParallel
(
mem_transformer
)
step_idx
=
0
train_loss
=
0.0
log_start_time
=
time
.
time
()
for
pass_id
in
range
(
args
.
epoch
):
batch_id
=
0
mems
=
tuple
()
for
input_data
in
train_loader
:
(
src
,
target
,
seq_len
)
=
input_data
ret
=
mem_transformer
(
src
,
target
,
*
mems
)
loss
=
ret
[
0
]
mems
=
ret
[
1
:]
train_loss
+=
loss
.
numpy
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
clear_grad
()
if
step_idx
>
0
and
step_idx
%
args
.
print_step
==
0
and
rank
==
0
:
cur_loss
=
train_loss
/
args
.
print_step
elapsed
=
time
.
time
()
-
log_start_time
if
args
.
scheduler
==
"constant"
:
lr
=
optimizer
.
get_lr
()
else
:
lr
=
scheduler
.
get_lr
()
logger_info
=
"step_idx: %d, epoch: %d, batch: %d, learning rate: %.8f, "
\
"speed: %f ms/batch, loss: %f"
%
\
(
step_idx
,
pass_id
,
batch_id
,
lr
,
elapsed
*
1000.0
/
args
.
print_step
,
cur_loss
)
if
args
.
dataset
in
[
"enwik8"
,
"text8"
]:
logger_info
=
logger_info
+
", bpc: %f"
%
(
cur_loss
/
np
.
log
(
2
))
else
:
logger_info
=
logger_info
+
", ppl: %f"
%
(
np
.
exp
(
cur_loss
))
logger
.
info
(
logger_info
)
train_loss
=
0.0
log_start_time
=
time
.
time
()
if
step_idx
%
args
.
save_step
==
0
and
step_idx
!=
0
:
# Do validation.
mem_transformer
.
eval
()
# TODO(FrostML): simplify this.
if
args
.
mem_len
==
0
:
if
dist
.
get_world_size
()
==
1
:
mem_transformer
.
reset_length
(
tgt_len
=
args
.
eval_tgt_len
,
ext_len
=
args
.
ext_len
+
args
.
tgt_len
-
args
.
eval_tgt_len
,
mem_len
=
args
.
mem_len
)
else
:
mem_transformer
.
_layers
.
reset_length
(
tgt_len
=
args
.
eval_tgt_len
,
ext_len
=
args
.
ext_len
+
args
.
tgt_len
-
args
.
eval_tgt_len
,
mem_len
=
args
.
mem_len
)
else
:
if
dist
.
get_world_size
()
==
1
:
mem_transformer
.
reset_length
(
tgt_len
=
args
.
eval_tgt_len
,
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
+
args
.
tgt_len
-
args
.
eval_tgt_len
)
else
:
mem_transformer
.
_layers
.
reset_length
(
tgt_len
=
args
.
eval_tgt_len
,
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
+
args
.
tgt_len
-
args
.
eval_tgt_len
)
total_len
,
total_loss
=
0
,
0.
eval_mems
=
tuple
()
with
paddle
.
no_grad
():
for
i
,
(
src
,
target
,
seq_len
)
in
enumerate
(
eval_loader
):
if
args
.
max_eval_steps
>
0
and
i
>=
args
.
max_eval_steps
:
break
ret
=
mem_transformer
(
src
,
target
,
*
eval_mems
)
loss
,
eval_mems
=
ret
[
0
],
ret
[
1
:]
seq_len
=
seq_len
.
numpy
()
eval_cur_loss
=
seq_len
*
loss
.
numpy
()
total_loss
+=
eval_cur_loss
total_len
+=
seq_len
eval_loss
=
total_loss
/
total_len
logger_info
=
"Validation, step_idx: %d, validation loss: %f"
%
\
(
step_idx
,
eval_loss
)
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
logger_info
=
logger_info
+
", bpc: %f"
%
(
eval_loss
/
np
.
log
(
2
))
else
:
logger_info
=
logger_info
+
", ppl: %f"
%
(
np
.
exp
(
eval_loss
)
)
logger
.
info
(
logger_info
)
if
args
.
save_model
and
rank
==
0
:
model_dir
=
os
.
path
.
join
(
args
.
save_model
,
"step_"
+
str
(
step_idx
))
if
not
os
.
path
.
exists
(
model_dir
):
os
.
makedirs
(
model_dir
)
paddle
.
save
(
mem_transformer
.
state_dict
(),
os
.
path
.
join
(
model_dir
,
"mem_transformer.pdparams"
))
paddle
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
model_dir
,
"mem_transformer.pdopt"
))
if
args
.
scheduler
==
'dev_perf'
:
scheduler
.
step
(
eval_loss
)
# TODO(FrostML): simplify this.
if
dist
.
get_world_size
()
==
1
:
mem_transformer
.
reset_length
(
tgt_len
=
args
.
tgt_len
,
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
)
else
:
mem_transformer
.
_layers
.
reset_length
(
tgt_len
=
args
.
tgt_len
,
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
)
mem_transformer
.
train
()
step_idx
+=
1
batch_id
+=
1
if
args
.
scheduler
in
[
'cosine'
,
'dev_perf'
]:
if
step_idx
<
args
.
warmup_steps
:
curr_lr
=
args
.
learning_rate
*
step_idx
/
args
.
warmup_steps
scheduler
.
base_lr
=
curr_lr
else
:
if
args
.
scheduler
==
'cosine'
:
scheduler
.
step
()
elif
args
.
scheduler
==
'constant'
:
if
step_idx
<
args
.
warmup_steps
:
curr_lr
=
args
.
learning_rate
*
step_idx
/
args
.
warmup_steps
optimizer
.
set_lr
(
curr_lr
)
elif
args
.
scheduler
==
'noam'
:
scheduler
.
step
()
if
step_idx
>=
args
.
max_step
:
break
if
args
.
save_model
and
rank
==
0
:
model_dir
=
os
.
path
.
join
(
args
.
save_model
,
"step_final"
)
if
not
os
.
path
.
exists
(
model_dir
):
os
.
makedirs
(
model_dir
)
paddle
.
save
(
mem_transformer
.
state_dict
(),
os
.
path
.
join
(
model_dir
,
"mem_transformer.pdparams"
))
paddle
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
model_dir
,
"mem_transformer.pdopt"
))
if
__name__
==
"__main__"
:
ARGS
=
parse_args
()
yaml_file
=
ARGS
.
config
with
open
(
yaml_file
,
'rt'
)
as
f
:
args
=
AttrDict
(
yaml
.
safe_load
(
f
))
pprint
(
args
)
do_train
(
args
)
PaddleNLP/examples/language_model/transformer-xl/utils/preprocess_text8.py
0 → 100644
浏览文件 @
d7009805
import
sys
import
zipfile
import
argparse
if
__name__
==
"__main__"
:
data
=
zipfile
.
ZipFile
(
"text8.zip"
).
extractall
()
data
=
open
(
"text8"
,
"r"
,
encoding
=
"utf-8"
).
read
()
num_test_char
=
int
(
sys
.
argv
[
1
])
train_data
=
data
[:
-
2
*
num_test_char
]
valid_data
=
data
[
-
2
*
num_test_char
:
-
num_test_char
]
test_data
=
data
[
-
num_test_char
:]
for
files
,
data
in
[(
"train.txt"
,
train_data
),
(
"valid.txt"
,
valid_data
),
(
"test.txt"
,
test_data
)]:
data_str
=
" "
.
join
([
"_"
if
c
==
" "
else
c
for
c
in
data
.
strip
()])
with
open
(
files
,
"w"
)
as
f
:
f
.
write
(
data_str
)
with
open
(
files
+
".raw"
,
"w"
,
encoding
=
"utf-8"
)
as
fw
:
fw
.
write
(
data
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录