infer.py 10.1 KB
Newer Older
1 2 3 4 5 6
import time
import os
import paddle.fluid as fluid
import numpy as np
import logging
import argparse
7
import preprocess
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41

word_to_id = dict()
id_to_word = dict()

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)


def parse_args():
    parser = argparse.ArgumentParser(
        description="PaddlePaddle Word2vec infer example")
    parser.add_argument(
        '--dict_path',
        type=str,
        default='./data/1-billion_dict',
        help="The path of training dataset")
    parser.add_argument(
        '--model_output_dir',
        type=str,
        default='models',
        help="The path for model to store (with infer_once please set specify dir to models) (default: models)"
    )
    parser.add_argument(
        '--rank_num',
        type=int,
        default=4,
        help="find rank_num-nearest result for test (default: 4)")
    parser.add_argument(
        '--infer_once',
        action='store_true',
        required=False,
        default=False,
        help='if using infer_once, (default: False)')
J
JiabinYang 已提交
42 43 44 45 46 47
    parser.add_argument(
        '--infer_during_train',
        action='store_true',
        required=False,
        default=True,
        help='if using infer_during_train, (default: True)')
48 49 50 51
    parser.add_argument(
        '--test_acc',
        action='store_true',
        required=False,
J
JiabinYang 已提交
52 53
        default=False,
        help='if using test_files , (default: False)')
54 55 56 57 58 59 60 61 62 63
    parser.add_argument(
        '--test_files_dir',
        type=str,
        default='test',
        help="The path for test_files) (default: test)")
    parser.add_argument(
        '--test_batch_size',
        type=int,
        default=1000,
        help="test used batch size (default: 1000)")
64 65 66 67 68 69 70 71 72 73 74

    return parser.parse_args()


def BuildWord_IdMap(dict_path):
    with open(dict_path + "_word_to_id_", 'r') as f:
        for line in f:
            word_to_id[line.split(' ')[0]] = int(line.split(' ')[1])
            id_to_word[int(line.split(' ')[1])] = line.split(' ')[0]


75
def inference_prog():  # just to create program for test
76 77 78 79
    fluid.layers.create_parameter(
        shape=[1, 1], dtype='float32', name="embeding")


80 81 82 83 84 85
def build_test_case_from_file(args, emb):
    logger.info("test files dir: {}".format(args.test_files_dir))
    current_list = os.listdir(args.test_files_dir)
    logger.info("test files list: {}".format(current_list))
    test_cases = list()
    test_labels = list()
J
JiabinYang 已提交
86
    test_case_descs = list()
87 88 89 90
    exclude_lists = list()
    for file_dir in current_list:
        with open(args.test_files_dir + "/" + file_dir, 'r') as f:
            for line in f:
J
JiabinYang 已提交
91
                if ':' in line:
92 93 94 95 96 97 98 99 100
                    logger.info("{}".format(line))
                    pass
                else:
                    line = preprocess.strip_lines(line, word_to_id)
                    test_case = emb[word_to_id[line.split()[0]]] - emb[
                        word_to_id[line.split()[1]]] + emb[word_to_id[
                            line.split()[2]]]
                    test_case_desc = line.split()[0] + " - " + line.split()[
                        1] + " + " + line.split()[2] + " = " + line.split()[3]
J
JiabinYang 已提交
101 102
                    test_cases.append(test_case)
                    test_case_descs.append(test_case_desc)
103 104 105 106 107
                    test_labels.append(word_to_id[line.split()[3]])
                    exclude_lists.append([
                        word_to_id[line.split()[0]],
                        word_to_id[line.split()[1]], word_to_id[line.split()[2]]
                    ])
J
JiabinYang 已提交
108 109
            test_cases = norm(np.array(test_cases))
    return test_cases, test_case_descs, test_labels, exclude_lists
110 111 112


def build_small_test_case(emb):
113 114 115
    emb1 = emb[word_to_id['boy']] - emb[word_to_id['girl']] + emb[word_to_id[
        'aunt']]
    desc1 = "boy - girl + aunt = uncle"
