From 92d1d08b9a9c8fbe96457024d60b60240fa3bc79 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Tue, 5 Jul 2022 09:23:01 +0000 Subject: [PATCH] fix scripts --- examples/wenetspeech/asr1/conf/conformer.yaml | 20 +- examples/wenetspeech/asr1/local/data.sh | 4 +- paddlespeech/audio/streamdata/__init__.py | 17 +- paddlespeech/audio/streamdata/autodecode.py | 8 +- paddlespeech/audio/streamdata/compat.py | 24 +- paddlespeech/audio/streamdata/filters.py | 28 +- paddlespeech/audio/streamdata/tariterators.py | 6 +- paddlespeech/audio/text/text_featurizer.py | 235 +++++++++++ paddlespeech/audio/text/utility.py | 393 ++++++++++++++++++ paddlespeech/s2t/exps/deepspeech2/model.py | 2 +- paddlespeech/s2t/exps/u2/model.py | 185 +-------- paddlespeech/s2t/exps/u2_kaldi/model.py | 116 ++---- paddlespeech/s2t/exps/u2_st/model.py | 93 +---- paddlespeech/s2t/io/dataloader.py | 177 +++++++- paddlespeech/s2t/models/u2/u2.py | 6 +- paddlespeech/s2t/models/u2_st/u2_st.py | 4 +- 16 files changed, 901 insertions(+), 417 deletions(-) create mode 100644 paddlespeech/audio/text/text_featurizer.py create mode 100644 paddlespeech/audio/text/utility.py diff --git a/examples/wenetspeech/asr1/conf/conformer.yaml b/examples/wenetspeech/asr1/conf/conformer.yaml index 013c3e0c..d1ac20b9 100644 --- a/examples/wenetspeech/asr1/conf/conformer.yaml +++ b/examples/wenetspeech/asr1/conf/conformer.yaml @@ -52,6 +52,7 @@ test_manifest: data/test_meeting/data.list use_stream_data: True unit_type: 'char' vocab_filepath: data/lang_char/vocab.txt +preprocess_config: conf/preprocess.yaml cmvn_file: data/mean_std.json spm_model_prefix: '' feat_dim: 80 @@ -65,30 +66,17 @@ maxlen_in: 1200 # if input length(number of frames) > maxlen-in, data is automa minlen_out: 0 maxlen_out: 150 # if output length(number of tokens) > maxlen-out, data is automatically removed resample_rate: 16000 -shuffle_size: 1500 -sort_size: 1000 +shuffle_size: 1500 # read number of 'shuffle_size' data as a chunk, shuffle the data in the chunk +sort_size: 1000 # read number of 'sort_size' data as a chunk, sort the data in the chunk num_workers: 8 prefetch_factor: 10 dist_sampler: True num_encs: 1 -augment_conf: - max_w: 80 - w_inplace: True - w_mode: "PIL" - max_f: 30 - num_f_mask: 2 - f_inplace: True - f_replace_with_zero: False - max_t: 40 - num_t_mask: 2 - t_inplace: True - t_replace_with_zero: False - ########################################### # Training # ########################################### -n_epoch: 30 +n_epoch: 32 accum_grad: 32 global_grad_clip: 5.0 log_interval: 100 diff --git a/examples/wenetspeech/asr1/local/data.sh b/examples/wenetspeech/asr1/local/data.sh index b3472a8f..62579ba3 100755 --- a/examples/wenetspeech/asr1/local/data.sh +++ b/examples/wenetspeech/asr1/local/data.sh @@ -90,8 +90,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then for x in $dev_set $test_sets ${train_set}; do dst=$shards_dir/$x mkdir -p $dst - utils/make_filted_shard_list.py --resample 16000 --num_utts_per_shard 1000 \ - --do_filter --num_node 1 --num_gpus_per_node 8 \ + utils/make_filted_shard_list.py --num_node 1 --num_gpus_per_node 8 --num_utts_per_shard 1000 \ + --do_filter --resample 16000 \ --num_threads 32 --segments data/$x/segments \ data/$x/wav.scp data/$x/text \ $(realpath $dst) data/$x/data.list diff --git a/paddlespeech/audio/streamdata/__init__.py b/paddlespeech/audio/streamdata/__init__.py index 1acd898a..753fcc11 100644 --- a/paddlespeech/audio/streamdata/__init__.py +++ b/paddlespeech/audio/streamdata/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset # @@ -26,7 +27,7 @@ from .filters import ( pipelinefilter, rename, rename_keys, - rsample, + audio_resample, select, shuffle, slice, @@ -34,14 +35,14 @@ from .filters import ( transform_with, unbatched, xdecode, - data_filter, - tokenize, - resample, - compute_fbank, - spec_aug, + audio_data_filter, + audio_tokenize, + audio_resample, + audio_compute_fbank, + audio_spec_aug, sort, - padding, - cmvn, + audio_padding, + audio_cmvn, placeholder, ) from .handlers import ( diff --git a/paddlespeech/audio/streamdata/autodecode.py b/paddlespeech/audio/streamdata/autodecode.py index 8c74b685..ca0e2ea2 100644 --- a/paddlespeech/audio/streamdata/autodecode.py +++ b/paddlespeech/audio/streamdata/autodecode.py @@ -291,12 +291,12 @@ def torch_video(key, data): ################################################################ -# paddleaudio +# paddlespeech.audio ################################################################ def paddle_audio(key, data): - """Decode audio using the paddleaudio library. + """Decode audio using the paddlespeech.audio library. :param key: file name extension :param data: data to be decoded @@ -305,13 +305,13 @@ def paddle_audio(key, data): if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]: return None - import paddleaudio + import paddlespeech.audio with tempfile.TemporaryDirectory() as dirname: fname = os.path.join(dirname, f"file.{extension}") with open(fname, "wb") as stream: stream.write(data) - return paddleaudio.load(fname) + return paddlespeech.audio.load(fname) ################################################################ diff --git a/paddlespeech/audio/streamdata/compat.py b/paddlespeech/audio/streamdata/compat.py index 11308d03..deda5338 100644 --- a/paddlespeech/audio/streamdata/compat.py +++ b/paddlespeech/audio/streamdata/compat.py @@ -78,29 +78,29 @@ class FluidInterface: def xdecode(self, *args, **kw): return self.compose(filters.xdecode(*args, **kw)) - def data_filter(self, *args, **kw): - return self.compose(filters.data_filter(*args, **kw)) + def audio_data_filter(self, *args, **kw): + return self.compose(filters.audio_data_filter(*args, **kw)) - def tokenize(self, *args, **kw): - return self.compose(filters.tokenize(*args, **kw)) + def audio_tokenize(self, *args, **kw): + return self.compose(filters.audio_tokenize(*args, **kw)) def resample(self, *args, **kw): return self.compose(filters.resample(*args, **kw)) - def compute_fbank(self, *args, **kw): - return self.compose(filters.compute_fbank(*args, **kw)) + def audio_compute_fbank(self, *args, **kw): + return self.compose(filters.audio_compute_fbank(*args, **kw)) - def spec_aug(self, *args, **kw): - return self.compose(filters.spec_aug(*args, **kw)) + def audio_spec_aug(self, *args, **kw): + return self.compose(filters.audio_spec_aug(*args, **kw)) def sort(self, size=500): return self.compose(filters.sort(size)) - def padding(self): - return self.compose(filters.padding()) + def audio_padding(self): + return self.compose(filters.audio_padding()) - def cmvn(self, cmvn_file): - return self.compose(filters.cmvn(cmvn_file)) + def audio_cmvn(self, cmvn_file): + return self.compose(filters.audio_cmvn(cmvn_file)) class WebDataset(DataPipeline, FluidInterface): """Small fluid-interface wrapper for DataPipeline.""" diff --git a/paddlespeech/audio/streamdata/filters.py b/paddlespeech/audio/streamdata/filters.py index 0ade66f9..82b9c6ba 100644 --- a/paddlespeech/audio/streamdata/filters.py +++ b/paddlespeech/audio/streamdata/filters.py @@ -579,7 +579,7 @@ xdecode = pipelinefilter(_xdecode) -def _data_filter(source, +def _audio_data_filter(source, frame_shift=10, max_length=10240, min_length=10, @@ -629,9 +629,9 @@ def _data_filter(source, continue yield sample -data_filter = pipelinefilter(_data_filter) +audio_data_filter = pipelinefilter(_audio_data_filter) -def _tokenize(source, +def _audio_tokenize(source, symbol_table, bpe_model=None, non_lang_syms=None, @@ -693,9 +693,9 @@ def _tokenize(source, sample['label'] = label yield sample -tokenize = pipelinefilter(_tokenize) +audio_tokenize = pipelinefilter(_audio_tokenize) -def _resample(source, resample_rate=16000): +def _audio_resample(source, resample_rate=16000): """ Resample data. Inplace operation. @@ -718,9 +718,9 @@ def _resample(source, resample_rate=16000): )) yield sample -resample = pipelinefilter(_resample) +audio_resample = pipelinefilter(_audio_resample) -def _compute_fbank(source, +def _audio_compute_fbank(source, num_mel_bins=80, frame_length=25, frame_shift=10, @@ -756,9 +756,9 @@ def _compute_fbank(source, yield dict(fname=sample['fname'], label=sample['label'], feat=mat) -compute_fbank = pipelinefilter(_compute_fbank) +audio_compute_fbank = pipelinefilter(_audio_compute_fbank) -def _spec_aug(source, +def _audio_spec_aug(source, max_w=5, w_inplace=True, w_mode="PIL", @@ -799,7 +799,7 @@ def _spec_aug(source, sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32) yield sample -spec_aug = pipelinefilter(_spec_aug) +audio_spec_aug = pipelinefilter(_audio_spec_aug) def _sort(source, sort_size=500): @@ -881,7 +881,7 @@ def dynamic_batched(source, max_frames_in_batch=12000): yield buf -def _padding(source): +def _audio_padding(source): """ Padding the data into training data Args: @@ -914,9 +914,9 @@ def _padding(source): yield (sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths) -padding = pipelinefilter(_padding) +audio_padding = pipelinefilter(_audio_padding) -def _cmvn(source, cmvn_file): +def _audio_cmvn(source, cmvn_file): global_cmvn = GlobalCMVN(cmvn_file) for batch in source: sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths = batch @@ -926,7 +926,7 @@ def _cmvn(source, cmvn_file): yield (sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths) -cmvn = pipelinefilter(_cmvn) +audio_cmvn = pipelinefilter(_audio_cmvn) def _placeholder(source): for data in source: diff --git a/paddlespeech/audio/streamdata/tariterators.py b/paddlespeech/audio/streamdata/tariterators.py index 2c1daae1..b1616918 100644 --- a/paddlespeech/audio/streamdata/tariterators.py +++ b/paddlespeech/audio/streamdata/tariterators.py @@ -21,7 +21,7 @@ trace = False meta_prefix = "__" meta_suffix = "__" -from ... import audio as paddleaudio +import paddlespeech import paddle import numpy as np @@ -118,7 +118,7 @@ def tar_file_iterator( assert pos > 0 prefix, postfix = name[:pos], name[pos + 1:] if postfix == 'wav': - waveform, sample_rate = paddleaudio.load(stream.extractfile(tarinfo), normal=False) + waveform, sample_rate = paddlespeech.audio.load(stream.extractfile(tarinfo), normal=False) result = dict(fname=prefix, wav=waveform, sample_rate = sample_rate) else: txt = stream.extractfile(tarinfo).read().decode('utf8').strip() @@ -167,7 +167,7 @@ def tar_file_and_group_iterator( if postfix == 'txt': example['txt'] = file_obj.read().decode('utf8').strip() elif postfix in AUDIO_FORMAT_SETS: - waveform, sample_rate = paddleaudio.load(file_obj, normal=False) + waveform, sample_rate = paddlespeech.audio.load(file_obj, normal=False) waveform = paddle.to_tensor(np.expand_dims(np.array(waveform),0), dtype=paddle.float32) example['wav'] = waveform diff --git a/paddlespeech/audio/text/text_featurizer.py b/paddlespeech/audio/text/text_featurizer.py new file mode 100644 index 00000000..91c4d75c --- /dev/null +++ b/paddlespeech/audio/text/text_featurizer.py @@ -0,0 +1,235 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains the text featurizer class.""" +from pprint import pformat +from typing import Union + +import sentencepiece as spm + +from .utility import BLANK +from .utility import EOS +from .utility import load_dict +from .utility import MASKCTC +from .utility import SOS +from .utility import SPACE +from .utility import UNK +from ..utils.log import Logger + +logger = Logger(__name__) + +__all__ = ["TextFeaturizer"] + + +class TextFeaturizer(): + def __init__(self, unit_type, vocab, spm_model_prefix=None, maskctc=False): + """Text featurizer, for processing or extracting features from text. + + Currently, it supports char/word/sentence-piece level tokenizing and conversion into + a list of token indices. Note that the token indexing order follows the + given vocabulary file. + + Args: + unit_type (str): unit type, e.g. char, word, spm + vocab Option[str, list]: Filepath to load vocabulary for token indices conversion, or vocab list. + spm_model_prefix (str, optional): spm model prefix. Defaults to None. + """ + assert unit_type in ('char', 'spm', 'word') + self.unit_type = unit_type + self.unk = UNK + self.maskctc = maskctc + + if vocab: + self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file( + vocab, maskctc) + self.vocab_size = len(self.vocab_list) + else: + logger.warning("TextFeaturizer: not have vocab file or vocab list.") + + if unit_type == 'spm': + spm_model = spm_model_prefix + '.model' + self.sp = spm.SentencePieceProcessor() + self.sp.Load(spm_model) + + def tokenize(self, text, replace_space=True): + if self.unit_type == 'char': + tokens = self.char_tokenize(text, replace_space) + elif self.unit_type == 'word': + tokens = self.word_tokenize(text) + else: # spm + tokens = self.spm_tokenize(text) + return tokens + + def detokenize(self, tokens): + if self.unit_type == 'char': + text = self.char_detokenize(tokens) + elif self.unit_type == 'word': + text = self.word_detokenize(tokens) + else: # spm + text = self.spm_detokenize(tokens) + return text + + def featurize(self, text): + """Convert text string to a list of token indices. + + Args: + text (str): Text to process. + + Returns: + List[int]: List of token indices. + """ + tokens = self.tokenize(text) + ids = [] + for token in tokens: + if token not in self.vocab_dict: + logger.debug(f"Text Token: {token} -> {self.unk}") + token = self.unk + ids.append(self.vocab_dict[token]) + return ids + + def defeaturize(self, idxs): + """Convert a list of token indices to text string, + ignore index after eos_id. + + Args: + idxs (List[int]): List of token indices. + + Returns: + str: Text. + """ + tokens = [] + for idx in idxs: + if idx == self.eos_id: + break + tokens.append(self._id2token[idx]) + text = self.detokenize(tokens) + return text + + def char_tokenize(self, text, replace_space=True): + """Character tokenizer. + + Args: + text (str): text string. + replace_space (bool): False only used by build_vocab.py. + + Returns: + List[str]: tokens. + """ + text = text.strip() + if replace_space: + text_list = [SPACE if item == " " else item for item in list(text)] + else: + text_list = list(text) + return text_list + + def char_detokenize(self, tokens): + """Character detokenizer. + + Args: + tokens (List[str]): tokens. + + Returns: + str: text string. + """ + tokens = [t.replace(SPACE, " ") for t in tokens] + return "".join(tokens) + + def word_tokenize(self, text): + """Word tokenizer, separate by .""" + return text.strip().split() + + def word_detokenize(self, tokens): + """Word detokenizer, separate by .""" + return " ".join(tokens) + + def spm_tokenize(self, text): + """spm tokenize. + + Args: + text (str): text string. + + Returns: + List[str]: sentence pieces str code + """ + stats = {"num_empty": 0, "num_filtered": 0} + + def valid(line): + return True + + def encode(l): + return self.sp.EncodeAsPieces(l) + + def encode_line(line): + line = line.strip() + if len(line) > 0: + line = encode(line) + if valid(line): + return line + else: + stats["num_filtered"] += 1 + else: + stats["num_empty"] += 1 + return None + + enc_line = encode_line(text) + return enc_line + + def spm_detokenize(self, tokens, input_format='piece'): + """spm detokenize. + + Args: + ids (List[str]): tokens. + + Returns: + str: text + """ + if input_format == "piece": + + def decode(l): + return "".join(self.sp.DecodePieces(l)) + elif input_format == "id": + + def decode(l): + return "".join(self.sp.DecodeIds(l)) + + return decode(tokens) + + def _load_vocabulary_from_file(self, vocab: Union[str, list], + maskctc: bool): + """Load vocabulary from file.""" + if isinstance(vocab, list): + vocab_list = vocab + else: + vocab_list = load_dict(vocab, maskctc) + assert vocab_list is not None + logger.debug(f"Vocab: {pformat(vocab_list)}") + + id2token = dict( + [(idx, token) for (idx, token) in enumerate(vocab_list)]) + token2id = dict( + [(token, idx) for (idx, token) in enumerate(vocab_list)]) + + blank_id = vocab_list.index(BLANK) if BLANK in vocab_list else -1 + maskctc_id = vocab_list.index(MASKCTC) if MASKCTC in vocab_list else -1 + unk_id = vocab_list.index(UNK) if UNK in vocab_list else -1 + eos_id = vocab_list.index(EOS) if EOS in vocab_list else -1 + sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1 + space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1 + + logger.info(f"BLANK id: {blank_id}") + logger.info(f"UNK id: {unk_id}") + logger.info(f"EOS id: {eos_id}") + logger.info(f"SOS id: {sos_id}") + logger.info(f"SPACE id: {space_id}") + logger.info(f"MASKCTC id: {maskctc_id}") + return token2id, id2token, vocab_list, unk_id, eos_id, blank_id diff --git a/paddlespeech/audio/text/utility.py b/paddlespeech/audio/text/utility.py new file mode 100644 index 00000000..d35785db --- /dev/null +++ b/paddlespeech/audio/text/utility.py @@ -0,0 +1,393 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains data helper functions.""" +import json +import math +import tarfile +from collections import namedtuple +from typing import List +from typing import Optional +from typing import Text + +import jsonlines +import numpy as np + +from paddlespeech.s2t.utils.log import Log + +logger = Log(__name__).getlog() + +__all__ = [ + "load_dict", "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", + "max_dbfs", "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS", + "EOS", "UNK", "BLANK", "MASKCTC", "SPACE", "convert_samples_to_float32", + "convert_samples_from_float32" +] + +IGNORE_ID = -1 +# `sos` and `eos` using same token +SOS = "" +EOS = SOS +UNK = "" +BLANK = "" +MASKCTC = "" +SPACE = "" + + +def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]: + if dict_path is None: + return None + + with open(dict_path, "r") as f: + dictionary = f.readlines() + # first token is `` + # multi line: ` 0\n` + # one line: `` + # space is relpace with + char_list = [entry[:-1].split(" ")[0] for entry in dictionary] + if BLANK not in char_list: + char_list.insert(0, BLANK) + if EOS not in char_list: + char_list.append(EOS) + # for non-autoregressive maskctc model + if maskctc and MASKCTC not in char_list: + char_list.append(MASKCTC) + return char_list + + +def read_manifest( + manifest_path, + max_input_len=float('inf'), + min_input_len=0.0, + max_output_len=float('inf'), + min_output_len=0.0, + max_output_input_ratio=float('inf'), + min_output_input_ratio=0.0, ): + """Load and parse manifest file. + + Args: + manifest_path ([type]): Manifest file to load and parse. + max_input_len ([type], optional): maximum output seq length, + in seconds for raw wav, in frame numbers for feature data. + Defaults to float('inf'). + min_input_len (float, optional): minimum input seq length, + in seconds for raw wav, in frame numbers for feature data. + Defaults to 0.0. + max_output_len (float, optional): maximum input seq length, + in modeling units. Defaults to 500.0. + min_output_len (float, optional): minimum input seq length, + in modeling units. Defaults to 0.0. + max_output_input_ratio (float, optional): + maximum output seq length/output seq length ratio. Defaults to 10.0. + min_output_input_ratio (float, optional): + minimum output seq length/output seq length ratio. Defaults to 0.05. + + Raises: + IOError: If failed to parse the manifest. + + Returns: + List[dict]: Manifest parsing results. + """ + manifest = [] + with jsonlines.open(manifest_path, 'r') as reader: + for json_data in reader: + feat_len = json_data["input"][0]["shape"][ + 0] if "input" in json_data and "shape" in json_data["input"][ + 0] else 1.0 + token_len = json_data["output"][0]["shape"][ + 0] if "output" in json_data and "shape" in json_data["output"][ + 0] else 1.0 + conditions = [ + feat_len >= min_input_len, + feat_len <= max_input_len, + token_len >= min_output_len, + token_len <= max_output_len, + token_len / feat_len >= min_output_input_ratio, + token_len / feat_len <= max_output_input_ratio, + ] + if all(conditions): + manifest.append(json_data) + return manifest + + +# Tar File read +TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) + + +def parse_tar(file): + """Parse a tar file to get a tarfile object + and a map containing tarinfoes + """ + result = {} + f = tarfile.open(file) + for tarinfo in f.getmembers(): + result[tarinfo.name] = tarinfo + return f, result + + +def subfile_from_tar(file, local_data=None): + """Get subfile object from tar. + + tar:tarpath#filename + + It will return a subfile object from tar file + and cached tar file info for next reading request. + """ + tarpath, filename = file.split(':', 1)[1].split('#', 1) + + if local_data is None: + local_data = TarLocalData(tar2info={}, tar2object={}) + + assert isinstance(local_data, TarLocalData) + + if 'tar2info' not in local_data.__dict__: + local_data.tar2info = {} + if 'tar2object' not in local_data.__dict__: + local_data.tar2object = {} + + if tarpath not in local_data.tar2info: + fobj, infos = parse_tar(tarpath) + local_data.tar2info[tarpath] = infos + local_data.tar2object[tarpath] = fobj + else: + fobj = local_data.tar2object[tarpath] + infos = local_data.tar2info[tarpath] + return fobj.extractfile(infos[filename]) + + +def rms_to_db(rms: float): + """Root Mean Square to dB. + + Args: + rms ([float]): root mean square + + Returns: + float: dB + """ + return 20.0 * math.log10(max(1e-16, rms)) + + +def rms_to_dbfs(rms: float): + """Root Mean Square to dBFS. + https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/ + Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB. + + dB = dBFS + 3.0103 + dBFS = db - 3.0103 + e.g. 0 dB = -3.0103 dBFS + + Args: + rms ([float]): root mean square + + Returns: + float: dBFS + """ + return rms_to_db(rms) - 3.0103 + + +def max_dbfs(sample_data: np.ndarray): + """Peak dBFS based on the maximum energy sample. + + Args: + sample_data ([np.ndarray]): float array, [-1, 1]. + + Returns: + float: dBFS + """ + # Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization. + return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data)))) + + +def mean_dbfs(sample_data): + """Peak dBFS based on the RMS energy. + + Args: + sample_data ([np.ndarray]): float array, [-1, 1]. + + Returns: + float: dBFS + """ + return rms_to_dbfs( + math.sqrt(np.mean(np.square(sample_data, dtype=np.float64)))) + + +def gain_db_to_ratio(gain_db: float): + """dB to ratio + + Args: + gain_db (float): gain in dB + + Returns: + float: scale in amp + """ + return math.pow(10.0, gain_db / 20.0) + + +def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103): + """Nomalize audio to dBFS. + + Args: + sample_data (np.ndarray): input wave samples, [-1, 1]. + dbfs (float, optional): target dBFS. Defaults to -3.0103. + + Returns: + np.ndarray: normalized wave + """ + return np.maximum( + np.minimum(sample_data * gain_db_to_ratio(dbfs - max_dbfs(sample_data)), + 1.0), -1.0) + + +def _load_json_cmvn(json_cmvn_file): + """ Load the json format cmvn stats file and calculate cmvn + + Args: + json_cmvn_file: cmvn stats file in json format + + Returns: + a numpy array of [means, vars] + """ + with open(json_cmvn_file) as f: + cmvn_stats = json.load(f) + + means = cmvn_stats['mean_stat'] + variance = cmvn_stats['var_stat'] + count = cmvn_stats['frame_num'] + for i in range(len(means)): + means[i] /= count + variance[i] = variance[i] / count - means[i] * means[i] + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 + variance[i] = 1.0 / math.sqrt(variance[i]) + cmvn = np.array([means, variance]) + return cmvn + + +def _load_kaldi_cmvn(kaldi_cmvn_file): + """ Load the kaldi format cmvn stats file and calculate cmvn + + Args: + kaldi_cmvn_file: kaldi text style global cmvn file, which + is generated by: + compute-cmvn-stats --binary=false scp:feats.scp global_cmvn + + Returns: + a numpy array of [means, vars] + """ + means = [] + variance = [] + with open(kaldi_cmvn_file, 'r') as fid: + # kaldi binary file start with '\0B' + if fid.read(2) == '\0B': + logger.error('kaldi cmvn binary file is not supported, please ' + 'recompute it by: compute-cmvn-stats --binary=false ' + ' scp:feats.scp global_cmvn') + sys.exit(1) + fid.seek(0) + arr = fid.read().split() + assert (arr[0] == '[') + assert (arr[-2] == '0') + assert (arr[-1] == ']') + feat_dim = int((len(arr) - 2 - 2) / 2) + for i in range(1, feat_dim + 1): + means.append(float(arr[i])) + count = float(arr[feat_dim + 1]) + for i in range(feat_dim + 2, 2 * feat_dim + 2): + variance.append(float(arr[i])) + + for i in range(len(means)): + means[i] /= count + variance[i] = variance[i] / count - means[i] * means[i] + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 + variance[i] = 1.0 / math.sqrt(variance[i]) + cmvn = np.array([means, variance]) + return cmvn + + +def load_cmvn(cmvn_file: str, filetype: str): + """load cmvn from file. + + Args: + cmvn_file (str): cmvn path. + filetype (str): file type, optional[npz, json, kaldi]. + + Raises: + ValueError: file type not support. + + Returns: + Tuple[np.ndarray, np.ndarray]: mean, istd + """ + assert filetype in ['npz', 'json', 'kaldi'], filetype + filetype = filetype.lower() + if filetype == "json": + cmvn = _load_json_cmvn(cmvn_file) + elif filetype == "kaldi": + cmvn = _load_kaldi_cmvn(cmvn_file) + elif filetype == "npz": + eps = 1e-14 + npzfile = np.load(cmvn_file) + mean = np.squeeze(npzfile["mean"]) + std = np.squeeze(npzfile["std"]) + istd = 1 / (std + eps) + cmvn = [mean, istd] + else: + raise ValueError(f"cmvn file type no support: {filetype}") + return cmvn[0], cmvn[1] + + +def convert_samples_to_float32(samples): + """Convert sample type to float32. + + Audio sample type is usually integer or float-point. + Integers will be scaled to [-1, 1] in float32. + + PCM16 -> PCM32 + """ + float32_samples = samples.astype('float32') + if samples.dtype in np.sctypes['int']: + bits = np.iinfo(samples.dtype).bits + float32_samples *= (1. / 2**(bits - 1)) + elif samples.dtype in np.sctypes['float']: + pass + else: + raise TypeError("Unsupported sample type: %s." % samples.dtype) + return float32_samples + + +def convert_samples_from_float32(samples, dtype): + """Convert sample type from float32 to dtype. + + Audio sample type is usually integer or float-point. For integer + type, float32 will be rescaled from [-1, 1] to the maximum range + supported by the integer type. + + PCM32 -> PCM16 + """ + dtype = np.dtype(dtype) + output_samples = samples.copy() + if dtype in np.sctypes['int']: + bits = np.iinfo(dtype).bits + output_samples *= (2**(bits - 1) / 1.) + min_val = np.iinfo(dtype).min + max_val = np.iinfo(dtype).max + output_samples[output_samples > max_val] = max_val + output_samples[output_samples < min_val] = min_val + elif samples.dtype in np.sctypes['float']: + min_val = np.finfo(dtype).min + max_val = np.finfo(dtype).max + output_samples[output_samples > max_val] = max_val + output_samples[output_samples < min_val] = min_val + else: + raise TypeError("Unsupported sample type: %s." % samples.dtype) + return output_samples.astype(dtype) diff --git a/paddlespeech/s2t/exps/deepspeech2/model.py b/paddlespeech/s2t/exps/deepspeech2/model.py index 511997a7..7ab8cf85 100644 --- a/paddlespeech/s2t/exps/deepspeech2/model.py +++ b/paddlespeech/s2t/exps/deepspeech2/model.py @@ -23,7 +23,7 @@ import paddle from paddle import distributed as dist from paddle import inference -from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.audio.text.text_featurizer import TextFeaturizer from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.models.ds2 import DeepSpeech2InferModel from paddlespeech.s2t.models.ds2 import DeepSpeech2Model diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index d6c68f96..cdad3b8f 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -27,6 +27,7 @@ from paddle import distributed as dist from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.io.dataloader import StreamDataLoader +from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.reporter import ObsScope @@ -134,7 +135,8 @@ class U2Trainer(Trainer): msg = f"Valid: Rank: {dist.get_rank()}, " msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - #msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + if not self.use_streamdata: + msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -195,7 +197,6 @@ class U2Trainer(Trainer): except Exception as e: logger.error(e) raise e - with Timer("Eval Time Cost: {}"): total_loss, num_seen_utts = self.valid() if dist.get_world_size() > 1: @@ -224,186 +225,14 @@ class U2Trainer(Trainer): config = self.config.clone() self.use_streamdata = config.get("use_stream_data", False) if self.train: - # train/valid dataset, return token ids - if self.use_streamdata: - self.train_loader = StreamDataLoader( - manifest_file=config.train_manifest, - train_mode=True, - unit_type=config.unit_type, - batch_size=config.batch_size, - num_mel_bins=config.feat_dim, - frame_length=config.window_ms, - frame_shift=config.stride_ms, - dither=config.dither, - minlen_in=config.minlen_in, - maxlen_in=config.maxlen_in, - minlen_out=config.minlen_out, - maxlen_out=config.maxlen_out, - resample_rate=config.resample_rate, - augment_conf=config.augment_conf, # dict - shuffle_size=config.shuffle_size, - sort_size=config.sort_size, - n_iter_processes=config.num_workers, - prefetch_factor=config.prefetch_factor, - dist_sampler=config.get('dist_sampler', False), - cmvn_file=config.cmvn_file, - vocab_filepath=config.vocab_filepath, - ) - self.valid_loader = StreamDataLoader( - manifest_file=config.dev_manifest, - train_mode=False, - unit_type=config.unit_type, - batch_size=config.batch_size, - num_mel_bins=config.feat_dim, - frame_length=config.window_ms, - frame_shift=config.stride_ms, - dither=config.dither, - minlen_in=config.minlen_in, - maxlen_in=config.maxlen_in, - minlen_out=config.minlen_out, - maxlen_out=config.maxlen_out, - resample_rate=config.resample_rate, - augment_conf=config.augment_conf, # dict - shuffle_size=config.shuffle_size, - sort_size=config.sort_size, - n_iter_processes=config.num_workers, - prefetch_factor=config.prefetch_factor, - dist_sampler=config.get('dist_sampler', False), - cmvn_file=config.cmvn_file, - vocab_filepath=config.vocab_filepath, - ) - else: - self.train_loader = BatchDataLoader( - json_file=config.train_manifest, - train_mode=True, - sortagrad=config.sortagrad, - batch_size=config.batch_size, - maxlen_in=config.maxlen_in, - maxlen_out=config.maxlen_out, - minibatches=config.minibatches, - mini_batch_size=self.args.ngpu, - batch_count=config.batch_count, - batch_bins=config.batch_bins, - batch_frames_in=config.batch_frames_in, - batch_frames_out=config.batch_frames_out, - batch_frames_inout=config.batch_frames_inout, - preprocess_conf=config.preprocess_config, - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1, - dist_sampler=config.get('dist_sampler', False), - shortest_first=False) - - self.valid_loader = BatchDataLoader( - json_file=config.dev_manifest, - train_mode=False, - sortagrad=False, - batch_size=config.batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=self.args.ngpu, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1, - dist_sampler=config.get('dist_sampler', False), - shortest_first=False) + self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) + self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) logger.info("Setup train/valid Dataloader!") else: decode_batch_size = config.get('decode', dict()).get( 'decode_batch_size', 1) - # test dataset, return raw text - if self.use_streamdata: - self.test_loader = StreamDataLoader( - manifest_file=config.test_manifest, - train_mode=False, - unit_type=config.unit_type, - batch_size=config.batch_size, - num_mel_bins=config.feat_dim, - frame_length=config.window_ms, - frame_shift=config.stride_ms, - dither=0.0, - minlen_in=0.0, - maxlen_in=float('inf'), - minlen_out=0, - maxlen_out=float('inf'), - resample_rate=config.resample_rate, - augment_conf=config.augment_conf, # dict - shuffle_size=config.shuffle_size, - sort_size=config.sort_size, - n_iter_processes=config.num_workers, - prefetch_factor=config.prefetch_factor, - dist_sampler=config.get('dist_sampler', False), - cmvn_file=config.cmvn_file, - vocab_filepath=config.vocab_filepath, - ) - self.align_loader = StreamDataLoader( - manifest_file=config.test_manifest, - train_mode=False, - unit_type=config.unit_type, - batch_size=config.batch_size, - num_mel_bins=config.feat_dim, - frame_length=config.window_ms, - frame_shift=config.stride_ms, - dither=0.0, - minlen_in=0.0, - maxlen_in=float('inf'), - minlen_out=0, - maxlen_out=float('inf'), - resample_rate=config.resample_rate, - augment_conf=config.augment_conf, # dict - shuffle_size=config.shuffle_size, - sort_size=config.sort_size, - n_iter_processes=config.num_workers, - prefetch_factor=config.prefetch_factor, - dist_sampler=config.get('dist_sampler', False), - cmvn_file=config.cmvn_file, - vocab_filepath=config.vocab_filepath, - ) - else: - self.test_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=1, - subsampling_factor=1, - num_encs=1) - - self.align_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=1, - subsampling_factor=1, - num_encs=1) + self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) + self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args) logger.info("Setup test/align Dataloader!") def setup_model(self): diff --git a/paddlespeech/s2t/exps/u2_kaldi/model.py b/paddlespeech/s2t/exps/u2_kaldi/model.py index bc995977..cb015c11 100644 --- a/paddlespeech/s2t/exps/u2_kaldi/model.py +++ b/paddlespeech/s2t/exps/u2_kaldi/model.py @@ -25,7 +25,7 @@ from paddle import distributed as dist from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.frontend.utility import load_dict -from paddlespeech.s2t.io.dataloader import BatchDataLoader +from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.scheduler import LRSchedulerFactory @@ -104,7 +104,8 @@ class U2Trainer(Trainer): @paddle.no_grad() def valid(self): self.model.eval() - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 @@ -131,7 +132,8 @@ class U2Trainer(Trainer): msg = f"Valid: Rank: {dist.get_rank()}, " msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + if not self.use_streamdata: + msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -150,8 +152,8 @@ class U2Trainer(Trainer): # paddle.jit.save(script_model, script_model_path) self.before_train() - - logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() @@ -162,7 +164,8 @@ class U2Trainer(Trainer): msg = "Train: Rank: {}, ".format(dist.get_rank()) msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, + if not self.use_streamdata: + msg += "batch : {}/{}, ".format(batch_index + 1, len(self.train_loader)) msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) msg += "data time: {:>.3f}s, ".format(dataload_time) @@ -198,87 +201,23 @@ class U2Trainer(Trainer): self.new_epoch() def setup_dataloader(self): - config = self.config.clone() - # train/valid dataset, return token ids - self.train_loader = BatchDataLoader( - json_file=config.train_manifest, - train_mode=True, - sortagrad=False, - batch_size=config.batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=self.args.ngpu, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1) - - self.valid_loader = BatchDataLoader( - json_file=config.dev_manifest, - train_mode=False, - sortagrad=False, - batch_size=config.batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=self.args.ngpu, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=None, - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1) - - decode_batch_size = config.get('decode', dict()).get( - 'decode_batch_size', 1) - # test dataset, return raw text - self.test_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=None, - n_iter_processes=1, - subsampling_factor=1, - num_encs=1) - - self.align_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=None, - n_iter_processes=1, - subsampling_factor=1, - num_encs=1) - logger.info("Setup train/valid/test/align Dataloader!") + self.use_streamdata = config.get("use_stream_data", False) + if self.train: + config = self.config.clone() + self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) + config = self.config.clone() + config['preprocess_config'] = None + self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) + logger.info("Setup train/valid Dataloader!") + else: + config = self.config.clone() + config['preprocess_config'] = None + self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) + config = self.config.clone() + config['preprocess_config'] = None + self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args) + logger.info("Setup test/align Dataloader!") + def setup_model(self): config = self.config @@ -406,7 +345,8 @@ class U2Tester(U2Trainer): def test(self): assert self.args.result_file self.model.eval() - logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") stride_ms = self.config.stride_ms error_rate_type = None diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py index 6a32eda7..60382543 100644 --- a/paddlespeech/s2t/exps/u2_st/model.py +++ b/paddlespeech/s2t/exps/u2_st/model.py @@ -25,7 +25,7 @@ import paddle from paddle import distributed as dist from paddlespeech.s2t.frontend.featurizer import TextFeaturizer -from paddlespeech.s2t.io.dataloader import BatchDataLoader +from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.models.u2_st import U2STModel from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.reporter import ObsScope @@ -120,7 +120,8 @@ class U2STTrainer(Trainer): @paddle.no_grad() def valid(self): self.model.eval() - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 @@ -153,7 +154,8 @@ class U2STTrainer(Trainer): msg = f"Valid: Rank: {dist.get_rank()}, " msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + if not self.use_streamdata: + msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -172,8 +174,8 @@ class U2STTrainer(Trainer): # paddle.jit.save(script_model, script_model_path) self.before_train() - - logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() @@ -191,7 +193,8 @@ class U2STTrainer(Trainer): self.train_batch(batch_index, batch, msg) self.after_train_batch() report('iter', batch_index + 1) - report('total', len(self.train_loader)) + if not self.use_streamdata: + report('total', len(self.train_loader)) report('reader_cost', dataload_time) observation['batch_cost'] = observation[ 'reader_cost'] + observation['step_cost'] @@ -241,79 +244,18 @@ class U2STTrainer(Trainer): load_transcript = True if config.model_conf.asr_weight > 0 else False + config = self.config.clone() + config['load_transcript'] = load_transcript + self.use_streamdata = config.get("use_stream_data", False) if self.train: - # train/valid dataset, return token ids - self.train_loader = BatchDataLoader( - json_file=config.train_manifest, - train_mode=True, - sortagrad=False, - batch_size=config.batch_size, - maxlen_in=config.maxlen_in, - maxlen_out=config.maxlen_out, - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config. - preprocess_config, # aug will be off when train_mode=False - n_iter_processes=config.num_workers, - subsampling_factor=1, - load_aux_output=load_transcript, - num_encs=1, - dist_sampler=True) - - self.valid_loader = BatchDataLoader( - json_file=config.dev_manifest, - train_mode=False, - sortagrad=False, - batch_size=config.batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config. - preprocess_config, # aug will be off when train_mode=False - n_iter_processes=config.num_workers, - subsampling_factor=1, - load_aux_output=load_transcript, - num_encs=1, - dist_sampler=False) + self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) + self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) logger.info("Setup train/valid Dataloader!") else: - # test dataset, return raw text - decode_batch_size = config.get('decode', dict()).get( - 'decode_batch_size', 1) - self.test_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config. - preprocess_config, # aug will be off when train_mode=False - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1, - dist_sampler=False) - + self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) logger.info("Setup test Dataloader!") + def setup_model(self): config = self.config model_conf = config @@ -468,7 +410,8 @@ class U2STTester(U2STTrainer): def test(self): assert self.args.result_file self.model.eval() - logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") decode_cfg = self.config.decode bleu_func = bleu_score.char_bleu if decode_cfg.error_rate_type == 'char-bleu' else bleu_score.bleu diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index cb466ecb..83183024 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -30,9 +30,10 @@ from paddlespeech.s2t.io.reader import LoadInputsAndTargets from paddlespeech.s2t.utils.log import Log import paddlespeech.audio.streamdata as streamdata -from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.audio.text.text_featurizer import TextFeaturizer +from yacs.config import CfgNode -__all__ = ["BatchDataLoader"] +__all__ = ["BatchDataLoader", "StreamDataLoader"] logger = Log(__name__).getlog() @@ -60,12 +61,36 @@ def batch_collate(x): """ return x[0] +def read_preprocess_cfg(preprocess_conf_file): + augment_conf = dict() + preprocess_cfg = CfgNode(new_allowed=True) + preprocess_cfg.merge_from_file(preprocess_conf_file) + for idx, process in enumerate(preprocess_cfg["process"]): + opts = dict(process) + process_type = opts.pop("type") + if process_type == 'time_warp': + augment_conf['max_w'] = process['max_time_warp'] + augment_conf['w_inplace'] = process['inplace'] + augment_conf['w_mode'] = process['mode'] + if process_type == 'freq_mask': + augment_conf['max_f'] = process['F'] + augment_conf['num_f_mask'] = process['n_mask'] + augment_conf['f_inplace'] = process['inplace'] + augment_conf['f_replace_with_zero'] = process['replace_with_zero'] + if process_type == 'time_mask': + augment_conf['max_t'] = process['T'] + augment_conf['num_t_mask'] = process['n_mask'] + augment_conf['t_inplace'] = process['inplace'] + augment_conf['t_replace_with_zero'] = process['replace_with_zero'] + return augment_conf + class StreamDataLoader(): def __init__(self, manifest_file: str, train_mode: bool, unit_type: str='char', batch_size: int=0, + preprocess_conf=None, num_mel_bins=80, frame_length=25, frame_shift=10, @@ -75,7 +100,6 @@ class StreamDataLoader(): minlen_out: float=0.0, maxlen_out: float=float('inf'), resample_rate: int=16000, - augment_conf: dict=None, shuffle_size: int=10000, sort_size: int=1000, n_iter_processes: int=1, @@ -95,12 +119,27 @@ class StreamDataLoader(): self.feat_dim = num_mel_bins self.vocab_size = text_featurizer.vocab_size + augment_conf = read_preprocess_cfg(preprocess_conf) + # The list of shard shardlist = [] with open(manifest_file, "r") as f: for line in f.readlines(): shardlist.append(line.strip()) - + world_size = 1 + try: + world_size = paddle.distributed.get_world_size() + except Exception as e: + logger.warninig(e) + logger.warninig("can not get world_size using paddle.distributed.get_world_size(), use world_size=1") + assert(len(shardlist) >= world_size, "the length of shard list should >= number of gpus/xpus/...") + + update_n_iter_processes = int(max(min(len(shardlist)/world_size - 1, self.n_iter_processes), 0)) + logger.info(f"update_n_iter_processes {update_n_iter_processes}") + if update_n_iter_processes != self.n_iter_processes: + self.n_iter_processes = update_n_iter_processes + logger.info(f"change nun_workers to {self.n_iter_processes}") + if self.dist_sampler: base_dataset = streamdata.DataPipeline( streamdata.SimpleShardList(shardlist), @@ -116,16 +155,16 @@ class StreamDataLoader(): ) self.dataset = base_dataset.append_list( - streamdata.tokenize(symbol_table), - streamdata.data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_in), - streamdata.resample(resample_rate=resample_rate), - streamdata.compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither), - streamdata.spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80) + streamdata.audio_tokenize(symbol_table), + streamdata.audio_data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_out), + streamdata.audio_resample(resample_rate=resample_rate), + streamdata.audio_compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither), + streamdata.audio_spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80) streamdata.shuffle(shuffle_size), streamdata.sort(sort_size=sort_size), streamdata.batched(batch_size), - streamdata.padding(), - streamdata.cmvn(cmvn_file) + streamdata.audio_padding(), + streamdata.audio_cmvn(cmvn_file) ) if paddle.__version__ >= '2.3.2': @@ -295,3 +334,119 @@ class BatchDataLoader(): echo += f"shortest_first: {self.shortest_first}, " echo += f"file: {self.json_file}" return echo + + +class DataLoaderFactory(): + @staticmethod + def get_dataloader(mode: str, config, args): + config = config.clone() + use_streamdata = config.get("use_stream_data", False) + if use_streamdata: + if mode == 'train': + config['manifest'] = config.train_manifest + config['train_mode'] = True + elif mode == 'valid': + config['manifest'] = config.dev_manifest + config['train_mode'] = False + elif model == 'test' or mode == 'align': + config['manifest'] = config.test_manifest + config['train_mode'] = False + config['dither'] = 0.0 + config['minlen_in'] = 0.0 + config['maxlen_in'] = float('inf') + config['minlen_out'] = 0 + config['maxlen_out'] = float('inf') + config['dist_sampler'] = False + else: + raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'") + return StreamDataLoader( + manifest_file=config.manifest, + train_mode=config.train_mode, + unit_type=config.unit_type, + preprocess_conf=config.preprocess_config, + batch_size=config.batch_size, + num_mel_bins=config.feat_dim, + frame_length=config.window_ms, + frame_shift=config.stride_ms, + dither=config.dither, + minlen_in=config.minlen_in, + maxlen_in=config.maxlen_in, + minlen_out=config.minlen_out, + maxlen_out=config.maxlen_out, + resample_rate=config.resample_rate, + shuffle_size=config.shuffle_size, + sort_size=config.sort_size, + n_iter_processes=config.num_workers, + prefetch_factor=config.prefetch_factor, + dist_sampler=config.dist_sampler, + cmvn_file=config.cmvn_file, + vocab_filepath=config.vocab_filepath, + ) + else: + if mode == 'train': + config['manifest'] = config.train_manifest + config['train_mode'] = True + config['mini_batch_size'] = args.ngpu + config['subsampling_factor'] = 1 + config['num_encs'] = 1 + elif mode == 'valid': + config['manifest'] = config.dev_manifest + config['train_mode'] = False + config['sortagrad'] = False + config['maxlen_in'] = float('inf') + config['maxlen_out'] = float('inf') + config['minibatches'] = 0 + config['mini_batch_size'] = args.ngpu + config['batch_count'] = 'auto' + config['batch_bins'] = 0 + config['batch_frames_in'] = 0 + config['batch_frames_out'] = 0 + config['batch_frames_inout'] = 0 + config['subsampling_factor'] = 1 + config['num_encs'] = 1 + config['shortest_first'] = False + elif mode == 'test' or mode == 'align': + config['manifest'] = config.test_manifest + config['train_mode'] = False + config['sortagrad'] = False + config['batch_size'] = config.get('decode', dict()).get( + 'decode_batch_size', 1) + config['maxlen_in'] = float('inf') + config['maxlen_out'] = float('inf') + config['minibatches'] = 0 + config['mini_batch_size'] = 1 + config['batch_count'] = 'auto' + config['batch_bins'] = 0 + config['batch_frames_in'] = 0 + config['batch_frames_out'] = 0 + config['batch_frames_inout'] = 0 + config['num_workers'] = 1 + config['subsampling_factor'] = 1 + config['num_encs'] = 1 + config['dist_sampler'] = False + config['shortest_first'] = False + else: + raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'") + + return BatchDataLoader( + json_file=config.manifest, + train_mode=config.train_mode, + sortagrad=config.sortagrad, + batch_size=config.batch_size, + maxlen_in=config.maxlen_in, + maxlen_out=config.maxlen_out, + minibatches=config.minibatches, + mini_batch_size=config.mini_batch_size, + batch_count=config.batch_count, + batch_bins=config.batch_bins, + batch_frames_in=config.batch_frames_in, + batch_frames_out=config.batch_frames_out, + batch_frames_inout=config.batch_frames_inout, + preprocess_conf=config.preprocess_config, + n_iter_processes=config.num_workers, + subsampling_factor=config.subsampling_factor, + load_aux_output=config.get('load_transcript', None), + num_encs=config.num_encs, + dist_sampler=config.dist_sampler, + shortest_first=config.shortest_first) + diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index b4b61666..e3d0edb7 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -48,9 +48,9 @@ from paddlespeech.s2t.utils import checkpoint from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank from paddlespeech.s2t.utils.log import Log -from paddlespeech.s2t.utils.tensor_utils import add_sos_eos -from paddlespeech.s2t.utils.tensor_utils import pad_sequence -from paddlespeech.s2t.utils.tensor_utils import th_accuracy +from paddlespeech.audio.utils.tensor_utils import add_sos_eos +from paddlespeech.audio.utils.tensor_utils import pad_sequence +from paddlespeech.audio.utils.tensor_utils import th_accuracy from paddlespeech.s2t.utils.utility import log_add from paddlespeech.s2t.utils.utility import UpdateConfig diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py index 6447753c..00ded912 100644 --- a/paddlespeech/s2t/models/u2_st/u2_st.py +++ b/paddlespeech/s2t/models/u2_st/u2_st.py @@ -38,8 +38,8 @@ from paddlespeech.s2t.modules.mask import subsequent_mask from paddlespeech.s2t.utils import checkpoint from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils.log import Log -from paddlespeech.s2t.utils.tensor_utils import add_sos_eos -from paddlespeech.s2t.utils.tensor_utils import th_accuracy +from paddlespeech.audio.utils.tensor_utils import add_sos_eos +from paddlespeech.audio.utils.tensor_utils import th_accuracy from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ["U2STModel", "U2STInferModel"] -- GitLab