model.py 26.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains U2 model."""
import json
import os
import sys
import time
from collections import defaultdict
20
from contextlib import nullcontext
21 22 23 24 25 26 27 28 29 30 31 32 33 34
from pathlib import Path
from typing import Optional

import numpy as np
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from yacs.config import CfgNode

from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.u2 import U2Model
35 36
from deepspeech.training.optimizer import OptimizerFactory
from deepspeech.training.scheduler import LRSchedulerFactory
37
from deepspeech.training.timer import Timer
38
from deepspeech.training.trainer import Trainer
H
Hui Zhang 已提交
39
from deepspeech.utils import ctc_utils
40 41 42
from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
H
Hui Zhang 已提交
43
from deepspeech.utils import text_grid
H
Hui Zhang 已提交
44
from deepspeech.utils import utility
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
from deepspeech.utils.log import Log

logger = Log(__name__).getlog()


class U2Trainer(Trainer):
    @classmethod
    def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
        # training config
        default = CfgNode(
            dict(
                n_epoch=50,  # train epochs
                log_interval=100,  # steps
                accum_grad=1,  # accum grad by # steps
                global_grad_clip=5.0,  # the global norm clip
            ))
        default.optim = 'adam'
        default.optim_conf = CfgNode(
            dict(
                lr=5e-4,  # learning rate
                weight_decay=1e-6,  # the coeff of weight decay
            ))
        default.scheduler = 'warmuplr'
        default.scheduler_conf = CfgNode(
            dict(
                warmup_steps=25000,
                lr_decay=1.0,  # learning rate decay
            ))

        if config is not None:
            config.merge_from_other_cfg(default)
        return default

    def __init__(self, config, args):
        super().__init__(config, args)

    def train_batch(self, batch_index, batch_data, msg):
        train_conf = self.config.training
        start = time.time()

85 86
        # forward
        utt, audio, audio_len, text, text_len = batch_data
H
Haoxin Ma 已提交
87 88
        loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
                                                    text_len)
89

90 91 92 93 94 95 96 97
        # loss div by `batch_size * accum_grad`
        loss /= train_conf.accum_grad
        losses_np = {'loss': float(loss) * train_conf.accum_grad}
        if attention_loss:
            losses_np['att_loss'] = float(attention_loss)
        if ctc_loss:
            losses_np['ctc_loss'] = float(ctc_loss)

98 99 100 101 102
        # loss backward
        if (batch_index + 1) % train_conf.accum_grad != 0:
            # Disable gradient synchronizations across DDP processes.
            # Within this context, gradients will be accumulated on module
            # variables, which will later be synchronized.
H
Hui Zhang 已提交
103 104
            # When using cpu w/o DDP, model does not have `no_sync`
            context = self.model.no_sync if self.parallel else nullcontext
105 106 107 108 109 110 111 112 113
        else:
            # Used for single gpu training and DDP gradient synchronization
            # processes.
            context = nullcontext
        with context():
            loss.backward()
            layer_tools.print_grads(self.model, print_func=None)

        # optimizer step
114 115 116 117 118 119 120 121 122 123
        if (batch_index + 1) % train_conf.accum_grad == 0:
            self.optimizer.step()
            self.optimizer.clear_grad()
            self.lr_scheduler.step()
            self.iteration += 1

        iteration_time = time.time() - start

        if (batch_index + 1) % train_conf.log_interval == 0:
            msg += "train time: {:>.3f}s, ".format(iteration_time)
H
Haoxin Ma 已提交
124
            msg += "batch size: {}, ".format(self.config.collator.batch_size)
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
            msg += "accum: {}, ".format(train_conf.accum_grad)
            msg += ', '.join('{}: {:>.6f}'.format(k, v)
                             for k, v in losses_np.items())
            logger.info(msg)

            if dist.get_rank() == 0 and self.visualizer:
                losses_np_v = losses_np.copy()
                losses_np_v.update({"lr": self.lr_scheduler()})
                self.visualizer.add_scalars("step", losses_np_v,
                                            self.iteration - 1)

    @paddle.no_grad()
    def valid(self):
        self.model.eval()
        logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
        valid_losses = defaultdict(list)
        num_seen_utts = 1
        total_loss = 0.0
        for i, batch in enumerate(self.valid_loader):
H
Haoxin Ma 已提交
144
            utt, audio, audio_len, text, text_len = batch
H
Haoxin Ma 已提交
145 146
            loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
                                                        text_len)
147
            if paddle.isfinite(loss):
H
Haoxin Ma 已提交
148
                num_utts = batch[1].shape[0]
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
                num_seen_utts += num_utts
                total_loss += float(loss) * num_utts
                valid_losses['val_loss'].append(float(loss))
                if attention_loss:
                    valid_losses['val_att_loss'].append(float(attention_loss))
                if ctc_loss:
                    valid_losses['val_ctc_loss'].append(float(ctc_loss))

            if (i + 1) % self.config.training.log_interval == 0:
                valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
                valid_dump['val_history_loss'] = total_loss / num_seen_utts

                # logging
                msg = f"Valid: Rank: {dist.get_rank()}, "
                msg += "epoch: {}, ".format(self.epoch)
                msg += "step: {}, ".format(self.iteration)
                msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
                msg += ', '.join('{}: {:>.6f}'.format(k, v)
                                 for k, v in valid_dump.items())
                logger.info(msg)

        logger.info('Rank {} Val info val_loss {}'.format(
            dist.get_rank(), total_loss / num_seen_utts))
        return total_loss, num_seen_utts

    def train(self):
        """The training process control by step."""
        # !!!IMPORTANT!!!
        # Try to export the model by script, if fails, we should refine
        # the code to satisfy the script export requirements
        # script_model = paddle.jit.to_static(self.model)
        # script_model_path = str(self.checkpoint_dir / 'init')
        # paddle.jit.save(script_model, script_model_path)

        from_scratch = self.resume_or_scratch()
        if from_scratch:
            # save init model, i.e. 0 epoch
186
            self.save(tag='init', infos=None)
187

188 189
        # lr will resotre from optimizer ckpt
        # self.lr_scheduler.step(self.iteration)
190
        if self.parallel and hasattr(self.train_loader, 'batch_sampler'):
191 192 193 194
            self.train_loader.batch_sampler.set_epoch(self.epoch)

        logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
        while self.epoch < self.config.training.n_epoch:
195 196 197
            with Timer("Epoch-Train Time Cost: {}"):
                self.model.train()
                try:
198
                    data_start_time = time.time()
199 200 201 202 203 204 205 206 207 208
                    for batch_index, batch in enumerate(self.train_loader):
                        dataload_time = time.time() - data_start_time
                        msg = "Train: Rank: {}, ".format(dist.get_rank())
                        msg += "epoch: {}, ".format(self.epoch)
                        msg += "step: {}, ".format(self.iteration)
                        msg += "batch : {}/{}, ".format(batch_index + 1,
                                                        len(self.train_loader))
                        msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
                        msg += "data time: {:>.3f}s, ".format(dataload_time)
                        self.train_batch(batch_index, batch, msg)
H
Hui Zhang 已提交
209
                        self.after_train_batch()
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
                        data_start_time = time.time()
                except Exception as e:
                    logger.error(e)
                    raise e

            with Timer("Eval Time Cost: {}"):
                total_loss, num_seen_utts = self.valid()
                if dist.get_world_size() > 1:
                    num_seen_utts = paddle.to_tensor(num_seen_utts)
                    # the default operator in all_reduce function is sum.
                    dist.all_reduce(num_seen_utts)
                    total_loss = paddle.to_tensor(total_loss)
                    dist.all_reduce(total_loss)
                    cv_loss = total_loss / num_seen_utts
                    cv_loss = float(cv_loss)
                else:
                    cv_loss = total_loss / num_seen_utts
227 228 229 230 231 232 233 234 235 236 237 238 239

            logger.info(
                'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
            if self.visualizer:
                self.visualizer.add_scalars(
                    'epoch', {'cv_loss': cv_loss,
                              'lr': self.lr_scheduler()}, self.epoch)
            self.save(tag=self.epoch, infos={'val_loss': cv_loss})
            self.new_epoch()

    def setup_dataloader(self):
        config = self.config.clone()
        config.defrost()
H
Haoxin Ma 已提交
240
        config.collator.keep_transcription_text = False
241 242 243 244 245 246 247 248

        # train/valid dataset, return token ids
        config.data.manifest = config.data.train_manifest
        train_dataset = ManifestDataset.from_config(config)

        config.data.manifest = config.data.dev_manifest
        dev_dataset = ManifestDataset.from_config(config)

H
Haoxin Ma 已提交
249
        collate_fn_train = SpeechCollator.from_config(config)
H
Haoxin Ma 已提交
250

H
Haoxin Ma 已提交
251 252 253
        config.collator.augmentation_config = ""
        collate_fn_dev = SpeechCollator.from_config(config)

254 255 256
        if self.parallel:
            batch_sampler = SortagradDistributedBatchSampler(
                train_dataset,
H
Haoxin Ma 已提交
257
                batch_size=config.collator.batch_size,
258 259 260 261
                num_replicas=None,
                rank=None,
                shuffle=True,
                drop_last=True,
H
Haoxin Ma 已提交
262 263
                sortagrad=config.collator.sortagrad,
                shuffle_method=config.collator.shuffle_method)
264 265 266 267
        else:
            batch_sampler = SortagradBatchSampler(
                train_dataset,
                shuffle=True,
H
Haoxin Ma 已提交
268
                batch_size=config.collator.batch_size,
269
                drop_last=True,
H
Haoxin Ma 已提交
270 271
                sortagrad=config.collator.sortagrad,
                shuffle_method=config.collator.shuffle_method)
272 273 274
        self.train_loader = DataLoader(
            train_dataset,
            batch_sampler=batch_sampler,
H
Haoxin Ma 已提交
275 276
            collate_fn=collate_fn_train,
            num_workers=config.collator.num_workers, )
277 278
        self.valid_loader = DataLoader(
            dev_dataset,
H
Haoxin Ma 已提交
279
            batch_size=config.collator.batch_size,
280 281
            shuffle=False,
            drop_last=False,
H
Haoxin Ma 已提交
282
            collate_fn=collate_fn_dev)
283 284 285 286 287

        # test dataset, return raw text
        config.data.manifest = config.data.test_manifest
        # filter test examples, will cause less examples, but no mismatch with training
        # and can use large batch size , save training time, so filter test egs now.
H
Hui Zhang 已提交
288 289 290 291 292 293
        config.data.min_input_len = 0.0  # second
        config.data.max_input_len = float('inf')  # second
        config.data.min_output_len = 0.0  # tokens
        config.data.max_output_len = float('inf')  # tokens
        config.data.min_output_input_ratio = 0.00
        config.data.max_output_input_ratio = float('inf')
H
Haoxin Ma 已提交
294

295 296
        test_dataset = ManifestDataset.from_config(config)
        # return text ord id
H
Haoxin Ma 已提交
297
        config.collator.keep_transcription_text = True
H
Haoxin Ma 已提交
298
        config.collator.augmentation_config = ""
299 300 301 302 303
        self.test_loader = DataLoader(
            test_dataset,
            batch_size=config.decoding.batch_size,
            shuffle=False,
            drop_last=False,
H
Haoxin Ma 已提交
304
            collate_fn=SpeechCollator.from_config(config))
H
Hui Zhang 已提交
305 306 307 308 309 310 311 312 313
        # return text token id
        config.collator.keep_transcription_text = False
        self.align_loader = DataLoader(
            test_dataset,
            batch_size=config.decoding.batch_size,
            shuffle=False,
            drop_last=False,
            collate_fn=SpeechCollator.from_config(config))
        logger.info("Setup train/valid/test/align Dataloader!")
314 315 316 317 318

    def setup_model(self):
        config = self.config
        model_conf = config.model
        model_conf.defrost()
H
Haoxin Ma 已提交
319 320
        model_conf.input_dim = self.train_loader.collate_fn.feature_size
        model_conf.output_dim = self.train_loader.collate_fn.vocab_size
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
        model_conf.freeze()
        model = U2Model.from_config(model_conf)

        if self.parallel:
            model = paddle.DataParallel(model)

        logger.info(f"{model}")
        layer_tools.print_params(model, logger.info)

        train_config = config.training
        optim_type = train_config.optim
        optim_conf = train_config.optim_conf
        scheduler_type = train_config.scheduler
        scheduler_conf = train_config.scheduler_conf

336
        scheduler_args = {
H
Hui Zhang 已提交
337 338 339 340 341
            "learning_rate": optim_conf.lr,
            "verbose": False,
            "warmup_steps": scheduler_conf.warmup_steps,
            "gamma": scheduler_conf.lr_decay,
            "d_model": model_conf.encoder_conf.output_size,
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
        }
        lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
                                                    scheduler_args)

        def optimizer_args(
                config,
                parameters,
                lr_scheduler=None, ):
            train_config = config.training
            optim_type = train_config.optim
            optim_conf = train_config.optim_conf
            scheduler_type = train_config.scheduler
            scheduler_conf = train_config.scheduler_conf
            return {
                "grad_clip": train_config.global_grad_clip,
                "weight_decay": optim_conf.weight_decay,
                "learning_rate": lr_scheduler
                if lr_scheduler else optim_conf.lr,
                "parameters": parameters,
H
Hui Zhang 已提交
361 362 363
                "epsilon": 1e-9 if optim_type == 'noam' else None,
                "beta1": 0.9 if optim_type == 'noam' else None,
                "beat2": 0.98 if optim_type == 'noam' else None,
364 365 366 367 368
            }

        optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
        optimizer = OptimizerFactory.from_args(optim_type, optimzer_args)

369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395
        self.model = model
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        logger.info("Setup model/optimizer/lr_scheduler!")


class U2Tester(U2Trainer):
    @classmethod
    def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
        # decoding config
        default = CfgNode(
            dict(
                alpha=2.5,  # Coef of LM for beam search.
                beta=0.3,  # Coef of WC for beam search.
                cutoff_prob=1.0,  # Cutoff probability for pruning.
                cutoff_top_n=40,  # Cutoff number for pruning.
                lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm',  # Filepath for language model.
                decoding_method='attention',  # Decoding method. Options: 'attention', 'ctc_greedy_search',
                # 'ctc_prefix_beam_search', 'attention_rescoring'
                error_rate_type='wer',  # Error rate type for evaluation. Options `wer`, 'cer'
                num_proc_bsearch=8,  # # of CPUs for beam search.
                beam_size=10,  # Beam search width.
                batch_size=16,  # decoding batch size
                ctc_weight=0.0,  # ctc weight for attention rescoring decode mode.
                decoding_chunk_size=-1,  # decoding chunk size. Defaults to -1.
                # <0: for decoding, use full chunk.
                # >0: for decoding, use fixed chunk size as set.
H
Hui Zhang 已提交
396
                # 0: used for training, it's prohibited here.
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416
                num_decoding_left_chunks=-1,  # number of left chunks for decoding. Defaults to -1.
                simulate_streaming=False,  # simulate streaming inference. Defaults to False.
            ))

        if config is not None:
            config.merge_from_other_cfg(default)
        return default

    def __init__(self, config, args):
        super().__init__(config, args)

    def ordid2token(self, texts, texts_len):
        """ ord() id to chr() chr """
        trans = []
        for text, n in zip(texts, texts_len):
            n = n.numpy().item()
            ids = text[:n]
            trans.append(''.join([chr(i) for i in ids]))
        return trans

H
Haoxin Ma 已提交
417 418 419 420 421 422 423
    def compute_metrics(self,
                        utts,
                        audio,
                        audio_len,
                        texts,
                        texts_len,
                        fout=None):
424 425 426 427 428 429
        cfg = self.config.decoding
        errors_sum, len_refs, num_ins = 0.0, 0, 0
        errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
        error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer

        start_time = time.time()
H
Haoxin Ma 已提交
430
        text_feature = self.test_loader.collate_fn.text_feature
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449
        target_transcripts = self.ordid2token(texts, texts_len)
        result_transcripts = self.model.decode(
            audio,
            audio_len,
            text_feature=text_feature,
            decoding_method=cfg.decoding_method,
            lang_model_path=cfg.lang_model_path,
            beam_alpha=cfg.alpha,
            beam_beta=cfg.beta,
            beam_size=cfg.beam_size,
            cutoff_prob=cfg.cutoff_prob,
            cutoff_top_n=cfg.cutoff_top_n,
            num_processes=cfg.num_proc_bsearch,
            ctc_weight=cfg.ctc_weight,
            decoding_chunk_size=cfg.decoding_chunk_size,
            num_decoding_left_chunks=cfg.num_decoding_left_chunks,
            simulate_streaming=cfg.simulate_streaming)
        decode_time = time.time() - start_time

H
Haoxin Ma 已提交
450 451
        for utt, target, result in zip(utts, target_transcripts,
                                       result_transcripts):
452 453 454 455 456
            errors, len_ref = errors_func(target, result)
            errors_sum += errors
            len_refs += len_ref
            num_ins += 1
            if fout:
H
Haoxin Ma 已提交
457
                fout.write(utt + " " + result + "\n")
458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478
            logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
                        (target, result))
            logger.info("One example error rate [%s] = %f" %
                        (cfg.error_rate_type, error_rate_func(target, result)))

        return dict(
            errors_sum=errors_sum,
            len_refs=len_refs,
            num_ins=num_ins,  # num examples
            error_rate=errors_sum / len_refs,
            error_rate_type=cfg.error_rate_type,
            num_frames=audio_len.sum().numpy().item(),
            decode_time=decode_time)

    @mp_tools.rank_zero_only
    @paddle.no_grad()
    def test(self):
        assert self.args.result_file
        self.model.eval()
        logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")

H
Haoxin Ma 已提交
479
        stride_ms = self.test_loader.collate_fn.stride_ms
480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507
        error_rate_type = None
        errors_sum, len_refs, num_ins = 0.0, 0, 0
        num_frames = 0.0
        num_time = 0.0
        with open(self.args.result_file, 'w') as fout:
            for i, batch in enumerate(self.test_loader):
                metrics = self.compute_metrics(*batch, fout=fout)
                num_frames += metrics['num_frames']
                num_time += metrics["decode_time"]
                errors_sum += metrics['errors_sum']
                len_refs += metrics['len_refs']
                num_ins += metrics['num_ins']
                error_rate_type = metrics['error_rate_type']
                rtf = num_time / (num_frames * stride_ms)
                logger.info(
                    "RTF: %f, Error rate [%s] (%d/?) = %f" %
                    (rtf, error_rate_type, num_ins, errors_sum / len_refs))

        rtf = num_time / (num_frames * stride_ms)
        msg = "Test: "
        msg += "epoch: {}, ".format(self.epoch)
        msg += "step: {}, ".format(self.iteration)
        msg += "RTF: {}, ".format(rtf)
        msg += "Final error rate [%s] (%d/%d) = %f" % (
            error_rate_type, num_ins, num_ins, errors_sum / len_refs)
        logger.info(msg)

        # test meta results
H
Hui Zhang 已提交
508
        err_meta_path = os.path.splitext(self.args.result_file)[0] + '.err'
509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528
        err_type_str = "{}".format(error_rate_type)
        with open(err_meta_path, 'w') as f:
            data = json.dumps({
                "epoch":
                self.epoch,
                "step":
                self.iteration,
                "rtf":
                rtf,
                error_rate_type:
                errors_sum / len_refs,
                "dataset_hour": (num_frames * stride_ms) / 1000.0 / 3600.0,
                "process_hour":
                num_time / 1000.0 / 3600.0,
                "num_examples":
                num_ins,
                "err_sum":
                errors_sum,
                "ref_len":
                len_refs,
H
Hui Zhang 已提交
529 530
                "decode_method":
                self.config.decoding.decoding_method,
531 532 533 534 535 536 537 538 539 540
            })
            f.write(data + '\n')

    def run_test(self):
        self.resume_or_scratch()
        try:
            self.test()
        except KeyboardInterrupt:
            sys.exit(-1)

H
Hui Zhang 已提交
541 542 543 544 545 546 547
    @paddle.no_grad()
    def align(self):
        if self.config.decoding.batch_size > 1:
            logger.fatal('alignment mode must be running with batch_size == 1')
            sys.exit(1)

        # xxx.align
H
Hui Zhang 已提交
548 549
        assert self.args.result_file and self.args.result_file.endswith(
            '.align')
H
Hui Zhang 已提交
550 551

        self.model.eval()
H
Hui Zhang 已提交
552
        logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
H
Hui Zhang 已提交
553

H
Hui Zhang 已提交
554 555
        stride_ms = self.align_loader.collate_fn.stride_ms
        token_dict = self.align_loader.collate_fn.vocab_list
H
Hui Zhang 已提交
556
        with open(self.args.result_file, 'w') as fout:
H
Hui Zhang 已提交
557
            # one example in batch
H
Hui Zhang 已提交
558
            for i, batch in enumerate(self.align_loader):
H
Hui Zhang 已提交
559
                key, feat, feats_length, target, target_length = batch
H
Hui Zhang 已提交
560

H
Hui Zhang 已提交
561 562 563 564 565 566 567 568 569 570 571
                # 1. Encoder
                encoder_out, encoder_mask = self.model._forward_encoder(
                    feat, feats_length)  # (B, maxlen, encoder_dim)
                maxlen = encoder_out.size(1)
                ctc_probs = self.model.ctc.log_softmax(
                    encoder_out)  # (1, maxlen, vocab_size)

                # 2. alignment
                ctc_probs = ctc_probs.squeeze(0)
                target = target.squeeze(0)
                alignment = ctc_utils.forced_align(ctc_probs, target)
H
Hui Zhang 已提交
572
                logger.info("align ids", key[0], alignment)
H
Hui Zhang 已提交
573 574 575 576 577
                fout.write('{} {}\n'.format(key[0], alignment))

                # 3. gen praat
                # segment alignment
                align_segs = text_grid.segment_alignment(alignment)
H
Hui Zhang 已提交
578
                logger.info("align tokens", key[0], align_segs)
H
Hui Zhang 已提交
579
                # IntervalTier, List["start end token\n"]
H
Hui Zhang 已提交
580
                subsample = utility.get_subsample(self.config)
H
Hui Zhang 已提交
581 582
                tierformat = text_grid.align_to_tierformat(
                    align_segs, subsample, token_dict)
H
Hui Zhang 已提交
583
                # write tier
H
Hui Zhang 已提交
584 585 586
                align_output_path = os.path.join(
                    os.path.dirname(self.args.result_file), "align")
                tier_path = os.path.join(align_output_path, key[0] + ".tier")
H
Hui Zhang 已提交
587 588
                with open(tier_path, 'w') as f:
                    f.writelines(tierformat)
H
Hui Zhang 已提交
589
                # write textgrid
H
Hui Zhang 已提交
590 591
                textgrid_path = os.path.join(align_output_path,
                                             key[0] + ".TextGrid")
H
Hui Zhang 已提交
592 593 594 595
                second_per_frame = 1. / (1000. /
                                         stride_ms)  # 25ms window, 10ms stride
                second_per_example = (
                    len(alignment) + 1) * subsample * second_per_frame
H
Hui Zhang 已提交
596
                text_grid.generate_textgrid(
H
Hui Zhang 已提交
597
                    maxtime=second_per_example,
H
Hui Zhang 已提交
598
                    intervals=tierformat,
H
Hui Zhang 已提交
599 600 601 602 603 604 605 606 607
                    output=textgrid_path)

    def run_align(self):
        self.resume_or_scratch()
        try:
            self.align()
        except KeyboardInterrupt:
            sys.exit(-1)

608 609 610 611 612 613 614 615
    def load_inferspec(self):
        """infer model and input spec.

        Returns:
            nn.Layer: inference model
            List[paddle.static.InputSpec]: input spec.
        """
        from deepspeech.models.u2 import U2InferModel
H
Haoxin Ma 已提交
616
        infer_model = U2InferModel.from_pretrained(self.test_loader,
617 618
                                                   self.config.model.clone(),
                                                   self.args.checkpoint_path)
H
Haoxin Ma 已提交
619
        feat_dim = self.test_loader.collate_fn.feature_size
620
        input_spec = [
621 622 623
            paddle.static.InputSpec(shape=[1, None, feat_dim],
                                    dtype='float32'),  # audio, [B,T,D]
            paddle.static.InputSpec(shape=[1],
624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668
                                    dtype='int64'),  # audio_length, [B]
        ]
        return infer_model, input_spec

    def export(self):
        infer_model, input_spec = self.load_inferspec()
        assert isinstance(input_spec, list), type(input_spec)
        infer_model.eval()
        static_model = paddle.jit.to_static(infer_model, input_spec=input_spec)
        logger.info(f"Export code: {static_model.forward.code}")
        paddle.jit.save(static_model, self.args.export_path)

    def run_export(self):
        try:
            self.export()
        except KeyboardInterrupt:
            sys.exit(-1)

    def setup(self):
        """Setup the experiment.
        """
        paddle.set_device(self.args.device)

        self.setup_output_dir()
        self.setup_checkpointer()

        self.setup_dataloader()
        self.setup_model()

        self.iteration = 0
        self.epoch = 0

    def setup_output_dir(self):
        """Create a directory used for output.
        """
        # output dir
        if self.args.output:
            output_dir = Path(self.args.output).expanduser()
            output_dir.mkdir(parents=True, exist_ok=True)
        else:
            output_dir = Path(
                self.args.checkpoint_path).expanduser().parent.parent
            output_dir.mkdir(parents=True, exist_ok=True)

        self.output_dir = output_dir