train.py 5.1 KB
Newer Older
P
peterzhang2029 已提交
1 2 3 4 5 6 7
import numpy as np
import sys
import os
import argparse
import time

import paddle.v2 as paddle
L
Luo Tao 已提交
8
import paddle.fluid as fluid
P
peterzhang2029 已提交
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42

from config import TrainConfig as conf


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--dict_path',
        type=str,
        required=True,
        help="Path of the word dictionary.")
    return parser.parse_args()


# Define to_lodtensor function to process the sequential data.
def to_lodtensor(data, place):
    seq_lens = [len(seq) for seq in data]
    cur_len = 0
    lod = [cur_len]
    for l in seq_lens:
        cur_len += l
        lod.append(cur_len)
    flattened_data = np.concatenate(data, axis=0).astype("int64")
    flattened_data = flattened_data.reshape([len(flattened_data), 1])
    res = fluid.LoDTensor()
    res.set(flattened_data, place)
    res.set_lod([lod])
    return res


# Load the dictionary.
def load_vocab(filename):
    vocab = {}
    with open(filename) as f:
43 44
        for idx, line in enumerate(f):
            vocab[line.strip()] = idx
P
peterzhang2029 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    return vocab


# Define the convolution model.
def conv_net(dict_dim,
             window_size=3,
             emb_dim=128,
             num_filters=128,
             fc0_dim=96,
             class_dim=2):

    data = fluid.layers.data(
        name="words", shape=[1], dtype="int64", lod_level=1)

    label = fluid.layers.data(name="label", shape=[1], dtype="int64")

    emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])

    conv_3 = fluid.nets.sequence_conv_pool(
        input=emb,
        num_filters=num_filters,
        filter_size=window_size,
        act="tanh",
        pool_type="max")

    fc_0 = fluid.layers.fc(input=[conv_3], size=fc0_dim)

    prediction = fluid.layers.fc(input=[fc_0], size=class_dim, act="softmax")

    cost = fluid.layers.cross_entropy(input=prediction, label=label)

    avg_cost = fluid.layers.mean(x=cost)

    return data, label, prediction, avg_cost


def main(dict_path):
    word_dict = load_vocab(dict_path)
    word_dict["<unk>"] = len(word_dict)
    dict_dim = len(word_dict)
    print("The dictionary size is : %d" % dict_dim)

    data, label, prediction, avg_cost = conv_net(dict_dim)

    sgd_optimizer = fluid.optimizer.SGD(learning_rate=conf.learning_rate)
    sgd_optimizer.minimize(avg_cost)

F
fengjiayi 已提交
92 93 94
    batch_size_var = fluid.layers.create_tensor(dtype='int64')
    batch_acc_var = fluid.layers.accuracy(
        input=prediction, label=label, total=batch_size_var)
P
peterzhang2029 已提交
95 96 97

    inference_program = fluid.default_main_program().clone()
    with fluid.program_guard(inference_program):
F
fengjiayi 已提交
98 99
        inference_program = fluid.io.get_inference_program(
            target_vars=[batch_acc_var, batch_size_var])
P
peterzhang2029 已提交
100 101 102 103

    # The training data set.
    train_reader = paddle.batch(
        paddle.reader.shuffle(
104
            paddle.dataset.imdb.train(word_dict), buf_size=51200),
P
peterzhang2029 已提交
105 106 107 108 109
        batch_size=conf.batch_size)

    # The testing data set.
    test_reader = paddle.batch(
        paddle.reader.shuffle(
110
            paddle.dataset.imdb.test(word_dict), buf_size=51200),
P
peterzhang2029 已提交
111 112 113 114 115 116 117 118 119 120 121 122 123
        batch_size=conf.batch_size)

    if conf.use_gpu:
        place = fluid.CUDAPlace(0)
    else:
        place = fluid.CPUPlace()

    exe = fluid.Executor(place)

    feeder = fluid.DataFeeder(feed_list=[data, label], place=place)

    exe.run(fluid.default_startup_program())

F
fengjiayi 已提交
124 125 126
    train_pass_acc_evaluator = fluid.average.WeightedAverage()
    test_pass_acc_evaluator = fluid.average.WeightedAverage()

P
peterzhang2029 已提交
127
    def test(exe):
F
fengjiayi 已提交
128
        test_pass_acc_evaluator.reset()
P
peterzhang2029 已提交
129 130 131 132
        for batch_id, data in enumerate(test_reader()):
            input_seq = to_lodtensor(map(lambda x: x[0], data), place)
            y_data = np.array(map(lambda x: x[1], data)).astype("int64")
            y_data = y_data.reshape([-1, 1])
F
fengjiayi 已提交
133 134 135 136 137 138
            b_acc, b_size = exe.run(inference_program,
                                    feed={"words": input_seq,
                                          "label": y_data},
                                    fetch_list=[batch_acc_var, batch_size_var])
            test_pass_acc_evaluator.add(value=b_acc, weight=b_size)
        test_acc = test_pass_acc_evaluator.eval()
P
peterzhang2029 已提交
139 140 141 142
        return test_acc

    total_time = 0.
    for pass_id in xrange(conf.num_passes):
F
fengjiayi 已提交
143
        train_pass_acc_evaluator.reset()
P
peterzhang2029 已提交
144 145
        start_time = time.time()
        for batch_id, data in enumerate(train_reader()):
F
fengjiayi 已提交
146
            cost_val, acc_val, size_val = exe.run(
P
peterzhang2029 已提交
147 148
                fluid.default_main_program(),
                feed=feeder.feed(data),
F
fengjiayi 已提交
149 150
                fetch_list=[avg_cost, batch_acc_var, batch_size_var])
            train_pass_acc_evaluator.add(value=acc_val, weight=size_val)
P
peterzhang2029 已提交
151
            if batch_id and batch_id % conf.log_period == 0:
F
fengjiayi 已提交
152
                print("Pass id: %d, batch id: %d, cost: %f, pass_acc: %f" %
F
fengjiayi 已提交
153 154
                      (pass_id, batch_id, cost_val,
                       train_pass_acc_evaluator.eval()))
P
peterzhang2029 已提交
155 156 157 158 159 160 161 162 163 164
        end_time = time.time()
        total_time += (end_time - start_time)
        pass_test_acc = test(exe)
        print("Pass id: %d, test_acc: %f" % (pass_id, pass_test_acc))
    print("Total train time: %f" % (total_time))


if __name__ == '__main__':
    args = parse_args()
    main(args.dict_path)