116
    label1 = word_to_id["uncle"]
117 118 119
    emb2 = emb[word_to_id['brother']] - emb[word_to_id['sister']] + emb[
        word_to_id['sisters']]
    desc2 = "brother - sister + sisters = brothers"
120
    label2 = word_to_id["brothers"]
121 122 123
    emb3 = emb[word_to_id['king']] - emb[word_to_id['queen']] + emb[word_to_id[
        'woman']]
    desc3 = "king - queen + woman = man"
124
    label3 = word_to_id["man"]
125 126 127
    emb4 = emb[word_to_id['reluctant']] - emb[word_to_id['reluctantly']] + emb[
        word_to_id['slowly']]
    desc4 = "reluctant - reluctantly + slowly = slow"
128
    label4 = word_to_id["slow"]
129 130 131
    emb5 = emb[word_to_id['old']] - emb[word_to_id['older']] + emb[word_to_id[
        'deeper']]
    desc5 = "old - older + deeper = deep"
132
    label5 = word_to_id["deep"]
J
JiabinYang 已提交
133

J
JiabinYang 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
    emb6 = emb[word_to_id['boy']]
    desc6 = "boy"
    label6 = word_to_id["boy"]
    emb7 = emb[word_to_id['king']]
    desc7 = "king"
    label7 = word_to_id["king"]
    emb8 = emb[word_to_id['sun']]
    desc8 = "sun"
    label8 = word_to_id["sun"]
    emb9 = emb[word_to_id['key']]
    desc9 = "key"
    label9 = word_to_id["key"]
    test_cases = [emb1, emb2, emb3, emb4, emb5, emb6, emb7, emb8, emb9]
    test_case_desc = [
        desc1, desc2, desc3, desc4, desc5, desc6, desc7, desc8, desc9
    ]
    test_labels = [
        label1, label2, label3, label4, label5, label6, label7, label8, label9
    ]
J
JiabinYang 已提交
153
    return norm(np.array(test_cases)), test_case_desc, test_labels
154 155 156 157 158 159 160


def build_test_case(args, emb):
    if args.test_acc:
        return build_test_case_from_file(args, emb)
    else:
        return build_small_test_case(emb)
161 162


J
JiabinYang 已提交
163
def norm(x):
J
JiabinYang 已提交
164 165
    y = np.linalg.norm(x, axis=1, keepdims=True)
    return x / y
J
JiabinYang 已提交
166 167


J
JiabinYang 已提交
168
def inference_test(scope, model_dir, args):
169 170
    BuildWord_IdMap(args.dict_path)
    logger.info("model_dir is: {}".format(model_dir + "/"))
J
JiabinYang 已提交
171
    emb = np.array(scope.find_var("embeding").get_tensor())
J
JiabinYang 已提交
172
    x = norm(emb)
173
    logger.info("inference result: ====================")
J
JiabinYang 已提交
174 175
    test_cases = None
    test_case_desc = list()
176 177 178
    test_labels = list()
    exclude_lists = list()
    if args.test_acc:
J
JiabinYang 已提交
179 180
        test_cases, test_case_desc, test_labels, exclude_lists = build_test_case(
            args, emb)
181
    else:
J
JiabinYang 已提交
182
        test_cases, test_case_desc, test_labels = build_test_case(args, emb)
183 184 185
        exclude_lists = [[-1]]
    accual_rank = 1 if args.test_acc else args.rank_num
    correct_num = 0
J
JiabinYang 已提交
186 187 188
    cosine_similarity_matrix = np.dot(test_cases, x.T)
    results = topKs(accual_rank, cosine_similarity_matrix, exclude_lists,
                    args.test_acc)
189
    for i in range(len(test_labels)):
J
JiabinYang 已提交
190 191
        logger.info("Test result for {}".format(test_case_desc[i]))
        result = results[i]
192
        for j in range(accual_rank):
J
JiabinYang 已提交
193 194
            if result[j][1] == test_labels[
                    i]:  # if the nearest word is what we want 
195
                correct_num += 1
