trainer.py 18.7 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15
import os
L
LielinJiang 已提交
16
import sys
L
LielinJiang 已提交
17
import time
L
LielinJiang 已提交
18
import copy
L
LielinJiang 已提交
19

L
LielinJiang 已提交
20
import logging
L
LielinJiang 已提交
21
import datetime
L
LielinJiang 已提交
22

L
LielinJiang 已提交
23
import paddle
L
LielinJiang 已提交
24
from paddle.distributed import ParallelEnv
L
LielinJiang 已提交
25 26 27 28

from ..datasets.builder import build_dataloader
from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image
L
LielinJiang 已提交
29
from ..utils.filesystem import makedirs, save, load
30
from ..utils.timer import TimeAverager
L
lzzyzlbb 已提交
31
from ..utils.profiler import add_profiler_step
L
fix nan  
LielinJiang 已提交
32

33

34
class IterLoader:
35

36 37 38 39
    def __init__(self, dataloader):
        self._dataloader = dataloader
        self.iter_loader = iter(self._dataloader)
        self._epoch = 1
L
LielinJiang 已提交
40
        self._inner_iter = 0
L
LielinJiang 已提交
41

42 43 44 45 46 47
    @property
    def epoch(self):
        return self._epoch

    def __next__(self):
        try:
L
LielinJiang 已提交
48 49 50 51
            if sys.platform == "Windows" and self._inner_iter == len(
                    self._dataloader) - 1:
                self._inner_iter = 0
                raise StopIteration
52 53 54 55 56
            data = next(self.iter_loader)
        except StopIteration:
            self._epoch += 1
            self.iter_loader = iter(self._dataloader)
            data = next(self.iter_loader)
L
LielinJiang 已提交
57

L
LielinJiang 已提交
58
        self._inner_iter += 1
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
        return data

    def __len__(self):
        return len(self._dataloader)


class Trainer:
    """
    # trainer calling logic:
    #
    #                build_model                               ||    model(BaseModel)
    #                     |                                    ||
    #               build_dataloader                           ||    dataloader
    #                     |                                    ||
    #               model.setup_lr_schedulers                  ||    lr_scheduler
    #                     |                                    ||
    #               model.setup_optimizers                     ||    optimizers
    #                     |                                    ||
    #     train loop (model.setup_input + model.train_iter)    ||    train loop
    #                     |                                    ||
    #         print log (model.get_current_losses)             ||
    #                     |                                    ||
    #         save checkpoint (model.nets)                     \/
    """
83

84
    def __init__(self, cfg):
85 86 87 88 89 90 91
        # base config
        self.logger = logging.getLogger(__name__)
        self.cfg = cfg
        self.output_dir = cfg.output_dir
        self.max_eval_steps = cfg.model.get('max_eval_steps', None)

        self.local_rank = ParallelEnv().local_rank
92
        self.world_size = ParallelEnv().nranks
93 94 95 96 97 98 99 100 101 102
        self.log_interval = cfg.log_config.interval
        self.visual_interval = cfg.log_config.visiual_interval
        self.weight_interval = cfg.snapshot_config.interval

        self.start_epoch = 1
        self.current_epoch = 1
        self.current_iter = 1
        self.inner_iter = 1
        self.batch_id = 0
        self.global_steps = 0
L
LielinJiang 已提交
103

L
LielinJiang 已提交
104
        # build model
105
        self.model = build_model(cfg.model)
106 107 108
        # multiple gpus prepare
        if ParallelEnv().nranks > 1:
            self.distributed_data_parallel()
L
LielinJiang 已提交
109

110 111
        # build metrics
        self.metrics = None
L
LielinJiang 已提交
112
        self.is_save_img = True
113 114 115
        validate_cfg = cfg.get('validate', None)
        if validate_cfg and 'metrics' in validate_cfg:
            self.metrics = self.model.setup_metrics(validate_cfg['metrics'])
116 117
        if validate_cfg and 'save_img' in validate_cfg:
            self.is_save_img = validate_cfg['save_img']
