trainer.py 13.6 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15 16
import os
import time
L
LielinJiang 已提交
17
import copy
L
LielinJiang 已提交
18

L
LielinJiang 已提交
19
import logging
L
LielinJiang 已提交
20
import datetime
L
LielinJiang 已提交
21

L
LielinJiang 已提交
22
import paddle
L
LielinJiang 已提交
23
from paddle.distributed import ParallelEnv
L
LielinJiang 已提交
24 25 26 27

from ..datasets.builder import build_dataloader
from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image
L
LielinJiang 已提交
28
from ..utils.filesystem import makedirs, save, load
29
from ..utils.timer import TimeAverager
L
LielinJiang 已提交
30
from ..metric.psnr_ssim import calculate_psnr, calculate_ssim
L
LielinJiang 已提交
31

L
fix nan  
LielinJiang 已提交
32

L
LielinJiang 已提交
33 34 35 36 37
class Trainer:
    def __init__(self, cfg):

        # build train dataloader
        self.train_dataloader = build_dataloader(cfg.dataset.train)
L
LielinJiang 已提交
38

L
LielinJiang 已提交
39
        if 'lr_scheduler' in cfg.optimizer:
L
LielinJiang 已提交
40 41 42
            cfg.optimizer.lr_scheduler.step_per_epoch = len(
                self.train_dataloader)

L
LielinJiang 已提交
43 44
        # build model
        self.model = build_model(cfg)
45 46 47
        # multiple gpus prepare
        if ParallelEnv().nranks > 1:
            self.distributed_data_parallel()
L
LielinJiang 已提交
48 49

        self.logger = logging.getLogger(__name__)
郑启航 已提交
50 51 52 53
        self.enable_visualdl = cfg.get('enable_visualdl', False)
        if self.enable_visualdl:
            import visualdl
            self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir)
54

L
LielinJiang 已提交
55 56 57
        # base config
        self.output_dir = cfg.output_dir
        self.epochs = cfg.epochs
L
LielinJiang 已提交
58 59
        self.start_epoch = 1
        self.current_epoch = 1
L
LielinJiang 已提交
60
        self.batch_id = 0
郑启航 已提交
61
        self.global_steps = 0
L
LielinJiang 已提交
62 63 64
        self.weight_interval = cfg.snapshot_config.interval
        self.log_interval = cfg.log_config.interval
        self.visual_interval = cfg.log_config.visiual_interval
L
LielinJiang 已提交
65 66 67
        self.validate_interval = -1
        if cfg.get('validate', None) is not None:
            self.validate_interval = cfg.validate.get('interval', -1)
L
LielinJiang 已提交
68 69 70
        self.cfg = cfg

        self.local_rank = ParallelEnv().local_rank
71 72

        # time count
L
LielinJiang 已提交
73 74 75
        self.steps_per_epoch = len(self.train_dataloader)
        self.total_steps = self.epochs * self.steps_per_epoch

76
        self.time_count = {}
L
LielinJiang 已提交
77 78
        self.best_metric = {}

79
    def distributed_data_parallel(self):
L
LielinJiang 已提交
80
        strategy = paddle.distributed.prepare_context()
81 82
        for net_name, net in self.model.nets.items():
            self.model.nets[net_name] = paddle.DataParallel(net, strategy)
83

L
LielinJiang 已提交
84
    def train(self):
85 86
        reader_cost_averager = TimeAverager()
        batch_cost_averager = TimeAverager()
L
LielinJiang 已提交
87

L
LielinJiang 已提交
88
        for epoch in range(self.start_epoch, self.epochs + 1):
L
LielinJiang 已提交
89
            self.current_epoch = epoch
90
            start_time = step_start_time = time.time()
L
LielinJiang 已提交
91
            for i, data in enumerate(self.train_dataloader):
92 93
                reader_cost_averager.record(time.time() - step_start_time)

L
LielinJiang 已提交
94 95
                self.batch_id = i
                # unpack data from dataset and apply preprocessing
L
LielinJiang 已提交
96
                # data input should be dict
L
LielinJiang 已提交
97 98
                self.model.set_input(data)
                self.model.optimize_parameters()
L
LielinJiang 已提交
99

