trainer.py 17.7 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15
import os
L
LielinJiang 已提交
16
import sys
L
LielinJiang 已提交
17
import time
L
LielinJiang 已提交
18
import copy
L
LielinJiang 已提交
19

L
LielinJiang 已提交
20
import logging
L
LielinJiang 已提交
21
import datetime
L
LielinJiang 已提交
22

L
LielinJiang 已提交
23
import paddle
L
LielinJiang 已提交
24
from paddle.distributed import ParallelEnv
L
LielinJiang 已提交
25 26 27 28

from ..datasets.builder import build_dataloader
from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image
L
LielinJiang 已提交
29
from ..utils.filesystem import makedirs, save, load
30
from ..utils.timer import TimeAverager
L
lzzyzlbb 已提交
31
from ..utils.profiler import add_profiler_step
L
fix nan  
LielinJiang 已提交
32

33

34 35 36 37 38
class IterLoader:
    def __init__(self, dataloader):
        self._dataloader = dataloader
        self.iter_loader = iter(self._dataloader)
        self._epoch = 1
L
LielinJiang 已提交
39
        self._inner_iter = 0
L
LielinJiang 已提交
40

41 42 43 44 45 46
    @property
    def epoch(self):
        return self._epoch

    def __next__(self):
        try:
L
LielinJiang 已提交
47 48 49 50
            if sys.platform == "Windows" and self._inner_iter == len(
                    self._dataloader) - 1:
                self._inner_iter = 0
                raise StopIteration
51 52 53 54 55
            data = next(self.iter_loader)
        except StopIteration:
            self._epoch += 1
            self.iter_loader = iter(self._dataloader)
            data = next(self.iter_loader)
L
LielinJiang 已提交
56

L
LielinJiang 已提交
57
        self._inner_iter += 1
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
        return data

    def __len__(self):
        return len(self._dataloader)


class Trainer:
    """
    # trainer calling logic:
    #
    #                build_model                               ||    model(BaseModel)
    #                     |                                    ||
    #               build_dataloader                           ||    dataloader
    #                     |                                    ||
    #               model.setup_lr_schedulers                  ||    lr_scheduler
    #                     |                                    ||
    #               model.setup_optimizers                     ||    optimizers
    #                     |                                    ||
    #     train loop (model.setup_input + model.train_iter)    ||    train loop
    #                     |                                    ||
    #         print log (model.get_current_losses)             ||
    #                     |                                    ||
    #         save checkpoint (model.nets)                     \/
    """
    def __init__(self, cfg):
83 84 85 86 87 88 89
        # base config
        self.logger = logging.getLogger(__name__)
        self.cfg = cfg
        self.output_dir = cfg.output_dir
        self.max_eval_steps = cfg.model.get('max_eval_steps', None)

        self.local_rank = ParallelEnv().local_rank
90
        self.world_size = ParallelEnv().nranks
91 92 93 94 95 96 97 98 99 100
        self.log_interval = cfg.log_config.interval
        self.visual_interval = cfg.log_config.visiual_interval
        self.weight_interval = cfg.snapshot_config.interval

        self.start_epoch = 1
        self.current_epoch = 1
        self.current_iter = 1
        self.inner_iter = 1
        self.batch_id = 0
        self.global_steps = 0
L
LielinJiang 已提交
101

L
LielinJiang 已提交
102
        # build model
103
        self.model = build_model(cfg.model)
104 105 106
        # multiple gpus prepare
        if ParallelEnv().nranks > 1:
            self.distributed_data_parallel()
L
LielinJiang 已提交
107

108 109
        # build metrics
        self.metrics = None
L
LielinJiang 已提交
110
        self.is_save_img = True
111 112 113
        validate_cfg = cfg.get('validate', None)
        if validate_cfg and 'metrics' in validate_cfg:
            self.metrics = self.model.setup_metrics(validate_cfg['metrics'])
