You need to sign in or sign up before continuing.
model.py 25.6 KB
Newer Older
H
Hui Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
"""Contains DeepSpeech2 and DeepSpeech2Online model."""
15
import os
H
Hui Zhang 已提交
16 17
import time
from collections import defaultdict
18
from contextlib import nullcontext
H
Hui Zhang 已提交
19
from pathlib import Path
H
Haoxin Ma 已提交
20
from typing import Optional
H
Hui Zhang 已提交
21

H
Hui Zhang 已提交
22
import jsonlines
23
import numpy as np
H
Hui Zhang 已提交
24 25
import paddle
from paddle import distributed as dist
H
huangyuxin 已提交
26
from paddle import inference
H
Hui Zhang 已提交
27
from paddle.io import DataLoader
H
Haoxin Ma 已提交
28
from yacs.config import CfgNode
H
Hui Zhang 已提交
29 30 31

from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
32 33
from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler
H
huangyuxin 已提交
34 35
from deepspeech.models.ds2 import DeepSpeech2InferModel
from deepspeech.models.ds2 import DeepSpeech2Model
H
huangyuxin 已提交
36 37
from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
38
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
H
Hui Zhang 已提交
39
from deepspeech.training.reporter import report
H
format  
Hui Zhang 已提交
40
from deepspeech.training.trainer import Trainer
41 42 43
from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
J
Jackwaterveg 已提交
44
from deepspeech.utils.log import Autolog
45
from deepspeech.utils.log import Log
H
Hui Zhang 已提交
46
from deepspeech.utils.utility import UpdateConfig
47

H
huangyuxin 已提交
48
logger = Log(__name__).getlog()
49

H
Hui Zhang 已提交
50 51

class DeepSpeech2Trainer(Trainer):
52 53 54 55
    @classmethod
    def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
        # training config
        default = CfgNode(
H
Haoxin Ma 已提交
56 57 58 59 60 61 62
            dict(
                lr=5e-4,  # learning rate
                lr_decay=1.0,  # learning rate decay
                weight_decay=1e-6,  # the coeff of weight decay
                global_grad_clip=5.0,  # the global norm clip
                n_epoch=50,  # train epochs
            ))
63 64 65 66 67

        if config is not None:
            config.merge_from_other_cfg(default)
        return default

H
Hui Zhang 已提交
68 69 70
    def __init__(self, config, args):
        super().__init__(config, args)

71
    def train_batch(self, batch_index, batch_data, msg):
H
Hui Zhang 已提交
72 73 74
        batch_size = self.config.collator.batch_size
        accum_grad = self.config.training.accum_grad

H
Hui Zhang 已提交
75
        start = time.time()
76 77

        # forward
H
Haoxin Ma 已提交
78 79
        utt, audio, audio_len, text, text_len = batch_data
        loss = self.model(audio, audio_len, text, text_len)
H
Hui Zhang 已提交
80 81 82
        losses_np = {
            'train_loss': float(loss),
        }
83 84

        # loss backward
H
Hui Zhang 已提交
85
        if (batch_index + 1) % accum_grad != 0:
86 87 88 89 90 91 92 93 94 95 96 97 98 99
            # Disable gradient synchronizations across DDP processes.
            # Within this context, gradients will be accumulated on module
            # variables, which will later be synchronized.
            context = self.model.no_sync
        else:
            # Used for single gpu training and DDP gradient synchronization
            # processes.
            context = nullcontext

        with context():
            loss.backward()
            layer_tools.print_grads(self.model, print_func=None)

        # optimizer step
H
Hui Zhang 已提交
100
        if (batch_index + 1) % accum_grad == 0:
101 102 103 104 105 106
            self.optimizer.step()
            self.optimizer.clear_grad()
            self.iteration += 1

        iteration_time = time.time() - start

H
Hui Zhang 已提交
107 108 109 110 111
        for k, v in losses_np.items():
            report(k, v)
        report("batch_size", batch_size)
        report("accum", accum_grad)
        report("step_cost", iteration_time)
H
format  
Hui Zhang 已提交
112

H
Hui Zhang 已提交
113 114
        if dist.get_rank() == 0 and self.visualizer:
            for k, v in losses_np.items():
115
                # `step -1` since we update `step` after optimizer.step().
