infer.py 8.3 KB
Newer Older
Z
zhangwenhui03 已提交
1 2
import argparse
import sys
3
import time
Z
zhangwenhui03 已提交
4 5 6
import math
import unittest
import contextlib
7
import numpy as np
Z
zhangwenhui03 已提交
8 9 10 11 12
import six
import paddle.fluid as fluid
import paddle
import net
import utils
Z
zhang wenhui 已提交
13 14 15
if six.PY2:
    reload(sys)
    sys.setdefaultencoding('utf-8')
16 17 18


def parse_args():
Z
zhangwenhui03 已提交
19
    parser = argparse.ArgumentParser("PaddlePaddle Word2vec infer example")
20 21 22
    parser.add_argument(
        '--dict_path',
        type=str,
Z
zhangwenhui03 已提交
23 24
        default='./data/data_c/1-billion_dict_word_to_id_',
        help="The path of dic")
25
    parser.add_argument(
Z
zhangwenhui03 已提交
26
        '--infer_epoch',
27 28 29
        action='store_true',
        required=False,
        default=False,
Z
zhangwenhui03 已提交
30
        help='infer by epoch')
31
    parser.add_argument(
Z
zhangwenhui03 已提交
32
        '--infer_step',
33 34
        action='store_true',
        required=False,
J
JiabinYang 已提交
35
        default=False,
Z
zhangwenhui03 已提交
36
        help='infer by step')
37
    parser.add_argument(
Z
zhangwenhui03 已提交
38
        '--test_dir', type=str, default='test_data', help='test file address')
39
    parser.add_argument(
Z
zhangwenhui03 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
        '--print_step', type=int, default='500000', help='print step')
    parser.add_argument(
        '--start_index', type=int, default='0', help='start index')
    parser.add_argument(
        '--start_batch', type=int, default='1', help='start index')
    parser.add_argument(
        '--end_batch', type=int, default='13', help='start index')
    parser.add_argument(
        '--last_index', type=int, default='100', help='last index')
    parser.add_argument(
        '--model_dir', type=str, default='model', help='model dir')
    parser.add_argument(
        '--use_cuda', type=int, default='0', help='whether use cuda')
    parser.add_argument(
        '--batch_size', type=int, default='5', help='batch_size')
    parser.add_argument('--emb_size', type=int, default='64', help='batch_size')
    args = parser.parse_args()
    return args


def infer_epoch(args, vocab_size, test_reader, use_cuda, i2w):
    """ inference function """
    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
    exe = fluid.Executor(place)
    emb_size = args.emb_size
    batch_size = args.batch_size
Y
Yibing Liu 已提交
66
    with fluid.scope_guard(fluid.Scope()):
Z
zhangwenhui03 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
        main_program = fluid.Program()
        with fluid.program_guard(main_program):
            values, pred = net.infer_network(vocab_size, emb_size)
            for epoch in range(start_index, last_index + 1):
                copy_program = main_program.clone()
                model_path = model_dir + "/pass-" + str(epoch)
                fluid.io.load_params(
                    executor=exe, dirname=model_path, main_program=copy_program)
                accum_num = 0
                accum_num_sum = 0.0
                t0 = time.time()
                step_id = 0
                for data in test_reader():
                    step_id += 1
                    b_size = len([dat[0] for dat in data])
Z
zhang wenhui 已提交
82 83 84 85 86 87
                    wa = np.array([dat[0] for dat in data]).astype(
                        "int64").reshape(b_size)
                    wb = np.array([dat[1] for dat in data]).astype(
                        "int64").reshape(b_size)
                    wc = np.array([dat[2] for dat in data]).astype(
                        "int64").reshape(b_size)
Z
zhangwenhui03 已提交
88 89 90

                    label = [dat[3] for dat in data]
                    input_word = [dat[4] for dat in data]
91 92 93 94 95
                    para = exe.run(copy_program,
                                   feed={
                                       "analogy_a": wa,
                                       "analogy_b": wb,
                                       "analogy_c": wc,
Z
zhang wenhui 已提交
96 97
                                       "all_label": np.arange(vocab_size)
                                       .reshape(vocab_size).astype("int64"),
98 99 100
                                   },
                                   fetch_list=[pred.name, values],
                                   return_numpy=False)
