train.py 4.0 KB
Newer Older
1
import os
M
minqiyang 已提交
2
import six
P
peterzhang2029 已提交
3 4
import sys
import time
G
gmcather 已提交
5 6
import unittest
import contextlib
G
gmcather 已提交
7

Y
Yibing Liu 已提交
8
import paddle
G
gmcather 已提交
9 10
import paddle.fluid as fluid

G
gmcather 已提交
11 12 13 14 15 16
import utils
from nets import bow_net
from nets import cnn_net
from nets import lstm_net
from nets import gru_net

G
gmcather 已提交
17

G
gmcather 已提交
18
def train(train_reader,
G
gmcather 已提交
19 20 21 22 23 24 25
          word_dict,
          network,
          use_cuda,
          parallel,
          save_dirname,
          lr=0.2,
          pass_num=30):
G
gmcather 已提交
26 27 28
    """
    train network
    """
P
peterzhang2029 已提交
29
    data = fluid.layers.data(
G
gmcather 已提交
30
        name="words", shape=[1], dtype="int64", lod_level=1)
G
gmcather 已提交
31

G
gmcather 已提交
32
    label = fluid.layers.data(name="label", shape=[1], dtype="int64")
G
gmcather 已提交
33 34

    if not parallel:
G
gmcather 已提交
35
        cost, acc, prediction = network(data, label, len(word_dict))
G
gmcather 已提交
36
    else:
37
        places = fluid.layers.device.get_places(device_count=2)
G
gmcather 已提交
38 39 40
        pd = fluid.layers.ParallelDo(places)
        with pd.do():
            cost, acc, prediction = network(
G
gmcather 已提交
41
                pd.read_input(data), pd.read_input(label), len(word_dict))
P
peterzhang2029 已提交
42

G
gmcather 已提交
43 44
            pd.write_output(cost)
            pd.write_output(acc)
P
peterzhang2029 已提交
45

G
gmcather 已提交
46 47 48
        cost, acc = pd()
        cost = fluid.layers.mean(cost)
        acc = fluid.layers.mean(acc)
P
peterzhang2029 已提交
49

G
gmcather 已提交
50 51
    sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=lr)
    sgd_optimizer.minimize(cost)
P
peterzhang2029 已提交
52

G
gmcather 已提交
53
    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
P
peterzhang2029 已提交
54 55 56
    exe = fluid.Executor(place)
    feeder = fluid.DataFeeder(feed_list=[data, label], place=place)

57
    # For internal continuous evaluation
58
    if "CE_MODE_X" in os.environ:
59
        fluid.default_startup_program().random_seed = 110
P
peterzhang2029 已提交
60
    exe.run(fluid.default_startup_program())
M
minqiyang 已提交
61
    for pass_id in six.moves.xrange(pass_num):
62
        pass_start = time.time()
G
gmcather 已提交
63
        data_size, data_count, total_acc, total_cost = 0, 0, 0.0, 0.0
G
gmcather 已提交
64 65
        for data in train_reader():
            avg_cost_np, avg_acc_np = exe.run(fluid.default_main_program(),
G
gmcather 已提交
66 67
                                              feed=feeder.feed(data),
                                              fetch_list=[cost, acc])
G
gmcather 已提交
68 69 70 71 72
            data_size = len(data)
            total_acc += data_size * avg_acc_np
            total_cost += data_size * avg_cost_np
            data_count += data_size
        avg_cost = total_cost / data_count
G
gmcather 已提交
73

G
gmcather 已提交
74
        avg_acc = total_acc / data_count
C
ccmeteorljh 已提交
75 76
        print("pass_id: %d, avg_acc: %f, avg_cost: %f, pass_time_cost: %f" %
              (pass_id, avg_acc, avg_cost, time.time() - pass_start))
G
gmcather 已提交
77

G
gmcather 已提交
78
        epoch_model = save_dirname + "/" + "epoch" + str(pass_id)
G
gmcather 已提交
79
        fluid.io.save_inference_model(epoch_model, ["words", "label"], acc, exe)
G
gmcather 已提交
80

81 82
        pass_end = time.time()
        # For internal continuous evaluation
83
        if "CE_MODE_X" in os.environ:
84 85 86 87
            print("kpis	train_acc	%f" % avg_acc)
            print("kpis	train_cost	%f" % avg_cost)
            print("kpis	train_duration	%f" % (pass_end - pass_start))

G
gmcather 已提交
88 89 90

def train_net():
    word_dict, train_reader, test_reader = utils.prepare_data(
K
kolinwei 已提交
91
        "imdb", self_dict=False, batch_size=128, buf_size=50000)
G
gmcather 已提交
92

G
gmcather 已提交
93
    if sys.argv[1] == "bow":
G
gmcather 已提交
94 95 96 97 98 99 100 101
        train(
            train_reader,
            word_dict,
            bow_net,
            use_cuda=False,
            parallel=False,
            save_dirname="bow_model",
            lr=0.002,
102
            pass_num=30)
G
gmcather 已提交
103
    elif sys.argv[1] == "cnn":
G
gmcather 已提交
104 105 106 107 108 109 110 111
        train(
            train_reader,
            word_dict,
            cnn_net,
            use_cuda=True,
            parallel=False,
            save_dirname="cnn_model",
            lr=0.01,
112
            pass_num=30)
G
gmcather 已提交
113
    elif sys.argv[1] == "lstm":
G
gmcather 已提交
114 115 116 117 118 119 120 121
        train(
            train_reader,
            word_dict,
            lstm_net,
            use_cuda=True,
            parallel=False,
            save_dirname="lstm_model",
            lr=0.05,
122
            pass_num=30)
G
gmcather 已提交
123
    elif sys.argv[1] == "gru":
G
gmcather 已提交
124 125 126
        train(
            train_reader,
            word_dict,
127
            gru_net,
G
gmcather 已提交
128 129 130 131
            use_cuda=True,
            parallel=False,
            save_dirname="gru_model",
            lr=0.05,
132
            pass_num=30)
G
gmcather 已提交
133 134
    else:
        print("network name cannot be found!")
G
gmcather 已提交
135
        sys.exit(1)
P
peterzhang2029 已提交
136

G
gmcather 已提交
137

G
gmcather 已提交
138 139
if __name__ == "__main__":
    train_net()