trainer.py 19.0 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15
import os
L
LielinJiang 已提交
16
import sys
L
LielinJiang 已提交
17
import time
L
LielinJiang 已提交
18
import copy
L
LielinJiang 已提交
19

L
LielinJiang 已提交
20
import logging
L
LielinJiang 已提交
21
import datetime
L
LielinJiang 已提交
22

L
LielinJiang 已提交
23
import paddle
L
LielinJiang 已提交
24
from paddle.distributed import ParallelEnv
L
LielinJiang 已提交
25 26 27 28

from ..datasets.builder import build_dataloader
from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image
L
LielinJiang 已提交
29
from ..utils.filesystem import makedirs, save, load
30
from ..utils.timer import TimeAverager
L
lzzyzlbb 已提交
31
from ..utils.profiler import add_profiler_step
L
fix nan  
LielinJiang 已提交
32

33

34
class IterLoader:
35

36 37 38 39
    def __init__(self, dataloader):
        self._dataloader = dataloader
        self.iter_loader = iter(self._dataloader)
        self._epoch = 1
L
LielinJiang 已提交
40
        self._inner_iter = 0
L
LielinJiang 已提交
41

42 43 44 45 46 47
    @property
    def epoch(self):
        return self._epoch

    def __next__(self):
        try:
L
LielinJiang 已提交
48 49 50 51
            if sys.platform == "Windows" and self._inner_iter == len(
                    self._dataloader) - 1:
                self._inner_iter = 0
                raise StopIteration
52 53 54 55 56
            data = next(self.iter_loader)
        except StopIteration:
            self._epoch += 1
            self.iter_loader = iter(self._dataloader)
            data = next(self.iter_loader)
L
LielinJiang 已提交
57

L
LielinJiang 已提交
58
        self._inner_iter += 1
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
        return data

    def __len__(self):
        return len(self._dataloader)


class Trainer:
    """
    # trainer calling logic:
    #
    #                build_model                               ||    model(BaseModel)
    #                     |                                    ||
    #               build_dataloader                           ||    dataloader
    #                     |                                    ||
    #               model.setup_lr_schedulers                  ||    lr_scheduler
    #                     |                                    ||
    #               model.setup_optimizers                     ||    optimizers
    #                     |                                    ||
    #     train loop (model.setup_input + model.train_iter)    ||    train loop
    #                     |                                    ||
    #         print log (model.get_current_losses)             ||
    #                     |                                    ||
    #         save checkpoint (model.nets)                     \/
    """
83

84
    def __init__(self, cfg):
85 86 87 88 89 90 91
        # base config
        self.logger = logging.getLogger(__name__)
        self.cfg = cfg
        self.output_dir = cfg.output_dir
        self.max_eval_steps = cfg.model.get('max_eval_steps', None)

        self.local_rank = ParallelEnv().local_rank
92
        self.world_size = ParallelEnv().nranks
93 94 95 96 97 98 99 100 101 102
        self.log_interval = cfg.log_config.interval
        self.visual_interval = cfg.log_config.visiual_interval
        self.weight_interval = cfg.snapshot_config.interval

        self.start_epoch = 1
        self.current_epoch = 1
        self.current_iter = 1
        self.inner_iter = 1
        self.batch_id = 0
        self.global_steps = 0
L
LielinJiang 已提交
103

L
LielinJiang 已提交
104
        # build model
105
        self.model = build_model(cfg.model)
L
LielinJiang 已提交
106

107 108
        # build metrics
        self.metrics = None
L
LielinJiang 已提交
109
        self.is_save_img = True
110 111 112
        validate_cfg = cfg.get('validate', None)
        if validate_cfg and 'metrics' in validate_cfg:
            self.metrics = self.model.setup_metrics(validate_cfg['metrics'])
113 114
        if validate_cfg and 'save_img' in validate_cfg:
            self.is_save_img = validate_cfg['save_img']
115 116 117 118 119 120

        self.enable_visualdl = cfg.get('enable_visualdl', False)
        if self.enable_visualdl:
            import visualdl
            self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir)

121 122 123 124 125 126 127 128 129 130 131 132 133 134
        # build train dataloader
        self.train_dataloader = build_dataloader(cfg.dataset.train)
        self.iters_per_epoch = len(self.train_dataloader)

        # build lr scheduler
        # TODO: has a better way?
        if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler:
            cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch
        self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler)

        # build optimizers
        self.optimizers = self.model.setup_optimizers(self.lr_schedulers,
                                                      cfg.optimizer)