H
Hui Zhang 已提交
116
                self.visualizer.add_scalar("train/{}".format(k), v,
117
                                           self.iteration - 1)
H
Hui Zhang 已提交
118 119 120

    @paddle.no_grad()
    def valid(self):
121
        logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
H
Hui Zhang 已提交
122 123
        self.model.eval()
        valid_losses = defaultdict(list)
124 125
        num_seen_utts = 1
        total_loss = 0.0
H
Hui Zhang 已提交
126
        for i, batch in enumerate(self.valid_loader):
H
Haoxin Ma 已提交
127 128
            utt, audio, audio_len, text, text_len = batch
            loss = self.model(audio, audio_len, text, text_len)
129
            if paddle.isfinite(loss):
H
Haoxin Ma 已提交
130
                num_utts = batch[1].shape[0]
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
                num_seen_utts += num_utts
                total_loss += float(loss) * num_utts
                valid_losses['val_loss'].append(float(loss))

            if (i + 1) % self.config.training.log_interval == 0:
                valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
                valid_dump['val_history_loss'] = total_loss / num_seen_utts

                # logging
                msg = f"Valid: Rank: {dist.get_rank()}, "
                msg += "epoch: {}, ".format(self.epoch)
                msg += "step: {}, ".format(self.iteration)
                msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader))
                msg += ', '.join('{}: {:>.6f}'.format(k, v)
                                 for k, v in valid_dump.items())
                logger.info(msg)

        logger.info('Rank {} Val info val_loss {}'.format(
            dist.get_rank(), total_loss / num_seen_utts))
        return total_loss, num_seen_utts
H
Hui Zhang 已提交
151 152

    def setup_model(self):
153
        config = self.config.clone()
H
Hui Zhang 已提交
154 155 156
        with UpdateConfig(config):
            config.model.feat_size = self.train_loader.collate_fn.feature_size
            config.model.dict_size = self.train_loader.collate_fn.vocab_size
157

H
huangyuxin 已提交
158
        if self.args.model_type == 'offline':
159
            model = DeepSpeech2Model.from_config(config.model)
H
huangyuxin 已提交
160
        elif self.args.model_type == 'online':
161
            model = DeepSpeech2ModelOnline.from_config(config.model)
H
huangyuxin 已提交
162 163
        else:
            raise Exception("wrong model type")
H
Hui Zhang 已提交
164 165 166
        if self.parallel:
            model = paddle.DataParallel(model)

167 168
        logger.info(f"{model}")
        layer_tools.print_params(model, logger.info)
H
Hui Zhang 已提交
169

170 171
        grad_clip = ClipGradByGlobalNormWithLog(
            config.training.global_grad_clip)
H
Hui Zhang 已提交
172 173 174 175 176 177 178 179 180 181 182 183 184 185
        lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
            learning_rate=config.training.lr,
            gamma=config.training.lr_decay,
            verbose=True)
        optimizer = paddle.optimizer.Adam(
            learning_rate=lr_scheduler,
            parameters=model.parameters(),
            weight_decay=paddle.regularizer.L2Decay(
                config.training.weight_decay),
            grad_clip=grad_clip)

        self.model = model
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
186
        logger.info("Setup model/optimizer/lr_scheduler!")
H
Hui Zhang 已提交
187 188

    def setup_dataloader(self):
189 190
        config = self.config.clone()
        config.defrost()
H
Haoxin Ma 已提交
191
        config.collator.keep_transcription_text = False
192 193 194

        config.data.manifest = config.data.train_manifest
        train_dataset = ManifestDataset.from_config(config)
H
Hui Zhang 已提交
195

196 197
        config.data.manifest = config.data.dev_manifest
        dev_dataset = ManifestDataset.from_config(config)
H
Hui Zhang 已提交
198

199 200 201
        config.data.manifest = config.data.test_manifest
        test_dataset = ManifestDataset.from_config(config)

H
Hui Zhang 已提交
202 203 204
        if self.parallel:
            batch_sampler = SortagradDistributedBatchSampler(
                train_dataset,
205
                batch_size=config.collator.batch_size,
H
Hui Zhang 已提交
206 207 208 209
                num_replicas=None,
                rank=None,
                shuffle=True,
                drop_last=True,
210 211
                sortagrad=config.collator.sortagrad,
                shuffle_method=config.collator.shuffle_method)
