predict.py 4.7 KB
Newer Older
1
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os, sys
Z
zhangjinchao01 已提交
16 17
import numpy as np
from optparse import OptionParser
18 19
from py_paddle import swig_paddle, DataProviderConverter
from paddle.trainer.PyDataProvider2 import integer_value_sequence
Z
zhangjinchao01 已提交
20 21 22
from paddle.trainer.config_parser import parse_config
"""
Usage: run following command to show help message.
23
  python predict.py -h
Z
zhangjinchao01 已提交
24 25
"""

26

Z
zhangjinchao01 已提交
27
class SentimentPrediction():
28
    def __init__(self, train_conf, dict_file, model_dir=None, label_file=None):
Z
zhangjinchao01 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
        """
        train_conf: trainer configure.
        dict_file: word dictionary file name.
        model_dir: directory of model.
        """
        self.train_conf = train_conf
        self.dict_file = dict_file
        self.word_dict = {}
        self.dict_dim = self.load_dict()
        self.model_dir = model_dir
        if model_dir is None:
            self.model_dir = os.path.dirname(train_conf)

        self.label = None
        if label_file is not None:
            self.load_label(label_file)

        conf = parse_config(train_conf, "is_predict=1")
47 48
        self.network = swig_paddle.GradientMachine.createFromConfigProto(
            conf.model_config)
Z
zhangjinchao01 已提交
49
        self.network.loadParameters(self.model_dir)
E
emailweixu 已提交
50 51
        input_types = [integer_value_sequence(self.dict_dim)]
        self.converter = DataProviderConverter(input_types)
Z
zhangjinchao01 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64

    def load_dict(self):
        """
        Load dictionary from self.dict_file.
        """
        for line_count, line in enumerate(open(self.dict_file, 'r')):
            self.word_dict[line.strip().split('\t')[0]] = line_count
        return len(self.word_dict)

    def load_label(self, label_file):
        """
        Load label.
        """
65
        self.label = {}
Z
zhangjinchao01 已提交
66 67 68
        for v in open(label_file, 'r'):
            self.label[int(v.split('\t')[1])] = v.split('\t')[0]

D
dangqingqing 已提交
69
    def get_index(self, data):
Z
zhangjinchao01 已提交
70
        """
D
dangqingqing 已提交
71
        transform word into integer index according to the dictionary.
Z
zhangjinchao01 已提交
72
        """
D
dangqingqing 已提交
73
        words = data.strip().split()
Y
Yu Yang 已提交
74
        word_slot = [self.word_dict[w] for w in words if w in self.word_dict]
D
dangqingqing 已提交
75 76 77 78 79 80 81 82 83 84 85
        return word_slot

    def batch_predict(self, data_batch):
        input = self.converter(data_batch)
        output = self.network.forwardTest(input)
        prob = output[0]["value"]
        labs = np.argsort(-prob)
        for idx, lab in enumerate(labs):
            if self.label is None:
                print("predicting label is %d" % (lab[0]))
            else:
Y
Yu Yang 已提交
86 87
                print("predicting label is %s" % (self.label[lab[0]]))

Z
zhangjinchao01 已提交
88 89 90 91

def option_parser():
    usage = "python predict.py -n config -w model_dir -d dictionary -i input_file "
    parser = OptionParser(usage="usage: %s [options]" % usage)
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
    parser.add_option(
        "-n",
        "--tconf",
        action="store",
        dest="train_conf",
        help="network config")
    parser.add_option(
        "-d",
        "--dict",
        action="store",
        dest="dict_file",
        help="dictionary file")
    parser.add_option(
        "-b",
        "--label",
        action="store",
        dest="label",
        default=None,
        help="dictionary file")
    parser.add_option(
112 113 114
        "-c",
        "--batch_size",
        type="int",
115
        action="store",
116 117 118
        dest="batch_size",
        default=1,
        help="the batch size for prediction")
119 120 121 122 123 124 125
    parser.add_option(
        "-w",
        "--model",
        action="store",
        dest="model_path",
        default=None,
        help="model path")
Z
zhangjinchao01 已提交
126 127
    return parser.parse_args()

128

Z
zhangjinchao01 已提交
129 130 131
def main():
    options, args = option_parser()
    train_conf = options.train_conf
132
    batch_size = options.batch_size
Z
zhangjinchao01 已提交
133 134 135 136 137 138
    dict_file = options.dict_file
    model_path = options.model_path
    label = options.label
    swig_paddle.initPaddle("--use_gpu=0")
    predict = SentimentPrediction(train_conf, dict_file, model_path, label)

D
dangqingqing 已提交
139 140
    batch = []
    for line in sys.stdin:
D
dangqingqing 已提交
141 142 143 144 145
        words = predict.get_index(line)
        if words:
            batch.append([words])
        else:
            print('All the words in [%s] are not in the dictionary.' % line)
D
dangqingqing 已提交
146 147
        if len(batch) == batch_size:
            predict.batch_predict(batch)
Y
Yu Yang 已提交
148
            batch = []
D
dangqingqing 已提交
149 150
    if len(batch) > 0:
        predict.batch_predict(batch)
151

Y
Yu Yang 已提交
152

Z
zhangjinchao01 已提交
153 154
if __name__ == '__main__':
    main()