infer.py 6.2 KB
Newer Older
C
Chengmo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
import os
C
Chengmo 已提交
17
import time
C
Chengmo 已提交
18
import six
C
Chengmo 已提交
19 20 21
import numpy as np
import logging
import argparse
Q
Qiao Longfei 已提交
22 23
import paddle
import paddle.fluid as fluid
C
Chengmo 已提交
24 25
from network_conf import CTR
import feed_generator as generator
Q
Qiao Longfei 已提交
26

Y
Yibing Liu 已提交
27
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
Q
Qiao Longfei 已提交
28 29
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
Q
Qiao Longfei 已提交
30 31


Q
Qiao Longfei 已提交
32
def parse_args():
Z
zhang wenhui 已提交
33
    parser = argparse.ArgumentParser(description="PaddlePaddle CTR-DNN example")
C
Chengmo 已提交
34
    # -------------Data & Model Path-------------
Q
Qiao Longfei 已提交
35
    parser.add_argument(
C
Chengmo 已提交
36 37 38 39 40 41
        '--test_files_path',
        type=str,
        default='./test_data',
        help="The path of testing dataset")
    parser.add_argument(
        '--model_path',
Q
Qiao Longfei 已提交
42
        type=str,
C
Chengmo 已提交
43 44 45
        default='models',
        help='The path for model to store (default: models)')

C
Chengmo 已提交
46
    # -------------Running parameter-------------
C
Chengmo 已提交
47 48 49 50 51 52
    parser.add_argument(
        '--batch_size',
        type=int,
        default=1000,
        help="The size of mini-batch (default:1000)")
    parser.add_argument(
C
Chengmo 已提交
53
        '--infer_epoch',
C
Chengmo 已提交
54
        type=int,
C
Chengmo 已提交
55
        default=0,
Z
zhang wenhui 已提交
56
        help='Specify which epoch to run infer')
C
Chengmo 已提交
57
    # -------------Network parameter-------------
Q
Qiao Longfei 已提交
58
    parser.add_argument(
Q
Qiao Longfei 已提交
59
        '--embedding_size',
Q
Qiao Longfei 已提交
60 61
        type=int,
        default=10,
Q
Qiao Longfei 已提交
62
        help="The size for embedding layer (default:10)")
63 64 65 66
    parser.add_argument(
        '--sparse_feature_dim',
        type=int,
        default=1000001,
C
Chengmo 已提交
67
        help='sparse feature hashing space for index processing')
Q
Qiao Longfei 已提交
68
    parser.add_argument(
Z
zhang wenhui 已提交
69
        '--dense_feature_dim', type=int, default=13, help='dense feature shape')
C
Chengmo 已提交
70 71 72 73 74 75 76 77 78 79 80 81

    # -------------device parameter-------------
    parser.add_argument(
        '--is_local',
        type=int,
        default=0,
        help='Local train or distributed train (default: 1)')
    parser.add_argument(
        '--is_cloud',
        type=int,
        default=0,
        help='Local train or distributed train on paddlecloud (default: 0)')
Q
Qiao Longfei 已提交
82 83 84 85

    return parser.parse_args()


C
Chengmo 已提交
86 87 88 89 90 91 92 93 94 95
def print_arguments(args):
    """
    print arguments
    """
    logger.info('-----------  Configuration Arguments -----------')
    for arg, value in sorted(six.iteritems(vars(args))):
        logger.info('%s: %s' % (arg, value))
    logger.info('------------------------------------------------')


C
Chengmo 已提交
96
def run_infer(args, model_path):
Q
Qiao Longfei 已提交
97
    place = fluid.CPUPlace()
C
Chengmo 已提交
98 99
    train_generator = generator.CriteoDataset(args.sparse_feature_dim)
    file_list = [
Z
zhang wenhui 已提交
100 101
        os.path.join(args.test_files_path, x)
        for x in os.listdir(args.test_files_path)
C
Chengmo 已提交
102
    ]
