infer.py 1.9 KB
Newer Older
P
pakchoi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
# -*- encoding:utf-8 -*-
import numpy as np
import glob
import gzip
import paddle.v2 as paddle
from nce_conf import network_conf


def main():
    paddle.init(use_gpu=False, trainer_count=1)
    word_dict = paddle.dataset.imikolov.build_dict()
    dict_size = len(word_dict)

    prediction_layer = network_conf(
        is_train=False,
        hidden_size=128,
        embedding_size=512,
        dict_size=dict_size)

    models_list = glob.glob('./models/*')
    models_list = sorted(models_list)

    with gzip.open(models_list[-1], 'r') as f:
        parameters = paddle.parameters.Parameters.from_tar(f)

    idx_word_dict = dict((v, k) for k, v in word_dict.items())
    batch_size = 64
    batch_ins = []
    ins_iter = paddle.dataset.imikolov.test(word_dict, 5)

    infer_data = []
    infer_data_label = []
    for item in paddle.dataset.imikolov.test(word_dict, 5)():
        infer_data.append((item[:4]))
        infer_data_label.append(item[4])
        # Choose 100 samples from the test set to show how to infer.
        if len(infer_data_label) == 100:
            break

    feeding = {
        'firstw': 0,
        'secondw': 1,
        'thirdw': 2,
        'fourthw': 3,
        'fifthw': 4
    }

    predictions = paddle.infer(
        output_layer=prediction_layer,
        parameters=parameters,
        input=infer_data,
        feeding=feeding,
        field=['value'])

    for i, (prob, data,
            label) in enumerate(zip(predictions, infer_data, infer_data_label)):
        print '--------------------------'
        print "No.%d Input: " % (i+1) + \
                idx_word_dict[data[0]] + ' ' + \
                idx_word_dict[data[1]] + ' ' + \
                idx_word_dict[data[2]] + ' ' + \
                idx_word_dict[data[3]]
        print 'Ground Truth Output: ' + idx_word_dict[label]
        print 'Predict Output: ' + idx_word_dict[prob.argsort(
            kind='heapsort', axis=0)[-1]]
        print


if __name__ == '__main__':
    main()