predict.py 6.5 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import numpy as np
from optparse import OptionParser
18 19
from py_paddle import swig_paddle, DataProviderConverter
from paddle.trainer.PyDataProvider2 import integer_value_sequence
Z
zhangjinchao01 已提交
20 21 22
from paddle.trainer.config_parser import parse_config
"""
Usage: run following command to show help message.
23
  python predict.py -h
Z
zhangjinchao01 已提交
24 25 26 27 28
"""
UNK_IDX = 0


class Prediction():
Z
zhangjinchao01 已提交
29
    def __init__(self, train_conf, dict_file, model_dir, label_file, predicate_dict_file):
Z
zhangjinchao01 已提交
30 31 32 33 34 35 36 37
        """
        train_conf: trainer configure.
        dict_file: word dictionary file name.
        model_dir: directory of model.
        """

        self.dict = {}
        self.labels = {}
Z
zhangjinchao01 已提交
38
        self.predicate_dict={}
Z
zhangjinchao01 已提交
39
        self.labels_reverse = {}
Z
zhangjinchao01 已提交
40
        self.load_dict_label(dict_file, label_file, predicate_dict_file)
Z
zhangjinchao01 已提交
41 42 43

        len_dict = len(self.dict)
        len_label = len(self.labels)
Z
zhangjinchao01 已提交
44
        len_pred = len(self.predicate_dict)
Z
zhangjinchao01 已提交
45 46 47

        conf = parse_config(
            train_conf,
Z
zhangjinchao01 已提交
48
            'dict_len=' + str(len_dict) + 
Z
zhangjinchao01 已提交
49
            ',label_len=' + str(len_label) +
Z
zhangjinchao01 已提交
50
            ',pred_len=' + str(len_pred) +
Z
zhangjinchao01 已提交
51 52 53 54 55
            ',is_predict=True')
        self.network = swig_paddle.GradientMachine.createFromConfigProto(
            conf.model_config)
        self.network.loadParameters(model_dir)

56 57
        slots = [
            integer_value_sequence(len_dict),
Z
zhangjinchao01 已提交
58
            integer_value_sequence(len_pred),
59 60 61 62
            integer_value_sequence(len_dict),
            integer_value_sequence(len_dict),
            integer_value_sequence(len_dict),
            integer_value_sequence(len_dict),
Z
zhangjinchao01 已提交
63
            integer_value_sequence(len_dict), 
64
            integer_value_sequence(2)
Z
zhangjinchao01 已提交
65
            ]
66
        self.converter = DataProviderConverter(slots)
Z
zhangjinchao01 已提交
67

Z
zhangjinchao01 已提交
68
    def load_dict_label(self, dict_file, label_file, predicate_dict_file):
Z
zhangjinchao01 已提交
69 70 71 72 73 74 75 76 77 78
        """
        Load dictionary from self.dict_file.
        """
        for line_count, line in enumerate(open(dict_file, 'r')):
            self.dict[line.strip()] = line_count

        for line_count, line in enumerate(open(label_file, 'r')):
            self.labels[line.strip()] = line_count
            self.labels_reverse[line_count] = line.strip()

Z
zhangjinchao01 已提交
79 80
        for line_count, line in enumerate(open(predicate_dict_file, 'r')):
            self.predicate_dict[line.strip()] = line_count
Z
zhangjinchao01 已提交
81 82 83 84 85 86
    def get_data(self, data_file):
        """
        Get input data of paddle format.
        """
        with open(data_file, 'r') as fdata:
            for line in fdata:
Z
zhangjinchao01 已提交
87
                sentence, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, label = line.strip(
Z
zhangjinchao01 已提交
88 89 90
                ).split('\t')
                words = sentence.split()
                sen_len = len(words)
Z
zhangjinchao01 已提交
91
                 
Z
zhangjinchao01 已提交
92
                word_slot = [self.dict.get(w, UNK_IDX) for w in words]
Z
zhangjinchao01 已提交
93 94
                predicate_slot = [self.predicate_dict.get(predicate, UNK_IDX)] * sen_len
                ctx_n2_slot = [self.dict.get(ctx_n2, UNK_IDX)] * sen_len
