train.py 12.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
import numpy as np
22 23
import datetime
from collections import deque
24

25

26 27 28 29 30
def set_paddle_flags(**kwargs):
    for key, value in kwargs.items():
        if os.environ.get(key, None) is None:
            os.environ[key] = str(value)

31

32
# NOTE(paddle-dev): All of these flags should be set before
33
# `import paddle`. Otherwise, it would not take any effect.
34 35 36 37
set_paddle_flags(
    FLAGS_eager_delete_tensor_gb=0,  # enable GC to save memory
)

38
from paddle import fluid
39 40

from ppdet.experimental import mixed_precision_context
41 42 43
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.data.data_feed import create_reader

44
from ppdet.utils.cli import print_total_cfg
45
from ppdet.utils import dist_utils
46 47
from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results
from ppdet.utils.stats import TrainingStats
Y
Yang Zhang 已提交
48
from ppdet.utils.cli import ArgsParser
49
from ppdet.utils.check import check_gpu
50
import ppdet.utils.checkpoint as checkpoint
51
from ppdet.modeling.model_input import create_feed
52 53 54 55 56 57 58 59

import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)


def main():
60 61 62 63 64 65 66 67 68
    env = os.environ
    FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
    if FLAGS.dist:
        trainer_id = int(env['PADDLE_TRAINER_ID'])
        import random
        local_seed = (99 + trainer_id)
        random.seed(local_seed)
        np.random.seed(local_seed)

Y
Yang Zhang 已提交
69
    cfg = load_config(FLAGS.config)
70
    if 'architecture' in cfg:
Y
Yang Zhang 已提交
71
        main_arch = cfg.architecture
72 73 74
    else:
        raise ValueError("'architecture' not specified in config file.")

Y
Yang Zhang 已提交
75
    merge_config(FLAGS.opt)
W
wangguanzhong 已提交
76

77 78
    if 'log_iter' not in cfg:
        cfg.log_iter = 20
79

80 81 82
    ignore_params = cfg.finetune_exclude_pretrained_params \
                 if 'finetune_exclude_pretrained_params' in cfg else []

83 84
    # check if set use_gpu=True in paddlepaddle cpu version
    check_gpu(cfg.use_gpu)
85 86
    if not FLAGS.dist or trainer_id == 0:
        print_total_cfg(cfg)
87

Y
Yang Zhang 已提交
88
    if cfg.use_gpu:
89 90
        devices_num = fluid.core.get_cuda_device_count()
    else:
91
        devices_num = int(os.environ.get('CPU_NUM', 1))
92 93

    if 'train_feed' not in cfg:
94
        train_feed = create(main_arch + 'TrainFeed')
95
    else:
Y
Yang Zhang 已提交
96
        train_feed = create(cfg.train_feed)
97

Y
Yang Zhang 已提交
98
    if FLAGS.eval:
99
        if 'eval_feed' not in cfg:
100
            eval_feed = create(main_arch + 'EvalFeed')
101
        else:
Y
Yang Zhang 已提交
102
            eval_feed = create(cfg.eval_feed)
103

104 105 106 107
    if 'FLAGS_selected_gpus' in env:
        device_id = int(env['FLAGS_selected_gpus'])
    else:
        device_id = 0
W
wangguanzhong 已提交
108
    place = fluid.CUDAPlace(device_id) if cfg.use_gpu else fluid.CPUPlace()
109 110 111 112 113
    exe = fluid.Executor(place)

    lr_builder = create('LearningRate')
    optim_builder = create('OptimizerBuilder')

114
    # build program
115 116 117 118
    startup_prog = fluid.Program()
    train_prog = fluid.Program()
    with fluid.program_guard(train_prog, startup_prog):
        with fluid.unique_name.guard():
119
            model = create(main_arch)
120
            train_pyreader, feed_vars = create_feed(train_feed)
121

122 123 124 125 126 127
            if FLAGS.fp16:
                assert (getattr(model.backbone, 'norm_type', None)
                        != 'affine_channel'), \
                    '--fp16 currently does not support affine channel, ' \
                    ' please modify backbone settings to use batch norm'

128 129 130 131 132 133 134 135 136 137 138
            with mixed_precision_context(FLAGS.loss_scale, FLAGS.fp16) as ctx:
                train_fetches = model.train(feed_vars)

                loss = train_fetches['loss']
                if FLAGS.fp16:
                    loss *= ctx.get_loss_scale_var()
                lr = lr_builder()
                optimizer = optim_builder(lr)
                optimizer.minimize(loss)
                if FLAGS.fp16:
                    loss /= ctx.get_loss_scale_var()
