未验证 提交 ac7a2931 编写于 作者: Y Yi Liu 提交者: GitHub

append a language model in fluid (#873)

* add a language model sample in fluid

* minor update

* fix code style
上级 5d1dc2ff
# 语言模型
以下是本例的简要目录结构及说明:
```text
.
├── README.md # 文档
├── train.py # 训练脚本
├── infer.py # 预测脚本
└── utils.py # 通用函数
```
## 简介
循环神经网络语言模型的介绍可以参阅论文[Recurrent Neural Network Regularization](https://arxiv.org/abs/1409.2329),在本例中,我们实现了GRU-RNN语言模型。
## 训练
运行命令 `python train.py` 开始训练模型。
```python
python train.py
```
当前支持的参数可参见[train.py](./train.py) `train_net` 函数
```python
vocab, train_reader, test_reader = utils.prepare_data(
batch_size=20, # batch size
buffer_size=1000, # buffer size, default value is OK
word_freq_threshold=0) # vocabulary related parameter, and words with frequency below this value will be filtered
train(train_reader=train_reader,
vocab=vocab,
network=network,
hid_size=200, # embedding and hidden size
base_lr=1.0, # base learning rate
batch_size=20, # batch size, the same as that in prepare_data
pass_num=12, # the number of passes for training
use_cuda=True, # whether to use GPU card
parallel=False, # whether to be parallel
model_dir="model", # directory to save model
init_low_bound=-0.1, # uniform parameter initialization lower bound
init_high_bound=0.1) # uniform parameter initialization upper bound
```
## 自定义网络结构
可在[train.py](./train.py) `network` 函数中调整网络结构,当前的网络结构如下:
```python
emb = fluid.layers.embedding(input=src, size=[vocab_size, hid_size],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(low=init_low_bound, high=init_high_bound),
learning_rate=emb_lr_x),
is_sparse=True)
fc0 = fluid.layers.fc(input=emb, size=hid_size * 3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(low=init_low_bound, high=init_high_bound),
learning_rate=gru_lr_x))
gru_h0 = fluid.layers.dynamic_gru(input=fc0, size=hid_size,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(low=init_low_bound, high=init_high_bound),
learning_rate=gru_lr_x))
fc = fluid.layers.fc(input=gru_h0, size=vocab_size, act='softmax',
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(low=init_low_bound, high=init_high_bound),
learning_rate=fc_lr_x))
cost = fluid.layers.cross_entropy(input=fc, label=dst)
```
## 训练结果示例
我们在Tesla K40m单GPU卡上训练的日志如下所示
```text
epoch_1 start
step:100 ppl:771.053
step:200 ppl:449.597
step:300 ppl:642.654
step:400 ppl:458.128
step:500 ppl:510.912
step:600 ppl:451.545
step:700 ppl:364.404
step:800 ppl:324.272
step:900 ppl:360.797
step:1000 ppl:275.761
step:1100 ppl:294.599
step:1200 ppl:335.877
step:1300 ppl:185.262
step:1400 ppl:241.744
step:1500 ppl:211.507
step:1600 ppl:233.431
step:1700 ppl:298.767
step:1800 ppl:203.403
step:1900 ppl:158.828
step:2000 ppl:171.148
step:2100 ppl:280.884
epoch:1 num_steps:2104 time_cost(s):47.478780
model saved in model/epoch_1
epoch_2 start
step:100 ppl:238.099
step:200 ppl:136.527
step:300 ppl:204.184
step:400 ppl:252.886
step:500 ppl:177.377
step:600 ppl:197.688
step:700 ppl:131.650
step:800 ppl:223.906
step:900 ppl:144.785
step:1000 ppl:176.286
step:1100 ppl:148.158
step:1200 ppl:203.581
step:1300 ppl:168.208
step:1400 ppl:159.412
step:1500 ppl:114.032
step:1600 ppl:157.985
step:1700 ppl:147.743
step:1800 ppl:88.676
step:1900 ppl:141.962
step:2000 ppl:106.087
step:2100 ppl:122.709
epoch:2 num_steps:2104 time_cost(s):47.583789
model saved in model/epoch_2
...
```
## 预测
运行命令 `python infer.py model_dir start_epoch last_epoch(inclusive)` 开始预测,其中,start_epoch指定开始预测的轮次,last_epoch指定结束的轮次,例如
```python
python infer.py model 1 12 # prediction from epoch 1 to epoch 12
```
## 预测结果示例
```text
model:model/epoch_1 ppl:254.540 time_cost(s):3.29
model:model/epoch_2 ppl:177.671 time_cost(s):3.27
model:model/epoch_3 ppl:156.251 time_cost(s):3.27
model:model/epoch_4 ppl:139.036 time_cost(s):3.27
model:model/epoch_5 ppl:132.661 time_cost(s):3.27
model:model/epoch_6 ppl:130.092 time_cost(s):3.28
model:model/epoch_7 ppl:128.751 time_cost(s):3.27
model:model/epoch_8 ppl:125.411 time_cost(s):3.27
model:model/epoch_9 ppl:124.604 time_cost(s):3.28
model:model/epoch_10 ppl:124.754 time_cost(s):3.29
model:model/epoch_11 ppl:125.421 time_cost(s):3.27
model:model/epoch_12 ppl:125.676 time_cost(s):3.27
```
import sys
import time
import math
import unittest
import contextlib
import numpy as np
import paddle.fluid as fluid
import paddle.v2 as paddle
import utils
def infer(test_reader, use_cuda, model_path):
""" inference function """
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
with fluid.scope_guard(fluid.core.Scope()):
infer_program, feed_target_names, fetch_vars = fluid.io.load_inference_model(
model_path, exe)
accum_cost = 0.0
accum_words = 0
t0 = time.time()
for data in test_reader():
src_wordseq = utils.to_lodtensor(map(lambda x: x[0], data), place)
dst_wordseq = utils.to_lodtensor(map(lambda x: x[1], data), place)
avg_cost = exe.run(
infer_program,
feed={"src_wordseq": src_wordseq,
"dst_wordseq": dst_wordseq},
fetch_list=fetch_vars)
nwords = src_wordseq.lod()[0][-1]
cost = np.array(avg_cost) * nwords
accum_cost += cost
accum_words += nwords
ppl = math.exp(accum_cost / accum_words)
t1 = time.time()
print("model:%s ppl:%.3f time_cost(s):%.2f" %
(model_path, ppl, t1 - t0))
if __name__ == "__main__":
if len(sys.argv) != 4:
print("Usage: %s model_dir start_epoch last_epoch(inclusive)")
exit(0)
model_dir = sys.argv[1]
try:
start_index = int(sys.argv[2])
last_index = int(sys.argv[3])
except:
print("Usage: %s model_dir start_epoch last_epoch(inclusive)")
exit(-1)
vocab, train_reader, test_reader = utils.prepare_data(
batch_size=20, buffer_size=1000, word_freq_threshold=0)
for epoch in xrange(start_index, last_index + 1):
epoch_path = model_dir + "/epoch_" + str(epoch)
infer(test_reader=test_reader, use_cuda=True, model_path=epoch_path)
import sys
import time
import numpy as np
import math
import paddle.fluid as fluid
import paddle.v2 as paddle
import utils
def network(src, dst, vocab_size, hid_size, init_low_bound, init_high_bound):
""" network definition """
emb_lr_x = 10.0
gru_lr_x = 1.0
fc_lr_x = 1.0
emb = fluid.layers.embedding(
input=src,
size=[vocab_size, hid_size],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound),
learning_rate=emb_lr_x),
is_sparse=True)
fc0 = fluid.layers.fc(input=emb,
size=hid_size * 3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound),
learning_rate=gru_lr_x))
gru_h0 = fluid.layers.dynamic_gru(
input=fc0,
size=hid_size,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound),
learning_rate=gru_lr_x))
fc = fluid.layers.fc(input=gru_h0,
size=vocab_size,
act='softmax',
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound),
learning_rate=fc_lr_x))
cost = fluid.layers.cross_entropy(input=fc, label=dst)
return cost
def train(train_reader,
vocab,
network,
hid_size,
base_lr,
batch_size,
pass_num,
use_cuda,
parallel,
model_dir,
init_low_bound=-0.04,
init_high_bound=0.04):
""" train network """
vocab_size = len(vocab)
src_wordseq = fluid.layers.data(
name="src_wordseq", shape=[1], dtype="int64", lod_level=1)
dst_wordseq = fluid.layers.data(
name="dst_wordseq", shape=[1], dtype="int64", lod_level=1)
avg_cost = None
if not parallel:
cost = network(src_wordseq, dst_wordseq, vocab_size, hid_size,
init_low_bound, init_high_bound)
avg_cost = fluid.layers.mean(x=cost)
else:
places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places)
with pd.do():
cost = network(
pd.read_input(src_wordseq),
pd.read_input(dst_wordseq), vocab_size, hid_size,
init_low_bound, init_high_bound)
pd.write_output(cost)
cost = pd()
avg_cost = fluid.layers.mean(x=cost)
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=fluid.layers.exponential_decay(
learning_rate=base_lr,
decay_steps=2100 * 4,
decay_rate=0.5,
staircase=True))
sgd_optimizer.minimize(avg_cost)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
total_time = 0.0
for pass_idx in xrange(pass_num):
epoch_idx = pass_idx + 1
print "epoch_%d start" % epoch_idx
t0 = time.time()
i = 0
for data in train_reader():
i += 1
lod_src_wordseq = utils.to_lodtensor(
map(lambda x: x[0], data), place)
lod_dst_wordseq = utils.to_lodtensor(
map(lambda x: x[1], data), place)
ret_avg_cost = exe.run(fluid.default_main_program(),
feed={
"src_wordseq": lod_src_wordseq,
"dst_wordseq": lod_dst_wordseq
},
fetch_list=[avg_cost],
use_program_cache=True)
avg_ppl = math.exp(ret_avg_cost[0])
if i % 100 == 0:
print "step:%d ppl:%.3f" % (i, avg_ppl)
t1 = time.time()
total_time += t1 - t0
print "epoch:%d num_steps:%d time_cost(s):%f" % (epoch_idx, i,
total_time / epoch_idx)
save_dir = "%s/epoch_%d" % (model_dir, epoch_idx)
feed_var_names = ["src_wordseq", "dst_wordseq"]
fetch_vars = [avg_cost]
fluid.io.save_inference_model(save_dir, feed_var_names, fetch_vars, exe)
print("model saved in %s" % save_dir)
print("finish training")
def train_net():
""" do training """
batch_size = 20
vocab, train_reader, test_reader = utils.prepare_data(
batch_size=batch_size, buffer_size=1000, word_freq_threshold=0)
train(
train_reader=train_reader,
vocab=vocab,
network=network,
hid_size=200,
base_lr=1.0,
batch_size=batch_size,
pass_num=12,
use_cuda=True,
parallel=False,
model_dir="model",
init_low_bound=-0.1,
init_high_bound=0.1)
if __name__ == "__main__":
train_net()
import sys
import time
import numpy as np
import paddle.fluid as fluid
import paddle.v2 as paddle
def to_lodtensor(data, place):
""" convert to LODtensor """
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def prepare_data(batch_size, buffer_size=1000, word_freq_threshold=0):
""" prepare the English Pann Treebank (PTB) data """
vocab = paddle.dataset.imikolov.build_dict(word_freq_threshold)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imikolov.train(
vocab,
buffer_size,
data_type=paddle.dataset.imikolov.DataType.SEQ),
buf_size=buffer_size),
batch_size)
test_reader = paddle.batch(
paddle.dataset.imikolov.test(
vocab, buffer_size, data_type=paddle.dataset.imikolov.DataType.SEQ),
batch_size)
return vocab, train_reader, test_reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册