Z
zhangjinchao01 已提交
95 96 97
                ctx_n1_slot = [self.dict.get(ctx_n1, UNK_IDX)] * sen_len
                ctx_0_slot = [self.dict.get(ctx_0, UNK_IDX)] * sen_len
                ctx_p1_slot = [self.dict.get(ctx_p1, UNK_IDX)] * sen_len
Z
zhangjinchao01 已提交
98
                ctx_p2_slot = [self.dict.get(ctx_p2, UNK_IDX)] * sen_len
Z
zhangjinchao01 已提交
99 100 101

                marks = mark.split()
                mark_slot = [int(w) for w in marks]
Z
zhangjinchao01 已提交
102 103 104
                
                yield word_slot, predicate_slot, ctx_n2_slot, ctx_n1_slot, \
                      ctx_0_slot, ctx_p1_slot, ctx_p2_slot,  mark_slot
Z
zhangjinchao01 已提交
105

Z
zhangjinchao01 已提交
106
    def predict(self, data_file, output_file):
Z
zhangjinchao01 已提交
107 108 109 110 111
        """
        data_file: file name of input data.
        """
        input = self.converter(self.get_data(data_file))
        output = self.network.forwardTest(input)
Z
zhangjinchao01 已提交
112
        lab = output[0]["id"].tolist()
Z
zhangjinchao01 已提交
113

Z
zhangjinchao01 已提交
114
        with open(data_file, 'r') as fin, open(output_file, 'w') as fout:
Z
zhangjinchao01 已提交
115 116 117 118 119 120 121 122 123 124 125
            index = 0
            for line in fin:
                sen = line.split('\t')[0]
                len_sen = len(sen.split())
                line_labels = lab[index:index + len_sen]
                index += len_sen
                fout.write(sen + '\t' + ' '.join([self.labels_reverse[
                    i] for i in line_labels]) + '\n')


def option_parser():
Z
zhangjinchao01 已提交
126 127
    usage = ("python predict.py -c config -w model_dir " 
             "-d word dictionary -l label_file -i input_file  -p pred_dict_file")
Z
zhangjinchao01 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    parser = OptionParser(usage="usage: %s [options]" % usage)
    parser.add_option(
        "-c",
        "--tconf",
        action="store",
        dest="train_conf",
        help="network config")
    parser.add_option(
        "-d",
        "--dict",
        action="store",
        dest="dict_file",
        help="dictionary file")
    parser.add_option(
        "-l",
        "--label",
        action="store",
        dest="label_file",
        default=None,
        help="label file")
Z
zhangjinchao01 已提交
148 149 150 151 152 153 154
    parser.add_option(
        "-p",
        "--predict_dict_file",
        action="store",
        dest="predict_dict_file",
        default=None,
        help="predict_dict_file")
Z
zhangjinchao01 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167
    parser.add_option(
        "-i",
        "--data",
        action="store",
        dest="data_file",
        help="data file to predict")
    parser.add_option(
        "-w",
        "--model",
        action="store",
        dest="model_path",
        default=None,
        help="model path")
Z
zhangjinchao01 已提交
168 169 170 171 172 173 174 175

    parser.add_option(
        "-o",
        "--output_file",
        action="store",
        dest="output_file",
        default=None,
        help="output file")
Z
zhangjinchao01 已提交
176 177 178 179 180 181 182 183 184 185
    return parser.parse_args()


def main():
    options, args = option_parser()
    train_conf = options.train_conf
    data_file = options.data_file
    dict_file = options.dict_file
    model_path = options.model_path
    label_file = options.label_file
Z
zhangjinchao01 已提交
186 187
    predict_dict_file = options.predict_dict_file
    output_file = options.output_file
Z
zhangjinchao01 已提交
188 189

    swig_paddle.initPaddle("--use_gpu=0")
Z
zhangjinchao01 已提交
190 191
    predict = Prediction(train_conf, dict_file, model_path, label_file, predict_dict_file)
    predict.predict(data_file,output_file)
Z
zhangjinchao01 已提交
192 193 194 195


if __name__ == '__main__':
    main()