118 119 120 121 122 123 124 125 126 127

        self.enable_visualdl = cfg.get('enable_visualdl', False)
        if self.enable_visualdl:
            import visualdl
            self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir)

        # evaluate only
        if not cfg.is_train:
            return

128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
        # build train dataloader
        self.train_dataloader = build_dataloader(cfg.dataset.train)
        self.iters_per_epoch = len(self.train_dataloader)

        # build lr scheduler
        # TODO: has a better way?
        if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler:
            cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch
        self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler)

        # build optimizers
        self.optimizers = self.model.setup_optimizers(self.lr_schedulers,
                                                      cfg.optimizer)

        self.epochs = cfg.get('epochs', None)
        if self.epochs:
            self.total_iters = self.epochs * self.iters_per_epoch
            self.by_epoch = True
        else:
            self.by_epoch = False
            self.total_iters = cfg.total_iters

L
LielinJiang 已提交
150 151 152
        if self.by_epoch:
            self.weight_interval *= self.iters_per_epoch

L
LielinJiang 已提交
153 154 155
        self.validate_interval = -1
        if cfg.get('validate', None) is not None:
            self.validate_interval = cfg.validate.get('interval', -1)
156 157

        self.time_count = {}
L
LielinJiang 已提交
158
        self.best_metric = {}
159
        self.model.set_total_iter(self.total_iters)
L
lzzyzlbb 已提交
160
        self.profiler_options = cfg.profiler_options
L
LielinJiang 已提交
161

162
    def distributed_data_parallel(self):
L
LielinJiang 已提交
163
        paddle.distributed.init_parallel_env()
164
        find_unused_parameters = self.cfg.get('find_unused_parameters', False)
165
        for net_name, net in self.model.nets.items():
166 167
            self.model.nets[net_name] = paddle.DataParallel(
                net, find_unused_parameters=find_unused_parameters)
168

L
LielinJiang 已提交
169 170 171 172 173 174 175 176 177 178 179
    def learning_rate_scheduler_step(self):
        if isinstance(self.model.lr_scheduler, dict):
            for lr_scheduler in self.model.lr_scheduler.values():
                lr_scheduler.step()
        elif isinstance(self.model.lr_scheduler,
                        paddle.optimizer.lr.LRScheduler):
            self.model.lr_scheduler.step()
        else:
            raise ValueError(
                'lr schedulter must be a dict or an instance of LRScheduler')

L
LielinJiang 已提交
180
    def train(self):
181 182
        reader_cost_averager = TimeAverager()
        batch_cost_averager = TimeAverager()
L
LielinJiang 已提交
183

184
        iter_loader = IterLoader(self.train_dataloader)
L
LielinJiang 已提交
185

186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
        # use amp
        if self.cfg.amp:
            self.logger.info('use AMP to train. AMP level = {}'.format(
                self.cfg.amp_level))
            assert self.cfg.model.name == 'MultiStageVSRModel', "AMP only support msvsr model"
            scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
            # need to decorate model and optim if amp_level == 'O2'
            if self.cfg.amp_level == 'O2':
                # msvsr has only one generator and one optimizer
                self.model.nets['generator'], self.optimizers[
                    'optim'] = paddle.amp.decorate(
                        models=self.model.nets['generator'],
                        optimizers=self.optimizers['optim'],
                        level='O2',
                        save_dtype='float32')

L
LielinJiang 已提交
202 203
        # set model.is_train = True
        self.model.setup_train_mode(is_train=True)
204 205 206
        while self.current_iter < (self.total_iters + 1):
            self.current_epoch = iter_loader.epoch
            self.inner_iter = self.current_iter % self.iters_per_epoch
L
LielinJiang 已提交
207

L
lzzyzlbb 已提交
208 209
            add_profiler_step(self.profiler_options)

210 211 212 213 214 215
            start_time = step_start_time = time.time()
            data = next(iter_loader)
            reader_cost_averager.record(time.time() - step_start_time)
            # unpack data from dataset and apply preprocessing
            # data input should be dict
            self.model.setup_input(data)
216 217 218 219 220 221

            if self.cfg.amp:
                self.model.train_iter_amp(self.optimizers, scaler,
                                          self.cfg.amp_level)  # amp train
            else:
                self.model.train_iter(self.optimizers)  # norm train
