infer.py 9.3 KB
Newer Older
S
slf12 已提交
1 2 3 4
import argparse
import sys
import time
import numpy as np
Z
zhouzj 已提交
5
import os
S
slf12 已提交
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
import paddle
import net
import utils
sys.path.append(sys.path[0] + "/../../../")
from paddleslim.quant import quant_embedding


def parse_args():
    parser = argparse.ArgumentParser("PaddlePaddle Word2vec infer example")
    parser.add_argument(
        '--dict_path',
        type=str,
        default='./data/data_c/1-billion_dict_word_to_id_',
        help="The path of dic")
    parser.add_argument(
        '--infer_epoch',
        action='store_true',
        required=False,
        default=False,
        help='infer by epoch')
    parser.add_argument(
        '--infer_step',
        action='store_true',
        required=False,
        default=False,
        help='infer by step')
    parser.add_argument(
        '--test_dir', type=str, default='test_data', help='test file address')
    parser.add_argument(
        '--print_step', type=int, default='500000', help='print step')
    parser.add_argument(
        '--start_index', type=int, default='0', help='start index')
    parser.add_argument(
        '--start_batch', type=int, default='1', help='start index')
    parser.add_argument(
        '--end_batch', type=int, default='13', help='start index')
    parser.add_argument(
        '--last_index', type=int, default='100', help='last index')
    parser.add_argument(
        '--model_dir', type=str, default='model', help='model dir')
    parser.add_argument(
        '--use_cuda', type=int, default='0', help='whether use cuda')
    parser.add_argument(
        '--batch_size', type=int, default='5', help='batch_size')
50
    parser.add_argument('--emb_size', type=int, default='64', help='batch_size')
S
slf12 已提交
51 52 53 54
    parser.add_argument(
        '--emb_quant',
        type=bool,
        default=False,
S
slf12 已提交
55
        help='whether to quant embedding parameter')
S
slf12 已提交
56 57 58 59 60 61
    args = parser.parse_args()
    return args


def infer_epoch(args, vocab_size, test_reader, use_cuda, i2w):
    """ inference function """
B
Bai Yifan 已提交
62 63
    place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
    exe = paddle.static.Executor(place)
S
slf12 已提交
64 65
    emb_size = args.emb_size
    batch_size = args.batch_size
B
Bai Yifan 已提交
66 67 68
    with paddle.static.scope_guard(paddle.static.Scope()):
        main_program = paddle.static.Program()
        with paddle.static.program_guard(main_program):
S
slf12 已提交
69 70 71
            values, pred = net.infer_network(vocab_size, emb_size)
            for epoch in range(start_index, last_index + 1):
                copy_program = main_program.clone()
Z
zhouzj 已提交
72
                model_path = os.path.join(model_dir, "pass-" + str(epoch))
B
Bai Yifan 已提交
73
                paddle.static.load(copy_program, model_path, exe)
S
slf12 已提交
74
                if args.emb_quant:
L
Liufang Sang 已提交
75 76 77 78 79 80
                    config = {
                        'quantize_op_types': 'lookup_table',
                        'lookup_table': {
                            'quantize_type': 'abs_max'
                        },
                    }
S
slf12 已提交
81
                    copy_program = quant_embedding(copy_program, place, config)
B
Bai Yifan 已提交
82 83
                    paddle.static.save(copy_program,
                                       './output_quant/pass-' + str(epoch))
S
slf12 已提交
84 85 86 87 88 89 90 91

                accum_num = 0
                accum_num_sum = 0.0
                t0 = time.time()
                step_id = 0
                for data in test_reader():
                    step_id += 1
                    b_size = len([dat[0] for dat in data])
Z
zhouzj 已提交
92 93 94 95 96 97 98 99 100
                    wa = np.array([dat[0]
                                   for dat in data]).astype("int64").reshape(
                                       b_size, 1)
                    wb = np.array([dat[1]
                                   for dat in data]).astype("int64").reshape(
                                       b_size, 1)
                    wc = np.array([dat[2]
                                   for dat in data]).astype("int64").reshape(
                                       b_size, 1)
S
slf12 已提交
101 102 103

                    label = [dat[3] for dat in data]
                    input_word = [dat[4] for dat in data]
Z
zhouzj 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
                    para = exe.run(
                        copy_program,
                        feed={
                            "analogy_a":
                            wa,
                            "analogy_b":
                            wb,
                            "analogy_c":
                            wc,
                            "all_label":
                            np.arange(vocab_size).reshape(vocab_size,
                                                          1).astype("int64"),
                        },
                        fetch_list=[pred.name, values],
                        return_numpy=False)
