train.py 9.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
J
jerrywgz 已提交
18
import os
19

20

21 22 23 24 25
def set_paddle_flags(flags):
    for key, value in flags.items():
        if os.environ.get(key, None) is None:
            os.environ[key] = str(value)

26

27 28
set_paddle_flags({
    'FLAGS_conv_workspace_size_limit': 500,
29
    'FLAGS_eager_delete_tensor_gb': 0,  # enable gc
30 31 32
    'FLAGS_memory_fraction_of_eager_deletion': 1,
    'FLAGS_fraction_of_gpu_memory_to_use': 0.98
})
33

34
import sys
J
jerrywgz 已提交
35
import numpy as np
36
import time
J
jerrywgz 已提交
37
import shutil
38
from utility import parse_args, print_arguments, SmoothedValue, TrainingStats, now_time, check_gpu
J
jerrywgz 已提交
39
import collections
J
jerrywgz 已提交
40 41 42 43

import paddle
import paddle.fluid as fluid
import reader
44 45
import models.model_builder as model_builder
import models.resnet as resnet
J
jerrywgz 已提交
46
from learning_rate import exponential_with_warmup_decay
J
jerrywgz 已提交
47
from config import cfg
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
import dist_utils

num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))


def get_device_num():
    # NOTE(zcd): for multi-processe training, each process use one GPU card.
    if num_trainers > 1: return 1
    visible_device = os.environ.get('CUDA_VISIBLE_DEVICES', None)
    if visible_device:
        device_num = len(visible_device.split(','))
    else:
        device_num = subprocess.check_output(
            ['nvidia-smi', '-L']).decode().count('\n')
    return device_num
J
jerrywgz 已提交
63 64


J
jerrywgz 已提交
65
def train():
66
    learning_rate = cfg.learning_rate
J
jerrywgz 已提交
67
    image_shape = [3, cfg.TRAIN.max_size, cfg.TRAIN.max_size]
J
jerrywgz 已提交
68

69
    if cfg.enable_ce:
J
jerrywgz 已提交
70 71 72 73 74 75
        fluid.default_startup_program().random_seed = 1000
        fluid.default_main_program().random_seed = 1000
        import random
        random.seed(0)
        np.random.seed(0)

76
    devices_num = get_device_num()
J
jerrywgz 已提交
77
    total_batch_size = devices_num * cfg.TRAIN.im_per_batch
J
jerrywgz 已提交
78

Z
zhengya01 已提交
79 80 81
    use_random = True
    if cfg.enable_ce:
        use_random = False
82
    model = model_builder.RCNN(
83 84 85
        add_conv_body_func=resnet.add_ResNet50_conv4_body,
        add_roi_box_head_func=resnet.add_ResNet_roi_conv5_head,
        use_pyreader=cfg.use_pyreader,
Z
zhengya01 已提交
86
        use_random=use_random)
87
    model.build_model(image_shape)
J
jerrywgz 已提交
88 89 90
    losses, keys = model.loss()
    loss = losses[0]
    fetch_list = losses
J
jerrywgz 已提交
91

J
jerrywgz 已提交
92 93
    boundaries = cfg.lr_steps
    gamma = cfg.lr_gamma
94
    step_num = len(cfg.lr_steps)
J
jerrywgz 已提交
95
    values = [learning_rate * (gamma**i) for i in range(step_num + 1)]
J
jerrywgz 已提交
96

J
jerrywgz 已提交
97 98 99 100 101 102
    lr = exponential_with_warmup_decay(
        learning_rate=learning_rate,
        boundaries=boundaries,
        values=values,
        warmup_iter=cfg.warm_up_iter,
        warmup_factor=cfg.warm_up_factor)
J
jerrywgz 已提交
103
    optimizer = fluid.optimizer.Momentum(
J
jerrywgz 已提交
104
        learning_rate=lr,
J
jerrywgz 已提交
105 106
        regularization=fluid.regularizer.L2Decay(cfg.weight_decay),
        momentum=cfg.momentum)
J
jerrywgz 已提交
107
    optimizer.minimize(loss)