100 101 102
                batch_cost_averager.record(time.time() - step_start_time,
                                           num_samples=self.cfg.get(
                                               'batch_size', 1))
L
LielinJiang 已提交
103
                if i % self.log_interval == 0:
104 105
                    self.data_time = reader_cost_averager.get_average()
                    self.step_time = batch_cost_averager.get_average()
106
                    self.ips = batch_cost_averager.get_ips_average()
L
LielinJiang 已提交
107
                    self.print_log()
L
LielinJiang 已提交
108

109 110 111
                    reader_cost_averager.reset()
                    batch_cost_averager.reset()

L
LielinJiang 已提交
112 113
                if i % self.visual_interval == 0:
                    self.visual('visual_train')
郑启航 已提交
114
                self.global_steps += 1
115
                step_start_time = time.time()
L
fix nan  
LielinJiang 已提交
116

L
LielinJiang 已提交
117 118 119
            self.logger.info(
                'train one epoch use time: {:.3f} seconds.'.format(time.time() -
                                                                   start_time))
L
LielinJiang 已提交
120 121
            if self.validate_interval > -1 and epoch % self.validate_interval:
                self.validate()
L
LielinJiang 已提交
122
            self.model.lr_scheduler.step()
L
LielinJiang 已提交
123 124 125 126
            if epoch % self.weight_interval == 0:
                self.save(epoch, 'weight', keep=-1)
            self.save(epoch)

L
LielinJiang 已提交
127 128
    def validate(self):
        if not hasattr(self, 'val_dataloader'):
129 130
            self.val_dataloader = build_dataloader(self.cfg.dataset.val,
                                                   is_train=False)
L
LielinJiang 已提交
131 132 133 134 135 136 137 138 139 140 141 142

        metric_result = {}

        for i, data in enumerate(self.val_dataloader):
            self.batch_id = i

            self.model.set_input(data)
            self.model.test()

            visual_results = {}
            current_paths = self.model.get_image_paths()
            current_visuals = self.model.get_current_visuals()
L
fix nan  
LielinJiang 已提交
143

L
LielinJiang 已提交
144 145 146 147 148 149 150 151
            for j in range(len(current_paths)):
                short_path = os.path.basename(current_paths[j])
                basename = os.path.splitext(short_path)[0]
                for k, img_tensor in current_visuals.items():
                    name = '%s_%s' % (basename, k)
                    visual_results.update({name: img_tensor[j]})
                if 'psnr' in self.cfg.validate.metrics:
                    if 'psnr' not in metric_result:
L
fix nan  
LielinJiang 已提交
152 153 154 155
                        metric_result['psnr'] = calculate_psnr(
                            tensor2img(current_visuals['output'][j], (0., 1.)),
                            tensor2img(current_visuals['gt'][j], (0., 1.)),
                            **self.cfg.validate.metrics.psnr)
L
LielinJiang 已提交
156
                    else:
L
fix nan  
LielinJiang 已提交
157 158 159 160
                        metric_result['psnr'] += calculate_psnr(
                            tensor2img(current_visuals['output'][j], (0., 1.)),
                            tensor2img(current_visuals['gt'][j], (0., 1.)),
                            **self.cfg.validate.metrics.psnr)
L
LielinJiang 已提交
161 162
                if 'ssim' in self.cfg.validate.metrics:
                    if 'ssim' not in metric_result:
L
fix nan  
LielinJiang 已提交
163 164 165 166
                        metric_result['ssim'] = calculate_ssim(
                            tensor2img(current_visuals['output'][j], (0., 1.)),
                            tensor2img(current_visuals['gt'][j], (0., 1.)),
                            **self.cfg.validate.metrics.ssim)
L
LielinJiang 已提交
167
                    else:
L
fix nan  
LielinJiang 已提交
168 169 170 171 172
                        metric_result['ssim'] += calculate_ssim(
                            tensor2img(current_visuals['output'][j], (0., 1.)),
                            tensor2img(current_visuals['gt'][j], (0., 1.)),
                            **self.cfg.validate.metrics.ssim)

郑启航 已提交
173 174 175
            self.visual('visual_val',
                        visual_results=visual_results,
                        step=self.batch_id)