114 115
        if validate_cfg and 'save_img' in validate_cfg:
            self.is_save_img = validate_cfg['save_img']
116 117 118 119 120 121 122 123 124 125

        self.enable_visualdl = cfg.get('enable_visualdl', False)
        if self.enable_visualdl:
            import visualdl
            self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir)

        # evaluate only
        if not cfg.is_train:
            return

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
        # build train dataloader
        self.train_dataloader = build_dataloader(cfg.dataset.train)
        self.iters_per_epoch = len(self.train_dataloader)

        # build lr scheduler
        # TODO: has a better way?
        if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler:
            cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch
        self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler)

        # build optimizers
        self.optimizers = self.model.setup_optimizers(self.lr_schedulers,
                                                      cfg.optimizer)

        self.epochs = cfg.get('epochs', None)
        if self.epochs:
            self.total_iters = self.epochs * self.iters_per_epoch
            self.by_epoch = True
        else:
            self.by_epoch = False
            self.total_iters = cfg.total_iters

L
LielinJiang 已提交
148 149 150
        if self.by_epoch:
            self.weight_interval *= self.iters_per_epoch

L
LielinJiang 已提交
151 152 153
        self.validate_interval = -1
        if cfg.get('validate', None) is not None:
            self.validate_interval = cfg.validate.get('interval', -1)
154 155

        self.time_count = {}
L
LielinJiang 已提交
156
        self.best_metric = {}
157
        self.model.set_total_iter(self.total_iters)
L
lzzyzlbb 已提交
158
        self.profiler_options = cfg.profiler_options
L
LielinJiang 已提交
159

160
    def distributed_data_parallel(self):
L
LielinJiang 已提交
161
        paddle.distributed.init_parallel_env()
162
        find_unused_parameters = self.cfg.get('find_unused_parameters', False)
163
        for net_name, net in self.model.nets.items():
164 165
            self.model.nets[net_name] = paddle.DataParallel(
                net, find_unused_parameters=find_unused_parameters)
166

L
LielinJiang 已提交
167 168 169 170 171 172 173 174 175 176 177
    def learning_rate_scheduler_step(self):
        if isinstance(self.model.lr_scheduler, dict):
            for lr_scheduler in self.model.lr_scheduler.values():
                lr_scheduler.step()
        elif isinstance(self.model.lr_scheduler,
                        paddle.optimizer.lr.LRScheduler):
            self.model.lr_scheduler.step()
        else:
            raise ValueError(
                'lr schedulter must be a dict or an instance of LRScheduler')

L
LielinJiang 已提交
178
    def train(self):
179 180
        reader_cost_averager = TimeAverager()
        batch_cost_averager = TimeAverager()
L
LielinJiang 已提交
181

182
        iter_loader = IterLoader(self.train_dataloader)
L
LielinJiang 已提交
183

L
LielinJiang 已提交
184 185
        # set model.is_train = True
        self.model.setup_train_mode(is_train=True)
186 187 188
        while self.current_iter < (self.total_iters + 1):
            self.current_epoch = iter_loader.epoch
            self.inner_iter = self.current_iter % self.iters_per_epoch
L
LielinJiang 已提交
189

L
lzzyzlbb 已提交
190 191
            add_profiler_step(self.profiler_options)

192 193 194 195 196 197 198 199
            start_time = step_start_time = time.time()
            data = next(iter_loader)
            reader_cost_averager.record(time.time() - step_start_time)
            # unpack data from dataset and apply preprocessing
            # data input should be dict
            self.model.setup_input(data)
            self.model.train_iter(self.optimizers)

W
wangna11BD 已提交
200 201 202
            batch_cost_averager.record(
                time.time() - step_start_time,
                num_samples=self.cfg['dataset']['train'].get('batch_size', 1))
203 204 205

            step_start_time = time.time()

