trainer.py 7.2 KB
Newer Older
L
LielinJiang 已提交
1 2
import os
import time
L
LielinJiang 已提交
3

L
LielinJiang 已提交
4
import logging
5
import paddle
L
LielinJiang 已提交
6

L
LielinJiang 已提交
7
from paddle import ParallelEnv, DataParallel
L
LielinJiang 已提交
8 9 10 11

from ..datasets.builder import build_dataloader
from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image
L
LielinJiang 已提交
12
from ..utils.filesystem import save, load, makedirs
L
LielinJiang 已提交
13 14 15 16 17 18 19 20 21 22 23 24 25


class Trainer:
    def __init__(self, cfg):

        # build train dataloader
        self.train_dataloader = build_dataloader(cfg.dataset.train)
        
        if 'lr_scheduler' in cfg.optimizer:
            cfg.optimizer.lr_scheduler.step_per_epoch = len(self.train_dataloader)
        
        # build model
        self.model = build_model(cfg)
26 27 28
        # multiple gpus prepare
        if ParallelEnv().nranks > 1:
            self.distributed_data_parallel()
L
LielinJiang 已提交
29 30

        self.logger = logging.getLogger(__name__)
31

L
LielinJiang 已提交
32 33 34 35 36 37 38 39 40 41 42 43
        # base config
        self.output_dir = cfg.output_dir
        self.epochs = cfg.epochs
        self.start_epoch = 0
        self.current_epoch = 0
        self.batch_id = 0
        self.weight_interval = cfg.snapshot_config.interval
        self.log_interval = cfg.log_config.interval
        self.visual_interval = cfg.log_config.visiual_interval
        self.cfg = cfg

        self.local_rank = ParallelEnv().local_rank
44 45 46

        # time count
        self.time_count = {}
L
LielinJiang 已提交
47
    
48
    def distributed_data_parallel(self):
L
LielinJiang 已提交
49
        strategy = paddle.prepare_context()
50 51 52 53 54
        for name in self.model.model_names:
            if isinstance(name, str):
                net = getattr(self.model, 'net' + name)
                setattr(self.model, 'net' + name, DataParallel(net, strategy))

L
LielinJiang 已提交
55 56 57 58
    def train(self):
        
        for epoch in range(self.start_epoch, self.epochs):
            self.current_epoch = epoch
59
            start_time = step_start_time = time.time()
L
LielinJiang 已提交
60
            for i, data in enumerate(self.train_dataloader):
61
                data_time = time.time()
L
LielinJiang 已提交
62 63
                self.batch_id = i
                # unpack data from dataset and apply preprocessing
L
LielinJiang 已提交
64
                # data input should be dict
L
LielinJiang 已提交
65 66
                self.model.set_input(data)
                self.model.optimize_parameters()
67 68 69
                
                self.data_time = data_time - step_start_time
                self.step_time = time.time() - step_start_time
L
LielinJiang 已提交
70 71 72 73 74 75
                if i % self.log_interval == 0:
                    self.print_log()
                    
                if i % self.visual_interval == 0:
                    self.visual('visual_train')

76
                step_start_time = time.time()
L
LielinJiang 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89
            self.logger.info('train one epoch time: {}'.format(time.time() - start_time))
            if epoch % self.weight_interval == 0:
                self.save(epoch, 'weight', keep=-1)
            self.save(epoch)

    def test(self):
        if not hasattr(self, 'test_dataloader'):
            self.test_dataloader = build_dataloader(self.cfg.dataset.test, is_train=False)

        # data[0]: img, data[1]: img path index
        # test batch size must be 1
        for i, data in enumerate(self.test_dataloader):
            self.batch_id = i
L
LielinJiang 已提交
90 91 92

            self.model.set_input(data)
            self.model.test()
L
LielinJiang 已提交
93 94

            visual_results = {}
L
LielinJiang 已提交
95 96 97
            current_paths = self.model.get_image_paths()
            current_visuals = self.model.get_current_visuals()

