train.py 3.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
import os
import gzip
import logging
import argparse

import paddle.v2 as paddle

from network_conf import DeepFM
import reader

logging.basicConfig()
logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)


def parse_args():
    parser = argparse.ArgumentParser(description="PaddlePaddle DeepFM example")
    parser.add_argument(
        '--train_data_path',
        type=str,
        required=True,
        help="path of training dataset")
W
wangmeng28 已提交
23 24 25 26 27
    parser.add_argument(
        '--test_data_path',
        type=str,
        required=True,
        help="path of testing dataset")
28 29 30
    parser.add_argument(
        '--batch_size',
        type=int,
W
wangmeng28 已提交
31 32
        default=1000,
        help="size of mini-batch (default:1000)")
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
    parser.add_argument(
        '--num_passes',
        type=int,
        default=10,
        help="number of passes to train (default: 10)")
    parser.add_argument(
        '--factor_size',
        type=int,
        default=10,
        help="the factor size for the factorization machine (default:10)")
    parser.add_argument(
        '--model_output_dir',
        type=str,
        default='models',
        help='path for model to store (default: models)')

    return parser.parse_args()


def train():
    args = parse_args()

    if not os.path.isdir(args.model_output_dir):
        os.mkdir(args.model_output_dir)

    paddle.init(use_gpu=False, trainer_count=1)

W
wangmeng28 已提交
60
    optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
61 62 63 64 65 66 67 68 69 70 71 72 73

    model = DeepFM(args.factor_size)

    params = paddle.parameters.create(model)

    trainer = paddle.trainer.SGD(
        cost=model, parameters=params, update_equation=optimizer)

    dataset = reader.Dataset()

    def __event_handler__(event):
        if isinstance(event, paddle.event.EndIteration):
            num_samples = event.batch_id * args.batch_size
W
wangmeng28 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
            if event.batch_id % 100 == 0:
                logger.warning("Pass %d, Batch %d, Samples %d, Cost %f, %s" %
                               (event.pass_id, event.batch_id, num_samples,
                                event.cost, event.metrics))

            if event.batch_id % 10000 == 0:
                if args.test_data_path:
                    result = trainer.test(
                        reader=paddle.batch(
                            dataset.test(args.test_data_path),
                            batch_size=args.batch_size),
                        feeding=reader.feeding)
                    logger.warning("Test %d-%d, Cost %f, %s" %
                                   (event.pass_id, event.batch_id, result.cost,
                                    result.metrics))
89 90 91 92 93 94 95 96 97 98

                path = "{}/model-pass-{}-batch-{}.tar.gz".format(
                    args.model_output_dir, event.pass_id, event.batch_id)
                with gzip.open(path, 'w') as f:
                    trainer.save_parameter_to_tar(f)

    trainer.train(
        reader=paddle.batch(
            paddle.reader.shuffle(
                dataset.train(args.train_data_path),
W
wangmeng28 已提交
99
                buf_size=args.batch_size * 10000),
100 101 102 103 104 105 106 107
            batch_size=args.batch_size),
        feeding=reader.feeding,
        event_handler=__event_handler__,
        num_passes=args.num_passes)


if __name__ == '__main__':
    train()