trainer.py 14.3 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15 16
import os
import time
L
LielinJiang 已提交
17
import copy
L
LielinJiang 已提交
18

L
LielinJiang 已提交
19
import logging
L
LielinJiang 已提交
20
import datetime
L
LielinJiang 已提交
21

L
LielinJiang 已提交
22
import paddle
L
LielinJiang 已提交
23
from paddle.distributed import ParallelEnv
L
LielinJiang 已提交
24 25 26 27

from ..datasets.builder import build_dataloader
from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image
L
LielinJiang 已提交
28
from ..utils.filesystem import makedirs, save, load
29
from ..utils.timer import TimeAverager
L
LielinJiang 已提交
30

L
fix nan  
LielinJiang 已提交
31

32 33 34 35 36
class IterLoader:
    def __init__(self, dataloader):
        self._dataloader = dataloader
        self.iter_loader = iter(self._dataloader)
        self._epoch = 1
L
LielinJiang 已提交
37

38 39 40 41 42 43 44 45 46 47 48
    @property
    def epoch(self):
        return self._epoch

    def __next__(self):
        try:
            data = next(self.iter_loader)
        except StopIteration:
            self._epoch += 1
            self.iter_loader = iter(self._dataloader)
            data = next(self.iter_loader)
L
LielinJiang 已提交
49

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
        return data

    def __len__(self):
        return len(self._dataloader)


class Trainer:
    """
    # trainer calling logic:
    #
    #                build_model                               ||    model(BaseModel)
    #                     |                                    ||
    #               build_dataloader                           ||    dataloader
    #                     |                                    ||
    #               model.setup_lr_schedulers                  ||    lr_scheduler
    #                     |                                    ||
    #               model.setup_optimizers                     ||    optimizers
    #                     |                                    ||
    #     train loop (model.setup_input + model.train_iter)    ||    train loop
    #                     |                                    ||
    #         print log (model.get_current_losses)             ||
    #                     |                                    ||
    #         save checkpoint (model.nets)                     \/
    """
    def __init__(self, cfg):
L
LielinJiang 已提交
75

L
LielinJiang 已提交
76
        # build model
77
        self.model = build_model(cfg.model)
78 79 80
        # multiple gpus prepare
        if ParallelEnv().nranks > 1:
            self.distributed_data_parallel()
L
LielinJiang 已提交
81

82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
        # build train dataloader
        self.train_dataloader = build_dataloader(cfg.dataset.train)
        self.iters_per_epoch = len(self.train_dataloader)

        # build lr scheduler
        # TODO: has a better way?
        if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler:
            cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch
        self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler)

        # build optimizers
        self.optimizers = self.model.setup_optimizers(self.lr_schedulers,
                                                      cfg.optimizer)

        # build metrics
        self.metrics = None
        validate_cfg = cfg.get('validate', None)
        if validate_cfg and 'metrics' in validate_cfg:
            self.metrics = self.model.setup_metrics(validate_cfg['metrics'])

L
LielinJiang 已提交
102
        self.logger = logging.getLogger(__name__)
郑启航 已提交
103 104 105 106
        self.enable_visualdl = cfg.get('enable_visualdl', False)
        if self.enable_visualdl:
            import visualdl
            self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir)
107

L
LielinJiang 已提交
108 109
        # base config
        self.output_dir = cfg.output_dir
110 111 112 113 114 115 116 117
        self.epochs = cfg.get('epochs', None)
        if self.epochs:
            self.total_iters = self.epochs * self.iters_per_epoch
            self.by_epoch = True
        else:
            self.by_epoch = False
            self.total_iters = cfg.total_iters

L
LielinJiang 已提交
118 119
        self.start_epoch = 1
        self.current_epoch = 1
120 121
        self.current_iter = 1
        self.inner_iter = 1
L
LielinJiang 已提交
122
        self.batch_id = 0
郑启航 已提交
123
        self.global_steps = 0