J
JiabinYang 已提交
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
            logger.info("{} nearest is {}, rate is {}".format(j, id_to_word[
                result[j][1]], result[j][0]))
    logger.info("Test acc is: {}, there are {} / {}".format(correct_num / len(
        test_labels), correct_num, len(test_labels)))


def topK(k, cosine_similarity_list, exclude_list, is_acc=False):
    if k == 1 and is_acc:  # accelerate acc calculate
        max = cosine_similarity_list[0]
        id = 0
        for i in range(len(cosine_similarity_list)):
            if cosine_similarity_list[i] >= max and (i not in exclude_list):
                max = cosine_similarity_list[i]
                id = i
            else:
                pass
        return [[max, id]]
    else:
J
JiabinYang 已提交
214 215 216 217 218 219
        result = list()
        result_index = np.argpartition(cosine_similarity_list, -k)[-k:]
        for index in result_index:
            result.append([cosine_similarity_list[index], index])
        result.sort(reverse=True)
        return result
220

J
JiabinYang 已提交
221 222 223 224 225 226 227 228 229 230 231

def topKs(k, cosine_similarity_matrix, exclude_lists, is_acc=False):
    results = list()
    result_queues = list()
    correct_num = 0

    for i in range(cosine_similarity_matrix.shape[0]):
        tmp_pq = None
        if is_acc:
            tmp_pq = topK(k, cosine_similarity_matrix[i], exclude_lists[i],
                          is_acc)
232
        else:
J
JiabinYang 已提交
233 234 235
            tmp_pq = topK(k, cosine_similarity_matrix[i], exclude_lists[0],
                          is_acc)
        result_queues.append(tmp_pq)
J
JiabinYang 已提交
236
    return result_queues
237 238 239 240


def infer_during_train(args):
    model_file_list = list()
J
JiabinYang 已提交
241 242 243
    exe = fluid.Executor(fluid.CPUPlace())
    Scope = fluid.Scope()
    inference_prog()
244
    solved_new = True
245
    while True:
J
JiabinYang 已提交
246
        time.sleep(60)
247 248
        current_list = os.listdir(args.model_output_dir)
        if set(model_file_list) == set(current_list):
249 250 251
            if solved_new:
                solved_new = False
                logger.info("No New models created")
252 253
            pass
        else:
254
            solved_new = True
255 256 257 258 259 260 261 262 263
            increment_models = list()
            for f in current_list:
                if f not in model_file_list:
                    increment_models.append(f)
            logger.info("increment_models is : {}".format(increment_models))
            for model in increment_models:
                model_dir = args.model_output_dir + "/" + model
                if os.path.exists(model_dir + "/_success"):
                    logger.info("using models from " + model_dir)
J
JiabinYang 已提交
264 265 266 267
                    with fluid.scope_guard(Scope):
                        fluid.io.load_persistables(
                            executor=exe, dirname=model_dir + "/")
                        inference_test(Scope, model_dir, args)
268 269 270 271
            model_file_list = current_list


def infer_once(args):
J
JiabinYang 已提交
272 273
    # check models file has already been finished
    if os.path.exists(args.model_output_dir + "/_success"):
274
        logger.info("using models from " + args.model_output_dir)
J
JiabinYang 已提交
275 276 277 278 279 280 281
        exe = fluid.Executor(fluid.CPUPlace())
        Scope = fluid.Scope()
        inference_prog()
        with fluid.scope_guard(Scope):
            fluid.io.load_persistables(
                executor=exe, dirname=args.model_output_dir + "/")
            inference_test(Scope, args.model_output_dir, args)
J
JiabinYang 已提交
282 283
    else:
        logger.info("Wrong Directory or save model failed!")
284 285 286 287 288 289 290


if __name__ == '__main__':
    args = parse_args()
    # while setting infer_once please specify the dir to models file with --model_output_dir
    if args.infer_once:
        infer_once(args)
J
JiabinYang 已提交
291
    elif args.infer_during_train:
292
        infer_during_train(args)
J
JiabinYang 已提交
293 294
    else:
        pass