139 140 141 142 143

    # parse train fetches
    train_keys, train_values, _ = parse_fetches(train_fetches)
    train_values.append(lr)

Y
Yang Zhang 已提交
144
    if FLAGS.eval:
145 146 147
        eval_prog = fluid.Program()
        with fluid.program_guard(eval_prog, startup_prog):
            with fluid.unique_name.guard():
148
                model = create(main_arch)
149
                eval_pyreader, feed_vars = create_feed(eval_feed)
150
                fetches = model.eval(feed_vars)
151 152
        eval_prog = eval_prog.clone(True)

W
wangguanzhong 已提交
153
        eval_reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir)
154
        eval_pyreader.decorate_sample_list_generator(eval_reader, place)
155

156
        # parse eval fetches
157 158 159 160 161
        extra_keys = []
        if cfg.metric == 'COCO':
            extra_keys = ['im_info', 'im_id', 'im_shape']
        if cfg.metric == 'VOC':
            extra_keys = ['gt_box', 'gt_label', 'is_difficult']
162 163
        if cfg.metric == 'WIDERFACE':
            extra_keys = ['im_id', 'im_shape', 'gt_box']
164 165 166
        eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
                                                         extra_keys)

167
    # compile program for multi-devices
168
    build_strategy = fluid.BuildStrategy()
169
    build_strategy.fuse_all_optimizer_ops = False
170
    build_strategy.fuse_elewise_add_act_ops = True
K
Kaipeng Deng 已提交
171
    # only enable sync_bn in multi GPU devices
172
    sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn'
173 174
    build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \
        and cfg.use_gpu
175 176 177 178 179 180

    exec_strategy = fluid.ExecutionStrategy()
    # iteration number when CompiledProgram tries to drop local execution scopes.
    # Set it to be 1 to save memory usages, so that unused variables in
    # local execution scopes can be deleted after each iteration.
    exec_strategy.num_iteration_per_drop_scope = 1
181
    if FLAGS.dist:
W
wangguanzhong 已提交
182 183
        dist_utils.prepare_for_multi_process(exe, build_strategy, startup_prog,
                                             train_prog)
184
        exec_strategy.num_threads = 1
185 186

    exe.run(startup_prog)
187 188 189 190 191 192 193
    compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
        loss_name=loss.name,
        build_strategy=build_strategy,
        exec_strategy=exec_strategy)

    if FLAGS.eval:
        compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog)
194

195
    fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
Q
qingqing01 已提交
196
    start_iter = 0
197

Y
Yang Zhang 已提交
198 199
    if FLAGS.resume_checkpoint:
        checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
Q
qingqing01 已提交
200
        start_iter = checkpoint.global_step()
201
    elif cfg.pretrain_weights and fuse_bn and not ignore_params:
Y
Yang Zhang 已提交
202 203
        checkpoint.load_and_fusebn(exe, train_prog, cfg.pretrain_weights)
    elif cfg.pretrain_weights:
204 205
        checkpoint.load_params(
            exe, train_prog, cfg.pretrain_weights, ignore_params=ignore_params)
Y
Yang Zhang 已提交
206

W
wangguanzhong 已提交
207 208
    train_reader = create_reader(train_feed, (cfg.max_iters - start_iter) *
                                 devices_num, FLAGS.dataset_dir)
209
    train_pyreader.decorate_sample_list_generator(train_reader, place)
210

211 212 213 214 215 216
    # whether output bbox is normalized in model output layer
    is_bbox_normalized = False
    if hasattr(model, 'is_bbox_normalized') and \
            callable(model.is_bbox_normalized):
        is_bbox_normalized = model.is_bbox_normalized()

K
Kaipeng Deng 已提交
217 218 219
    # if map_type not set, use default 11point, only use in VOC eval
    map_type = cfg.map_type if 'map_type' in cfg else '11point'

Y
Yang Zhang 已提交
220
    train_stats = TrainingStats(cfg.log_smooth_window, train_keys)
221
    train_pyreader.start()
222 223 224
    start_time = time.time()
    end_time = time.time()

Y
Yang Zhang 已提交
225 226
    cfg_name = os.path.basename(FLAGS.config).split('.')[0]
    save_dir = os.path.join(cfg.save_dir, cfg_name)
227
    time_stat = deque(maxlen=cfg.log_smooth_window)
228
    best_box_ap_list = [0.0, 0]  #[map, iter]