L
LielinJiang 已提交
124 125 126
        self.weight_interval = cfg.snapshot_config.interval
        self.log_interval = cfg.log_config.interval
        self.visual_interval = cfg.log_config.visiual_interval
L
LielinJiang 已提交
127 128 129
        self.validate_interval = -1
        if cfg.get('validate', None) is not None:
            self.validate_interval = cfg.validate.get('interval', -1)
L
LielinJiang 已提交
130 131 132
        self.cfg = cfg

        self.local_rank = ParallelEnv().local_rank
133 134

        self.time_count = {}
L
LielinJiang 已提交
135 136
        self.best_metric = {}

137
    def distributed_data_parallel(self):
L
LielinJiang 已提交
138
        strategy = paddle.distributed.prepare_context()
139 140
        for net_name, net in self.model.nets.items():
            self.model.nets[net_name] = paddle.DataParallel(net, strategy)
141

L
LielinJiang 已提交
142
    def train(self):
143 144
        reader_cost_averager = TimeAverager()
        batch_cost_averager = TimeAverager()
L
LielinJiang 已提交
145

146
        iter_loader = IterLoader(self.train_dataloader)
L
LielinJiang 已提交
147

148 149 150
        while self.current_iter < (self.total_iters + 1):
            self.current_epoch = iter_loader.epoch
            self.inner_iter = self.current_iter % self.iters_per_epoch
L
LielinJiang 已提交
151

152 153 154 155 156 157 158 159 160 161 162
            start_time = step_start_time = time.time()
            data = next(iter_loader)
            reader_cost_averager.record(time.time() - step_start_time)
            # unpack data from dataset and apply preprocessing
            # data input should be dict
            self.model.setup_input(data)
            self.model.train_iter(self.optimizers)

            batch_cost_averager.record(time.time() - step_start_time,
                                       num_samples=self.cfg.get(
                                           'batch_size', 1))
163 164 165

            step_start_time = time.time()

166 167 168 169 170 171 172 173 174 175 176 177 178
            if self.current_iter % self.log_interval == 0:
                self.data_time = reader_cost_averager.get_average()
                self.step_time = batch_cost_averager.get_average()
                self.ips = batch_cost_averager.get_ips_average()
                self.print_log()

                reader_cost_averager.reset()
                batch_cost_averager.reset()

            if self.current_iter % self.visual_interval == 0:
                self.visual('visual_train')

            self.model.lr_scheduler.step()
L
LielinJiang 已提交
179

180 181 182 183 184 185
            if self.by_epoch:
                temp = self.current_epoch
            else:
                temp = self.current_iter
            if self.validate_interval > -1 and temp % self.validate_interval == 0:
                self.test()
L
fix nan  
LielinJiang 已提交
186

187 188 189
            if temp % self.weight_interval == 0:
                self.save(temp, 'weight', keep=-1)
                self.save(temp)
L
LielinJiang 已提交
190

191
            self.current_iter += 1
L
LielinJiang 已提交
192

L
LielinJiang 已提交
193 194
    def test(self):
        if not hasattr(self, 'test_dataloader'):
195
            self.test_dataloader = build_dataloader(self.cfg.dataset.test,
196 197 198 199 200 201
                                                    is_train=False,
                                                    distributed=False)

        if self.metrics:
            for metric in self.metrics.values():
                metric.reset()
L
LielinJiang 已提交
202 203 204 205

        # data[0]: img, data[1]: img path index
        # test batch size must be 1
        for i, data in enumerate(self.test_dataloader):
L
LielinJiang 已提交
206

207 208
            self.model.setup_input(data)
            self.model.test_iter(metrics=self.metrics)
L
LielinJiang 已提交
209 210

            visual_results = {}
L
LielinJiang 已提交
211 212 213
            current_paths = self.model.get_image_paths()
            current_visuals = self.model.get_current_visuals()

