trainer.py 15.6 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15 16
import os
import time
L
LielinJiang 已提交
17
import copy
L
LielinJiang 已提交
18

L
LielinJiang 已提交
19
import logging
L
LielinJiang 已提交
20
import datetime
L
LielinJiang 已提交
21

L
LielinJiang 已提交
22
import paddle
L
LielinJiang 已提交
23
from paddle.distributed import ParallelEnv
L
LielinJiang 已提交
24 25 26 27

from ..datasets.builder import build_dataloader
from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image
L
LielinJiang 已提交
28
from ..utils.filesystem import makedirs, save, load
29
from ..utils.timer import TimeAverager
L
LielinJiang 已提交
30

L
fix nan  
LielinJiang 已提交
31

32 33 34 35 36
class IterLoader:
    def __init__(self, dataloader):
        self._dataloader = dataloader
        self.iter_loader = iter(self._dataloader)
        self._epoch = 1
L
LielinJiang 已提交
37

38 39 40 41 42 43 44 45 46 47 48
    @property
    def epoch(self):
        return self._epoch

    def __next__(self):
        try:
            data = next(self.iter_loader)
        except StopIteration:
            self._epoch += 1
            self.iter_loader = iter(self._dataloader)
            data = next(self.iter_loader)
L
LielinJiang 已提交
49

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
        return data

    def __len__(self):
        return len(self._dataloader)


class Trainer:
    """
    # trainer calling logic:
    #
    #                build_model                               ||    model(BaseModel)
    #                     |                                    ||
    #               build_dataloader                           ||    dataloader
    #                     |                                    ||
    #               model.setup_lr_schedulers                  ||    lr_scheduler
    #                     |                                    ||
    #               model.setup_optimizers                     ||    optimizers
    #                     |                                    ||
    #     train loop (model.setup_input + model.train_iter)    ||    train loop
    #                     |                                    ||
    #         print log (model.get_current_losses)             ||
    #                     |                                    ||
    #         save checkpoint (model.nets)                     \/
    """
    def __init__(self, cfg):
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
        # base config
        self.logger = logging.getLogger(__name__)
        self.cfg = cfg
        self.output_dir = cfg.output_dir
        self.max_eval_steps = cfg.model.get('max_eval_steps', None)

        self.local_rank = ParallelEnv().local_rank
        self.log_interval = cfg.log_config.interval
        self.visual_interval = cfg.log_config.visiual_interval
        self.weight_interval = cfg.snapshot_config.interval

        self.start_epoch = 1
        self.current_epoch = 1
        self.current_iter = 1
        self.inner_iter = 1
        self.batch_id = 0
        self.global_steps = 0
L
LielinJiang 已提交
92

L
LielinJiang 已提交
93
        # build model
94
        self.model = build_model(cfg.model)
95 96 97
        # multiple gpus prepare
        if ParallelEnv().nranks > 1:
            self.distributed_data_parallel()
L
LielinJiang 已提交
98

99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
        # build metrics
        self.metrics = None
        validate_cfg = cfg.get('validate', None)
        if validate_cfg and 'metrics' in validate_cfg:
            self.metrics = self.model.setup_metrics(validate_cfg['metrics'])

        self.enable_visualdl = cfg.get('enable_visualdl', False)
        if self.enable_visualdl:
            import visualdl
            self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir)

        # evaluate only
        if not cfg.is_train:
            return

114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
        # build train dataloader
        self.train_dataloader = build_dataloader(cfg.dataset.train)
        self.iters_per_epoch = len(self.train_dataloader)

        # build lr scheduler
        # TODO: has a better way?
        if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler:
            cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch
        self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler)

        # build optimizers
        self.optimizers = self.model.setup_optimizers(self.lr_schedulers,
                                                      cfg.optimizer)

        self.epochs = cfg.get('epochs', None)
        if self.epochs:
            self.total_iters = self.epochs * self.iters_per_epoch
            self.by_epoch = True
        else:
            self.by_epoch = False
            self.total_iters = cfg.total_iters

L
LielinJiang 已提交
136 137 138
        if self.by_epoch:
            self.weight_interval *= self.iters_per_epoch

L
LielinJiang 已提交
139 140 141
        self.validate_interval = -1
        if cfg.get('validate', None) is not None:
            self.validate_interval = cfg.validate.get('interval', -1)
