train.py 13.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Q
qingqing01 已提交
19 20 21 22 23 24
import os, sys
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
if parent_path not in sys.path:
    sys.path.append(parent_path)

25 26
import time
import numpy as np
X
xiegegege 已提交
27
import random
28
import datetime
29
import six
30
from collections import deque
H
hysunflower 已提交
31
from paddle.fluid import profiler
32 33

from paddle import fluid
34 35
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
from paddle.fluid.optimizer import ExponentialMovingAverage
36 37

from ppdet.experimental import mixed_precision_context
38
from ppdet.core.workspace import load_config, merge_config, create
39
from ppdet.data.reader import create_reader
40

41
from ppdet.utils import dist_utils
42 43 44
from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results
from ppdet.utils.stats import TrainingStats
from ppdet.utils.cli import ArgsParser
45
from ppdet.utils.check import check_gpu, check_version, check_config
46 47 48 49 50 51 52 53 54
import ppdet.utils.checkpoint as checkpoint

import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)


def main():
55
    env = os.environ
56 57 58
    FLAGS.dist = 'PADDLE_TRAINER_ID' in env \
                    and 'PADDLE_TRAINERS_NUM' in env \
                    and int(env['PADDLE_TRAINERS_NUM']) > 1
59
    num_trainers = int(env.get('PADDLE_TRAINERS_NUM', 1))
60 61 62 63 64 65
    if FLAGS.dist:
        trainer_id = int(env['PADDLE_TRAINER_ID'])
        local_seed = (99 + trainer_id)
        random.seed(local_seed)
        np.random.seed(local_seed)

X
xiegegege 已提交
66 67 68 69
    if FLAGS.enable_ce:
        random.seed(0)
        np.random.seed(0)

70 71
    cfg = load_config(FLAGS.config)
    merge_config(FLAGS.opt)
72
    check_config(cfg)
73 74
    # check if set use_gpu=True in paddlepaddle cpu version
    check_gpu(cfg.use_gpu)
W
wangguanzhong 已提交
75 76
    # check if paddlepaddle version is satisfied
    check_version()
77

W
wangguanzhong 已提交
78 79 80 81
    save_only = getattr(cfg, 'save_prediction_only', False)
    if save_only:
        raise NotImplementedError('The config file only support prediction,'
                                  ' training stage is not implemented now')
82 83
    main_arch = cfg.architecture

84 85 86
    if cfg.use_gpu:
        devices_num = fluid.core.get_cuda_device_count()
    else:
87
        devices_num = int(os.environ.get('CPU_NUM', 1))
88

89 90 91 92
    if 'FLAGS_selected_gpus' in env:
        device_id = int(env['FLAGS_selected_gpus'])
    else:
        device_id = 0
W
wangguanzhong 已提交
93
    place = fluid.CUDAPlace(device_id) if cfg.use_gpu else fluid.CPUPlace()
94 95 96 97 98 99 100 101
    exe = fluid.Executor(place)

    lr_builder = create('LearningRate')
    optim_builder = create('OptimizerBuilder')

    # build program
    startup_prog = fluid.Program()
    train_prog = fluid.Program()
X
xiegegege 已提交
102 103 104
    if FLAGS.enable_ce:
        startup_prog.random_seed = 1000
        train_prog.random_seed = 1000
105 106
    with fluid.program_guard(train_prog, startup_prog):
        with fluid.unique_name.guard():
107
            model = create(main_arch)
108 109 110 111 112 113
            if FLAGS.fp16:
                assert (getattr(model.backbone, 'norm_type', None)
                        != 'affine_channel'), \
                    '--fp16 currently does not support affine channel, ' \
                    ' please modify backbone settings to use batch norm'

114
            with mixed_precision_context(FLAGS.loss_scale, FLAGS.fp16) as ctx:
115 116
                inputs_def = cfg['TrainReader']['inputs_def']
                feed_vars, train_loader = model.build_inputs(**inputs_def)
