infer.py 4.4 KB
Newer Older
F
frankwhzhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import time
import six
import numpy as np
import math
import argparse
import logging
import paddle.fluid as fluid
import paddle
import time
import reader as reader
from nets import MultiviewSimnet, SimpleEncoderFactory

logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)


Z
zhang wenhui 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
def check_version():
    """
     Log error and exit when the installed version of paddlepaddle is
     not satisfied.
     """
    err = "PaddlePaddle version 1.6 or higher is required, " \
          "or a suitable develop version is satisfied as well. \n" \
          "Please make sure the version is good with your code." \

    try:
        fluid.require_version('1.6.0')
    except Exception as e:
        logger.error(err)
        sys.exit(1)


F
frankwhzhang 已提交
50 51
def parse_args():
    parser = argparse.ArgumentParser("multi-view simnet")
Y
Yibing Liu 已提交
52 53
    parser.add_argument("--train_file", type=str, help="Training file")
    parser.add_argument("--valid_file", type=str, help="Validation file")
F
frankwhzhang 已提交
54 55 56
    parser.add_argument(
        "--epochs", type=int, default=10, help="Number of epochs for training")
    parser.add_argument(
Y
Yibing Liu 已提交
57 58 59 60
        "--model_dir",
        type=str,
        default='model_output',
        help="Model output folder")
F
frankwhzhang 已提交
61 62 63 64 65
    parser.add_argument(
        "--query_slots", type=int, default=1, help="Number of query slots")
    parser.add_argument(
        "--title_slots", type=int, default=1, help="Number of title slots")
    parser.add_argument(
Y
Yibing Liu 已提交
66 67 68 69
        "--query_encoder",
        type=str,
        default="bow",
        help="Encoder module for slot encoding")
F
frankwhzhang 已提交
70
    parser.add_argument(
Y
Yibing Liu 已提交
71 72 73 74
        "--title_encoder",
        type=str,
        default="bow",
        help="Encoder module for slot encoding")
F
frankwhzhang 已提交
75
    parser.add_argument(
Y
Yibing Liu 已提交
76 77 78 79
        "--query_encode_dim",
        type=int,
        default=128,
        help="Dimension of query encoder output")
F
frankwhzhang 已提交
80
    parser.add_argument(
Y
Yibing Liu 已提交
81 82 83 84
        "--title_encode_dim",
        type=int,
        default=128,
        help="Dimension of title encoder output")
F
frankwhzhang 已提交
85 86 87
    parser.add_argument(
        "--batch_size", type=int, default=128, help="Batch size for training")
    parser.add_argument(
Y
Yibing Liu 已提交
88 89 90 91
        "--embedding_dim",
        type=int,
        default=128,
        help="Default Dimension of Embedding")
F
frankwhzhang 已提交
92
    parser.add_argument(
Y
Yibing Liu 已提交
93 94 95 96
        "--sparse_feature_dim",
        type=int,
        default=1000001,
        help="Sparse feature hashing space for index processing")
F
frankwhzhang 已提交
97 98 99 100 101 102 103 104
    parser.add_argument(
        "--hidden_size", type=int, default=128, help="Hidden dim")
    return parser.parse_args()


def start_infer(args, model_path):
    dataset = reader.SyntheticDataset(args.sparse_feature_dim, args.query_slots,
                                      args.title_slots)
Z
zhang wenhui 已提交
105 106
    test_reader = fluid.io.batch(
        fluid.io.shuffle(
F
frankwhzhang 已提交
107 108 109 110
            dataset.valid(), buf_size=args.batch_size * 100),
        batch_size=args.batch_size)
    place = fluid.CPUPlace()
    exe = fluid.Executor(place)
Y
Yibing Liu 已提交
111 112

    with fluid.scope_guard(fluid.Scope()):
F
frankwhzhang 已提交
113
        infer_program, feed_target_names, fetch_vars = fluid.io.load_inference_model(
Y
Yibing Liu 已提交
114
            args.model_dir, exe)
F
frankwhzhang 已提交
115 116
        t0 = time.time()
        step_id = 0
Y
Yibing Liu 已提交
117 118
        feeder = fluid.DataFeeder(
            program=infer_program, feed_list=feed_target_names, place=place)
F
frankwhzhang 已提交
119 120 121 122 123 124 125 126 127
        for batch_id, data in enumerate(test_reader()):
            step_id += 1
            loss_val, correct_val = exe.run(infer_program,
                                            feed=feeder.feed(data),
                                            fetch_list=fetch_vars)
            logger.info("TRAIN --> pass: {} batch_id: {} avg_cost: {}, acc: {}"
                        .format(step_id, batch_id, loss_val,
                                float(correct_val) / args.batch_size))

Y
Yibing Liu 已提交
128

F
frankwhzhang 已提交
129 130 131 132
def main():
    args = parse_args()
    start_infer(args, args.model_dir)

Y
Yibing Liu 已提交
133

F
frankwhzhang 已提交
134
if __name__ == "__main__":
Z
zhang wenhui 已提交
135
    check_version()
F
frankwhzhang 已提交
136
    main()