142 143

        self.time_count = {}
L
LielinJiang 已提交
144 145
        self.best_metric = {}

146
    def distributed_data_parallel(self):
L
LielinJiang 已提交
147
        paddle.distributed.init_parallel_env()
148
        for net_name, net in self.model.nets.items():
L
LielinJiang 已提交
149
            self.model.nets[net_name] = paddle.DataParallel(net)
150

L
LielinJiang 已提交
151 152 153 154 155 156 157 158 159 160 161
    def learning_rate_scheduler_step(self):
        if isinstance(self.model.lr_scheduler, dict):
            for lr_scheduler in self.model.lr_scheduler.values():
                lr_scheduler.step()
        elif isinstance(self.model.lr_scheduler,
                        paddle.optimizer.lr.LRScheduler):
            self.model.lr_scheduler.step()
        else:
            raise ValueError(
                'lr schedulter must be a dict or an instance of LRScheduler')

L
LielinJiang 已提交
162
    def train(self):
163 164
        reader_cost_averager = TimeAverager()
        batch_cost_averager = TimeAverager()
L
LielinJiang 已提交
165

166
        iter_loader = IterLoader(self.train_dataloader)
L
LielinJiang 已提交
167

168 169 170
        while self.current_iter < (self.total_iters + 1):
            self.current_epoch = iter_loader.epoch
            self.inner_iter = self.current_iter % self.iters_per_epoch
L
LielinJiang 已提交
171

172 173 174 175 176 177 178 179 180 181 182
            start_time = step_start_time = time.time()
            data = next(iter_loader)
            reader_cost_averager.record(time.time() - step_start_time)
            # unpack data from dataset and apply preprocessing
            # data input should be dict
            self.model.setup_input(data)
            self.model.train_iter(self.optimizers)

            batch_cost_averager.record(time.time() - step_start_time,
                                       num_samples=self.cfg.get(
                                           'batch_size', 1))
183 184 185

            step_start_time = time.time()

186 187 188 189 190 191 192 193 194 195 196 197
            if self.current_iter % self.log_interval == 0:
                self.data_time = reader_cost_averager.get_average()
                self.step_time = batch_cost_averager.get_average()
                self.ips = batch_cost_averager.get_ips_average()
                self.print_log()

                reader_cost_averager.reset()
                batch_cost_averager.reset()

            if self.current_iter % self.visual_interval == 0:
                self.visual('visual_train')

L
LielinJiang 已提交
198
            self.learning_rate_scheduler_step()
L
LielinJiang 已提交
199

L
LielinJiang 已提交
200
            if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0:
201
                self.test()
L
fix nan  
LielinJiang 已提交
202

L
LielinJiang 已提交
203 204 205
            if self.current_iter % self.weight_interval == 0:
                self.save(self.current_iter, 'weight', keep=-1)
                self.save(self.current_iter)
L
LielinJiang 已提交
206

207
            self.current_iter += 1
L
LielinJiang 已提交
208

L
LielinJiang 已提交
209 210
    def test(self):
        if not hasattr(self, 'test_dataloader'):
211
            self.test_dataloader = build_dataloader(self.cfg.dataset.test,
212 213
                                                    is_train=False,
                                                    distributed=False)
L
lijianshe02 已提交
214 215 216
        iter_loader = IterLoader(self.test_dataloader)
        if self.max_eval_steps is None:
            self.max_eval_steps = len(self.test_dataloader)
217 218 219 220

        if self.metrics:
            for metric in self.metrics.values():
                metric.reset()
L
LielinJiang 已提交
221

L
lijianshe02 已提交
222 223
        for i in range(self.max_eval_steps):
            data = next(iter_loader)
224 225
            self.model.setup_input(data)
            self.model.test_iter(metrics=self.metrics)
L
LielinJiang 已提交
226 227

            visual_results = {}
L
LielinJiang 已提交
228 229 230
            current_paths = self.model.get_image_paths()
            current_visuals = self.model.get_current_visuals()

L
LielinJiang 已提交
231 232 233 234 235 236 237 238 239 240 241 242
            if len(current_visuals) > 0 and list(
                    current_visuals.values())[0].shape == 4:
                num_samples = list(current_visuals.values())[0].shape[0]
            else:
                num_samples = 1

            for j in range(num_samples):
                if j < len(current_paths):
                    short_path = os.path.basename(current_paths[j])
                    basename = os.path.splitext(short_path)[0]
                else:
                    basename = '{:04d}_{:04d}'.format(i, j)