117 118 119 120 121 122
                train_fetches = model.train(feed_vars)
                loss = train_fetches['loss']
                if FLAGS.fp16:
                    loss *= ctx.get_loss_scale_var()
                lr = lr_builder()
                optimizer = optim_builder(lr)
123
                optimizer.minimize(loss)
124

125 126
                if FLAGS.fp16:
                    loss /= ctx.get_loss_scale_var()
127

128 129 130 131 132 133
            if 'use_ema' in cfg and cfg['use_ema']:
                global_steps = _decay_step_counter()
                ema = ExponentialMovingAverage(
                    cfg['ema_decay'], thres_steps=global_steps)
                ema.update()

134 135 136 137 138 139 140 141
    # parse train fetches
    train_keys, train_values, _ = parse_fetches(train_fetches)
    train_values.append(lr)

    if FLAGS.eval:
        eval_prog = fluid.Program()
        with fluid.program_guard(eval_prog, startup_prog):
            with fluid.unique_name.guard():
142
                model = create(main_arch)
143 144
                inputs_def = cfg['EvalReader']['inputs_def']
                feed_vars, eval_loader = model.build_inputs(**inputs_def)
145
                fetches = model.eval(feed_vars)
146 147
        eval_prog = eval_prog.clone(True)

148
        eval_reader = create_reader(cfg.EvalReader, devices_num=1)
149 150
        # When iterable mode, set set_sample_list_generator(eval_reader, place)
        eval_loader.set_sample_list_generator(eval_reader)
151

152
        # parse eval fetches
153 154 155 156
        extra_keys = []
        if cfg.metric == 'COCO':
            extra_keys = ['im_info', 'im_id', 'im_shape']
        if cfg.metric == 'VOC':
157
            extra_keys = ['gt_bbox', 'gt_class', 'is_difficult']
158
        if cfg.metric == 'WIDERFACE':
159
            extra_keys = ['im_id', 'im_shape', 'gt_bbox']
160 161 162 163 164
        eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
                                                         extra_keys)

    # compile program for multi-devices
    build_strategy = fluid.BuildStrategy()
165
    build_strategy.fuse_all_optimizer_ops = False
K
Kaipeng Deng 已提交
166
    # only enable sync_bn in multi GPU devices
167
    sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn'
168 169
    build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \
        and cfg.use_gpu
170 171 172 173 174 175

    exec_strategy = fluid.ExecutionStrategy()
    # iteration number when CompiledProgram tries to drop local execution scopes.
    # Set it to be 1 to save memory usages, so that unused variables in
    # local execution scopes can be deleted after each iteration.
    exec_strategy.num_iteration_per_drop_scope = 1
176
    if FLAGS.dist:
W
wangguanzhong 已提交
177 178
        dist_utils.prepare_for_multi_process(exe, build_strategy, startup_prog,
                                             train_prog)
179
        exec_strategy.num_threads = 1
180 181

    exe.run(startup_prog)
182 183 184 185 186 187
    compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
        loss_name=loss.name,
        build_strategy=build_strategy,
        exec_strategy=exec_strategy)

    if FLAGS.eval:
188
        compiled_eval_prog = fluid.CompiledProgram(eval_prog)
189

190
    fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
191

Q
qingqing01 已提交
192 193 194 195
    ignore_params = cfg.finetune_exclude_pretrained_params \
                 if 'finetune_exclude_pretrained_params' in cfg else []

    start_iter = 0
196 197
    if FLAGS.resume_checkpoint:
        checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
Q
qingqing01 已提交
198
        start_iter = checkpoint.global_step()
199
    elif cfg.pretrain_weights and fuse_bn and not ignore_params:
200 201
        checkpoint.load_and_fusebn(exe, train_prog, cfg.pretrain_weights)
    elif cfg.pretrain_weights:
202 203
        checkpoint.load_params(
            exe, train_prog, cfg.pretrain_weights, ignore_params=ignore_params)
204