206 207 208 209 210 211 212 213 214
            if self.current_iter % self.log_interval == 0:
                self.data_time = reader_cost_averager.get_average()
                self.step_time = batch_cost_averager.get_average()
                self.ips = batch_cost_averager.get_ips_average()
                self.print_log()

                reader_cost_averager.reset()
                batch_cost_averager.reset()

L
LielinJiang 已提交
215
            if self.current_iter % self.visual_interval == 0 and self.local_rank == 0:
216 217
                self.visual('visual_train')

L
LielinJiang 已提交
218
            self.learning_rate_scheduler_step()
L
LielinJiang 已提交
219

L
LielinJiang 已提交
220
            if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0:
221
                self.test()
L
fix nan  
LielinJiang 已提交
222

L
LielinJiang 已提交
223 224 225
            if self.current_iter % self.weight_interval == 0:
                self.save(self.current_iter, 'weight', keep=-1)
                self.save(self.current_iter)
L
LielinJiang 已提交
226

227
            self.current_iter += 1
L
LielinJiang 已提交
228

L
LielinJiang 已提交
229 230
    def test(self):
        if not hasattr(self, 'test_dataloader'):
231
            self.test_dataloader = build_dataloader(self.cfg.dataset.test,
232
                                                    is_train=False)
L
lijianshe02 已提交
233 234 235
        iter_loader = IterLoader(self.test_dataloader)
        if self.max_eval_steps is None:
            self.max_eval_steps = len(self.test_dataloader)
236 237 238 239

        if self.metrics:
            for metric in self.metrics.values():
                metric.reset()
L
LielinJiang 已提交
240

L
LielinJiang 已提交
241 242 243
        # set model.is_train = False
        self.model.setup_train_mode(is_train=False)

L
lijianshe02 已提交
244
        for i in range(self.max_eval_steps):
245 246
            if self.max_eval_steps < self.log_interval or i % self.log_interval == 0:
                self.logger.info('Test iter: [%d/%d]' %
W
wangna11BD 已提交
247 248
                                 (i * self.world_size,
                                  self.max_eval_steps * self.world_size))
249

L
lijianshe02 已提交
250
            data = next(iter_loader)
251 252
            self.model.setup_input(data)
            self.model.test_iter(metrics=self.metrics)
L
LielinJiang 已提交
253

254 255 256 257
            if self.is_save_img:
                visual_results = {}
                current_paths = self.model.get_image_paths()
                current_visuals = self.model.get_current_visuals()
L
LielinJiang 已提交
258

259 260 261
                if len(current_visuals) > 0 and list(
                        current_visuals.values())[0].shape == 4:
                    num_samples = list(current_visuals.values())[0].shape[0]
L
LielinJiang 已提交
262
                else:
263
                    num_samples = 1
L
LielinJiang 已提交
264

265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
                for j in range(num_samples):
                    if j < len(current_paths):
                        short_path = os.path.basename(current_paths[j])
                        basename = os.path.splitext(short_path)[0]
                    else:
                        basename = '{:04d}_{:04d}'.format(i, j)
                    for k, img_tensor in current_visuals.items():
                        name = '%s_%s' % (basename, k)
                        if len(img_tensor.shape) == 4:
                            visual_results.update({name: img_tensor[j]})
                        else:
                            visual_results.update({name: img_tensor})

                self.visual('visual_test',
                            visual_results=visual_results,
                            step=self.batch_id,
                            is_save_image=True)
L
LielinJiang 已提交
282

283 284 285 286 287
        if self.metrics:
            for metric_name, metric in self.metrics.items():
                self.logger.info("Metric {}: {:.4f}".format(
                    metric_name, metric.accumulate()))

L
LielinJiang 已提交
288 289
    def print_log(self):
        losses = self.model.get_current_losses()
L
LielinJiang 已提交
290

291 292 293 294 295 296 297 298 299
        message = ''
        if self.by_epoch:
            message += 'Epoch: %d/%d, iter: %d/%d ' % (
                self.current_epoch, self.epochs, self.inner_iter,
                self.iters_per_epoch)
        else:
            message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters)

        message += f'lr: {self.current_learning_rate:.3e} '