B
Birdylx 已提交
135
        # setup amp train
B
Birdylx 已提交
136
        self.scalers = self.setup_amp_train() if self.cfg.amp else None
B
Birdylx 已提交
137 138 139 140 141 142 143 144 145

        # multiple gpus prepare
        if ParallelEnv().nranks > 1:
            self.distributed_data_parallel()

        # evaluate only
        if not cfg.is_train:
            return

146 147 148 149 150 151 152 153
        self.epochs = cfg.get('epochs', None)
        if self.epochs:
            self.total_iters = self.epochs * self.iters_per_epoch
            self.by_epoch = True
        else:
            self.by_epoch = False
            self.total_iters = cfg.total_iters

L
LielinJiang 已提交
154 155 156
        if self.by_epoch:
            self.weight_interval *= self.iters_per_epoch

L
LielinJiang 已提交
157 158 159
        self.validate_interval = -1
        if cfg.get('validate', None) is not None:
            self.validate_interval = cfg.validate.get('interval', -1)
160 161

        self.time_count = {}
L
LielinJiang 已提交
162
        self.best_metric = {}
163
        self.model.set_total_iter(self.total_iters)
L
lzzyzlbb 已提交
164
        self.profiler_options = cfg.profiler_options
L
LielinJiang 已提交
165

B
Birdylx 已提交
166
    def setup_amp_train(self):
B
Birdylx 已提交
167
        """ decerate model, optimizer and return a list of GradScaler """
B
Birdylx 已提交
168 169
        self.logger.info('use AMP to train. AMP level = {}'.format(
            self.cfg.amp_level))
B
Birdylx 已提交
170

B
Birdylx 已提交
171 172 173 174 175 176 177 178 179 180 181 182
        # need to decorate model and optim if amp_level == 'O2'
        if self.cfg.amp_level == 'O2':
            nets, optimizers = list(self.model.nets.values()), list(
                self.optimizers.values())
            nets, optimizers = paddle.amp.decorate(models=nets,
                                                   optimizers=optimizers,
                                                   level='O2',
                                                   save_dtype='float32')
            for i, (k, _) in enumerate(self.model.nets.items()):
                self.model.nets[k] = nets[i]
            for i, (k, _) in enumerate(self.optimizers.items()):
                self.optimizers[k] = optimizers[i]
B
Birdylx 已提交
183 184 185 186 187 188 189

        scalers = [
            paddle.amp.GradScaler(init_loss_scaling=1024)
            for i in range(len(self.optimizers))
        ]

        return scalers
B
Birdylx 已提交
190

191
    def distributed_data_parallel(self):
L
LielinJiang 已提交
192
        paddle.distributed.init_parallel_env()
193
        find_unused_parameters = self.cfg.get('find_unused_parameters', False)
194
        for net_name, net in self.model.nets.items():
195 196
            self.model.nets[net_name] = paddle.DataParallel(
                net, find_unused_parameters=find_unused_parameters)
197

L
LielinJiang 已提交
198 199 200 201 202 203 204 205 206 207 208
    def learning_rate_scheduler_step(self):
        if isinstance(self.model.lr_scheduler, dict):
            for lr_scheduler in self.model.lr_scheduler.values():
                lr_scheduler.step()
        elif isinstance(self.model.lr_scheduler,
                        paddle.optimizer.lr.LRScheduler):
            self.model.lr_scheduler.step()
        else:
            raise ValueError(
                'lr schedulter must be a dict or an instance of LRScheduler')

L
LielinJiang 已提交
209
    def train(self):
210 211
        reader_cost_averager = TimeAverager()
        batch_cost_averager = TimeAverager()
L
LielinJiang 已提交
212

213
        iter_loader = IterLoader(self.train_dataloader)
L
LielinJiang 已提交
214

L
LielinJiang 已提交
215 216
        # set model.is_train = True
        self.model.setup_train_mode(is_train=True)
217 218
        while self.current_iter < (self.total_iters + 1):
            self.current_epoch = iter_loader.epoch
W
wangna11BD 已提交
219
            self.inner_iter = self.current_iter % max(self.iters_per_epoch, 1)