L
LielinJiang 已提交
176 177

            if i % self.log_interval == 0:
178 179
                self.logger.info('val iter: [%d/%d]' %
                                 (i, len(self.val_dataloader)))
L
fix nan  
LielinJiang 已提交
180

L
LielinJiang 已提交
181 182 183
        for metric_name in metric_result.keys():
            metric_result[metric_name] /= len(self.val_dataloader.dataset)

L
fix nan  
LielinJiang 已提交
184 185
        self.logger.info('Epoch {} validate end: {}'.format(
            self.current_epoch, metric_result))
L
LielinJiang 已提交
186

L
LielinJiang 已提交
187 188
    def test(self):
        if not hasattr(self, 'test_dataloader'):
189 190
            self.test_dataloader = build_dataloader(self.cfg.dataset.test,
                                                    is_train=False)
L
LielinJiang 已提交
191 192 193 194 195

        # data[0]: img, data[1]: img path index
        # test batch size must be 1
        for i, data in enumerate(self.test_dataloader):
            self.batch_id = i
L
LielinJiang 已提交
196 197 198

            self.model.set_input(data)
            self.model.test()
L
LielinJiang 已提交
199 200

            visual_results = {}
L
LielinJiang 已提交
201 202 203
            current_paths = self.model.get_image_paths()
            current_visuals = self.model.get_current_visuals()

L
LielinJiang 已提交
204
            for j in range(len(current_paths)):
L
LielinJiang 已提交
205 206 207 208 209
                short_path = os.path.basename(current_paths[j])
                basename = os.path.splitext(short_path)[0]
                for k, img_tensor in current_visuals.items():
                    name = '%s_%s' % (basename, k)
                    visual_results.update({name: img_tensor[j]})
L
LielinJiang 已提交
210

郑启航 已提交
211 212 213 214
            self.visual('visual_test',
                        visual_results=visual_results,
                        step=self.batch_id,
                        is_save_image=True)
L
LielinJiang 已提交
215

L
LielinJiang 已提交
216
            if i % self.log_interval == 0:
217 218
                self.logger.info('Test iter: [%d/%d]' %
                                 (i, len(self.test_dataloader)))
L
LielinJiang 已提交
219 220 221 222

    def print_log(self):
        losses = self.model.get_current_losses()
        message = 'Epoch: %d, iters: %d ' % (self.current_epoch, self.batch_id)
L
LielinJiang 已提交
223

L
LielinJiang 已提交
224 225 226 227
        message += '%s: %.6f ' % ('lr', self.current_learning_rate)

        for k, v in losses.items():
            message += '%s: %.3f ' % (k, v)
郑启航 已提交
228 229
            if self.enable_visualdl:
                self.vdl_logger.add_scalar(k, v, step=self.global_steps)
L
LielinJiang 已提交
230

231 232 233
        if hasattr(self, 'step_time'):
            message += 'batch_cost: %.5f sec ' % self.step_time

234
        if hasattr(self, 'data_time'):
235
            message += 'reader_cost: %.5f sec ' % self.data_time
236

237
        if hasattr(self, 'ips'):
L
LielinJiang 已提交
238 239 240 241 242 243 244 245
            message += 'ips: %.5f images/s ' % self.ips

        if hasattr(self, 'step_time'):
            cur_step = self.steps_per_epoch * (self.current_epoch -
                                               1) + self.batch_id
            eta = self.step_time * (self.total_steps - cur_step - 1)
            eta_str = str(datetime.timedelta(seconds=int(eta)))
            message += f'eta: {eta_str}'
246

L
LielinJiang 已提交
247 248 249 250 251
        # print the message
        self.logger.info(message)

    @property
    def current_learning_rate(self):
L
LielinJiang 已提交
252 253
        for optimizer in self.model.optimizers.values():
            return optimizer.get_lr()
L
LielinJiang 已提交
254

郑启航 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268
    def visual(self,
               results_dir,
               visual_results=None,
               step=None,
               is_save_image=False):
        """
        visual the images, use visualdl or directly write to the directory

        Parameters:
            results_dir (str)     --  directory name which contains saved images
            visual_results (dict) --  the results images dict
            step (int)            --  global steps, used in visualdl
            is_save_image (bool)  --  weather write to the directory or visualdl
        """