L
LielinJiang 已提交
300 301 302

        for k, v in losses.items():
            message += '%s: %.3f ' % (k, v)
郑启航 已提交
303 304
            if self.enable_visualdl:
                self.vdl_logger.add_scalar(k, v, step=self.global_steps)
L
LielinJiang 已提交
305

306 307 308
        if hasattr(self, 'step_time'):
            message += 'batch_cost: %.5f sec ' % self.step_time

309
        if hasattr(self, 'data_time'):
310
            message += 'reader_cost: %.5f sec ' % self.data_time
311

312
        if hasattr(self, 'ips'):
L
LielinJiang 已提交
313 314 315
            message += 'ips: %.5f images/s ' % self.ips

        if hasattr(self, 'step_time'):
L
LielinJiang 已提交
316 317 318
            eta = self.step_time * (self.total_iters - self.current_iter)
            eta = eta if eta > 0 else 0

L
LielinJiang 已提交
319 320
            eta_str = str(datetime.timedelta(seconds=int(eta)))
            message += f'eta: {eta_str}'
321

L
LielinJiang 已提交
322 323 324 325 326
        # print the message
        self.logger.info(message)

    @property
    def current_learning_rate(self):
L
LielinJiang 已提交
327 328
        for optimizer in self.model.optimizers.values():
            return optimizer.get_lr()
L
LielinJiang 已提交
329

郑启航 已提交
330 331 332 333 334 335 336 337 338 339 340 341 342
    def visual(self,
               results_dir,
               visual_results=None,
               step=None,
               is_save_image=False):
        """
        visual the images, use visualdl or directly write to the directory
        Parameters:
            results_dir (str)     --  directory name which contains saved images
            visual_results (dict) --  the results images dict
            step (int)            --  global steps, used in visualdl
            is_save_image (bool)  --  weather write to the directory or visualdl
        """
L
LielinJiang 已提交
343 344 345 346 347
        self.model.compute_visuals()

        if visual_results is None:
            visual_results = self.model.get_current_visuals()

L
LielinJiang 已提交
348 349 350
        min_max = self.cfg.get('min_max', None)
        if min_max is None:
            min_max = (-1., 1.)
351

郑启航 已提交
352 353 354
        image_num = self.cfg.get('image_num', None)
        if (image_num is None) or (not self.enable_visualdl):
            image_num = 1
L
LielinJiang 已提交
355
        for label, image in visual_results.items():
郑启航 已提交
356 357 358 359 360 361 362 363 364
            image_numpy = tensor2img(image, min_max, image_num)
            if (not is_save_image) and self.enable_visualdl:
                self.vdl_logger.add_image(
                    results_dir + '/' + label,
                    image_numpy,
                    step=step if step else self.global_steps,
                    dataformats="HWC" if image_num == 1 else "NCHW")
            else:
                if self.cfg.is_train:
W
wangna11BD 已提交
365 366 367 368
                    if self.by_epoch:
                        msg = 'epoch%.3d_' % self.current_epoch
                    else:
                        msg = 'iter%.3d_' % self.current_iter
郑启航 已提交
369 370 371 372 373 374
                else:
                    msg = ''
                makedirs(os.path.join(self.output_dir, results_dir))
                img_path = os.path.join(self.output_dir, results_dir,
                                        msg + '%s.png' % (label))
                save_image(image_numpy, img_path)
L
LielinJiang 已提交
375 376 377 378

    def save(self, epoch, name='checkpoint', keep=1):
        if self.local_rank != 0:
            return
L
LielinJiang 已提交
379

L
LielinJiang 已提交
380 381 382
        assert name in ['checkpoint', 'weight']

        state_dicts = {}
