train.py 1.5 KB
Newer Older
O
overlordmax 已提交
1 2 3 4 5 6 7 8 9 10 11 12
import numpy as np
import os
import paddle.fluid as fluid
from net import ESMM
import paddle
import utils
import args

def train(args, vocab_size, train_data_path):
    esmm_model = ESMM()
    inputs = esmm_model.input_data()

O
overlordmax 已提交
13
    dataset, file_list = utils.get_dataset(inputs, train_data_path, args.batch_size,args.cpu_num)
O
overlordmax 已提交
14
    
O
overlordmax 已提交
15
    avg_cost, auc_ctr, auc_ctcvr = esmm_model.net(inputs, vocab_size, args.embed_size)
O
overlordmax 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
    optimizer = fluid.optimizer.Adam()
    optimizer.minimize(avg_cost)
    
    if args.use_gpu == True:
        exe = fluid.Executor(fluid.CUDAPlace(0))
        dataset.set_thread(1)
    else:
        exe = fluid.Executor(fluid.CPUPlace())
        dataset.set_thread(args.cpu_num)
    
    exe.run(fluid.default_startup_program())

    for epoch in range(args.epochs):
        dataset.set_filelist(file_list)
        exe.train_from_dataset(program=fluid.default_main_program(),
                                   dataset=dataset,
O
overlordmax 已提交
32 33
                                   fetch_list=[avg_cost, auc_ctr, auc_ctcvr],
                                   fetch_info=['epoch %d batch loss' % (epoch), "auc_ctr", "auc_ctcvr"],
O
overlordmax 已提交
34 35
                                   print_period=20,
                                   debug=False)
O
overlordmax 已提交
36
        model_dir = os.path.join(args.model_dir, 'epoch_' + str(epoch + 1), "checkpoint")
O
overlordmax 已提交
37 38 39 40 41 42 43
        main_program = fluid.default_main_program()
        fluid.io.save(main_program,model_dir)

if __name__ == "__main__":
    args = args.parse_args()
    vocab_size =utils.get_vocab_size(args.vocab_path)
    train(args, vocab_size, args.train_data_path)