trainer.py 17.3 KB
Newer Older
H
Hui Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
H
Hui Zhang 已提交
14
import sys
H
Hui Zhang 已提交
15
import time
H
Hui Zhang 已提交
16
from collections import OrderedDict
17
from contextlib import contextmanager
H
format  
Hui Zhang 已提交
18
from pathlib import Path
H
Hui Zhang 已提交
19 20 21 22 23

import paddle
from paddle import distributed as dist
from tensorboardX import SummaryWriter

H
Hui Zhang 已提交
24
from deepspeech.training.reporter import ObsScope
H
format  
Hui Zhang 已提交
25 26
from deepspeech.training.reporter import report
from deepspeech.training.timer import Timer
H
Hui Zhang 已提交
27
from deepspeech.utils import mp_tools
H
Hui Zhang 已提交
28
from deepspeech.utils import profiler
H
Haoxin Ma 已提交
29
from deepspeech.utils.checkpoint import Checkpoint
30
from deepspeech.utils.log import Log
H
Hui Zhang 已提交
31
from deepspeech.utils.utility import all_version
H
Hui Zhang 已提交
32
from deepspeech.utils.utility import seed_all
H
Hui Zhang 已提交
33
from deepspeech.utils.utility import UpdateConfig
H
Hui Zhang 已提交
34 35 36

__all__ = ["Trainer"]

37 38
logger = Log(__name__).getlog()

H
Hui Zhang 已提交
39 40 41

class Trainer():
    """
42 43 44 45 46 47
    An experiment template in order to structure the training code and take
    care of saving, loading, logging, visualization stuffs. It's intended to
    be flexible and simple.

    So it only handles output directory (create directory for the output,
    create a checkpoint directory, dump the config in use and create
H
Hui Zhang 已提交
48
    visualizer and logger) in a standard way without enforcing any
49 50 51
    input-output protocols to the model and dataloader. It leaves the main
    part for the user to implement their own (setup the model, criterion,
    optimizer, define a training step, define a validation function and
H
Hui Zhang 已提交
52
    customize all the text and visual logs).
53 54
    It does not save too much boilerplate code. The users still have to write
    the forward/backward/update mannually, but they are free to add
H
Hui Zhang 已提交
55 56
    non-standard behaviors if needed.
    We have some conventions to follow.
57
    1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
H
Hui Zhang 已提交
58
    ``valid_loader``, ``config`` and ``args`` attributes.
59 60 61
    2. The config should have a ``training`` field, which has
    ``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
    used as the trigger to invoke validation, checkpointing and stop of the
H
Hui Zhang 已提交
62
    experiment.
63
    3. There are four methods, namely ``train_batch``, ``valid``,
H
Hui Zhang 已提交
64
    ``setup_model`` and ``setup_dataloader`` that should be implemented.
65
    Feel free to add/overwrite other methods and standalone functions if you
H
Hui Zhang 已提交
66
    need.
67

H
Hui Zhang 已提交
68 69 70 71
    Parameters
    ----------
    config: yacs.config.CfgNode
        The configuration used for the experiment.
72

H
Hui Zhang 已提交
73 74 75 76
    args: argparse.Namespace
        The parsed command line arguments.
    Examples
    --------
H
fix bug  
Haoxin Ma 已提交
77
    >>> def main_sp(config, args):
H
Hui Zhang 已提交
78 79 80
    >>>     exp = Trainer(config, args)
    >>>     exp.setup()
    >>>     exp.run()
81
    >>>
H
Hui Zhang 已提交
82 83 84
    >>> config = get_cfg_defaults()
    >>> parser = default_argument_parser()
    >>> args = parser.parse_args()
85
    >>> if args.config:
H
Hui Zhang 已提交
86 87 88 89
    >>>     config.merge_from_file(args.config)
    >>> if args.opts:
    >>>     config.merge_from_list(args.opts)
    >>> config.freeze()
90
    >>>
H
Hui Zhang 已提交
91
    >>> if args.nprocs > 0:
H
Hui Zhang 已提交
92 93 94 95 96 97 98 99 100 101 102 103
    >>>     dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
    >>> else:
    >>>     main_sp(config, args)
    """

    def __init__(self, config, args):
        self.config = config
        self.args = args
        self.optimizer = None
        self.visualizer = None
        self.output_dir = None
        self.checkpoint_dir = None
