trainer.py 14.8 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# coding:utf-8
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pickle
import time
from collections import defaultdict
W
wuzewu 已提交
20
from typing import Any, Callable, Generic, List
W
wuzewu 已提交
21

22 23
import paddle
from paddle.distributed import ParallelEnv
W
wuzewu 已提交
24 25 26 27 28 29 30 31
from visualdl import LogWriter

from paddlehub.utils.log import logger
from paddlehub.utils.utils import Timer


class Trainer(object):
    '''
W
wuzewu 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
    Model trainer

    Args:
        model(paddle.nn.Layer) : Model to train or evaluate.
        strategy(paddle.optimizer.Optimizer) : Optimizer strategy.
        use_vdl(bool) : Whether to use visualdl to record training data.
        checkpoint_dir(str) : Directory where the checkpoint is saved, and the trainer will restore the
            state and model parameters from the checkpoint.
        compare_metrics(callable) : The method of comparing the model metrics. If not specified, the main
            metric return by `validation_step` will be used for comparison by default, the larger the
            value, the better the effect. This method will affect the saving of the best model. If the
            default behavior does not meet your requirements, please pass in a custom method.

            Example:
                .. code-block:: python

                    def compare_metrics(old_metric: dict, new_metric: dict):
                        mainkey = list(new_metric.keys())[0]
                        return old_metric[mainkey] < new_metric[mainkey]
W
wuzewu 已提交
51 52 53
    '''

    def __init__(self,
54 55
                 model: paddle.nn.Layer,
                 strategy: paddle.optimizer.Optimizer,
W
wuzewu 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
                 use_vdl: bool = True,
                 checkpoint_dir: str = None,
                 compare_metrics: Callable = None):
        self.nranks = ParallelEnv().nranks
        self.local_rank = ParallelEnv().local_rank
        self.model = model
        self.optimizer = strategy
        self.checkpoint_dir = checkpoint_dir if checkpoint_dir else 'ckpt_{}'.format(time.time())

        if self.local_rank == 0 and not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

        self.use_vdl = use_vdl
        if self.local_rank == 0 and self.use_vdl:
            vdl_dir = os.path.join(self.checkpoint_dir, 'visualization')
            self.log_writer = LogWriter(vdl_dir)

        self.current_epoch = 0
        self.best_metrics = defaultdict(int)

        if self.nranks > 1:
77 78
            context = paddle.distributed.init_parallel_env()
            self.model = paddle.DataParallel(self.model, context)
W
wuzewu 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
        self.compare_metrics = self._compare_metrics if not compare_metrics else compare_metrics

        self._load_checkpoint()

    def _load_checkpoint(self):
        '''Load checkpoint and state dict'''
        max_epoch = -1

        for file in os.listdir(self.checkpoint_dir):
            if not file.startswith('epoch_'):
                continue

            _epoch = file.split('_')[-1]
            if not _epoch.isdigit():
                continue

            max_epoch = max(max_epoch, int(_epoch))

        if max_epoch == -1:
            if self.local_rank == 0:
                logger.warning('PaddleHub model checkpoint not found, start from scratch...')
            return

        # load best metrics
        self._load_metrics()

        self.current_epoch = max_epoch
        metric_msg = ['{}={:.4f}'.format(metric, value) for metric, value in self.best_metrics.items()]
        metric_msg = ' '.join(metric_msg)
        if self.local_rank == 0:
            logger.info('PaddleHub model checkpoint loaded. current_epoch={} [{}]'.format(
                self.current_epoch, metric_msg))

        # load model from checkpoint
        model_path = os.path.join(self.checkpoint_dir, '{}_{}'.format('epoch', self.current_epoch), 'model')
114
        state_dict, _ = paddle.load(model_path)
