提交 59adc0d6 编写于 作者: F frankwhzhang 提交者: Yi Liu

add gru4rec (#1366)

* add gru4rec

* modify gru4rec format2
上级 6a064b71
# GRU4REC
以下是本例的简要目录结构及说明:
```text
.
├── README.md # 文档
├── train.py # 训练脚本
├── infer.py # 预测脚本
├── utils # 通用函数
├── convert_format.py # 转换数据格式
├── small_train.txt # 小样本训练集
└── small_test.txt # 小样本测试集
```
## 简介
GRU4REC模型的介绍可以参阅论文[Session-based Recommendations with Recurrent Neural Networks](https://arxiv.org/abs/1511.06939),在本例中,我们实现了GRU4REC的模型。
## RSC15 数据下载及预处理
运行命令 下载RSC15官网数据集
```
curl -Lo yoochoose-data.7z https://s3-eu-west-1.amazonaws.com/yc-rdata/yoochoose-data.7z
7z x yoochoose-data.7z
```
GRU4REC的数据过滤,下载脚本[https://github.com/hidasib/GRU4Rec/blob/master/examples/rsc15/preprocess.py](https://github.com/hidasib/GRU4Rec/blob/master/examples/rsc15/preprocess.py)
注意修改文件路径
line12: PATH_TO_ORIGINAL_DATA = './'
line13:PATH_TO_PROCESSED_DATA = './'
注意使用python3 执行脚本
```
python preprocess.py
```
生成的数据格式如下
```
SessionId ItemId Time
1 214536502 1396839069.277
1 214536500 1396839249.868
1 214536506 1396839286.998
1 214577561 1396839420.306
2 214662742 1396850197.614
2 214662742 1396850239.373
2 214825110 1396850317.446
2 214757390 1396850390.71
2 214757407 1396850438.247
```
数据格式需要转换 运行脚本
```
python convert_format.py
```
模型的训练及测试数据如下,一行表示一个用户按照时间顺序的序列
```
214536502 214536500 214536506 214577561
214662742 214662742 214825110 214757390 214757407 214551617
214716935 214774687 214832672
214836765 214706482
214701242 214826623
214826835 214826715
214838855 214838855
214576500 214576500 214576500
214821275 214821275 214821371 214821371 214821371 214717089 214563337 214706462 214717436 214743335 214826837 214819762
214717867 214717867
```
## 训练
GPU 环境 默认配置
运行命令 `CUDA_VISIBLE_DEVICES=0 python train.py train_file test_file` 开始训练模型。
```python
CUDA_VISIBLE_DEVICES=0 python train.py small_train.txt small_test.file
```
CPU 环境
运行命令 `python train.py train_file test_file` 开始训练模型。
```python
python train.py small_train.txt small_test.txt
```
当前支持的参数可参见[train.py](./train.py) `train_net` 函数
```python
batch_size = 50 # batch大小 推荐500()
args = parse_args()
vocab, train_reader, test_reader = utils.prepare_data(
train_file, test_file,batch_size=batch_size * get_cards(args),\
buffer_size=1000, word_freq_threshold=0) # buffer_size 局部序列长度排序
train(
train_reader=train_reader,
vocab=vocab,
network=network,
hid_size=100, # embedding and hidden size
base_lr=0.01, # base learning rate
batch_size=batch_size,
pass_num=10, # the number of passed for training
use_cuda=True, # whether to use GPU card
parallel=False, # whether to be parallel
model_dir="model_recall20", # directory to save model
init_low_bound=-0.1, # uniform parameter initialization lower bound
init_high_bound=0.1) # uniform parameter initialization upper bound
```
## 自定义网络结构
可在[train.py](./train.py) `network` 函数中调整网络结构,当前的网络结构如下:
```python
emb = fluid.layers.embedding(
input=src,
size=[vocab_size, hid_size],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound),
learning_rate=emb_lr_x),
is_sparse=True)
fc0 = fluid.layers.fc(input=emb,
size=hid_size * 3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound),
learning_rate=gru_lr_x))
gru_h0 = fluid.layers.dynamic_gru(
input=fc0,
size=hid_size,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound),
learning_rate=gru_lr_x))
fc = fluid.layers.fc(input=gru_h0,
size=vocab_size,
act='softmax',
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound),
learning_rate=fc_lr_x))
cost = fluid.layers.cross_entropy(input=fc, label=dst)
acc = fluid.layers.accuracy(input=fc, label=dst, k=20)
```
## 训练结果示例
我们在Tesla K40m单GPU卡上训练的日志如下所示
```text
epoch_1 start
step:100 ppl:441.468
step:200 ppl:311.043
step:300 ppl:218.952
step:400 ppl:186.172
step:500 ppl:188.600
step:600 ppl:131.213
step:700 ppl:165.770
step:800 ppl:164.414
step:900 ppl:156.470
step:1000 ppl:174.201
step:1100 ppl:118.619
step:1200 ppl:122.635
step:1300 ppl:118.220
step:1400 ppl:90.372
step:1500 ppl:135.018
step:1600 ppl:114.327
step:1700 ppl:141.806
step:1800 ppl:93.416
step:1900 ppl:92.897
step:2000 ppl:121.703
step:2100 ppl:96.288
step:2200 ppl:88.355
step:2300 ppl:101.737
step:2400 ppl:95.934
step:2500 ppl:86.158
step:2600 ppl:80.925
step:2700 ppl:202.219
step:2800 ppl:106.828
step:2900 ppl:91.458
step:3000 ppl:105.988
step:3100 ppl:87.067
step:3200 ppl:92.651
step:3300 ppl:101.145
step:3400 ppl:91.247
step:3500 ppl:107.656
step:3600 ppl:89.410
...
...
step:15700 ppl:76.819
step:15800 ppl:62.257
step:15900 ppl:81.735
epoch:1 num_steps:15907 time_cost(s):4154.096032
model saved in model_recall20/epoch_1
...
```
## 预测
运行命令 `CUDA_VISIBLE_DEVICES=0 python infer.py model_dir start_epoch last_epoch(inclusive) train_file test_file` 开始预测,其中,start_epoch指定开始预测的轮次,last_epoch指定结束的轮次,例如
```python
CUDA_VISIBLE_DEVICES=0 python infer.py model 1 10 small_train.txt small_test.txt# prediction from epoch 1 to epoch 10 small_train.txt small_test.txt
```
## 预测结果示例
```text
model:model_r@20/epoch_1 recall@20:0.613 time_cost(s):12.23
model:model_r@20/epoch_2 recall@20:0.647 time_cost(s):12.33
model:model_r@20/epoch_3 recall@20:0.662 time_cost(s):12.38
model:model_r@20/epoch_4 recall@20:0.669 time_cost(s):12.21
model:model_r@20/epoch_5 recall@20:0.673 time_cost(s):12.17
model:model_r@20/epoch_6 recall@20:0.675 time_cost(s):12.26
model:model_r@20/epoch_7 recall@20:0.677 time_cost(s):12.25
model:model_r@20/epoch_8 recall@20:0.679 time_cost(s):12.37
model:model_r@20/epoch_9 recall@20:0.680 time_cost(s):12.22
model:model_r@20/epoch_10 recall@20:0.681 time_cost(s):12.2
```
import sys
def convert_format(input, output):
with open(input) as rf:
with open(output, "w") as wf:
last_sess = -1
sign = 1
i = 0
for l in rf:
i = i + 1
if i == 1:
continue
if (i % 1000000 == 1):
print(i)
tokens = l.strip().split()
if (int(tokens[0]) != last_sess):
if (sign):
sign = 0
wf.write(tokens[1] + " ")
else:
wf.write("\n" + tokens[1] + " ")
last_sess = int(tokens[0])
else:
wf.write(tokens[1] + " ")
input = "rsc15_train_tr.txt"
output = "rsc15_train_tr_paddle.txt"
input2 = "rsc15_test.txt"
output2 = "rsc15_test_paddle.txt"
convert_format(input, output)
convert_format(input2, output2)
import sys
import time
import math
import unittest
import contextlib
import numpy as np
import six
import paddle.fluid as fluid
import paddle
import utils
def infer(test_reader, use_cuda, model_path):
""" inference function """
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
with fluid.scope_guard(fluid.core.Scope()):
infer_program, feed_target_names, fetch_vars = fluid.io.load_inference_model(
model_path, exe)
accum_num_recall = 0.0
accum_num_sum = 0.0
t0 = time.time()
step_id = 0
for data in test_reader():
step_id += 1
src_wordseq = utils.to_lodtensor([dat[0] for dat in data], place)
label_data = [dat[1] for dat in data]
dst_wordseq = utils.to_lodtensor(label_data, place)
para = exe.run(
infer_program,
feed={"src_wordseq": src_wordseq,
"dst_wordseq": dst_wordseq},
fetch_list=fetch_vars,
return_numpy=False)
acc_ = para[1]._get_float_element(0)
data_length = len(
np.concatenate(
label_data, axis=0).astype("int64"))
accum_num_sum += (data_length)
accum_num_recall += (data_length * acc_)
if step_id % 100 == 0:
print("step:%d " % (step_id), accum_num_recall / accum_num_sum)
t1 = time.time()
print("model:%s recall@20:%.3f time_cost(s):%.2f" %
(model_path, accum_num_recall / accum_num_sum, t1 - t0))
if __name__ == "__main__":
if len(sys.argv) != 6:
print(
"Usage: %s model_dir start_epoch last_epoch(inclusive) train_file test_file"
)
exit(0)
train_file = ""
test_file = ""
model_dir = sys.argv[1]
try:
start_index = int(sys.argv[2])
last_index = int(sys.argv[3])
train_file = sys.argv[4]
test_file = sys.argv[5]
except:
iprint(
"Usage: %s model_dir start_ipoch last_epoch(inclusive) train_file test_file"
)
exit(-1)
vocab, train_reader, test_reader = utils.prepare_data(
train_file,
test_file,
batch_size=5,
buffer_size=1000,
word_freq_threshold=0)
for epoch in xrange(start_index, last_index + 1):
epoch_path = model_dir + "/epoch_" + str(epoch)
infer(test_reader=test_reader, use_cuda=True, model_path=epoch_path)
214586805 214509260
214857547 214857268 214857260
214859848 214857787
214687963 214531502 214687963
214696532 214859034 214858850
214857570 214857810 214857568 214857787 214857182
214857562 214857570 214857562 214857568
214859132 214545928 214859132 214551913
214858843 214859859 214858912 214858691 214859900
214561888 214561888
214688430 214688435 214688430
214536302 214531376 214531659 214531440 214531466 214513382 214550996
214854930 214854930
214858856 214690775 214859306
214859872 214858912 214858689
214859310 214859338 214859338 214859942 214859293 214859889 214859338 214859889 214859075 214859338 214859338 214859889
214574906 214574906
214859342 214859342 214858777 214851155 214851152 214572433
214537127 214857257
214857570 214857570 214857568 214857562 214857015
214854352 214854352 214854354
214738466 214855010 214857605 214856552 214574906 214857765 214849299
214858365 214859900 214859126 214858689 214859126 214859126 214857759 214858850 214859895 214859300
214857260 214561481 214848995 214849052 214865212
214857596 214819412 214819412
214849342 214849342
214859902 214854845 214854845 214854825
214859306 214859126 214859126
214644962 214644960 214644958
214696432 214696434
214708372 214508287 214684093
214857015 214857015 214858847 214690130
214858787 214859855
214858847 214696532 214859304 214854845
214586805 214586805
214857568 214857570
214696532 214858850 214859034 214569238 214568120 214854165 214684785 214854262 214567327
214602729 214857568 214857596
214859122 214858687 214859122 214859872
214555607 214836225 214836225 214836223
214849299 214829724 214855010 214829801 214574906 214586722 214684307 214857570
214859872 214695525
214845947 214586722 214829801
214829312 214546123
214849055 214849052
214509260 214587932 214596435 214644960 214696432 214696434 214545928 214857030 214636329 214832604 214574906
214586805 214586805
214587932 214587932
214857568 214857549 214854894
214836819 214836819 214595855 214595855
214858787 214858787
214854860 214857701
214848750 214643908
214858847 214859872 214859038 214859855 214690130
214847780 214696817 214717305
214509260 214509260
214853122 214853122 214853122 214853323
214858847 214858631 214858691
214859859 214819807 214853072 214853072 214819730
214820450 214705115 214586805
214858787 214859036
214829842 214864967
214846033 214850949
214587932 214586805 214509260 214696432 214855110 214545928
214858856 214859081 214859306 214858854
214690839 214690839 214711277 214839607 214582942 214582942
214857030 214832604
214857570 214855046 214859870 214577475 214858687 214656380
214854845 214854845 214854684 214859893 214854845 214854778
214850630 214848159 214848159 214848159 214848159 214848159 214848159 214848159
214856248 214856248
214858365 214858905 214858905
214712274 214855046
214845947 214845947 214831946 214717511 214846014 214854729
214561462 214561462 214561481 214561481
214836819 214853250
214858854 214859915 214859306 214854300
214857660 214857787 214539307 214855010 214855046 214849299 214856981 214849055
214855046 214854877 214568102 214539523 214579762 214539347 214641127 214600995 214833733 214600995 214684633 214645121 214658040 214712276 214857660 214687895 214854313 214857517
214845962 214853165 214846119
214854146 214859034
214819412 214819412 214819412 214819412
214849747 214578350 214561991
214854341 214854341
214644855 214644857 214531153
214644960 214862167
214640490 214600918 214600922
214854710 214857759 214859306
214858843 214859297 214858631 214859117 214858689 214858912 214859902 214690127
214586805 214586805
214859306 214859306 214859126
214859034 214696532 214858850 214859126 214859859 214859034 214859859 214858850
214857782 214849048 214857787
214854148 214857787 214854877
214858631 214858631 214690127 214859034 214858850 214859117 214858631 214859300 214858843 214859859 214859859
214646036 214646036
214858847 214858631 214690127 214859297
214861603 214700002 214700000 214835117 214700000 214857830 214700000 214712235 214700000 214700002 214510700 214835713 214712235 214853321
214854855 214854815 214854815
214857185 214854637 214829765 214848384 214829765 214856546 214848596 214835167 214563335 214553837 214536185 214855982 214845515 214550844 214712006
214536502 214536500 214536506 214577561
214662742 214662742 214825110 214757390 214757407 214551617
214716935 214774687 214832672
214836765 214706482
214701242 214826623
214826835 214826715
214838855 214838855
214576500 214576500 214576500
214821275 214821275 214821371 214821371 214821371 214717089 214563337 214706462 214717436 214743335 214826837 214819762
214717867 214717867
214836761 214684513 214836761
214577732 214587013 214577732
214826897 214820441
214684093 214684093 214684093
214561790 214561790 214611457 214611457
214577732 214577732
214838503 214838503 214838503 214838503 214838503 214548744
214718203 214718203 214718203 214718203
214837485 214837485 214837485 214837487 214837487 214821315 214586711 214821305 214821307 214844357 214821341 214821309 214551617 214551617 214612920 214837487
214613743 214613743 214539110 214539110
214827028 214827017 214537796 214840762 214707930 214707930 214585652 214536197 214536195 214646169
214579288 214714790 214676070 214601407
214532036 214700432
214836789 214836789 214710804
214537967 214537967
214718246 214826835
214835257 214835265
214834865 214571188 214571188 214571188 214820225 214820225 214820225 214820225 214820225 214820225 214820225 214820225 214706441 214706441 214706441 214706441
214652878 214716737 214652878
214684721 214680356
214551594 214586970
214826769 214537967
214819745 214819745
214691587 214587915
214821277 214821277 214821277 214821277 214821277
214716932 214716932 214716932 214716932 214716932 214716932
214712235 214581489 214602605
214820441 214826897 214826702 214684513 214838100 214544357 214551626 214691484
214545935 214819438 214839907 214835917 214836210
214698491 214523692
214695307 214695305 214538317 214677448
214819468 214716977 214716977 214716977 214716977 214716939
214544355 214601212 214601212 214601212
214716982 214716984
214844248 214844248
214515834 214515830
214717318 214717318
214832557 214559660 214559660 214819520 214586540
214587797 214835775 214844109
214714794 214601407 214826619 214746427 214821300 214717562 214826927 214748334 214826908 214800262
214709645 214709645 214709645 214709645 214709645
214532072 214532070
214827022 214840419
214716984 214832657
214662975 214537779 214840762
214821277 214821277 214821277
214748300 214748293
214826955 214826606 214687642
214832559 214832559 214832559 214821017 214821017 214572234 214826715 214826715
214509135 214536853 214509133 214509135 214509135 214509135 214717877 214826615 214716982
214819472 214687685
214821285 214821285 214826801 214826801
214826705 214826705
214668590 214826872
214652220 214840483 214840483 214717286 214558807 214821300 214826908 214826908 214826908 214554637 214819430 214819430 214826837 214826837 214820392 214820392 214586694 214819376 214553844 214601229 214555500 214695127 214819760 214717850 214718385 214743369 214743369
214648475 214648340 214648438 214648455 214712936 214712887 214696149 214717097 214534352 214534352 214717097
214560099 214560099 214560099 214832750 214560099
214685621 214684093 214546097 214685623
214819685 214839907 214839905 214811752
214717007 214717003 214716928
214820842 214819490
214555869 214537185
214840599 214835735
214838100 214706216
214829737 214821315
214748293 214748293
214712272 214820450
214821380 214821380
214826799 214827005 214718390 214718396 214826627
214841060 214841060
214687768 214706445
214811752 214811754
214594678 214594680 214594680
214821369 214821369 214697771 214697512 214697413 214697409 214652409 214537127 214537127 214820237 214820237 214709645 214699213 214820237 214820237 214820237 214709645 214537127
214554358 214716950
214821275 214829741
214829741 214820842 214821279 214703790
214716954 214838366
214821022 214820814
214684721 214821369 214826833 214819472
214821315 214821305
214826702 214821275
214717847 214819719 214748336
214536440 214536437
214512416 214512416
214839313 214839313 214839313
214826705 214826705
214510044 214510044 214510044 214582387 214537535 214584812 214537535 214584810
214600989 214704180
214705693 214696824 214705682 214696817 214705691 214705693 214711710 214705691 214705691 214687539 214705687 214744796 214681648 214717307 214577750 214650382 214744796 214696817 214705682 214711710
import os
import sys
import time
import six
import numpy as np
import math
import argparse
import paddle.fluid as fluid
import paddle
import time
import utils
SEED = 102
def parse_args():
parser = argparse.ArgumentParser("gru4rec benchmark.")
parser.add_argument('train_file')
parser.add_argument('test_file')
parser.add_argument(
'--enable_ce',
action='store_true',
help='If set, run \
the task with continuous evaluation logs.')
parser.add_argument(
'--num_devices', type=int, default=1, help='Number of GPU devices')
args = parser.parse_args()
return args
def network(src, dst, vocab_size, hid_size, init_low_bound, init_high_bound):
""" network definition """
emb_lr_x = 10.0
gru_lr_x = 1.0
fc_lr_x = 1.0
emb = fluid.layers.embedding(
input=src,
size=[vocab_size, hid_size],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound),
learning_rate=emb_lr_x),
is_sparse=True)
fc0 = fluid.layers.fc(input=emb,
size=hid_size * 3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound),
learning_rate=gru_lr_x))
gru_h0 = fluid.layers.dynamic_gru(
input=fc0,
size=hid_size,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound),
learning_rate=gru_lr_x))
fc = fluid.layers.fc(input=gru_h0,
size=vocab_size,
act='softmax',
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound),
learning_rate=fc_lr_x))
cost = fluid.layers.cross_entropy(input=fc, label=dst)
acc = fluid.layers.accuracy(input=fc, label=dst, k=20)
return cost, acc
def train(train_reader,
vocab,
network,
hid_size,
base_lr,
batch_size,
pass_num,
use_cuda,
parallel,
model_dir,
init_low_bound=-0.04,
init_high_bound=0.04):
""" train network """
args = parse_args()
if args.enable_ce:
# random seed must set before configuring the network.
fluid.default_startup_program().random_seed = SEED
vocab_size = len(vocab)
# Input data
src_wordseq = fluid.layers.data(
name="src_wordseq", shape=[1], dtype="int64", lod_level=1)
dst_wordseq = fluid.layers.data(
name="dst_wordseq", shape=[1], dtype="int64", lod_level=1)
# Train program
avg_cost = None
cost, acc = network(src_wordseq, dst_wordseq, vocab_size, hid_size,
init_low_bound, init_high_bound)
avg_cost = fluid.layers.mean(x=cost)
# Optimization to minimize lost
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=base_lr)
sgd_optimizer.minimize(avg_cost)
# Initialize executor
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if parallel:
train_exe = fluid.ParallelExecutor(
use_cuda=use_cuda, loss_name=avg_cost.name)
else:
train_exe = exe
total_time = 0.0
fetch_list = [avg_cost.name]
for pass_idx in six.moves.xrange(pass_num):
epoch_idx = pass_idx + 1
print "epoch_%d start" % epoch_idx
t0 = time.time()
i = 0
newest_ppl = 0
for data in train_reader():
i += 1
lod_src_wordseq = utils.to_lodtensor([dat[0] for dat in data],
place)
lod_dst_wordseq = utils.to_lodtensor([dat[1] for dat in data],
place)
ret_avg_cost = train_exe.run(feed={
"src_wordseq": lod_src_wordseq,
"dst_wordseq": lod_dst_wordseq
},
fetch_list=fetch_list)
avg_ppl = np.exp(ret_avg_cost[0])
newest_ppl = np.mean(avg_ppl)
if i % 10 == 0:
print("step:%d ppl:%.3f" % (i, newest_ppl))
t1 = time.time()
total_time += t1 - t0
print("epoch:%d num_steps:%d time_cost(s):%f" %
(epoch_idx, i, total_time / epoch_idx))
if pass_idx == pass_num - 1 and args.enable_ce:
#Note: The following logs are special for CE monitoring.
#Other situations do not need to care about these logs.
gpu_num = get_cards(args.enable_ce)
if gpu_num == 1:
print("kpis rsc15_pass_duration %s" %
(total_time / epoch_idx))
print("kpis rsc15_avg_ppl %s" % newest_ppl)
else:
print("kpis rsc15_pass_duration_card%s %s" % \
(gpu_num, total_time / epoch_idx))
print("kpis rsc15_avg_ppl_card%s %s" %
(gpu_num, newest_ppl))
save_dir = "%s/epoch_%d" % (model_dir, epoch_idx)
feed_var_names = ["src_wordseq", "dst_wordseq"]
fetch_vars = [avg_cost, acc]
fluid.io.save_inference_model(save_dir, feed_var_names, fetch_vars, exe)
print("model saved in %s" % save_dir)
print("finish training")
def get_cards(args):
if args.enable_ce:
cards = os.environ.get('CUDA_VISIBLE_DEVICES')
num = len(cards.split(","))
return num
else:
return args.num_devices
def train_net():
""" do training """
args = parse_args()
train_file = args.train_file
test_file = args.test_file
batch_size = 50
vocab, train_reader, test_reader = utils.prepare_data(
train_file, test_file,batch_size=batch_size * get_cards(args),\
buffer_size=1000, word_freq_threshold=0)
train(
train_reader=train_reader,
vocab=vocab,
network=network,
hid_size=100,
base_lr=0.01,
batch_size=batch_size,
pass_num=10,
use_cuda=True,
parallel=False,
model_dir="model_recall20",
init_low_bound=-0.1,
init_high_bound=0.1)
if __name__ == "__main__":
train_net()
import sys
import collections
import six
import time
import numpy as np
import paddle.fluid as fluid
import paddle
def to_lodtensor(data, place):
""" convert to LODtensor """
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def prepare_data(train_filename,
test_filename,
batch_size,
buffer_size=1000,
word_freq_threshold=0,
enable_ce=False):
""" prepare the English Pann Treebank (PTB) data """
print("start constuct word dict")
vocab = build_dict(word_freq_threshold, train_filename, test_filename)
print("construct word dict done\n")
if enable_ce:
train_reader = paddle.batch(
train(
train_filename, vocab, buffer_size, data_type=DataType.SEQ),
batch_size)
else:
train_reader = sort_batch(
paddle.reader.shuffle(
train(
train_filename, vocab, buffer_size, data_type=DataType.SEQ),
buf_size=buffer_size),
batch_size,
batch_size * 20)
test_reader = sort_batch(
test(
test_filename, vocab, buffer_size, data_type=DataType.SEQ),
batch_size,
batch_size * 20)
return vocab, train_reader, test_reader
def sort_batch(reader, batch_size, sort_group_size, drop_last=False):
"""
Create a batched reader.
:param reader: the data reader to read from.
:type reader: callable
:param batch_size: size of each mini-batch
:type batch_size: int
:param sort_group_size: size of partial sorted batch
:type sort_group_size: int
:param drop_last: drop the last batch, if the size of last batch is not equal to batch_size.
:type drop_last: bool
:return: the batched reader.
:rtype: callable
"""
def batch_reader():
r = reader()
b = []
for instance in r:
b.append(instance)
if len(b) == sort_group_size:
sortl = sorted(b, key=lambda x: len(x[0]), reverse=True)
b = []
c = []
for sort_i in sortl:
c.append(sort_i)
if (len(c) == batch_size):
yield c
c = []
if drop_last == False and len(b) != 0:
sortl = sorted(b, key=lambda x: len(x[0]), reverse=True)
c = []
for sort_i in sortl:
c.append(sort_i)
if (len(c) == batch_size):
yield c
c = []
# Batch size check
batch_size = int(batch_size)
if batch_size <= 0:
raise ValueError("batch_size should be a positive integeral value, "
"but got batch_size={}".format(batch_size))
return batch_reader
class DataType(object):
SEQ = 2
def word_count(input_file, word_freq=None):
"""
compute word count from corpus
"""
if word_freq is None:
word_freq = collections.defaultdict(int)
for l in input_file:
for w in l.strip().split():
word_freq[w] += 1
return word_freq
def build_dict(min_word_freq=50, train_filename="", test_filename=""):
"""
Build a word dictionary from the corpus, Keys of the dictionary are words,
and values are zero-based IDs of these words.
"""
with open(train_filename) as trainf:
with open(test_filename) as testf:
word_freq = word_count(testf, word_count(trainf))
word_freq = [x for x in six.iteritems(word_freq) if x[1] > min_word_freq]
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*word_freq_sorted))
word_idx = dict(list(zip(words, six.moves.range(len(words)))))
return word_idx
def reader_creator(filename, word_idx, n, data_type):
def reader():
with open(filename) as f:
for l in f:
if DataType.SEQ == data_type:
l = l.strip().split()
l = [word_idx.get(w) for w in l]
src_seq = l[:len(l) - 1]
trg_seq = l[1:]
if n > 0 and len(src_seq) > n: continue
yield src_seq, trg_seq
else:
assert False, 'error data type'
return reader
def train(filename, word_idx, n, data_type=DataType.SEQ):
return reader_creator(filename, word_idx, n, data_type)
def test(filename, word_idx, n, data_type=DataType.SEQ):
return reader_creator(filename, word_idx, n, data_type)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册