local_train.py 2.8 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# coding: utf-8
from __future__ import print_function, absolute_import, division
import os
import random
import sys
import time
8
from collections import OrderedDict
9 10 11 12 13

import paddle.fluid as fluid

from config import parse_args
from network import DCN
14
import utils
15 16 17 18 19 20 21 22 23 24 25
"""
train DCN model
"""


def train(args):
    """train and save DCN model

    :param args: hyperparams of model
    :return:
    """
26 27 28 29 30
    cat_feat_dims_dict = OrderedDict()
    for line in open(args.cat_feat_num):
        spls = line.strip().split()
        assert len(spls) == 2
        cat_feat_dims_dict[spls[0]] = int(spls[1])
31
    dcn_model = DCN(args.cross_num, args.dnn_hidden_units, args.l2_reg_cross,
32 33
                    args.use_bn, args.clip_by_norm, cat_feat_dims_dict,
                    args.is_sparse)
34 35 36 37 38 39
    dcn_model.build_network()
    dcn_model.backward(args.lr)

    # config dataset
    dataset = fluid.DatasetFactory().create_dataset()
    dataset.set_use_var(dcn_model.data_list)
40
    pipe_command = 'python reader.py {}'.format(args.vocab_dir)
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
    dataset.set_pipe_command(pipe_command)
    dataset.set_batch_size(args.batch_size)
    dataset.set_thread(args.num_thread)
    train_filelist = [
        os.path.join(args.train_data_dir, fname)
        for fname in next(os.walk(args.train_data_dir))[2]
    ]
    dataset.set_filelist(train_filelist)
    num_epoch = args.num_epoch
    if args.steps:
        epoch = args.steps * args.batch_size / 41000000
        full_epoch = int(epoch // 1)
        last_epoch = epoch % 1
        train_filelists = [train_filelist for _ in range(full_epoch)] + [
            random.sample(train_filelist, int(
                len(train_filelist) * last_epoch))
        ]
        num_epoch = full_epoch + 1
    print("train epoch: {}".format(num_epoch))

    # Executor
    exe = fluid.Executor(fluid.CPUPlace())
    exe.run(fluid.default_startup_program())

    for epoch_id in range(num_epoch):
        start = time.time()
        sys.stderr.write('\nepoch%d start ...\n' % (epoch_id + 1))
        dataset.set_filelist(train_filelists[epoch_id])
        exe.train_from_dataset(
            program=fluid.default_main_program(),
            dataset=dataset,
            fetch_list=[
                dcn_model.loss, dcn_model.avg_logloss, dcn_model.auc_var
            ],
            fetch_info=['total_loss', 'avg_logloss', 'auc'],
            debug=False,
            print_period=args.print_steps)
        model_dir = args.model_output_dir + '/epoch_' + str(epoch_id + 1)
        sys.stderr.write('epoch%d is finished and takes %f s\n' % (
            (epoch_id + 1), time.time() - start))
        fluid.io.save_persistables(
            executor=exe,
            dirname=model_dir,
            main_program=fluid.default_main_program())


if __name__ == '__main__':
    args = parse_args()
    print(args)
90
    utils.check_version()
91
    train(args)