L
LielinJiang 已提交
220

L
lzzyzlbb 已提交
221 222
            add_profiler_step(self.profiler_options)

223 224 225 226 227 228
            start_time = step_start_time = time.time()
            data = next(iter_loader)
            reader_cost_averager.record(time.time() - step_start_time)
            # unpack data from dataset and apply preprocessing
            # data input should be dict
            self.model.setup_input(data)
229 230

            if self.cfg.amp:
B
Birdylx 已提交
231
                self.model.train_iter_amp(self.optimizers, self.scalers,
232 233 234
                                          self.cfg.amp_level)  # amp train
            else:
                self.model.train_iter(self.optimizers)  # norm train
235

W
wangna11BD 已提交
236 237 238
            batch_cost_averager.record(
                time.time() - step_start_time,
                num_samples=self.cfg['dataset']['train'].get('batch_size', 1))
239 240 241

            step_start_time = time.time()

242 243 244 245 246 247 248 249 250
            if self.current_iter % self.log_interval == 0:
                self.data_time = reader_cost_averager.get_average()
                self.step_time = batch_cost_averager.get_average()
                self.ips = batch_cost_averager.get_ips_average()
                self.print_log()

                reader_cost_averager.reset()
                batch_cost_averager.reset()

L
LielinJiang 已提交
251
            if self.current_iter % self.visual_interval == 0 and self.local_rank == 0:
252 253
                self.visual('visual_train')

L
LielinJiang 已提交
254
            self.learning_rate_scheduler_step()
L
LielinJiang 已提交
255

L
LielinJiang 已提交
256
            if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0:
257
                self.test()
L
fix nan  
LielinJiang 已提交
258

L
LielinJiang 已提交
259 260 261
            if self.current_iter % self.weight_interval == 0:
                self.save(self.current_iter, 'weight', keep=-1)
                self.save(self.current_iter)
L
LielinJiang 已提交
262

263
            self.current_iter += 1
L
LielinJiang 已提交
264

L
LielinJiang 已提交
265 266
    def test(self):
        if not hasattr(self, 'test_dataloader'):
267
            self.test_dataloader = build_dataloader(self.cfg.dataset.test,
268
                                                    is_train=False)
L
lijianshe02 已提交
269 270 271
        iter_loader = IterLoader(self.test_dataloader)
        if self.max_eval_steps is None:
            self.max_eval_steps = len(self.test_dataloader)
272 273 274 275

        if self.metrics:
            for metric in self.metrics.values():
                metric.reset()
L
LielinJiang 已提交
276

L
LielinJiang 已提交
277 278 279
        # set model.is_train = False
        self.model.setup_train_mode(is_train=False)

L
lijianshe02 已提交
280
        for i in range(self.max_eval_steps):
281 282
            if self.max_eval_steps < self.log_interval or i % self.log_interval == 0:
                self.logger.info('Test iter: [%d/%d]' %
W
wangna11BD 已提交
283 284
                                 (i * self.world_size,
                                  self.max_eval_steps * self.world_size))
285

L
lijianshe02 已提交
286
            data = next(iter_loader)
287 288
            self.model.setup_input(data)
            self.model.test_iter(metrics=self.metrics)
L
LielinJiang 已提交
289

290 291 292 293
            if self.is_save_img:
                visual_results = {}
                current_paths = self.model.get_image_paths()
                current_visuals = self.model.get_current_visuals()
L
LielinJiang 已提交
294

295 296 297
                if len(current_visuals) > 0 and list(
                        current_visuals.values())[0].shape == 4:
                    num_samples = list(current_visuals.values())[0].shape[0]
L
LielinJiang 已提交
298
                else:
299
                    num_samples = 1
L
LielinJiang 已提交
300

301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
                for j in range(num_samples):
                    if j < len(current_paths):
                        short_path = os.path.basename(current_paths[j])
                        basename = os.path.splitext(short_path)[0]
                    else:
                        basename = '{:04d}_{:04d}'.format(i, j)
                    for k, img_tensor in current_visuals.items():
                        name = '%s_%s' % (basename, k)
                        if len(img_tensor.shape) == 4:
                            visual_results.update({name: img_tensor[j]})
                        else:
                            visual_results.update({name: img_tensor})

                self.visual('visual_test',
                            visual_results=visual_results,
                            step=self.batch_id,
                            is_save_image=True)