L
LielinJiang 已提交
243 244
                for k, img_tensor in current_visuals.items():
                    name = '%s_%s' % (basename, k)
L
LielinJiang 已提交
245 246 247 248
                    if len(img_tensor.shape) == 4:
                        visual_results.update({name: img_tensor[j]})
                    else:
                        visual_results.update({name: img_tensor})
L
LielinJiang 已提交
249

郑启航 已提交
250 251 252 253
            self.visual('visual_test',
                        visual_results=visual_results,
                        step=self.batch_id,
                        is_save_image=True)
L
LielinJiang 已提交
254

L
LielinJiang 已提交
255
            if i % self.log_interval == 0:
256
                self.logger.info('Test iter: [%d/%d]' %
L
lijianshe02 已提交
257
                                 (i, self.max_eval_steps))
L
LielinJiang 已提交
258

259 260 261 262 263
        if self.metrics:
            for metric_name, metric in self.metrics.items():
                self.logger.info("Metric {}: {:.4f}".format(
                    metric_name, metric.accumulate()))

L
LielinJiang 已提交
264 265
    def print_log(self):
        losses = self.model.get_current_losses()
L
LielinJiang 已提交
266

267 268 269 270 271 272 273 274 275
        message = ''
        if self.by_epoch:
            message += 'Epoch: %d/%d, iter: %d/%d ' % (
                self.current_epoch, self.epochs, self.inner_iter,
                self.iters_per_epoch)
        else:
            message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters)

        message += f'lr: {self.current_learning_rate:.3e} '
L
LielinJiang 已提交
276 277 278

        for k, v in losses.items():
            message += '%s: %.3f ' % (k, v)
郑启航 已提交
279 280
            if self.enable_visualdl:
                self.vdl_logger.add_scalar(k, v, step=self.global_steps)
L
LielinJiang 已提交
281

282 283 284
        if hasattr(self, 'step_time'):
            message += 'batch_cost: %.5f sec ' % self.step_time

285
        if hasattr(self, 'data_time'):
286
            message += 'reader_cost: %.5f sec ' % self.data_time
287

288
        if hasattr(self, 'ips'):
L
LielinJiang 已提交
289 290 291
            message += 'ips: %.5f images/s ' % self.ips

        if hasattr(self, 'step_time'):
292
            eta = self.step_time * (self.total_iters - self.current_iter - 1)
L
LielinJiang 已提交
293 294
            eta_str = str(datetime.timedelta(seconds=int(eta)))
            message += f'eta: {eta_str}'
295

L
LielinJiang 已提交
296 297 298 299 300
        # print the message
        self.logger.info(message)

    @property
    def current_learning_rate(self):
L
LielinJiang 已提交
301 302
        for optimizer in self.model.optimizers.values():
            return optimizer.get_lr()
L
LielinJiang 已提交
303

郑启航 已提交
304 305 306 307 308 309 310 311 312 313 314 315 316 317
    def visual(self,
               results_dir,
               visual_results=None,
               step=None,
               is_save_image=False):
        """
        visual the images, use visualdl or directly write to the directory

        Parameters:
            results_dir (str)     --  directory name which contains saved images
            visual_results (dict) --  the results images dict
            step (int)            --  global steps, used in visualdl
            is_save_image (bool)  --  weather write to the directory or visualdl
        """
L
LielinJiang 已提交
318 319 320 321 322
        self.model.compute_visuals()

        if visual_results is None:
            visual_results = self.model.get_current_visuals()

L
LielinJiang 已提交
323 324 325
        min_max = self.cfg.get('min_max', None)
        if min_max is None:
            min_max = (-1., 1.)
326

郑启航 已提交
327 328 329
        image_num = self.cfg.get('image_num', None)
        if (image_num is None) or (not self.enable_visualdl):
            image_num = 1
L
LielinJiang 已提交
330
        for label, image in visual_results.items():