L
LielinJiang 已提交
98
            for j in range(len(current_paths)):
L
LielinJiang 已提交
99 100 101 102 103
                short_path = os.path.basename(current_paths[j])
                basename = os.path.splitext(short_path)[0]
                for k, img_tensor in current_visuals.items():
                    name = '%s_%s' % (basename, k)
                    visual_results.update({name: img_tensor[j]})
L
LielinJiang 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118

            self.visual('visual_test', visual_results=visual_results)
            
            if i % self.log_interval == 0:
                self.logger.info('Test iter: [%d/%d]' % (i, len(self.test_dataloader)))

    def print_log(self):
        losses = self.model.get_current_losses()
        message = 'Epoch: %d, iters: %d ' % (self.current_epoch, self.batch_id)
        
        message += '%s: %.6f ' % ('lr', self.current_learning_rate)

        for k, v in losses.items():
            message += '%s: %.3f ' % (k, v)

119 120 121 122 123 124
        if hasattr(self, 'data_time'):
            message += 'reader cost: %.5fs ' % self.data_time

        if hasattr(self, 'step_time'):
            message += 'batch cost: %.5fs' % self.step_time

L
LielinJiang 已提交
125 126 127 128 129
        # print the message
        self.logger.info(message)

    @property
    def current_learning_rate(self):
L
LielinJiang 已提交
130
        return self.model.optimizers[0].get_lr()
L
LielinJiang 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208

    def visual(self, results_dir, visual_results=None):
        self.model.compute_visuals()

        if visual_results is None:
            visual_results = self.model.get_current_visuals()

        if self.cfg.isTrain:
            msg = 'epoch%.3d_' % self.current_epoch
        else:
            msg = ''

        makedirs(os.path.join(self.output_dir, results_dir))
        for label, image in visual_results.items():
            image_numpy = tensor2img(image)
            img_path = os.path.join(self.output_dir, results_dir, msg + '%s.png' % (label))
            save_image(image_numpy, img_path)

    def save(self, epoch, name='checkpoint', keep=1):
        if self.local_rank != 0:
            return
            
        assert name in ['checkpoint', 'weight']

        state_dicts = {}
        save_filename = 'epoch_%s_%s.pkl' % (epoch, name)
        save_path = os.path.join(self.output_dir, save_filename)
        for net_name in self.model.model_names:
            if isinstance(net_name, str):
                net = getattr(self.model, 'net' + net_name)
                state_dicts['net' + net_name] = net.state_dict()

        if name == 'weight':
            save(state_dicts, save_path)
            return

        state_dicts['epoch'] = epoch

        for opt_name in self.model.optimizer_names:
            if isinstance(opt_name, str):
                opt = getattr(self.model, opt_name)
                state_dicts[opt_name] = opt.state_dict()

        save(state_dicts, save_path)

        if keep > 0:
            try:
                checkpoint_name_to_be_removed = os.path.join(self.output_dir, 
                                            'epoch_%s_%s.pkl' % (epoch - keep, name))
                if os.path.exists(checkpoint_name_to_be_removed):
                    os.remove(checkpoint_name_to_be_removed)

            except Exception as e:
                self.logger.info('remove old checkpoints error: {}'.format(e))

    def resume(self, checkpoint_path):
        state_dicts = load(checkpoint_path)
        if state_dicts.get('epoch', None) is not None:
            self.start_epoch = state_dicts['epoch'] + 1
        
        for name in self.model.model_names:
            if isinstance(name, str):
                net = getattr(self.model, 'net' + name)
                net.set_dict(state_dicts['net' + name])

        for name in self.model.optimizer_names:
            if isinstance(name, str):
                opt = getattr(self.model, name)
                opt.set_dict(state_dicts[name])

    def load(self, weight_path):
        state_dicts = load(weight_path)
        
        for name in self.model.model_names:
            if isinstance(name, str):
                net = getattr(self.model, 'net' + name)
                net.set_dict(state_dicts['net' + name])