L
LielinJiang 已提交
318

319 320 321 322 323
        if self.metrics:
            for metric_name, metric in self.metrics.items():
                self.logger.info("Metric {}: {:.4f}".format(
                    metric_name, metric.accumulate()))

L
LielinJiang 已提交
324 325
    def print_log(self):
        losses = self.model.get_current_losses()
L
LielinJiang 已提交
326

327 328 329 330 331 332 333 334 335
        message = ''
        if self.by_epoch:
            message += 'Epoch: %d/%d, iter: %d/%d ' % (
                self.current_epoch, self.epochs, self.inner_iter,
                self.iters_per_epoch)
        else:
            message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters)

        message += f'lr: {self.current_learning_rate:.3e} '
L
LielinJiang 已提交
336 337 338

        for k, v in losses.items():
            message += '%s: %.3f ' % (k, v)
郑启航 已提交
339 340
            if self.enable_visualdl:
                self.vdl_logger.add_scalar(k, v, step=self.global_steps)
L
LielinJiang 已提交
341

342 343 344
        if hasattr(self, 'step_time'):
            message += 'batch_cost: %.5f sec ' % self.step_time

345
        if hasattr(self, 'data_time'):
346
            message += 'reader_cost: %.5f sec ' % self.data_time
347

348
        if hasattr(self, 'ips'):
L
LielinJiang 已提交
349 350 351
            message += 'ips: %.5f images/s ' % self.ips

        if hasattr(self, 'step_time'):
L
LielinJiang 已提交
352 353 354
            eta = self.step_time * (self.total_iters - self.current_iter)
            eta = eta if eta > 0 else 0

L
LielinJiang 已提交
355 356
            eta_str = str(datetime.timedelta(seconds=int(eta)))
            message += f'eta: {eta_str}'
357

L
LielinJiang 已提交
358 359 360 361 362
        # print the message
        self.logger.info(message)

    @property
    def current_learning_rate(self):
L
LielinJiang 已提交
363 364
        for optimizer in self.model.optimizers.values():
            return optimizer.get_lr()
L
LielinJiang 已提交
365

郑启航 已提交
366 367 368 369 370 371 372 373 374 375 376 377 378
    def visual(self,
               results_dir,
               visual_results=None,
               step=None,
               is_save_image=False):
        """
        visual the images, use visualdl or directly write to the directory
        Parameters:
            results_dir (str)     --  directory name which contains saved images
            visual_results (dict) --  the results images dict
            step (int)            --  global steps, used in visualdl
            is_save_image (bool)  --  weather write to the directory or visualdl
        """
L
LielinJiang 已提交
379 380 381 382 383
        self.model.compute_visuals()

        if visual_results is None:
            visual_results = self.model.get_current_visuals()

L
LielinJiang 已提交
384 385 386
        min_max = self.cfg.get('min_max', None)
        if min_max is None:
            min_max = (-1., 1.)
387

郑启航 已提交
388 389 390
        image_num = self.cfg.get('image_num', None)
        if (image_num is None) or (not self.enable_visualdl):
            image_num = 1
L
LielinJiang 已提交
391
        for label, image in visual_results.items():
郑启航 已提交
392 393 394 395 396 397 398 399 400
            image_numpy = tensor2img(image, min_max, image_num)
            if (not is_save_image) and self.enable_visualdl:
                self.vdl_logger.add_image(
                    results_dir + '/' + label,
                    image_numpy,
                    step=step if step else self.global_steps,
                    dataformats="HWC" if image_num == 1 else "NCHW")
            else:
                if self.cfg.is_train:
W
wangna11BD 已提交
401 402 403 404
                    if self.by_epoch:
                        msg = 'epoch%.3d_' % self.current_epoch
                    else:
                        msg = 'iter%.3d_' % self.current_iter
郑启航 已提交
405 406 407 408 409 410
                else:
                    msg = ''
                makedirs(os.path.join(self.output_dir, results_dir))
                img_path = os.path.join(self.output_dir, results_dir,
                                        msg + '%s.png' % (label))
                save_image(image_numpy, img_path)
L
LielinJiang 已提交
411 412 413 414

    def save(self, epoch, name='checkpoint', keep=1):
        if self.local_rank != 0:
            return