H
Hui Zhang 已提交
212 213 214 215
        else:
            batch_sampler = SortagradBatchSampler(
                train_dataset,
                shuffle=True,
216
                batch_size=config.collator.batch_size,
H
Hui Zhang 已提交
217
                drop_last=True,
218 219
                sortagrad=config.collator.sortagrad,
                shuffle_method=config.collator.shuffle_method)
H
Hui Zhang 已提交
220

H
Haoxin Ma 已提交
221 222 223 224
        collate_fn_train = SpeechCollator.from_config(config)

        config.collator.augmentation_config = ""
        collate_fn_dev = SpeechCollator.from_config(config)
225 226 227 228 229

        config.collator.keep_transcription_text = True
        config.collator.augmentation_config = ""
        collate_fn_test = SpeechCollator.from_config(config)

H
Hui Zhang 已提交
230 231 232
        self.train_loader = DataLoader(
            train_dataset,
            batch_sampler=batch_sampler,
H
Haoxin Ma 已提交
233
            collate_fn=collate_fn_train,
234
            num_workers=config.collator.num_workers)
H
Hui Zhang 已提交
235 236
        self.valid_loader = DataLoader(
            dev_dataset,
237
            batch_size=config.collator.batch_size,
H
Hui Zhang 已提交
238 239
            shuffle=False,
            drop_last=False,
H
Haoxin Ma 已提交
240
            collate_fn=collate_fn_dev)
241 242 243 244 245 246 247
        self.test_loader = DataLoader(
            test_dataset,
            batch_size=config.decoding.batch_size,
            shuffle=False,
            drop_last=False,
            collate_fn=collate_fn_test)
        logger.info("Setup train/valid/test  Dataloader!")
H
Hui Zhang 已提交
248 249 250


class DeepSpeech2Tester(DeepSpeech2Trainer):
251 252 253 254
    @classmethod
    def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
        # testing config
        default = CfgNode(
H
Haoxin Ma 已提交
255 256 257 258 259 260 261 262 263 264 265 266
            dict(
                alpha=2.5,  # Coef of LM for beam search.
                beta=0.3,  # Coef of WC for beam search.
                cutoff_prob=1.0,  # Cutoff probability for pruning.
                cutoff_top_n=40,  # Cutoff number for pruning.
                lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm',  # Filepath for language model.
                decoding_method='ctc_beam_search',  # Decoding method. Options: ctc_beam_search, ctc_greedy
                error_rate_type='wer',  # Error rate type for evaluation. Options `wer`, 'cer'
                num_proc_bsearch=8,  # # of CPUs for beam search.
                beam_size=500,  # Beam search width.
                batch_size=128,  # decoding batch size
            ))
267 268 269 270 271

        if config is not None:
            config.merge_from_other_cfg(default)
        return default

H
Hui Zhang 已提交
272 273
    def __init__(self, config, args):
        super().__init__(config, args)
274

H
Hui Zhang 已提交
275 276 277 278 279 280 281 282 283
    def ordid2token(self, texts, texts_len):
        """ ord() id to chr() chr """
        trans = []
        for text, n in zip(texts, texts_len):
            n = n.numpy().item()
            ids = text[:n]
            trans.append(''.join([chr(i) for i in ids]))
        return trans

H
Haoxin Ma 已提交
284 285 286 287 288 289 290
    def compute_metrics(self,
                        utts,
                        audio,
                        audio_len,
                        texts,
                        texts_len,
                        fout=None):
H
Hui Zhang 已提交
291 292 293 294 295
        cfg = self.config.decoding
        errors_sum, len_refs, num_ins = 0.0, 0, 0
        errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
        error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer

H
Haoxin Ma 已提交
296
        vocab_list = self.test_loader.collate_fn.vocab_list
H
Hui Zhang 已提交
297 298

        target_transcripts = self.ordid2token(texts, texts_len)
H
huangyuxin 已提交
299

H
huangyuxin 已提交
300 301
        result_transcripts = self.compute_result_transcripts(audio, audio_len,
                                                             vocab_list, cfg)
