train.py 3.7 KB
Newer Older
S
Superjom 已提交
1 2 3 4 5 6 7
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import paddle.v2 as paddle
from paddle.v2 import layer
from paddle.v2 import data_type as dtype
S
Superjom 已提交
8
from data_provider import field_index, detect_dataset, AvazuDataset
S
Superjom 已提交
9

S
Superjom 已提交
10
id_features_space = 100000
S
Superjom 已提交
11 12
dnn_layer_dims = [128, 64, 32, 1]
train_data_path = './train.txt'
S
Superjom 已提交
13
data_meta_info = detect_dataset(train_data_path, 500000)
S
Superjom 已提交
14
batch_size = 10000
S
Superjom 已提交
15
test_set_size = 10000
S
Superjom 已提交
16 17 18 19 20

logging.warning('detect categorical fields in dataset %s' % train_data_path)
for key, item in data_meta_info.items():
    logging.warning('    - {}\t{}'.format(key, item))

S
Superjom 已提交
21
paddle.init(use_gpu=False, trainer_count=1)
S
Superjom 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35

# ==============================================================================
#                    input layers
# ==============================================================================
dnn_merged_input = layer.data(
    name='dnn_input',
    type=paddle.data_type.sparse_binary_vector(data_meta_info['dnn_input']))

lr_merged_input = layer.data(
    name='lr_input',
    type=paddle.data_type.sparse_binary_vector(data_meta_info['lr_input']))

click = paddle.layer.data(name='click', type=dtype.dense_vector(1))

S
Superjom 已提交
36

S
Superjom 已提交
37 38 39 40 41 42
# ==============================================================================
#                    network structure
# ==============================================================================
def build_dnn_submodel(dnn_layer_dims):
    dnn_embedding = layer.fc(input=dnn_merged_input, size=dnn_layer_dims[0])
    _input_layer = dnn_embedding
S
Superjom 已提交
43
    for i, dim in enumerate(dnn_layer_dims[1:]):
S
Superjom 已提交
44 45 46 47
        fc = layer.fc(
            input=_input_layer,
            size=dim,
            act=paddle.activation.Relu(),
S
Superjom 已提交
48
            name='dnn-fc-%d' % i)
S
Superjom 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
        _input_layer = fc
    return _input_layer


# config LR submodel
def build_lr_submodel():
    fc = layer.fc(
        input=lr_merged_input, size=1, name='lr', act=paddle.activation.Relu())
    return fc


# conbine DNN and LR submodels
def combine_submodels(dnn, lr):
    merge_layer = layer.concat(input=[dnn, lr])
    fc = layer.fc(
        input=merge_layer,
        size=1,
        name='output',
        # use sigmoid function to approximate ctr rate, a float value between 0 and 1.
        act=paddle.activation.Sigmoid())
    return fc


dnn = build_dnn_submodel(dnn_layer_dims)
lr = build_lr_submodel()
output = combine_submodels(dnn, lr)

# ==============================================================================
#                   cost and train period
# ==============================================================================
classification_cost = paddle.layer.multi_binary_label_cross_entropy_cost(
    input=output, label=click)

params = paddle.parameters.create(classification_cost)

S
Superjom 已提交
84
optimizer = paddle.optimizer.Momentum(momentum=0.01)
S
Superjom 已提交
85 86 87 88

trainer = paddle.trainer.SGD(
    cost=classification_cost, parameters=params, update_equation=optimizer)

S
Superjom 已提交
89
dataset = AvazuDataset(train_data_path, n_records_as_test=test_set_size)
S
Superjom 已提交
90

S
Superjom 已提交
91

S
Superjom 已提交
92 93
def event_handler(event):
    if isinstance(event, paddle.event.EndIteration):
S
Superjom 已提交
94
        num_samples = event.batch_id * batch_size
S
Superjom 已提交
95
        if event.batch_id % 100 == 0:
S
Superjom 已提交
96
            logging.warning("Pass %d, Samples %d, Cost %f" %
S
Superjom 已提交
97
                            (event.pass_id, num_samples, event.cost))
S
Superjom 已提交
98

S
Superjom 已提交
99 100 101 102
        if event.batch_id % 1000 == 0:
            result = trainer.test(
                reader=paddle.batch(dataset.test, batch_size=1000),
                feeding=field_index)
S
Superjom 已提交
103 104
            logging.warning("Test %d-%d, Cost %f" %
                            (event.pass_id, event.batch_id, result.cost))
S
Superjom 已提交
105 106 107 108


trainer.train(
    reader=paddle.batch(
S
Superjom 已提交
109 110
        paddle.reader.shuffle(dataset.train, buf_size=500),
        batch_size=batch_size),
S
Superjom 已提交
111
    feeding=field_index,
S
Superjom 已提交
112 113
    event_handler=event_handler,
    num_passes=100)