未验证 提交 00d76542 编写于 作者: J Jackwaterveg 提交者: GitHub

Merge pull request #864 from PaddlePaddle/collator

refactor st and asr collator
......@@ -28,10 +28,8 @@ from paddle import distributed as dist
from paddle.io import DataLoader
from yacs.config import CfgNode
from deepspeech.io.collator_st import KaldiPrePorocessedCollator
from deepspeech.io.collator_st import SpeechCollator
from deepspeech.io.collator_st import TripletKaldiPrePorocessedCollator
from deepspeech.io.collator_st import TripletSpeechCollator
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.collator import TripletSpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.dataset import TripletManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler
......@@ -258,22 +256,13 @@ class U2STTrainer(Trainer):
config.data.manifest = config.data.dev_manifest
dev_dataset = Dataset.from_config(config)
if config.collator.raw_wav:
if config.model.model_conf.asr_weight > 0.:
Collator = TripletSpeechCollator
TestCollator = SpeechCollator
else:
TestCollator = Collator = SpeechCollator
# Not yet implement the mtl loader for raw_wav.
else:
if config.model.model_conf.asr_weight > 0.:
Collator = TripletKaldiPrePorocessedCollator
TestCollator = KaldiPrePorocessedCollator
else:
TestCollator = Collator = KaldiPrePorocessedCollator
collate_fn_train = Collator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = Collator.from_config(config)
......
......@@ -24,8 +24,10 @@ import soundfile
import soxbindings as sox
from scipy import signal
from .utility import subfile_from_tar
class AudioSegment(object):
class AudioSegment():
"""Monaural audio segment abstraction.
:param samples: Audio samples [num_samples x num_channels].
......@@ -68,16 +70,20 @@ class AudioSegment(object):
self.duration, self.rms_db))
@classmethod
def from_file(cls, file):
def from_file(cls, file, infos=None):
"""Create audio segment from audio file.
:param filepath: Filepath or file object to audio file.
:type filepath: str|file
:return: Audio segment instance.
:rtype: AudioSegment
Args:
filepath (str|file): Filepath or file object to audio file.
infos (TarLocalData, optional): tar2obj and tar2infos. Defaults to None.
Returns:
AudioSegment: Audio segment instance.
"""
if isinstance(file, str) and re.findall(r".seqbin_\d+$", file):
return cls.from_sequence_file(file)
elif isinstance(file, str) and file.startswith('tar:'):
return cls.from_file(subfile_from_tar(file, infos))
else:
samples, sample_rate = soundfile.read(file, dtype='float32')
return cls(samples, sample_rate)
......
......@@ -64,8 +64,12 @@ class SpeechFeaturizer():
target_sample_rate=16000,
use_dB_normalization=True,
target_dB=-20,
dither=1.0):
self._audio_featurizer = AudioFeaturizer(
dither=1.0,
maskctc=False):
self.stride_ms = stride_ms
self.window_ms = window_ms
self.audio_feature = AudioFeaturizer(
specgram_type=specgram_type,
feat_dim=feat_dim,
delta_delta=delta_delta,
......@@ -77,8 +81,12 @@ class SpeechFeaturizer():
use_dB_normalization=use_dB_normalization,
target_dB=target_dB,
dither=dither)
self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath,
spm_model_prefix)
self.text_feature = TextFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix,
maskctc=maskctc)
def featurize(self, speech_segment, keep_transcription_text):
"""Extract features for speech segment.
......@@ -94,60 +102,33 @@ class SpeechFeaturizer():
Returns:
tuple: 1) spectrogram audio feature in 2darray, 2) list oftoken indices.
"""
spec_feature = self._audio_featurizer.featurize(speech_segment)
spec_feature = self.audio_feature.featurize(speech_segment)
if keep_transcription_text:
return spec_feature, speech_segment.transcript
if speech_segment.has_token:
text_ids = speech_segment.token_ids
else:
text_ids = self._text_featurizer.featurize(
speech_segment.transcript)
text_ids = self.text_feature.featurize(speech_segment.transcript)
return spec_feature, text_ids
@property
def vocab_size(self):
"""Return the vocabulary size.
Returns:
int: Vocabulary size.
"""
return self._text_featurizer.vocab_size
@property
def vocab_list(self):
"""Return the vocabulary in list.
Returns:
List[str]:
"""
return self._text_featurizer.vocab_list
def text_featurize(self, text, keep_transcription_text):
"""Extract features for speech segment.
@property
def vocab_dict(self):
"""Return the vocabulary in dict.
Returns:
Dict[str, int]:
"""
return self._text_featurizer.vocab_dict
1. For audio parts, extract the audio features.
2. For transcript parts, keep the original text or convert text string
to a list of token indices in char-level.
@property
def feature_size(self):
"""Return the audio feature size.
Returns:
int: audio feature size.
"""
return self._audio_featurizer.feature_size
Args:
text (str): text.
keep_transcription_text (bool): True, keep transcript text, False, token ids
@property
def stride_ms(self):
"""time length in `ms` unit per frame
Returns:
float: time(ms)/frame
(str|List[int]): text, or list of token indices.
"""
return self._audio_featurizer.stride_ms
if keep_transcription_text:
return text
@property
def text_feature(self):
"""Return the text feature object.
Returns:
TextFeaturizer: object.
"""
return self._text_featurizer
text_ids = self.text_feature.featurize(text)
return text_ids
......@@ -68,7 +68,12 @@ class SpeechSegment(AudioSegment):
return not self.__eq__(other)
@classmethod
def from_file(cls, filepath, transcript, tokens=None, token_ids=None):
def from_file(cls,
filepath,
transcript,
tokens=None,
token_ids=None,
infos=None):
"""Create speech segment from audio file and corresponding transcript.
Args:
......@@ -76,12 +81,12 @@ class SpeechSegment(AudioSegment):
transcript (str): Transcript text for the speech.
tokens (List[str], optional): text tokens. Defaults to None.
token_ids (List[int], optional): text token ids. Defaults to None.
infos (TarLocalData, optional): tar2obj and tar2infos. Defaults to None.
Returns:
SpeechSegment: Speech segment instance.
"""
audio = AudioSegment.from_file(filepath)
audio = AudioSegment.from_file(filepath, infos)
return cls(audio.samples, audio.sample_rate, transcript, tokens,
token_ids)
......
......@@ -14,6 +14,7 @@
"""Contains data helper functions."""
import json
import math
import tarfile
from typing import List
from typing import Optional
from typing import Text
......@@ -112,6 +113,51 @@ def read_manifest(
return manifest
# Tar File read
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
def parse_tar(file):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result = {}
f = tarfile.open(file)
for tarinfo in f.getmembers():
result[tarinfo.name] = tarinfo
return f, result
def subfile_from_tar(file, local_data=None):
"""Get subfile object from tar.
tar:tarpath#filename
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath, filename = file.split(':', 1)[1].split('#', 1)
if local_data is None:
local_data = TarLocalData(tar2info={}, tar2object={})
assert isinstance(local_data, TarLocalData)
if 'tar2info' not in local_data.__dict__:
local_data.tar2info = {}
if 'tar2object' not in local_data.__dict__:
local_data.tar2object = {}
if tarpath not in local_data.tar2info:
fobj, infos = parse_tar(tarpath)
local_data.tar2info[tarpath] = infos
local_data.tar2object[tarpath] = fobj
else:
fobj = local_data.tar2object[tarpath]
infos = local_data.tar2info[tarpath]
return fobj.extractfile(infos[filename])
def rms_to_db(rms: float):
"""Root Mean Square to dB.
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
from collections import namedtuple
from typing import Optional
import numpy as np
......@@ -23,96 +22,17 @@ from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
from deepspeech.frontend.normalizer import FeatureNormalizer
from deepspeech.frontend.speech import SpeechSegment
from deepspeech.frontend.utility import IGNORE_ID
from deepspeech.frontend.utility import TarLocalData
from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.io.utility import pad_list
from deepspeech.utils.log import Log
__all__ = ["SpeechCollator"]
__all__ = ["SpeechCollator", "TripletSpeechCollator"]
logger = Log(__name__).getlog()
# namedtupe need global for pickle.
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
class SpeechCollator():
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
augmentation_config="",
random_seed=0,
mean_std_filepath="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
dither=1.0, # feature dither
keep_transcription_text=False))
if config is not None:
config.merge_from_other_cfg(default)
return default
@classmethod
def from_config(cls, config):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
"""
assert 'augmentation_config' in config.collator
assert 'keep_transcription_text' in config.collator
assert 'mean_std_filepath' in config.collator
assert 'vocab_filepath' in config.collator
assert 'specgram_type' in config.collator
assert 'n_fft' in config.collator
assert config.collator
if isinstance(config.collator.augmentation_config, (str, bytes)):
if config.collator.augmentation_config:
aug_file = io.open(
config.collator.augmentation_config,
mode='r',
encoding='utf8')
else:
aug_file = io.StringIO(initial_value='{}', newline='')
else:
aug_file = config.collator.augmentation_config
assert isinstance(aug_file, io.StringIO)
speech_collator = cls(
aug_file=aug_file,
random_seed=0,
mean_std_filepath=config.collator.mean_std_filepath,
unit_type=config.collator.unit_type,
vocab_filepath=config.collator.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix,
specgram_type=config.collator.specgram_type,
feat_dim=config.collator.feat_dim,
delta_delta=config.collator.delta_delta,
stride_ms=config.collator.stride_ms,
window_ms=config.collator.window_ms,
n_fft=config.collator.n_fft,
max_freq=config.collator.max_freq,
target_sample_rate=config.collator.target_sample_rate,
use_dB_normalization=config.collator.use_dB_normalization,
target_dB=config.collator.target_dB,
dither=config.collator.dither,
keep_transcription_text=config.collator.keep_transcription_text)
return speech_collator
class SpeechCollatorBase():
def __init__(
self,
aug_file,
......@@ -121,7 +41,7 @@ class SpeechCollator():
spm_model_prefix,
random_seed=0,
unit_type="char",
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
spectrum_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms
......@@ -146,7 +66,7 @@ class SpeechCollator():
n_fft (int, optional): fft points for rfft. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
spectrum_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
......@@ -159,23 +79,27 @@ class SpeechCollator():
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one batch.
"""
self._keep_transcription_text = keep_transcription_text
self.keep_transcription_text = keep_transcription_text
self.stride_ms = stride_ms
self.window_ms = window_ms
self.feat_dim = feat_dim
self.loader = LoadInputsAndTargets()
# only for tar filetype
self._local_data = TarLocalData(tar2info={}, tar2object={})
self._augmentation_pipeline = AugmentationPipeline(
self.augmentation = AugmentationPipeline(
augmentation_config=aug_file.read(), random_seed=random_seed)
self._normalizer = FeatureNormalizer(
mean_std_filepath) if mean_std_filepath else None
self._stride_ms = stride_ms
self._target_sample_rate = target_sample_rate
self._speech_featurizer = SpeechFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix,
specgram_type=specgram_type,
spectrum_type=spectrum_type,
feat_dim=feat_dim,
delta_delta=delta_delta,
stride_ms=stride_ms,
......@@ -187,33 +111,11 @@ class SpeechCollator():
target_dB=target_dB,
dither=dither)
def _parse_tar(self, file):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result = {}
f = tarfile.open(file)
for tarinfo in f.getmembers():
result[tarinfo.name] = tarinfo
return f, result
def _subfile_from_tar(self, file):
"""Get subfile object from tar.
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath, filename = file.split(':', 1)[1].split('#', 1)
if 'tar2info' not in self._local_data.__dict__:
self._local_data.tar2info = {}
if 'tar2object' not in self._local_data.__dict__:
self._local_data.tar2object = {}
if tarpath not in self._local_data.tar2info:
object, infoes = self._parse_tar(tarpath)
self._local_data.tar2info[tarpath] = infoes
self._local_data.tar2object[tarpath] = object
return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename])
self.feature_size = self._speech_featurizer.audio_feature.feature_size
self.text_feature = self._speech_featurizer.text_feature
self.vocab_dict = self.text_feature.vocab_dict
self.vocab_list = self.text_feature.vocab_list
self.vocab_size = self.text_feature.vocab_size
def process_utterance(self, audio_file, transcript):
"""Load, augment, featurize and normalize for speech data.
......@@ -226,23 +128,36 @@ class SpeechCollator():
where transcription part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(audio_file), transcript)
else:
speech_segment = SpeechSegment.from_file(audio_file, transcript)
filetype = self.loader.file_type(audio_file)
if filetype != 'sound':
spectrum = self.loader._get_from_loader(audio_file, filetype)
feat_dim = spectrum.shape[1]
assert feat_dim == self.feat_dim, f"expect feat dim {self.feat_dim}, but got {feat_dim}"
if self.keep_transcription_text:
transcript_part = transcript
else:
text_ids = self.text_feature.featurize(transcript)
transcript_part = text_ids
else:
# read audio
speech_segment = SpeechSegment.from_file(
audio_file, transcript, infos=self._local_data)
# audio augment
self._augmentation_pipeline.transform_audio(speech_segment)
self.augmentation.transform_audio(speech_segment)
# extract speech feature
spectrum, transcript_part = self._speech_featurizer.featurize(
speech_segment, self.keep_transcription_text)
specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
# CMVN spectrum
if self._normalizer:
specgram = self._normalizer.apply(specgram)
spectrum = self._normalizer.apply(spectrum)
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
return specgram, transcript_part
# spectrum augment
spectrum = self.augmentation.transform_feature(spectrum)
return spectrum, transcript_part
def __call__(self, batch):
"""batch examples
......@@ -272,16 +187,14 @@ class SpeechCollator():
audios.append(audio) # [T, D]
audio_lens.append(audio.shape[0])
# text
# for training, text is token ids
# else text is string, convert to unicode ord
# for training, text is token ids, else text is string, convert to unicode ord
tokens = []
if self._keep_transcription_text:
if self.keep_transcription_text:
assert isinstance(text, str), (type(text), text)
tokens = [ord(t) for t in text]
else:
tokens = text # token ids
tokens = tokens if isinstance(tokens, np.ndarray) else np.array(
tokens, dtype=np.int64)
tokens = np.array(tokens, dtype=np.int64)
texts.append(tokens)
text_lens.append(tokens.shape[0])
......@@ -292,26 +205,162 @@ class SpeechCollator():
olens = np.array(text_lens).astype(np.int64)
return utts, xs_pad, ilens, ys_pad, olens
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
return self._speech_featurizer.vocab_list
class SpeechCollator(SpeechCollatorBase):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
augmentation_config="",
random_seed=0,
mean_std_filepath="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
spectrum_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
dither=1.0, # feature dither
keep_transcription_text=False))
if config is not None:
config.merge_from_other_cfg(default)
return default
@classmethod
def from_config(cls, config):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
"""
assert 'augmentation_config' in config.collator
assert 'keep_transcription_text' in config.collator
assert 'mean_std_filepath' in config.collator
assert 'vocab_filepath' in config.collator
assert 'spectrum_type' in config.collator
assert 'n_fft' in config.collator
assert config.collator
if isinstance(config.collator.augmentation_config, (str, bytes)):
if config.collator.augmentation_config:
aug_file = io.open(
config.collator.augmentation_config,
mode='r',
encoding='utf8')
else:
aug_file = io.StringIO(initial_value='{}', newline='')
else:
aug_file = config.collator.augmentation_config
assert isinstance(aug_file, io.StringIO)
speech_collator = cls(
aug_file=aug_file,
random_seed=0,
mean_std_filepath=config.collator.mean_std_filepath,
unit_type=config.collator.unit_type,
vocab_filepath=config.collator.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix,
spectrum_type=config.collator.spectrum_type,
feat_dim=config.collator.feat_dim,
delta_delta=config.collator.delta_delta,
stride_ms=config.collator.stride_ms,
window_ms=config.collator.window_ms,
n_fft=config.collator.n_fft,
max_freq=config.collator.max_freq,
target_sample_rate=config.collator.target_sample_rate,
use_dB_normalization=config.collator.use_dB_normalization,
target_dB=config.collator.target_dB,
dither=config.collator.dither,
keep_transcription_text=config.collator.keep_transcription_text)
return speech_collator
class TripletSpeechCollator(SpeechCollator):
def process_utterance(self, audio_file, translation, transcript):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
:type audio_file: str | file
:param translation: translation text.
:type translation: str
:return: Tuple of audio feature tensor and data of translation part,
where translation part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
spectrum, translation_part = super().process_utterance(audio_file,
translation)
transcript_part = self._speech_featurizer.text_featurize(
transcript, self.keep_transcription_text)
return spectrum, translation_part, transcript_part
@property
def vocab_dict(self):
return self._speech_featurizer.vocab_dict
def __call__(self, batch):
"""batch examples
@property
def text_feature(self):
return self._speech_featurizer.text_feature
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (T, D)
text (List[int] or str): shape (U,)
@property
def feature_size(self):
return self._speech_featurizer.feature_size
Returns:
tuple(audio, text, audio_lens, text_lens): batched data.
audio : (B, Tmax, D)
audio_lens: (B)
text : (B, Umax)
text_lens: (B)
"""
audios = []
audio_lens = []
translation_text = []
translation_text_lens = []
transcription_text = []
transcription_text_lens = []
@property
def stride_ms(self):
return self._speech_featurizer.stride_ms
utts = []
for utt, audio, translation, transcription in batch:
audio, translation, transcription = self.process_utterance(
audio, translation, transcription)
#utt
utts.append(utt)
# audio
audios.append(audio) # [T, D]
audio_lens.append(audio.shape[0])
# text
# for training, text is token ids
# else text is string, convert to unicode ord
tokens = [[], []]
for idx, text in enumerate([translation, transcription]):
if self.keep_transcription_text:
assert isinstance(text, str), (type(text), text)
tokens[idx] = [ord(t) for t in text]
else:
tokens[idx] = text # token ids
tokens[idx] = np.array(tokens[idx], dtype=np.int64)
translation_text.append(tokens[0])
translation_text_lens.append(tokens[0].shape[0])
transcription_text.append(tokens[1])
transcription_text_lens.append(tokens[1].shape[0])
padded_audios = pad_sequence(
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
audio_lens = np.array(audio_lens).astype(np.int64)
padded_translation = pad_sequence(
translation_text, padding_value=IGNORE_ID).astype(np.int64)
translation_lens = np.array(translation_text_lens).astype(np.int64)
padded_transcription = pad_sequence(
transcription_text, padding_value=IGNORE_ID).astype(np.int64)
transcription_lens = np.array(transcription_text_lens).astype(np.int64)
return utts, padded_audios, audio_lens, (
padded_translation, padded_transcription), (translation_lens,
transcription_lens)
此差异已折叠。
......@@ -321,6 +321,22 @@ class LoadInputsAndTargets():
raise NotImplementedError(
"Not supported: loader_type={}".format(filetype))
def file_type(self, filepath):
suffix = filepath.split(":")[0].split('.')[1]
if suffix == 'ark':
return 'mat'
elif suffix == 'scp':
return 'scp'
elif suffix == 'npy':
return 'npy'
elif suffix == 'npz':
return 'npz'
elif suffix in ['wav', 'flac']:
# PCM16
return 'sound'
else:
raise ValueError(f"Not support filetype: {suffix}")
class SoundHDF5File():
"""Collecting sound files to a HDF5 file
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册