infer.py 4.0 KB
Newer Older
F
frankwhzhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import time
import six
import numpy as np
import math
import argparse
import logging
import paddle.fluid as fluid
import paddle
import time
import reader as reader
from nets import MultiviewSimnet, SimpleEncoderFactory

logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)


def parse_args():
    parser = argparse.ArgumentParser("multi-view simnet")
Y
Yibing Liu 已提交
36 37
    parser.add_argument("--train_file", type=str, help="Training file")
    parser.add_argument("--valid_file", type=str, help="Validation file")
F
frankwhzhang 已提交
38 39 40
    parser.add_argument(
        "--epochs", type=int, default=10, help="Number of epochs for training")
    parser.add_argument(
Y
Yibing Liu 已提交
41 42 43 44
        "--model_dir",
        type=str,
        default='model_output',
        help="Model output folder")
F
frankwhzhang 已提交
45 46 47 48 49
    parser.add_argument(
        "--query_slots", type=int, default=1, help="Number of query slots")
    parser.add_argument(
        "--title_slots", type=int, default=1, help="Number of title slots")
    parser.add_argument(
Y
Yibing Liu 已提交
50 51 52 53
        "--query_encoder",
        type=str,
        default="bow",
        help="Encoder module for slot encoding")
F
frankwhzhang 已提交
54
    parser.add_argument(
Y
Yibing Liu 已提交
55 56 57 58
        "--title_encoder",
        type=str,
        default="bow",
        help="Encoder module for slot encoding")
F
frankwhzhang 已提交
59
    parser.add_argument(
Y
Yibing Liu 已提交
60 61 62 63
        "--query_encode_dim",
        type=int,
        default=128,
        help="Dimension of query encoder output")
F
frankwhzhang 已提交
64
    parser.add_argument(
Y
Yibing Liu 已提交
65 66 67 68
        "--title_encode_dim",
        type=int,
        default=128,
        help="Dimension of title encoder output")
F
frankwhzhang 已提交
69 70 71
    parser.add_argument(
        "--batch_size", type=int, default=128, help="Batch size for training")
    parser.add_argument(
Y
Yibing Liu 已提交
72 73 74 75
        "--embedding_dim",
        type=int,
        default=128,
        help="Default Dimension of Embedding")
F
frankwhzhang 已提交
76
    parser.add_argument(
Y
Yibing Liu 已提交
77 78 79 80
        "--sparse_feature_dim",
        type=int,
        default=1000001,
        help="Sparse feature hashing space for index processing")
F
frankwhzhang 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94
    parser.add_argument(
        "--hidden_size", type=int, default=128, help="Hidden dim")
    return parser.parse_args()


def start_infer(args, model_path):
    dataset = reader.SyntheticDataset(args.sparse_feature_dim, args.query_slots,
                                      args.title_slots)
    test_reader = paddle.batch(
        paddle.reader.shuffle(
            dataset.valid(), buf_size=args.batch_size * 100),
        batch_size=args.batch_size)
    place = fluid.CPUPlace()
    exe = fluid.Executor(place)
Y
Yibing Liu 已提交
95 96

    with fluid.scope_guard(fluid.Scope()):
F
frankwhzhang 已提交
97
        infer_program, feed_target_names, fetch_vars = fluid.io.load_inference_model(
Y
Yibing Liu 已提交
98
            args.model_dir, exe)
F
frankwhzhang 已提交
99 100
        t0 = time.time()
        step_id = 0
Y
Yibing Liu 已提交
101 102
        feeder = fluid.DataFeeder(
            program=infer_program, feed_list=feed_target_names, place=place)
F
frankwhzhang 已提交
103 104 105 106 107 108 109 110 111
        for batch_id, data in enumerate(test_reader()):
            step_id += 1
            loss_val, correct_val = exe.run(infer_program,
                                            feed=feeder.feed(data),
                                            fetch_list=fetch_vars)
            logger.info("TRAIN --> pass: {} batch_id: {} avg_cost: {}, acc: {}"
                        .format(step_id, batch_id, loss_val,
                                float(correct_val) / args.batch_size))

Y
Yibing Liu 已提交
112

F
frankwhzhang 已提交
113 114 115 116
def main():
    args = parse_args()
    start_infer(args, args.model_dir)

Y
Yibing Liu 已提交
117

F
frankwhzhang 已提交
118 119
if __name__ == "__main__":
    main()