L
LielinJiang 已提交
269 270 271 272 273
        self.model.compute_visuals()

        if visual_results is None:
            visual_results = self.model.get_current_visuals()

L
LielinJiang 已提交
274 275 276
        min_max = self.cfg.get('min_max', None)
        if min_max is None:
            min_max = (-1., 1.)
郑启航 已提交
277 278 279
        image_num = self.cfg.get('image_num', None)
        if (image_num is None) or (not self.enable_visualdl):
            image_num = 1
L
LielinJiang 已提交
280
        for label, image in visual_results.items():
郑启航 已提交
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
            image_numpy = tensor2img(image, min_max, image_num)
            if (not is_save_image) and self.enable_visualdl:
                self.vdl_logger.add_image(
                    results_dir + '/' + label,
                    image_numpy,
                    step=step if step else self.global_steps,
                    dataformats="HWC" if image_num == 1 else "NCHW")
            else:
                if self.cfg.is_train:
                    msg = 'epoch%.3d_' % self.current_epoch
                else:
                    msg = ''
                makedirs(os.path.join(self.output_dir, results_dir))
                img_path = os.path.join(self.output_dir, results_dir,
                                        msg + '%s.png' % (label))
                save_image(image_numpy, img_path)
L
LielinJiang 已提交
297 298 299 300

    def save(self, epoch, name='checkpoint', keep=1):
        if self.local_rank != 0:
            return
L
LielinJiang 已提交
301

L
LielinJiang 已提交
302 303 304
        assert name in ['checkpoint', 'weight']

        state_dicts = {}
L
LielinJiang 已提交
305
        save_filename = 'epoch_%s_%s.pdparams' % (epoch, name)
L
LielinJiang 已提交
306
        save_path = os.path.join(self.output_dir, save_filename)
L
LielinJiang 已提交
307 308
        for net_name, net in self.model.nets.items():
            state_dicts[net_name] = net.state_dict()
L
LielinJiang 已提交
309 310 311 312 313 314 315

        if name == 'weight':
            save(state_dicts, save_path)
            return

        state_dicts['epoch'] = epoch

L
LielinJiang 已提交
316 317
        for opt_name, opt in self.model.optimizers.items():
            state_dicts[opt_name] = opt.state_dict()
L
LielinJiang 已提交
318 319 320 321 322

        save(state_dicts, save_path)

        if keep > 0:
            try:
L
LielinJiang 已提交
323
                checkpoint_name_to_be_removed = os.path.join(
L
LielinJiang 已提交
324 325
                    self.output_dir,
                    'epoch_%s_%s.pdparams' % (epoch - keep, name))
L
LielinJiang 已提交
326 327 328 329 330 331 332 333 334 335
                if os.path.exists(checkpoint_name_to_be_removed):
                    os.remove(checkpoint_name_to_be_removed)

            except Exception as e:
                self.logger.info('remove old checkpoints error: {}'.format(e))

    def resume(self, checkpoint_path):
        state_dicts = load(checkpoint_path)
        if state_dicts.get('epoch', None) is not None:
            self.start_epoch = state_dicts['epoch'] + 1
郑启航 已提交
336
            self.global_steps = self.steps_per_epoch * state_dicts['epoch']
L
LielinJiang 已提交
337

L
LielinJiang 已提交
338
        for net_name, net in self.model.nets.items():
339
            net.set_state_dict(state_dicts[net_name])
L
LielinJiang 已提交
340

L
LielinJiang 已提交
341
        for opt_name, opt in self.model.optimizers.items():
342
            opt.set_state_dict(state_dicts[opt_name])
L
LielinJiang 已提交
343 344 345

    def load(self, weight_path):
        state_dicts = load(weight_path)
L
LielinJiang 已提交
346

L
LielinJiang 已提交
347
        for net_name, net in self.model.nets.items():
348
            net.set_state_dict(state_dicts[net_name])
郑启航 已提交
349 350 351 352 353 354 355 356

    def close(self):
        """
        when finish the training need close file handler or other.

        """
        if self.enable_visualdl:
            self.vdl_logger.close()