104 105
        self.iteration = 0
        self.epoch = 0
H
Hui Zhang 已提交
106
        self.rank = dist.get_rank()
107 108
        self.world_size = dist.get_world_size()
        self._train = True
H
Hui Zhang 已提交
109

110
        # print deps version
H
Hui Zhang 已提交
111
        all_version()
112
        logger.info(f"Rank: {self.rank}/{self.world_size}")
H
huangyuxin 已提交
113

114 115 116 117 118 119 120 121 122 123
        # set device
        paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
        if self.parallel:
            self.init_parallel()

        self.checkpoint = Checkpoint(
            kbest_n=self.config.training.checkpoint.kbest_n,
            latest_n=self.config.training.checkpoint.latest_n)

        # set random seed if needed
H
Hui Zhang 已提交
124 125 126
        if args.seed:
            seed_all(args.seed)
            logger.info(f"Set seed {args.seed}")
H
Hui Zhang 已提交
127

128
        # profiler and benchmark options
H
Hui Zhang 已提交
129 130
        if hasattr(self.args,
                   "benchmark_batch_size") and self.args.benchmark_batch_size:
H
Hui Zhang 已提交
131 132
            with UpdateConfig(self.config):
                self.config.collator.batch_size = self.args.benchmark_batch_size
H
Hui Zhang 已提交
133
                self.config.training.log_interval = 1
H
Hui Zhang 已提交
134 135 136
            logger.info(
                f"Benchmark reset batch-size: {self.args.benchmark_batch_size}")

H
Hui Zhang 已提交
137 138 139 140
    @property
    def train(self):
        return self._train

141 142 143 144 145 146
    @contextmanager
    def eval(self):
        self._train = False
        yield
        self._train = True

H
Hui Zhang 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
    def setup(self):
        """Setup the experiment.
        """
        self.setup_output_dir()
        self.dump_config()
        self.setup_visualizer()

        self.setup_dataloader()
        self.setup_model()

        self.iteration = 0
        self.epoch = 0

    @property
    def parallel(self):
162
        """A flag indicating whether the experiment should run with
H
Hui Zhang 已提交
163 164
        multiprocessing.
        """
165
        return self.args.nprocs > 1
H
Hui Zhang 已提交
166 167 168 169 170 171 172

    def init_parallel(self):
        """Init environment for multiprocess training.
        """
        dist.init_parallel_env()

    @mp_tools.rank_zero_only
173
    def save(self, tag=None, infos: dict=None):
H
Hui Zhang 已提交
174
        """Save checkpoint (model parameters and optimizer states).
175 176 177 178

        Args:
            tag (int or str, optional): None for step, else using tag, e.g epoch. Defaults to None.
            infos (dict, optional): meta data to save. Defaults to None.
H
Hui Zhang 已提交
179 180
        """

181 182 183 184 185 186
        infos = infos if infos else dict()
        infos.update({
            "step": self.iteration,
            "epoch": self.epoch,
            "lr": self.optimizer.get_lr()
        })
H
Hui Zhang 已提交
187 188 189
        self.checkpoint.save_parameters(self.checkpoint_dir, self.iteration
                                        if tag is None else tag, self.model,
                                        self.optimizer, infos)
190 191

    def resume_or_scratch(self):
192
        """Resume from latest checkpoint at checkpoints in the output
H
Hui Zhang 已提交
193
        directory or load a specified checkpoint.
194

H
Hui Zhang 已提交
195 196 197
        If ``args.checkpoint_path`` is not None, load the checkpoint, else
        resume training.
        """
198
        scratch = None
H
Haoxin Ma 已提交
199
        infos = self.checkpoint.load_latest_parameters(
H
Hui Zhang 已提交
200 201 202
            self.model,
            self.optimizer,
            checkpoint_dir=self.checkpoint_dir,
H
Haoxin Ma 已提交
203
            checkpoint_path=self.args.checkpoint_path)