222

W
wangna11BD 已提交
223 224 225
            batch_cost_averager.record(
                time.time() - step_start_time,
                num_samples=self.cfg['dataset']['train'].get('batch_size', 1))
226 227 228

            step_start_time = time.time()

229 230 231 232 233 234 235 236 237
            if self.current_iter % self.log_interval == 0:
                self.data_time = reader_cost_averager.get_average()
                self.step_time = batch_cost_averager.get_average()
                self.ips = batch_cost_averager.get_ips_average()
                self.print_log()

                reader_cost_averager.reset()
                batch_cost_averager.reset()

L
LielinJiang 已提交
238
            if self.current_iter % self.visual_interval == 0 and self.local_rank == 0:
239 240
                self.visual('visual_train')

L
LielinJiang 已提交
241
            self.learning_rate_scheduler_step()
L
LielinJiang 已提交
242

L
LielinJiang 已提交
243
            if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0:
244
                self.test()
L
fix nan  
LielinJiang 已提交
245

L
LielinJiang 已提交
246 247 248
            if self.current_iter % self.weight_interval == 0:
                self.save(self.current_iter, 'weight', keep=-1)
                self.save(self.current_iter)
L
LielinJiang 已提交
249

250
            self.current_iter += 1
L
LielinJiang 已提交
251

L
LielinJiang 已提交
252 253
    def test(self):
        if not hasattr(self, 'test_dataloader'):
254
            self.test_dataloader = build_dataloader(self.cfg.dataset.test,
255
                                                    is_train=False)
L
lijianshe02 已提交
256 257 258
        iter_loader = IterLoader(self.test_dataloader)
        if self.max_eval_steps is None:
            self.max_eval_steps = len(self.test_dataloader)
259 260 261 262

        if self.metrics:
            for metric in self.metrics.values():
                metric.reset()
L
LielinJiang 已提交
263

L
LielinJiang 已提交
264 265 266
        # set model.is_train = False
        self.model.setup_train_mode(is_train=False)

L
lijianshe02 已提交
267
        for i in range(self.max_eval_steps):
268 269
            if self.max_eval_steps < self.log_interval or i % self.log_interval == 0:
                self.logger.info('Test iter: [%d/%d]' %
W
wangna11BD 已提交
270 271
                                 (i * self.world_size,
                                  self.max_eval_steps * self.world_size))
272

L
lijianshe02 已提交
273
            data = next(iter_loader)
274 275
            self.model.setup_input(data)
            self.model.test_iter(metrics=self.metrics)
L
LielinJiang 已提交
276

277 278 279 280
            if self.is_save_img:
                visual_results = {}
                current_paths = self.model.get_image_paths()
                current_visuals = self.model.get_current_visuals()
L
LielinJiang 已提交
281

282 283 284
                if len(current_visuals) > 0 and list(
                        current_visuals.values())[0].shape == 4:
                    num_samples = list(current_visuals.values())[0].shape[0]
L
LielinJiang 已提交
285
                else:
286
                    num_samples = 1
L
LielinJiang 已提交
287

288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
                for j in range(num_samples):
                    if j < len(current_paths):
                        short_path = os.path.basename(current_paths[j])
                        basename = os.path.splitext(short_path)[0]
                    else:
                        basename = '{:04d}_{:04d}'.format(i, j)
                    for k, img_tensor in current_visuals.items():
                        name = '%s_%s' % (basename, k)
                        if len(img_tensor.shape) == 4:
                            visual_results.update({name: img_tensor[j]})
                        else:
                            visual_results.update({name: img_tensor})

                self.visual('visual_test',
                            visual_results=visual_results,
                            step=self.batch_id,
                            is_save_image=True)
L
LielinJiang 已提交
305

306 307 308 309 310
        if self.metrics:
            for metric_name, metric in self.metrics.items():
                self.logger.info("Metric {}: {:.4f}".format(
                    metric_name, metric.accumulate()))

L
LielinJiang 已提交
311 312
    def print_log(self):
        losses = self.model.get_current_losses()
