train.py 4.1 KB
Newer Older
1
import os
P
peterzhang2029 已提交
2 3
import sys
import time
G
gmcather 已提交
4 5
import unittest
import contextlib
G
gmcather 已提交
6

Y
Yibing Liu 已提交
7
import paddle
G
gmcather 已提交
8 9
import paddle.fluid as fluid

G
gmcather 已提交
10 11 12 13 14 15
import utils
from nets import bow_net
from nets import cnn_net
from nets import lstm_net
from nets import gru_net

G
gmcather 已提交
16

G
gmcather 已提交
17
def train(train_reader,
G
gmcather 已提交
18 19 20 21 22 23 24 25
          word_dict,
          network,
          use_cuda,
          parallel,
          save_dirname,
          lr=0.2,
          batch_size=128,
          pass_num=30):
G
gmcather 已提交
26 27 28
    """
    train network
    """
P
peterzhang2029 已提交
29
    data = fluid.layers.data(
G
gmcather 已提交
30
        name="words", shape=[1], dtype="int64", lod_level=1)
G
gmcather 已提交
31

G
gmcather 已提交
32
    label = fluid.layers.data(name="label", shape=[1], dtype="int64")
G
gmcather 已提交
33 34

    if not parallel:
G
gmcather 已提交
35
        cost, acc, prediction = network(data, label, len(word_dict))
G
gmcather 已提交
36
    else:
G
gmcather 已提交
37
        places = fluid.layers.get_places(device_count=2)
G
gmcather 已提交
38 39 40
        pd = fluid.layers.ParallelDo(places)
        with pd.do():
            cost, acc, prediction = network(
G
gmcather 已提交
41
                pd.read_input(data), pd.read_input(label), len(word_dict))
P
peterzhang2029 已提交
42

G
gmcather 已提交
43 44
            pd.write_output(cost)
            pd.write_output(acc)
P
peterzhang2029 已提交
45

G
gmcather 已提交
46 47 48
        cost, acc = pd()
        cost = fluid.layers.mean(cost)
        acc = fluid.layers.mean(acc)
P
peterzhang2029 已提交
49

G
gmcather 已提交
50 51
    sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=lr)
    sgd_optimizer.minimize(cost)
P
peterzhang2029 已提交
52

G
gmcather 已提交
53
    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
P
peterzhang2029 已提交
54 55 56
    exe = fluid.Executor(place)
    feeder = fluid.DataFeeder(feed_list=[data, label], place=place)

57
    # For internal continuous evaluation
58
    if "CE_MODE_X" in os.environ:
59
        fluid.default_startup_program().random_seed = 110
P
peterzhang2029 已提交
60
    exe.run(fluid.default_startup_program())
G
gmcather 已提交
61
    for pass_id in xrange(pass_num):
62
        pass_start = time.time()
G
gmcather 已提交
63
        data_size, data_count, total_acc, total_cost = 0, 0, 0.0, 0.0
G
gmcather 已提交
64 65
        for data in train_reader():
            avg_cost_np, avg_acc_np = exe.run(fluid.default_main_program(),
G
gmcather 已提交
66 67
                                              feed=feeder.feed(data),
                                              fetch_list=[cost, acc])
G
gmcather 已提交
68 69 70 71 72
            data_size = len(data)
            total_acc += data_size * avg_acc_np
            total_cost += data_size * avg_cost_np
            data_count += data_size
        avg_cost = total_cost / data_count
G
gmcather 已提交
73

G
gmcather 已提交
74
        avg_acc = total_acc / data_count
G
gmcather 已提交
75
        print("pass_id: %d, avg_acc: %f, avg_cost: %f" %
G
gmcather 已提交
76 77
              (pass_id, avg_acc, avg_cost))

G
gmcather 已提交
78
        epoch_model = save_dirname + "/" + "epoch" + str(pass_id)
G
gmcather 已提交
79
        fluid.io.save_inference_model(epoch_model, ["words", "label"], acc, exe)
G
gmcather 已提交
80

81 82
        pass_end = time.time()
        # For internal continuous evaluation
83
        if "CE_MODE_X" in os.environ:
84 85 86 87
            print("kpis	train_acc	%f" % avg_acc)
            print("kpis	train_cost	%f" % avg_cost)
            print("kpis	train_duration	%f" % (pass_end - pass_start))

G
gmcather 已提交
88 89 90

def train_net():
    word_dict, train_reader, test_reader = utils.prepare_data(
G
gmcather 已提交
91
        "imdb", self_dict=False, batch_size=128, buf_size=50000)
G
gmcather 已提交
92

G
gmcather 已提交
93
    if sys.argv[1] == "bow":
G
gmcather 已提交
94 95 96 97 98 99 100 101 102 103
        train(
            train_reader,
            word_dict,
            bow_net,
            use_cuda=False,
            parallel=False,
            save_dirname="bow_model",
            lr=0.002,
            pass_num=30,
            batch_size=128)
G
gmcather 已提交
104
    elif sys.argv[1] == "cnn":
G
gmcather 已提交
105 106 107 108 109 110 111 112 113 114
        train(
            train_reader,
            word_dict,
            cnn_net,
            use_cuda=True,
            parallel=False,
            save_dirname="cnn_model",
            lr=0.01,
            pass_num=30,
            batch_size=4)
G
gmcather 已提交
115
    elif sys.argv[1] == "lstm":
G
gmcather 已提交
116 117 118 119 120 121 122 123 124 125
        train(
            train_reader,
            word_dict,
            lstm_net,
            use_cuda=True,
            parallel=False,
            save_dirname="lstm_model",
            lr=0.05,
            pass_num=30,
            batch_size=4)
G
gmcather 已提交
126
    elif sys.argv[1] == "gru":
G
gmcather 已提交
127 128 129 130 131 132 133 134 135 136
        train(
            train_reader,
            word_dict,
            lstm_net,
            use_cuda=True,
            parallel=False,
            save_dirname="gru_model",
            lr=0.05,
            pass_num=30,
            batch_size=128)
G
gmcather 已提交
137 138
    else:
        print("network name cannot be found!")
G
gmcather 已提交
139
        sys.exit(1)
P
peterzhang2029 已提交
140

G
gmcather 已提交
141

G
gmcather 已提交
142 143
if __name__ == "__main__":
    train_net()