train.py 4.4 KB
Newer Older
S
Superjom 已提交
1 2 3
#!/usr/bin/env python
# -*- coding: utf-8 -*-

S
Superjom 已提交
4
import argparse
S
Superjom 已提交
5 6 7 8
import logging
import paddle.v2 as paddle
from paddle.v2 import layer
from paddle.v2 import data_type as dtype
S
Superjom 已提交
9
from data_provider import field_index, detect_dataset, AvazuDataset
S
Superjom 已提交
10

S
Superjom 已提交
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
parser = argparse.ArgumentParser(description="PaddlePaddle CTR example")
parser.add_argument(
    '--train_data_path',
    type=str,
    required=True,
    help="path of training dataset")
parser.add_argument(
    '--batch_size',
    type=int,
    default=10000,
    help="size of mini-batch (default:10000)")
parser.add_argument(
    '--test_set_size',
    type=int,
    default=10000,
    help="size of the validation dataset(default: 10000)")
parser.add_argument(
    '--num_passes', type=int, default=10, help="number of passes to train")
parser.add_argument(
    '--num_lines_to_detact',
    type=int,
    default=500000,
    help="number of records to detect dataset's meta info")

args = parser.parse_args()

S
Superjom 已提交
37
dnn_layer_dims = [128, 64, 32, 1]
S
Superjom 已提交
38
data_meta_info = detect_dataset(args.train_data_path, args.num_lines_to_detact)
S
Superjom 已提交
39

S
Superjom 已提交
40 41
logging.warning('detect categorical fields in dataset %s' %
                args.train_data_path)
S
Superjom 已提交
42 43 44
for key, item in data_meta_info.items():
    logging.warning('    - {}\t{}'.format(key, item))

S
Superjom 已提交
45
paddle.init(use_gpu=False, trainer_count=1)
S
Superjom 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59

# ==============================================================================
#                    input layers
# ==============================================================================
dnn_merged_input = layer.data(
    name='dnn_input',
    type=paddle.data_type.sparse_binary_vector(data_meta_info['dnn_input']))

lr_merged_input = layer.data(
    name='lr_input',
    type=paddle.data_type.sparse_binary_vector(data_meta_info['lr_input']))

click = paddle.layer.data(name='click', type=dtype.dense_vector(1))

S
Superjom 已提交
60

S
Superjom 已提交
61 62 63 64 65 66
# ==============================================================================
#                    network structure
# ==============================================================================
def build_dnn_submodel(dnn_layer_dims):
    dnn_embedding = layer.fc(input=dnn_merged_input, size=dnn_layer_dims[0])
    _input_layer = dnn_embedding
S
Superjom 已提交
67
    for i, dim in enumerate(dnn_layer_dims[1:]):
S
Superjom 已提交
68 69 70 71
        fc = layer.fc(
            input=_input_layer,
            size=dim,
            act=paddle.activation.Relu(),
S
Superjom 已提交
72
            name='dnn-fc-%d' % i)
S
Superjom 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
        _input_layer = fc
    return _input_layer


# config LR submodel
def build_lr_submodel():
    fc = layer.fc(
        input=lr_merged_input, size=1, name='lr', act=paddle.activation.Relu())
    return fc


# conbine DNN and LR submodels
def combine_submodels(dnn, lr):
    merge_layer = layer.concat(input=[dnn, lr])
    fc = layer.fc(
        input=merge_layer,
        size=1,
        name='output',
        # use sigmoid function to approximate ctr rate, a float value between 0 and 1.
        act=paddle.activation.Sigmoid())
    return fc


dnn = build_dnn_submodel(dnn_layer_dims)
lr = build_lr_submodel()
output = combine_submodels(dnn, lr)

# ==============================================================================
#                   cost and train period
# ==============================================================================
classification_cost = paddle.layer.multi_binary_label_cross_entropy_cost(
    input=output, label=click)

params = paddle.parameters.create(classification_cost)

S
Superjom 已提交
108
optimizer = paddle.optimizer.Momentum(momentum=0.01)
S
Superjom 已提交
109 110 111 112

trainer = paddle.trainer.SGD(
    cost=classification_cost, parameters=params, update_equation=optimizer)

S
Superjom 已提交
113 114
dataset = AvazuDataset(
    args.train_data_path, n_records_as_test=args.test_set_size)
S
Superjom 已提交
115

S
Superjom 已提交
116

S
Superjom 已提交
117 118
def event_handler(event):
    if isinstance(event, paddle.event.EndIteration):
S
Superjom 已提交
119
        num_samples = event.batch_id * args.batch_size
S
Superjom 已提交
120
        if event.batch_id % 100 == 0:
S
Superjom 已提交
121
            logging.warning("Pass %d, Samples %d, Cost %f" %
S
Superjom 已提交
122
                            (event.pass_id, num_samples, event.cost))
S
Superjom 已提交
123

S
Superjom 已提交
124 125
        if event.batch_id % 1000 == 0:
            result = trainer.test(
S
Superjom 已提交
126
                reader=paddle.batch(dataset.test, batch_size=args.batch_size),
S
Superjom 已提交
127
                feeding=field_index)
S
Superjom 已提交
128 129
            logging.warning("Test %d-%d, Cost %f" %
                            (event.pass_id, event.batch_id, result.cost))
S
Superjom 已提交
130 131 132 133


trainer.train(
    reader=paddle.batch(
S
Superjom 已提交
134
        paddle.reader.shuffle(dataset.train, buf_size=500),
S
Superjom 已提交
135
        batch_size=args.batch_size),
S
Superjom 已提交
136
    feeding=field_index,
S
Superjom 已提交
137
    event_handler=event_handler,
S
Superjom 已提交
138
    num_passes=args.num_passes)