提交 b3f0f3d2 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #766 from qingqing01/sentiment

Support predicting the samples from sys.stdin
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os, sys
import numpy as np import numpy as np
from optparse import OptionParser from optparse import OptionParser
from py_paddle import swig_paddle, DataProviderConverter from py_paddle import swig_paddle, DataProviderConverter
...@@ -66,35 +66,27 @@ class SentimentPrediction(): ...@@ -66,35 +66,27 @@ class SentimentPrediction():
for v in open(label_file, 'r'): for v in open(label_file, 'r'):
self.label[int(v.split('\t')[1])] = v.split('\t')[0] self.label[int(v.split('\t')[1])] = v.split('\t')[0]
def get_data(self, data_file): def get_index(self, data):
""" """
Get input data of paddle format. transform word into integer index according to the dictionary.
""" """
with open(data_file, 'r') as fdata: words = data.strip().split()
for line in fdata: word_slot = [
words = line.strip().split() self.word_dict[w] for w in words if w in self.word_dict
word_slot = [ ]
self.word_dict[w] for w in words if w in self.word_dict return word_slot
]
if not word_slot:
print "all words are not in dictionary: %s", line
continue
yield [word_slot]
def predict(self, data_file): def batch_predict(self, data_batch):
""" input = self.converter(data_batch)
data_file: file name of input data.
"""
input = self.converter(self.get_data(data_file))
output = self.network.forwardTest(input) output = self.network.forwardTest(input)
prob = output[0]["value"] prob = output[0]["value"]
lab = np.argsort(-prob) labs = np.argsort(-prob)
if self.label is None: for idx, lab in enumerate(labs):
print("%s: predicting label is %d" % (data_file, lab[0][0])) if self.label is None:
else: print("predicting label is %d" % (lab[0]))
print("%s: predicting label is %s" % else:
(data_file, self.label[lab[0][0]])) print("predicting label is %s" %
(self.label[lab[0]]))
def option_parser(): def option_parser():
usage = "python predict.py -n config -w model_dir -d dictionary -i input_file " usage = "python predict.py -n config -w model_dir -d dictionary -i input_file "
...@@ -119,11 +111,13 @@ def option_parser(): ...@@ -119,11 +111,13 @@ def option_parser():
default=None, default=None,
help="dictionary file") help="dictionary file")
parser.add_option( parser.add_option(
"-i", "-c",
"--data", "--batch_size",
type="int",
action="store", action="store",
dest="data", dest="batch_size",
help="data file to predict") default=1,
help="the batch size for prediction")
parser.add_option( parser.add_option(
"-w", "-w",
"--model", "--model",
...@@ -137,14 +131,21 @@ def option_parser(): ...@@ -137,14 +131,21 @@ def option_parser():
def main(): def main():
options, args = option_parser() options, args = option_parser()
train_conf = options.train_conf train_conf = options.train_conf
data = options.data batch_size = options.batch_size
dict_file = options.dict_file dict_file = options.dict_file
model_path = options.model_path model_path = options.model_path
label = options.label label = options.label
swig_paddle.initPaddle("--use_gpu=0") swig_paddle.initPaddle("--use_gpu=0")
predict = SentimentPrediction(train_conf, dict_file, model_path, label) predict = SentimentPrediction(train_conf, dict_file, model_path, label)
predict.predict(data)
batch = []
for line in sys.stdin:
batch.append([predict.get_index(line)])
if len(batch) == batch_size:
predict.batch_predict(batch)
batch=[]
if len(batch) > 0:
predict.batch_predict(batch)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -19,9 +19,9 @@ set -e ...@@ -19,9 +19,9 @@ set -e
model=model_output/pass-00002/ model=model_output/pass-00002/
config=trainer_config.py config=trainer_config.py
label=data/pre-imdb/labels.list label=data/pre-imdb/labels.list
python predict.py \ cat ./data/aclImdb/test/pos/10007_10.txt | python predict.py \
-n $config\ --tconf=$config\
-w $model \ --model=$model \
-b $label \ --label=$label \
-d ./data/pre-imdb/dict.txt \ --dict=./data/pre-imdb/dict.txt \
-i ./data/aclImdb/test/pos/10007_10.txt --batch_size=1
...@@ -293,20 +293,21 @@ predict.sh: ...@@ -293,20 +293,21 @@ predict.sh:
model=model_output/pass-00002/ model=model_output/pass-00002/
config=trainer_config.py config=trainer_config.py
label=data/pre-imdb/labels.list label=data/pre-imdb/labels.list
python predict.py \ cat ./data/aclImdb/test/pos/10007_10.txt | python predict.py \
-n $config\ --tconf=$config\
-w $model \ --model=$model \
-b $label \ --label=$label \
-d data/pre-imdb/dict.txt \ --dict=./data/pre-imdb/dict.txt \
-i data/aclImdb/test/pos/10007_10.txt --batch_size=1
``` ```
* `predict.py`: predicting interface. * `cat ./data/aclImdb/test/pos/10007_10.txt` : the input sample.
* -n $config : set network configure. * `predict.py` : predicting interface.
* -w $model: set model path. * `--tconf=$config` : set network configure.
* -b $label: set dictionary about corresponding relation between integer label and string label. * ` --model=$model` : set model path.
* -d data/pre-imdb/dict.txt: set dictionary. * `--label=$label` : set dictionary about corresponding relation between integer label and string label.
* -i data/aclImdb/test/pos/10014_7.txt: set one example file to predict. * `--dict=data/pre-imdb/dict.txt` : set dictionary.
* `--batch_size=1` : set batch size.
Note you should make sure the default model path `model_output/pass-00002` Note you should make sure the default model path `model_output/pass-00002`
exists or change the model path. exists or change the model path.
......
...@@ -291,20 +291,21 @@ predict.sh: ...@@ -291,20 +291,21 @@ predict.sh:
model=model_output/pass-00002/ model=model_output/pass-00002/
config=trainer_config.py config=trainer_config.py
label=data/pre-imdb/labels.list label=data/pre-imdb/labels.list
python predict.py \ cat ./data/aclImdb/test/pos/10007_10.txt | python predict.py \
-n $config\ --tconf=$config\
-w $model \ --model=$model \
-b $label \ --label=$label \
-d data/pre-imdb/dict.txt \ --dict=./data/pre-imdb/dict.txt \
-i data/aclImdb/test/pos/10007_10.txt --batch_size=1
``` ```
* `predict.py`: 预测接口脚本。 * `cat ./data/aclImdb/test/pos/10007_10.txt` : 输入预测样本。
* -n $config : 设置网络配置。 * `predict.py` : 预测接口脚本。
* -w $model: 设置模型路径。 * `--tconf=$config` : 设置网络配置。
* -b $label: 设置标签类别字典,这个字典是整数标签和字符串标签的一个对应。 * `--model=$model` : 设置模型路径。
* -d data/pre-imdb/dict.txt: 设置字典文件。 * `--label=$label` : 设置标签类别字典,这个字典是整数标签和字符串标签的一个对应。
* -i data/aclImdb/test/pos/10014_7.txt: 设置一个要预测的示例文件。 * `--dict=data/pre-imdb/dict.txt` : 设置字典文件。
* `--batch_size=1` : 设置batch size。
注意应该确保默认模型路径`model_output / pass-00002`存在或更改为其它模型路径。 注意应该确保默认模型路径`model_output / pass-00002`存在或更改为其它模型路径。
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册