text_classification_dnn.py 4.7 KB
Newer Older
1
import sys
F
fengjiayi 已提交
2
import math
3 4 5 6
import paddle.v2 as paddle
import gzip


7 8 9 10 11 12 13 14 15 16 17 18
def fc_net(dict_dim, class_dim=2, emb_dim=28):
    """
    dnn network definition

    :param dict_dim: size of word dictionary
    :type input_dim: int
    :params class_dim: number of instance class
    :type class_dim: int
    :params emb_dim: embedding vector dimension
    :type emb_dim: int
    """

F
fengjiayi 已提交
19
    # input layers
20
    data = paddle.layer.data("word",
21
                             paddle.data_type.integer_value_sequence(dict_dim))
22 23
    lbl = paddle.layer.data("label", paddle.data_type.integer_value(class_dim))

F
fengjiayi 已提交
24
    # embedding layer
25
    emb = paddle.layer.embedding(input=data, size=emb_dim)
F
fengjiayi 已提交
26
    # max pooling
27 28 29
    seq_pool = paddle.layer.pooling(
        input=emb, pooling_type=paddle.pooling.Max())

F
fengjiayi 已提交
30
    # two hidden layers
31 32
    hd_layer_size = [28, 8]
    hd_layer_init_std = [1.0 / math.sqrt(s) for s in hd_layer_size]
33 34
    hd1 = paddle.layer.fc(
        input=seq_pool,
F
fengjiayi 已提交
35
        size=hd_layer_size[0],
36
        act=paddle.activation.Tanh(),
F
fengjiayi 已提交
37
        param_attr=paddle.attr.Param(initial_std=hd_layer_init_std[0]))
38 39
    hd2 = paddle.layer.fc(
        input=hd1,
F
fengjiayi 已提交
40
        size=hd_layer_size[1],
41
        act=paddle.activation.Tanh(),
F
fengjiayi 已提交
42
        param_attr=paddle.attr.Param(initial_std=hd_layer_init_std[1]))
43

F
fengjiayi 已提交
44
    # output layer
45 46 47 48
    output = paddle.layer.fc(
        input=hd2,
        size=class_dim,
        act=paddle.activation.Softmax(),
49
        param_attr=paddle.attr.Param(initial_std=1.0 / math.sqrt(class_dim)))
50 51 52

    cost = paddle.layer.classification_cost(input=output, label=lbl)

F
fengjiayi 已提交
53
    return cost, output, lbl
54 55 56


def train_dnn_model(num_pass):
57 58 59 60 61 62 63
    """
    train dnn model

    :params num_pass: train pass number
    :type num_pass: int
    """

64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
    # load word dictionary
    print 'load dictionary...'
    word_dict = paddle.dataset.imdb.word_dict()

    dict_dim = len(word_dict)
    class_dim = 2
    # define data reader
    train_reader = paddle.batch(
        paddle.reader.shuffle(
            lambda: paddle.dataset.imdb.train(word_dict), buf_size=1000),
        batch_size=100)
    test_reader = paddle.batch(
        lambda: paddle.dataset.imdb.test(word_dict), batch_size=100)

    # network config
F
fengjiayi 已提交
79 80
    [cost, output, label] = fc_net(dict_dim, class_dim=class_dim)

81 82 83 84
    # create parameters
    parameters = paddle.parameters.create(cost)
    # create optimizer
    adam_optimizer = paddle.optimizer.Adam(
85 86
        learning_rate=1e-3,
        regularization=paddle.optimizer.L2Regularization(rate=1e-3),
87 88
        model_average=paddle.optimizer.ModelAverage(average_window=0.5))

F
fengjiayi 已提交
89 90 91
    # add auc evaluator
    paddle.evaluator.auc(input=output, label=label)

92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    # create trainer
    trainer = paddle.trainer.SGD(
        cost=cost, parameters=parameters, update_equation=adam_optimizer)

    # Define end batch and end pass event handler
    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
            if event.batch_id % 100 == 0:
                print "\nPass %d, Batch %d, Cost %f, %s" % (
                    event.pass_id, event.batch_id, event.cost, event.metrics)
            else:
                sys.stdout.write('.')
                sys.stdout.flush()
        if isinstance(event, paddle.event.EndPass):
            result = trainer.test(reader=test_reader, feeding=feeding)
            print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
F
fengjiayi 已提交
108 109
            with gzip.open("dnn_params_pass" + str(event.pass_id) + ".tar.gz",
                           'w') as f:
110 111 112 113 114 115 116 117 118 119 120 121 122
                parameters.to_tar(f)

    # begin training network
    feeding = {'word': 0, 'label': 1}
    trainer.train(
        reader=train_reader,
        event_handler=event_handler,
        feeding=feeding,
        num_passes=num_pass)

    print("Training finished.")


F
fengjiayi 已提交
123
def dnn_infer(file_name):
124 125 126 127 128 129 130
    """
    predict instance labels by dnn network

    :params file_name: network parameter file
    :type file_name: str
    """

131 132 133 134 135 136
    print("Begin to predict...")

    word_dict = paddle.dataset.imdb.word_dict()
    dict_dim = len(word_dict)
    class_dim = 2

F
fengjiayi 已提交
137
    [_, output, _] = fc_net(dict_dim, class_dim=class_dim)
F
fengjiayi 已提交
138
    parameters = paddle.parameters.Parameters.from_tar(gzip.open(file_name))
139 140

    infer_data = []
F
fengjiayi 已提交
141
    infer_data_label = []
142 143
    for item in paddle.dataset.imdb.test(word_dict):
        infer_data.append([item[0]])
F
fengjiayi 已提交
144
        infer_data_label.append(item[1])
145 146 147 148 149 150 151

    predictions = paddle.infer(
        output_layer=output,
        parameters=parameters,
        input=infer_data,
        field=['value'])
    for i, prob in enumerate(predictions):
F
fengjiayi 已提交
152
        print prob, infer_data_label[i]
153 154 155


if __name__ == "__main__":
156
    paddle.init(use_gpu=False, trainer_count=1)
F
fengjiayi 已提交
157 158 159 160
    num_pass = 5
    train_dnn_model(num_pass=num_pass)
    param_file_name = "dnn_params_pass" + str(num_pass - 1) + ".tar.gz"
    dnn_infer(file_name=param_file_name)