H
Haoxin Ma 已提交
302 303
        for utt, target, result in zip(utts, target_transcripts,
                                       result_transcripts):
H
Hui Zhang 已提交
304 305 306 307
            errors, len_ref = errors_func(target, result)
            errors_sum += errors
            len_refs += len_ref
            num_ins += 1
H
Haoxin Ma 已提交
308
            if fout:
H
Hui Zhang 已提交
309
                fout.write({"utt": utt, "ref": target, "hyp": result})
H
Hui Zhang 已提交
310 311 312
            logger.info(f"Utt: {utt}")
            logger.info(f"Ref: {target}")
            logger.info(f"Hyp: {result}")
313 314
            logger.info("Current error rate [%s] = %f" %
                        (cfg.error_rate_type, error_rate_func(target, result)))
H
Hui Zhang 已提交
315 316 317 318 319 320 321 322

        return dict(
            errors_sum=errors_sum,
            len_refs=len_refs,
            num_ins=num_ins,
            error_rate=errors_sum / len_refs,
            error_rate_type=cfg.error_rate_type)

H
huangyuxin 已提交
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
    def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
        self.autolog.times.start()
        self.autolog.times.stamp()
        result_transcripts = self.model.decode(
            audio,
            audio_len,
            vocab_list,
            decoding_method=cfg.decoding_method,
            lang_model_path=cfg.lang_model_path,
            beam_alpha=cfg.alpha,
            beam_beta=cfg.beta,
            beam_size=cfg.beam_size,
            cutoff_prob=cfg.cutoff_prob,
            cutoff_top_n=cfg.cutoff_top_n,
            num_processes=cfg.num_proc_bsearch)
        self.autolog.times.stamp()
        self.autolog.times.stamp()
        self.autolog.times.end()
        return result_transcripts

H
Hui Zhang 已提交
343 344 345
    @mp_tools.rank_zero_only
    @paddle.no_grad()
    def test(self):
346
        logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
347 348 349 350
        self.autolog = Autolog(
            batch_size=self.config.decoding.batch_size,
            model_name="deepspeech2",
            model_precision="fp32").getlog()
H
Hui Zhang 已提交
351 352 353 354
        self.model.eval()
        cfg = self.config
        error_rate_type = None
        errors_sum, len_refs, num_ins = 0.0, 0, 0
H
Hui Zhang 已提交
355
        with jsonlines.open(self.args.result_file, 'w') as fout:
H
Haoxin Ma 已提交
356 357
            for i, batch in enumerate(self.test_loader):
                utts, audio, audio_len, texts, texts_len = batch
H
Haoxin Ma 已提交
358 359
                metrics = self.compute_metrics(utts, audio, audio_len, texts,
                                               texts_len, fout)
H
Haoxin Ma 已提交
360 361 362 363 364 365
                errors_sum += metrics['errors_sum']
                len_refs += metrics['len_refs']
                num_ins += metrics['num_ins']
                error_rate_type = metrics['error_rate_type']
                logger.info("Error rate [%s] (%d/?) = %f" %
                            (error_rate_type, num_ins, errors_sum / len_refs))
H
Hui Zhang 已提交
366 367 368 369 370

        # logging
        msg = "Test: "
        msg += "epoch: {}, ".format(self.epoch)
        msg += "step: {}, ".format(self.iteration)
371
        msg += "Final error rate [%s] (%d/%d) = %f" % (
H
Hui Zhang 已提交
372
            error_rate_type, num_ins, num_ins, errors_sum / len_refs)
373
        logger.info(msg)
374
        self.autolog.report()
H
Hui Zhang 已提交
375 376

    def run_test(self):
377
        self.resume_or_scratch()
H
Hui Zhang 已提交
378 379 380 381 382 383
        try:
            self.test()
        except KeyboardInterrupt:
            exit(-1)

    def export(self):
H
huangyuxin 已提交
384 385 386 387
        if self.args.model_type == 'offline':
            infer_model = DeepSpeech2InferModel.from_pretrained(
                self.test_loader, self.config, self.args.checkpoint_path)
        elif self.args.model_type == 'online':
H
huangyuxin 已提交
388 389 390
            infer_model = DeepSpeech2InferModelOnline.from_pretrained(
                self.test_loader, self.config, self.args.checkpoint_path)
        else:
