未验证 提交 a76dc125 编写于 作者: G guru4elephant 提交者: GitHub

Merge pull request #1635 from wangguibao/text_classification_async

text_classification run with fluid.AsyncExecutor
# 文本分类
以下是本例的简要目录结构及说明:
```text
.
|-- README.md # README
|-- data_generator # IMDB数据集生成工具
| |-- IMDB.py # 在data_generator.py基础上扩展IMDB数据集处理逻辑
| |-- build_raw_data.py # IMDB数据预处理,其产出被splitfile.py读取。格式:word word ... | label
| |-- data_generator.py # 与AsyncExecutor配套的数据生成工具框架
| `-- splitfile.py # 将build_raw_data.py生成的文件切分,其产出被IMDB.py读取
|-- data_generator.sh # IMDB数据集生成工具入口
|-- data_reader.py # 预测脚本使用的数据读取工具
|-- infer.py # 预测脚本
`-- train.py # 训练脚本
```
## 简介
本目录包含用fluid.AsyncExecutor训练文本分类任务的脚本。网络模型定义沿用自父目录nets.py
## 训练
1. 运行命令 `sh data_generator.sh`,下载IMDB数据集,并转化成适合AsyncExecutor读取的训练数据
2. 运行命令 `python train.py bow` 开始训练模型。
```python
python train.py bow # bow指定网络结构,可替换成cnn, lstm, gru
```
3. (可选)想自定义网络结构,需在[nets.py](../nets.py)中自行添加,并设置[train.py](./train.py)中的相应参数。
```python
def train(train_reader, # 训练数据
word_dict, # 数据字典
network, # 模型配置
use_cuda, # 是否用GPU
parallel, # 是否并行
save_dirname, # 保存模型路径
lr=0.2, # 学习率大小
batch_size=128, # 每个batch的样本数
pass_num=30): # 训练的轮数
```
## 训练结果示例
```text
pass_id: 0 pass_time_cost 4.723438
pass_id: 1 pass_time_cost 3.867186
pass_id: 2 pass_time_cost 4.490111
pass_id: 3 pass_time_cost 4.573296
pass_id: 4 pass_time_cost 4.180547
pass_id: 5 pass_time_cost 4.214476
pass_id: 6 pass_time_cost 4.520387
pass_id: 7 pass_time_cost 4.149485
pass_id: 8 pass_time_cost 3.821354
pass_id: 9 pass_time_cost 5.136178
pass_id: 10 pass_time_cost 4.137318
pass_id: 11 pass_time_cost 3.943429
pass_id: 12 pass_time_cost 3.766478
pass_id: 13 pass_time_cost 4.235983
pass_id: 14 pass_time_cost 4.796462
pass_id: 15 pass_time_cost 4.668116
pass_id: 16 pass_time_cost 4.373798
pass_id: 17 pass_time_cost 4.298131
pass_id: 18 pass_time_cost 4.260021
pass_id: 19 pass_time_cost 4.244411
pass_id: 20 pass_time_cost 3.705138
pass_id: 21 pass_time_cost 3.728070
pass_id: 22 pass_time_cost 3.817919
pass_id: 23 pass_time_cost 4.698598
pass_id: 24 pass_time_cost 4.859262
pass_id: 25 pass_time_cost 5.725732
pass_id: 26 pass_time_cost 5.102599
pass_id: 27 pass_time_cost 3.876582
pass_id: 28 pass_time_cost 4.762538
pass_id: 29 pass_time_cost 3.797759
```
与fluid.Executor不同,AsyncExecutor在每个pass结束不会将accuracy打印出来。为了观察训练过程,可以将fluid.AsyncExecutor.run()方法的Debug参数设为True,这样每个pass结束会把参数指定的fetch variable打印出来:
```
async_executor.run(
main_program,
dataset,
filelist,
thread_num,
[acc],
debug=True)
```
## 预测
1. 运行命令 `python infer.py bow_model`, 开始预测。
```python
python infer.py bow_model # bow_model指定需要导入的模型
```
## 预测结果示例
```text
model_path: bow_model/epoch0.model, avg_acc: 0.882600
model_path: bow_model/epoch1.model, avg_acc: 0.887920
model_path: bow_model/epoch2.model, avg_acc: 0.886920
model_path: bow_model/epoch3.model, avg_acc: 0.884720
model_path: bow_model/epoch4.model, avg_acc: 0.879760
model_path: bow_model/epoch5.model, avg_acc: 0.876920
model_path: bow_model/epoch6.model, avg_acc: 0.874160
model_path: bow_model/epoch7.model, avg_acc: 0.872000
model_path: bow_model/epoch8.model, avg_acc: 0.870360
model_path: bow_model/epoch9.model, avg_acc: 0.868480
model_path: bow_model/epoch10.model, avg_acc: 0.867240
model_path: bow_model/epoch11.model, avg_acc: 0.866200
model_path: bow_model/epoch12.model, avg_acc: 0.865560
model_path: bow_model/epoch13.model, avg_acc: 0.865160
model_path: bow_model/epoch14.model, avg_acc: 0.864480
model_path: bow_model/epoch15.model, avg_acc: 0.864240
model_path: bow_model/epoch16.model, avg_acc: 0.863800
model_path: bow_model/epoch17.model, avg_acc: 0.863520
model_path: bow_model/epoch18.model, avg_acc: 0.862760
model_path: bow_model/epoch19.model, avg_acc: 0.862680
model_path: bow_model/epoch20.model, avg_acc: 0.862240
model_path: bow_model/epoch21.model, avg_acc: 0.862280
model_path: bow_model/epoch22.model, avg_acc: 0.862080
model_path: bow_model/epoch23.model, avg_acc: 0.861560
model_path: bow_model/epoch24.model, avg_acc: 0.861280
model_path: bow_model/epoch25.model, avg_acc: 0.861160
model_path: bow_model/epoch26.model, avg_acc: 0.861080
model_path: bow_model/epoch27.model, avg_acc: 0.860920
model_path: bow_model/epoch28.model, avg_acc: 0.860800
model_path: bow_model/epoch29.model, avg_acc: 0.860760
```
注:过拟合导致acc持续下降,请忽略
#!/usr/bin/env bash
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pushd .
cd ./data_generator
# wget "http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz"
if [ ! -f aclImdb_v1.tar.gz ]; then
wget "http://10.64.74.104:8080/paddle/dataset/imdb/aclImdb_v1.tar.gz"
fi
tar zxvf aclImdb_v1.tar.gz
mkdir train_data
python build_raw_data.py train | python splitfile.py 12 train_data
mkdir test_data
python build_raw_data.py test | python splitfile.py 12 test_data
/opt/python27/bin/python IMDB.py train_data
/opt/python27/bin/python IMDB.py test_data
mv ./output_dataset/train_data ../
mv ./output_dataset/test_data ../
cp aclImdb/imdb.vocab ../
rm -rf ./output_dataset
rm -rf train_data
rm -rf test_data
rm -rf aclImdb
popd
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import os, sys
sys.path.append(os.path.abspath(os.path.join('..')))
from data_generator import MultiSlotDataGenerator
class IMDbDataGenerator(MultiSlotDataGenerator):
def load_resource(self, dictfile):
self._vocab = {}
wid = 0
with open(dictfile) as f:
for line in f:
self._vocab[line.strip()] = wid
wid += 1
self._unk_id = len(self._vocab)
self._pattern = re.compile(r'(;|,|\.|\?|!|\s|\(|\))')
def process(self, line):
send = '|'.join(line.split('|')[:-1]).lower().replace("<br />",
" ").strip()
label = [int(line.split('|')[-1])]
words = [x for x in self._pattern.split(send) if x and x != " "]
feas = [
self._vocab[x] if x in self._vocab else self._unk_id for x in words
]
return ("words", feas), ("label", label)
imdb = IMDbDataGenerator()
imdb.load_resource("aclImdb/imdb.vocab")
# data from files
file_names = os.listdir(sys.argv[1])
filelist = []
for i in range(0, len(file_names)):
filelist.append(os.path.join(sys.argv[1], file_names[i]))
line_limit = 2500
process_num = 24
imdb.run_from_files(
filelist=filelist,
line_limit=line_limit,
process_num=process_num,
output_dir=('output_dataset/%s' % (sys.argv[1])))
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Split file into parts
"""
import sys
import os
block = int(sys.argv[1])
datadir = sys.argv[2]
file_list = []
for i in range(block):
file_list.append(open(datadir + "/part-" + str(i), "w"))
id_ = 0
for line in sys.stdin:
file_list[id_ % block].write(line)
id_ += 1
for f in file_list:
f.close()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import paddle
def parse_fields(fields):
words_width = int(fields[0])
words = fields[1:1 + words_width]
label = fields[-1]
return words, label
def imdb_data_feed_reader(data_dir, batch_size, buf_size):
"""
Data feed reader for IMDB dataset.
This data set has been converted from original format to a format suitable
for AsyncExecutor
See data.proto for data format
"""
def reader():
for file in os.listdir(data_dir):
if file.endswith('.proto'):
continue
with open(os.path.join(data_dir, file), 'r') as f:
for line in f:
fields = line.split(' ')
words, label = parse_fields(fields)
yield words, label
test_reader = paddle.batch(
paddle.reader.shuffle(
reader, buf_size=buf_size), batch_size=batch_size)
return test_reader
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import unittest
import contextlib
import numpy as np
import paddle
import paddle.fluid as fluid
import data_reader
def infer(test_reader, use_cuda, model_path=None):
"""
inference function
"""
if model_path is None:
print(str(model_path) + " cannot be found")
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(model_path, exe)
total_acc = 0.0
total_count = 0
for data in test_reader():
acc = exe.run(inference_program,
feed=utils.data2tensor(data, place),
fetch_list=fetch_targets,
return_numpy=True)
total_acc += acc[0] * len(data)
total_count += len(data)
avg_acc = total_acc / total_count
print("model_path: %s, avg_acc: %f" % (model_path, avg_acc))
if __name__ == "__main__":
if __package__ is None:
from os import sys, path
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import utils
batch_size = 128
model_path = sys.argv[1]
test_data_dirname = 'test_data'
if len(sys.argv) == 3:
test_data_dirname = sys.argv[2]
test_reader = data_reader.imdb_data_feed_reader(
'test_data', batch_size, buf_size=500000)
models = os.listdir(model_path)
for i in range(0, len(models)):
epoch_path = "epoch" + str(i) + ".model"
epoch_path = os.path.join(model_path, epoch_path)
infer(test_reader, use_cuda=False, model_path=epoch_path)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import multiprocessing
import paddle
import paddle.fluid as fluid
def train(network, dict_dim, lr, save_dirname, training_data_dirname, pass_num,
thread_num, batch_size):
file_names = os.listdir(training_data_dirname)
filelist = []
for i in range(0, len(file_names)):
if file_names[i] == 'data_feed.proto':
continue
filelist.append(os.path.join(training_data_dirname, file_names[i]))
dataset = fluid.DataFeedDesc(
os.path.join(training_data_dirname, 'data_feed.proto'))
dataset.set_batch_size(
batch_size) # datafeed should be assigned a batch size
dataset.set_use_slots(['words', 'label'])
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
avg_cost, acc, prediction = network(data, label, dict_dim)
optimizer = fluid.optimizer.Adagrad(learning_rate=lr)
opt_ops, weight_and_grad = optimizer.minimize(avg_cost)
startup_program = fluid.default_startup_program()
main_program = fluid.default_main_program()
place = fluid.CPUPlace()
executor = fluid.Executor(place)
executor.run(startup_program)
async_executor = fluid.AsyncExecutor(place)
for i in range(pass_num):
pass_start = time.time()
async_executor.run(main_program,
dataset,
filelist,
thread_num, [acc],
debug=False)
print('pass_id: %u pass_time_cost %f' % (i, time.time() - pass_start))
fluid.io.save_inference_model('%s/epoch%d.model' % (save_dirname, i),
[data.name, label.name], [acc], executor)
if __name__ == "__main__":
if __package__ is None:
from os import sys, path
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from nets import bow_net, cnn_net, lstm_net, gru_net
from utils import load_vocab
batch_size = 4
lr = 0.002
pass_num = 30
save_dirname = ""
thread_num = multiprocessing.cpu_count()
if sys.argv[1] == "bow":
network = bow_net
batch_size = 128
save_dirname = "bow_model"
elif sys.argv[1] == "cnn":
network = cnn_net
lr = 0.01
save_dirname = "cnn_model"
elif sys.argv[1] == "lstm":
network = lstm_net
lr = 0.05
save_dirname = "lstm_model"
elif sys.argv[1] == "gru":
network = gru_net
batch_size = 128
lr = 0.05
save_dirname = "gru_model"
training_data_dirname = 'train_data/'
if len(sys.argv) == 3:
training_data_dirname = sys.argv[2]
if len(sys.argv) == 4:
if thread_num >= int(sys.argv[3]):
thread_num = int(sys.argv[3])
vocab = load_vocab('imdb.vocab')
dict_dim = len(vocab)
train(network, dict_dim, lr, save_dirname, training_data_dirname, pass_num,
thread_num, batch_size)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册