infer.py 2.4 KB
Newer Older
Q
Qiao Longfei 已提交
1
import argparse
Q
Qiao Longfei 已提交
2
import logging
Q
Qiao Longfei 已提交
3

Q
Qiao Longfei 已提交
4
import numpy as np
Q
Qiao Longfei 已提交
5 6 7
import paddle
import paddle.fluid as fluid

Q
Qiao Longfei 已提交
8
import reader
Q
Qiao Longfei 已提交
9
from network_conf import ctr_dnn_model
Q
Qiao Longfei 已提交
10 11


Q
Qiao Longfei 已提交
12 13 14 15
logging.basicConfig(
    format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
Q
Qiao Longfei 已提交
16 17


Q
Qiao Longfei 已提交
18 19 20
def parse_args():
    parser = argparse.ArgumentParser(description="PaddlePaddle DeepFM example")
    parser.add_argument(
Q
Qiao Longfei 已提交
21
        '--model_path',
Q
Qiao Longfei 已提交
22 23 24 25 26 27 28 29 30
        type=str,
        required=True,
        help="The path of model parameters gz file")
    parser.add_argument(
        '--data_path',
        type=str,
        required=True,
        help="The path of the dataset to infer")
    parser.add_argument(
Q
Qiao Longfei 已提交
31
        '--embedding_size',
Q
Qiao Longfei 已提交
32 33
        type=int,
        default=10,
Q
Qiao Longfei 已提交
34
        help="The size for embedding layer (default:10)")
Q
Qiao Longfei 已提交
35 36 37 38 39
    parser.add_argument(
        '--batch_size',
        type=int,
        default=1000,
        help="The size of mini-batch (default:1000)")
Q
Qiao Longfei 已提交
40 41 42 43 44 45 46

    return parser.parse_args()


def infer():
    args = parse_args()

Q
Qiao Longfei 已提交
47 48
    place = fluid.CPUPlace()
    inference_scope = fluid.core.Scope()
Q
Qiao Longfei 已提交
49 50

    dataset = reader.Dataset()
Q
Qiao Longfei 已提交
51
    test_reader = paddle.batch(dataset.train([args.data_path]), batch_size=args.batch_size)
Q
Qiao Longfei 已提交
52 53 54 55

    startup_program = fluid.framework.Program()
    test_program = fluid.framework.Program()
    with fluid.framework.program_guard(test_program, startup_program):
Q
Qiao Longfei 已提交
56
        loss, data_list, auc_var, batch_auc_var = ctr_dnn_model(args.embedding_size)
Q
Qiao Longfei 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73

    exe = fluid.Executor(place)

    feeder = fluid.DataFeeder(feed_list=data_list, place=place)

    with fluid.scope_guard(inference_scope):
        [inference_program, _, fetch_targets] = fluid.io.load_inference_model(args.model_path, exe)

        def set_zero(var_name):
            param = inference_scope.var(var_name).get_tensor()
            param_array = np.zeros(param._get_dims()).astype("int64")
            param.set(param_array, place)

        auc_states_names = ['_generated_var_2', '_generated_var_3']
        for name in auc_states_names:
            set_zero(name)

Q
Qiao Longfei 已提交
74
        for batch_id, data in enumerate(test_reader()):
Q
Qiao Longfei 已提交
75 76 77 78
            loss_val, auc_val = exe.run(inference_program,
                feed=feeder.feed(data),
                fetch_list=fetch_targets)
            if batch_id % 100 == 0:
Q
Qiao Longfei 已提交
79
                logger.info("TEST --> batch: {} loss: {} auc: {}".format(batch_id, loss_val/args.batch_size, auc_val))
Q
Qiao Longfei 已提交
80 81 82 83


if __name__ == '__main__':
    infer()