L
LielinJiang 已提交
313

314 315 316 317 318 319 320 321 322
        message = ''
        if self.by_epoch:
            message += 'Epoch: %d/%d, iter: %d/%d ' % (
                self.current_epoch, self.epochs, self.inner_iter,
                self.iters_per_epoch)
        else:
            message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters)

        message += f'lr: {self.current_learning_rate:.3e} '
L
LielinJiang 已提交
323 324 325

        for k, v in losses.items():
            message += '%s: %.3f ' % (k, v)
郑启航 已提交
326 327
            if self.enable_visualdl:
                self.vdl_logger.add_scalar(k, v, step=self.global_steps)
L
LielinJiang 已提交
328

329 330 331
        if hasattr(self, 'step_time'):
            message += 'batch_cost: %.5f sec ' % self.step_time

332
        if hasattr(self, 'data_time'):
333
            message += 'reader_cost: %.5f sec ' % self.data_time
334

335
        if hasattr(self, 'ips'):
L
LielinJiang 已提交
336 337 338
            message += 'ips: %.5f images/s ' % self.ips

        if hasattr(self, 'step_time'):
L
LielinJiang 已提交
339 340 341
            eta = self.step_time * (self.total_iters - self.current_iter)
            eta = eta if eta > 0 else 0

L
LielinJiang 已提交
342 343
            eta_str = str(datetime.timedelta(seconds=int(eta)))
            message += f'eta: {eta_str}'
344

L
LielinJiang 已提交
345 346 347 348 349
        # print the message
        self.logger.info(message)

    @property
    def current_learning_rate(self):
L
LielinJiang 已提交
350 351
        for optimizer in self.model.optimizers.values():
            return optimizer.get_lr()
L
LielinJiang 已提交
352

郑启航 已提交
353 354 355 356 357 358 359 360 361 362 363 364 365
    def visual(self,
               results_dir,
               visual_results=None,
               step=None,
               is_save_image=False):
        """
        visual the images, use visualdl or directly write to the directory
        Parameters:
            results_dir (str)     --  directory name which contains saved images
            visual_results (dict) --  the results images dict
            step (int)            --  global steps, used in visualdl
            is_save_image (bool)  --  weather write to the directory or visualdl
        """
L
LielinJiang 已提交
366 367 368 369 370
        self.model.compute_visuals()

        if visual_results is None:
            visual_results = self.model.get_current_visuals()

L
LielinJiang 已提交
371 372 373
        min_max = self.cfg.get('min_max', None)
        if min_max is None:
            min_max = (-1., 1.)
374

郑启航 已提交
375 376 377
        image_num = self.cfg.get('image_num', None)
        if (image_num is None) or (not self.enable_visualdl):
            image_num = 1
L
LielinJiang 已提交
378
        for label, image in visual_results.items():
郑启航 已提交
379 380 381 382 383 384 385 386 387
            image_numpy = tensor2img(image, min_max, image_num)
            if (not is_save_image) and self.enable_visualdl:
                self.vdl_logger.add_image(
                    results_dir + '/' + label,
                    image_numpy,
                    step=step if step else self.global_steps,
                    dataformats="HWC" if image_num == 1 else "NCHW")
            else:
                if self.cfg.is_train:
W
wangna11BD 已提交
388 389 390 391
                    if self.by_epoch:
                        msg = 'epoch%.3d_' % self.current_epoch
                    else:
                        msg = 'iter%.3d_' % self.current_iter
郑启航 已提交
392 393 394 395 396 397
                else:
                    msg = ''
                makedirs(os.path.join(self.output_dir, results_dir))
                img_path = os.path.join(self.output_dir, results_dir,
                                        msg + '%s.png' % (label))
                save_image(image_numpy, img_path)
L
LielinJiang 已提交
398 399 400 401

    def save(self, epoch, name='checkpoint', keep=1):
        if self.local_rank != 0:
            return
L
LielinJiang 已提交
402

L
LielinJiang 已提交
403 404 405
        assert name in ['checkpoint', 'weight']

        state_dicts = {}
