infer.py 3.0 KB
Newer Older
O
overlordmax 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
import numpy as np
import os
import paddle.fluid as fluid
from net import wide_deep
import logging
import paddle
import args
import utils
import time
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)

def set_zero(var_name,scope=fluid.global_scope(), place=fluid.CPUPlace(),param_type="int64"):
    """
    Set tensor of a Variable to zero.
    Args:
        var_name(str): name of Variable
        scope(Scope): Scope object, default is fluid.global_scope()
        place(Place): Place object, default is fluid.CPUPlace()
        param_type(str): param data type, default is int64
    """
    param = scope.var(var_name).get_tensor()
    param_array = np.zeros(param._get_dims()).astype(param_type)
    param.set(param_array, place)


def run_infer(args,test_data_path):
    wide_deep_model = wide_deep()

    test_data_generator = utils.CriteoDataset()
O
overlordmax 已提交
32
    test_reader = paddle.batch(test_data_generator.test(test_data_path), batch_size=args.batch_size)
O
overlordmax 已提交
33 34 35 36 37
    
    inference_scope = fluid.Scope()
    startup_program = fluid.framework.Program()
    test_program = fluid.framework.Program()
    
O
overlordmax 已提交
38
    cur_model_path = os.path.join(args.model_dir, 'epoch_' + str(args.test_epoch), "checkpoint")
O
overlordmax 已提交
39 40 41 42 43

    with fluid.scope_guard(inference_scope):
        with fluid.framework.program_guard(test_program, startup_program):
            inputs = wide_deep_model.input_data()
            place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
O
overlordmax 已提交
44
            loss, acc, auc, batch_auc, auc_states = wide_deep_model.model(inputs, args.hidden1_units, args.hidden2_units, args.hidden3_units)
O
overlordmax 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58
            exe = fluid.Executor(place)
            exe.run(startup_program)

            fluid.load(fluid.default_main_program(), cur_model_path,exe)
            feeder = fluid.DataFeeder(feed_list=inputs, place=place)
            
            for var in auc_states:  # reset auc states
                set_zero(var.name, scope=inference_scope, place=place)

            mean_acc = []
            mean_auc = []
            for batch_id, data in enumerate(test_reader()):
                begin = time.time()
                acc_val,auc_val = exe.run(program=test_program,
O
overlordmax 已提交
59 60 61 62
                                        feed=feeder.feed(data),
                                        fetch_list=[acc.name, auc.name],
                                        return_numpy=True
                                        )
O
overlordmax 已提交
63 64 65
                mean_acc.append(np.array(acc_val)[0])
                mean_auc.append(np.array(auc_val)[0])
                end = time.time()
O
overlordmax 已提交
66 67 68 69
                logger.info("batch_id: {}, batch_time: {:.5f}s, acc: {:.5f}, auc: {:.5f}".format(
                            batch_id, end-begin, np.array(acc_val)[0], np.array(auc_val)[0]))
                            
            logger.info("mean_acc:{:.5f}, mean_auc:{:.5f}".format(np.mean(mean_acc), np.mean(mean_auc)))
O
overlordmax 已提交
70 71 72 73
                
if __name__ == "__main__":
  
    args = args.parse_args()
O
overlordmax 已提交
74 75
    run_infer(args, args.test_data_path)