trainer.py 7.3 KB
Newer Older
L
LielinJiang 已提交
1 2
import os
import time
L
LielinJiang 已提交
3

L
LielinJiang 已提交
4
import logging
5
import paddle
L
LielinJiang 已提交
6

L
LielinJiang 已提交
7
from paddle.distributed import ParallelEnv
L
LielinJiang 已提交
8 9 10 11

from ..datasets.builder import build_dataloader
from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image
L
LielinJiang 已提交
12
from ..utils.filesystem import save, load, makedirs
L
LielinJiang 已提交
13 14 15 16 17 18 19


class Trainer:
    def __init__(self, cfg):

        # build train dataloader
        self.train_dataloader = build_dataloader(cfg.dataset.train)
L
LielinJiang 已提交
20

L
LielinJiang 已提交
21
        if 'lr_scheduler' in cfg.optimizer:
L
LielinJiang 已提交
22 23 24
            cfg.optimizer.lr_scheduler.step_per_epoch = len(
                self.train_dataloader)

L
LielinJiang 已提交
25 26
        # build model
        self.model = build_model(cfg)
27 28 29
        # multiple gpus prepare
        if ParallelEnv().nranks > 1:
            self.distributed_data_parallel()
L
LielinJiang 已提交
30 31

        self.logger = logging.getLogger(__name__)
32

L
LielinJiang 已提交
33 34 35 36 37 38 39 40 41 42 43 44
        # base config
        self.output_dir = cfg.output_dir
        self.epochs = cfg.epochs
        self.start_epoch = 0
        self.current_epoch = 0
        self.batch_id = 0
        self.weight_interval = cfg.snapshot_config.interval
        self.log_interval = cfg.log_config.interval
        self.visual_interval = cfg.log_config.visiual_interval
        self.cfg = cfg

        self.local_rank = ParallelEnv().local_rank
45 46 47

        # time count
        self.time_count = {}
L
LielinJiang 已提交
48

49
    def distributed_data_parallel(self):
L
LielinJiang 已提交
50
        strategy = paddle.prepare_context()
51 52 53
        for name in self.model.model_names:
            if isinstance(name, str):
                net = getattr(self.model, 'net' + name)
L
LielinJiang 已提交
54 55
                setattr(self.model, 'net' + name,
                        paddle.DataParallel(net, strategy))
56

L
LielinJiang 已提交
57
    def train(self):
L
LielinJiang 已提交
58

L
LielinJiang 已提交
59 60
        for epoch in range(self.start_epoch, self.epochs):
            self.current_epoch = epoch
61
            start_time = step_start_time = time.time()
L
LielinJiang 已提交
62
            for i, data in enumerate(self.train_dataloader):
63
                data_time = time.time()
L
LielinJiang 已提交
64 65
                self.batch_id = i
                # unpack data from dataset and apply preprocessing
L
LielinJiang 已提交
66
                # data input should be dict
L
LielinJiang 已提交
67 68
                self.model.set_input(data)
                self.model.optimize_parameters()
L
LielinJiang 已提交
69

70 71
                self.data_time = data_time - step_start_time
                self.step_time = time.time() - step_start_time
L
LielinJiang 已提交
72 73
                if i % self.log_interval == 0:
                    self.print_log()
L
LielinJiang 已提交
74

L
LielinJiang 已提交
75 76 77
                if i % self.visual_interval == 0:
                    self.visual('visual_train')

78
                step_start_time = time.time()
L
LielinJiang 已提交
79 80 81
            self.logger.info('train one epoch time: {}'.format(time.time() -
                                                               start_time))
            self.model.lr_scheduler.step()
L
LielinJiang 已提交
82 83 84 85 86 87
            if epoch % self.weight_interval == 0:
                self.save(epoch, 'weight', keep=-1)
            self.save(epoch)

    def test(self):
        if not hasattr(self, 'test_dataloader'):
L
LielinJiang 已提交
88 89
            self.test_dataloader = build_dataloader(self.cfg.dataset.test,
                                                    is_train=False)
L
LielinJiang 已提交
90 91 92 93 94

        # data[0]: img, data[1]: img path index
        # test batch size must be 1
        for i, data in enumerate(self.test_dataloader):
            self.batch_id = i
L
LielinJiang 已提交
95 96 97

            self.model.set_input(data)
            self.model.test()
L
LielinJiang 已提交
98 99

            visual_results = {}
L
LielinJiang 已提交
100 101 102
            current_paths = self.model.get_image_paths()
            current_visuals = self.model.get_current_visuals()

L
LielinJiang 已提交
103
            for j in range(len(current_paths)):
