infer.py 3.0 KB
Newer Older
C
CandyCaneLane 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
import logging
import numpy as np
import pickle

# disable gpu training for this example 
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import paddle
import paddle.fluid as fluid

from args import parse_args
from criteo_reader import CriteoDataset
from network_conf import ctr_deepfm_model
14
import utils
C
CandyCaneLane 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger('fluid')
logger.setLevel(logging.INFO)


def infer():
    args = parse_args()

    place = fluid.CPUPlace()
    inference_scope = fluid.Scope()

    test_files = [
28 29
        os.path.join(args.test_data_dir, x)
        for x in os.listdir(args.test_data_dir)
C
CandyCaneLane 已提交
30 31
    ]
    criteo_dataset = CriteoDataset()
32
    criteo_dataset.setup(args.feat_dict)
Z
zhang wenhui 已提交
33
    test_reader = fluid.io.batch(
C
CandyCaneLane 已提交
34 35 36 37
        criteo_dataset.test(test_files), batch_size=args.batch_size)

    startup_program = fluid.framework.Program()
    test_program = fluid.framework.Program()
38 39
    cur_model_path = os.path.join(args.model_output_dir,
                                  'epoch_' + args.test_epoch)
C
CandyCaneLane 已提交
40 41 42

    with fluid.scope_guard(inference_scope):
        with fluid.framework.program_guard(test_program, startup_program):
43
            loss, auc, data_list, auc_states = ctr_deepfm_model(
C
CandyCaneLane 已提交
44 45 46 47 48
                args.embedding_size, args.num_field, args.num_feat,
                args.layer_sizes, args.act, args.reg)

            exe = fluid.Executor(place)
            feeder = fluid.DataFeeder(feed_list=data_list, place=place)
Y
yaoxuefeng 已提交
49 50
            main_program = fluid.default_main_program()
            fluid.load(main_program, cur_model_path, exe)
51 52
            for var in auc_states:  # reset auc states
                set_zero(var.name, scope=inference_scope, place=place)
C
CandyCaneLane 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69

            loss_all = 0
            num_ins = 0
            for batch_id, data_test in enumerate(test_reader()):
                loss_val, auc_val = exe.run(test_program,
                                            feed=feeder.feed(data_test),
                                            fetch_list=[loss.name, auc.name])
                num_ins += len(data_test)
                loss_all += loss_val
                logger.info('TEST --> batch: {} loss: {} auc_val: {}'.format(
                    batch_id, loss_all / num_ins, auc_val))

            print(
                'The last log info is the total Logloss and AUC for all test data. '
            )


70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
def set_zero(var_name,
             scope=fluid.global_scope(),
             place=fluid.CPUPlace(),
             param_type="int64"):
    """
    Set tensor of a Variable to zero.
    Args:
        var_name(str): name of Variable
        scope(Scope): Scope object, default is fluid.global_scope()
        place(Place): Place object, default is fluid.CPUPlace()
        param_type(str): param data type, default is int64
    """
    param = scope.var(var_name).get_tensor()
    param_array = np.zeros(param._get_dims()).astype(param_type)
    param.set(param_array, place)


C
CandyCaneLane 已提交
87
if __name__ == '__main__':
88
    utils.check_version()
C
CandyCaneLane 已提交
89
    infer()