Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3d289649
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3d289649
编写于
12月 09, 2016
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
follow comments
上级
aaecfcc4
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
57 addition
and
63 deletion
+57
-63
demo/sentiment/predict.py
demo/sentiment/predict.py
+27
-35
doc/tutorials/sentiment_analysis/sentiment_analysis.md
doc/tutorials/sentiment_analysis/sentiment_analysis.md
+15
-14
doc_cn/demo/sentiment_analysis/sentiment_analysis.md
doc_cn/demo/sentiment_analysis/sentiment_analysis.md
+15
-14
未找到文件。
demo/sentiment/predict.py
浏览文件 @
3d289649
...
...
@@ -66,24 +66,18 @@ class SentimentPrediction():
for
v
in
open
(
label_file
,
'r'
):
self
.
label
[
int
(
v
.
split
(
'
\t
'
)[
1
])]
=
v
.
split
(
'
\t
'
)[
0
]
def
get_
data
(
self
,
data
):
def
get_
index
(
self
,
data
):
"""
Get input data of paddle format
.
transform word into integer index according to the dictionary
.
"""
for
line
in
data
:
words
=
line
.
strip
().
split
()
words
=
data
.
strip
().
split
()
word_slot
=
[
self
.
word_dict
[
w
]
for
w
in
words
if
w
in
self
.
word_dict
]
if
not
word_slot
:
print
"all words are not in dictionary: %s"
,
line
continue
yield
[
word_slot
]
return
word_slot
def
predict
(
self
,
batch_size
):
def
batch_predict
(
batch_data
):
input
=
self
.
converter
(
self
.
get_data
(
batch_data
))
def
batch_predict
(
self
,
data_batch
):
input
=
self
.
converter
(
data_batch
)
output
=
self
.
network
.
forwardTest
(
input
)
prob
=
output
[
0
][
"value"
]
labs
=
np
.
argsort
(
-
prob
)
...
...
@@ -94,15 +88,6 @@ class SentimentPrediction():
print
(
"predicting label is %s"
%
(
self
.
label
[
lab
[
0
]]))
batch
=
[]
for
line
in
sys
.
stdin
:
batch
.
append
(
line
)
if
len
(
batch
)
==
batch_size
:
batch_predict
(
batch
)
batch
=
[]
if
len
(
batch
)
>
0
:
batch_predict
(
batch
)
def
option_parser
():
usage
=
"python predict.py -n config -w model_dir -d dictionary -i input_file "
parser
=
OptionParser
(
usage
=
"usage: %s [options]"
%
usage
)
...
...
@@ -152,8 +137,15 @@ def main():
label
=
options
.
label
swig_paddle
.
initPaddle
(
"--use_gpu=0"
)
predict
=
SentimentPrediction
(
train_conf
,
dict_file
,
model_path
,
label
)
predict
.
predict
(
batch_size
)
batch
=
[]
for
line
in
sys
.
stdin
:
batch
.
append
([
predict
.
get_index
(
line
)])
if
len
(
batch
)
==
batch_size
:
predict
.
batch_predict
(
batch
)
batch
=
[]
if
len
(
batch
)
>
0
:
predict
.
batch_predict
(
batch
)
if
__name__
==
'__main__'
:
main
()
doc/tutorials/sentiment_analysis/sentiment_analysis.md
浏览文件 @
3d289649
...
...
@@ -293,20 +293,21 @@ predict.sh:
model=model_output/pass-00002/
config=trainer_config.py
label=data/pre-imdb/labels.list
python predict.py
\
-n $config
\
-w $model
\
-b $label
\
-d data/pre-imdb/dict.txt
\
-i data/aclImdb/test/pos/10007_10.txt
```
* `predict.py`: predicting interface.
* -n $config : set network configure.
* -w $model: set model path.
* -b $label: set dictionary about corresponding relation between integer label and string label.
* -d data/pre-imdb/dict.txt: set dictionary.
* -i data/aclImdb/test/pos/10014_7.txt: set one example file to predict.
cat ./data/aclImdb/test/pos/10007_10.txt | python predict.py
\
--tconf=$config
\
--model=$model
\
--label=$label
\
--dict=./data/pre-imdb/dict.txt
\
--batch_size=1
```
* `cat ./data/aclImdb/test/pos/10007_10.txt` : the input sample.
* `predict.py` : predicting interface.
* `--tconf=$config` : set network configure.
* ` --model=$model` : set model path.
* `--label=$label` : set dictionary about corresponding relation between integer label and string label.
* `--dict=data/pre-imdb/dict.txt` : set dictionary.
* `--batch_size=1` : set batch size.
Note you should make sure the default model path `model_output/pass-00002`
exists or change the model path.
...
...
doc_cn/demo/sentiment_analysis/sentiment_analysis.md
浏览文件 @
3d289649
...
...
@@ -291,20 +291,21 @@ predict.sh:
model=model_output/pass-00002/
config=trainer_config.py
label=data/pre-imdb/labels.list
python predict.py \
-n $config\
-w $model \
-b $label \
-d data/pre-imdb/dict.txt \
-i data/aclImdb/test/pos/10007_10.txt
```
*
`predict.py`
: 预测接口脚本。
*
-n $config : 设置网络配置。
*
-w $model: 设置模型路径。
*
-b $label: 设置标签类别字典,这个字典是整数标签和字符串标签的一个对应。
*
-d data/pre-imdb/dict.txt: 设置字典文件。
*
-i data/aclImdb/test/pos/10014_7.txt: 设置一个要预测的示例文件。
cat ./data/aclImdb/test/pos/10007_10.txt | python predict.py \
--tconf=$config\
--model=$model \
--label=$label \
--dict=./data/pre-imdb/dict.txt \
--batch_size=1
```
*
`cat ./data/aclImdb/test/pos/10007_10.txt`
: 输入预测样本。
*
`predict.py`
: 预测接口脚本。
*
`--tconf=$config`
: 设置网络配置。
*
`--model=$model`
: 设置模型路径。
*
`--label=$label`
: 设置标签类别字典,这个字典是整数标签和字符串标签的一个对应。
*
`--dict=data/pre-imdb/dict.txt`
: 设置字典文件。
*
`--batch_size=1`
: 设置batch size。
注意应该确保默认模型路径
`model_output / pass-00002`
存在或更改为其它模型路径。
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录