L
LielinJiang 已提交
415

L
LielinJiang 已提交
416 417 418
        assert name in ['checkpoint', 'weight']

        state_dicts = {}
L
LielinJiang 已提交
419 420 421 422 423 424
        if self.by_epoch:
            save_filename = 'epoch_%s_%s.pdparams' % (
                epoch // self.iters_per_epoch, name)
        else:
            save_filename = 'iter_%s_%s.pdparams' % (epoch, name)

L
lijianshe02 已提交
425
        os.makedirs(self.output_dir, exist_ok=True)
L
LielinJiang 已提交
426
        save_path = os.path.join(self.output_dir, save_filename)
L
LielinJiang 已提交
427 428
        for net_name, net in self.model.nets.items():
            state_dicts[net_name] = net.state_dict()
L
LielinJiang 已提交
429 430 431 432 433 434 435

        if name == 'weight':
            save(state_dicts, save_path)
            return

        state_dicts['epoch'] = epoch

L
LielinJiang 已提交
436 437
        for opt_name, opt in self.model.optimizers.items():
            state_dicts[opt_name] = opt.state_dict()
L
LielinJiang 已提交
438 439 440 441 442

        save(state_dicts, save_path)

        if keep > 0:
            try:
L
LielinJiang 已提交
443 444 445 446 447 448 449 450 451 452
                if self.by_epoch:
                    checkpoint_name_to_be_removed = os.path.join(
                        self.output_dir, 'epoch_%s_%s.pdparams' %
                        ((epoch - keep * self.weight_interval) //
                         self.iters_per_epoch, name))
                else:
                    checkpoint_name_to_be_removed = os.path.join(
                        self.output_dir, 'iter_%s_%s.pdparams' %
                        (epoch - keep * self.weight_interval, name))

L
LielinJiang 已提交
453 454 455 456 457 458 459 460 461 462
                if os.path.exists(checkpoint_name_to_be_removed):
                    os.remove(checkpoint_name_to_be_removed)

            except Exception as e:
                self.logger.info('remove old checkpoints error: {}'.format(e))

    def resume(self, checkpoint_path):
        state_dicts = load(checkpoint_path)
        if state_dicts.get('epoch', None) is not None:
            self.start_epoch = state_dicts['epoch'] + 1
L
LielinJiang 已提交
463
            self.global_steps = self.iters_per_epoch * state_dicts['epoch']
L
LielinJiang 已提交
464

L
lijianshe02 已提交
465 466
            self.current_iter = state_dicts['epoch'] + 1

L
LielinJiang 已提交
467
        for net_name, net in self.model.nets.items():
468
            net.set_state_dict(state_dicts[net_name])
L
LielinJiang 已提交
469

L
LielinJiang 已提交
470
        for opt_name, opt in self.model.optimizers.items():
471
            opt.set_state_dict(state_dicts[opt_name])
L
LielinJiang 已提交
472 473 474

    def load(self, weight_path):
        state_dicts = load(weight_path)
L
LielinJiang 已提交
475

476 477 478 479 480 481 482
        def is_dict_in_dict_weight(state_dict):
            if isinstance(state_dict, dict) and len(state_dict) > 0:
                val = list(state_dict.values())[0]
                if isinstance(val, dict):
                    return True
                else:
                    return False
483
            else:
484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
                return False

        if is_dict_in_dict_weight(state_dicts):
            for net_name, net in self.model.nets.items():
                if net_name in state_dicts:
                    net.set_state_dict(state_dicts[net_name])
                    self.logger.info(
                        'Loaded pretrained weight for net {}'.format(net_name))
                else:
                    self.logger.warning(
                        'Can not find state dict of net {}. Skip load pretrained weight for net {}'
                        .format(net_name, net_name))
        else:
            assert len(self.model.nets
                       ) == 1, 'checkpoint only contain weight of one net, \
                                                but model contains more than one net!'

            net_name, net = list(self.model.nets.items())[0]
            net.set_state_dict(state_dicts)
            self.logger.info(
                'Loaded pretrained weight for net {}'.format(net_name))
郑启航 已提交
505 506 507 508 509 510

    def close(self):
        """
        when finish the training need close file handler or other.
        """
        if self.enable_visualdl:
W
wangna11BD 已提交
511
            self.vdl_logger.close()