infer.py 2.6 KB
Newer Older
O
overlordmax 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
import os
import numpy as np
import paddle
import paddle.fluid as fluid
from net import ESMM
import args
import logging
import utils

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)

def set_zero(place):
    auc_states_names = [
        'auc_1.tmp_0', 'auc_0.tmp_0'
    ]
    for name in auc_states_names:
        param = fluid.global_scope().var(name).get_tensor()
        if param:
            param_array = np.zeros(param._get_dims()).astype("int64")
            param.set(param_array, place)

O
overlordmax 已提交
24
def run_infer(args, model_path, test_data_path, vocab_size):
O
overlordmax 已提交
25 26 27 28 29 30 31 32 33 34 35
    place = fluid.CPUPlace()
    esmm_model = ESMM()
    
    startup_program = fluid.framework.Program()
    test_program = fluid.framework.Program()

    with fluid.framework.program_guard(test_program, startup_program):
        with fluid.unique_name.guard():
            inputs = esmm_model.input_data()
            avg_cost,auc_ctr,auc_ctcvr= esmm_model.net(inputs, vocab_size, args.embed_size)
            
Y
yudongxu(许煜东) 已提交
36
            dataset, file_list = utils.get_dataset(inputs, test_data_path, args.batch_size, args.cpu_num)
O
overlordmax 已提交
37 38
            
            exe = fluid.Executor(place)
O
overlordmax 已提交
39
            fluid.load(fluid.default_main_program(), os.path.join(model_path, "checkpoint"), exe)
O
overlordmax 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53
            
            set_zero(place)
            
            dataset.set_filelist(file_list)
            exe.infer_from_dataset(program=test_program,
                                       dataset=dataset,
                                       fetch_list=[auc_ctr,auc_ctcvr],
                                       fetch_info=["auc_ctr","auc_ctcvr"],
                                       print_period=20,
                                       debug=False)
                                       
if __name__ == "__main__":
  
    args = args.parse_args()
Y
yudongxu(许煜东) 已提交
54 55 56
    
    logger.info("use_gpu: {}, epochs: {}, batch_size: {}, cpu_num: {}, model_dir: {}, test_data_path: {}, vocab_path: {}".format(args.use_gpu, args.epochs, 
        args.batch_size, args.cpu_num, args.model_dir, args.test_data_path, args.vocab_path))
O
overlordmax 已提交
57 58 59 60 61 62 63
    model_list = []
    for _, dir, _ in os.walk(args.model_dir):
        for model in dir:
            if "epoch" in model:
                path = os.path.join(args.model_dir, model)
                model_list.append(path)
                
Y
yudongxu(许煜东) 已提交
64
    vocab_size = utils.get_vocab_size(args.vocab_path)  
O
overlordmax 已提交
65 66 67
    
    for model in model_list:
        logger.info("Test model {}".format(model))
Y
yudongxu(许煜东) 已提交
68
        run_infer(args, model,args.test_data_path, vocab_size)
O
overlordmax 已提交
69
                
O
overlordmax 已提交
70