L
LielinJiang 已提交
214 215 216 217 218 219 220 221 222 223 224 225
            if len(current_visuals) > 0 and list(
                    current_visuals.values())[0].shape == 4:
                num_samples = list(current_visuals.values())[0].shape[0]
            else:
                num_samples = 1

            for j in range(num_samples):
                if j < len(current_paths):
                    short_path = os.path.basename(current_paths[j])
                    basename = os.path.splitext(short_path)[0]
                else:
                    basename = '{:04d}_{:04d}'.format(i, j)
L
LielinJiang 已提交
226 227
                for k, img_tensor in current_visuals.items():
                    name = '%s_%s' % (basename, k)
L
LielinJiang 已提交
228 229 230 231
                    if len(img_tensor.shape) == 4:
                        visual_results.update({name: img_tensor[j]})
                    else:
                        visual_results.update({name: img_tensor})
L
LielinJiang 已提交
232

郑启航 已提交
233 234 235 236
            self.visual('visual_test',
                        visual_results=visual_results,
                        step=self.batch_id,
                        is_save_image=True)
L
LielinJiang 已提交
237

L
LielinJiang 已提交
238
            if i % self.log_interval == 0:
239 240
                self.logger.info('Test iter: [%d/%d]' %
                                 (i, len(self.test_dataloader)))
L
LielinJiang 已提交
241

242 243 244 245 246
        if self.metrics:
            for metric_name, metric in self.metrics.items():
                self.logger.info("Metric {}: {:.4f}".format(
                    metric_name, metric.accumulate()))

L
LielinJiang 已提交
247 248
    def print_log(self):
        losses = self.model.get_current_losses()
L
LielinJiang 已提交
249

250 251 252 253 254 255 256 257 258
        message = ''
        if self.by_epoch:
            message += 'Epoch: %d/%d, iter: %d/%d ' % (
                self.current_epoch, self.epochs, self.inner_iter,
                self.iters_per_epoch)
        else:
            message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters)

        message += f'lr: {self.current_learning_rate:.3e} '
L
LielinJiang 已提交
259 260 261

        for k, v in losses.items():
            message += '%s: %.3f ' % (k, v)
郑启航 已提交
262 263
            if self.enable_visualdl:
                self.vdl_logger.add_scalar(k, v, step=self.global_steps)
L
LielinJiang 已提交
264

265 266 267
        if hasattr(self, 'step_time'):
            message += 'batch_cost: %.5f sec ' % self.step_time

268
        if hasattr(self, 'data_time'):
269
            message += 'reader_cost: %.5f sec ' % self.data_time
270

271
        if hasattr(self, 'ips'):
L
LielinJiang 已提交
272 273 274
            message += 'ips: %.5f images/s ' % self.ips

        if hasattr(self, 'step_time'):
275
            eta = self.step_time * (self.total_iters - self.current_iter - 1)
L
LielinJiang 已提交
276 277
            eta_str = str(datetime.timedelta(seconds=int(eta)))
            message += f'eta: {eta_str}'
278

L
LielinJiang 已提交
279 280 281 282 283
        # print the message
        self.logger.info(message)

    @property
    def current_learning_rate(self):
L
LielinJiang 已提交
284 285
        for optimizer in self.model.optimizers.values():
            return optimizer.get_lr()
L
LielinJiang 已提交
286

郑启航 已提交
287 288 289 290 291 292 293 294 295 296 297 298 299 300
    def visual(self,
               results_dir,
               visual_results=None,
               step=None,
               is_save_image=False):
        """
        visual the images, use visualdl or directly write to the directory

        Parameters:
            results_dir (str)     --  directory name which contains saved images
            visual_results (dict) --  the results images dict
            step (int)            --  global steps, used in visualdl
            is_save_image (bool)  --  weather write to the directory or visualdl
        """
L
LielinJiang 已提交
301 302 303 304 305
        self.model.compute_visuals()

        if visual_results is None:
            visual_results = self.model.get_current_visuals()

L
LielinJiang 已提交
306 307 308
        min_max = self.cfg.get('min_max', None)
        if min_max is None:
            min_max = (-1., 1.)
309

