infer.py 6.2 KB
Newer Older
C
Chengmo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
import os
C
Chengmo 已提交
17
import time
C
Chengmo 已提交
18
import six
C
Chengmo 已提交
19 20 21
import numpy as np
import logging
import argparse
Q
Qiao Longfei 已提交
22 23
import paddle
import paddle.fluid as fluid
C
Chengmo 已提交
24 25
from network_conf import CTR
import feed_generator as generator
Q
Qiao Longfei 已提交
26

Y
Yibing Liu 已提交
27
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
Q
Qiao Longfei 已提交
28 29
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
Q
Qiao Longfei 已提交
30 31


Q
Qiao Longfei 已提交
32
def parse_args():
C
Chengmo 已提交
33 34 35
    parser = argparse.ArgumentParser(
        description="PaddlePaddle CTR-DNN example")
    # -------------Data & Model Path-------------
Q
Qiao Longfei 已提交
36
    parser.add_argument(
C
Chengmo 已提交
37 38 39 40 41 42
        '--test_files_path',
        type=str,
        default='./test_data',
        help="The path of testing dataset")
    parser.add_argument(
        '--model_path',
Q
Qiao Longfei 已提交
43
        type=str,
C
Chengmo 已提交
44 45 46
        default='models',
        help='The path for model to store (default: models)')

C
Chengmo 已提交
47
    # -------------Running parameter-------------
C
Chengmo 已提交
48 49 50 51 52 53
    parser.add_argument(
        '--batch_size',
        type=int,
        default=1000,
        help="The size of mini-batch (default:1000)")
    parser.add_argument(
C
Chengmo 已提交
54
        '--infer_epoch',
C
Chengmo 已提交
55
        type=int,
C
Chengmo 已提交
56 57 58
        default=0,
        help='Specify which epoch to run infer'
    )
C
Chengmo 已提交
59
    # -------------Network parameter-------------
Q
Qiao Longfei 已提交
60
    parser.add_argument(
Q
Qiao Longfei 已提交
61
        '--embedding_size',
Q
Qiao Longfei 已提交
62 63
        type=int,
        default=10,
Q
Qiao Longfei 已提交
64
        help="The size for embedding layer (default:10)")
65 66 67 68
    parser.add_argument(
        '--sparse_feature_dim',
        type=int,
        default=1000001,
C
Chengmo 已提交
69
        help='sparse feature hashing space for index processing')
Q
Qiao Longfei 已提交
70
    parser.add_argument(
C
Chengmo 已提交
71
        '--dense_feature_dim',
Q
Qiao Longfei 已提交
72
        type=int,
C
Chengmo 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86
        default=13,
        help='dense feature shape')

    # -------------device parameter-------------
    parser.add_argument(
        '--is_local',
        type=int,
        default=0,
        help='Local train or distributed train (default: 1)')
    parser.add_argument(
        '--is_cloud',
        type=int,
        default=0,
        help='Local train or distributed train on paddlecloud (default: 0)')
Q
Qiao Longfei 已提交
87 88 89 90

    return parser.parse_args()


C
Chengmo 已提交
91 92 93 94 95 96 97 98 99 100
def print_arguments(args):
    """
    print arguments
    """
    logger.info('-----------  Configuration Arguments -----------')
    for arg, value in sorted(six.iteritems(vars(args))):
        logger.info('%s: %s' % (arg, value))
    logger.info('------------------------------------------------')


C
Chengmo 已提交
101
def run_infer(args, model_path):
Q
Qiao Longfei 已提交
102
    place = fluid.CPUPlace()
C
Chengmo 已提交
103 104
    train_generator = generator.CriteoDataset(args.sparse_feature_dim)
    file_list = [
C
Chengmo 已提交
105
        os.path.join(args.test_files_path, x) for x in os.listdir(args.test_files_path)
C
Chengmo 已提交
106 107 108
    ]
    test_reader = paddle.batch(train_generator.test(file_list),
                               batch_size=args.batch_size)
Q
Qiao Longfei 已提交
109 110
    startup_program = fluid.framework.Program()
    test_program = fluid.framework.Program()
C
Chengmo 已提交
111 112 113 114 115 116 117 118 119 120
    ctr_model = CTR()

    def set_zero():
        auc_states_names = [
            '_generated_var_0', '_generated_var_1', '_generated_var_2',
            '_generated_var_3'
        ]
        for name in auc_states_names:
            param = fluid.global_scope().var(name).get_tensor()
            if param:
121 122 123
                param_array = np.zeros(param._get_dims()).astype("int64")
                param.set(param_array, place)

C
Chengmo 已提交
124 125 126 127
    with fluid.framework.program_guard(test_program, startup_program):
        with fluid.unique_name.guard():
            inputs = ctr_model.input_data(args)
            loss, auc_var = ctr_model.net(inputs, args)
128

C
Chengmo 已提交
129 130 131 132 133 134 135 136 137 138
            exe = fluid.Executor(place)
            feeder = fluid.DataFeeder(feed_list=inputs, place=place)

            if args.is_cloud:
                fluid.io.load_persistables(
                    executor=exe,
                    dirname=model_path,
                    main_program=fluid.default_main_program())
            elif args.is_local:
                fluid.load(fluid.default_main_program(),
C
Chengmo 已提交
139
                           os.path.join(model_path, "checkpoint"), exe)
C
Chengmo 已提交
140 141 142 143 144
            set_zero()

            run_index = 0
            infer_auc = 0
            L = []
145 146 147 148
            for batch_id, data in enumerate(test_reader()):
                loss_val, auc_val = exe.run(test_program,
                                            feed=feeder.feed(data),
                                            fetch_list=[loss, auc_var])
C
Chengmo 已提交
149 150 151
                run_index += 1
                infer_auc = auc_val
                L.append(loss_val / args.batch_size)
152 153 154
                if batch_id % 100 == 0:
                    logger.info("TEST --> batch: {} loss: {} auc: {}".format(
                        batch_id, loss_val / args.batch_size, auc_val))
Q
Qiao Longfei 已提交
155

C
Chengmo 已提交
156 157 158 159
            infer_loss = np.mean(L)
            infer_result = {}
            infer_result['loss'] = infer_loss
            infer_result['auc'] = infer_auc
C
Chengmo 已提交
160
            log_path = os.path.join(model_path, 'infer_result.log')
C
Chengmo 已提交
161 162 163 164 165
            logger.info(str(infer_result))
            with open(log_path, 'w+') as f:
                f.write(str(infer_result))
            logger.info("Inference complete")
    return infer_result
Q
Qiao Longfei 已提交
166

C
Chengmo 已提交
167 168 169

if __name__ == "__main__":
    args = parse_args()
C
Chengmo 已提交
170
    print_arguments(args)
C
Chengmo 已提交
171 172 173
    model_list = []
    for _, dir, _ in os.walk(args.model_path):
        for model in dir:
C
Chengmo 已提交
174 175
            if "epoch" in model and args.infer_epoch == int(model.split('_')[-1]):
                path = os.path.join(args.model_path, model)
C
Chengmo 已提交
176
                model_list.append(path)
C
Chengmo 已提交
177 178 179 180 181

    if len(model_list) == 0:
        logger.info("There is no satisfactory model {} at path {}, please check your start command & env. ".format(
            str("epoch_")+str(args.infer_epoch), args.model_path))

C
Chengmo 已提交
182 183 184
    for model in model_list:
        logger.info("Test model {}".format(model))
        run_infer(args, model)