infer.py 8.5 KB
Newer Older
Z
zhangwenhui03 已提交
1 2
import argparse
import sys
3
import time
Z
zhangwenhui03 已提交
4 5 6
import math
import unittest
import contextlib
7
import numpy as np
Z
zhangwenhui03 已提交
8 9 10 11 12
import six
import paddle.fluid as fluid
import paddle
import net
import utils
13 14 15


def parse_args():
Z
zhangwenhui03 已提交
16
    parser = argparse.ArgumentParser("PaddlePaddle Word2vec infer example")
17 18 19
    parser.add_argument(
        '--dict_path',
        type=str,
Z
zhangwenhui03 已提交
20 21
        default='./data/data_c/1-billion_dict_word_to_id_',
        help="The path of dic")
22
    parser.add_argument(
Z
zhangwenhui03 已提交
23
        '--infer_epoch',
24 25 26
        action='store_true',
        required=False,
        default=False,
Z
zhangwenhui03 已提交
27
        help='infer by epoch')
28
    parser.add_argument(
Z
zhangwenhui03 已提交
29
        '--infer_step',
30 31
        action='store_true',
        required=False,
J
JiabinYang 已提交
32
        default=False,
Z
zhangwenhui03 已提交
33
        help='infer by step')
34
    parser.add_argument(
Z
zhangwenhui03 已提交
35
        '--test_dir', type=str, default='test_data', help='test file address')
36
    parser.add_argument(
Z
zhangwenhui03 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
        '--print_step', type=int, default='500000', help='print step')
    parser.add_argument(
        '--start_index', type=int, default='0', help='start index')
    parser.add_argument(
        '--start_batch', type=int, default='1', help='start index')
    parser.add_argument(
        '--end_batch', type=int, default='13', help='start index')
    parser.add_argument(
        '--last_index', type=int, default='100', help='last index')
    parser.add_argument(
        '--model_dir', type=str, default='model', help='model dir')
    parser.add_argument(
        '--use_cuda', type=int, default='0', help='whether use cuda')
    parser.add_argument(
        '--batch_size', type=int, default='5', help='batch_size')
    parser.add_argument('--emb_size', type=int, default='64', help='batch_size')
    args = parser.parse_args()
    return args


def infer_epoch(args, vocab_size, test_reader, use_cuda, i2w):
    """ inference function """
    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
    exe = fluid.Executor(place)
    emb_size = args.emb_size
    batch_size = args.batch_size
Y
Yibing Liu 已提交
63
    with fluid.scope_guard(fluid.Scope()):
Z
zhangwenhui03 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
        main_program = fluid.Program()
        with fluid.program_guard(main_program):
            values, pred = net.infer_network(vocab_size, emb_size)
            for epoch in range(start_index, last_index + 1):
                copy_program = main_program.clone()
                model_path = model_dir + "/pass-" + str(epoch)
                fluid.io.load_params(
                    executor=exe, dirname=model_path, main_program=copy_program)
                accum_num = 0
                accum_num_sum = 0.0
                t0 = time.time()
                step_id = 0
                for data in test_reader():
                    step_id += 1
                    b_size = len([dat[0] for dat in data])
                    wa = np.array(
                        [dat[0] for dat in data]).astype("int64").reshape(
81
                            b_size)
Z
zhangwenhui03 已提交
82 83
                    wb = np.array(
                        [dat[1] for dat in data]).astype("int64").reshape(
84
                            b_size)
Z
zhangwenhui03 已提交
85 86
                    wc = np.array(
                        [dat[2] for dat in data]).astype("int64").reshape(
87
                            b_size)
Z
zhangwenhui03 已提交
88 89 90

                    label = [dat[3] for dat in data]
                    input_word = [dat[4] for dat in data]
91 92 93 94 95 96 97
                    para = exe.run(copy_program,
                                   feed={
                                       "analogy_a": wa,
                                       "analogy_b": wb,
                                       "analogy_c": wc,
                                       "all_label":
                                       np.arange(vocab_size).reshape(
98
                                           vocab_size).astype("int64"),
99 100 101
                                   },
                                   fetch_list=[pred.name, values],
                                   return_numpy=False)