J
jerrywgz 已提交
108
    fetch_list = fetch_list + [lr]
J
jerrywgz 已提交
109

W
wangchaochaohu 已提交
110 111
    for var in fetch_list:
        var.persistable = True
J
jerrywgz 已提交
112

113 114 115
    #fluid.memory_optimize(fluid.default_main_program(), skip_opt_set=set(fetch_list))
    gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
    place = fluid.CUDAPlace(gpu_id) if cfg.use_gpu else fluid.CPUPlace()
J
jerrywgz 已提交
116 117 118
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

119
    if cfg.pretrained_model:
J
jerrywgz 已提交
120

J
jerrywgz 已提交
121
        def if_exist(var):
122
            return os.path.exists(os.path.join(cfg.pretrained_model, var.name))
J
jerrywgz 已提交
123

J
jerrywgz 已提交
124
        fluid.io.load_vars(exe, cfg.pretrained_model, predicate=if_exist)
J
jerrywgz 已提交
125

126
    if cfg.parallel:
W
wangchaochaohu 已提交
127 128 129 130 131
        build_strategy = fluid.BuildStrategy()
        build_strategy.memory_optimize = False
        build_strategy.enable_inplace = True
        exec_strategy = fluid.ExecutionStrategy()
        exec_strategy.num_iteration_per_drop_scope = 10
132 133 134 135 136 137 138 139 140 141 142 143 144

        if num_trainers > 1 and cfg.use_gpu:
            dist_utils.prepare_for_multi_process(exe, build_strategy,
                                                 fluid.default_main_program())
            # NOTE: the process is fast when num_threads is 1 
            # for multi-process training.
            exec_strategy.num_threads = 1

        train_exe = fluid.ParallelExecutor(
            use_cuda=bool(cfg.use_gpu),
            loss_name=loss.name,
            build_strategy=build_strategy,
            exec_strategy=exec_strategy)
W
wangchaochaohu 已提交
145 146
    else:
        train_exe = exe
147

Z
zhengya01 已提交
148 149 150
    shuffle = True
    if cfg.enable_ce:
        shuffle = False
151 152 153 154
    # NOTE: do not shuffle dataset when using multi-process training 
    shuffle_seed = None
    if num_trainers > 1:
        shuffle_seed = 1
155
    if cfg.use_pyreader:
J
jerrywgz 已提交
156
        train_reader = reader.train(
J
jerrywgz 已提交
157 158 159
            batch_size=cfg.TRAIN.im_per_batch,
            total_batch_size=total_batch_size,
            padding_total=cfg.TRAIN.padding_minibatch,
160 161 162 163 164 165 166 167 168 169 170 171
            shuffle=shuffle,
            shuffle_seed=shuffle_seed)
        if num_trainers > 1:
            assert shuffle_seed is not None, \
                "If num_trainers > 1, the shuffle_seed must be set, because " \
                "the order of batch data generated by reader " \
                "must be the same in the respective processes."
            # NOTE: the order of batch data generated by batch_reader
            # must be the same in the respective processes.
            if num_trainers > 1:
                train_reader = fluid.contrib.reader.distributed_batch_reader(
                    train_reader)
172 173 174
        py_reader = model.py_reader
        py_reader.decorate_paddle_reader(train_reader)
    else:
175
        if num_trainers > 1: shuffle = False
J
jerrywgz 已提交
176 177
        train_reader = reader.train(
            batch_size=total_batch_size, shuffle=shuffle)
178
        feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
J
jerrywgz 已提交
179 180

    def save_model(postfix):
181
        model_path = os.path.join(cfg.model_save_dir, postfix)
J
jerrywgz 已提交
182 183 184 185
        if os.path.isdir(model_path):
            shutil.rmtree(model_path)
        fluid.io.save_persistables(exe, model_path)

J
jerrywgz 已提交
186
    def train_loop_pyreader():
187
        py_reader.start()
J
jerrywgz 已提交
188
        train_stats = TrainingStats(cfg.log_window, keys)
189 190 191
        try:
            start_time = time.time()
            prev_start_time = start_time
192
            for iter_id in range(cfg.max_iter):