391
            raise Exception("wrong model type")
H
huangyuxin 已提交
392

393
        infer_model.eval()
H
Haoxin Ma 已提交
394
        feat_dim = self.test_loader.collate_fn.feature_size
395
        static_model = infer_model.export()
396 397
        logger.info(f"Export code: {static_model.forward.code}")
        paddle.jit.save(static_model, self.args.export_path)
H
Hui Zhang 已提交
398 399 400 401 402 403 404 405 406 407

    def run_export(self):
        try:
            self.export()
        except KeyboardInterrupt:
            exit(-1)

    def setup(self):
        """Setup the experiment.
        """
408
        paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
H
Hui Zhang 已提交
409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431

        self.setup_output_dir()
        self.setup_checkpointer()

        self.setup_dataloader()
        self.setup_model()

        self.iteration = 0
        self.epoch = 0

    def setup_output_dir(self):
        """Create a directory used for output.
        """
        # output dir
        if self.args.output:
            output_dir = Path(self.args.output).expanduser()
            output_dir.mkdir(parents=True, exist_ok=True)
        else:
            output_dir = Path(
                self.args.checkpoint_path).expanduser().parent.parent
            output_dir.mkdir(parents=True, exist_ok=True)

        self.output_dir = output_dir
H
huangyuxin 已提交
432 433


434
class DeepSpeech2ExportTester(DeepSpeech2Tester):
H
huangyuxin 已提交
435 436 437
    def __init__(self, config, args):
        super().__init__(config, args)

H
huangyuxin 已提交
438
    def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
439
        if self.args.model_type == "online":
H
huangyuxin 已提交
440 441
            output_probs, output_lens = self.static_forward_online(audio,
                                                                   audio_len)
442
        elif self.args.model_type == "offline":
H
huangyuxin 已提交
443 444
            output_probs, output_lens = self.static_forward_offline(audio,
                                                                    audio_len)
445 446
        else:
            raise Exception("wrong model type")
H
huangyuxin 已提交
447

448 449
        self.predictor.clear_intermediate_tensor()
        self.predictor.try_shrink_memory()
H
huangyuxin 已提交
450

451 452 453 454
        self.model.decoder.init_decode(cfg.alpha, cfg.beta, cfg.lang_model_path,
                                       vocab_list, cfg.decoding_method)

        result_transcripts = self.model.decoder.decode_probs(
H
huangyuxin 已提交
455 456 457
            output_probs, output_lens, vocab_list, cfg.decoding_method,
            cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size,
            cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch)
H
huangyuxin 已提交
458

H
huangyuxin 已提交
459
        return result_transcripts
460

H
huangyuxin 已提交
461 462 463 464 465 466 467 468 469 470 471 472 473
    def static_forward_online(self, audio, audio_len,
                              decoder_chunk_size: int=1):
        """
        Parameters
        ----------
            audio (Tensor): shape[B, T, D]
            audio_len (Tensor): shape[B]
            decoder_chunk_size(int)
        Returns
        -------
            output_probs(numpy.array): shape[B, T, vocab_size]
            output_lens(numpy.array): shape[B]
        """
474
        output_probs_list = []
H
huangyuxin 已提交
475 476 477 478 479 480 481 482
        output_lens_list = []
        subsampling_rate = self.model.encoder.conv.subsampling_rate
        receptive_field_length = self.model.encoder.conv.receptive_field_length
        chunk_stride = subsampling_rate * decoder_chunk_size
        chunk_size = (decoder_chunk_size - 1
                      ) * subsampling_rate + receptive_field_length

        x_batch = audio.numpy()
H
huangyuxin 已提交
483
        batch_size, Tmax, x_dim = x_batch.shape
H
huangyuxin 已提交
484
        x_len_batch = audio_len.numpy().astype(np.int64)
H
huangyuxin 已提交
485 486 487 488 489 490
        if (Tmax - chunk_size) % chunk_stride != 0:
            padding_len_batch = chunk_stride - (
                Tmax - chunk_size
            ) % chunk_stride  # The length of padding for the batch
        else:
            padding_len_batch = 0
491
        x_list = np.split(x_batch, batch_size, axis=0)