S
slf12 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
                    pre = np.array(para[0])
                    val = np.array(para[1])
                    for ii in range(len(label)):
                        top4 = pre[ii]
                        accum_num_sum += 1
                        for idx in top4:
                            if int(idx) in input_word[ii]:
                                continue
                            if int(idx) == int(label[ii][0]):
                                accum_num += 1
                            break
                    if step_id % 1 == 0:
                        print("step:%d %d " % (step_id, accum_num))

                print("epoch:%d \t acc:%.3f " %
                      (epoch, 1.0 * accum_num / accum_num_sum))


def infer_step(args, vocab_size, test_reader, use_cuda, i2w):
    """ inference function """
B
Bai Yifan 已提交
139 140
    place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
    exe = paddle.static.Executor(place)
S
slf12 已提交
141 142
    emb_size = args.emb_size
    batch_size = args.batch_size
B
Bai Yifan 已提交
143 144 145
    with paddle.static.scope_guard(paddle.static.Scope()):
        main_program = paddle.static.Program()
        with paddle.static.program_guard(main_program):
S
slf12 已提交
146 147 148 149 150 151
            values, pred = net.infer_network(vocab_size, emb_size)
            for epoch in range(start_index, last_index + 1):
                for batchid in range(args.start_batch, args.end_batch):
                    copy_program = main_program.clone()
                    model_path = model_dir + "/pass-" + str(epoch) + (
                        '/batch-' + str(batchid * args.print_step))
B
Bai Yifan 已提交
152
                    paddle.static.load(copy_program, model_path, exe)
S
slf12 已提交
153 154 155 156 157 158 159
                    accum_num = 0
                    accum_num_sum = 0.0
                    t0 = time.time()
                    step_id = 0
                    for data in test_reader():
                        step_id += 1
                        b_size = len([dat[0] for dat in data])
Z
zhouzj 已提交
160 161 162 163 164 165 166 167 168
                        wa = np.array([dat[0] for dat in
                                       data]).astype("int64").reshape(
                                           b_size, 1)
                        wb = np.array([dat[1] for dat in
                                       data]).astype("int64").reshape(
                                           b_size, 1)
                        wc = np.array([dat[2] for dat in
                                       data]).astype("int64").reshape(
                                           b_size, 1)
S
slf12 已提交
169 170 171 172 173 174

                        label = [dat[3] for dat in data]
                        input_word = [dat[4] for dat in data]
                        para = exe.run(
                            copy_program,
                            feed={
Z
zhouzj 已提交
175 176 177 178 179 180
                                "analogy_a":
                                wa,
                                "analogy_b":
                                wb,
                                "analogy_c":
                                wc,
S
slf12 已提交
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
                                "all_label":
                                np.arange(vocab_size).reshape(vocab_size, 1),
                            },
                            fetch_list=[pred.name, values],
                            return_numpy=False)
                        pre = np.array(para[0])
                        val = np.array(para[1])
                        for ii in range(len(label)):
                            top4 = pre[ii]
                            accum_num_sum += 1
                            for idx in top4:
                                if int(idx) in input_word[ii]:
                                    continue
                                if int(idx) == int(label[ii][0]):
                                    accum_num += 1
                                break
                        if step_id % 1 == 0:
                            print("step:%d %d " % (step_id, accum_num))
                    print("epoch:%d \t acc:%.3f " %
                          (epoch, 1.0 * accum_num / accum_num_sum))
                    t1 = time.time()


if __name__ == "__main__":
205
    paddle.enable_static()
S
slf12 已提交
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
    args = parse_args()
    start_index = args.start_index
    last_index = args.last_index
    test_dir = args.test_dir
    model_dir = args.model_dir
    batch_size = args.batch_size
    dict_path = args.dict_path
    use_cuda = True if args.use_cuda else False
    print("start index: ", start_index, " last_index:", last_index)
    vocab_size, test_reader, id2word = utils.prepare_data(
        test_dir, dict_path, batch_size=batch_size)
    print("vocab_size:", vocab_size)
    if args.infer_step:
        infer_step(
            args,
            vocab_size,
            test_reader=test_reader,
            use_cuda=use_cuda,
            i2w=id2word)
    else:
        infer_epoch(
            args,
            vocab_size,
            test_reader=test_reader,
            use_cuda=use_cuda,
            i2w=id2word)