205 206 207
    train_reader = create_reader(
        cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num,
        cfg,
208 209
        devices_num=devices_num,
        num_trainers=num_trainers)
210 211
    # When iterable mode, set set_sample_list_generator(train_reader, place)
    train_loader.set_sample_list_generator(train_reader)
212

213 214 215 216 217 218
    # whether output bbox is normalized in model output layer
    is_bbox_normalized = False
    if hasattr(model, 'is_bbox_normalized') and \
            callable(model.is_bbox_normalized):
        is_bbox_normalized = model.is_bbox_normalized()

K
Kaipeng Deng 已提交
219 220 221
    # if map_type not set, use default 11point, only use in VOC eval
    map_type = cfg.map_type if 'map_type' in cfg else '11point'

222
    train_stats = TrainingStats(cfg.log_smooth_window, train_keys)
W
wangguanzhong 已提交
223
    train_loader.start()
224 225 226 227 228
    start_time = time.time()
    end_time = time.time()

    cfg_name = os.path.basename(FLAGS.config).split('.')[0]
    save_dir = os.path.join(cfg.save_dir, cfg_name)
229
    time_stat = deque(maxlen=cfg.log_smooth_window)
230
    best_box_ap_list = [0.0, 0]  #[map, iter]
231

走神的阿圆's avatar
走神的阿圆 已提交
232 233
    # use VisualDL to log data
    if FLAGS.use_vdl:
234
        assert six.PY3, "VisualDL requires Python >= 3.5"
走神的阿圆's avatar
走神的阿圆 已提交
235 236 237 238
        from visualdl import LogWriter
        vdl_writer = LogWriter(FLAGS.vdl_log_dir)
        vdl_loss_step = 0
        vdl_mAP_step = 0
239

Q
qingqing01 已提交
240
    for it in range(start_iter, cfg.max_iters):
241 242
        start_time = end_time
        end_time = time.time()
243 244 245 246
        time_stat.append(end_time - start_time)
        time_cost = np.mean(time_stat)
        eta_sec = (cfg.max_iters - it) * time_cost
        eta = str(datetime.timedelta(seconds=int(eta_sec)))
247
        outs = exe.run(compiled_train_prog, fetch_list=train_values)
248
        stats = {k: np.array(v).mean() for k, v in zip(train_keys, outs[:-1])}
249

走神的阿圆's avatar
走神的阿圆 已提交
250 251
        # use vdl-paddle to log loss
        if FLAGS.use_vdl:
252 253
            if it % cfg.log_iter == 0:
                for loss_name, loss_value in stats.items():
走神的阿圆's avatar
走神的阿圆 已提交
254 255
                    vdl_writer.add_scalar(loss_name, loss_value, vdl_loss_step)
                vdl_loss_step += 1
256

257 258
        train_stats.update(stats)
        logs = train_stats.log()
259
        if it % cfg.log_iter == 0 and (not FLAGS.dist or trainer_id == 0):
T
Tao Luo 已提交
260 261 262
            ips = float(cfg['TrainReader']['batch_size']) / time_cost
            strs = 'iter: {}, lr: {:.6f}, {}, batch_cost: {:.5f} s, eta: {}, ips: {:.5f} images/sec'.format(
                it, np.mean(outs[-1]), logs, time_cost, eta, ips)
263
            logger.info(strs)
264

H
hysunflower 已提交
265 266 267 268 269 270 271
        # NOTE : profiler tools, used for benchmark
        if FLAGS.is_profiler and it == 5:
            profiler.start_profiler("All")
        elif FLAGS.is_profiler and it == 10:
            profiler.stop_profiler("total", FLAGS.profiler_path)
            return

littletomatodonkey's avatar
littletomatodonkey 已提交
272

273 274
        if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \
           and (not FLAGS.dist or trainer_id == 0):
275
            save_name = str(it) if it != cfg.max_iters - 1 else "model_final"
276 277
            if 'use_ema' in cfg and cfg['use_ema']:
                exe.run(ema.apply_program)