193 194
                prev_start_time = start_time
                start_time = time.time()
J
jerrywgz 已提交
195 196 197 198
                outs = train_exe.run(fetch_list=[v.name for v in fetch_list])
                stats = {k: np.array(v).mean() for k, v in zip(keys, outs[:-1])}
                train_stats.update(stats)
                logs = train_stats.log()
199 200
                strs = '{}, iter: {}, lr: {:.5f}, {}, time: {:.3f}'.format(
                    now_time(), iter_id,
J
jerrywgz 已提交
201 202
                    np.mean(outs[-1]), logs, start_time - prev_start_time)
                print(strs)
203
                sys.stdout.flush()
J
jerrywgz 已提交
204
                if (iter_id + 1) % cfg.TRAIN.snapshot_iter == 0:
205
                    save_model("model_iter{}".format(iter_id))
J
jerrywgz 已提交
206 207 208
            end_time = time.time()
            total_time = end_time - start_time
            last_loss = np.array(outs[0]).mean()
Z
zhengya01 已提交
209 210 211 212 213
            if cfg.enable_ce:
                gpu_num = devices_num
                epoch_idx = iter_id + 1
                loss = last_loss
                print("kpis\teach_pass_duration_card%s\t%s" %
J
jerrywgz 已提交
214 215 216
                      (gpu_num, total_time / epoch_idx))
                print("kpis\ttrain_loss_card%s\t%s" % (gpu_num, loss))
        except (StopIteration, fluid.core.EOFException):
217
            py_reader.reset()
J
jerrywgz 已提交
218

J
jerrywgz 已提交
219
    def train_loop():
J
jerrywgz 已提交
220 221
        start_time = time.time()
        prev_start_time = start_time
J
jerrywgz 已提交
222
        start = start_time
J
jerrywgz 已提交
223
        train_stats = TrainingStats(cfg.log_window, keys)
224
        for iter_id, data in enumerate(train_reader()):
J
jerrywgz 已提交
225 226
            prev_start_time = start_time
            start_time = time.time()
J
jerrywgz 已提交
227 228 229 230 231
            outs = train_exe.run(fetch_list=[v.name for v in fetch_list],
                                 feed=feeder.feed(data))
            stats = {k: np.array(v).mean() for k, v in zip(keys, outs[:-1])}
            train_stats.update(stats)
            logs = train_stats.log()
232 233
            strs = '{}, iter: {}, lr: {:.5f}, {}, time: {:.3f}'.format(
                now_time(), iter_id,
J
jerrywgz 已提交
234 235
                np.mean(outs[-1]), logs, start_time - prev_start_time)
            print(strs)
236
            sys.stdout.flush()
J
jerrywgz 已提交
237
            if (iter_id + 1) % cfg.TRAIN.snapshot_iter == 0:
238 239 240
                save_model("model_iter{}".format(iter_id))
            if (iter_id + 1) == cfg.max_iter:
                break
J
jerrywgz 已提交
241 242 243
        end_time = time.time()
        total_time = end_time - start_time
        last_loss = np.array(outs[0]).mean()
Z
zhengya01 已提交
244 245 246 247 248 249
        # only for ce
        if cfg.enable_ce:
            gpu_num = devices_num
            epoch_idx = iter_id + 1
            loss = last_loss
            print("kpis\teach_pass_duration_card%s\t%s" %
J
jerrywgz 已提交
250 251
                  (gpu_num, total_time / epoch_idx))
            print("kpis\ttrain_loss_card%s\t%s" % (gpu_num, loss))
Z
zhengya01 已提交
252

253
    if cfg.use_pyreader:
J
jerrywgz 已提交
254
        train_loop_pyreader()
255
    else:
J
jerrywgz 已提交
256
        train_loop()
257
    save_model('model_final')
J
jerrywgz 已提交
258

J
jerrywgz 已提交
259

J
jerrywgz 已提交
260
if __name__ == '__main__':
J
jerrywgz 已提交
261
    args = parse_args()
J
jerrywgz 已提交
262
    print_arguments(args)
263
    check_gpu(args.use_gpu)
J
jerrywgz 已提交
264
    train()