提交 92d1d08b 编写于 作者: H huangyuxin

fix scripts

上级 6ec69212
......@@ -52,6 +52,7 @@ test_manifest: data/test_meeting/data.list
use_stream_data: True
unit_type: 'char'
vocab_filepath: data/lang_char/vocab.txt
preprocess_config: conf/preprocess.yaml
cmvn_file: data/mean_std.json
spm_model_prefix: ''
feat_dim: 80
......@@ -65,30 +66,17 @@ maxlen_in: 1200 # if input length(number of frames) > maxlen-in, data is automa
minlen_out: 0
maxlen_out: 150 # if output length(number of tokens) > maxlen-out, data is automatically removed
resample_rate: 16000
shuffle_size: 1500
sort_size: 1000
shuffle_size: 1500 # read number of 'shuffle_size' data as a chunk, shuffle the data in the chunk
sort_size: 1000 # read number of 'sort_size' data as a chunk, sort the data in the chunk
num_workers: 8
prefetch_factor: 10
dist_sampler: True
num_encs: 1
augment_conf:
max_w: 80
w_inplace: True
w_mode: "PIL"
max_f: 30
num_f_mask: 2
f_inplace: True
f_replace_with_zero: False
max_t: 40
num_t_mask: 2
t_inplace: True
t_replace_with_zero: False
###########################################
# Training #
###########################################
n_epoch: 30
n_epoch: 32
accum_grad: 32
global_grad_clip: 5.0
log_interval: 100
......
......@@ -90,8 +90,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
for x in $dev_set $test_sets ${train_set}; do
dst=$shards_dir/$x
mkdir -p $dst
utils/make_filted_shard_list.py --resample 16000 --num_utts_per_shard 1000 \
--do_filter --num_node 1 --num_gpus_per_node 8 \
utils/make_filted_shard_list.py --num_node 1 --num_gpus_per_node 8 --num_utts_per_shard 1000 \
--do_filter --resample 16000 \
--num_threads 32 --segments data/$x/segments \
data/$x/wav.scp data/$x/text \
$(realpath $dst) data/$x/data.list
......
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
......@@ -26,7 +27,7 @@ from .filters import (
pipelinefilter,
rename,
rename_keys,
rsample,
audio_resample,
select,
shuffle,
slice,
......@@ -34,14 +35,14 @@ from .filters import (
transform_with,
unbatched,
xdecode,
data_filter,
tokenize,
resample,
compute_fbank,
spec_aug,
audio_data_filter,
audio_tokenize,
audio_resample,
audio_compute_fbank,
audio_spec_aug,
sort,
padding,
cmvn,
audio_padding,
audio_cmvn,
placeholder,
)
from .handlers import (
......
......@@ -291,12 +291,12 @@ def torch_video(key, data):
################################################################
# paddleaudio
# paddlespeech.audio
################################################################
def paddle_audio(key, data):
"""Decode audio using the paddleaudio library.
"""Decode audio using the paddlespeech.audio library.
:param key: file name extension
:param data: data to be decoded
......@@ -305,13 +305,13 @@ def paddle_audio(key, data):
if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]:
return None
import paddleaudio
import paddlespeech.audio
with tempfile.TemporaryDirectory() as dirname:
fname = os.path.join(dirname, f"file.{extension}")
with open(fname, "wb") as stream:
stream.write(data)
return paddleaudio.load(fname)
return paddlespeech.audio.load(fname)
################################################################
......
......@@ -78,29 +78,29 @@ class FluidInterface:
def xdecode(self, *args, **kw):
return self.compose(filters.xdecode(*args, **kw))
def data_filter(self, *args, **kw):
return self.compose(filters.data_filter(*args, **kw))
def audio_data_filter(self, *args, **kw):
return self.compose(filters.audio_data_filter(*args, **kw))
def tokenize(self, *args, **kw):
return self.compose(filters.tokenize(*args, **kw))
def audio_tokenize(self, *args, **kw):
return self.compose(filters.audio_tokenize(*args, **kw))
def resample(self, *args, **kw):
return self.compose(filters.resample(*args, **kw))
def compute_fbank(self, *args, **kw):
return self.compose(filters.compute_fbank(*args, **kw))
def audio_compute_fbank(self, *args, **kw):
return self.compose(filters.audio_compute_fbank(*args, **kw))
def spec_aug(self, *args, **kw):
return self.compose(filters.spec_aug(*args, **kw))
def audio_spec_aug(self, *args, **kw):
return self.compose(filters.audio_spec_aug(*args, **kw))
def sort(self, size=500):
return self.compose(filters.sort(size))
def padding(self):
return self.compose(filters.padding())
def audio_padding(self):
return self.compose(filters.audio_padding())
def cmvn(self, cmvn_file):
return self.compose(filters.cmvn(cmvn_file))
def audio_cmvn(self, cmvn_file):
return self.compose(filters.audio_cmvn(cmvn_file))
class WebDataset(DataPipeline, FluidInterface):
"""Small fluid-interface wrapper for DataPipeline."""
......
......@@ -579,7 +579,7 @@ xdecode = pipelinefilter(_xdecode)
def _data_filter(source,
def _audio_data_filter(source,
frame_shift=10,
max_length=10240,
min_length=10,
......@@ -629,9 +629,9 @@ def _data_filter(source,
continue
yield sample
data_filter = pipelinefilter(_data_filter)
audio_data_filter = pipelinefilter(_audio_data_filter)
def _tokenize(source,
def _audio_tokenize(source,
symbol_table,
bpe_model=None,
non_lang_syms=None,
......@@ -693,9 +693,9 @@ def _tokenize(source,
sample['label'] = label
yield sample
tokenize = pipelinefilter(_tokenize)
audio_tokenize = pipelinefilter(_audio_tokenize)
def _resample(source, resample_rate=16000):
def _audio_resample(source, resample_rate=16000):
""" Resample data.
Inplace operation.
......@@ -718,9 +718,9 @@ def _resample(source, resample_rate=16000):
))
yield sample
resample = pipelinefilter(_resample)
audio_resample = pipelinefilter(_audio_resample)
def _compute_fbank(source,
def _audio_compute_fbank(source,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
......@@ -756,9 +756,9 @@ def _compute_fbank(source,
yield dict(fname=sample['fname'], label=sample['label'], feat=mat)
compute_fbank = pipelinefilter(_compute_fbank)
audio_compute_fbank = pipelinefilter(_audio_compute_fbank)
def _spec_aug(source,
def _audio_spec_aug(source,
max_w=5,
w_inplace=True,
w_mode="PIL",
......@@ -799,7 +799,7 @@ def _spec_aug(source,
sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32)
yield sample
spec_aug = pipelinefilter(_spec_aug)
audio_spec_aug = pipelinefilter(_audio_spec_aug)
def _sort(source, sort_size=500):
......@@ -881,7 +881,7 @@ def dynamic_batched(source, max_frames_in_batch=12000):
yield buf
def _padding(source):
def _audio_padding(source):
""" Padding the data into training data
Args:
......@@ -914,9 +914,9 @@ def _padding(source):
yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
label_lengths)
padding = pipelinefilter(_padding)
audio_padding = pipelinefilter(_audio_padding)
def _cmvn(source, cmvn_file):
def _audio_cmvn(source, cmvn_file):
global_cmvn = GlobalCMVN(cmvn_file)
for batch in source:
sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths = batch
......@@ -926,7 +926,7 @@ def _cmvn(source, cmvn_file):
yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
label_lengths)
cmvn = pipelinefilter(_cmvn)
audio_cmvn = pipelinefilter(_audio_cmvn)
def _placeholder(source):
for data in source:
......
......@@ -21,7 +21,7 @@ trace = False
meta_prefix = "__"
meta_suffix = "__"
from ... import audio as paddleaudio
import paddlespeech
import paddle
import numpy as np
......@@ -118,7 +118,7 @@ def tar_file_iterator(
assert pos > 0
prefix, postfix = name[:pos], name[pos + 1:]
if postfix == 'wav':
waveform, sample_rate = paddleaudio.load(stream.extractfile(tarinfo), normal=False)
waveform, sample_rate = paddlespeech.audio.load(stream.extractfile(tarinfo), normal=False)
result = dict(fname=prefix, wav=waveform, sample_rate = sample_rate)
else:
txt = stream.extractfile(tarinfo).read().decode('utf8').strip()
......@@ -167,7 +167,7 @@ def tar_file_and_group_iterator(
if postfix == 'txt':
example['txt'] = file_obj.read().decode('utf8').strip()
elif postfix in AUDIO_FORMAT_SETS:
waveform, sample_rate = paddleaudio.load(file_obj, normal=False)
waveform, sample_rate = paddlespeech.audio.load(file_obj, normal=False)
waveform = paddle.to_tensor(np.expand_dims(np.array(waveform),0), dtype=paddle.float32)
example['wav'] = waveform
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the text featurizer class."""
from pprint import pformat
from typing import Union
import sentencepiece as spm
from .utility import BLANK
from .utility import EOS
from .utility import load_dict
from .utility import MASKCTC
from .utility import SOS
from .utility import SPACE
from .utility import UNK
from ..utils.log import Logger
logger = Logger(__name__)
__all__ = ["TextFeaturizer"]
class TextFeaturizer():
def __init__(self, unit_type, vocab, spm_model_prefix=None, maskctc=False):
"""Text featurizer, for processing or extracting features from text.
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
a list of token indices. Note that the token indexing order follows the
given vocabulary file.
Args:
unit_type (str): unit type, e.g. char, word, spm
vocab Option[str, list]: Filepath to load vocabulary for token indices conversion, or vocab list.
spm_model_prefix (str, optional): spm model prefix. Defaults to None.
"""
assert unit_type in ('char', 'spm', 'word')
self.unit_type = unit_type
self.unk = UNK
self.maskctc = maskctc
if vocab:
self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file(
vocab, maskctc)
self.vocab_size = len(self.vocab_list)
else:
logger.warning("TextFeaturizer: not have vocab file or vocab list.")
if unit_type == 'spm':
spm_model = spm_model_prefix + '.model'
self.sp = spm.SentencePieceProcessor()
self.sp.Load(spm_model)
def tokenize(self, text, replace_space=True):
if self.unit_type == 'char':
tokens = self.char_tokenize(text, replace_space)
elif self.unit_type == 'word':
tokens = self.word_tokenize(text)
else: # spm
tokens = self.spm_tokenize(text)
return tokens
def detokenize(self, tokens):
if self.unit_type == 'char':
text = self.char_detokenize(tokens)
elif self.unit_type == 'word':
text = self.word_detokenize(tokens)
else: # spm
text = self.spm_detokenize(tokens)
return text
def featurize(self, text):
"""Convert text string to a list of token indices.
Args:
text (str): Text to process.
Returns:
List[int]: List of token indices.
"""
tokens = self.tokenize(text)
ids = []
for token in tokens:
if token not in self.vocab_dict:
logger.debug(f"Text Token: {token} -> {self.unk}")
token = self.unk
ids.append(self.vocab_dict[token])
return ids
def defeaturize(self, idxs):
"""Convert a list of token indices to text string,
ignore index after eos_id.
Args:
idxs (List[int]): List of token indices.
Returns:
str: Text.
"""
tokens = []
for idx in idxs:
if idx == self.eos_id:
break
tokens.append(self._id2token[idx])
text = self.detokenize(tokens)
return text
def char_tokenize(self, text, replace_space=True):
"""Character tokenizer.
Args:
text (str): text string.
replace_space (bool): False only used by build_vocab.py.
Returns:
List[str]: tokens.
"""
text = text.strip()
if replace_space:
text_list = [SPACE if item == " " else item for item in list(text)]
else:
text_list = list(text)
return text_list
def char_detokenize(self, tokens):
"""Character detokenizer.
Args:
tokens (List[str]): tokens.
Returns:
str: text string.
"""
tokens = [t.replace(SPACE, " ") for t in tokens]
return "".join(tokens)
def word_tokenize(self, text):
"""Word tokenizer, separate by <space>."""
return text.strip().split()
def word_detokenize(self, tokens):
"""Word detokenizer, separate by <space>."""
return " ".join(tokens)
def spm_tokenize(self, text):
"""spm tokenize.
Args:
text (str): text string.
Returns:
List[str]: sentence pieces str code
"""
stats = {"num_empty": 0, "num_filtered": 0}
def valid(line):
return True
def encode(l):
return self.sp.EncodeAsPieces(l)
def encode_line(line):
line = line.strip()
if len(line) > 0:
line = encode(line)
if valid(line):
return line
else:
stats["num_filtered"] += 1
else:
stats["num_empty"] += 1
return None
enc_line = encode_line(text)
return enc_line
def spm_detokenize(self, tokens, input_format='piece'):
"""spm detokenize.
Args:
ids (List[str]): tokens.
Returns:
str: text
"""
if input_format == "piece":
def decode(l):
return "".join(self.sp.DecodePieces(l))
elif input_format == "id":
def decode(l):
return "".join(self.sp.DecodeIds(l))
return decode(tokens)
def _load_vocabulary_from_file(self, vocab: Union[str, list],
maskctc: bool):
"""Load vocabulary from file."""
if isinstance(vocab, list):
vocab_list = vocab
else:
vocab_list = load_dict(vocab, maskctc)
assert vocab_list is not None
logger.debug(f"Vocab: {pformat(vocab_list)}")
id2token = dict(
[(idx, token) for (idx, token) in enumerate(vocab_list)])
token2id = dict(
[(token, idx) for (idx, token) in enumerate(vocab_list)])
blank_id = vocab_list.index(BLANK) if BLANK in vocab_list else -1
maskctc_id = vocab_list.index(MASKCTC) if MASKCTC in vocab_list else -1
unk_id = vocab_list.index(UNK) if UNK in vocab_list else -1
eos_id = vocab_list.index(EOS) if EOS in vocab_list else -1
sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1
space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1
logger.info(f"BLANK id: {blank_id}")
logger.info(f"UNK id: {unk_id}")
logger.info(f"EOS id: {eos_id}")
logger.info(f"SOS id: {sos_id}")
logger.info(f"SPACE id: {space_id}")
logger.info(f"MASKCTC id: {maskctc_id}")
return token2id, id2token, vocab_list, unk_id, eos_id, blank_id
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains data helper functions."""
import json
import math
import tarfile
from collections import namedtuple
from typing import List
from typing import Optional
from typing import Text
import jsonlines
import numpy as np
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
__all__ = [
"load_dict", "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs",
"max_dbfs", "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS",
"EOS", "UNK", "BLANK", "MASKCTC", "SPACE", "convert_samples_to_float32",
"convert_samples_from_float32"
]
IGNORE_ID = -1
# `sos` and `eos` using same token
SOS = "<eos>"
EOS = SOS
UNK = "<unk>"
BLANK = "<blank>"
MASKCTC = "<mask>"
SPACE = "<space>"
def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
if dict_path is None:
return None
with open(dict_path, "r") as f:
dictionary = f.readlines()
# first token is `<blank>`
# multi line: `<blank> 0\n`
# one line: `<blank>`
# space is relpace with <space>
char_list = [entry[:-1].split(" ")[0] for entry in dictionary]
if BLANK not in char_list:
char_list.insert(0, BLANK)
if EOS not in char_list:
char_list.append(EOS)
# for non-autoregressive maskctc model
if maskctc and MASKCTC not in char_list:
char_list.append(MASKCTC)
return char_list
def read_manifest(
manifest_path,
max_input_len=float('inf'),
min_input_len=0.0,
max_output_len=float('inf'),
min_output_len=0.0,
max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0, ):
"""Load and parse manifest file.
Args:
manifest_path ([type]): Manifest file to load and parse.
max_input_len ([type], optional): maximum output seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to float('inf').
min_input_len (float, optional): minimum input seq length,
in seconds for raw wav, in frame numbers for feature data.
Defaults to 0.0.
max_output_len (float, optional): maximum input seq length,
in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length,
in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional):
maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional):
minimum output seq length/output seq length ratio. Defaults to 0.05.
Raises:
IOError: If failed to parse the manifest.
Returns:
List[dict]: Manifest parsing results.
"""
manifest = []
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
feat_len = json_data["input"][0]["shape"][
0] if "input" in json_data and "shape" in json_data["input"][
0] else 1.0
token_len = json_data["output"][0]["shape"][
0] if "output" in json_data and "shape" in json_data["output"][
0] else 1.0
conditions = [
feat_len >= min_input_len,
feat_len <= max_input_len,
token_len >= min_output_len,
token_len <= max_output_len,
token_len / feat_len >= min_output_input_ratio,
token_len / feat_len <= max_output_input_ratio,
]
if all(conditions):
manifest.append(json_data)
return manifest
# Tar File read
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
def parse_tar(file):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result = {}
f = tarfile.open(file)
for tarinfo in f.getmembers():
result[tarinfo.name] = tarinfo
return f, result
def subfile_from_tar(file, local_data=None):
"""Get subfile object from tar.
tar:tarpath#filename
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath, filename = file.split(':', 1)[1].split('#', 1)
if local_data is None:
local_data = TarLocalData(tar2info={}, tar2object={})
assert isinstance(local_data, TarLocalData)
if 'tar2info' not in local_data.__dict__:
local_data.tar2info = {}
if 'tar2object' not in local_data.__dict__:
local_data.tar2object = {}
if tarpath not in local_data.tar2info:
fobj, infos = parse_tar(tarpath)
local_data.tar2info[tarpath] = infos
local_data.tar2object[tarpath] = fobj
else:
fobj = local_data.tar2object[tarpath]
infos = local_data.tar2info[tarpath]
return fobj.extractfile(infos[filename])
def rms_to_db(rms: float):
"""Root Mean Square to dB.
Args:
rms ([float]): root mean square
Returns:
float: dB
"""
return 20.0 * math.log10(max(1e-16, rms))
def rms_to_dbfs(rms: float):
"""Root Mean Square to dBFS.
https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/
Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB.
dB = dBFS + 3.0103
dBFS = db - 3.0103
e.g. 0 dB = -3.0103 dBFS
Args:
rms ([float]): root mean square
Returns:
float: dBFS
"""
return rms_to_db(rms) - 3.0103
def max_dbfs(sample_data: np.ndarray):
"""Peak dBFS based on the maximum energy sample.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
"""
# Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization.
return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data))))
def mean_dbfs(sample_data):
"""Peak dBFS based on the RMS energy.
Args:
sample_data ([np.ndarray]): float array, [-1, 1].
Returns:
float: dBFS
"""
return rms_to_dbfs(
math.sqrt(np.mean(np.square(sample_data, dtype=np.float64))))
def gain_db_to_ratio(gain_db: float):
"""dB to ratio
Args:
gain_db (float): gain in dB
Returns:
float: scale in amp
"""
return math.pow(10.0, gain_db / 20.0)
def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103):
"""Nomalize audio to dBFS.
Args:
sample_data (np.ndarray): input wave samples, [-1, 1].
dbfs (float, optional): target dBFS. Defaults to -3.0103.
Returns:
np.ndarray: normalized wave
"""
return np.maximum(
np.minimum(sample_data * gain_db_to_ratio(dbfs - max_dbfs(sample_data)),
1.0), -1.0)
def _load_json_cmvn(json_cmvn_file):
""" Load the json format cmvn stats file and calculate cmvn
Args:
json_cmvn_file: cmvn stats file in json format
Returns:
a numpy array of [means, vars]
"""
with open(json_cmvn_file) as f:
cmvn_stats = json.load(f)
means = cmvn_stats['mean_stat']
variance = cmvn_stats['var_stat']
count = cmvn_stats['frame_num']
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def _load_kaldi_cmvn(kaldi_cmvn_file):
""" Load the kaldi format cmvn stats file and calculate cmvn
Args:
kaldi_cmvn_file: kaldi text style global cmvn file, which
is generated by:
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
Returns:
a numpy array of [means, vars]
"""
means = []
variance = []
with open(kaldi_cmvn_file, 'r') as fid:
# kaldi binary file start with '\0B'
if fid.read(2) == '\0B':
logger.error('kaldi cmvn binary file is not supported, please '
'recompute it by: compute-cmvn-stats --binary=false '
' scp:feats.scp global_cmvn')
sys.exit(1)
fid.seek(0)
arr = fid.read().split()
assert (arr[0] == '[')
assert (arr[-2] == '0')
assert (arr[-1] == ']')
feat_dim = int((len(arr) - 2 - 2) / 2)
for i in range(1, feat_dim + 1):
means.append(float(arr[i]))
count = float(arr[feat_dim + 1])
for i in range(feat_dim + 2, 2 * feat_dim + 2):
variance.append(float(arr[i]))
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def load_cmvn(cmvn_file: str, filetype: str):
"""load cmvn from file.
Args:
cmvn_file (str): cmvn path.
filetype (str): file type, optional[npz, json, kaldi].
Raises:
ValueError: file type not support.
Returns:
Tuple[np.ndarray, np.ndarray]: mean, istd
"""
assert filetype in ['npz', 'json', 'kaldi'], filetype
filetype = filetype.lower()
if filetype == "json":
cmvn = _load_json_cmvn(cmvn_file)
elif filetype == "kaldi":
cmvn = _load_kaldi_cmvn(cmvn_file)
elif filetype == "npz":
eps = 1e-14
npzfile = np.load(cmvn_file)
mean = np.squeeze(npzfile["mean"])
std = np.squeeze(npzfile["std"])
istd = 1 / (std + eps)
cmvn = [mean, istd]
else:
raise ValueError(f"cmvn file type no support: {filetype}")
return cmvn[0], cmvn[1]
def convert_samples_to_float32(samples):
"""Convert sample type to float32.
Audio sample type is usually integer or float-point.
Integers will be scaled to [-1, 1] in float32.
PCM16 -> PCM32
"""
float32_samples = samples.astype('float32')
if samples.dtype in np.sctypes['int']:
bits = np.iinfo(samples.dtype).bits
float32_samples *= (1. / 2**(bits - 1))
elif samples.dtype in np.sctypes['float']:
pass
else:
raise TypeError("Unsupported sample type: %s." % samples.dtype)
return float32_samples
def convert_samples_from_float32(samples, dtype):
"""Convert sample type from float32 to dtype.
Audio sample type is usually integer or float-point. For integer
type, float32 will be rescaled from [-1, 1] to the maximum range
supported by the integer type.
PCM32 -> PCM16
"""
dtype = np.dtype(dtype)
output_samples = samples.copy()
if dtype in np.sctypes['int']:
bits = np.iinfo(dtype).bits
output_samples *= (2**(bits - 1) / 1.)
min_val = np.iinfo(dtype).min
max_val = np.iinfo(dtype).max
output_samples[output_samples > max_val] = max_val
output_samples[output_samples < min_val] = min_val
elif samples.dtype in np.sctypes['float']:
min_val = np.finfo(dtype).min
max_val = np.finfo(dtype).max
output_samples[output_samples > max_val] = max_val
output_samples[output_samples < min_val] = min_val
else:
raise TypeError("Unsupported sample type: %s." % samples.dtype)
return output_samples.astype(dtype)
......@@ -23,7 +23,7 @@ import paddle
from paddle import distributed as dist
from paddle import inference
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.audio.text.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.models.ds2 import DeepSpeech2InferModel
from paddlespeech.s2t.models.ds2 import DeepSpeech2Model
......
......@@ -27,6 +27,7 @@ from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope
......@@ -134,7 +135,8 @@ class U2Trainer(Trainer):
msg = f"Valid: Rank: {dist.get_rank()}, "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
#msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
......@@ -195,7 +197,6 @@ class U2Trainer(Trainer):
except Exception as e:
logger.error(e)
raise e
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1:
......@@ -224,186 +225,14 @@ class U2Trainer(Trainer):
config = self.config.clone()
self.use_streamdata = config.get("use_stream_data", False)
if self.train:
# train/valid dataset, return token ids
if self.use_streamdata:
self.train_loader = StreamDataLoader(
manifest_file=config.train_manifest,
train_mode=True,
unit_type=config.unit_type,
batch_size=config.batch_size,
num_mel_bins=config.feat_dim,
frame_length=config.window_ms,
frame_shift=config.stride_ms,
dither=config.dither,
minlen_in=config.minlen_in,
maxlen_in=config.maxlen_in,
minlen_out=config.minlen_out,
maxlen_out=config.maxlen_out,
resample_rate=config.resample_rate,
augment_conf=config.augment_conf, # dict
shuffle_size=config.shuffle_size,
sort_size=config.sort_size,
n_iter_processes=config.num_workers,
prefetch_factor=config.prefetch_factor,
dist_sampler=config.get('dist_sampler', False),
cmvn_file=config.cmvn_file,
vocab_filepath=config.vocab_filepath,
)
self.valid_loader = StreamDataLoader(
manifest_file=config.dev_manifest,
train_mode=False,
unit_type=config.unit_type,
batch_size=config.batch_size,
num_mel_bins=config.feat_dim,
frame_length=config.window_ms,
frame_shift=config.stride_ms,
dither=config.dither,
minlen_in=config.minlen_in,
maxlen_in=config.maxlen_in,
minlen_out=config.minlen_out,
maxlen_out=config.maxlen_out,
resample_rate=config.resample_rate,
augment_conf=config.augment_conf, # dict
shuffle_size=config.shuffle_size,
sort_size=config.sort_size,
n_iter_processes=config.num_workers,
prefetch_factor=config.prefetch_factor,
dist_sampler=config.get('dist_sampler', False),
cmvn_file=config.cmvn_file,
vocab_filepath=config.vocab_filepath,
)
else:
self.train_loader = BatchDataLoader(
json_file=config.train_manifest,
train_mode=True,
sortagrad=config.sortagrad,
batch_size=config.batch_size,
maxlen_in=config.maxlen_in,
maxlen_out=config.maxlen_out,
minibatches=config.minibatches,
mini_batch_size=self.args.ngpu,
batch_count=config.batch_count,
batch_bins=config.batch_bins,
batch_frames_in=config.batch_frames_in,
batch_frames_out=config.batch_frames_out,
batch_frames_inout=config.batch_frames_inout,
preprocess_conf=config.preprocess_config,
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1,
dist_sampler=config.get('dist_sampler', False),
shortest_first=False)
self.valid_loader = BatchDataLoader(
json_file=config.dev_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=self.args.ngpu,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.preprocess_config,
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1,
dist_sampler=config.get('dist_sampler', False),
shortest_first=False)
self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
logger.info("Setup train/valid Dataloader!")
else:
decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1)
# test dataset, return raw text
if self.use_streamdata:
self.test_loader = StreamDataLoader(
manifest_file=config.test_manifest,
train_mode=False,
unit_type=config.unit_type,
batch_size=config.batch_size,
num_mel_bins=config.feat_dim,
frame_length=config.window_ms,
frame_shift=config.stride_ms,
dither=0.0,
minlen_in=0.0,
maxlen_in=float('inf'),
minlen_out=0,
maxlen_out=float('inf'),
resample_rate=config.resample_rate,
augment_conf=config.augment_conf, # dict
shuffle_size=config.shuffle_size,
sort_size=config.sort_size,
n_iter_processes=config.num_workers,
prefetch_factor=config.prefetch_factor,
dist_sampler=config.get('dist_sampler', False),
cmvn_file=config.cmvn_file,
vocab_filepath=config.vocab_filepath,
)
self.align_loader = StreamDataLoader(
manifest_file=config.test_manifest,
train_mode=False,
unit_type=config.unit_type,
batch_size=config.batch_size,
num_mel_bins=config.feat_dim,
frame_length=config.window_ms,
frame_shift=config.stride_ms,
dither=0.0,
minlen_in=0.0,
maxlen_in=float('inf'),
minlen_out=0,
maxlen_out=float('inf'),
resample_rate=config.resample_rate,
augment_conf=config.augment_conf, # dict
shuffle_size=config.shuffle_size,
sort_size=config.sort_size,
n_iter_processes=config.num_workers,
prefetch_factor=config.prefetch_factor,
dist_sampler=config.get('dist_sampler', False),
cmvn_file=config.cmvn_file,
vocab_filepath=config.vocab_filepath,
)
else:
self.test_loader = BatchDataLoader(
json_file=config.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=decode_batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.preprocess_config,
n_iter_processes=1,
subsampling_factor=1,
num_encs=1)
self.align_loader = BatchDataLoader(
json_file=config.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=decode_batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.preprocess_config,
n_iter_processes=1,
subsampling_factor=1,
num_encs=1)
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args)
logger.info("Setup test/align Dataloader!")
def setup_model(self):
......
......@@ -25,7 +25,7 @@ from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import load_dict
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.scheduler import LRSchedulerFactory
......@@ -104,7 +104,8 @@ class U2Trainer(Trainer):
@paddle.no_grad()
def valid(self):
self.model.eval()
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
......@@ -131,7 +132,8 @@ class U2Trainer(Trainer):
msg = f"Valid: Rank: {dist.get_rank()}, "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
......@@ -150,8 +152,8 @@ class U2Trainer(Trainer):
# paddle.jit.save(script_model, script_model_path)
self.before_train()
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train()
......@@ -162,7 +164,8 @@ class U2Trainer(Trainer):
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch : {}/{}, ".format(batch_index + 1,
if not self.use_streamdata:
msg += "batch : {}/{}, ".format(batch_index + 1,
len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time)
......@@ -198,87 +201,23 @@ class U2Trainer(Trainer):
self.new_epoch()
def setup_dataloader(self):
config = self.config.clone()
# train/valid dataset, return token ids
self.train_loader = BatchDataLoader(
json_file=config.train_manifest,
train_mode=True,
sortagrad=False,
batch_size=config.batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=self.args.ngpu,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.preprocess_config,
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1)
self.valid_loader = BatchDataLoader(
json_file=config.dev_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=self.args.ngpu,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=None,
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1)
decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1)
# test dataset, return raw text
self.test_loader = BatchDataLoader(
json_file=config.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=decode_batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=None,
n_iter_processes=1,
subsampling_factor=1,
num_encs=1)
self.align_loader = BatchDataLoader(
json_file=config.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=decode_batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=None,
n_iter_processes=1,
subsampling_factor=1,
num_encs=1)
logger.info("Setup train/valid/test/align Dataloader!")
self.use_streamdata = config.get("use_stream_data", False)
if self.train:
config = self.config.clone()
self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
config = self.config.clone()
config['preprocess_config'] = None
self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
logger.info("Setup train/valid Dataloader!")
else:
config = self.config.clone()
config['preprocess_config'] = None
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
config = self.config.clone()
config['preprocess_config'] = None
self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args)
logger.info("Setup test/align Dataloader!")
def setup_model(self):
config = self.config
......@@ -406,7 +345,8 @@ class U2Tester(U2Trainer):
def test(self):
assert self.args.result_file
self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
if not self.use_streamdata:
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.config.stride_ms
error_rate_type = None
......
......@@ -25,7 +25,7 @@ import paddle
from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.models.u2_st import U2STModel
from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope
......@@ -120,7 +120,8 @@ class U2STTrainer(Trainer):
@paddle.no_grad()
def valid(self):
self.model.eval()
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
......@@ -153,7 +154,8 @@ class U2STTrainer(Trainer):
msg = f"Valid: Rank: {dist.get_rank()}, "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
......@@ -172,8 +174,8 @@ class U2STTrainer(Trainer):
# paddle.jit.save(script_model, script_model_path)
self.before_train()
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train()
......@@ -191,7 +193,8 @@ class U2STTrainer(Trainer):
self.train_batch(batch_index, batch, msg)
self.after_train_batch()
report('iter', batch_index + 1)
report('total', len(self.train_loader))
if not self.use_streamdata:
report('total', len(self.train_loader))
report('reader_cost', dataload_time)
observation['batch_cost'] = observation[
'reader_cost'] + observation['step_cost']
......@@ -241,79 +244,18 @@ class U2STTrainer(Trainer):
load_transcript = True if config.model_conf.asr_weight > 0 else False
config = self.config.clone()
config['load_transcript'] = load_transcript
self.use_streamdata = config.get("use_stream_data", False)
if self.train:
# train/valid dataset, return token ids
self.train_loader = BatchDataLoader(
json_file=config.train_manifest,
train_mode=True,
sortagrad=False,
batch_size=config.batch_size,
maxlen_in=config.maxlen_in,
maxlen_out=config.maxlen_out,
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.
preprocess_config, # aug will be off when train_mode=False
n_iter_processes=config.num_workers,
subsampling_factor=1,
load_aux_output=load_transcript,
num_encs=1,
dist_sampler=True)
self.valid_loader = BatchDataLoader(
json_file=config.dev_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.
preprocess_config, # aug will be off when train_mode=False
n_iter_processes=config.num_workers,
subsampling_factor=1,
load_aux_output=load_transcript,
num_encs=1,
dist_sampler=False)
self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
logger.info("Setup train/valid Dataloader!")
else:
# test dataset, return raw text
decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1)
self.test_loader = BatchDataLoader(
json_file=config.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=decode_batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.
preprocess_config, # aug will be off when train_mode=False
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1,
dist_sampler=False)
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
logger.info("Setup test Dataloader!")
def setup_model(self):
config = self.config
model_conf = config
......@@ -468,7 +410,8 @@ class U2STTester(U2STTrainer):
def test(self):
assert self.args.result_file
self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
if not self.use_streamdata:
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
decode_cfg = self.config.decode
bleu_func = bleu_score.char_bleu if decode_cfg.error_rate_type == 'char-bleu' else bleu_score.bleu
......
......@@ -30,9 +30,10 @@ from paddlespeech.s2t.io.reader import LoadInputsAndTargets
from paddlespeech.s2t.utils.log import Log
import paddlespeech.audio.streamdata as streamdata
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.audio.text.text_featurizer import TextFeaturizer
from yacs.config import CfgNode
__all__ = ["BatchDataLoader"]
__all__ = ["BatchDataLoader", "StreamDataLoader"]
logger = Log(__name__).getlog()
......@@ -60,12 +61,36 @@ def batch_collate(x):
"""
return x[0]
def read_preprocess_cfg(preprocess_conf_file):
augment_conf = dict()
preprocess_cfg = CfgNode(new_allowed=True)
preprocess_cfg.merge_from_file(preprocess_conf_file)
for idx, process in enumerate(preprocess_cfg["process"]):
opts = dict(process)
process_type = opts.pop("type")
if process_type == 'time_warp':
augment_conf['max_w'] = process['max_time_warp']
augment_conf['w_inplace'] = process['inplace']
augment_conf['w_mode'] = process['mode']
if process_type == 'freq_mask':
augment_conf['max_f'] = process['F']
augment_conf['num_f_mask'] = process['n_mask']
augment_conf['f_inplace'] = process['inplace']
augment_conf['f_replace_with_zero'] = process['replace_with_zero']
if process_type == 'time_mask':
augment_conf['max_t'] = process['T']
augment_conf['num_t_mask'] = process['n_mask']
augment_conf['t_inplace'] = process['inplace']
augment_conf['t_replace_with_zero'] = process['replace_with_zero']
return augment_conf
class StreamDataLoader():
def __init__(self,
manifest_file: str,
train_mode: bool,
unit_type: str='char',
batch_size: int=0,
preprocess_conf=None,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
......@@ -75,7 +100,6 @@ class StreamDataLoader():
minlen_out: float=0.0,
maxlen_out: float=float('inf'),
resample_rate: int=16000,
augment_conf: dict=None,
shuffle_size: int=10000,
sort_size: int=1000,
n_iter_processes: int=1,
......@@ -95,12 +119,27 @@ class StreamDataLoader():
self.feat_dim = num_mel_bins
self.vocab_size = text_featurizer.vocab_size
augment_conf = read_preprocess_cfg(preprocess_conf)
# The list of shard
shardlist = []
with open(manifest_file, "r") as f:
for line in f.readlines():
shardlist.append(line.strip())
world_size = 1
try:
world_size = paddle.distributed.get_world_size()
except Exception as e:
logger.warninig(e)
logger.warninig("can not get world_size using paddle.distributed.get_world_size(), use world_size=1")
assert(len(shardlist) >= world_size, "the length of shard list should >= number of gpus/xpus/...")
update_n_iter_processes = int(max(min(len(shardlist)/world_size - 1, self.n_iter_processes), 0))
logger.info(f"update_n_iter_processes {update_n_iter_processes}")
if update_n_iter_processes != self.n_iter_processes:
self.n_iter_processes = update_n_iter_processes
logger.info(f"change nun_workers to {self.n_iter_processes}")
if self.dist_sampler:
base_dataset = streamdata.DataPipeline(
streamdata.SimpleShardList(shardlist),
......@@ -116,16 +155,16 @@ class StreamDataLoader():
)
self.dataset = base_dataset.append_list(
streamdata.tokenize(symbol_table),
streamdata.data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_in),
streamdata.resample(resample_rate=resample_rate),
streamdata.compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither),
streamdata.spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
streamdata.audio_tokenize(symbol_table),
streamdata.audio_data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_out),
streamdata.audio_resample(resample_rate=resample_rate),
streamdata.audio_compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither),
streamdata.audio_spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
streamdata.shuffle(shuffle_size),
streamdata.sort(sort_size=sort_size),
streamdata.batched(batch_size),
streamdata.padding(),
streamdata.cmvn(cmvn_file)
streamdata.audio_padding(),
streamdata.audio_cmvn(cmvn_file)
)
if paddle.__version__ >= '2.3.2':
......@@ -295,3 +334,119 @@ class BatchDataLoader():
echo += f"shortest_first: {self.shortest_first}, "
echo += f"file: {self.json_file}"
return echo
class DataLoaderFactory():
@staticmethod
def get_dataloader(mode: str, config, args):
config = config.clone()
use_streamdata = config.get("use_stream_data", False)
if use_streamdata:
if mode == 'train':
config['manifest'] = config.train_manifest
config['train_mode'] = True
elif mode == 'valid':
config['manifest'] = config.dev_manifest
config['train_mode'] = False
elif model == 'test' or mode == 'align':
config['manifest'] = config.test_manifest
config['train_mode'] = False
config['dither'] = 0.0
config['minlen_in'] = 0.0
config['maxlen_in'] = float('inf')
config['minlen_out'] = 0
config['maxlen_out'] = float('inf')
config['dist_sampler'] = False
else:
raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'")
return StreamDataLoader(
manifest_file=config.manifest,
train_mode=config.train_mode,
unit_type=config.unit_type,
preprocess_conf=config.preprocess_config,
batch_size=config.batch_size,
num_mel_bins=config.feat_dim,
frame_length=config.window_ms,
frame_shift=config.stride_ms,
dither=config.dither,
minlen_in=config.minlen_in,
maxlen_in=config.maxlen_in,
minlen_out=config.minlen_out,
maxlen_out=config.maxlen_out,
resample_rate=config.resample_rate,
shuffle_size=config.shuffle_size,
sort_size=config.sort_size,
n_iter_processes=config.num_workers,
prefetch_factor=config.prefetch_factor,
dist_sampler=config.dist_sampler,
cmvn_file=config.cmvn_file,
vocab_filepath=config.vocab_filepath,
)
else:
if mode == 'train':
config['manifest'] = config.train_manifest
config['train_mode'] = True
config['mini_batch_size'] = args.ngpu
config['subsampling_factor'] = 1
config['num_encs'] = 1
elif mode == 'valid':
config['manifest'] = config.dev_manifest
config['train_mode'] = False
config['sortagrad'] = False
config['maxlen_in'] = float('inf')
config['maxlen_out'] = float('inf')
config['minibatches'] = 0
config['mini_batch_size'] = args.ngpu
config['batch_count'] = 'auto'
config['batch_bins'] = 0
config['batch_frames_in'] = 0
config['batch_frames_out'] = 0
config['batch_frames_inout'] = 0
config['subsampling_factor'] = 1
config['num_encs'] = 1
config['shortest_first'] = False
elif mode == 'test' or mode == 'align':
config['manifest'] = config.test_manifest
config['train_mode'] = False
config['sortagrad'] = False
config['batch_size'] = config.get('decode', dict()).get(
'decode_batch_size', 1)
config['maxlen_in'] = float('inf')
config['maxlen_out'] = float('inf')
config['minibatches'] = 0
config['mini_batch_size'] = 1
config['batch_count'] = 'auto'
config['batch_bins'] = 0
config['batch_frames_in'] = 0
config['batch_frames_out'] = 0
config['batch_frames_inout'] = 0
config['num_workers'] = 1
config['subsampling_factor'] = 1
config['num_encs'] = 1
config['dist_sampler'] = False
config['shortest_first'] = False
else:
raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'")
return BatchDataLoader(
json_file=config.manifest,
train_mode=config.train_mode,
sortagrad=config.sortagrad,
batch_size=config.batch_size,
maxlen_in=config.maxlen_in,
maxlen_out=config.maxlen_out,
minibatches=config.minibatches,
mini_batch_size=config.mini_batch_size,
batch_count=config.batch_count,
batch_bins=config.batch_bins,
batch_frames_in=config.batch_frames_in,
batch_frames_out=config.batch_frames_out,
batch_frames_inout=config.batch_frames_inout,
preprocess_conf=config.preprocess_config,
n_iter_processes=config.num_workers,
subsampling_factor=config.subsampling_factor,
load_aux_output=config.get('load_transcript', None),
num_encs=config.num_encs,
dist_sampler=config.dist_sampler,
shortest_first=config.shortest_first)
......@@ -48,9 +48,9 @@ from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos
from paddlespeech.s2t.utils.tensor_utils import pad_sequence
from paddlespeech.s2t.utils.tensor_utils import th_accuracy
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import pad_sequence
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.utils.utility import log_add
from paddlespeech.s2t.utils.utility import UpdateConfig
......
......@@ -38,8 +38,8 @@ from paddlespeech.s2t.modules.mask import subsequent_mask
from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos
from paddlespeech.s2t.utils.tensor_utils import th_accuracy
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ["U2STModel", "U2STInferModel"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册