W
wuzewu 已提交
115 116 117 118 119 120 121 122 123 124
        self.model.set_dict(state_dict)

    def _save_checkpoint(self):
        '''Save model checkpoint and state dict'''
        model_path = os.path.join(self.checkpoint_dir, '{}_{}'.format('epoch', self.current_epoch), 'model')
        logger.info('Saving model checkpoint to {}'.format(model_path))
        self.save_model(model_path)

    def save_model(self, save_dir: str):
        '''Save model'''
125
        paddle.save(self.model.state_dict(), save_dir)
W
wuzewu 已提交
126 127 128 129 130 131 132 133 134 135

    def _save_metrics(self):
        with open(os.path.join(self.checkpoint_dir, 'metrics.pkl'), 'wb') as file:
            pickle.dump(self.best_metrics, file)

    def _load_metrics(self):
        with open(os.path.join(self.checkpoint_dir, 'metrics.pkl'), 'rb') as file:
            self.best_metrics = pickle.load(file)

    def train(self,
136
              train_dataset: paddle.io.Dataset,
W
wuzewu 已提交
137 138 139
              epochs: int = 1,
              batch_size: int = 1,
              num_workers: int = 0,
140
              eval_dataset: paddle.io.Dataset = None,
W
wuzewu 已提交
141 142 143 144 145 146
              log_interval: int = 10,
              save_interval: int = 10):
        '''
        Train a model with specific config.

        Args:
147
            train_dataset(paddle.io.Dataset) : Dataset to train the model
W
wuzewu 已提交
148 149 150
            epochs(int) : Number of training loops, default is 1.
            batch_size(int) : Batch size of per step, default is 1.
            num_workers(int) : Number of subprocess to load data, default is 0.
W
wuzewu 已提交
151 152
            eval_dataset(paddle.io.Dataset) : The validation dataset, deafult is None. If set, the Trainer will
                execute evaluate function every `save_interval` epochs.
W
wuzewu 已提交
153 154 155 156
            log_interval(int) : Log the train infomation every `log_interval` steps.
            save_interval(int) : Save the checkpoint every `save_interval` epochs.
        '''
        use_gpu = True
157 158
        place = paddle.CUDAPlace(ParallelEnv().dev_id) if use_gpu else paddle.CPUPlace()
        paddle.disable_static(place)
W
wuzewu 已提交
159