L
LielinJiang 已提交
383 384 385 386 387 388
        if self.by_epoch:
            save_filename = 'epoch_%s_%s.pdparams' % (
                epoch // self.iters_per_epoch, name)
        else:
            save_filename = 'iter_%s_%s.pdparams' % (epoch, name)

L
lijianshe02 已提交
389
        os.makedirs(self.output_dir, exist_ok=True)
L
LielinJiang 已提交
390
        save_path = os.path.join(self.output_dir, save_filename)
L
LielinJiang 已提交
391 392
        for net_name, net in self.model.nets.items():
            state_dicts[net_name] = net.state_dict()
L
LielinJiang 已提交
393 394 395 396 397 398 399

        if name == 'weight':
            save(state_dicts, save_path)
            return

        state_dicts['epoch'] = epoch

L
LielinJiang 已提交
400 401
        for opt_name, opt in self.model.optimizers.items():
            state_dicts[opt_name] = opt.state_dict()
L
LielinJiang 已提交
402 403 404 405 406

        save(state_dicts, save_path)

        if keep > 0:
            try:
L
LielinJiang 已提交
407 408 409 410 411 412 413 414 415 416
                if self.by_epoch:
                    checkpoint_name_to_be_removed = os.path.join(
                        self.output_dir, 'epoch_%s_%s.pdparams' %
                        ((epoch - keep * self.weight_interval) //
                         self.iters_per_epoch, name))
                else:
                    checkpoint_name_to_be_removed = os.path.join(
                        self.output_dir, 'iter_%s_%s.pdparams' %
                        (epoch - keep * self.weight_interval, name))

L
LielinJiang 已提交
417 418 419 420 421 422 423 424 425 426
                if os.path.exists(checkpoint_name_to_be_removed):
                    os.remove(checkpoint_name_to_be_removed)

            except Exception as e:
                self.logger.info('remove old checkpoints error: {}'.format(e))

    def resume(self, checkpoint_path):
        state_dicts = load(checkpoint_path)
        if state_dicts.get('epoch', None) is not None:
            self.start_epoch = state_dicts['epoch'] + 1
L
LielinJiang 已提交
427
            self.global_steps = self.iters_per_epoch * state_dicts['epoch']
L
LielinJiang 已提交
428

L
lijianshe02 已提交
429 430
            self.current_iter = state_dicts['epoch'] + 1

L
LielinJiang 已提交
431
        for net_name, net in self.model.nets.items():
432
            net.set_state_dict(state_dicts[net_name])
L
LielinJiang 已提交
433

L
LielinJiang 已提交
434
        for opt_name, opt in self.model.optimizers.items():
435
            opt.set_state_dict(state_dicts[opt_name])
L
LielinJiang 已提交
436 437 438

    def load(self, weight_path):
        state_dicts = load(weight_path)
L
LielinJiang 已提交
439

440 441 442 443 444 445 446
        def is_dict_in_dict_weight(state_dict):
            if isinstance(state_dict, dict) and len(state_dict) > 0:
                val = list(state_dict.values())[0]
                if isinstance(val, dict):
                    return True
                else:
                    return False
447
            else:
448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468
                return False

        if is_dict_in_dict_weight(state_dicts):
            for net_name, net in self.model.nets.items():
                if net_name in state_dicts:
                    net.set_state_dict(state_dicts[net_name])
                    self.logger.info(
                        'Loaded pretrained weight for net {}'.format(net_name))
                else:
                    self.logger.warning(
                        'Can not find state dict of net {}. Skip load pretrained weight for net {}'
                        .format(net_name, net_name))
        else:
            assert len(self.model.nets
                       ) == 1, 'checkpoint only contain weight of one net, \
                                                but model contains more than one net!'

            net_name, net = list(self.model.nets.items())[0]
            net.set_state_dict(state_dicts)
            self.logger.info(
                'Loaded pretrained weight for net {}'.format(net_name))
郑启航 已提交
469 470 471 472 473 474

    def close(self):
        """
        when finish the training need close file handler or other.
        """
        if self.enable_visualdl:
W
wangna11BD 已提交
475
            self.vdl_logger.close()