提交 8c7cc72b 编写于 作者: Q qiaolongfei 提交者: Yu Yang

add python api_predict for quick start

上级 c91b7906
...@@ -18,13 +18,12 @@ from optparse import OptionParser ...@@ -18,13 +18,12 @@ from optparse import OptionParser
from py_paddle import swig_paddle, DataProviderConverter from py_paddle import swig_paddle, DataProviderConverter
from paddle.trainer.PyDataProvider2 import sparse_binary_vector from paddle.trainer.PyDataProvider2 import sparse_binary_vector
from paddle.trainer.config_parser import parse_config from paddle.trainer.config_parser import parse_config
""" """
Usage: run following command to show help message. Usage: run following command to show help message.
python api_predict.py -h python api_predict.py -h
""" """
class QuickStartPrediction(): class QuickStartPrediction():
def __init__(self, train_conf, dict_file, model_dir=None, label_file=None): def __init__(self, train_conf, dict_file, model_dir=None, label_file=None):
""" """
...@@ -72,9 +71,7 @@ class QuickStartPrediction(): ...@@ -72,9 +71,7 @@ class QuickStartPrediction():
transform word into integer index according to the dictionary. transform word into integer index according to the dictionary.
""" """
words = data.strip().split() words = data.strip().split()
word_slot = [ word_slot = [self.word_dict[w] for w in words if w in self.word_dict]
self.word_dict[w] for w in words if w in self.word_dict
]
return word_slot return word_slot
def batch_predict(self, data_batch): def batch_predict(self, data_batch):
...@@ -84,6 +81,7 @@ class QuickStartPrediction(): ...@@ -84,6 +81,7 @@ class QuickStartPrediction():
print("predicting labels is:") print("predicting labels is:")
print prob print prob
def option_parser(): def option_parser():
usage = "python predict.py -n config -w model_dir -d dictionary -i input_file " usage = "python predict.py -n config -w model_dir -d dictionary -i input_file "
parser = OptionParser(usage="usage: %s [options]" % usage) parser = OptionParser(usage="usage: %s [options]" % usage)
...@@ -144,5 +142,6 @@ def main(): ...@@ -144,5 +142,6 @@ def main():
print labels print labels
predict.batch_predict(batch) predict.batch_predict(batch)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册