Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
a76dc125
M
models
项目概览
PaddlePaddle
/
models
接近 2 年 前同步成功
通知
230
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a76dc125
编写于
1月 19, 2019
作者:
G
guru4elephant
提交者:
GitHub
1月 19, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1635 from wangguibao/text_classification_async
text_classification run with fluid.AsyncExecutor
上级
9a4f5786
62afaf75
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
1011 addition
and
0 deletion
+1011
-0
fluid/PaddleNLP/text_classification/async_executor/README.md
fluid/PaddleNLP/text_classification/async_executor/README.md
+130
-0
fluid/PaddleNLP/text_classification/async_executor/data_generator.sh
...eNLP/text_classification/async_executor/data_generator.sh
+43
-0
fluid/PaddleNLP/text_classification/async_executor/data_generator/IMDB.py
...text_classification/async_executor/data_generator/IMDB.py
+60
-0
fluid/PaddleNLP/text_classification/async_executor/data_generator/data_generator.py
...ification/async_executor/data_generator/data_generator.py
+508
-0
fluid/PaddleNLP/text_classification/async_executor/data_generator/splitfile.py
...classification/async_executor/data_generator/splitfile.py
+29
-0
fluid/PaddleNLP/text_classification/async_executor/data_reader.py
...ddleNLP/text_classification/async_executor/data_reader.py
+50
-0
fluid/PaddleNLP/text_classification/async_executor/infer.py
fluid/PaddleNLP/text_classification/async_executor/infer.py
+79
-0
fluid/PaddleNLP/text_classification/async_executor/train.py
fluid/PaddleNLP/text_classification/async_executor/train.py
+112
-0
未找到文件。
fluid/PaddleNLP/text_classification/async_executor/README.md
0 → 100644
浏览文件 @
a76dc125
# 文本分类
以下是本例的简要目录结构及说明:
```
text
.
|-- README.md # README
|-- data_generator # IMDB数据集生成工具
| |-- IMDB.py # 在data_generator.py基础上扩展IMDB数据集处理逻辑
| |-- build_raw_data.py # IMDB数据预处理,其产出被splitfile.py读取。格式:word word ... | label
| |-- data_generator.py # 与AsyncExecutor配套的数据生成工具框架
| `-- splitfile.py # 将build_raw_data.py生成的文件切分,其产出被IMDB.py读取
|-- data_generator.sh # IMDB数据集生成工具入口
|-- data_reader.py # 预测脚本使用的数据读取工具
|-- infer.py # 预测脚本
`-- train.py # 训练脚本
```
## 简介
本目录包含用fluid.AsyncExecutor训练文本分类任务的脚本。网络模型定义沿用自父目录nets.py
## 训练
1.
运行命令
`sh data_generator.sh`
,下载IMDB数据集,并转化成适合AsyncExecutor读取的训练数据
2.
运行命令
`python train.py bow`
开始训练模型。
```
python
python
train
.
py
bow
# bow指定网络结构,可替换成cnn, lstm, gru
```
3.
(可选)想自定义网络结构,需在
[
nets.py
](
../nets.py
)
中自行添加,并设置
[
train.py
](
./train.py
)
中的相应参数。
```
python
def
train
(
train_reader
,
# 训练数据
word_dict
,
# 数据字典
network
,
# 模型配置
use_cuda
,
# 是否用GPU
parallel
,
# 是否并行
save_dirname
,
# 保存模型路径
lr
=
0.2
,
# 学习率大小
batch_size
=
128
,
# 每个batch的样本数
pass_num
=
30
):
# 训练的轮数
```
## 训练结果示例
```
text
pass_id: 0 pass_time_cost 4.723438
pass_id: 1 pass_time_cost 3.867186
pass_id: 2 pass_time_cost 4.490111
pass_id: 3 pass_time_cost 4.573296
pass_id: 4 pass_time_cost 4.180547
pass_id: 5 pass_time_cost 4.214476
pass_id: 6 pass_time_cost 4.520387
pass_id: 7 pass_time_cost 4.149485
pass_id: 8 pass_time_cost 3.821354
pass_id: 9 pass_time_cost 5.136178
pass_id: 10 pass_time_cost 4.137318
pass_id: 11 pass_time_cost 3.943429
pass_id: 12 pass_time_cost 3.766478
pass_id: 13 pass_time_cost 4.235983
pass_id: 14 pass_time_cost 4.796462
pass_id: 15 pass_time_cost 4.668116
pass_id: 16 pass_time_cost 4.373798
pass_id: 17 pass_time_cost 4.298131
pass_id: 18 pass_time_cost 4.260021
pass_id: 19 pass_time_cost 4.244411
pass_id: 20 pass_time_cost 3.705138
pass_id: 21 pass_time_cost 3.728070
pass_id: 22 pass_time_cost 3.817919
pass_id: 23 pass_time_cost 4.698598
pass_id: 24 pass_time_cost 4.859262
pass_id: 25 pass_time_cost 5.725732
pass_id: 26 pass_time_cost 5.102599
pass_id: 27 pass_time_cost 3.876582
pass_id: 28 pass_time_cost 4.762538
pass_id: 29 pass_time_cost 3.797759
```
与fluid.Executor不同,AsyncExecutor在每个pass结束不会将accuracy打印出来。为了观察训练过程,可以将fluid.AsyncExecutor.run()方法的Debug参数设为True,这样每个pass结束会把参数指定的fetch variable打印出来:
```
async_executor.run(
main_program,
dataset,
filelist,
thread_num,
[acc],
debug=True)
```
## 预测
1.
运行命令
`python infer.py bow_model`
, 开始预测。
```
python
python
infer
.
py
bow_model
# bow_model指定需要导入的模型
```
## 预测结果示例
```
text
model_path: bow_model/epoch0.model, avg_acc: 0.882600
model_path: bow_model/epoch1.model, avg_acc: 0.887920
model_path: bow_model/epoch2.model, avg_acc: 0.886920
model_path: bow_model/epoch3.model, avg_acc: 0.884720
model_path: bow_model/epoch4.model, avg_acc: 0.879760
model_path: bow_model/epoch5.model, avg_acc: 0.876920
model_path: bow_model/epoch6.model, avg_acc: 0.874160
model_path: bow_model/epoch7.model, avg_acc: 0.872000
model_path: bow_model/epoch8.model, avg_acc: 0.870360
model_path: bow_model/epoch9.model, avg_acc: 0.868480
model_path: bow_model/epoch10.model, avg_acc: 0.867240
model_path: bow_model/epoch11.model, avg_acc: 0.866200
model_path: bow_model/epoch12.model, avg_acc: 0.865560
model_path: bow_model/epoch13.model, avg_acc: 0.865160
model_path: bow_model/epoch14.model, avg_acc: 0.864480
model_path: bow_model/epoch15.model, avg_acc: 0.864240
model_path: bow_model/epoch16.model, avg_acc: 0.863800
model_path: bow_model/epoch17.model, avg_acc: 0.863520
model_path: bow_model/epoch18.model, avg_acc: 0.862760
model_path: bow_model/epoch19.model, avg_acc: 0.862680
model_path: bow_model/epoch20.model, avg_acc: 0.862240
model_path: bow_model/epoch21.model, avg_acc: 0.862280
model_path: bow_model/epoch22.model, avg_acc: 0.862080
model_path: bow_model/epoch23.model, avg_acc: 0.861560
model_path: bow_model/epoch24.model, avg_acc: 0.861280
model_path: bow_model/epoch25.model, avg_acc: 0.861160
model_path: bow_model/epoch26.model, avg_acc: 0.861080
model_path: bow_model/epoch27.model, avg_acc: 0.860920
model_path: bow_model/epoch28.model, avg_acc: 0.860800
model_path: bow_model/epoch29.model, avg_acc: 0.860760
```
注:过拟合导致acc持续下降,请忽略
fluid/PaddleNLP/text_classification/async_executor/data_generator.sh
0 → 100644
浏览文件 @
a76dc125
#!/usr/bin/env bash
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pushd
.
cd
./data_generator
# wget "http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz"
if
[
!
-f
aclImdb_v1.tar.gz
]
;
then
wget
"http://10.64.74.104:8080/paddle/dataset/imdb/aclImdb_v1.tar.gz"
fi
tar
zxvf aclImdb_v1.tar.gz
mkdir
train_data
python build_raw_data.py train | python splitfile.py 12 train_data
mkdir
test_data
python build_raw_data.py
test
| python splitfile.py 12 test_data
/opt/python27/bin/python IMDB.py train_data
/opt/python27/bin/python IMDB.py test_data
mv
./output_dataset/train_data ../
mv
./output_dataset/test_data ../
cp
aclImdb/imdb.vocab ../
rm
-rf
./output_dataset
rm
-rf
train_data
rm
-rf
test_data
rm
-rf
aclImdb
popd
fluid/PaddleNLP/text_classification/async_executor/data_generator/IMDB.py
0 → 100644
浏览文件 @
a76dc125
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
re
import
os
,
sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
'..'
)))
from
data_generator
import
MultiSlotDataGenerator
class
IMDbDataGenerator
(
MultiSlotDataGenerator
):
def
load_resource
(
self
,
dictfile
):
self
.
_vocab
=
{}
wid
=
0
with
open
(
dictfile
)
as
f
:
for
line
in
f
:
self
.
_vocab
[
line
.
strip
()]
=
wid
wid
+=
1
self
.
_unk_id
=
len
(
self
.
_vocab
)
self
.
_pattern
=
re
.
compile
(
r
'(;|,|\.|\?|!|\s|\(|\))'
)
def
process
(
self
,
line
):
send
=
'|'
.
join
(
line
.
split
(
'|'
)[:
-
1
]).
lower
().
replace
(
"<br />"
,
" "
).
strip
()
label
=
[
int
(
line
.
split
(
'|'
)[
-
1
])]
words
=
[
x
for
x
in
self
.
_pattern
.
split
(
send
)
if
x
and
x
!=
" "
]
feas
=
[
self
.
_vocab
[
x
]
if
x
in
self
.
_vocab
else
self
.
_unk_id
for
x
in
words
]
return
(
"words"
,
feas
),
(
"label"
,
label
)
imdb
=
IMDbDataGenerator
()
imdb
.
load_resource
(
"aclImdb/imdb.vocab"
)
# data from files
file_names
=
os
.
listdir
(
sys
.
argv
[
1
])
filelist
=
[]
for
i
in
range
(
0
,
len
(
file_names
)):
filelist
.
append
(
os
.
path
.
join
(
sys
.
argv
[
1
],
file_names
[
i
]))
line_limit
=
2500
process_num
=
24
imdb
.
run_from_files
(
filelist
=
filelist
,
line_limit
=
line_limit
,
process_num
=
process_num
,
output_dir
=
(
'output_dataset/%s'
%
(
sys
.
argv
[
1
])))
fluid/PaddleNLP/text_classification/async_executor/data_generator/data_generator.py
0 → 100644
浏览文件 @
a76dc125
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
multiprocessing
__all__
=
[
'MultiSlotDataGenerator'
]
class
DataGenerator
(
object
):
def
__init__
(
self
):
self
.
_proto_info
=
None
def
_set_filelist
(
self
,
filelist
):
if
not
isinstance
(
filelist
,
list
)
and
not
isinstance
(
filelist
,
tuple
):
raise
ValueError
(
"filelist%s must be in list or tuple type"
%
type
(
filelist
))
if
not
filelist
:
raise
ValueError
(
"filelist can not be empty"
)
self
.
_filelist
=
filelist
def
_set_process_num
(
self
,
process_num
):
if
not
isinstance
(
process_num
,
int
):
raise
ValueError
(
"process_num%s must be in int type"
%
type
(
process_num
))
if
process_num
<
1
:
raise
ValueError
(
"process_num can not less than 1"
)
self
.
_process_num
=
process_num
def
_set_line_limit
(
self
,
line_limit
):
if
not
isinstance
(
line_limit
,
int
):
raise
ValueError
(
"line_limit%s must be in int type"
%
type
(
line_limit
))
if
line_limit
<
1
:
raise
ValueError
(
"line_limit can not less than 1"
)
self
.
_line_limit
=
line_limit
def
_set_output_dir
(
self
,
output_dir
):
if
not
isinstance
(
output_dir
,
str
):
raise
ValueError
(
"output_dir%s must be in str type"
%
type
(
output_dir
))
if
not
output_dir
:
raise
ValueError
(
"output_dir can not be empty"
)
self
.
_output_dir
=
output_dir
def
_set_output_prefix
(
self
,
output_prefix
):
if
not
isinstance
(
output_prefix
,
str
):
raise
ValueError
(
"output_prefix%s must be in str type"
%
type
(
output_prefix
))
self
.
_output_prefix
=
output_prefix
def
_set_output_fill_digit
(
self
,
output_fill_digit
):
if
not
isinstance
(
output_fill_digit
,
int
):
raise
ValueError
(
"output_fill_digit%s must be in int type"
%
type
(
output_fill_digit
))
if
output_fill_digit
<
1
:
raise
ValueError
(
"output_fill_digit can not less than 1"
)
self
.
_output_fill_digit
=
output_fill_digit
def
_set_proto_filename
(
self
,
proto_filename
):
if
not
isinstance
(
proto_filename
,
str
):
raise
ValueError
(
"proto_filename%s must be in str type"
%
type
(
proto_filename
))
if
not
proto_filename
:
raise
ValueError
(
"proto_filename can not be empty"
)
self
.
_proto_filename
=
proto_filename
def
_print_info
(
self
):
'''
Print the configuration information
(Called only in the run_from_stdin function).
'''
sys
.
stderr
.
write
(
"="
*
16
+
" config "
+
"="
*
16
+
"
\n
"
)
sys
.
stderr
.
write
(
" filelist size: %d
\n
"
%
len
(
self
.
_filelist
))
sys
.
stderr
.
write
(
" process num: %d
\n
"
%
self
.
_process_num
)
sys
.
stderr
.
write
(
" line limit: %d
\n
"
%
self
.
_line_limit
)
sys
.
stderr
.
write
(
" output dir: %s
\n
"
%
self
.
_output_dir
)
sys
.
stderr
.
write
(
" output prefix: %s
\n
"
%
self
.
_output_prefix
)
sys
.
stderr
.
write
(
" output fill digit: %d
\n
"
%
self
.
_output_fill_digit
)
sys
.
stderr
.
write
(
" proto filename: %s
\n
"
%
self
.
_proto_filename
)
sys
.
stderr
.
write
(
"==== This may take a few minutes... ====
\n
"
)
def
_get_output_filename
(
self
,
output_index
,
lock
=
None
):
'''
This function is used to get the name of the output file and
update output_index.
Args:
output_index(manager.Value(i)): the index of output file.
lock(manager.Lock): The lock for processes safe.
Return:
Return the name(string) of output file.
'''
if
lock
is
not
None
:
lock
.
acquire
()
file_index
=
output_index
.
value
output_index
.
value
+=
1
if
lock
is
not
None
:
lock
.
release
()
filename
=
os
.
path
.
join
(
self
.
_output_dir
,
self
.
_output_prefix
)
\
+
str
(
file_index
).
zfill
(
self
.
_output_fill_digit
)
sys
.
stderr
.
write
(
"[%d] write data to file: %s
\n
"
%
(
os
.
getpid
(),
filename
))
return
filename
def
run_from_stdin
(
self
,
is_local
=
True
,
hadoop_host
=
None
,
hadoop_ugi
=
None
,
proto_path
=
None
,
proto_filename
=
"data_feed.proto"
):
'''
This function reads the data row from stdin, parses it with the
process function, and further parses the return value of the
process function with the _gen_str function. The parsed data will
be wrote to stdout and the corresponding protofile will be
generated. If local is set to False, the protofile will be
uploaded to hadoop.
Args:
is_local(bool): Whether to execute locally. If it is False, the
protofile will be uploaded to hadoop. The
default value is True.
hadoop_host(str): The host name of the hadoop. It should be
in this format: "hdfs://${HOST}:${PORT}".
hadoop_ugi(str): The ugi of the hadoop. It should be in this
format: "${USERNAME},${PASSWORD}".
proto_path(str): The hadoop path you want to upload the
protofile to.
proto_filename(str): The name of protofile. The default value
is "data_feed.proto". It is not
recommended to modify it.
'''
if
is_local
:
print
\
'''
\033
[1;34m=======================================================
Pay attention to that the version of Python in Hadoop
may inconsistent with local version. Please check the
Python version of Hadoop to ensure that it is >= 2.7.
=======================================================
\033
[0m'''
else
:
if
hadoop_ugi
is
None
or
\
hadoop_host
is
None
or
\
proto_path
is
None
:
raise
ValueError
(
"pls set hadoop_ugi, hadoop_host, and proto_path"
)
self
.
_set_proto_filename
(
proto_filename
)
for
line
in
sys
.
stdin
:
user_parsed_line
=
self
.
process
(
line
)
sys
.
stdout
.
write
(
self
.
_gen_str
(
user_parsed_line
))
if
self
.
_proto_info
is
not
None
:
# maybe some task do not catch files
with
open
(
self
.
_proto_filename
,
"w"
)
as
f
:
f
.
write
(
self
.
_get_proto_desc
(
self
.
_proto_info
))
if
is_local
==
False
:
cmd
=
"$HADOOP_HOME/bin/hadoop fs"
\
+
" -Dhadoop.job.ugi="
+
hadoop_ugi
\
+
" -Dfs.default.name="
+
hadoop_host
\
+
" -put "
+
self
.
_proto_filename
+
" "
+
proto_path
os
.
system
(
cmd
)
def
run_from_files
(
self
,
filelist
,
line_limit
,
process_num
=
1
,
output_dir
=
"./output_dataset"
,
output_prefix
=
"part-"
,
output_fill_digit
=
8
,
proto_filename
=
"data_feed.proto"
):
'''
This function will run process_num processes to process the files
in the filelist. It will create the output data folder(output_dir)
in the current directory, and write the processed data into the
output_dir folder(each file line_limit data, the prefix of filename
is output_prefix, the suffix of filename is output_fill_digit
numbers). And the proto_info is generated at the same time. the
name of proto file will be proto_filename.
Args:
filelist(list or tuple): Files that need to be processed.
line_limit(int): Maximum number of data stored per file.
process_num(int): Number of processes running simultaneously.
output_dir(str): The name of the folder where the output
data file is stored.
output_prefix(str): The prefix of output data file.
output_fill_digit(int): The number of suffix numbers of the
output data file.
proto_filename(str): The name of protofile.
'''
self
.
_set_filelist
(
filelist
)
self
.
_set_line_limit
(
line_limit
)
self
.
_set_process_num
(
min
(
process_num
,
len
(
filelist
)))
self
.
_set_output_dir
(
output_dir
)
self
.
_set_output_prefix
(
output_prefix
)
self
.
_set_output_fill_digit
(
output_fill_digit
)
self
.
_set_proto_filename
(
proto_filename
)
self
.
_print_info
()
if
not
os
.
path
.
exists
(
self
.
_output_dir
):
os
.
makedirs
(
self
.
_output_dir
)
elif
not
os
.
path
.
isdir
(
self
.
_output_dir
):
raise
ValueError
(
"%s is not a directory"
%
self
.
_output_dir
)
processes
=
multiprocessing
.
Pool
()
manager
=
multiprocessing
.
Manager
()
output_index
=
manager
.
Value
(
'i'
,
0
)
file_queue
=
manager
.
Queue
()
lock
=
manager
.
Lock
()
remaining_queue
=
manager
.
Queue
()
for
file
in
self
.
_filelist
:
file_queue
.
put
(
file
)
info_result
=
[]
for
i
in
range
(
self
.
_process_num
):
info_result
.
append
(
processes
.
apply_async
(
subprocess_wrapper
,
\
(
self
,
file_queue
,
remaining_queue
,
output_index
,
lock
,
)))
processes
.
close
()
processes
.
join
()
infos
=
[
result
.
get
()
for
result
in
info_result
if
result
.
get
()
is
not
None
]
proto_info
=
self
.
_combine_infos
(
infos
)
with
open
(
os
.
path
.
join
(
self
.
_output_dir
,
self
.
_proto_filename
),
"w"
)
as
f
:
f
.
write
(
self
.
_get_proto_desc
(
proto_info
))
while
not
remaining_queue
.
empty
():
with
open
(
self
.
_get_output_filename
(
output_index
),
"w"
)
as
f
:
for
i
in
range
(
min
(
self
.
_line_limit
,
remaining_queue
.
qsize
())):
f
.
write
(
remaining_queue
.
get
(
False
))
def
_subprocess
(
self
,
file_queue
,
remaining_queue
,
output_index
,
lock
):
'''
This function will be called by multiple processes. It is used to
continuously fetch files from file_queue, using process() function
(defined by user) and _gen_str() function(defined by concrete classes)
to process data in units of rows. Write the processed data to the
file(each file will be self._line_limit line). If the file in the
file_queue has been consumed, but the file is not full, the data
that is less than the self._line_limit line will be stored in the
remaining_queue.
Args:
file_queue(manager.Queue): The queue contains all the file
names to be processed.
remaining_queue(manager.Queue): The queue contains the data that
is less than the self._line_limit
line.
output_index(manager.Value(i)): The index(suffix) of the
output file.
lock(manager.Lock): The lock for processes safe.
Returns:
Return a proto_info which can be translated into a proto string.
'''
buffer
=
[]
while
not
file_queue
.
empty
():
try
:
filename
=
file_queue
.
get
(
False
)
except
:
# file_queue empty
break
with
open
(
filename
,
'r'
)
as
f
:
for
line
in
f
:
buffer
.
append
(
self
.
_gen_str
(
self
.
process
(
line
)))
if
len
(
buffer
)
==
self
.
_line_limit
:
with
open
(
self
.
_get_output_filename
(
output_index
,
lock
),
"w"
)
as
wf
:
for
x
in
buffer
:
wf
.
write
(
x
)
buffer
=
[]
if
buffer
:
for
x
in
buffer
:
remaining_queue
.
put
(
x
)
return
self
.
_proto_info
def
_gen_str
(
self
,
line
):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the datafeed,and
updating proto_info infomation.
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the datafeed.
'''
raise
NotImplementedError
(
"pls use MultiSlotDataGenerator or PairWiseDataGenerator"
)
def
_combine_infos
(
self
,
infos
):
'''
This function is used to merge proto_info information from different
processes. In general, the proto_info of each process is consistent.
Args:
infos(list): the list of proto_infos from different processes.
Returns:
Return a unified proto_info.
'''
raise
NotImplementedError
(
"pls use MultiSlotDataGenerator or PairWiseDataGenerator"
)
def
_get_proto_desc
(
self
,
proto_info
):
'''
This function outputs the string of the proto file(can be directly
written to the file) according to the proto_info information.
Args:
proto_info: The proto information used to generate the proto
string. The type of the variable will be determined
by the subclass. In the MultiSlotDataGenerator,
proto_info variable is a list of tuple.
Returns:
Returns a string of the proto file.
'''
raise
NotImplementedError
(
"pls use MultiSlotDataGenerator or PairWiseDataGenerator"
)
def
process
(
self
,
line
):
'''
This function needs to be overridden by the user to process the
original data row into a list or tuple.
Args:
line(str): the original data row
Returns:
Returns the data processed by the user.
The data format is list or tuple:
[(name, [feasign, ...]), ...]
or ((name, [feasign, ...]), ...)
For example:
[("words", [1926, 08, 17]), ("label", [1])]
or (("words", [1926, 08, 17]), ("label", [1]))
Note:
The type of feasigns must be in int or float. Once the float
element appears in the feasign, the type of that slot will be
processed into a float.
'''
raise
NotImplementedError
(
"pls rewrite this function to return a list or tuple: "
+
"[(name, [feasign, ...]), ...] or ((name, [feasign, ...]), ...)"
)
def
subprocess_wrapper
(
instance
,
file_queue
,
remaining_queue
,
output_index
,
lock
):
'''
In order to use the class function as a process, you need to wrap it.
'''
return
instance
.
_subprocess
(
file_queue
,
remaining_queue
,
output_index
,
lock
)
class
MultiSlotDataGenerator
(
DataGenerator
):
def
_combine_infos
(
self
,
infos
):
'''
This function is used to merge proto_info information from different
processes. In general, the proto_info of each process is consistent.
The type of input infos is list, and the type of element of infos is
tuple. The format of element of infos will be (name, type).
Args:
infos(list): the list of proto_infos from different processes.
Returns:
Return a unified proto_info.
Note:
This function is only called by the run_from_files function, so
when using the run_from_stdin function(usually used for hadoop),
the output of the process function(rewritten by the user) does
not allow that the same field to have both float and int type
values.
'''
proto_info
=
infos
[
0
]
for
info
in
infos
:
for
index
,
slot
in
enumerate
(
info
):
name
,
type
=
slot
if
name
!=
proto_info
[
index
][
0
]:
raise
ValueError
(
"combine infos error, pls contact the maintainer of this code~"
)
if
type
==
"float"
and
proto_info
[
index
][
1
]
==
"uint64"
:
proto_info
[
index
]
=
(
name
,
type
)
return
proto_info
def
_get_proto_desc
(
self
,
proto_info
):
'''
Generate a string of proto file based on the proto_info information.
The proto_info will be a list of tuples:
>>> [(Name, Type), ...]
The string of proto file will be in this format:
>>> name: "MultiSlotDataFeed"
>>> batch_size: 32
>>> multi_slot_desc {
>>> slots {
>>> name: Name
>>> type: Type
>>> is_dense: false
>>> is_used: false
>>> }
>>> }
Args:
proto_info(list): The proto information used to generate the
proto string.
Returns:
Returns a string of the proto file.
'''
proto_str
=
"name:
\"
MultiSlotDataFeed
\"\n
"
\
+
"batch_size: 32
\n
multi_slot_desc {
\n
"
for
elem
in
proto_info
:
proto_str
+=
" slots {
\n
"
\
+
" name:
\"
%s
\"\n
"
%
elem
[
0
]
\
+
" type:
\"
%s
\"\n
"
%
elem
[
1
]
\
+
" is_dense: false
\n
"
\
+
" is_used: false
\n
"
\
+
" }
\n
"
proto_str
+=
"}"
return
proto_str
def
_gen_str
(
self
,
line
):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the MultiSlotDataFeed,
and updating proto_info infomation.
The input line will be in this format:
>>> [(name, [feasign, ...]), ...]
>>> or ((name, [feasign, ...]), ...)
The output will be in this format:
>>> [ids_num id1 id2 ...] ...
The proto_info will be in this format:
>>> [(name, type), ...]
For example, if the input is like this:
>>> [("words", [1926, 08, 17]), ("label", [1])]
>>> or (("words", [1926, 08, 17]), ("label", [1]))
the output will be:
>>> 3 1234 2345 3456 1 1
the proto_info will be:
>>> [("words", "uint64"), ("label", "uint64")]
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if
not
isinstance
(
line
,
list
)
and
not
isinstance
(
line
,
tuple
):
raise
ValueError
(
"the output of process() must be in list or tuple type"
)
output
=
""
if
self
.
_proto_info
is
None
:
self
.
_proto_info
=
[]
for
item
in
line
:
name
,
elements
=
item
if
not
isinstance
(
name
,
str
):
raise
ValueError
(
"name%s must be in str type"
%
type
(
name
))
if
not
isinstance
(
elements
,
list
):
raise
ValueError
(
"elements%s must be in list type"
%
type
(
elements
))
if
not
elements
:
raise
ValueError
(
"the elements of each field can not be empty, you need padding it in process()."
)
self
.
_proto_info
.
append
((
name
,
"uint64"
))
if
output
:
output
+=
" "
output
+=
str
(
len
(
elements
))
for
elem
in
elements
:
if
isinstance
(
elem
,
float
):
self
.
_proto_info
[
-
1
]
=
(
name
,
"float"
)
elif
not
isinstance
(
elem
,
int
)
and
not
isinstance
(
elem
,
long
):
raise
ValueError
(
"the type of element%s must be in int or float"
%
type
(
elem
))
output
+=
" "
+
str
(
elem
)
else
:
if
len
(
line
)
!=
len
(
self
.
_proto_info
):
raise
ValueError
(
"the complete field set of two given line are inconsistent."
)
for
index
,
item
in
enumerate
(
line
):
name
,
elements
=
item
if
not
isinstance
(
name
,
str
):
raise
ValueError
(
"name%s must be in str type"
%
type
(
name
))
if
not
isinstance
(
elements
,
list
):
raise
ValueError
(
"elements%s must be in list type"
%
type
(
elements
))
if
not
elements
:
raise
ValueError
(
"the elements of each field can not be empty, you need padding it in process()."
)
if
name
!=
self
.
_proto_info
[
index
][
0
]:
raise
ValueError
(
"the field name of two given line are not match: require<%s>, get<%d>."
%
(
self
.
_proto_info
[
index
][
0
],
name
))
if
output
:
output
+=
" "
output
+=
str
(
len
(
elements
))
for
elem
in
elements
:
if
self
.
_proto_info
[
index
][
1
]
!=
"float"
:
if
isinstance
(
elem
,
float
):
self
.
_proto_info
[
index
]
=
(
name
,
"float"
)
elif
not
isinstance
(
elem
,
int
)
and
not
isinstance
(
elem
,
long
):
raise
ValueError
(
"the type of element%s must be in int or float"
%
type
(
elem
))
output
+=
" "
+
str
(
elem
)
return
output
+
"
\n
"
fluid/PaddleNLP/text_classification/async_executor/data_generator/splitfile.py
0 → 100644
浏览文件 @
a76dc125
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Split file into parts
"""
import
sys
import
os
block
=
int
(
sys
.
argv
[
1
])
datadir
=
sys
.
argv
[
2
]
file_list
=
[]
for
i
in
range
(
block
):
file_list
.
append
(
open
(
datadir
+
"/part-"
+
str
(
i
),
"w"
))
id_
=
0
for
line
in
sys
.
stdin
:
file_list
[
id_
%
block
].
write
(
line
)
id_
+=
1
for
f
in
file_list
:
f
.
close
()
fluid/PaddleNLP/text_classification/async_executor/data_reader.py
0 → 100644
浏览文件 @
a76dc125
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
os
import
paddle
def
parse_fields
(
fields
):
words_width
=
int
(
fields
[
0
])
words
=
fields
[
1
:
1
+
words_width
]
label
=
fields
[
-
1
]
return
words
,
label
def
imdb_data_feed_reader
(
data_dir
,
batch_size
,
buf_size
):
"""
Data feed reader for IMDB dataset.
This data set has been converted from original format to a format suitable
for AsyncExecutor
See data.proto for data format
"""
def
reader
():
for
file
in
os
.
listdir
(
data_dir
):
if
file
.
endswith
(
'.proto'
):
continue
with
open
(
os
.
path
.
join
(
data_dir
,
file
),
'r'
)
as
f
:
for
line
in
f
:
fields
=
line
.
split
(
' '
)
words
,
label
=
parse_fields
(
fields
)
yield
words
,
label
test_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
reader
,
buf_size
=
buf_size
),
batch_size
=
batch_size
)
return
test_reader
fluid/PaddleNLP/text_classification/async_executor/infer.py
0 → 100644
浏览文件 @
a76dc125
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
time
import
unittest
import
contextlib
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
import
data_reader
def
infer
(
test_reader
,
use_cuda
,
model_path
=
None
):
"""
inference function
"""
if
model_path
is
None
:
print
(
str
(
model_path
)
+
" cannot be found"
)
return
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
inference_scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
inference_scope
):
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
model_path
,
exe
)
total_acc
=
0.0
total_count
=
0
for
data
in
test_reader
():
acc
=
exe
.
run
(
inference_program
,
feed
=
utils
.
data2tensor
(
data
,
place
),
fetch_list
=
fetch_targets
,
return_numpy
=
True
)
total_acc
+=
acc
[
0
]
*
len
(
data
)
total_count
+=
len
(
data
)
avg_acc
=
total_acc
/
total_count
print
(
"model_path: %s, avg_acc: %f"
%
(
model_path
,
avg_acc
))
if
__name__
==
"__main__"
:
if
__package__
is
None
:
from
os
import
sys
,
path
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
)))
import
utils
batch_size
=
128
model_path
=
sys
.
argv
[
1
]
test_data_dirname
=
'test_data'
if
len
(
sys
.
argv
)
==
3
:
test_data_dirname
=
sys
.
argv
[
2
]
test_reader
=
data_reader
.
imdb_data_feed_reader
(
'test_data'
,
batch_size
,
buf_size
=
500000
)
models
=
os
.
listdir
(
model_path
)
for
i
in
range
(
0
,
len
(
models
)):
epoch_path
=
"epoch"
+
str
(
i
)
+
".model"
epoch_path
=
os
.
path
.
join
(
model_path
,
epoch_path
)
infer
(
test_reader
,
use_cuda
=
False
,
model_path
=
epoch_path
)
fluid/PaddleNLP/text_classification/async_executor/train.py
0 → 100644
浏览文件 @
a76dc125
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
time
import
multiprocessing
import
paddle
import
paddle.fluid
as
fluid
def
train
(
network
,
dict_dim
,
lr
,
save_dirname
,
training_data_dirname
,
pass_num
,
thread_num
,
batch_size
):
file_names
=
os
.
listdir
(
training_data_dirname
)
filelist
=
[]
for
i
in
range
(
0
,
len
(
file_names
)):
if
file_names
[
i
]
==
'data_feed.proto'
:
continue
filelist
.
append
(
os
.
path
.
join
(
training_data_dirname
,
file_names
[
i
]))
dataset
=
fluid
.
DataFeedDesc
(
os
.
path
.
join
(
training_data_dirname
,
'data_feed.proto'
))
dataset
.
set_batch_size
(
batch_size
)
# datafeed should be assigned a batch size
dataset
.
set_use_slots
([
'words'
,
'label'
])
data
=
fluid
.
layers
.
data
(
name
=
"words"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
1
)
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
1
],
dtype
=
"int64"
)
avg_cost
,
acc
,
prediction
=
network
(
data
,
label
,
dict_dim
)
optimizer
=
fluid
.
optimizer
.
Adagrad
(
learning_rate
=
lr
)
opt_ops
,
weight_and_grad
=
optimizer
.
minimize
(
avg_cost
)
startup_program
=
fluid
.
default_startup_program
()
main_program
=
fluid
.
default_main_program
()
place
=
fluid
.
CPUPlace
()
executor
=
fluid
.
Executor
(
place
)
executor
.
run
(
startup_program
)
async_executor
=
fluid
.
AsyncExecutor
(
place
)
for
i
in
range
(
pass_num
):
pass_start
=
time
.
time
()
async_executor
.
run
(
main_program
,
dataset
,
filelist
,
thread_num
,
[
acc
],
debug
=
False
)
print
(
'pass_id: %u pass_time_cost %f'
%
(
i
,
time
.
time
()
-
pass_start
))
fluid
.
io
.
save_inference_model
(
'%s/epoch%d.model'
%
(
save_dirname
,
i
),
[
data
.
name
,
label
.
name
],
[
acc
],
executor
)
if
__name__
==
"__main__"
:
if
__package__
is
None
:
from
os
import
sys
,
path
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
)))
from
nets
import
bow_net
,
cnn_net
,
lstm_net
,
gru_net
from
utils
import
load_vocab
batch_size
=
4
lr
=
0.002
pass_num
=
30
save_dirname
=
""
thread_num
=
multiprocessing
.
cpu_count
()
if
sys
.
argv
[
1
]
==
"bow"
:
network
=
bow_net
batch_size
=
128
save_dirname
=
"bow_model"
elif
sys
.
argv
[
1
]
==
"cnn"
:
network
=
cnn_net
lr
=
0.01
save_dirname
=
"cnn_model"
elif
sys
.
argv
[
1
]
==
"lstm"
:
network
=
lstm_net
lr
=
0.05
save_dirname
=
"lstm_model"
elif
sys
.
argv
[
1
]
==
"gru"
:
network
=
gru_net
batch_size
=
128
lr
=
0.05
save_dirname
=
"gru_model"
training_data_dirname
=
'train_data/'
if
len
(
sys
.
argv
)
==
3
:
training_data_dirname
=
sys
.
argv
[
2
]
if
len
(
sys
.
argv
)
==
4
:
if
thread_num
>=
int
(
sys
.
argv
[
3
]):
thread_num
=
int
(
sys
.
argv
[
3
])
vocab
=
load_vocab
(
'imdb.vocab'
)
dict_dim
=
len
(
vocab
)
train
(
network
,
dict_dim
,
lr
,
save_dirname
,
training_data_dirname
,
pass_num
,
thread_num
,
batch_size
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录