From 3d2896493b90283b6b5b8e49869d90f656f39783 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Fri, 9 Dec 2016 14:00:04 +0800 Subject: [PATCH] follow comments --- demo/sentiment/predict.py | 62 ++++++++----------- .../sentiment_analysis/sentiment_analysis.md | 29 ++++----- .../sentiment_analysis/sentiment_analysis.md | 29 ++++----- 3 files changed, 57 insertions(+), 63 deletions(-) diff --git a/demo/sentiment/predict.py b/demo/sentiment/predict.py index e01dc6d2282..3920a1bade6 100755 --- a/demo/sentiment/predict.py +++ b/demo/sentiment/predict.py @@ -66,42 +66,27 @@ class SentimentPrediction(): for v in open(label_file, 'r'): self.label[int(v.split('\t')[1])] = v.split('\t')[0] - def get_data(self, data): + def get_index(self, data): """ - Get input data of paddle format. + transform word into integer index according to the dictionary. """ - for line in data: - words = line.strip().split() - word_slot = [ - self.word_dict[w] for w in words if w in self.word_dict - ] - if not word_slot: - print "all words are not in dictionary: %s", line - continue - yield [word_slot] - - def predict(self, batch_size): - - def batch_predict(batch_data): - input = self.converter(self.get_data(batch_data)) - output = self.network.forwardTest(input) - prob = output[0]["value"] - labs = np.argsort(-prob) - for idx, lab in enumerate(labs): - if self.label is None: - print("predicting label is %d" % (lab[0])) - else: - print("predicting label is %s" % - (self.label[lab[0]])) - - batch = [] - for line in sys.stdin: - batch.append(line) - if len(batch) == batch_size: - batch_predict(batch) - batch=[] - if len(batch) > 0: - batch_predict(batch) + words = data.strip().split() + word_slot = [ + self.word_dict[w] for w in words if w in self.word_dict + ] + return word_slot + + def batch_predict(self, data_batch): + input = self.converter(data_batch) + output = self.network.forwardTest(input) + prob = output[0]["value"] + labs = np.argsort(-prob) + for idx, lab in enumerate(labs): + if self.label is None: + print("predicting label is %d" % (lab[0])) + else: + print("predicting label is %s" % + (self.label[lab[0]])) def option_parser(): usage = "python predict.py -n config -w model_dir -d dictionary -i input_file " @@ -152,8 +137,15 @@ def main(): label = options.label swig_paddle.initPaddle("--use_gpu=0") predict = SentimentPrediction(train_conf, dict_file, model_path, label) - predict.predict(batch_size) + batch = [] + for line in sys.stdin: + batch.append([predict.get_index(line)]) + if len(batch) == batch_size: + predict.batch_predict(batch) + batch=[] + if len(batch) > 0: + predict.batch_predict(batch) if __name__ == '__main__': main() diff --git a/doc/tutorials/sentiment_analysis/sentiment_analysis.md b/doc/tutorials/sentiment_analysis/sentiment_analysis.md index c53952c544d..bb7681db44c 100644 --- a/doc/tutorials/sentiment_analysis/sentiment_analysis.md +++ b/doc/tutorials/sentiment_analysis/sentiment_analysis.md @@ -293,20 +293,21 @@ predict.sh: model=model_output/pass-00002/ config=trainer_config.py label=data/pre-imdb/labels.list -python predict.py \ - -n $config\ - -w $model \ - -b $label \ - -d data/pre-imdb/dict.txt \ - -i data/aclImdb/test/pos/10007_10.txt -``` - -* `predict.py`: predicting interface. -* -n $config : set network configure. -* -w $model: set model path. -* -b $label: set dictionary about corresponding relation between integer label and string label. -* -d data/pre-imdb/dict.txt: set dictionary. -* -i data/aclImdb/test/pos/10014_7.txt: set one example file to predict. +cat ./data/aclImdb/test/pos/10007_10.txt | python predict.py \ + --tconf=$config\ + --model=$model \ + --label=$label \ + --dict=./data/pre-imdb/dict.txt \ + --batch_size=1 +``` + +* `cat ./data/aclImdb/test/pos/10007_10.txt` : the input sample. +* `predict.py` : predicting interface. +* `--tconf=$config` : set network configure. +* ` --model=$model` : set model path. +* `--label=$label` : set dictionary about corresponding relation between integer label and string label. +* `--dict=data/pre-imdb/dict.txt` : set dictionary. +* `--batch_size=1` : set batch size. Note you should make sure the default model path `model_output/pass-00002` exists or change the model path. diff --git a/doc_cn/demo/sentiment_analysis/sentiment_analysis.md b/doc_cn/demo/sentiment_analysis/sentiment_analysis.md index b70f2d59675..ba307e97e30 100644 --- a/doc_cn/demo/sentiment_analysis/sentiment_analysis.md +++ b/doc_cn/demo/sentiment_analysis/sentiment_analysis.md @@ -291,20 +291,21 @@ predict.sh: model=model_output/pass-00002/ config=trainer_config.py label=data/pre-imdb/labels.list -python predict.py \ - -n $config\ - -w $model \ - -b $label \ - -d data/pre-imdb/dict.txt \ - -i data/aclImdb/test/pos/10007_10.txt -``` - -* `predict.py`: 预测接口脚本。 -* -n $config : 设置网络配置。 -* -w $model: 设置模型路径。 -* -b $label: 设置标签类别字典,这个字典是整数标签和字符串标签的一个对应。 -* -d data/pre-imdb/dict.txt: 设置字典文件。 -* -i data/aclImdb/test/pos/10014_7.txt: 设置一个要预测的示例文件。 +cat ./data/aclImdb/test/pos/10007_10.txt | python predict.py \ + --tconf=$config\ + --model=$model \ + --label=$label \ + --dict=./data/pre-imdb/dict.txt \ + --batch_size=1 +``` + +* `cat ./data/aclImdb/test/pos/10007_10.txt` : 输入预测样本。 +* `predict.py` : 预测接口脚本。 +* `--tconf=$config` : 设置网络配置。 +* `--model=$model` : 设置模型路径。 +* `--label=$label` : 设置标签类别字典,这个字典是整数标签和字符串标签的一个对应。 +* `--dict=data/pre-imdb/dict.txt` : 设置字典文件。 +* `--batch_size=1` : 设置batch size。 注意应该确保默认模型路径`model_output / pass-00002`存在或更改为其它模型路径。 -- GitLab