infer.py 4.0 KB
Newer Older
H
hetianjian 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#  Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

H
hetianjian 已提交
15 16 17 18 19 20 21
import argparse
import logging
import numpy as np
import os
import paddle
import paddle.fluid as fluid
import reader
H
hutuxian 已提交
22
import network
H
hutuxian 已提交
23
import sys
H
hetianjian 已提交
24 25 26 27 28 29 30 31 32 33 34 35

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)


def parse_args():
    parser = argparse.ArgumentParser(description="PaddlePaddle DIN example")
    parser.add_argument(
        '--model_path', type=str, default='./saved_model/', help="path of model parameters")
    parser.add_argument(
        '--test_path', type=str, default='./data/diginetica/test.txt', help='dir of test file')
H
hutuxian 已提交
36 37
    parser.add_argument(
        '--config_path', type=str, default='./data/diginetica/config.txt', help='dir of config')
H
hetianjian 已提交
38 39 40 41 42 43 44 45
    parser.add_argument(
        '--use_cuda', type=int, default=1, help='whether to use gpu')
    parser.add_argument(
        '--batch_size', type=int, default=100, help='input batch size')
    parser.add_argument(
        '--start_index', type=int, default='0', help='start index')
    parser.add_argument(
        '--last_index', type=int, default='10', help='end index')
H
hutuxian 已提交
46 47 48 49
    parser.add_argument(
        '--hidden_size', type=int, default=100, help='hidden state size')
    parser.add_argument(
        '--step', type=int, default=1, help='gnn propogation steps')
H
hetianjian 已提交
50 51 52
    return parser.parse_args()


H
hutuxian 已提交
53
def infer(args):
H
hetianjian 已提交
54
    batch_size = args.batch_size
H
hutuxian 已提交
55
    items_num = reader.read_config(args.config_path)
H
hetianjian 已提交
56 57 58
    test_data = reader.Data(args.test_path, False)
    place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
    exe = fluid.Executor(place)
H
hutuxian 已提交
59
    loss, acc, py_reader, feed_datas = network.network(items_num, args.hidden_size, args.step, batch_size)
H
hutuxian 已提交
60 61
    exe.run(fluid.default_startup_program())
    infer_program = fluid.default_main_program().clone(for_test=True)
H
hetianjian 已提交
62

H
hutuxian 已提交
63 64 65 66 67 68 69
    for epoch_num in range(args.start_index, args.last_index + 1):
        model_path = args.model_path + "epoch_" + str(epoch_num)
        try:
            if not os.path.exists(model_path):
                raise ValueError()
            fluid.io.load_persistables(executor=exe, dirname=model_path,
                    main_program=infer_program)
H
hetianjian 已提交
70

H
hutuxian 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
            loss_sum = 0.0
            acc_sum = 0.0
            count = 0
            py_reader.decorate_paddle_reader(test_data.reader(batch_size, batch_size*20, False))
            py_reader.start()
            try:
                while True:
                    res = exe.run(infer_program,
                                  fetch_list=[loss.name, acc.name], use_program_cache=True)
                    loss_sum += res[0]
                    acc_sum += res[1]
                    count += 1
            except fluid.core.EOFException:
                py_reader.reset()
            logger.info("TEST --> loss: %.4lf, Recall@20: %.4lf" %
                        (loss_sum / count, acc_sum / count))
        except ValueError as e:
            logger.info("TEST --> error: there is no model in " + model_path)
H
hetianjian 已提交
89 90


H
hutuxian 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
def check_version():
    """
    Log error and exit when the installed version of paddlepaddle is
    not satisfied.
    """
    err = "PaddlePaddle version 1.6 or higher is required, " \
          "or a suitable develop version is satisfied as well. \n" \
          "Please make sure the version is good with your code." \

    try:
        fluid.require_version('1.6.0')
    except Exception as e:
        logger.error(err)
        sys.exit(1)

H
hetianjian 已提交
106
if __name__ == "__main__":
H
hutuxian 已提交
107
    check_version()
H
hetianjian 已提交
108
    args = parse_args()
H
hutuxian 已提交
109
    infer(args)