L
LielinJiang 已提交
406 407 408 409 410 411
        if self.by_epoch:
            save_filename = 'epoch_%s_%s.pdparams' % (
                epoch // self.iters_per_epoch, name)
        else:
            save_filename = 'iter_%s_%s.pdparams' % (epoch, name)

L
lijianshe02 已提交
412
        os.makedirs(self.output_dir, exist_ok=True)
L
LielinJiang 已提交
413
        save_path = os.path.join(self.output_dir, save_filename)
L
LielinJiang 已提交
414 415
        for net_name, net in self.model.nets.items():
            state_dicts[net_name] = net.state_dict()
L
LielinJiang 已提交
416 417 418 419 420 421 422

        if name == 'weight':
            save(state_dicts, save_path)
            return

        state_dicts['epoch'] = epoch

L
LielinJiang 已提交
423 424
        for opt_name, opt in self.model.optimizers.items():
            state_dicts[opt_name] = opt.state_dict()
L
LielinJiang 已提交
425 426 427 428 429

        save(state_dicts, save_path)

        if keep > 0:
            try:
L
LielinJiang 已提交
430 431 432 433 434 435 436 437 438 439
                if self.by_epoch:
                    checkpoint_name_to_be_removed = os.path.join(
                        self.output_dir, 'epoch_%s_%s.pdparams' %
                        ((epoch - keep * self.weight_interval) //
                         self.iters_per_epoch, name))
                else:
                    checkpoint_name_to_be_removed = os.path.join(
                        self.output_dir, 'iter_%s_%s.pdparams' %
                        (epoch - keep * self.weight_interval, name))

L
LielinJiang 已提交
440 441 442 443 444 445 446 447 448 449
                if os.path.exists(checkpoint_name_to_be_removed):
                    os.remove(checkpoint_name_to_be_removed)

            except Exception as e:
                self.logger.info('remove old checkpoints error: {}'.format(e))

    def resume(self, checkpoint_path):
        state_dicts = load(checkpoint_path)
        if state_dicts.get('epoch', None) is not None:
            self.start_epoch = state_dicts['epoch'] + 1
L
LielinJiang 已提交
450
            self.global_steps = self.iters_per_epoch * state_dicts['epoch']
L
LielinJiang 已提交
451

L
lijianshe02 已提交
452 453
            self.current_iter = state_dicts['epoch'] + 1

L
LielinJiang 已提交
454
        for net_name, net in self.model.nets.items():
455
            net.set_state_dict(state_dicts[net_name])
L
LielinJiang 已提交
456

L
LielinJiang 已提交
457
        for opt_name, opt in self.model.optimizers.items():
458
            opt.set_state_dict(state_dicts[opt_name])
L
LielinJiang 已提交
459 460 461

    def load(self, weight_path):
        state_dicts = load(weight_path)
L
LielinJiang 已提交
462

463 464 465 466 467 468 469
        def is_dict_in_dict_weight(state_dict):
            if isinstance(state_dict, dict) and len(state_dict) > 0:
                val = list(state_dict.values())[0]
                if isinstance(val, dict):
                    return True
                else:
                    return False
470
            else:
471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
                return False

        if is_dict_in_dict_weight(state_dicts):
            for net_name, net in self.model.nets.items():
                if net_name in state_dicts:
                    net.set_state_dict(state_dicts[net_name])
                    self.logger.info(
                        'Loaded pretrained weight for net {}'.format(net_name))
                else:
                    self.logger.warning(
                        'Can not find state dict of net {}. Skip load pretrained weight for net {}'
                        .format(net_name, net_name))
        else:
            assert len(self.model.nets
                       ) == 1, 'checkpoint only contain weight of one net, \
                                                but model contains more than one net!'

            net_name, net = list(self.model.nets.items())[0]
            net.set_state_dict(state_dicts)
            self.logger.info(
                'Loaded pretrained weight for net {}'.format(net_name))
郑启航 已提交
492 493 494 495 496 497

    def close(self):
        """
        when finish the training need close file handler or other.
        """
        if self.enable_visualdl:
W
wangna11BD 已提交
498
            self.vdl_logger.close()