204
        if infos:
205 206
            # just restore ckpt
            # lr will resotre from optimizer ckpt
207 208
            self.iteration = infos["step"]
            self.epoch = infos["epoch"]
209
            scratch = False
H
Hui Zhang 已提交
210 211
            logger.info(
                f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
212 213 214 215
        else:
            self.iteration = 0
            self.epoch = 0
            scratch = True
H
Hui Zhang 已提交
216
            logger.info("Init from scratch!")
217
        return scratch
H
Hui Zhang 已提交
218

219 220 221 222 223 224 225
    def maybe_batch_sampler_step(self):
        """ batch_sampler seed by epoch """
        if hasattr(self.train_loader, "batch_sampler"):
            batch_sampler = self.train_loader.batch_sampler
            if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
                batch_sampler.set_epoch(self.epoch)

H
Hui Zhang 已提交
226
    def before_train(self):
227 228 229 230 231 232 233 234
        from_scratch = self.resume_or_scratch()
        if from_scratch:
            # scratch: save init model, i.e. 0 epoch
            self.save(tag='init', infos=None)
        else:
            # resume: train next_epoch and next_iteration
            self.epoch += 1
            self.iteration += 1
H
Hui Zhang 已提交
235 236
            logger.info(
                f"Resume train: epoch {self.epoch }, step {self.iteration}!")
237 238 239

        self.maybe_batch_sampler_step()

H
Hui Zhang 已提交
240
    def new_epoch(self):
241
        """Reset the train loader seed and increment `epoch`.
H
Hui Zhang 已提交
242
        """
243
        # `iteration` increased by train step
244
        self.epoch += 1
245
        self.maybe_batch_sampler_step()
H
Hui Zhang 已提交
246

H
Hui Zhang 已提交
247
    def after_train_batch(self):
H
Hui Zhang 已提交
248
        if self.args.benchmark_max_step and self.iteration > self.args.benchmark_max_step:
H
Hui Zhang 已提交
249
            profiler.add_profiler_step(self.args.profiler_options)
H
Hui Zhang 已提交
250 251 252 253 254
            logger.info(
                f"Reach benchmark-max-step: {self.args.benchmark_max_step}")
            sys.exit(
                f"Reach benchmark-max-step: {self.args.benchmark_max_step}")

H
Hui Zhang 已提交
255
    def do_train(self):
256
        """The training process control by epoch."""
257
        self.before_train()
258 259 260

        logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
        while self.epoch < self.config.training.n_epoch:
H
Hui Zhang 已提交
261 262 263
            with Timer("Epoch-Train Time Cost: {}"):
                self.model.train()
                try:
264
                    data_start_time = time.time()
H
Hui Zhang 已提交
265 266
                    for batch_index, batch in enumerate(self.train_loader):
                        dataload_time = time.time() - data_start_time
H
Hui Zhang 已提交
267 268 269 270 271 272 273 274 275
                        msg = "Train:"
                        observation = OrderedDict()
                        with ObsScope(observation):
                            report("Rank", dist.get_rank())
                            report("epoch", self.epoch)
                            report('step', self.iteration)
                            report("lr", self.lr_scheduler())
                            self.train_batch(batch_index, batch, msg)
                            self.after_train_batch()
H
Hui Zhang 已提交
276 277
                            report('iter', batch_index + 1)
                            report('total', len(self.train_loader))
H
Hui Zhang 已提交
278
                            report('reader_cost', dataload_time)
H
format  
Hui Zhang 已提交
279 280
                        observation['batch_cost'] = observation[
                            'reader_cost'] + observation['step_cost']
H
Hui Zhang 已提交
281
                        observation['samples'] = observation['batch_size']
H
format  
Hui Zhang 已提交
282 283
                        observation['ips[sent./sec]'] = observation[
                            'batch_size'] / observation['batch_cost']
H
Hui Zhang 已提交
284 285
                        for k, v in observation.items():
                            msg += f" {k}: "
H
format  
Hui Zhang 已提交
286 287
                            msg += f"{v:>.8f}" if isinstance(v,
                                                             float) else f"{v}"
H
Hui Zhang 已提交
288
                            msg += ","
H
huangyuxin 已提交
289
                        msg = msg[:-1]  # remove the last ","
H
Hui Zhang 已提交
290
                        logger.info(msg)
H
Hui Zhang 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
                        data_start_time = time.time()
                except Exception as e:
                    logger.error(e)
                    raise e

            with Timer("Eval Time Cost: {}"):
                total_loss, num_seen_utts = self.valid()
                if dist.get_world_size() > 1:
                    num_seen_utts = paddle.to_tensor(num_seen_utts)
                    # the default operator in all_reduce function is sum.
                    dist.all_reduce(num_seen_utts)
                    total_loss = paddle.to_tensor(total_loss)
                    dist.all_reduce(total_loss)
                    cv_loss = total_loss / num_seen_utts
                    cv_loss = float(cv_loss)
                else:
                    cv_loss = total_loss / num_seen_utts
308 309 310 311 312 313 314 315

            logger.info(
                'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
            if self.visualizer:
                self.visualizer.add_scalars(
                    'epoch', {'cv_loss': cv_loss,
                              'lr': self.lr_scheduler()}, self.epoch)

316
            # after epoch
317
            self.save(tag=self.epoch, infos={'val_loss': cv_loss})
H
Hui Zhang 已提交
318
            # step lr every epoch
H
Hui Zhang 已提交
319 320 321 322 323 324 325
            self.lr_scheduler.step()
            self.new_epoch()

    def run(self):
        """The routine of the experiment after setup. This method is intended
        to be used by the user.
        """
326 327
        try:
            with Timer("Training Done: {}"):
H
Hui Zhang 已提交
328
                self.do_train()
329 330 331 332 333
        except KeyboardInterrupt:
            exit(-1)
        finally:
            self.destory()

H
Hui Zhang 已提交
334 335 336 337 338 339 340 341 342
    def restore(self):
        """Resume from latest checkpoint at checkpoints in the output
        directory or load a specified checkpoint.

        If ``args.checkpoint_path`` is not None, load the checkpoint, else
        resume training.
        """
        assert self.args.checkpoint_path
        infos = self.checkpoint.load_latest_parameters(
H
Hui Zhang 已提交
343
            self.model, checkpoint_path=self.args.checkpoint_path)
H
Hui Zhang 已提交
344 345
        return infos

346 347 348 349 350
    def run_test(self):
        """Do Test/Decode"""
        try:
            with Timer("Test/Decode Done: {}"):
                with self.eval():
H
Hui Zhang 已提交
351
                    self.restore()
352 353 354 355 356 357 358 359 360
                    self.test()
        except KeyboardInterrupt:
            exit(-1)

    def run_export(self):
        """Do Model Export"""
        try:
            with Timer("Export Done: {}"):
                with self.eval():
H
Hui Zhang 已提交
361
                    self.restore()
362 363 364 365 366 367 368 369 370
                    self.export()
        except KeyboardInterrupt:
            exit(-1)

    def run_align(self):
        """Do CTC alignment"""
        try:
            with Timer("Align Done: {}"):
                with self.eval():
H
Hui Zhang 已提交
371
                    self.restore()
372 373 374
                    self.align()
        except KeyboardInterrupt:
            sys.exit(-1)
H
Hui Zhang 已提交
375 376 377 378

    def setup_output_dir(self):
        """Create a directory used for output.
        """
379 380 381 382 383
        if self.args.output:
            output_dir = Path(self.args.output).expanduser()
        elif self.args.checkpoint_path:
            output_dir = Path(
                self.args.checkpoint_path).expanduser().parent.parent
H
Hui Zhang 已提交
384
        self.output_dir = output_dir
385
        self.output_dir.mkdir(parents=True, exist_ok=True)
H
Hui Zhang 已提交
386

387 388
        self.checkpoint_dir = self.output_dir / "checkpoints"
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
389

390 391
        self.log_dir = output_dir / "log"
        self.log_dir.mkdir(parents=True, exist_ok=True)
H
Hui Zhang 已提交
392

393 394
        self.test_dir = output_dir / "test"
        self.test_dir.mkdir(parents=True, exist_ok=True)
H
Hui Zhang 已提交
395

396 397 398 399 400 401 402 403 404 405 406
        self.decode_dir = output_dir / "decode"
        self.decode_dir.mkdir(parents=True, exist_ok=True)

        self.export_dir = output_dir / "export"
        self.export_dir.mkdir(parents=True, exist_ok=True)

        self.visual_dir = output_dir / "visual"
        self.visual_dir.mkdir(parents=True, exist_ok=True)

        self.config_dir = output_dir / "conf"
        self.config_dir.mkdir(parents=True, exist_ok=True)
H
Haoxin Ma 已提交
407

H
Hui Zhang 已提交
408 409
    @mp_tools.rank_zero_only
    def destory(self):
410
        """Close visualizer to avoid hanging after training"""
H
Hui Zhang 已提交
411 412 413 414 415 416 417
        # https://github.com/pytorch/fairseq/issues/2357
        if self.visualizer:
            self.visualizer.close()

    @mp_tools.rank_zero_only
    def setup_visualizer(self):
        """Initialize a visualizer to log the experiment.
418

H
Hui Zhang 已提交
419
        The visual log is saved in the output directory.
420

H
Hui Zhang 已提交
421 422
        Notes
        ------
423 424
        Only the main process has a visualizer with it. Use multiple
        visualizers in multiprocess to write to a same log file may cause
H
Hui Zhang 已提交
425 426 427
        unexpected behaviors.
        """
        # visualizer
428
        visualizer = SummaryWriter(logdir=str(self.visual_dir))
H
Hui Zhang 已提交
429 430 431 432
        self.visualizer = visualizer

    @mp_tools.rank_zero_only
    def dump_config(self):
433 434 435
        """Save the configuration used for this experiment.

        It is saved in to ``config.yaml`` in the output directory at the
H
Hui Zhang 已提交
436 437
        beginning of the experiment.
        """
438
        config_file = self.config_dir / "config.yaml"
H
Hui Zhang 已提交
439
        if self.train and config_file.exists():
440 441 442 443 444 445
            time_stamp = time.strftime("%Y_%m_%d_%H_%M_%s", time.gmtime())
            target_path = self.config_dir / ".".join(
                [time_stamp, "config.yaml"])
            config_file.rename(target_path)

        with open(config_file, 'wt') as f:
H
Hui Zhang 已提交
446 447 448 449 450 451 452 453 454 455 456 457 458
            print(self.config, file=f)

    def train_batch(self):
        """The training loop. A subclass should implement this method.
        """
        raise NotImplementedError("train_batch should be implemented.")

    @paddle.no_grad()
    def valid(self):
        """The validation. A subclass should implement this method.
        """
        raise NotImplementedError("valid should be implemented.")

459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476
    @paddle.no_grad()
    def test(self):
        """The test. A subclass should implement this method in Tester.
        """
        raise NotImplementedError("test should be implemented.")

    @paddle.no_grad()
    def export(self):
        """The test. A subclass should implement this method in Tester.
        """
        raise NotImplementedError("export should be implemented.")

    @paddle.no_grad()
    def align(self):
        """The align. A subclass should implement this method in Tester.
        """
        raise NotImplementedError("align should be implemented.")

H
Hui Zhang 已提交
477
    def setup_model(self):
478
        """Setup model, criterion and optimizer, etc. A subclass should
H
Hui Zhang 已提交
479 480 481 482 483
        implement this method.
        """
        raise NotImplementedError("setup_model should be implemented.")

    def setup_dataloader(self):
484
        """Setup training dataloader and validation dataloader. A subclass
H
Hui Zhang 已提交
485 486 487
        should implement this method.
        """
        raise NotImplementedError("setup_dataloader should be implemented.")