diff --git a/demo/quick_start/api_predict.py b/demo/quick_start/api_predict.py index 9c224e3cdbab692cb18221aa193cbb9b699a3117..a1a9ef7bca3301f260a0ad899e46ec1b4a81b052 100755 --- a/demo/quick_start/api_predict.py +++ b/demo/quick_start/api_predict.py @@ -18,13 +18,12 @@ from optparse import OptionParser from py_paddle import swig_paddle, DataProviderConverter from paddle.trainer.PyDataProvider2 import sparse_binary_vector from paddle.trainer.config_parser import parse_config - - """ Usage: run following command to show help message. python api_predict.py -h """ + class QuickStartPrediction(): def __init__(self, train_conf, dict_file, model_dir=None, label_file=None): """ @@ -72,9 +71,7 @@ class QuickStartPrediction(): transform word into integer index according to the dictionary. """ words = data.strip().split() - word_slot = [ - self.word_dict[w] for w in words if w in self.word_dict - ] + word_slot = [self.word_dict[w] for w in words if w in self.word_dict] return word_slot def batch_predict(self, data_batch): @@ -84,6 +81,7 @@ class QuickStartPrediction(): print("predicting labels is:") print prob + def option_parser(): usage = "python predict.py -n config -w model_dir -d dictionary -i input_file " parser = OptionParser(usage="usage: %s [options]" % usage) @@ -144,5 +142,6 @@ def main(): print labels predict.batch_predict(batch) + if __name__ == '__main__': main()