Z
zhangwenhui03 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
                    pre = np.array(para[0])
                    val = np.array(para[1])
                    for ii in range(len(label)):
                        top4 = pre[ii]
                        accum_num_sum += 1
                        for idx in top4:
                            if int(idx) in input_word[ii]:
                                continue
                            if int(idx) == int(label[ii][0]):
                                accum_num += 1
                            break
                    if step_id % 1 == 0:
                        print("step:%d %d " % (step_id, accum_num))

                print("epoch:%d \t acc:%.3f " %
                      (epoch, 1.0 * accum_num / accum_num_sum))


def infer_step(args, vocab_size, test_reader, use_cuda, i2w):
    """ inference function """
    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
    exe = fluid.Executor(place)
    emb_size = args.emb_size
    batch_size = args.batch_size
Y
Yibing Liu 已提交
126
    with fluid.scope_guard(fluid.Scope()):
Z
zhangwenhui03 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
        main_program = fluid.Program()
        with fluid.program_guard(main_program):
            values, pred = net.infer_network(vocab_size, emb_size)
            for epoch in range(start_index, last_index + 1):
                for batchid in range(args.start_batch, args.end_batch):
                    copy_program = main_program.clone()
                    model_path = model_dir + "/pass-" + str(epoch) + (
                        '/batch-' + str(batchid * args.print_step))
                    fluid.io.load_params(
                        executor=exe,
                        dirname=model_path,
                        main_program=copy_program)
                    accum_num = 0
                    accum_num_sum = 0.0
                    t0 = time.time()
                    step_id = 0
                    for data in test_reader():
                        step_id += 1
                        b_size = len([dat[0] for dat in data])
                        wa = np.array(
                            [dat[0] for dat in data]).astype("int64").reshape(
148
                                b_size)
Z
zhangwenhui03 已提交
149 150
                        wb = np.array(
                            [dat[1] for dat in data]).astype("int64").reshape(
151
                                b_size)
Z
zhangwenhui03 已提交
152 153
                        wc = np.array(
                            [dat[2] for dat in data]).astype("int64").reshape(
154
                                b_size)
Z
zhangwenhui03 已提交
155 156 157 158 159 160 161 162 163 164

                        label = [dat[3] for dat in data]
                        input_word = [dat[4] for dat in data]
                        para = exe.run(
                            copy_program,
                            feed={
                                "analogy_a": wa,
                                "analogy_b": wb,
                                "analogy_c": wc,
                                "all_label":
165
                                np.arange(vocab_size).reshape(vocab_size),
Z
zhangwenhui03 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
                            },
                            fetch_list=[pred.name, values],
                            return_numpy=False)
                        pre = np.array(para[0])
                        val = np.array(para[1])
                        for ii in range(len(label)):
                            top4 = pre[ii]
                            accum_num_sum += 1
                            for idx in top4:
                                if int(idx) in input_word[ii]:
                                    continue
                                if int(idx) == int(label[ii][0]):
                                    accum_num += 1
                                break
                        if step_id % 1 == 0:
                            print("step:%d %d " % (step_id, accum_num))
                    print("epoch:%d \t acc:%.3f " %
                          (epoch, 1.0 * accum_num / accum_num_sum))
                    t1 = time.time()


if __name__ == "__main__":
188
    utils.check_version()
189
    args = parse_args()
Z
zhangwenhui03 已提交
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
    start_index = args.start_index
    last_index = args.last_index
    test_dir = args.test_dir
    model_dir = args.model_dir
    batch_size = args.batch_size
    dict_path = args.dict_path
    use_cuda = True if args.use_cuda else False
    print("start index: ", start_index, " last_index:", last_index)
    vocab_size, test_reader, id2word = utils.prepare_data(
        test_dir, dict_path, batch_size=batch_size)
    print("vocab_size:", vocab_size)
    if args.infer_step:
        infer_step(
            args,
            vocab_size,
            test_reader=test_reader,
            use_cuda=use_cuda,
            i2w=id2word)
J
JiabinYang 已提交
208
    else:
Z
zhangwenhui03 已提交
209 210 211 212 213 214
        infer_epoch(
            args,
            vocab_size,
            test_reader=test_reader,
            use_cuda=use_cuda,
            i2w=id2word)