229 230 231 232 233 234 235 236

    # use tb-paddle to log data
    if FLAGS.use_tb:
        from tb_paddle import SummaryWriter
        tb_writer = SummaryWriter(FLAGS.tb_log_dir)
        tb_loss_step = 0
        tb_mAP_step = 0

Q
qingqing01 已提交
237
    for it in range(start_iter, cfg.max_iters):
238 239
        start_time = end_time
        end_time = time.time()
240 241 242 243
        time_stat.append(end_time - start_time)
        time_cost = np.mean(time_stat)
        eta_sec = (cfg.max_iters - it) * time_cost
        eta = str(datetime.timedelta(seconds=int(eta_sec)))
244
        outs = exe.run(compiled_train_prog, fetch_list=train_values)
245
        stats = {k: np.array(v).mean() for k, v in zip(train_keys, outs[:-1])}
246 247 248 249 250 251 252 253

        # use tb-paddle to log loss
        if FLAGS.use_tb:
            if it % cfg.log_iter == 0:
                for loss_name, loss_value in stats.items():
                    tb_writer.add_scalar(loss_name, loss_value, tb_loss_step)
                tb_loss_step += 1

254 255
        train_stats.update(stats)
        logs = train_stats.log()
256
        if it % cfg.log_iter == 0 and (not FLAGS.dist or trainer_id == 0):
257 258 259
            strs = 'iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format(
                it, np.mean(outs[-1]), logs, time_cost, eta)
            logger.info(strs)
260

261 262
        if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \
           and (not FLAGS.dist or trainer_id == 0):
263 264
            save_name = str(it) if it != cfg.max_iters - 1 else "model_final"
            checkpoint.save(exe, train_prog, os.path.join(save_dir, save_name))
265

Y
Yang Zhang 已提交
266
            if FLAGS.eval:
267
                # evaluation
268
                results = eval_run(exe, compiled_eval_prog, eval_pyreader,
269
                                   eval_keys, eval_values, eval_cls)
Y
Yang Zhang 已提交
270 271 272
                resolution = None
                if 'mask' in results[0]:
                    resolution = model.mask_head.resolution
273 274 275
                box_ap_stats = eval_results(
                    results, eval_feed, cfg.metric, cfg.num_classes, resolution,
                    is_bbox_normalized, FLAGS.output_eval, map_type)
276

277 278 279 280
                # use tb_paddle to log mAP
                if FLAGS.use_tb:
                    tb_writer.add_scalar("mAP", box_ap_stats[0], tb_mAP_step)
                    tb_mAP_step += 1
281

282 283 284
                if box_ap_stats[0] > best_box_ap_list[0]:
                    best_box_ap_list[0] = box_ap_stats[0]
                    best_box_ap_list[1] = it
285 286
                    checkpoint.save(exe, train_prog,
                                    os.path.join(save_dir, "best_model"))
287
                logger.info("Best test box ap: {}, in iter: {}".format(
288
                    best_box_ap_list[0], best_box_ap_list[1]))
289

290
    train_pyreader.reset()
291 292 293


if __name__ == '__main__':
Y
Yang Zhang 已提交
294
    parser = ArgsParser()
295 296 297 298 299 300
    parser.add_argument(
        "-r",
        "--resume_checkpoint",
        default=None,
        type=str,
        help="Checkpoint path for resuming training.")
301 302 303 304 305 306 307 308 309 310
    parser.add_argument(
        "--fp16",
        action='store_true',
        default=False,
        help="Enable mixed precision training.")
    parser.add_argument(
        "--loss_scale",
        default=8.,
        type=float,
        help="Mixed precision training loss scale.")
Y
Yang Zhang 已提交
311 312 313 314 315 316
    parser.add_argument(
        "--eval",
        action='store_true',
        default=False,
        help="Whether to perform evaluation in train")
    parser.add_argument(
317
        "--output_eval",
Y
Yang Zhang 已提交
318 319
        default=None,
        type=str,
320
        help="Evaluation directory, default is current directory.")
W
wangguanzhong 已提交
321 322 323 324 325 326
    parser.add_argument(
        "-d",
        "--dataset_dir",
        default=None,
        type=str,
        help="Dataset path, same as DataFeed.dataset.dataset_dir")
327 328 329 330 331 332 333 334 335 336
    parser.add_argument(
        "--use_tb",
        type=bool,
        default=False,
        help="whether to record the data to Tensorboard.")
    parser.add_argument(
        '--tb_log_dir',
        type=str,
        default="tb_log_dir/scalar",
        help='Tensorboard logging directory for scalar.')
Y
Yang Zhang 已提交
337
    FLAGS = parser.parse_args()
338
    main()