local_train.py 2.1 KB
Newer Older
C
CandyCaneLane 已提交
1 2 3 4 5 6 7 8
from args import parse_args
import os
import paddle.fluid as fluid
import sys
from network_conf import ctr_deepfm_model
import time
import numpy
import pickle
9
import utils
C
CandyCaneLane 已提交
10 11 12 13 14 15 16 17 18 19 20


def train():
    args = parse_args()
    print('---------- Configuration Arguments ----------')
    for key, value in args.__dict__.items():
        print(key + ':' + str(value))

    if not os.path.isdir(args.model_output_dir):
        os.mkdir(args.model_output_dir)

21 22 23
    loss, auc, data_list, auc_states = ctr_deepfm_model(
        args.embedding_size, args.num_field, args.num_feat, args.layer_sizes,
        args.act, args.reg)
C
CandyCaneLane 已提交
24 25 26 27 28 29 30 31 32 33
    optimizer = fluid.optimizer.SGD(
        learning_rate=args.lr,
        regularization=fluid.regularizer.L2DecayRegularizer(args.reg))
    optimizer.minimize(loss)

    exe = fluid.Executor(fluid.CPUPlace())
    exe.run(fluid.default_startup_program())

    dataset = fluid.DatasetFactory().create_dataset()
    dataset.set_use_var(data_list)
34
    pipe_command = 'python criteo_reader.py {}'.format(args.feat_dict)
C
CandyCaneLane 已提交
35 36 37 38
    dataset.set_pipe_command(pipe_command)
    dataset.set_batch_size(args.batch_size)
    dataset.set_thread(args.num_thread)
    train_filelist = [
39 40
        os.path.join(args.train_data_dir, x)
        for x in os.listdir(args.train_data_dir)
C
CandyCaneLane 已提交
41 42 43 44 45 46 47 48 49
    ]

    print('---------------------------------------------')
    for epoch_id in range(args.num_epoch):
        start = time.time()
        dataset.set_filelist(train_filelist)
        exe.train_from_dataset(
            program=fluid.default_main_program(),
            dataset=dataset,
50 51
            fetch_list=[loss, auc],
            fetch_info=['epoch %d batch loss' % (epoch_id + 1), "auc"],
C
CandyCaneLane 已提交
52 53
            print_period=1000,
            debug=False)
54 55
        model_dir = os.path.join(args.model_output_dir,
                                 'epoch_' + str(epoch_id + 1))
C
CandyCaneLane 已提交
56 57 58 59 60 61 62 63 64
        sys.stderr.write('epoch%d is finished and takes %f s\n' % (
            (epoch_id + 1), time.time() - start))
        fluid.io.save_persistables(
            executor=exe,
            dirname=model_dir,
            main_program=fluid.default_main_program())


if __name__ == '__main__':
65
    utils.check_version()
C
CandyCaneLane 已提交
66
    train()