郑启航 已提交
331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
            image_numpy = tensor2img(image, min_max, image_num)
            if (not is_save_image) and self.enable_visualdl:
                self.vdl_logger.add_image(
                    results_dir + '/' + label,
                    image_numpy,
                    step=step if step else self.global_steps,
                    dataformats="HWC" if image_num == 1 else "NCHW")
            else:
                if self.cfg.is_train:
                    msg = 'epoch%.3d_' % self.current_epoch
                else:
                    msg = ''
                makedirs(os.path.join(self.output_dir, results_dir))
                img_path = os.path.join(self.output_dir, results_dir,
                                        msg + '%s.png' % (label))
                save_image(image_numpy, img_path)
L
LielinJiang 已提交
347 348 349 350

    def save(self, epoch, name='checkpoint', keep=1):
        if self.local_rank != 0:
            return
L
LielinJiang 已提交
351

L
LielinJiang 已提交
352 353 354
        assert name in ['checkpoint', 'weight']

        state_dicts = {}
L
LielinJiang 已提交
355 356 357 358 359 360
        if self.by_epoch:
            save_filename = 'epoch_%s_%s.pdparams' % (
                epoch // self.iters_per_epoch, name)
        else:
            save_filename = 'iter_%s_%s.pdparams' % (epoch, name)

L
lijianshe02 已提交
361
        os.makedirs(self.output_dir, exist_ok=True)
L
LielinJiang 已提交
362
        save_path = os.path.join(self.output_dir, save_filename)
L
LielinJiang 已提交
363 364
        for net_name, net in self.model.nets.items():
            state_dicts[net_name] = net.state_dict()
L
LielinJiang 已提交
365 366 367 368 369 370 371

        if name == 'weight':
            save(state_dicts, save_path)
            return

        state_dicts['epoch'] = epoch

L
LielinJiang 已提交
372 373
        for opt_name, opt in self.model.optimizers.items():
            state_dicts[opt_name] = opt.state_dict()
L
LielinJiang 已提交
374 375 376 377 378

        save(state_dicts, save_path)

        if keep > 0:
            try:
L
LielinJiang 已提交
379 380 381 382 383 384 385 386 387 388
                if self.by_epoch:
                    checkpoint_name_to_be_removed = os.path.join(
                        self.output_dir, 'epoch_%s_%s.pdparams' %
                        ((epoch - keep * self.weight_interval) //
                         self.iters_per_epoch, name))
                else:
                    checkpoint_name_to_be_removed = os.path.join(
                        self.output_dir, 'iter_%s_%s.pdparams' %
                        (epoch - keep * self.weight_interval, name))

L
LielinJiang 已提交
389 390 391 392 393 394 395 396 397 398
                if os.path.exists(checkpoint_name_to_be_removed):
                    os.remove(checkpoint_name_to_be_removed)

            except Exception as e:
                self.logger.info('remove old checkpoints error: {}'.format(e))

    def resume(self, checkpoint_path):
        state_dicts = load(checkpoint_path)
        if state_dicts.get('epoch', None) is not None:
            self.start_epoch = state_dicts['epoch'] + 1
L
LielinJiang 已提交
399
            self.global_steps = self.iters_per_epoch * state_dicts['epoch']
L
LielinJiang 已提交
400

L
lijianshe02 已提交
401 402
            self.current_iter = state_dicts['epoch'] + 1

L
LielinJiang 已提交
403
        for net_name, net in self.model.nets.items():
404
            net.set_state_dict(state_dicts[net_name])
L
LielinJiang 已提交
405

L
LielinJiang 已提交
406
        for opt_name, opt in self.model.optimizers.items():
407
            opt.set_state_dict(state_dicts[opt_name])
L
LielinJiang 已提交
408 409 410

    def load(self, weight_path):
        state_dicts = load(weight_path)
L
LielinJiang 已提交
411

L
LielinJiang 已提交
412
        for net_name, net in self.model.nets.items():
413 414 415 416 417 418 419 420
            if net_name in state_dicts:
                net.set_state_dict(state_dicts[net_name])
                self.logger.info(
                    'Loaded pretrained weight for net {}'.format(net_name))
            else:
                self.logger.warning(
                    'Can not find state dict of net {}. Skip load pretrained weight for net {}'
                    .format(net_name, net_name))
郑启航 已提交
421 422 423 424 425 426 427 428

    def close(self):
        """
        when finish the training need close file handler or other.

        """
        if self.enable_visualdl:
            self.vdl_logger.close()