提交 69055698 编写于 作者: H Hui Zhang

transformer using batch data loader

上级 3f611c75
......@@ -66,7 +66,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for dataset in train dev test; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type "char" \
--vocab_path="data/vocab.txt" \
......
......@@ -67,7 +67,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for dataset in train dev test; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type "char" \
--vocab_path="data/vocab.txt" \
......
......@@ -55,7 +55,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for dataset in train dev test; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type "char" \
--vocab_path="data/vocab.txt" \
......
......@@ -89,25 +89,28 @@ def create_manifest(data_dir, manifest_path):
text_filepath = os.path.join(subfolder, text_filelist[0])
for line in io.open(text_filepath, encoding="utf8"):
segments = line.strip().split()
n_token = len(segments[1:])
text = ' '.join(segments[1:]).lower()
audio_filepath = os.path.abspath(
os.path.join(subfolder, segments[0] + '.flac'))
audio_data, samplerate = soundfile.read(audio_filepath)
duration = float(len(audio_data)) / samplerate
utt = os.path.splitext(os.path.basename(audio_filepath))[0]
utt2spk = '-'.join(utt.split('-')[:2])
json_lines.append(
json.dumps({
'utt':
os.path.splitext(os.path.basename(audio_filepath))[0],
'feat':
audio_filepath,
'feat_shape': (duration, ), #second
'text':
text
'utt': utt,
'utt2spk': utt2spk,
'feat': audio_filepath,
'feat_shape': (duration, ), # second
'text': text,
}))
total_sec += duration
total_text += len(text)
total_text += n_token
total_num += 1
with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
......
......@@ -81,7 +81,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for set in train dev test dev-clean dev-other test-clean test-other; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type ${unit_type} \
--vocab_path="data/vocab.txt" \
......
......@@ -88,7 +88,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for set in train dev test dev-clean dev-other test-clean test-other; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type "spm" \
--spm_model_prefix ${bpeprefix} \
......
......@@ -50,7 +50,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for dataset in train dev test; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.npz" \
--unit_type "char" \
--vocab_path="data/vocab.txt" \
......
......@@ -65,7 +65,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for set in train dev test dev-clean dev-other test-clean test-other; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.npz" \
--unit_type ${unit_type} \
--vocab_path="data/vocab.txt" \
......
......@@ -63,7 +63,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for set in train dev test dev-clean dev-other test-clean test-other; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.npz" \
--unit_type ${unit_type} \
--vocab_path="data/vocab.txt" \
......
......@@ -89,7 +89,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for set in train dev test; do
{
python3 ${MAIN_ROOT}/utils/format_triplet_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type "spm" \
--spm_model_prefix ${bpeprefix} \
......
......@@ -66,7 +66,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for set in train dev test; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type ${unit_type} \
--vocab_path="data/vocab.txt" \
......
......@@ -63,7 +63,6 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type ${unit_type} \
--vocab_path="data/vocab.txt" \
......
......@@ -15,7 +15,7 @@ collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
augmentation_config: conf/augmentation.json
augmentation_config: conf/preprocess.yaml
batch_size: 4
raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank
......
......@@ -15,7 +15,7 @@ collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
augmentation_config: conf/augmentation.json
augmentation_config: conf/preprocess.yaml
batch_size: 4
raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank
......
......@@ -15,7 +15,7 @@ collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
augmentation_config: conf/augmentation.json
augmentation_config: conf/preprocess.yaml
batch_size: 4
raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank
......
process:
# extract kaldi fbank from PCM
- type: "fbank_kaldi"
fs: 16000
n_mels: 80
n_shift: 160
win_length: 400
dither: true
# these three processes are a.k.a. SpecAugument
- type: "time_warp"
max_time_warp: 5
inplace: true
mode: "PIL"
- type: "freq_mask"
F: 30
n_mask: 2
inplace: true
replace_with_zero: false
- type: "time_mask"
T: 40
n_mask: 2
inplace: true
replace_with_zero: false
......@@ -15,7 +15,7 @@ collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
augmentation_config: conf/augmentation.json
augmentation_config: conf/preprocess.yaml
batch_size: 4
raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank
......
......@@ -69,7 +69,6 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \
--unit_type "spm" \
--spm_model_prefix ${bpeprefix} \
......
......@@ -27,7 +27,9 @@ from paddle import distributed as dist
from paddle.io import DataLoader
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataset import ManifestDataset
from paddlespeech.s2t.io.sampler import SortagradBatchSampler
from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler
......@@ -247,92 +249,103 @@ class U2Trainer(Trainer):
def setup_dataloader(self):
config = self.config.clone()
config.defrost()
config.collator.keep_transcription_text = False
# train/valid dataset, return token ids
config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.dev_manifest
dev_dataset = ManifestDataset.from_config(config)
collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
if self.train:
# train/valid dataset, return token ids
self.train_loader = BatchDataLoader(
json_file=config.data.train_manifest,
train_mode=True,
sortagrad=False,
batch_size=config.collator.batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
else:
batch_sampler = SortagradBatchSampler(
train_dataset,
shuffle=True,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=self.args.nprocs,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.collator.
augmentation_config, # aug will be off when train_mode=False
n_iter_processes=config.collator.num_workers,
subsampling_factor=1,
num_encs=1)
self.valid_loader = BatchDataLoader(
json_file=config.data.dev_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.collator.batch_size,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn_train,
num_workers=config.collator.num_workers, )
self.valid_loader = DataLoader(
dev_dataset,
batch_size=config.collator.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev,
num_workers=config.collator.num_workers, )
# test dataset, return raw text
config.data.manifest = config.data.test_manifest
# filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now.
config.data.min_input_len = 0.0 # second
config.data.max_input_len = float('inf') # second
config.data.min_output_len = 0.0 # tokens
config.data.max_output_len = float('inf') # tokens
config.data.min_output_input_ratio = 0.00
config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config)
# return text ord id
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config),
num_workers=config.collator.num_workers, )
# return text token id
config.collator.keep_transcription_text = False
self.align_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config),
num_workers=config.collator.num_workers, )
logger.info("Setup train/valid/test/align Dataloader!")
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=self.args.nprocs,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.collator.
augmentation_config, # aug will be off when train_mode=False
n_iter_processes=config.collator.num_workers,
subsampling_factor=1,
num_encs=1)
logger.info("Setup train/valid Dataloader!")
else:
# test dataset, return raw text
self.test_loader = BatchDataLoader(
json_file=config.data.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.decoding.batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.collator.
augmentation_config, # aug will be off when train_mode=False
n_iter_processes=1,
subsampling_factor=1,
num_encs=1)
self.align_loader = BatchDataLoader(
json_file=config.data.test_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.decoding.batch_size,
maxlen_in=float('inf'),
maxlen_out=float('inf'),
minibatches=0,
mini_batch_size=1,
batch_count='auto',
batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.collator.
augmentation_config, # aug will be off when train_mode=False
n_iter_processes=1,
subsampling_factor=1,
num_encs=1)
logger.info("Setup test/align Dataloader!")
def setup_model(self):
config = self.config
model_conf = config.model
with UpdateConfig(model_conf):
model_conf.input_dim = self.train_loader.collate_fn.feature_size
model_conf.output_dim = self.train_loader.collate_fn.vocab_size
if self.train:
model_conf.input_dim = self.train_loader.feat_dim
model_conf.output_dim = self.train_loader.vocab_size
else:
model_conf.input_dim = self.test_loader.feat_dim
model_conf.output_dim = self.test_loader.vocab_size
model = U2Model.from_config(model_conf)
......@@ -341,6 +354,11 @@ class U2Trainer(Trainer):
logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
self.model = model
logger.info("Setup model!")
if not self.train:
return
train_config = config.training
optim_type = train_config.optim
......@@ -381,10 +399,9 @@ class U2Trainer(Trainer):
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
optimizer = OptimizerFactory.from_args(optim_type, optimzer_args)
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
logger.info("Setup model/optimizer/lr_scheduler!")
logger.info("Setup optimizer/lr_scheduler!")
class U2Tester(U2Trainer):
......@@ -419,14 +436,19 @@ class U2Tester(U2Trainer):
def __init__(self, config, args):
super().__init__(config, args)
self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list
def ordid2token(self, texts, texts_len):
def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """
trans = []
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(''.join([chr(i) for i in ids]))
trans.append(text_feature.defeaturize(ids.numpy().tolist()))
return trans
def compute_metrics(self,
......@@ -442,12 +464,11 @@ class U2Tester(U2Trainer):
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time()
text_feature = self.test_loader.collate_fn.text_feature
target_transcripts = self.ordid2token(texts, texts_len)
target_transcripts = self.id2token(texts, texts_len, self.text_feature)
result_transcripts, result_tokenids = self.model.decode(
audio,
audio_len,
text_feature=text_feature,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
......@@ -497,7 +518,7 @@ class U2Tester(U2Trainer):
self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.test_loader.collate_fn.stride_ms
stride_ms = self.config.collator.stride_ms
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0
......@@ -556,8 +577,8 @@ class U2Tester(U2Trainer):
def align(self):
ctc_utils.ctc_align(
self.model, self.align_loader, self.config.decoding.batch_size,
self.align_loader.collate_fn.stride_ms,
self.align_loader.collate_fn.vocab_list, self.args.result_file)
self.config.collator.stride_ms,
self.vocab_list, self.args.result_file)
def load_inferspec(self):
"""infer model and input spec.
......
......@@ -392,6 +392,7 @@ class U2Tester(U2Trainer):
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list
def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """
......@@ -529,8 +530,8 @@ class U2Tester(U2Trainer):
def align(self):
ctc_utils.ctc_align(
self.model, self.align_loader, self.config.decoding.batch_size,
self.align_loader.collate_fn.stride_ms,
self.align_loader.collate_fn.vocab_list, self.args.result_file)
self.config.collator.stride_ms,
self.vocab_list, self.args.result_file)
def load_inferspec(self):
"""infer model and input spec.
......
......@@ -207,34 +207,16 @@ class AudioDataset(Dataset):
if sort:
data = sorted(data, key=lambda x: x["feat_shape"][0])
if raw_wav:
assert data[0]['feat'].split(':')[0].splitext()[-1] not in ('.ark',
'.scp')
data = map(lambda x: (float(x['feat_shape'][0]) * 1000 / stride_ms))
path_suffix = data[0]['feat'].split(':')[0].splitext()[-1]
assert path_suffix not in ('.ark', '.scp')
# m second to n frame
data = list(
map(lambda x: (float(x['feat_shape'][0]) * 1000 / stride_ms),
data))
self.input_dim = data[0]['feat_shape'][1]
self.output_dim = data[0]['token_shape'][1]
# with open(data_file, 'r') as f:
# for line in f:
# arr = line.strip().split('\t')
# if len(arr) != 7:
# continue
# key = arr[0].split(':')[1]
# tokenid = arr[5].split(':')[1]
# output_dim = int(arr[6].split(':')[1].split(',')[1])
# if raw_wav:
# wav_path = ':'.join(arr[1].split(':')[1:])
# duration = int(float(arr[2].split(':')[1]) * 1000 / 10)
# data.append((key, wav_path, duration, tokenid))
# else:
# feat_ark = ':'.join(arr[1].split(':')[1:])
# feat_info = arr[2].split(':')[1].split(',')
# feat_dim = int(feat_info[1].strip())
# num_frames = int(feat_info[0].strip())
# data.append((key, feat_ark, num_frames, tokenid))
# self.input_dim = feat_dim
# self.output_dim = output_dim
valid_data = []
for i in range(len(data)):
length = data[i]['feat_shape'][0]
......@@ -242,17 +224,17 @@ class AudioDataset(Dataset):
# remove too lang or too short utt for both input and output
# to prevent from out of memory
if length > max_length or length < min_length:
# logging.warn('ignore utterance {} feature {}'.format(
# data[i][0], length))
pass
elif token_length > token_max_length or token_length < token_min_length:
pass
else:
valid_data.append(data[i])
logger.info(f"raw dataset len: {len(data)}")
data = valid_data
num_data = len(data)
logger.info(f"dataset len after filter: {num_data}")
self.minibatch = []
num_data = len(data)
# Dynamic batch size
if batch_type == 'dynamic':
assert (max_frames_in_batch > 0)
......@@ -277,7 +259,9 @@ class AudioDataset(Dataset):
cur = end
def __len__(self):
"""number of example(batch)"""
return len(self.minibatch)
def __getitem__(self, idx):
"""batch example of idx"""
return self.minibatch[idx]
......@@ -18,8 +18,10 @@ import kaldiio
import numpy as np
import soundfile
from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline as Transformation
from .utility import feat_type
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.log import Log
# from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline as Transformation
__all__ = ["LoadInputsAndTargets"]
......@@ -322,20 +324,7 @@ class LoadInputsAndTargets():
"Not supported: loader_type={}".format(filetype))
def file_type(self, filepath):
suffix = filepath.split(":")[0].split('.')[-1].lower()
if suffix == 'ark':
return 'mat'
elif suffix == 'scp':
return 'scp'
elif suffix == 'npy':
return 'npy'
elif suffix == 'npz':
return 'npz'
elif suffix in ['wav', 'flac']:
# PCM16
return 'sound'
else:
raise ValueError(f"Not support filetype: {suffix}")
return feat_type(filepath)
class SoundHDF5File():
......
......@@ -17,7 +17,7 @@ import numpy as np
from paddlespeech.s2t.utils.log import Log
__all__ = ["pad_list", "pad_sequence"]
__all__ = ["pad_list", "pad_sequence", "feat_type"]
logger = Log(__name__).getlog()
......@@ -85,3 +85,20 @@ def pad_sequence(sequences: List[np.ndarray],
out_tensor[:length, i, ...] = tensor
return out_tensor
def feat_type(filepath):
suffix = filepath.split(":")[0].split('.')[-1].lower()
if suffix == 'ark':
return 'mat'
elif suffix == 'scp':
return 'scp'
elif suffix == 'npy':
return 'npy'
elif suffix == 'npz':
return 'npz'
elif suffix in ['wav', 'flac']:
# PCM16
return 'sound'
else:
raise ValueError(f"Not support filetype: {suffix}")
......@@ -14,6 +14,7 @@
# Modified from espnet(https://github.com/espnet/espnet)
import librosa
import numpy as np
from python_speech_features import logfbank
def stft(x,
......@@ -304,3 +305,85 @@ class IStft():
win_length=self.win_length,
window=self.window,
center=self.center, )
class LogMelSpectrogramKaldi():
def __init__(
self,
fs=16000,
n_mels=80,
n_fft=512, # fft point
n_shift=160, # unit:sample, 10ms
win_length=400, # unit:sample, 25ms
window="povey",
fmin=20,
fmax=None,
eps=1e-10,
dither=False):
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
if n_shift > win_length:
raise ValueError("Stride size must not be greater than "
"window size.")
self.n_shift = n_shift / fs # unit: ms
self.win_length = win_length / fs # unit: ms
self.window = window
self.fmin = fmin
if fmax is None:
fmax_ = fmax if fmax else self.fs / 2
elif fmax > int(self.fs / 2):
raise ValueError("fmax must not be greater than half of "
"sample rate.")
self.fmax = fmax_
self.eps = eps
self.remove_dc_offset = True
self.preemph = 0.97
self.dither = dither
def __repr__(self):
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"n_shift={n_shift}, win_length={win_length}, window={window}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
fmin=self.fmin,
fmax=self.fmax,
eps=self.eps, ))
def __call__(self, x):
"""
Args:
x (np.ndarray): shape (Ti,)
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
if x.ndim != 1:
raise ValueError("Not support x: [Time, Channel]")
if x.dtype == np.int16:
x = x / 2**(16 - 1)
return logfbank(
signal=x,
samplerate=self.fs,
winlen=self.win_length, # unit ms
winstep=self.n_shift, # unit ms
nfilt=self.n_mels,
nfft=self.n_fft,
lowfreq=self.fmin,
highfreq=self.fmax,
dither=self.dither,
remove_dc_offset=self.remove_dc_offset,
preemph=self.preemph,
wintype=self.window)
......@@ -45,6 +45,7 @@ import_alias = dict(
stft2fbank="paddlespeech.s2t.transform.spectrogram:Stft2LogMelSpectrogram",
wpe="paddlespeech.s2t.transform.wpe:WPE",
channel_selector="paddlespeech.s2t.transform.channel_selector:ChannelSelector",
fbank_kaldi="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogramKaldi",
)
......
......@@ -20,13 +20,13 @@ import json
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.frontend.utility import read_manifest
from paddlespeech.s2t.io.utility import feat_type
from paddlespeech.s2t.utils.utility import add_arguments
from paddlespeech.s2t.utils.utility import print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), mat(ark), scp")
add_arg('cmvn_path', str,
'examples/librispeech/data/mean_std.json',
"Filepath of cmvn.")
......@@ -62,24 +62,64 @@ def main():
vocab_size = text_feature.vocab_size
print(f"Vocab size: {vocab_size}")
# josnline like this
# {
# "input": [{"name": "input1", "shape": (100, 83), "feat": "xxx.ark:123"}],
# "output": [{"name":"target1", "shape": (40, 5002), "text": "a b c de"}],
# "utt2spk": "111-2222",
# "utt": "111-2222-333"
# }
count = 0
for manifest_path in args.manifest_paths:
manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons:
output_json = {
"input": [],
"output": [],
'utt': line_json['utt'],
'utt2spk': line_json.get('utt2spk', 'global'),
}
# output
line = line_json['text']
tokens = text_feature.tokenize(line)
tokenids = text_feature.featurize(line)
line_json['token'] = tokens
line_json['token_id'] = tokenids
line_json['token_shape'] = (len(tokenids), vocab_size)
feat_shape = line_json['feat_shape']
assert isinstance(feat_shape, (list, tuple)), type(feat_shape)
if args.feat_type == 'raw':
feat_shape.append(feat_dim)
line_json['filetype'] = 'sound'
else: # kaldi
raise NotImplementedError('no support kaldi feat now!')
fout.write(json.dumps(line_json) + '\n')
if isinstance(line, str):
# only one target
tokens = text_feature.tokenize(line)
tokenids = text_feature.featurize(line)
output_json['output'].append({
'name': 'traget1',
'shape': (len(tokenids), vocab_size),
'text': line,
'token': ' '.join(tokens),
'tokenid': ' '.join(map(str, tokenids)),
})
else:
# isinstance(line, list), multi target
raise NotImplementedError("not support multi output now!")
# input
line = line_json['feat']
if isinstance(line, str):
# only one input
feat_shape = line_json['feat_shape']
assert isinstance(feat_shape, (list, tuple)), type(feat_shape)
filetype = feat_type(line)
if filetype == 'sound':
feat_shape.append(feat_dim)
else: # kaldi
raise NotImplementedError('no support kaldi feat now!')
output_json['input'].append({
"name": "input1",
"shape": feat_shape,
"feat": line,
"filetype": filetype,
})
else:
# isinstance(line, list), multi input
raise NotImplementedError("not support multi input now!")
fout.write(json.dumps(output_json) + '\n')
count += 1
print(f"Examples number: {count}")
......
......@@ -20,13 +20,13 @@ import json
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.frontend.utility import read_manifest
from paddlespeech.s2t.io.utility import feat_type
from paddlespeech.s2t.utils.utility import add_arguments
from paddlespeech.s2t.utils.utility import print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), kaldi")
add_arg('cmvn_path', str,
'examples/librispeech/data/mean_std.json',
"Filepath of cmvn.")
......@@ -79,9 +79,11 @@ def main():
line_json['token1'] = tokens
line_json['token_id1'] = tokenids
line_json['token_shape1'] = (len(tokenids), vocab_size)
feat_shape = line_json['feat_shape']
assert isinstance(feat_shape, (list, tuple)), type(feat_shape)
if args.feat_type == 'raw':
filetype = feat_type(line_json['feat'])
if filetype == 'sound':
feat_shape.append(feat_dim)
else: # kaldi
raise NotImplementedError('no support kaldi feat now!')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册