infer.py 9.6 KB
Newer Older
1 2 3 4 5 6 7
import time
import os
import paddle.fluid as fluid
import numpy as np
from Queue import PriorityQueue
import logging
import argparse
8
import preprocess
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
from sklearn.metrics.pairwise import cosine_similarity

word_to_id = dict()
id_to_word = dict()

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)


def parse_args():
    parser = argparse.ArgumentParser(
        description="PaddlePaddle Word2vec infer example")
    parser.add_argument(
        '--dict_path',
        type=str,
        default='./data/1-billion_dict',
        help="The path of training dataset")
    parser.add_argument(
        '--model_output_dir',
        type=str,
        default='models',
        help="The path for model to store (with infer_once please set specify dir to models) (default: models)"
    )
    parser.add_argument(
        '--rank_num',
        type=int,
        default=4,
        help="find rank_num-nearest result for test (default: 4)")
    parser.add_argument(
        '--infer_once',
        action='store_true',
        required=False,
        default=False,
        help='if using infer_once, (default: False)')
J
JiabinYang 已提交
44 45 46 47 48 49
    parser.add_argument(
        '--infer_during_train',
        action='store_true',
        required=False,
        default=True,
        help='if using infer_during_train, (default: True)')
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    parser.add_argument(
        '--test_acc',
        action='store_true',
        required=False,
        default=True,
        help='if using test_files , (default: True)')
    parser.add_argument(
        '--test_files_dir',
        type=str,
        default='test',
        help="The path for test_files) (default: test)")
    parser.add_argument(
        '--test_batch_size',
        type=int,
        default=1000,
        help="test used batch size (default: 1000)")
66 67 68 69 70 71 72 73 74 75 76

    return parser.parse_args()


def BuildWord_IdMap(dict_path):
    with open(dict_path + "_word_to_id_", 'r') as f:
        for line in f:
            word_to_id[line.split(' ')[0]] = int(line.split(' ')[1])
            id_to_word[int(line.split(' ')[1])] = line.split(' ')[0]


77
def inference_prog():  # just to create program for test
78 79 80 81
    fluid.layers.create_parameter(
        shape=[1, 1], dtype='float32', name="embeding")


82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
def build_test_case_from_file(args, emb):
    logger.info("test files dir: {}".format(args.test_files_dir))
    current_list = os.listdir(args.test_files_dir)
    logger.info("test files list: {}".format(current_list))
    test_cases = list()
    test_labels = list()
    exclude_lists = list()
    for file_dir in current_list:
        with open(args.test_files_dir + "/" + file_dir, 'r') as f:
            count = 0
            for line in f:
                if count == 0:
                    pass
                elif ':' in line:
                    logger.info("{}".format(line))
                    pass
                else:
                    line = preprocess.strip_lines(line, word_to_id)
                    test_case = emb[word_to_id[line.split()[0]]] - emb[
                        word_to_id[line.split()[1]]] + emb[word_to_id[
                            line.split()[2]]]
                    test_case_desc = line.split()[0] + " - " + line.split()[
                        1] + " + " + line.split()[2] + " = " + line.split()[3]
                    test_cases.append([test_case, test_case_desc])
                    test_labels.append(word_to_id[line.split()[3]])
                    exclude_lists.append([
                        word_to_id[line.split()[0]],
                        word_to_id[line.split()[1]], word_to_id[line.split()[2]]
                    ])
                count += 1
    return test_cases, test_labels, exclude_lists


def build_small_test_case(emb):
116 117 118
    emb1 = emb[word_to_id['boy']] - emb[word_to_id['girl']] + emb[word_to_id[
        'aunt']]
    desc1 = "boy - girl + aunt = uncle"
119
    label1 = word_to_id["uncle"]
120 121 122
    emb2 = emb[word_to_id['brother']] - emb[word_to_id['sister']] + emb[
        word_to_id['sisters']]
    desc2 = "brother - sister + sisters = brothers"
123
    label2 = word_to_id["brothers"]
124 125 126
    emb3 = emb[word_to_id['king']] - emb[word_to_id['queen']] + emb[word_to_id[
        'woman']]
    desc3 = "king - queen + woman = man"
127
    label3 = word_to_id["man"]
128 129 130
    emb4 = emb[word_to_id['reluctant']] - emb[word_to_id['reluctantly']] + emb[
        word_to_id['slowly']]
    desc4 = "reluctant - reluctantly + slowly = slow"
131
    label4 = word_to_id["slow"]
132 133 134
    emb5 = emb[word_to_id['old']] - emb[word_to_id['older']] + emb[word_to_id[
        'deeper']]
    desc5 = "old - older + deeper = deep"
135
    label5 = word_to_id["deep"]
136
    return [[emb1, desc1], [emb2, desc2], [emb3, desc3], [emb4, desc4],
137 138 139 140 141 142 143 144
            [emb5, desc5]], [label1, label2, label3, label4, label5]


def build_test_case(args, emb):
    if args.test_acc:
        return build_test_case_from_file(args, emb)
    else:
        return build_small_test_case(emb)
145 146


