local_train.py 2.2 KB
Newer Older
1 2 3 4 5 6
from args import parse_args
import os
import paddle.fluid as fluid
import sys
import network_conf
import time
7
import utils
8 9 10 11


def train():
    args = parse_args()
12 13 14 15 16 17
    # add ce
    if args.enable_ce:
        SEED = 102
        fluid.default_main_program().random_seed = SEED
        fluid.default_startup_program().random_seed = SEED

18 19 20 21
    print(args)
    if not os.path.isdir(args.model_output_dir):
        os.mkdir(args.model_output_dir)

22
    loss, auc, data_list, auc_states = eval('network_conf.' + args.model_name)(
23 24 25 26 27 28 29 30 31 32 33 34
        args.embedding_size, args.num_field, args.num_feat,
        args.layer_sizes_dnn, args.act, args.reg, args.layer_sizes_cin)
    optimizer = fluid.optimizer.SGD(
        learning_rate=args.lr,
        regularization=fluid.regularizer.L2DecayRegularizer(args.reg))
    optimizer.minimize(loss)

    dataset = fluid.DatasetFactory().create_dataset()
    dataset.set_use_var(data_list)
    dataset.set_pipe_command('python criteo_reader.py')
    dataset.set_batch_size(args.batch_size)
    dataset.set_filelist([
35 36
        os.path.join(args.train_data_dir, x)
        for x in os.listdir(args.train_data_dir)
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
    ])

    if args.use_gpu == 1:
        exe = fluid.Executor(fluid.CUDAPlace(0))
        dataset.set_thread(1)
    else:
        exe = fluid.Executor(fluid.CPUPlace())
        dataset.set_thread(args.num_thread)
    exe.run(fluid.default_startup_program())

    for epoch_id in range(args.num_epoch):
        start = time.time()
        sys.stderr.write('\nepoch%d start ...\n' % (epoch_id + 1))
        exe.train_from_dataset(
            program=fluid.default_main_program(),
            dataset=dataset,
            fetch_list=[loss, auc],
            fetch_info=['loss', 'auc'],
            debug=False,
            print_period=args.print_steps)
57 58
        model_dir = os.path.join(args.model_output_dir,
                                 'epoch_' + str(epoch_id + 1))
59 60 61 62 63 64 65 66 67
        sys.stderr.write('epoch%d is finished and takes %f s\n' % (
            (epoch_id + 1), time.time() - start))
        fluid.io.save_persistables(
            executor=exe,
            dirname=model_dir,
            main_program=fluid.default_main_program())


if __name__ == '__main__':
68
    utils.check_version()
69
    train()