train.py 3.7 KB
Newer Older
S
Superjom 已提交
1
import argparse
S
Superjom 已提交
2
import gzip
S
Superjom 已提交
3

S
Superjom 已提交
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
import reader
import paddle.v2 as paddle
from utils import logger, ModelType
from network_conf import CTRmodel


def parse_args():
    parser = argparse.ArgumentParser(description="PaddlePaddle CTR example")
    parser.add_argument(
        '--train_data_path',
        type=str,
        required=True,
        help="path of training dataset")
    parser.add_argument(
        '--test_data_path', type=str, help='path of testing dataset')
    parser.add_argument(
        '--batch_size',
        type=int,
        default=10000,
        help="size of mini-batch (default:10000)")
    parser.add_argument(
        '--num_passes', type=int, default=10, help="number of passes to train")
    parser.add_argument(
        '--model_output_prefix',
        type=str,
        default='./ctr_models',
        help='prefix of path for model to store (default: ./ctr_models)')
    parser.add_argument(
        '--data_meta_file',
        type=str,
        required=True,
        help='path of data meta info file', )
    parser.add_argument(
        '--model_type',
        type=int,
        required=True,
        default=ModelType.CLASSIFICATION,
        help='model type, classification: %d, regression %d (default classification)'
        % (ModelType.CLASSIFICATION, ModelType.REGRESSION))

    return parser.parse_args()
S
Superjom 已提交
45

S
Superjom 已提交
46

S
Superjom 已提交
47
dnn_layer_dims = [128, 64, 32, 1]
S
Superjom 已提交
48 49 50 51 52 53

# ==============================================================================
#                   cost and train period
# ==============================================================================


S
Superjom 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
def train():
    args = parse_args()
    args.model_type = ModelType(args.model_type)
    paddle.init(use_gpu=False, trainer_count=1)
    dnn_input_dim, lr_input_dim = reader.load_data_meta(args.data_meta_file)

    # create ctr model.
    model = CTRmodel(
        dnn_layer_dims,
        dnn_input_dim,
        lr_input_dim,
        model_type=args.model_type,
        is_infer=False)

    params = paddle.parameters.create(model.train_cost)
    optimizer = paddle.optimizer.AdaGrad()

71 72 73
    trainer = paddle.trainer.SGD(cost=model.train_cost,
                                 parameters=params,
                                 update_equation=optimizer)
S
Superjom 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98

    dataset = reader.Dataset()

    def __event_handler__(event):
        if isinstance(event, paddle.event.EndIteration):
            num_samples = event.batch_id * args.batch_size
            if event.batch_id % 100 == 0:
                logger.warning("Pass %d, Samples %d, Cost %f, %s" % (
                    event.pass_id, num_samples, event.cost, event.metrics))

            if event.batch_id % 1000 == 0:
                if args.test_data_path:
                    result = trainer.test(
                        reader=paddle.batch(
                            dataset.test(args.test_data_path),
                            batch_size=args.batch_size),
                        feeding=reader.feeding_index)
                    logger.warning("Test %d-%d, Cost %f, %s" %
                                   (event.pass_id, event.batch_id, result.cost,
                                    result.metrics))

                path = "{}-pass-{}-batch-{}-test-{}.tar.gz".format(
                    args.model_output_prefix, event.pass_id, event.batch_id,
                    result.cost)
                with gzip.open(path, 'w') as f:
99
                    trainer.save_parameter_to_tar(f)
S
Superjom 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112

    trainer.train(
        reader=paddle.batch(
            paddle.reader.shuffle(
                dataset.train(args.train_data_path), buf_size=500),
            batch_size=args.batch_size),
        feeding=reader.feeding_index,
        event_handler=__event_handler__,
        num_passes=args.num_passes)


if __name__ == '__main__':
    train()