未验证 提交 00d76542 编写于 作者: J Jackwaterveg 提交者: GitHub

Merge pull request #864 from PaddlePaddle/collator

refactor st and asr collator
...@@ -28,10 +28,8 @@ from paddle import distributed as dist ...@@ -28,10 +28,8 @@ from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from yacs.config import CfgNode from yacs.config import CfgNode
from deepspeech.io.collator_st import KaldiPrePorocessedCollator from deepspeech.io.collator import SpeechCollator
from deepspeech.io.collator_st import SpeechCollator from deepspeech.io.collator import TripletSpeechCollator
from deepspeech.io.collator_st import TripletKaldiPrePorocessedCollator
from deepspeech.io.collator_st import TripletSpeechCollator
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.dataset import TripletManifestDataset from deepspeech.io.dataset import TripletManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradBatchSampler
...@@ -258,22 +256,13 @@ class U2STTrainer(Trainer): ...@@ -258,22 +256,13 @@ class U2STTrainer(Trainer):
config.data.manifest = config.data.dev_manifest config.data.manifest = config.data.dev_manifest
dev_dataset = Dataset.from_config(config) dev_dataset = Dataset.from_config(config)
if config.collator.raw_wav: if config.model.model_conf.asr_weight > 0.:
if config.model.model_conf.asr_weight > 0.: Collator = TripletSpeechCollator
Collator = TripletSpeechCollator TestCollator = SpeechCollator
TestCollator = SpeechCollator
else:
TestCollator = Collator = SpeechCollator
# Not yet implement the mtl loader for raw_wav.
else: else:
if config.model.model_conf.asr_weight > 0.: TestCollator = Collator = SpeechCollator
Collator = TripletKaldiPrePorocessedCollator
TestCollator = KaldiPrePorocessedCollator
else:
TestCollator = Collator = KaldiPrePorocessedCollator
collate_fn_train = Collator.from_config(config) collate_fn_train = Collator.from_config(config)
config.collator.augmentation_config = "" config.collator.augmentation_config = ""
collate_fn_dev = Collator.from_config(config) collate_fn_dev = Collator.from_config(config)
......
...@@ -24,8 +24,10 @@ import soundfile ...@@ -24,8 +24,10 @@ import soundfile
import soxbindings as sox import soxbindings as sox
from scipy import signal from scipy import signal
from .utility import subfile_from_tar
class AudioSegment(object):
class AudioSegment():
"""Monaural audio segment abstraction. """Monaural audio segment abstraction.
:param samples: Audio samples [num_samples x num_channels]. :param samples: Audio samples [num_samples x num_channels].
...@@ -68,16 +70,20 @@ class AudioSegment(object): ...@@ -68,16 +70,20 @@ class AudioSegment(object):
self.duration, self.rms_db)) self.duration, self.rms_db))
@classmethod @classmethod
def from_file(cls, file): def from_file(cls, file, infos=None):
"""Create audio segment from audio file. """Create audio segment from audio file.
:param filepath: Filepath or file object to audio file. Args:
:type filepath: str|file filepath (str|file): Filepath or file object to audio file.
:return: Audio segment instance. infos (TarLocalData, optional): tar2obj and tar2infos. Defaults to None.
:rtype: AudioSegment
Returns:
AudioSegment: Audio segment instance.
""" """
if isinstance(file, str) and re.findall(r".seqbin_\d+$", file): if isinstance(file, str) and re.findall(r".seqbin_\d+$", file):
return cls.from_sequence_file(file) return cls.from_sequence_file(file)
elif isinstance(file, str) and file.startswith('tar:'):
return cls.from_file(subfile_from_tar(file, infos))
else: else:
samples, sample_rate = soundfile.read(file, dtype='float32') samples, sample_rate = soundfile.read(file, dtype='float32')
return cls(samples, sample_rate) return cls(samples, sample_rate)
......
...@@ -64,8 +64,12 @@ class SpeechFeaturizer(): ...@@ -64,8 +64,12 @@ class SpeechFeaturizer():
target_sample_rate=16000, target_sample_rate=16000,
use_dB_normalization=True, use_dB_normalization=True,
target_dB=-20, target_dB=-20,
dither=1.0): dither=1.0,
self._audio_featurizer = AudioFeaturizer( maskctc=False):
self.stride_ms = stride_ms
self.window_ms = window_ms
self.audio_feature = AudioFeaturizer(
specgram_type=specgram_type, specgram_type=specgram_type,
feat_dim=feat_dim, feat_dim=feat_dim,
delta_delta=delta_delta, delta_delta=delta_delta,
...@@ -77,8 +81,12 @@ class SpeechFeaturizer(): ...@@ -77,8 +81,12 @@ class SpeechFeaturizer():
use_dB_normalization=use_dB_normalization, use_dB_normalization=use_dB_normalization,
target_dB=target_dB, target_dB=target_dB,
dither=dither) dither=dither)
self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath,
spm_model_prefix) self.text_feature = TextFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix,
maskctc=maskctc)
def featurize(self, speech_segment, keep_transcription_text): def featurize(self, speech_segment, keep_transcription_text):
"""Extract features for speech segment. """Extract features for speech segment.
...@@ -94,60 +102,33 @@ class SpeechFeaturizer(): ...@@ -94,60 +102,33 @@ class SpeechFeaturizer():
Returns: Returns:
tuple: 1) spectrogram audio feature in 2darray, 2) list oftoken indices. tuple: 1) spectrogram audio feature in 2darray, 2) list oftoken indices.
""" """
spec_feature = self._audio_featurizer.featurize(speech_segment) spec_feature = self.audio_feature.featurize(speech_segment)
if keep_transcription_text: if keep_transcription_text:
return spec_feature, speech_segment.transcript return spec_feature, speech_segment.transcript
if speech_segment.has_token: if speech_segment.has_token:
text_ids = speech_segment.token_ids text_ids = speech_segment.token_ids
else: else:
text_ids = self._text_featurizer.featurize( text_ids = self.text_feature.featurize(speech_segment.transcript)
speech_segment.transcript)
return spec_feature, text_ids return spec_feature, text_ids
@property def text_featurize(self, text, keep_transcription_text):
def vocab_size(self): """Extract features for speech segment.
"""Return the vocabulary size.
Returns:
int: Vocabulary size.
"""
return self._text_featurizer.vocab_size
@property
def vocab_list(self):
"""Return the vocabulary in list.
Returns:
List[str]:
"""
return self._text_featurizer.vocab_list
@property 1. For audio parts, extract the audio features.
def vocab_dict(self): 2. For transcript parts, keep the original text or convert text string
"""Return the vocabulary in dict. to a list of token indices in char-level.
Returns:
Dict[str, int]:
"""
return self._text_featurizer.vocab_dict
@property Args:
def feature_size(self): text (str): text.
"""Return the audio feature size. keep_transcription_text (bool): True, keep transcript text, False, token ids
Returns:
int: audio feature size.
"""
return self._audio_featurizer.feature_size
@property
def stride_ms(self):
"""time length in `ms` unit per frame
Returns: Returns:
float: time(ms)/frame (str|List[int]): text, or list of token indices.
""" """
return self._audio_featurizer.stride_ms if keep_transcription_text:
return text
@property text_ids = self.text_feature.featurize(text)
def text_feature(self): return text_ids
"""Return the text feature object.
Returns:
TextFeaturizer: object.
"""
return self._text_featurizer
...@@ -68,7 +68,12 @@ class SpeechSegment(AudioSegment): ...@@ -68,7 +68,12 @@ class SpeechSegment(AudioSegment):
return not self.__eq__(other) return not self.__eq__(other)
@classmethod @classmethod
def from_file(cls, filepath, transcript, tokens=None, token_ids=None): def from_file(cls,
filepath,
transcript,
tokens=None,
token_ids=None,
infos=None):
"""Create speech segment from audio file and corresponding transcript. """Create speech segment from audio file and corresponding transcript.
Args: Args:
...@@ -76,12 +81,12 @@ class SpeechSegment(AudioSegment): ...@@ -76,12 +81,12 @@ class SpeechSegment(AudioSegment):
transcript (str): Transcript text for the speech. transcript (str): Transcript text for the speech.
tokens (List[str], optional): text tokens. Defaults to None. tokens (List[str], optional): text tokens. Defaults to None.
token_ids (List[int], optional): text token ids. Defaults to None. token_ids (List[int], optional): text token ids. Defaults to None.
infos (TarLocalData, optional): tar2obj and tar2infos. Defaults to None.
Returns: Returns:
SpeechSegment: Speech segment instance. SpeechSegment: Speech segment instance.
""" """
audio = AudioSegment.from_file(filepath, infos)
audio = AudioSegment.from_file(filepath)
return cls(audio.samples, audio.sample_rate, transcript, tokens, return cls(audio.samples, audio.sample_rate, transcript, tokens,
token_ids) token_ids)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Contains data helper functions.""" """Contains data helper functions."""
import json import json
import math import math
import tarfile
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Text from typing import Text
...@@ -112,6 +113,51 @@ def read_manifest( ...@@ -112,6 +113,51 @@ def read_manifest(
return manifest return manifest
# Tar File read
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
def parse_tar(file):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result = {}
f = tarfile.open(file)
for tarinfo in f.getmembers():
result[tarinfo.name] = tarinfo
return f, result
def subfile_from_tar(file, local_data=None):
"""Get subfile object from tar.
tar:tarpath#filename
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath, filename = file.split(':', 1)[1].split('#', 1)
if local_data is None:
local_data = TarLocalData(tar2info={}, tar2object={})
assert isinstance(local_data, TarLocalData)
if 'tar2info' not in local_data.__dict__:
local_data.tar2info = {}
if 'tar2object' not in local_data.__dict__:
local_data.tar2object = {}
if tarpath not in local_data.tar2info:
fobj, infos = parse_tar(tarpath)
local_data.tar2info[tarpath] = infos
local_data.tar2object[tarpath] = fobj
else:
fobj = local_data.tar2object[tarpath]
infos = local_data.tar2info[tarpath]
return fobj.extractfile(infos[filename])
def rms_to_db(rms: float): def rms_to_db(rms: float):
"""Root Mean Square to dB. """Root Mean Square to dB.
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io import io
from collections import namedtuple
from typing import Optional from typing import Optional
import numpy as np import numpy as np
...@@ -23,96 +22,17 @@ from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer ...@@ -23,96 +22,17 @@ from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
from deepspeech.frontend.normalizer import FeatureNormalizer from deepspeech.frontend.normalizer import FeatureNormalizer
from deepspeech.frontend.speech import SpeechSegment from deepspeech.frontend.speech import SpeechSegment
from deepspeech.frontend.utility import IGNORE_ID from deepspeech.frontend.utility import IGNORE_ID
from deepspeech.frontend.utility import TarLocalData
from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.io.utility import pad_list from deepspeech.io.utility import pad_list
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = ["SpeechCollator"] __all__ = ["SpeechCollator", "TripletSpeechCollator"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
# namedtupe need global for pickle.
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
class SpeechCollator():
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
augmentation_config="",
random_seed=0,
mean_std_filepath="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
dither=1.0, # feature dither
keep_transcription_text=False))
if config is not None:
config.merge_from_other_cfg(default)
return default
@classmethod
def from_config(cls, config):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
"""
assert 'augmentation_config' in config.collator
assert 'keep_transcription_text' in config.collator
assert 'mean_std_filepath' in config.collator
assert 'vocab_filepath' in config.collator
assert 'specgram_type' in config.collator
assert 'n_fft' in config.collator
assert config.collator
if isinstance(config.collator.augmentation_config, (str, bytes)):
if config.collator.augmentation_config:
aug_file = io.open(
config.collator.augmentation_config,
mode='r',
encoding='utf8')
else:
aug_file = io.StringIO(initial_value='{}', newline='')
else:
aug_file = config.collator.augmentation_config
assert isinstance(aug_file, io.StringIO)
speech_collator = cls(
aug_file=aug_file,
random_seed=0,
mean_std_filepath=config.collator.mean_std_filepath,
unit_type=config.collator.unit_type,
vocab_filepath=config.collator.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix,
specgram_type=config.collator.specgram_type,
feat_dim=config.collator.feat_dim,
delta_delta=config.collator.delta_delta,
stride_ms=config.collator.stride_ms,
window_ms=config.collator.window_ms,
n_fft=config.collator.n_fft,
max_freq=config.collator.max_freq,
target_sample_rate=config.collator.target_sample_rate,
use_dB_normalization=config.collator.use_dB_normalization,
target_dB=config.collator.target_dB,
dither=config.collator.dither,
keep_transcription_text=config.collator.keep_transcription_text)
return speech_collator
class SpeechCollatorBase():
def __init__( def __init__(
self, self,
aug_file, aug_file,
...@@ -121,7 +41,7 @@ class SpeechCollator(): ...@@ -121,7 +41,7 @@ class SpeechCollator():
spm_model_prefix, spm_model_prefix,
random_seed=0, random_seed=0,
unit_type="char", unit_type="char",
specgram_type='linear', # 'linear', 'mfcc', 'fbank' spectrum_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank' feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank' delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms stride_ms=10.0, # ms
...@@ -146,7 +66,7 @@ class SpeechCollator(): ...@@ -146,7 +66,7 @@ class SpeechCollator():
n_fft (int, optional): fft points for rfft. Defaults to None. n_fft (int, optional): fft points for rfft. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None. max_freq (int, optional): max cut freq. Defaults to None.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000. target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'. spectrum_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None. feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False. delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True. use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
...@@ -159,23 +79,27 @@ class SpeechCollator(): ...@@ -159,23 +79,27 @@ class SpeechCollator():
Padding audio features with zeros to make them have the same shape (or Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one batch. a user-defined shape) within one batch.
""" """
self._keep_transcription_text = keep_transcription_text self.keep_transcription_text = keep_transcription_text
self.stride_ms = stride_ms
self.window_ms = window_ms
self.feat_dim = feat_dim
self.loader = LoadInputsAndTargets()
# only for tar filetype
self._local_data = TarLocalData(tar2info={}, tar2object={}) self._local_data = TarLocalData(tar2info={}, tar2object={})
self._augmentation_pipeline = AugmentationPipeline(
self.augmentation = AugmentationPipeline(
augmentation_config=aug_file.read(), random_seed=random_seed) augmentation_config=aug_file.read(), random_seed=random_seed)
self._normalizer = FeatureNormalizer( self._normalizer = FeatureNormalizer(
mean_std_filepath) if mean_std_filepath else None mean_std_filepath) if mean_std_filepath else None
self._stride_ms = stride_ms
self._target_sample_rate = target_sample_rate
self._speech_featurizer = SpeechFeaturizer( self._speech_featurizer = SpeechFeaturizer(
unit_type=unit_type, unit_type=unit_type,
vocab_filepath=vocab_filepath, vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix, spm_model_prefix=spm_model_prefix,
specgram_type=specgram_type, spectrum_type=spectrum_type,
feat_dim=feat_dim, feat_dim=feat_dim,
delta_delta=delta_delta, delta_delta=delta_delta,
stride_ms=stride_ms, stride_ms=stride_ms,
...@@ -187,33 +111,11 @@ class SpeechCollator(): ...@@ -187,33 +111,11 @@ class SpeechCollator():
target_dB=target_dB, target_dB=target_dB,
dither=dither) dither=dither)
def _parse_tar(self, file): self.feature_size = self._speech_featurizer.audio_feature.feature_size
"""Parse a tar file to get a tarfile object self.text_feature = self._speech_featurizer.text_feature
and a map containing tarinfoes self.vocab_dict = self.text_feature.vocab_dict
""" self.vocab_list = self.text_feature.vocab_list
result = {} self.vocab_size = self.text_feature.vocab_size
f = tarfile.open(file)
for tarinfo in f.getmembers():
result[tarinfo.name] = tarinfo
return f, result
def _subfile_from_tar(self, file):
"""Get subfile object from tar.
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath, filename = file.split(':', 1)[1].split('#', 1)
if 'tar2info' not in self._local_data.__dict__:
self._local_data.tar2info = {}
if 'tar2object' not in self._local_data.__dict__:
self._local_data.tar2object = {}
if tarpath not in self._local_data.tar2info:
object, infoes = self._parse_tar(tarpath)
self._local_data.tar2info[tarpath] = infoes
self._local_data.tar2object[tarpath] = object
return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename])
def process_utterance(self, audio_file, transcript): def process_utterance(self, audio_file, transcript):
"""Load, augment, featurize and normalize for speech data. """Load, augment, featurize and normalize for speech data.
...@@ -226,23 +128,36 @@ class SpeechCollator(): ...@@ -226,23 +128,36 @@ class SpeechCollator():
where transcription part could be token ids or text. where transcription part could be token ids or text.
:rtype: tuple of (2darray, list) :rtype: tuple of (2darray, list)
""" """
if isinstance(audio_file, str) and audio_file.startswith('tar:'): filetype = self.loader.file_type(audio_file)
speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(audio_file), transcript) if filetype != 'sound':
spectrum = self.loader._get_from_loader(audio_file, filetype)
feat_dim = spectrum.shape[1]
assert feat_dim == self.feat_dim, f"expect feat dim {self.feat_dim}, but got {feat_dim}"
if self.keep_transcription_text:
transcript_part = transcript
else:
text_ids = self.text_feature.featurize(transcript)
transcript_part = text_ids
else: else:
speech_segment = SpeechSegment.from_file(audio_file, transcript) # read audio
speech_segment = SpeechSegment.from_file(
audio_file, transcript, infos=self._local_data)
# audio augment
self.augmentation.transform_audio(speech_segment)
# audio augment # extract speech feature
self._augmentation_pipeline.transform_audio(speech_segment) spectrum, transcript_part = self._speech_featurizer.featurize(
speech_segment, self.keep_transcription_text)
specgram, transcript_part = self._speech_featurizer.featurize( # CMVN spectrum
speech_segment, self._keep_transcription_text) if self._normalizer:
if self._normalizer: spectrum = self._normalizer.apply(spectrum)
specgram = self._normalizer.apply(specgram)
# specgram augment # spectrum augment
specgram = self._augmentation_pipeline.transform_feature(specgram) spectrum = self.augmentation.transform_feature(spectrum)
return specgram, transcript_part return spectrum, transcript_part
def __call__(self, batch): def __call__(self, batch):
"""batch examples """batch examples
...@@ -272,16 +187,14 @@ class SpeechCollator(): ...@@ -272,16 +187,14 @@ class SpeechCollator():
audios.append(audio) # [T, D] audios.append(audio) # [T, D]
audio_lens.append(audio.shape[0]) audio_lens.append(audio.shape[0])
# text # text
# for training, text is token ids # for training, text is token ids, else text is string, convert to unicode ord
# else text is string, convert to unicode ord
tokens = [] tokens = []
if self._keep_transcription_text: if self.keep_transcription_text:
assert isinstance(text, str), (type(text), text) assert isinstance(text, str), (type(text), text)
tokens = [ord(t) for t in text] tokens = [ord(t) for t in text]
else: else:
tokens = text # token ids tokens = text # token ids
tokens = tokens if isinstance(tokens, np.ndarray) else np.array( tokens = np.array(tokens, dtype=np.int64)
tokens, dtype=np.int64)
texts.append(tokens) texts.append(tokens)
text_lens.append(tokens.shape[0]) text_lens.append(tokens.shape[0])
...@@ -292,26 +205,162 @@ class SpeechCollator(): ...@@ -292,26 +205,162 @@ class SpeechCollator():
olens = np.array(text_lens).astype(np.int64) olens = np.array(text_lens).astype(np.int64)
return utts, xs_pad, ilens, ys_pad, olens return utts, xs_pad, ilens, ys_pad, olens
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
@property class SpeechCollator(SpeechCollatorBase):
def vocab_list(self): @classmethod
return self._speech_featurizer.vocab_list def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
augmentation_config="",
random_seed=0,
mean_std_filepath="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
spectrum_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
dither=1.0, # feature dither
keep_transcription_text=False))
if config is not None:
config.merge_from_other_cfg(default)
return default
@classmethod
def from_config(cls, config):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
"""
assert 'augmentation_config' in config.collator
assert 'keep_transcription_text' in config.collator
assert 'mean_std_filepath' in config.collator
assert 'vocab_filepath' in config.collator
assert 'spectrum_type' in config.collator
assert 'n_fft' in config.collator
assert config.collator
if isinstance(config.collator.augmentation_config, (str, bytes)):
if config.collator.augmentation_config:
aug_file = io.open(
config.collator.augmentation_config,
mode='r',
encoding='utf8')
else:
aug_file = io.StringIO(initial_value='{}', newline='')
else:
aug_file = config.collator.augmentation_config
assert isinstance(aug_file, io.StringIO)
speech_collator = cls(
aug_file=aug_file,
random_seed=0,
mean_std_filepath=config.collator.mean_std_filepath,
unit_type=config.collator.unit_type,
vocab_filepath=config.collator.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix,
spectrum_type=config.collator.spectrum_type,
feat_dim=config.collator.feat_dim,
delta_delta=config.collator.delta_delta,
stride_ms=config.collator.stride_ms,
window_ms=config.collator.window_ms,
n_fft=config.collator.n_fft,
max_freq=config.collator.max_freq,
target_sample_rate=config.collator.target_sample_rate,
use_dB_normalization=config.collator.use_dB_normalization,
target_dB=config.collator.target_dB,
dither=config.collator.dither,
keep_transcription_text=config.collator.keep_transcription_text)
return speech_collator
class TripletSpeechCollator(SpeechCollator):
def process_utterance(self, audio_file, translation, transcript):
"""Load, augment, featurize and normalize for speech data.
@property :param audio_file: Filepath or file object of audio file.
def vocab_dict(self): :type audio_file: str | file
return self._speech_featurizer.vocab_dict :param translation: translation text.
:type translation: str
:return: Tuple of audio feature tensor and data of translation part,
where translation part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
spectrum, translation_part = super().process_utterance(audio_file,
translation)
transcript_part = self._speech_featurizer.text_featurize(
transcript, self.keep_transcription_text)
return spectrum, translation_part, transcript_part
@property def __call__(self, batch):
def text_feature(self): """batch examples
return self._speech_featurizer.text_feature
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (T, D)
text (List[int] or str): shape (U,)
@property Returns:
def feature_size(self): tuple(audio, text, audio_lens, text_lens): batched data.
return self._speech_featurizer.feature_size audio : (B, Tmax, D)
audio_lens: (B)
text : (B, Umax)
text_lens: (B)
"""
audios = []
audio_lens = []
translation_text = []
translation_text_lens = []
transcription_text = []
transcription_text_lens = []
@property utts = []
def stride_ms(self): for utt, audio, translation, transcription in batch:
return self._speech_featurizer.stride_ms audio, translation, transcription = self.process_utterance(
audio, translation, transcription)
#utt
utts.append(utt)
# audio
audios.append(audio) # [T, D]
audio_lens.append(audio.shape[0])
# text
# for training, text is token ids
# else text is string, convert to unicode ord
tokens = [[], []]
for idx, text in enumerate([translation, transcription]):
if self.keep_transcription_text:
assert isinstance(text, str), (type(text), text)
tokens[idx] = [ord(t) for t in text]
else:
tokens[idx] = text # token ids
tokens[idx] = np.array(tokens[idx], dtype=np.int64)
translation_text.append(tokens[0])
translation_text_lens.append(tokens[0].shape[0])
transcription_text.append(tokens[1])
transcription_text_lens.append(tokens[1].shape[0])
padded_audios = pad_sequence(
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
audio_lens = np.array(audio_lens).astype(np.int64)
padded_translation = pad_sequence(
translation_text, padding_value=IGNORE_ID).astype(np.int64)
translation_lens = np.array(translation_text_lens).astype(np.int64)
padded_transcription = pad_sequence(
transcription_text, padding_value=IGNORE_ID).astype(np.int64)
transcription_lens = np.array(transcription_text_lens).astype(np.int64)
return utts, padded_audios, audio_lens, (
padded_translation, padded_transcription), (translation_lens,
transcription_lens)
此差异已折叠。
...@@ -321,6 +321,22 @@ class LoadInputsAndTargets(): ...@@ -321,6 +321,22 @@ class LoadInputsAndTargets():
raise NotImplementedError( raise NotImplementedError(
"Not supported: loader_type={}".format(filetype)) "Not supported: loader_type={}".format(filetype))
def file_type(self, filepath):
suffix = filepath.split(":")[0].split('.')[1]
if suffix == 'ark':
return 'mat'
elif suffix == 'scp':
return 'scp'
elif suffix == 'npy':
return 'npy'
elif suffix == 'npz':
return 'npz'
elif suffix in ['wav', 'flac']:
# PCM16
return 'sound'
else:
raise ValueError(f"Not support filetype: {suffix}")
class SoundHDF5File(): class SoundHDF5File():
"""Collecting sound files to a HDF5 file """Collecting sound files to a HDF5 file
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册