local_train.py 2.2 KB
Newer Older
C
CandyCaneLane 已提交
1 2 3 4 5 6 7 8
from args import parse_args
import os
import paddle.fluid as fluid
import sys
from network_conf import ctr_deepfm_model
import time
import numpy
import pickle
9
import utils
C
CandyCaneLane 已提交
10 11 12 13


def train():
    args = parse_args()
14 15 16 17 18 19
    # add ce
    if args.enable_ce:
        SEED = 102
        fluid.default_main_program().random_seed = SEED
        fluid.default_startup_program().random_seed = SEED

C
CandyCaneLane 已提交
20 21 22 23 24 25 26
    print('---------- Configuration Arguments ----------')
    for key, value in args.__dict__.items():
        print(key + ':' + str(value))

    if not os.path.isdir(args.model_output_dir):
        os.mkdir(args.model_output_dir)

27 28 29
    loss, auc, data_list, auc_states = ctr_deepfm_model(
        args.embedding_size, args.num_field, args.num_feat, args.layer_sizes,
        args.act, args.reg)
C
CandyCaneLane 已提交
30 31 32 33 34 35 36 37 38 39
    optimizer = fluid.optimizer.SGD(
        learning_rate=args.lr,
        regularization=fluid.regularizer.L2DecayRegularizer(args.reg))
    optimizer.minimize(loss)

    exe = fluid.Executor(fluid.CPUPlace())
    exe.run(fluid.default_startup_program())

    dataset = fluid.DatasetFactory().create_dataset()
    dataset.set_use_var(data_list)
40
    pipe_command = 'python criteo_reader.py {}'.format(args.feat_dict)
C
CandyCaneLane 已提交
41 42 43 44
    dataset.set_pipe_command(pipe_command)
    dataset.set_batch_size(args.batch_size)
    dataset.set_thread(args.num_thread)
    train_filelist = [
45 46
        os.path.join(args.train_data_dir, x)
        for x in os.listdir(args.train_data_dir)
C
CandyCaneLane 已提交
47 48 49 50 51 52 53 54 55
    ]

    print('---------------------------------------------')
    for epoch_id in range(args.num_epoch):
        start = time.time()
        dataset.set_filelist(train_filelist)
        exe.train_from_dataset(
            program=fluid.default_main_program(),
            dataset=dataset,
56 57
            fetch_list=[loss, auc],
            fetch_info=['epoch %d batch loss' % (epoch_id + 1), "auc"],
C
CandyCaneLane 已提交
58 59
            print_period=1000,
            debug=False)
60 61
        model_dir = os.path.join(args.model_output_dir,
                                 'epoch_' + str(epoch_id + 1))
C
CandyCaneLane 已提交
62 63
        sys.stderr.write('epoch%d is finished and takes %f s\n' % (
            (epoch_id + 1), time.time() - start))
Y
yaoxuefeng 已提交
64 65
        main_program = fluid.default_main_program()
        fluid.io.save(main_program, model_dir)
C
CandyCaneLane 已提交
66 67 68


if __name__ == '__main__':
69
    utils.check_version()
C
CandyCaneLane 已提交
70
    train()