未验证 提交 a76dc125 编写于 作者: G guru4elephant 提交者: GitHub

Merge pull request #1635 from wangguibao/text_classification_async

text_classification run with fluid.AsyncExecutor
# 文本分类
以下是本例的简要目录结构及说明:
```text
.
|-- README.md # README
|-- data_generator # IMDB数据集生成工具
| |-- IMDB.py # 在data_generator.py基础上扩展IMDB数据集处理逻辑
| |-- build_raw_data.py # IMDB数据预处理,其产出被splitfile.py读取。格式:word word ... | label
| |-- data_generator.py # 与AsyncExecutor配套的数据生成工具框架
| `-- splitfile.py # 将build_raw_data.py生成的文件切分,其产出被IMDB.py读取
|-- data_generator.sh # IMDB数据集生成工具入口
|-- data_reader.py # 预测脚本使用的数据读取工具
|-- infer.py # 预测脚本
`-- train.py # 训练脚本
```
## 简介
本目录包含用fluid.AsyncExecutor训练文本分类任务的脚本。网络模型定义沿用自父目录nets.py
## 训练
1. 运行命令 `sh data_generator.sh`,下载IMDB数据集,并转化成适合AsyncExecutor读取的训练数据
2. 运行命令 `python train.py bow` 开始训练模型。
```python
python train.py bow # bow指定网络结构,可替换成cnn, lstm, gru
```
3. (可选)想自定义网络结构,需在[nets.py](../nets.py)中自行添加,并设置[train.py](./train.py)中的相应参数。
```python
def train(train_reader, # 训练数据
word_dict, # 数据字典
network, # 模型配置
use_cuda, # 是否用GPU
parallel, # 是否并行
save_dirname, # 保存模型路径
lr=0.2, # 学习率大小
batch_size=128, # 每个batch的样本数
pass_num=30): # 训练的轮数
```
## 训练结果示例
```text
pass_id: 0 pass_time_cost 4.723438
pass_id: 1 pass_time_cost 3.867186
pass_id: 2 pass_time_cost 4.490111
pass_id: 3 pass_time_cost 4.573296
pass_id: 4 pass_time_cost 4.180547
pass_id: 5 pass_time_cost 4.214476
pass_id: 6 pass_time_cost 4.520387
pass_id: 7 pass_time_cost 4.149485
pass_id: 8 pass_time_cost 3.821354
pass_id: 9 pass_time_cost 5.136178
pass_id: 10 pass_time_cost 4.137318
pass_id: 11 pass_time_cost 3.943429
pass_id: 12 pass_time_cost 3.766478
pass_id: 13 pass_time_cost 4.235983
pass_id: 14 pass_time_cost 4.796462
pass_id: 15 pass_time_cost 4.668116
pass_id: 16 pass_time_cost 4.373798
pass_id: 17 pass_time_cost 4.298131
pass_id: 18 pass_time_cost 4.260021
pass_id: 19 pass_time_cost 4.244411
pass_id: 20 pass_time_cost 3.705138
pass_id: 21 pass_time_cost 3.728070
pass_id: 22 pass_time_cost 3.817919
pass_id: 23 pass_time_cost 4.698598
pass_id: 24 pass_time_cost 4.859262
pass_id: 25 pass_time_cost 5.725732
pass_id: 26 pass_time_cost 5.102599
pass_id: 27 pass_time_cost 3.876582
pass_id: 28 pass_time_cost 4.762538
pass_id: 29 pass_time_cost 3.797759
```
与fluid.Executor不同,AsyncExecutor在每个pass结束不会将accuracy打印出来。为了观察训练过程,可以将fluid.AsyncExecutor.run()方法的Debug参数设为True,这样每个pass结束会把参数指定的fetch variable打印出来:
```
async_executor.run(
main_program,
dataset,
filelist,
thread_num,
[acc],
debug=True)
```
## 预测
1. 运行命令 `python infer.py bow_model`, 开始预测。
```python
python infer.py bow_model # bow_model指定需要导入的模型
```
## 预测结果示例
```text
model_path: bow_model/epoch0.model, avg_acc: 0.882600
model_path: bow_model/epoch1.model, avg_acc: 0.887920
model_path: bow_model/epoch2.model, avg_acc: 0.886920
model_path: bow_model/epoch3.model, avg_acc: 0.884720
model_path: bow_model/epoch4.model, avg_acc: 0.879760
model_path: bow_model/epoch5.model, avg_acc: 0.876920
model_path: bow_model/epoch6.model, avg_acc: 0.874160
model_path: bow_model/epoch7.model, avg_acc: 0.872000
model_path: bow_model/epoch8.model, avg_acc: 0.870360
model_path: bow_model/epoch9.model, avg_acc: 0.868480
model_path: bow_model/epoch10.model, avg_acc: 0.867240
model_path: bow_model/epoch11.model, avg_acc: 0.866200
model_path: bow_model/epoch12.model, avg_acc: 0.865560
model_path: bow_model/epoch13.model, avg_acc: 0.865160
model_path: bow_model/epoch14.model, avg_acc: 0.864480
model_path: bow_model/epoch15.model, avg_acc: 0.864240
model_path: bow_model/epoch16.model, avg_acc: 0.863800
model_path: bow_model/epoch17.model, avg_acc: 0.863520
model_path: bow_model/epoch18.model, avg_acc: 0.862760
model_path: bow_model/epoch19.model, avg_acc: 0.862680
model_path: bow_model/epoch20.model, avg_acc: 0.862240
model_path: bow_model/epoch21.model, avg_acc: 0.862280
model_path: bow_model/epoch22.model, avg_acc: 0.862080
model_path: bow_model/epoch23.model, avg_acc: 0.861560
model_path: bow_model/epoch24.model, avg_acc: 0.861280
model_path: bow_model/epoch25.model, avg_acc: 0.861160
model_path: bow_model/epoch26.model, avg_acc: 0.861080
model_path: bow_model/epoch27.model, avg_acc: 0.860920
model_path: bow_model/epoch28.model, avg_acc: 0.860800
model_path: bow_model/epoch29.model, avg_acc: 0.860760
```
注:过拟合导致acc持续下降,请忽略
#!/usr/bin/env bash
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pushd .
cd ./data_generator
# wget "http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz"
if [ ! -f aclImdb_v1.tar.gz ]; then
wget "http://10.64.74.104:8080/paddle/dataset/imdb/aclImdb_v1.tar.gz"
fi
tar zxvf aclImdb_v1.tar.gz
mkdir train_data
python build_raw_data.py train | python splitfile.py 12 train_data
mkdir test_data
python build_raw_data.py test | python splitfile.py 12 test_data
/opt/python27/bin/python IMDB.py train_data
/opt/python27/bin/python IMDB.py test_data
mv ./output_dataset/train_data ../
mv ./output_dataset/test_data ../
cp aclImdb/imdb.vocab ../
rm -rf ./output_dataset
rm -rf train_data
rm -rf test_data
rm -rf aclImdb
popd
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import os, sys
sys.path.append(os.path.abspath(os.path.join('..')))
from data_generator import MultiSlotDataGenerator
class IMDbDataGenerator(MultiSlotDataGenerator):
def load_resource(self, dictfile):
self._vocab = {}
wid = 0
with open(dictfile) as f:
for line in f:
self._vocab[line.strip()] = wid
wid += 1
self._unk_id = len(self._vocab)
self._pattern = re.compile(r'(;|,|\.|\?|!|\s|\(|\))')
def process(self, line):
send = '|'.join(line.split('|')[:-1]).lower().replace("<br />",
" ").strip()
label = [int(line.split('|')[-1])]
words = [x for x in self._pattern.split(send) if x and x != " "]
feas = [
self._vocab[x] if x in self._vocab else self._unk_id for x in words
]
return ("words", feas), ("label", label)
imdb = IMDbDataGenerator()
imdb.load_resource("aclImdb/imdb.vocab")
# data from files
file_names = os.listdir(sys.argv[1])
filelist = []
for i in range(0, len(file_names)):
filelist.append(os.path.join(sys.argv[1], file_names[i]))
line_limit = 2500
process_num = 24
imdb.run_from_files(
filelist=filelist,
line_limit=line_limit,
process_num=process_num,
output_dir=('output_dataset/%s' % (sys.argv[1])))
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import multiprocessing
__all__ = ['MultiSlotDataGenerator']
class DataGenerator(object):
def __init__(self):
self._proto_info = None
def _set_filelist(self, filelist):
if not isinstance(filelist, list) and not isinstance(filelist, tuple):
raise ValueError("filelist%s must be in list or tuple type" %
type(filelist))
if not filelist:
raise ValueError("filelist can not be empty")
self._filelist = filelist
def _set_process_num(self, process_num):
if not isinstance(process_num, int):
raise ValueError("process_num%s must be in int type" %
type(process_num))
if process_num < 1:
raise ValueError("process_num can not less than 1")
self._process_num = process_num
def _set_line_limit(self, line_limit):
if not isinstance(line_limit, int):
raise ValueError("line_limit%s must be in int type" %
type(line_limit))
if line_limit < 1:
raise ValueError("line_limit can not less than 1")
self._line_limit = line_limit
def _set_output_dir(self, output_dir):
if not isinstance(output_dir, str):
raise ValueError("output_dir%s must be in str type" %
type(output_dir))
if not output_dir:
raise ValueError("output_dir can not be empty")
self._output_dir = output_dir
def _set_output_prefix(self, output_prefix):
if not isinstance(output_prefix, str):
raise ValueError("output_prefix%s must be in str type" %
type(output_prefix))
self._output_prefix = output_prefix
def _set_output_fill_digit(self, output_fill_digit):
if not isinstance(output_fill_digit, int):
raise ValueError("output_fill_digit%s must be in int type" %
type(output_fill_digit))
if output_fill_digit < 1:
raise ValueError("output_fill_digit can not less than 1")
self._output_fill_digit = output_fill_digit
def _set_proto_filename(self, proto_filename):
if not isinstance(proto_filename, str):
raise ValueError("proto_filename%s must be in str type" %
type(proto_filename))
if not proto_filename:
raise ValueError("proto_filename can not be empty")
self._proto_filename = proto_filename
def _print_info(self):
'''
Print the configuration information
(Called only in the run_from_stdin function).
'''
sys.stderr.write("=" * 16 + " config " + "=" * 16 + "\n")
sys.stderr.write(" filelist size: %d\n" % len(self._filelist))
sys.stderr.write(" process num: %d\n" % self._process_num)
sys.stderr.write(" line limit: %d\n" % self._line_limit)
sys.stderr.write(" output dir: %s\n" % self._output_dir)
sys.stderr.write(" output prefix: %s\n" % self._output_prefix)
sys.stderr.write(" output fill digit: %d\n" % self._output_fill_digit)
sys.stderr.write(" proto filename: %s\n" % self._proto_filename)
sys.stderr.write("==== This may take a few minutes... ====\n")
def _get_output_filename(self, output_index, lock=None):
'''
This function is used to get the name of the output file and
update output_index.
Args:
output_index(manager.Value(i)): the index of output file.
lock(manager.Lock): The lock for processes safe.
Return:
Return the name(string) of output file.
'''
if lock is not None: lock.acquire()
file_index = output_index.value
output_index.value += 1
if lock is not None: lock.release()
filename = os.path.join(self._output_dir, self._output_prefix) \
+ str(file_index).zfill(self._output_fill_digit)
sys.stderr.write("[%d] write data to file: %s\n" %
(os.getpid(), filename))
return filename
def run_from_stdin(self,
is_local=True,
hadoop_host=None,
hadoop_ugi=None,
proto_path=None,
proto_filename="data_feed.proto"):
'''
This function reads the data row from stdin, parses it with the
process function, and further parses the return value of the
process function with the _gen_str function. The parsed data will
be wrote to stdout and the corresponding protofile will be
generated. If local is set to False, the protofile will be
uploaded to hadoop.
Args:
is_local(bool): Whether to execute locally. If it is False, the
protofile will be uploaded to hadoop. The
default value is True.
hadoop_host(str): The host name of the hadoop. It should be
in this format: "hdfs://${HOST}:${PORT}".
hadoop_ugi(str): The ugi of the hadoop. It should be in this
format: "${USERNAME},${PASSWORD}".
proto_path(str): The hadoop path you want to upload the
protofile to.
proto_filename(str): The name of protofile. The default value
is "data_feed.proto". It is not
recommended to modify it.
'''
if is_local:
print \
'''\033[1;34m=======================================================
Pay attention to that the version of Python in Hadoop
may inconsistent with local version. Please check the
Python version of Hadoop to ensure that it is >= 2.7.
=======================================================\033[0m'''
else:
if hadoop_ugi is None or \
hadoop_host is None or \
proto_path is None:
raise ValueError(
"pls set hadoop_ugi, hadoop_host, and proto_path")
self._set_proto_filename(proto_filename)
for line in sys.stdin:
user_parsed_line = self.process(line)
sys.stdout.write(self._gen_str(user_parsed_line))
if self._proto_info is not None:
# maybe some task do not catch files
with open(self._proto_filename, "w") as f:
f.write(self._get_proto_desc(self._proto_info))
if is_local == False:
cmd = "$HADOOP_HOME/bin/hadoop fs" \
+ " -Dhadoop.job.ugi=" + hadoop_ugi \
+ " -Dfs.default.name=" + hadoop_host \
+ " -put " + self._proto_filename + " " + proto_path
os.system(cmd)
def run_from_files(self,
filelist,
line_limit,
process_num=1,
output_dir="./output_dataset",
output_prefix="part-",
output_fill_digit=8,
proto_filename="data_feed.proto"):
'''
This function will run process_num processes to process the files
in the filelist. It will create the output data folder(output_dir)
in the current directory, and write the processed data into the
output_dir folder(each file line_limit data, the prefix of filename
is output_prefix, the suffix of filename is output_fill_digit
numbers). And the proto_info is generated at the same time. the
name of proto file will be proto_filename.
Args:
filelist(list or tuple): Files that need to be processed.
line_limit(int): Maximum number of data stored per file.
process_num(int): Number of processes running simultaneously.
output_dir(str): The name of the folder where the output
data file is stored.
output_prefix(str): The prefix of output data file.
output_fill_digit(int): The number of suffix numbers of the
output data file.
proto_filename(str): The name of protofile.
'''
self._set_filelist(filelist)
self._set_line_limit(line_limit)
self._set_process_num(min(process_num, len(filelist)))
self._set_output_dir(output_dir)
self._set_output_prefix(output_prefix)
self._set_output_fill_digit(output_fill_digit)
self._set_proto_filename(proto_filename)
self._print_info()
if not os.path.exists(self._output_dir):
os.makedirs(self._output_dir)
elif not os.path.isdir(self._output_dir):
raise ValueError("%s is not a directory" % self._output_dir)
processes = multiprocessing.Pool()
manager = multiprocessing.Manager()
output_index = manager.Value('i', 0)
file_queue = manager.Queue()
lock = manager.Lock()
remaining_queue = manager.Queue()
for file in self._filelist:
file_queue.put(file)
info_result = []
for i in range(self._process_num):
info_result.append(processes.apply_async(subprocess_wrapper, \
(self, file_queue, remaining_queue, output_index, lock, )))
processes.close()
processes.join()
infos = [
result.get() for result in info_result if result.get() is not None
]
proto_info = self._combine_infos(infos)
with open(os.path.join(self._output_dir, self._proto_filename),
"w") as f:
f.write(self._get_proto_desc(proto_info))
while not remaining_queue.empty():
with open(self._get_output_filename(output_index), "w") as f:
for i in range(min(self._line_limit, remaining_queue.qsize())):
f.write(remaining_queue.get(False))
def _subprocess(self, file_queue, remaining_queue, output_index, lock):
'''
This function will be called by multiple processes. It is used to
continuously fetch files from file_queue, using process() function
(defined by user) and _gen_str() function(defined by concrete classes)
to process data in units of rows. Write the processed data to the
file(each file will be self._line_limit line). If the file in the
file_queue has been consumed, but the file is not full, the data
that is less than the self._line_limit line will be stored in the
remaining_queue.
Args:
file_queue(manager.Queue): The queue contains all the file
names to be processed.
remaining_queue(manager.Queue): The queue contains the data that
is less than the self._line_limit
line.
output_index(manager.Value(i)): The index(suffix) of the
output file.
lock(manager.Lock): The lock for processes safe.
Returns:
Return a proto_info which can be translated into a proto string.
'''
buffer = []
while not file_queue.empty():
try:
filename = file_queue.get(False)
except: # file_queue empty
break
with open(filename, 'r') as f:
for line in f:
buffer.append(self._gen_str(self.process(line)))
if len(buffer) == self._line_limit:
with open(
self._get_output_filename(output_index, lock),
"w") as wf:
for x in buffer:
wf.write(x)
buffer = []
if buffer:
for x in buffer:
remaining_queue.put(x)
return self._proto_info
def _gen_str(self, line):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the datafeed,and
updating proto_info infomation.
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the datafeed.
'''
raise NotImplementedError(
"pls use MultiSlotDataGenerator or PairWiseDataGenerator")
def _combine_infos(self, infos):
'''
This function is used to merge proto_info information from different
processes. In general, the proto_info of each process is consistent.
Args:
infos(list): the list of proto_infos from different processes.
Returns:
Return a unified proto_info.
'''
raise NotImplementedError(
"pls use MultiSlotDataGenerator or PairWiseDataGenerator")
def _get_proto_desc(self, proto_info):
'''
This function outputs the string of the proto file(can be directly
written to the file) according to the proto_info information.
Args:
proto_info: The proto information used to generate the proto
string. The type of the variable will be determined
by the subclass. In the MultiSlotDataGenerator,
proto_info variable is a list of tuple.
Returns:
Returns a string of the proto file.
'''
raise NotImplementedError(
"pls use MultiSlotDataGenerator or PairWiseDataGenerator")
def process(self, line):
'''
This function needs to be overridden by the user to process the
original data row into a list or tuple.
Args:
line(str): the original data row
Returns:
Returns the data processed by the user.
The data format is list or tuple:
[(name, [feasign, ...]), ...]
or ((name, [feasign, ...]), ...)
For example:
[("words", [1926, 08, 17]), ("label", [1])]
or (("words", [1926, 08, 17]), ("label", [1]))
Note:
The type of feasigns must be in int or float. Once the float
element appears in the feasign, the type of that slot will be
processed into a float.
'''
raise NotImplementedError(
"pls rewrite this function to return a list or tuple: " +
"[(name, [feasign, ...]), ...] or ((name, [feasign, ...]), ...)")
def subprocess_wrapper(instance, file_queue, remaining_queue, output_index,
lock):
'''
In order to use the class function as a process, you need to wrap it.
'''
return instance._subprocess(file_queue, remaining_queue, output_index, lock)
class MultiSlotDataGenerator(DataGenerator):
def _combine_infos(self, infos):
'''
This function is used to merge proto_info information from different
processes. In general, the proto_info of each process is consistent.
The type of input infos is list, and the type of element of infos is
tuple. The format of element of infos will be (name, type).
Args:
infos(list): the list of proto_infos from different processes.
Returns:
Return a unified proto_info.
Note:
This function is only called by the run_from_files function, so
when using the run_from_stdin function(usually used for hadoop),
the output of the process function(rewritten by the user) does
not allow that the same field to have both float and int type
values.
'''
proto_info = infos[0]
for info in infos:
for index, slot in enumerate(info):
name, type = slot
if name != proto_info[index][0]:
raise ValueError(
"combine infos error, pls contact the maintainer of this code~"
)
if type == "float" and proto_info[index][1] == "uint64":
proto_info[index] = (name, type)
return proto_info
def _get_proto_desc(self, proto_info):
'''
Generate a string of proto file based on the proto_info information.
The proto_info will be a list of tuples:
>>> [(Name, Type), ...]
The string of proto file will be in this format:
>>> name: "MultiSlotDataFeed"
>>> batch_size: 32
>>> multi_slot_desc {
>>> slots {
>>> name: Name
>>> type: Type
>>> is_dense: false
>>> is_used: false
>>> }
>>> }
Args:
proto_info(list): The proto information used to generate the
proto string.
Returns:
Returns a string of the proto file.
'''
proto_str = "name: \"MultiSlotDataFeed\"\n" \
+ "batch_size: 32\nmulti_slot_desc {\n"
for elem in proto_info:
proto_str += " slots {\n" \
+ " name: \"%s\"\n" % elem[0]\
+ " type: \"%s\"\n" % elem[1]\
+ " is_dense: false\n" \
+ " is_used: false\n" \
+ " }\n"
proto_str += "}"
return proto_str
def _gen_str(self, line):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the MultiSlotDataFeed,
and updating proto_info infomation.
The input line will be in this format:
>>> [(name, [feasign, ...]), ...]
>>> or ((name, [feasign, ...]), ...)
The output will be in this format:
>>> [ids_num id1 id2 ...] ...
The proto_info will be in this format:
>>> [(name, type), ...]
For example, if the input is like this:
>>> [("words", [1926, 08, 17]), ("label", [1])]
>>> or (("words", [1926, 08, 17]), ("label", [1]))
the output will be:
>>> 3 1234 2345 3456 1 1
the proto_info will be:
>>> [("words", "uint64"), ("label", "uint64")]
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type")
output = ""
if self._proto_info is None:
self._proto_info = []
for item in line:
name, elements = item
if not isinstance(name, str):
raise ValueError("name%s must be in str type" % type(name))
if not isinstance(elements, list):
raise ValueError("elements%s must be in list type" %
type(elements))
if not elements:
raise ValueError(
"the elements of each field can not be empty, you need padding it in process()."
)
self._proto_info.append((name, "uint64"))
if output:
output += " "
output += str(len(elements))
for elem in elements:
if isinstance(elem, float):
self._proto_info[-1] = (name, "float")
elif not isinstance(elem, int) and not isinstance(elem,
long):
raise ValueError(
"the type of element%s must be in int or float" %
type(elem))
output += " " + str(elem)
else:
if len(line) != len(self._proto_info):
raise ValueError(
"the complete field set of two given line are inconsistent.")
for index, item in enumerate(line):
name, elements = item
if not isinstance(name, str):
raise ValueError("name%s must be in str type" % type(name))
if not isinstance(elements, list):
raise ValueError("elements%s must be in list type" %
type(elements))
if not elements:
raise ValueError(
"the elements of each field can not be empty, you need padding it in process()."
)
if name != self._proto_info[index][0]:
raise ValueError(
"the field name of two given line are not match: require<%s>, get<%d>."
% (self._proto_info[index][0], name))
if output:
output += " "
output += str(len(elements))
for elem in elements:
if self._proto_info[index][1] != "float":
if isinstance(elem, float):
self._proto_info[index] = (name, "float")
elif not isinstance(elem, int) and not isinstance(elem,
long):
raise ValueError(
"the type of element%s must be in int or float"
% type(elem))
output += " " + str(elem)
return output + "\n"
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Split file into parts
"""
import sys
import os
block = int(sys.argv[1])
datadir = sys.argv[2]
file_list = []
for i in range(block):
file_list.append(open(datadir + "/part-" + str(i), "w"))
id_ = 0
for line in sys.stdin:
file_list[id_ % block].write(line)
id_ += 1
for f in file_list:
f.close()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import paddle
def parse_fields(fields):
words_width = int(fields[0])
words = fields[1:1 + words_width]
label = fields[-1]
return words, label
def imdb_data_feed_reader(data_dir, batch_size, buf_size):
"""
Data feed reader for IMDB dataset.
This data set has been converted from original format to a format suitable
for AsyncExecutor
See data.proto for data format
"""
def reader():
for file in os.listdir(data_dir):
if file.endswith('.proto'):
continue
with open(os.path.join(data_dir, file), 'r') as f:
for line in f:
fields = line.split(' ')
words, label = parse_fields(fields)
yield words, label
test_reader = paddle.batch(
paddle.reader.shuffle(
reader, buf_size=buf_size), batch_size=batch_size)
return test_reader
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import unittest
import contextlib
import numpy as np
import paddle
import paddle.fluid as fluid
import data_reader
def infer(test_reader, use_cuda, model_path=None):
"""
inference function
"""
if model_path is None:
print(str(model_path) + " cannot be found")
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(model_path, exe)
total_acc = 0.0
total_count = 0
for data in test_reader():
acc = exe.run(inference_program,
feed=utils.data2tensor(data, place),
fetch_list=fetch_targets,
return_numpy=True)
total_acc += acc[0] * len(data)
total_count += len(data)
avg_acc = total_acc / total_count
print("model_path: %s, avg_acc: %f" % (model_path, avg_acc))
if __name__ == "__main__":
if __package__ is None:
from os import sys, path
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import utils
batch_size = 128
model_path = sys.argv[1]
test_data_dirname = 'test_data'
if len(sys.argv) == 3:
test_data_dirname = sys.argv[2]
test_reader = data_reader.imdb_data_feed_reader(
'test_data', batch_size, buf_size=500000)
models = os.listdir(model_path)
for i in range(0, len(models)):
epoch_path = "epoch" + str(i) + ".model"
epoch_path = os.path.join(model_path, epoch_path)
infer(test_reader, use_cuda=False, model_path=epoch_path)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import multiprocessing
import paddle
import paddle.fluid as fluid
def train(network, dict_dim, lr, save_dirname, training_data_dirname, pass_num,
thread_num, batch_size):
file_names = os.listdir(training_data_dirname)
filelist = []
for i in range(0, len(file_names)):
if file_names[i] == 'data_feed.proto':
continue
filelist.append(os.path.join(training_data_dirname, file_names[i]))
dataset = fluid.DataFeedDesc(
os.path.join(training_data_dirname, 'data_feed.proto'))
dataset.set_batch_size(
batch_size) # datafeed should be assigned a batch size
dataset.set_use_slots(['words', 'label'])
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
avg_cost, acc, prediction = network(data, label, dict_dim)
optimizer = fluid.optimizer.Adagrad(learning_rate=lr)
opt_ops, weight_and_grad = optimizer.minimize(avg_cost)
startup_program = fluid.default_startup_program()
main_program = fluid.default_main_program()
place = fluid.CPUPlace()
executor = fluid.Executor(place)
executor.run(startup_program)
async_executor = fluid.AsyncExecutor(place)
for i in range(pass_num):
pass_start = time.time()
async_executor.run(main_program,
dataset,
filelist,
thread_num, [acc],
debug=False)
print('pass_id: %u pass_time_cost %f' % (i, time.time() - pass_start))
fluid.io.save_inference_model('%s/epoch%d.model' % (save_dirname, i),
[data.name, label.name], [acc], executor)
if __name__ == "__main__":
if __package__ is None:
from os import sys, path
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from nets import bow_net, cnn_net, lstm_net, gru_net
from utils import load_vocab
batch_size = 4
lr = 0.002
pass_num = 30
save_dirname = ""
thread_num = multiprocessing.cpu_count()
if sys.argv[1] == "bow":
network = bow_net
batch_size = 128
save_dirname = "bow_model"
elif sys.argv[1] == "cnn":
network = cnn_net
lr = 0.01
save_dirname = "cnn_model"
elif sys.argv[1] == "lstm":
network = lstm_net
lr = 0.05
save_dirname = "lstm_model"
elif sys.argv[1] == "gru":
network = gru_net
batch_size = 128
lr = 0.05
save_dirname = "gru_model"
training_data_dirname = 'train_data/'
if len(sys.argv) == 3:
training_data_dirname = sys.argv[2]
if len(sys.argv) == 4:
if thread_num >= int(sys.argv[3]):
thread_num = int(sys.argv[3])
vocab = load_vocab('imdb.vocab')
dict_dim = len(vocab)
train(network, dict_dim, lr, save_dirname, training_data_dirname, pass_num,
thread_num, batch_size)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册