train.py 2.0 KB
Newer Older
O
overlordmax 已提交
1 2 3 4 5 6 7
import numpy as np
import os
import paddle.fluid as fluid
from net import ESMM
import paddle
import utils
import args
Y
yudongxu(许煜东) 已提交
8 9 10 11 12
import logging

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
O
overlordmax 已提交
13 14 15 16 17

def train(args, vocab_size, train_data_path):
    esmm_model = ESMM()
    inputs = esmm_model.input_data()

O
overlordmax 已提交
18
    dataset, file_list = utils.get_dataset(inputs, train_data_path, args.batch_size,args.cpu_num)
O
overlordmax 已提交
19
    
O
overlordmax 已提交
20
    avg_cost, auc_ctr, auc_ctcvr = esmm_model.net(inputs, vocab_size, args.embed_size)
O
overlordmax 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
    optimizer = fluid.optimizer.Adam()
    optimizer.minimize(avg_cost)
    
    if args.use_gpu == True:
        exe = fluid.Executor(fluid.CUDAPlace(0))
        dataset.set_thread(1)
    else:
        exe = fluid.Executor(fluid.CPUPlace())
        dataset.set_thread(args.cpu_num)
    
    exe.run(fluid.default_startup_program())

    for epoch in range(args.epochs):
        dataset.set_filelist(file_list)
        exe.train_from_dataset(program=fluid.default_main_program(),
                                   dataset=dataset,
O
overlordmax 已提交
37 38
                                   fetch_list=[avg_cost, auc_ctr, auc_ctcvr],
                                   fetch_info=['epoch %d batch loss' % (epoch), "auc_ctr", "auc_ctcvr"],
O
overlordmax 已提交
39 40
                                   print_period=20,
                                   debug=False)
O
overlordmax 已提交
41
        model_dir = os.path.join(args.model_dir, 'epoch_' + str(epoch + 1), "checkpoint")
O
overlordmax 已提交
42 43 44 45 46
        main_program = fluid.default_main_program()
        fluid.io.save(main_program,model_dir)

if __name__ == "__main__":
    args = args.parse_args()
Y
yudongxu(许煜东) 已提交
47 48 49
    logger.info("use_gpu: {}, epochs: {}, batch_size: {}, embed_size: {}, cpu_num: {}, model_dir: {}, train_data_path: {}, vocab_path: {}".format(args.use_gpu, args.epochs, 
        args.batch_size, args.embed_size, args.cpu_num, args.model_dir, args.train_data_path, args.vocab_path))
    
O
overlordmax 已提交
50 51
    vocab_size =utils.get_vocab_size(args.vocab_path)
    train(args, vocab_size, args.train_data_path)