Z
zhangwenhui03 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
                    pre = np.array(para[0])
                    val = np.array(para[1])
                    for ii in range(len(label)):
                        top4 = pre[ii]
                        accum_num_sum += 1
                        for idx in top4:
                            if int(idx) in input_word[ii]:
                                continue
                            if int(idx) == int(label[ii][0]):
                                accum_num += 1
                            break
                    if step_id % 1 == 0:
                        print("step:%d %d " % (step_id, accum_num))

                print("epoch:%d \t acc:%.3f " %
                      (epoch, 1.0 * accum_num / accum_num_sum))


def infer_step(args, vocab_size, test_reader, use_cuda, i2w):
    """ inference function """
    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
    exe = fluid.Executor(place)
    emb_size = args.emb_size
    batch_size = args.batch_size
Y
Yibing Liu 已提交
125
    with fluid.scope_guard(fluid.Scope()):
Z
zhangwenhui03 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
        main_program = fluid.Program()
        with fluid.program_guard(main_program):
            values, pred = net.infer_network(vocab_size, emb_size)
            for epoch in range(start_index, last_index + 1):
                for batchid in range(args.start_batch, args.end_batch):
                    copy_program = main_program.clone()
                    model_path = model_dir + "/pass-" + str(epoch) + (
                        '/batch-' + str(batchid * args.print_step))
                    fluid.io.load_params(
                        executor=exe,
                        dirname=model_path,
                        main_program=copy_program)
                    accum_num = 0
                    accum_num_sum = 0.0
                    t0 = time.time()
                    step_id = 0
                    for data in test_reader():
                        step_id += 1
                        b_size = len([dat[0] for dat in data])
Z
zhang wenhui 已提交
145 146 147 148 149 150
                        wa = np.array([dat[0] for dat in data]).astype(
                            "int64").reshape(b_size)
                        wb = np.array([dat[1] for dat in data]).astype(
                            "int64").reshape(b_size)
                        wc = np.array([dat[2] for dat in data]).astype(
                            "int64").reshape(b_size)
Z
zhangwenhui03 已提交
151 152 153 154 155 156 157 158 159 160

                        label = [dat[3] for dat in data]
                        input_word = [dat[4] for dat in data]
                        para = exe.run(
                            copy_program,
                            feed={
                                "analogy_a": wa,
                                "analogy_b": wb,
                                "analogy_c": wc,
                                "all_label":
Z
zhang wenhui 已提交
161
                                np.arange(vocab_size).reshape(vocab_size),
Z
zhangwenhui03 已提交
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
                            },
                            fetch_list=[pred.name, values],
                            return_numpy=False)
                        pre = np.array(para[0])
                        val = np.array(para[1])
                        for ii in range(len(label)):
                            top4 = pre[ii]
                            accum_num_sum += 1
                            for idx in top4:
                                if int(idx) in input_word[ii]:
                                    continue
                                if int(idx) == int(label[ii][0]):
                                    accum_num += 1
                                break
                        if step_id % 1 == 0:
                            print("step:%d %d " % (step_id, accum_num))
                    print("epoch:%d \t acc:%.3f " %
                          (epoch, 1.0 * accum_num / accum_num_sum))
                    t1 = time.time()


if __name__ == "__main__":
Z
zhang wenhui 已提交
184
    utils.check_version()
185
    args = parse_args()
Z
zhangwenhui03 已提交
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
    start_index = args.start_index
    last_index = args.last_index
    test_dir = args.test_dir
    model_dir = args.model_dir
    batch_size = args.batch_size
    dict_path = args.dict_path
    use_cuda = True if args.use_cuda else False
    print("start index: ", start_index, " last_index:", last_index)
    vocab_size, test_reader, id2word = utils.prepare_data(
        test_dir, dict_path, batch_size=batch_size)
    print("vocab_size:", vocab_size)
    if args.infer_step:
        infer_step(
            args,
            vocab_size,
            test_reader=test_reader,
            use_cuda=use_cuda,
            i2w=id2word)
J
JiabinYang 已提交
204
    else:
Z
zhangwenhui03 已提交
205 206 207 208 209 210
        infer_epoch(
            args,
            vocab_size,
            test_reader=test_reader,
            use_cuda=use_cuda,
            i2w=id2word)