H
huangyuxin 已提交
492
        x_len_list = np.split(x_len_batch, batch_size, axis=0)
H
huangyuxin 已提交
493 494

        for x, x_len in zip(x_list, x_len_list):
495 496
            self.autolog.times.start()
            self.autolog.times.stamp()
H
huangyuxin 已提交
497 498
            x_len = x_len[0]
            assert (chunk_size <= x_len)
H
huangyuxin 已提交
499

H
huangyuxin 已提交
500 501 502 503 504
            if (x_len - chunk_size) % chunk_stride != 0:
                padding_len_x = chunk_stride - (x_len - chunk_size
                                                ) % chunk_stride
            else:
                padding_len_x = 0
H
huangyuxin 已提交
505 506

            padding = np.zeros(
H
huangyuxin 已提交
507
                (x.shape[0], padding_len_x, x.shape[2]), dtype=x.dtype)
H
huangyuxin 已提交
508 509
            padded_x = np.concatenate([x, padding], axis=1)

H
huangyuxin 已提交
510
            num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1
H
huangyuxin 已提交
511 512 513 514 515
            num_chunk = int(num_chunk)

            chunk_state_h_box = np.zeros(
                (self.config.model.num_rnn_layers, 1,
                 self.config.model.rnn_layer_size),
H
huangyuxin 已提交
516
                dtype=x.dtype)
H
huangyuxin 已提交
517 518 519
            chunk_state_c_box = np.zeros(
                (self.config.model.num_rnn_layers, 1,
                 self.config.model.rnn_layer_size),
H
huangyuxin 已提交
520
                dtype=x.dtype)
H
huangyuxin 已提交
521 522 523 524 525 526 527 528 529 530 531 532 533

            input_names = self.predictor.get_input_names()
            audio_handle = self.predictor.get_input_handle(input_names[0])
            audio_len_handle = self.predictor.get_input_handle(input_names[1])
            h_box_handle = self.predictor.get_input_handle(input_names[2])
            c_box_handle = self.predictor.get_input_handle(input_names[3])

            probs_chunk_list = []
            probs_chunk_lens_list = []
            for i in range(0, num_chunk):
                start = i * chunk_stride
                end = start + chunk_size
                x_chunk = padded_x[:, start:end, :]
H
huangyuxin 已提交
534 535 536 537 538 539
                if x_len < i * chunk_stride:
                    x_chunk_lens = 0
                else:
                    x_chunk_lens = min(x_len - i * chunk_stride, chunk_size)

                if (x_chunk_lens <
H
huangyuxin 已提交
540 541
                        receptive_field_length):  #means the number of input frames in the chunk is not enough for predicting one prob
                    break
H
huangyuxin 已提交
542
                x_chunk_lens = np.array([x_chunk_lens])
H
huangyuxin 已提交
543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564
                audio_handle.reshape(x_chunk.shape)
                audio_handle.copy_from_cpu(x_chunk)

                audio_len_handle.reshape(x_chunk_lens.shape)
                audio_len_handle.copy_from_cpu(x_chunk_lens)

                h_box_handle.reshape(chunk_state_h_box.shape)
                h_box_handle.copy_from_cpu(chunk_state_h_box)

                c_box_handle.reshape(chunk_state_c_box.shape)
                c_box_handle.copy_from_cpu(chunk_state_c_box)

                output_names = self.predictor.get_output_names()
                output_handle = self.predictor.get_output_handle(
                    output_names[0])
                output_lens_handle = self.predictor.get_output_handle(
                    output_names[1])
                output_state_h_handle = self.predictor.get_output_handle(
                    output_names[2])
                output_state_c_handle = self.predictor.get_output_handle(
                    output_names[3])
                self.predictor.run()
565
                output_chunk_probs = output_handle.copy_to_cpu()
H
huangyuxin 已提交
566 567 568 569
                output_chunk_lens = output_lens_handle.copy_to_cpu()
                chunk_state_h_box = output_state_h_handle.copy_to_cpu()
                chunk_state_c_box = output_state_c_handle.copy_to_cpu()

570
                probs_chunk_list.append(output_chunk_probs)
H
huangyuxin 已提交
571
                probs_chunk_lens_list.append(output_chunk_lens)