160 161 162 163
        batch_sampler = paddle.io.DistributedBatchSampler(
            train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
        loader = paddle.io.DataLoader(
            train_dataset, batch_sampler=batch_sampler, places=place, num_workers=num_workers, return_list=True)
W
wuzewu 已提交
164

165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
        steps_per_epoch = len(batch_sampler)
        timer = Timer(steps_per_epoch * epochs)
        timer.start()

        for i in range(epochs):
            self.current_epoch += 1
            avg_loss = 0
            avg_metrics = defaultdict(int)
            self.model.train()

            for batch_idx, batch in enumerate(loader):
                loss, metrics = self.training_step(batch, batch_idx)
                self.optimizer_step(self.current_epoch, batch_idx, self.optimizer, loss)
                self.optimizer_zero_grad(self.current_epoch, batch_idx, self.optimizer)

                # calculate metrics and loss
                avg_loss += loss.numpy()[0]
                for metric, value in metrics.items():
                    avg_metrics[metric] += value.numpy()[0]

                timer.count()

                if (batch_idx + 1) % log_interval == 0 and self.local_rank == 0:
W
wuzewu 已提交
188
                    lr = self.optimizer.get_lr()
189 190 191 192 193 194 195
                    avg_loss /= log_interval
                    if self.use_vdl:
                        self.log_writer.add_scalar(tag='TRAIN/loss', step=timer.current_step, value=avg_loss)

                    print_msg = 'Epoch={}/{}, Step={}/{}'.format(self.current_epoch, epochs, batch_idx + 1,
                                                                 steps_per_epoch)
                    print_msg += ' loss={:.4f}'.format(avg_loss)
W
wuzewu 已提交
196

197 198 199 200 201 202 203 204
                    for metric, value in avg_metrics.items():
                        value /= log_interval
                        if self.use_vdl:
                            self.log_writer.add_scalar(
                                tag='TRAIN/{}'.format(metric), step=timer.current_step, value=value)
                        print_msg += ' {}={:.4f}'.format(metric, value)

                    print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format(lr, timer.timing, timer.eta)
W
wuzewu 已提交
205

206
                    logger.train(print_msg)
W
wuzewu 已提交
207

208 209
                    avg_loss = 0
                    avg_metrics = defaultdict(int)
W
wuzewu 已提交
210

211 212 213 214 215 216 217 218
                if self.current_epoch % save_interval == 0 and batch_idx + 1 == steps_per_epoch and self.local_rank == 0:
                    if eval_dataset:
                        result = self.evaluate(eval_dataset, batch_size, num_workers)
                        eval_loss = result.get('loss', None)
                        eval_metrics = result.get('metrics', {})
                        if self.use_vdl:
                            if eval_loss:
                                self.log_writer.add_scalar(tag='EVAL/loss', step=timer.current_step, value=eval_loss)
W
wuzewu 已提交
219

220 221 222
                            for metric, value in eval_metrics.items():
                                self.log_writer.add_scalar(
                                    tag='EVAL/{}'.format(metric), step=timer.current_step, value=value)
W
wuzewu 已提交
223

224 225 226 227 228
                        if not self.best_metrics or self.compare_metrics(self.best_metrics, eval_metrics):
                            self.best_metrics = eval_metrics
                            best_model_path = os.path.join(self.checkpoint_dir, 'best_model')
                            self.save_model(best_model_path)
                            self._save_metrics()
W
wuzewu 已提交
229

230 231 232 233 234
                            metric_msg = [
                                '{}={:.4f}'.format(metric, value) for metric, value in self.best_metrics.items()
                            ]
                            metric_msg = ' '.join(metric_msg)
                            logger.eval('Saving best model to {} [best {}]'.format(best_model_path, metric_msg))
W
wuzewu 已提交
235

236
                    self._save_checkpoint()
W
wuzewu 已提交
237

238
    def evaluate(self, eval_dataset: paddle.io.Dataset, batch_size: int = 1, num_workers: int = 0):
W
wuzewu 已提交
239 240 241 242
        '''
        Run evaluation and returns metrics.

        Args:
243
            eval_dataset(paddle.io.Dataset) : The validation dataset
W
wuzewu 已提交
244 245 246 247
            batch_size(int) : Batch size of per step, default is 1.
            num_workers(int) : Number of subprocess to load data, default is 0.
        '''
        use_gpu = True
248 249
        place = paddle.CUDAPlace(ParallelEnv().dev_id) if use_gpu else paddle.CPUPlace()
        paddle.disable_static(place)
W
wuzewu 已提交
250

251 252
        batch_sampler = paddle.io.DistributedBatchSampler(
            eval_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
W
wuzewu 已提交
253

254 255
        loader = paddle.io.DataLoader(
            eval_dataset, batch_sampler=batch_sampler, places=place, num_workers=num_workers, return_list=True)
W
wuzewu 已提交
256

257 258 259 260
        self.model.eval()
        avg_loss = num_samples = 0
        sum_metrics = defaultdict(int)
        avg_metrics = defaultdict(int)
W
wuzewu 已提交
261

262 263 264 265 266 267
        for batch_idx, batch in enumerate(loader):
            result = self.validation_step(batch, batch_idx)
            loss = result.get('loss', None)
            metrics = result.get('metrics', {})
            bs = batch[0].shape[0]
            num_samples += bs
W
wuzewu 已提交
268 269

            if loss:
270
                avg_loss += loss.numpy()[0] * bs
W
wuzewu 已提交
271

272 273
            for metric, value in metrics.items():
                sum_metrics[metric] += value.numpy()[0] * bs
W
wuzewu 已提交
274

275 276 277 278 279
        # print avg metrics and loss
        print_msg = '[Evaluation result]'
        if loss:
            avg_loss /= num_samples
            print_msg += ' avg_loss={:.4f}'.format(avg_loss)
W
wuzewu 已提交
280

281 282 283 284 285 286 287 288 289
        for metric, value in sum_metrics.items():
            avg_metrics[metric] = value / num_samples
            print_msg += ' avg_{}={:.4f}'.format(metric, avg_metrics[metric])

        logger.eval(print_msg)

        if loss:
            return {'loss': avg_loss, 'metrics': avg_metrics}
        return {'metrics': avg_metrics}
W
wuzewu 已提交
290

W
wuzewu 已提交
291 292 293 294 295 296 297 298
    def training_step(self, batch: List[paddle.Tensor], batch_idx: int):
        '''
        One step for training, which should be called as forward computation.

        Args:
            batch(list[paddle.Tensor]) : The one batch data
            batch_idx(int) : The index of batch.
        '''
W
wuzewu 已提交
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
        if self.nranks > 1:
            result = self.model._layers.training_step(batch, batch_idx)
        else:
            result = self.model.training_step(batch, batch_idx)

        # process result
        if not isinstance(result, dict):
            raise RuntimeError()

        loss = result.get('loss', None)
        if not loss:
            raise RuntimeError()

        metrics = result.get('metrics', {})

        # back prop
        if self.nranks > 1:
            self.model.scale_loss(loss)
            loss.backward()
            self.model.apply_collective_grads()
        else:
            loss.backward()

        return loss, metrics

    def validation_step(self, batch: Any, batch_idx: int):
W
wuzewu 已提交
325 326 327 328 329 330 331
        '''
        One step for validation, which should be called as forward computation.

        Args:
            batch(list[paddle.Tensor]) : The one batch data
            batch_idx(int) : The index of batch.
        '''
W
wuzewu 已提交
332 333 334 335 336 337
        if self.nranks > 1:
            result = self.model._layers.validation_step(batch, batch_idx)
        else:
            result = self.model.validation_step(batch, batch_idx)
        return result

W
wuzewu 已提交
338
    def optimizer_step(self, epoch_idx: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer,
339
                       loss: paddle.Tensor):
W
wuzewu 已提交
340 341 342 343 344 345 346 347 348
        '''
        One step for optimize.

        Args:
            epoch_idx(int) : The index of epoch.
            batch_idx(int) : The index of batch.
            optimizer(paddle.optimizer.Optimizer) : Optimizer used.
            loss(paddle.Tensor) : Loss tensor.
        '''
W
wuzewu 已提交
349 350 351 352 353 354
        self.optimizer.step()
        self.learning_rate_step(epoch_idx, batch_idx, self.optimizer.get_lr(), loss)

    def learning_rate_step(self, epoch_idx: int, batch_idx: int, learning_rate: Generic, loss: paddle.Tensor):
        if isinstance(learning_rate, paddle.optimizer._LRScheduler):
            learning_rate.step()
W
wuzewu 已提交
355

W
wuzewu 已提交
356 357 358 359 360 361 362 363 364 365
    def optimizer_zero_grad(self, epoch_idx: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer):
        '''
        One step for clear gradients.

        Args:
            epoch_idx(int) : The index of epoch.
            batch_idx(int) : The index of batch.
            optimizer(paddle.optimizer.Optimizer) : Optimizer used.
            loss(paddle.Tensor) : Loss tensor.
        '''
W
wuzewu 已提交
366 367 368 369 370 371
        self.model.clear_gradients()

    def _compare_metrics(self, old_metric: dict, new_metric: dict):
        '''Compare the whether the new metric value is better than the old one'''
        mainkey = list(new_metric.keys())[0]
        return old_metric[mainkey] < new_metric[mainkey]