train.py 4.1 KB
Newer Older
1
import os
M
minqiyang 已提交
2
import six
P
peterzhang2029 已提交
3 4
import sys
import time
G
gmcather 已提交
5 6
import unittest
import contextlib
G
gmcather 已提交
7

Y
Yibing Liu 已提交
8
import paddle
G
gmcather 已提交
9 10
import paddle.fluid as fluid

G
gmcather 已提交
11 12 13 14 15 16
import utils
from nets import bow_net
from nets import cnn_net
from nets import lstm_net
from nets import gru_net

G
gmcather 已提交
17

G
gmcather 已提交
18
def train(train_reader,
G
gmcather 已提交
19 20 21 22 23 24 25 26
          word_dict,
          network,
          use_cuda,
          parallel,
          save_dirname,
          lr=0.2,
          batch_size=128,
          pass_num=30):
G
gmcather 已提交
27 28 29
    """
    train network
    """
P
peterzhang2029 已提交
30
    data = fluid.layers.data(
G
gmcather 已提交
31
        name="words", shape=[1], dtype="int64", lod_level=1)
G
gmcather 已提交
32

G
gmcather 已提交
33
    label = fluid.layers.data(name="label", shape=[1], dtype="int64")
G
gmcather 已提交
34 35

    if not parallel:
G
gmcather 已提交
36
        cost, acc, prediction = network(data, label, len(word_dict))
G
gmcather 已提交
37
    else:
38
        places = fluid.layers.device.get_places(device_count=2)
G
gmcather 已提交
39 40 41
        pd = fluid.layers.ParallelDo(places)
        with pd.do():
            cost, acc, prediction = network(
G
gmcather 已提交
42
                pd.read_input(data), pd.read_input(label), len(word_dict))
P
peterzhang2029 已提交
43

G
gmcather 已提交
44 45
            pd.write_output(cost)
            pd.write_output(acc)
P
peterzhang2029 已提交
46

G
gmcather 已提交
47 48 49
        cost, acc = pd()
        cost = fluid.layers.mean(cost)
        acc = fluid.layers.mean(acc)
P
peterzhang2029 已提交
50

G
gmcather 已提交
51 52
    sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=lr)
    sgd_optimizer.minimize(cost)
P
peterzhang2029 已提交
53

G
gmcather 已提交
54
    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
P
peterzhang2029 已提交
55 56 57
    exe = fluid.Executor(place)
    feeder = fluid.DataFeeder(feed_list=[data, label], place=place)

58
    # For internal continuous evaluation
59
    if "CE_MODE_X" in os.environ:
60
        fluid.default_startup_program().random_seed = 110
P
peterzhang2029 已提交
61
    exe.run(fluid.default_startup_program())
M
minqiyang 已提交
62
    for pass_id in six.moves.xrange(pass_num):
63
        pass_start = time.time()
G
gmcather 已提交
64
        data_size, data_count, total_acc, total_cost = 0, 0, 0.0, 0.0
G
gmcather 已提交
65 66
        for data in train_reader():
            avg_cost_np, avg_acc_np = exe.run(fluid.default_main_program(),
G
gmcather 已提交
67 68
                                              feed=feeder.feed(data),
                                              fetch_list=[cost, acc])
G
gmcather 已提交
69 70 71 72 73
            data_size = len(data)
            total_acc += data_size * avg_acc_np
            total_cost += data_size * avg_cost_np
            data_count += data_size
        avg_cost = total_cost / data_count
G
gmcather 已提交
74

G
gmcather 已提交
75
        avg_acc = total_acc / data_count
C
ccmeteorljh 已提交
76 77
        print("pass_id: %d, avg_acc: %f, avg_cost: %f, pass_time_cost: %f" %
              (pass_id, avg_acc, avg_cost, time.time() - pass_start))
G
gmcather 已提交
78

G
gmcather 已提交
79
        epoch_model = save_dirname + "/" + "epoch" + str(pass_id)
G
gmcather 已提交
80
        fluid.io.save_inference_model(epoch_model, ["words", "label"], acc, exe)
G
gmcather 已提交
81

82 83
        pass_end = time.time()
        # For internal continuous evaluation
84
        if "CE_MODE_X" in os.environ:
85 86 87 88
            print("kpis	train_acc	%f" % avg_acc)
            print("kpis	train_cost	%f" % avg_cost)
            print("kpis	train_duration	%f" % (pass_end - pass_start))

G
gmcather 已提交
89 90 91

def train_net():
    word_dict, train_reader, test_reader = utils.prepare_data(
K
kolinwei 已提交
92
        "imdb", self_dict=False, batch_size=128, buf_size=50000)
G
gmcather 已提交
93

G
gmcather 已提交
94
    if sys.argv[1] == "bow":
G
gmcather 已提交
95 96 97 98 99 100 101 102 103
        train(
            train_reader,
            word_dict,
            bow_net,
            use_cuda=False,
            parallel=False,
            save_dirname="bow_model",
            lr=0.002,
            pass_num=30,
G
gmcather 已提交
104
            batch_size=4)
G
gmcather 已提交
105
    elif sys.argv[1] == "cnn":
G
gmcather 已提交
106 107 108 109 110 111 112 113 114 115
        train(
            train_reader,
            word_dict,
            cnn_net,
            use_cuda=True,
            parallel=False,
            save_dirname="cnn_model",
            lr=0.01,
            pass_num=30,
            batch_size=4)
G
gmcather 已提交
116
    elif sys.argv[1] == "lstm":
G
gmcather 已提交
117 118 119 120 121 122 123 124 125 126
        train(
            train_reader,
            word_dict,
            lstm_net,
            use_cuda=True,
            parallel=False,
            save_dirname="lstm_model",
            lr=0.05,
            pass_num=30,
            batch_size=4)
G
gmcather 已提交
127
    elif sys.argv[1] == "gru":
G
gmcather 已提交
128 129 130 131 132 133 134 135 136
        train(
            train_reader,
            word_dict,
            lstm_net,
            use_cuda=True,
            parallel=False,
            save_dirname="gru_model",
            lr=0.05,
            pass_num=30,
G
gmcather 已提交
137
            batch_size=4)
G
gmcather 已提交
138 139
    else:
        print("network name cannot be found!")
G
gmcather 已提交
140
        sys.exit(1)
P
peterzhang2029 已提交
141

G
gmcather 已提交
142

G
gmcather 已提交
143 144
if __name__ == "__main__":
    train_net()