From aaecfcc47f0319cb74b80a7e095a8406ce8354ff Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 8 Dec 2016 11:03:21 +0800 Subject: [PATCH] Support predicting the samples from sys.stdin --- demo/sentiment/predict.py | 73 ++++++++++++++++++++++----------------- demo/sentiment/predict.sh | 12 +++---- 2 files changed, 47 insertions(+), 38 deletions(-) diff --git a/demo/sentiment/predict.py b/demo/sentiment/predict.py index bc0f6f31264..e01dc6d2282 100755 --- a/demo/sentiment/predict.py +++ b/demo/sentiment/predict.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +import os, sys import numpy as np from optparse import OptionParser from py_paddle import swig_paddle, DataProviderConverter @@ -66,35 +66,42 @@ class SentimentPrediction(): for v in open(label_file, 'r'): self.label[int(v.split('\t')[1])] = v.split('\t')[0] - def get_data(self, data_file): + def get_data(self, data): """ Get input data of paddle format. """ - with open(data_file, 'r') as fdata: - for line in fdata: - words = line.strip().split() - word_slot = [ - self.word_dict[w] for w in words if w in self.word_dict - ] - if not word_slot: - print "all words are not in dictionary: %s", line - continue - yield [word_slot] - - def predict(self, data_file): - """ - data_file: file name of input data. - """ - input = self.converter(self.get_data(data_file)) - output = self.network.forwardTest(input) - prob = output[0]["value"] - lab = np.argsort(-prob) - if self.label is None: - print("%s: predicting label is %d" % (data_file, lab[0][0])) - else: - print("%s: predicting label is %s" % - (data_file, self.label[lab[0][0]])) + for line in data: + words = line.strip().split() + word_slot = [ + self.word_dict[w] for w in words if w in self.word_dict + ] + if not word_slot: + print "all words are not in dictionary: %s", line + continue + yield [word_slot] + + def predict(self, batch_size): + + def batch_predict(batch_data): + input = self.converter(self.get_data(batch_data)) + output = self.network.forwardTest(input) + prob = output[0]["value"] + labs = np.argsort(-prob) + for idx, lab in enumerate(labs): + if self.label is None: + print("predicting label is %d" % (lab[0])) + else: + print("predicting label is %s" % + (self.label[lab[0]])) + batch = [] + for line in sys.stdin: + batch.append(line) + if len(batch) == batch_size: + batch_predict(batch) + batch=[] + if len(batch) > 0: + batch_predict(batch) def option_parser(): usage = "python predict.py -n config -w model_dir -d dictionary -i input_file " @@ -119,11 +126,13 @@ def option_parser(): default=None, help="dictionary file") parser.add_option( - "-i", - "--data", + "-c", + "--batch_size", + type="int", action="store", - dest="data", - help="data file to predict") + dest="batch_size", + default=1, + help="the batch size for prediction") parser.add_option( "-w", "--model", @@ -137,13 +146,13 @@ def option_parser(): def main(): options, args = option_parser() train_conf = options.train_conf - data = options.data + batch_size = options.batch_size dict_file = options.dict_file model_path = options.model_path label = options.label swig_paddle.initPaddle("--use_gpu=0") predict = SentimentPrediction(train_conf, dict_file, model_path, label) - predict.predict(data) + predict.predict(batch_size) if __name__ == '__main__': diff --git a/demo/sentiment/predict.sh b/demo/sentiment/predict.sh index 053f23e491a..219d2d20250 100755 --- a/demo/sentiment/predict.sh +++ b/demo/sentiment/predict.sh @@ -19,9 +19,9 @@ set -e model=model_output/pass-00002/ config=trainer_config.py label=data/pre-imdb/labels.list -python predict.py \ - -n $config\ - -w $model \ - -b $label \ - -d ./data/pre-imdb/dict.txt \ - -i ./data/aclImdb/test/pos/10007_10.txt +cat ./data/aclImdb/test/pos/10007_10.txt | python predict.py \ + --tconf=$config\ + --model=$model \ + --label=$label \ + --dict=./data/pre-imdb/dict.txt \ + --batch_size=1 -- GitLab