J
JiabinYang 已提交
147
def inference_test(scope, model_dir, args):
148 149
    BuildWord_IdMap(args.dict_path)
    logger.info("model_dir is: {}".format(model_dir + "/"))
J
JiabinYang 已提交
150
    emb = np.array(scope.find_var("embeding").get_tensor())
151
    logger.info("inference result: ====================")
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
    test_cases = list()
    test_labels = list()
    exclude_lists = list()
    if args.test_acc:
        test_cases, test_labels, exclude_lists = build_test_case(args, emb)
    else:
        test_cases, test_labels = build_test_case(args, emb)
        exclude_lists = [[-1]]
    accual_rank = 1 if args.test_acc else args.rank_num
    correct_num = 0
    for i in range(len(test_labels)):
        pq = None
        if args.test_acc:
            pq = topK(
                accual_rank,
                emb,
                test_cases[i][0],
                exclude_lists[i],
                is_acc=True)
        else:
            pq = pq = topK(
                accual_rank,
                emb,
                test_cases[i][0],
                exclude_lists[0],
                is_acc=False)
        logger.info("Test result for {}".format(test_cases[i][1]))
        for j in range(accual_rank):
            pq_tmps = pq.get()
            if (j == accual_rank - 1) and (
                    pq_tmps.id == test_labels[i]
            ):  # if the nearest word is what we want 
                correct_num += 1
            logger.info("{} nearest is {}, rate is {}".format(
                accual_rank - j, id_to_word[pq_tmps.id], pq_tmps.priority))
    acc = correct_num / len(test_labels)
    logger.info("Test acc is: {}, there are {} / {}}".format(acc, correct_num,
                                                             len(test_labels)))
190 191 192 193 194 195 196 197 198 199 200


class PQ_Entry(object):
    def __init__(self, cos_similarity, id):
        self.priority = cos_similarity
        self.id = id

    def __cmp__(self, other):
        return cmp(self.priority, other.priority)


201
def topK(k, emb, test_emb, exclude_list, is_acc=False):
202
    pq = PriorityQueue(k + 1)
J
JiabinYang 已提交
203 204 205 206 207 208 209
    while not pq.empty():
        try:
            pq.get(False)
        except Empty:
            continue
        pq.task_done()

210 211 212 213 214 215 216
    if len(emb) <= k:
        for i in range(len(emb)):
            x = cosine_similarity([emb[i]], [test_emb])
            pq.put(PQ_Entry(x, i))
        return pq

    for i in range(len(emb)):
217 218 219 220 221 222 223 224
        if is_acc and (i in exclude_list):
            pass
        else:
            x = cosine_similarity([emb[i]], [test_emb])
            pq_e = PQ_Entry(x, i)
            if pq.full():
                pq.get()
            pq.put(pq_e)
225 226 227 228 229 230
    pq.get()
    return pq


def infer_during_train(args):
    model_file_list = list()
J
JiabinYang 已提交
231 232 233
    exe = fluid.Executor(fluid.CPUPlace())
    Scope = fluid.Scope()
    inference_prog()
234
    solved_new = True
235
    while True:
J
JiabinYang 已提交
236
        time.sleep(60)
237
        current_list = os.listdir(args.model_output_dir)
J
JiabinYang 已提交
238 239
        # logger.info("current_list is : {}".format(current_list))
        # logger.info("model_file_list is : {}".format(model_file_list))
240
        if set(model_file_list) == set(current_list):
241 242 243
            if solved_new:
                solved_new = False
                logger.info("No New models created")
244 245
            pass
        else:
246
            solved_new = True
247 248 249 250 251 252 253 254 255
            increment_models = list()
            for f in current_list:
                if f not in model_file_list:
                    increment_models.append(f)
            logger.info("increment_models is : {}".format(increment_models))
            for model in increment_models:
                model_dir = args.model_output_dir + "/" + model
                if os.path.exists(model_dir + "/_success"):
                    logger.info("using models from " + model_dir)
J
JiabinYang 已提交
256 257 258 259
                    with fluid.scope_guard(Scope):
                        fluid.io.load_persistables(
                            executor=exe, dirname=model_dir + "/")
                        inference_test(Scope, model_dir, args)
260 261 262 263
            model_file_list = current_list


def infer_once(args):
J
JiabinYang 已提交
264 265
    # check models file has already been finished
    if os.path.exists(args.model_output_dir + "/_success"):
266
        logger.info("using models from " + args.model_output_dir)
J
JiabinYang 已提交
267 268 269 270 271 272 273
        exe = fluid.Executor(fluid.CPUPlace())
        Scope = fluid.Scope()
        inference_prog()
        with fluid.scope_guard(Scope):
            fluid.io.load_persistables(
                executor=exe, dirname=args.model_output_dir + "/")
            inference_test(Scope, args.model_output_dir, args)
274 275 276 277 278 279 280


if __name__ == '__main__':
    args = parse_args()
    # while setting infer_once please specify the dir to models file with --model_output_dir
    if args.infer_once:
        infer_once(args)
J
JiabinYang 已提交
281
    elif args.infer_during_train:
282
        infer_during_train(args)
J
JiabinYang 已提交
283 284
    else:
        pass