278
            checkpoint.save(exe, train_prog, os.path.join(save_dir, save_name))
279 280 281 282

            if FLAGS.eval:
                # evaluation
                resolution = None
W
wangguanzhong 已提交
283
                if 'Mask' in cfg.architecture:
284
                    resolution = model.mask_head.resolution
W
wangguanzhong 已提交
285 286 287 288 289 290 291
                results = eval_run(
                    exe,
                    compiled_eval_prog,
                    eval_loader,
                    eval_keys,
                    eval_values,
                    eval_cls,
W
wangguanzhong 已提交
292
                    cfg,
W
wangguanzhong 已提交
293
                    resolution=resolution)
294
                box_ap_stats = eval_results(
295 296 297
                    results, cfg.metric, cfg.num_classes, resolution,
                    is_bbox_normalized, FLAGS.output_eval, map_type,
                    cfg['EvalReader']['dataset'])
298

走神的阿圆's avatar
走神的阿圆 已提交
299 300 301 302
                # use vdl_paddle to log mAP
                if FLAGS.use_vdl:
                    vdl_writer.add_scalar("mAP", box_ap_stats[0], vdl_mAP_step)
                    vdl_mAP_step += 1
303

304 305 306
                if box_ap_stats[0] > best_box_ap_list[0]:
                    best_box_ap_list[0] = box_ap_stats[0]
                    best_box_ap_list[1] = it
307 308
                    checkpoint.save(exe, train_prog,
                                    os.path.join(save_dir, "best_model"))
309
                logger.info("Best test box ap: {}, in iter: {}".format(
310
                    best_box_ap_list[0], best_box_ap_list[1]))
311

312 313 314
            if 'use_ema' in cfg and cfg['use_ema']:
                exe.run(ema.restore_program)

W
wangguanzhong 已提交
315
    train_loader.reset()
316 317 318 319


if __name__ == '__main__':
    parser = ArgsParser()
320 321 322 323 324 325
    parser.add_argument(
        "-r",
        "--resume_checkpoint",
        default=None,
        type=str,
        help="Checkpoint path for resuming training.")
326 327 328 329 330 331 332 333 334 335
    parser.add_argument(
        "--fp16",
        action='store_true',
        default=False,
        help="Enable mixed precision training.")
    parser.add_argument(
        "--loss_scale",
        default=8.,
        type=float,
        help="Mixed precision training loss scale.")
336 337 338 339 340 341
    parser.add_argument(
        "--eval",
        action='store_true',
        default=False,
        help="Whether to perform evaluation in train")
    parser.add_argument(
342
        "--output_eval",
343 344
        default=None,
        type=str,
345
        help="Evaluation directory, default is current directory.")
346
    parser.add_argument(
走神的阿圆's avatar
走神的阿圆 已提交
347
        "--use_vdl",
348 349
        type=bool,
        default=False,
走神的阿圆's avatar
走神的阿圆 已提交
350
        help="whether to record the data to VisualDL.")
351
    parser.add_argument(
走神的阿圆's avatar
走神的阿圆 已提交
352
        '--vdl_log_dir',
353
        type=str,
走神的阿圆's avatar
走神的阿圆 已提交
354 355
        default="vdl_log_dir/scalar",
        help='VisualDL logging directory for scalar.')
X
xiegegege 已提交
356 357 358 359 360 361
    parser.add_argument(
        "--enable_ce",
        type=bool,
        default=False,
        help="If set True, enable continuous evaluation job."
        "This flag is only used for internal test.")
H
hysunflower 已提交
362 363 364 365 366 367 368 369 370 371 372 373

    #NOTE:args for profiler tools, used for benchmark
    parser.add_argument(
        '--is_profiler',
        type=int,
        default=0,
        help='The switch of profiler tools. (used for benchmark)')
    parser.add_argument(
        '--profiler_path',
        type=str,
        default="./detection.profiler",
        help='The profiler output file path. (used for benchmark)')
374 375
    FLAGS = parser.parse_args()
    main()