Z
zhang wenhui 已提交
103 104
    test_reader = fluid.io.batch(
        train_generator.test(file_list), batch_size=args.batch_size)
Q
Qiao Longfei 已提交
105 106
    startup_program = fluid.framework.Program()
    test_program = fluid.framework.Program()
C
Chengmo 已提交
107 108 109 110 111 112 113 114 115 116
    ctr_model = CTR()

    def set_zero():
        auc_states_names = [
            '_generated_var_0', '_generated_var_1', '_generated_var_2',
            '_generated_var_3'
        ]
        for name in auc_states_names:
            param = fluid.global_scope().var(name).get_tensor()
            if param:
117 118 119
                param_array = np.zeros(param._get_dims()).astype("int64")
                param.set(param_array, place)

C
Chengmo 已提交
120 121 122 123
    with fluid.framework.program_guard(test_program, startup_program):
        with fluid.unique_name.guard():
            inputs = ctr_model.input_data(args)
            loss, auc_var = ctr_model.net(inputs, args)
124

C
Chengmo 已提交
125 126 127 128 129 130 131 132 133 134
            exe = fluid.Executor(place)
            feeder = fluid.DataFeeder(feed_list=inputs, place=place)

            if args.is_cloud:
                fluid.io.load_persistables(
                    executor=exe,
                    dirname=model_path,
                    main_program=fluid.default_main_program())
            elif args.is_local:
                fluid.load(fluid.default_main_program(),
C
Chengmo 已提交
135
                           os.path.join(model_path, "checkpoint"), exe)
C
Chengmo 已提交
136 137 138 139 140
            set_zero()

            run_index = 0
            infer_auc = 0
            L = []
141 142 143 144
            for batch_id, data in enumerate(test_reader()):
                loss_val, auc_val = exe.run(test_program,
                                            feed=feeder.feed(data),
                                            fetch_list=[loss, auc_var])
C
Chengmo 已提交
145 146 147
                run_index += 1
                infer_auc = auc_val
                L.append(loss_val / args.batch_size)
148 149 150
                if batch_id % 100 == 0:
                    logger.info("TEST --> batch: {} loss: {} auc: {}".format(
                        batch_id, loss_val / args.batch_size, auc_val))
Q
Qiao Longfei 已提交
151

C
Chengmo 已提交
152 153 154 155
            infer_loss = np.mean(L)
            infer_result = {}
            infer_result['loss'] = infer_loss
            infer_result['auc'] = infer_auc
C
Chengmo 已提交
156
            log_path = os.path.join(model_path, 'infer_result.log')
C
Chengmo 已提交
157 158 159 160 161
            logger.info(str(infer_result))
            with open(log_path, 'w+') as f:
                f.write(str(infer_result))
            logger.info("Inference complete")
    return infer_result
Q
Qiao Longfei 已提交
162

C
Chengmo 已提交
163 164 165

if __name__ == "__main__":
    args = parse_args()
C
Chengmo 已提交
166
    print_arguments(args)
C
Chengmo 已提交
167 168 169
    model_list = []
    for _, dir, _ in os.walk(args.model_path):
        for model in dir:
Z
zhang wenhui 已提交
170 171
            if "epoch" in model and args.infer_epoch == int(
                    model.split('_')[-1]):
C
Chengmo 已提交
172
                path = os.path.join(args.model_path, model)
C
Chengmo 已提交
173
                model_list.append(path)
C
Chengmo 已提交
174 175

    if len(model_list) == 0:
Z
zhang wenhui 已提交
176 177 178
        logger.info(
            "There is no satisfactory model {} at path {}, please check your start command & env. ".
            format(str("epoch_") + str(args.infer_epoch), args.model_path))
C
Chengmo 已提交
179

C
Chengmo 已提交
180 181 182
    for model in model_list:
        logger.info("Test model {}".format(model))
        run_infer(args, model)