local_train.py 2.9 KB
Newer Older
1 2 3 4 5
from __future__ import print_function, absolute_import, division
import os
import random
import sys
import time
6
from collections import OrderedDict
7 8 9 10 11

import paddle.fluid as fluid

from config import parse_args
from network import DCN
12
import utils
13 14 15 16 17 18 19 20 21 22 23
"""
train DCN model
"""


def train(args):
    """train and save DCN model

    :param args: hyperparams of model
    :return:
    """
24 25 26 27 28 29
    # ce
    if args.enable_ce:
        SEED = 102
        fluid.default_main_program().random_seed = SEED
        fluid.default_startup_program().random_seed = SEED

30 31 32 33 34
    cat_feat_dims_dict = OrderedDict()
    for line in open(args.cat_feat_num):
        spls = line.strip().split()
        assert len(spls) == 2
        cat_feat_dims_dict[spls[0]] = int(spls[1])
35
    dcn_model = DCN(args.cross_num, args.dnn_hidden_units, args.l2_reg_cross,
36 37
                    args.use_bn, args.clip_by_norm, cat_feat_dims_dict,
                    args.is_sparse)
38 39 40 41 42 43
    dcn_model.build_network()
    dcn_model.backward(args.lr)

    # config dataset
    dataset = fluid.DatasetFactory().create_dataset()
    dataset.set_use_var(dcn_model.data_list)
44
    pipe_command = 'python reader.py {}'.format(args.vocab_dir)
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
    dataset.set_pipe_command(pipe_command)
    dataset.set_batch_size(args.batch_size)
    dataset.set_thread(args.num_thread)
    train_filelist = [
        os.path.join(args.train_data_dir, fname)
        for fname in next(os.walk(args.train_data_dir))[2]
    ]
    dataset.set_filelist(train_filelist)
    num_epoch = args.num_epoch
    if args.steps:
        epoch = args.steps * args.batch_size / 41000000
        full_epoch = int(epoch // 1)
        last_epoch = epoch % 1
        train_filelists = [train_filelist for _ in range(full_epoch)] + [
            random.sample(train_filelist, int(
                len(train_filelist) * last_epoch))
        ]
        num_epoch = full_epoch + 1
    print("train epoch: {}".format(num_epoch))

    # Executor
    exe = fluid.Executor(fluid.CPUPlace())
    exe.run(fluid.default_startup_program())

    for epoch_id in range(num_epoch):
        start = time.time()
        sys.stderr.write('\nepoch%d start ...\n' % (epoch_id + 1))
        dataset.set_filelist(train_filelists[epoch_id])
        exe.train_from_dataset(
            program=fluid.default_main_program(),
            dataset=dataset,
            fetch_list=[
                dcn_model.loss, dcn_model.avg_logloss, dcn_model.auc_var
            ],
            fetch_info=['total_loss', 'avg_logloss', 'auc'],
            debug=False,
            print_period=args.print_steps)
82
        model_dir = os.path.join(args.model_output_dir,
83
                                 'epoch_' + str(epoch_id + 1), "checkpoint")
84 85
        sys.stderr.write('epoch%d is finished and takes %f s\n' % (
            (epoch_id + 1), time.time() - start))
86
        fluid.save(fluid.default_main_program(), model_dir)
87 88 89 90 91


if __name__ == '__main__':
    args = parse_args()
    print(args)
92
    utils.check_version()
93
    train(args)