郑启航 已提交
310 311 312
        image_num = self.cfg.get('image_num', None)
        if (image_num is None) or (not self.enable_visualdl):
            image_num = 1
L
LielinJiang 已提交
313
        for label, image in visual_results.items():
郑启航 已提交
314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
            image_numpy = tensor2img(image, min_max, image_num)
            if (not is_save_image) and self.enable_visualdl:
                self.vdl_logger.add_image(
                    results_dir + '/' + label,
                    image_numpy,
                    step=step if step else self.global_steps,
                    dataformats="HWC" if image_num == 1 else "NCHW")
            else:
                if self.cfg.is_train:
                    msg = 'epoch%.3d_' % self.current_epoch
                else:
                    msg = ''
                makedirs(os.path.join(self.output_dir, results_dir))
                img_path = os.path.join(self.output_dir, results_dir,
                                        msg + '%s.png' % (label))
                save_image(image_numpy, img_path)
L
LielinJiang 已提交
330 331 332 333

    def save(self, epoch, name='checkpoint', keep=1):
        if self.local_rank != 0:
            return
L
LielinJiang 已提交
334

L
LielinJiang 已提交
335 336 337
        assert name in ['checkpoint', 'weight']

        state_dicts = {}
L
LielinJiang 已提交
338
        save_filename = 'epoch_%s_%s.pdparams' % (epoch, name)
L
LielinJiang 已提交
339
        save_path = os.path.join(self.output_dir, save_filename)
L
LielinJiang 已提交
340 341
        for net_name, net in self.model.nets.items():
            state_dicts[net_name] = net.state_dict()
L
LielinJiang 已提交
342 343 344 345 346 347 348

        if name == 'weight':
            save(state_dicts, save_path)
            return

        state_dicts['epoch'] = epoch

L
LielinJiang 已提交
349 350
        for opt_name, opt in self.model.optimizers.items():
            state_dicts[opt_name] = opt.state_dict()
L
LielinJiang 已提交
351 352 353 354 355

        save(state_dicts, save_path)

        if keep > 0:
            try:
L
LielinJiang 已提交
356
                checkpoint_name_to_be_removed = os.path.join(
L
LielinJiang 已提交
357 358
                    self.output_dir,
                    'epoch_%s_%s.pdparams' % (epoch - keep, name))
L
LielinJiang 已提交
359 360 361 362 363 364 365 366 367 368
                if os.path.exists(checkpoint_name_to_be_removed):
                    os.remove(checkpoint_name_to_be_removed)

            except Exception as e:
                self.logger.info('remove old checkpoints error: {}'.format(e))

    def resume(self, checkpoint_path):
        state_dicts = load(checkpoint_path)
        if state_dicts.get('epoch', None) is not None:
            self.start_epoch = state_dicts['epoch'] + 1
郑启航 已提交
369
            self.global_steps = self.steps_per_epoch * state_dicts['epoch']
L
LielinJiang 已提交
370

L
LielinJiang 已提交
371
        for net_name, net in self.model.nets.items():
372
            net.set_state_dict(state_dicts[net_name])
L
LielinJiang 已提交
373

L
LielinJiang 已提交
374
        for opt_name, opt in self.model.optimizers.items():
375
            opt.set_state_dict(state_dicts[opt_name])
L
LielinJiang 已提交
376 377 378

    def load(self, weight_path):
        state_dicts = load(weight_path)
L
LielinJiang 已提交
379

L
LielinJiang 已提交
380
        for net_name, net in self.model.nets.items():
381 382 383 384 385 386 387 388
            if net_name in state_dicts:
                net.set_state_dict(state_dicts[net_name])
                self.logger.info(
                    'Loaded pretrained weight for net {}'.format(net_name))
            else:
                self.logger.warning(
                    'Can not find state dict of net {}. Skip load pretrained weight for net {}'
                    .format(net_name, net_name))
郑启航 已提交
389 390 391 392 393 394 395 396

    def close(self):
        """
        when finish the training need close file handler or other.

        """
        if self.enable_visualdl:
            self.vdl_logger.close()