提交 aaecfcc4 编写于 作者: D dangqingqing

Support predicting the samples from sys.stdin

上级 db379811
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os, sys
import numpy as np
from optparse import OptionParser
from py_paddle import swig_paddle, DataProviderConverter
......@@ -66,12 +66,11 @@ class SentimentPrediction():
for v in open(label_file, 'r'):
self.label[int(v.split('\t')[1])] = v.split('\t')[0]
def get_data(self, data_file):
def get_data(self, data):
"""
Get input data of paddle format.
"""
with open(data_file, 'r') as fdata:
for line in fdata:
for line in data:
words = line.strip().split()
word_slot = [
self.word_dict[w] for w in words if w in self.word_dict
......@@ -81,20 +80,28 @@ class SentimentPrediction():
continue
yield [word_slot]
def predict(self, data_file):
"""
data_file: file name of input data.
"""
input = self.converter(self.get_data(data_file))
def predict(self, batch_size):
def batch_predict(batch_data):
input = self.converter(self.get_data(batch_data))
output = self.network.forwardTest(input)
prob = output[0]["value"]
lab = np.argsort(-prob)
labs = np.argsort(-prob)
for idx, lab in enumerate(labs):
if self.label is None:
print("%s: predicting label is %d" % (data_file, lab[0][0]))
print("predicting label is %d" % (lab[0]))
else:
print("%s: predicting label is %s" %
(data_file, self.label[lab[0][0]]))
print("predicting label is %s" %
(self.label[lab[0]]))
batch = []
for line in sys.stdin:
batch.append(line)
if len(batch) == batch_size:
batch_predict(batch)
batch=[]
if len(batch) > 0:
batch_predict(batch)
def option_parser():
usage = "python predict.py -n config -w model_dir -d dictionary -i input_file "
......@@ -119,11 +126,13 @@ def option_parser():
default=None,
help="dictionary file")
parser.add_option(
"-i",
"--data",
"-c",
"--batch_size",
type="int",
action="store",
dest="data",
help="data file to predict")
dest="batch_size",
default=1,
help="the batch size for prediction")
parser.add_option(
"-w",
"--model",
......@@ -137,13 +146,13 @@ def option_parser():
def main():
options, args = option_parser()
train_conf = options.train_conf
data = options.data
batch_size = options.batch_size
dict_file = options.dict_file
model_path = options.model_path
label = options.label
swig_paddle.initPaddle("--use_gpu=0")
predict = SentimentPrediction(train_conf, dict_file, model_path, label)
predict.predict(data)
predict.predict(batch_size)
if __name__ == '__main__':
......
......@@ -19,9 +19,9 @@ set -e
model=model_output/pass-00002/
config=trainer_config.py
label=data/pre-imdb/labels.list
python predict.py \
-n $config\
-w $model \
-b $label \
-d ./data/pre-imdb/dict.txt \
-i ./data/aclImdb/test/pos/10007_10.txt
cat ./data/aclImdb/test/pos/10007_10.txt | python predict.py \
--tconf=$config\
--model=$model \
--label=$label \
--dict=./data/pre-imdb/dict.txt \
--batch_size=1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册