H
huangyuxin 已提交
572 573
            output_probs = np.concatenate(probs_chunk_list, axis=1)
            output_lens = np.sum(probs_chunk_lens_list, axis=0)
H
huangyuxin 已提交
574 575
            vocab_size = output_probs.shape[2]
            output_probs_padding_len = Tmax + padding_len_batch - output_probs.shape[
H
huangyuxin 已提交
576
                1]
H
huangyuxin 已提交
577
            output_probs_padding = np.zeros(
H
huangyuxin 已提交
578 579 580
                (1, output_probs_padding_len, vocab_size),
                dtype=output_probs.
                dtype)  # The prob padding for a piece of utterance
H
huangyuxin 已提交
581
            output_probs = np.concatenate(
582 583
                [output_probs, output_probs_padding], axis=1)
            output_probs_list.append(output_probs)
H
huangyuxin 已提交
584
            output_lens_list.append(output_lens)
585 586 587
            self.autolog.times.stamp()
            self.autolog.times.stamp()
            self.autolog.times.end()
H
huangyuxin 已提交
588 589 590
        output_probs = np.concatenate(output_probs_list, axis=0)
        output_lens = np.concatenate(output_lens_list, axis=0)
        return output_probs, output_lens
591 592

    def static_forward_offline(self, audio, audio_len):
H
huangyuxin 已提交
593 594 595 596 597 598 599 600 601 602 603
        """
        Parameters
        ----------
            audio (Tensor): shape[B, T, D]
            audio_len (Tensor): shape[B]

        Returns
        -------
            output_probs(numpy.array): shape[B, T, vocab_size]
            output_lens(numpy.array): shape[B]
        """
H
huangyuxin 已提交
604 605 606 607 608 609 610 611 612 613 614 615 616
        x = audio.numpy()
        x_len = audio_len.numpy().astype(np.int64)

        input_names = self.predictor.get_input_names()
        audio_handle = self.predictor.get_input_handle(input_names[0])
        audio_len_handle = self.predictor.get_input_handle(input_names[1])

        audio_handle.reshape(x.shape)
        audio_handle.copy_from_cpu(x)

        audio_len_handle.reshape(x_len.shape)
        audio_len_handle.copy_from_cpu(x_len)

617 618
        self.autolog.times.start()
        self.autolog.times.stamp()
H
huangyuxin 已提交
619
        self.predictor.run()
620 621 622
        self.autolog.times.stamp()
        self.autolog.times.stamp()
        self.autolog.times.end()
H
huangyuxin 已提交
623 624 625 626

        output_names = self.predictor.get_output_names()
        output_handle = self.predictor.get_output_handle(output_names[0])
        output_lens_handle = self.predictor.get_output_handle(output_names[1])
627
        output_probs = output_handle.copy_to_cpu()
H
huangyuxin 已提交
628
        output_lens = output_lens_handle.copy_to_cpu()
H
huangyuxin 已提交
629
        return output_probs, output_lens
H
huangyuxin 已提交
630 631 632 633 634 635 636 637 638 639

    def run_test(self):
        try:
            self.test()
        except KeyboardInterrupt:
            exit(-1)

    def setup(self):
        """Setup the experiment.
        """
640
        paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
H
huangyuxin 已提交
641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664

        self.setup_output_dir()

        self.setup_dataloader()
        self.setup_model()

        self.iteration = 0
        self.epoch = 0

    def setup_output_dir(self):
        """Create a directory used for output.
        """
        # output dir
        if self.args.output:
            output_dir = Path(self.args.output).expanduser()
            output_dir.mkdir(parents=True, exist_ok=True)
        else:
            output_dir = Path(self.args.export_path).expanduser().parent.parent
            output_dir.mkdir(parents=True, exist_ok=True)

        self.output_dir = output_dir

    def setup_model(self):
        super().setup_model()
665 666 667 668
        speedyspeech_config = inference.Config(
            self.args.export_path + ".pdmodel",
            self.args.export_path + ".pdiparams")
        if (os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''):
H
huangyuxin 已提交
669 670
            speedyspeech_config.enable_use_gpu(100, 0)
            speedyspeech_config.enable_memory_optim()
671 672
        speedyspeech_predictor = inference.create_predictor(speedyspeech_config)
        self.predictor = speedyspeech_predictor