L
LielinJiang 已提交
104 105 106 107 108
                short_path = os.path.basename(current_paths[j])
                basename = os.path.splitext(short_path)[0]
                for k, img_tensor in current_visuals.items():
                    name = '%s_%s' % (basename, k)
                    visual_results.update({name: img_tensor[j]})
L
LielinJiang 已提交
109 110

            self.visual('visual_test', visual_results=visual_results)
L
LielinJiang 已提交
111

L
LielinJiang 已提交
112
            if i % self.log_interval == 0:
L
LielinJiang 已提交
113 114
                self.logger.info('Test iter: [%d/%d]' %
                                 (i, len(self.test_dataloader)))
L
LielinJiang 已提交
115 116 117 118

    def print_log(self):
        losses = self.model.get_current_losses()
        message = 'Epoch: %d, iters: %d ' % (self.current_epoch, self.batch_id)
L
LielinJiang 已提交
119

L
LielinJiang 已提交
120 121 122 123 124
        message += '%s: %.6f ' % ('lr', self.current_learning_rate)

        for k, v in losses.items():
            message += '%s: %.3f ' % (k, v)

125 126 127 128 129 130
        if hasattr(self, 'data_time'):
            message += 'reader cost: %.5fs ' % self.data_time

        if hasattr(self, 'step_time'):
            message += 'batch cost: %.5fs' % self.step_time

L
LielinJiang 已提交
131 132 133 134 135
        # print the message
        self.logger.info(message)

    @property
    def current_learning_rate(self):
L
LielinJiang 已提交
136
        return self.model.optimizers[0].get_lr()
L
LielinJiang 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151

    def visual(self, results_dir, visual_results=None):
        self.model.compute_visuals()

        if visual_results is None:
            visual_results = self.model.get_current_visuals()

        if self.cfg.isTrain:
            msg = 'epoch%.3d_' % self.current_epoch
        else:
            msg = ''

        makedirs(os.path.join(self.output_dir, results_dir))
        for label, image in visual_results.items():
            image_numpy = tensor2img(image)
L
LielinJiang 已提交
152 153
            img_path = os.path.join(self.output_dir, results_dir,
                                    msg + '%s.png' % (label))
L
LielinJiang 已提交
154 155 156 157 158
            save_image(image_numpy, img_path)

    def save(self, epoch, name='checkpoint', keep=1):
        if self.local_rank != 0:
            return
L
LielinJiang 已提交
159

L
LielinJiang 已提交
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
        assert name in ['checkpoint', 'weight']

        state_dicts = {}
        save_filename = 'epoch_%s_%s.pkl' % (epoch, name)
        save_path = os.path.join(self.output_dir, save_filename)
        for net_name in self.model.model_names:
            if isinstance(net_name, str):
                net = getattr(self.model, 'net' + net_name)
                state_dicts['net' + net_name] = net.state_dict()

        if name == 'weight':
            save(state_dicts, save_path)
            return

        state_dicts['epoch'] = epoch

        for opt_name in self.model.optimizer_names:
            if isinstance(opt_name, str):
                opt = getattr(self.model, opt_name)
                state_dicts[opt_name] = opt.state_dict()

        save(state_dicts, save_path)

        if keep > 0:
            try:
L
LielinJiang 已提交
185 186
                checkpoint_name_to_be_removed = os.path.join(
                    self.output_dir, 'epoch_%s_%s.pkl' % (epoch - keep, name))
L
LielinJiang 已提交
187 188 189 190 191 192 193 194 195 196
                if os.path.exists(checkpoint_name_to_be_removed):
                    os.remove(checkpoint_name_to_be_removed)

            except Exception as e:
                self.logger.info('remove old checkpoints error: {}'.format(e))

    def resume(self, checkpoint_path):
        state_dicts = load(checkpoint_path)
        if state_dicts.get('epoch', None) is not None:
            self.start_epoch = state_dicts['epoch'] + 1
L
LielinJiang 已提交
197

L
LielinJiang 已提交
198 199 200 201 202 203 204 205 206 207 208 209
        for name in self.model.model_names:
            if isinstance(name, str):
                net = getattr(self.model, 'net' + name)
                net.set_dict(state_dicts['net' + name])

        for name in self.model.optimizer_names:
            if isinstance(name, str):
                opt = getattr(self.model, name)
                opt.set_dict(state_dicts[name])

    def load(self, weight_path):
        state_dicts = load(weight_path)
L
LielinJiang 已提交
210

L
LielinJiang 已提交
211 212 213 214
        for name in self.model.model_names:
            if isinstance(name, str):
                net = getattr(self.model, 'net' + name)
                net.set_dict(state_dicts['net' + name])