提交 69055698 编写于 作者: H Hui Zhang

transformer using batch data loader

上级 3f611c75
...@@ -66,7 +66,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -66,7 +66,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for dataset in train dev test; do for dataset in train dev test; do
{ {
python3 ${MAIN_ROOT}/utils/format_data.py \ python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \ --cmvn_path "data/mean_std.json" \
--unit_type "char" \ --unit_type "char" \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
......
...@@ -67,7 +67,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -67,7 +67,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for dataset in train dev test; do for dataset in train dev test; do
{ {
python3 ${MAIN_ROOT}/utils/format_data.py \ python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \ --cmvn_path "data/mean_std.json" \
--unit_type "char" \ --unit_type "char" \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
......
...@@ -55,7 +55,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -55,7 +55,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for dataset in train dev test; do for dataset in train dev test; do
{ {
python3 ${MAIN_ROOT}/utils/format_data.py \ python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \ --cmvn_path "data/mean_std.json" \
--unit_type "char" \ --unit_type "char" \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
......
...@@ -89,25 +89,28 @@ def create_manifest(data_dir, manifest_path): ...@@ -89,25 +89,28 @@ def create_manifest(data_dir, manifest_path):
text_filepath = os.path.join(subfolder, text_filelist[0]) text_filepath = os.path.join(subfolder, text_filelist[0])
for line in io.open(text_filepath, encoding="utf8"): for line in io.open(text_filepath, encoding="utf8"):
segments = line.strip().split() segments = line.strip().split()
n_token = len(segments[1:])
text = ' '.join(segments[1:]).lower() text = ' '.join(segments[1:]).lower()
audio_filepath = os.path.abspath( audio_filepath = os.path.abspath(
os.path.join(subfolder, segments[0] + '.flac')) os.path.join(subfolder, segments[0] + '.flac'))
audio_data, samplerate = soundfile.read(audio_filepath) audio_data, samplerate = soundfile.read(audio_filepath)
duration = float(len(audio_data)) / samplerate duration = float(len(audio_data)) / samplerate
utt = os.path.splitext(os.path.basename(audio_filepath))[0]
utt2spk = '-'.join(utt.split('-')[:2])
json_lines.append( json_lines.append(
json.dumps({ json.dumps({
'utt': 'utt': utt,
os.path.splitext(os.path.basename(audio_filepath))[0], 'utt2spk': utt2spk,
'feat': 'feat': audio_filepath,
audio_filepath, 'feat_shape': (duration, ), # second
'feat_shape': (duration, ), #second 'text': text,
'text':
text
})) }))
total_sec += duration total_sec += duration
total_text += len(text) total_text += n_token
total_num += 1 total_num += 1
with codecs.open(manifest_path, 'w', 'utf-8') as out_file: with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
......
...@@ -81,7 +81,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -81,7 +81,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for set in train dev test dev-clean dev-other test-clean test-other; do for set in train dev test dev-clean dev-other test-clean test-other; do
{ {
python3 ${MAIN_ROOT}/utils/format_data.py \ python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \ --cmvn_path "data/mean_std.json" \
--unit_type ${unit_type} \ --unit_type ${unit_type} \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
......
...@@ -88,7 +88,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -88,7 +88,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for set in train dev test dev-clean dev-other test-clean test-other; do for set in train dev test dev-clean dev-other test-clean test-other; do
{ {
python3 ${MAIN_ROOT}/utils/format_data.py \ python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \ --cmvn_path "data/mean_std.json" \
--unit_type "spm" \ --unit_type "spm" \
--spm_model_prefix ${bpeprefix} \ --spm_model_prefix ${bpeprefix} \
......
...@@ -50,7 +50,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -50,7 +50,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for dataset in train dev test; do for dataset in train dev test; do
{ {
python3 ${MAIN_ROOT}/utils/format_data.py \ python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.npz" \ --cmvn_path "data/mean_std.npz" \
--unit_type "char" \ --unit_type "char" \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
......
...@@ -65,7 +65,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -65,7 +65,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for set in train dev test dev-clean dev-other test-clean test-other; do for set in train dev test dev-clean dev-other test-clean test-other; do
{ {
python3 ${MAIN_ROOT}/utils/format_data.py \ python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.npz" \ --cmvn_path "data/mean_std.npz" \
--unit_type ${unit_type} \ --unit_type ${unit_type} \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
......
...@@ -63,7 +63,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -63,7 +63,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for set in train dev test dev-clean dev-other test-clean test-other; do for set in train dev test dev-clean dev-other test-clean test-other; do
{ {
python3 ${MAIN_ROOT}/utils/format_data.py \ python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.npz" \ --cmvn_path "data/mean_std.npz" \
--unit_type ${unit_type} \ --unit_type ${unit_type} \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
......
...@@ -89,7 +89,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -89,7 +89,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for set in train dev test; do for set in train dev test; do
{ {
python3 ${MAIN_ROOT}/utils/format_triplet_data.py \ python3 ${MAIN_ROOT}/utils/format_triplet_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \ --cmvn_path "data/mean_std.json" \
--unit_type "spm" \ --unit_type "spm" \
--spm_model_prefix ${bpeprefix} \ --spm_model_prefix ${bpeprefix} \
......
...@@ -66,7 +66,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ...@@ -66,7 +66,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
for set in train dev test; do for set in train dev test; do
{ {
python3 ${MAIN_ROOT}/utils/format_data.py \ python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \ --cmvn_path "data/mean_std.json" \
--unit_type ${unit_type} \ --unit_type ${unit_type} \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
......
...@@ -63,7 +63,6 @@ fi ...@@ -63,7 +63,6 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size # format manifest with tokenids, vocab size
python3 ${MAIN_ROOT}/utils/format_data.py \ python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \ --cmvn_path "data/mean_std.json" \
--unit_type ${unit_type} \ --unit_type ${unit_type} \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
......
...@@ -15,7 +15,7 @@ collator: ...@@ -15,7 +15,7 @@ collator:
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
unit_type: 'spm' unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200' spm_model_prefix: 'data/bpe_unigram_200'
augmentation_config: conf/augmentation.json augmentation_config: conf/preprocess.yaml
batch_size: 4 batch_size: 4
raw_wav: True # use raw_wav or kaldi feature raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank spectrum_type: fbank #linear, mfcc, fbank
......
...@@ -15,7 +15,7 @@ collator: ...@@ -15,7 +15,7 @@ collator:
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
unit_type: 'spm' unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200' spm_model_prefix: 'data/bpe_unigram_200'
augmentation_config: conf/augmentation.json augmentation_config: conf/preprocess.yaml
batch_size: 4 batch_size: 4
raw_wav: True # use raw_wav or kaldi feature raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank spectrum_type: fbank #linear, mfcc, fbank
......
...@@ -15,7 +15,7 @@ collator: ...@@ -15,7 +15,7 @@ collator:
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
unit_type: 'spm' unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200' spm_model_prefix: 'data/bpe_unigram_200'
augmentation_config: conf/augmentation.json augmentation_config: conf/preprocess.yaml
batch_size: 4 batch_size: 4
raw_wav: True # use raw_wav or kaldi feature raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank spectrum_type: fbank #linear, mfcc, fbank
......
process:
# extract kaldi fbank from PCM
- type: "fbank_kaldi"
fs: 16000
n_mels: 80
n_shift: 160
win_length: 400
dither: true
# these three processes are a.k.a. SpecAugument
- type: "time_warp"
max_time_warp: 5
inplace: true
mode: "PIL"
- type: "freq_mask"
F: 30
n_mask: 2
inplace: true
replace_with_zero: false
- type: "time_mask"
T: 40
n_mask: 2
inplace: true
replace_with_zero: false
...@@ -15,7 +15,7 @@ collator: ...@@ -15,7 +15,7 @@ collator:
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
unit_type: 'spm' unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200' spm_model_prefix: 'data/bpe_unigram_200'
augmentation_config: conf/augmentation.json augmentation_config: conf/preprocess.yaml
batch_size: 4 batch_size: 4
raw_wav: True # use raw_wav or kaldi feature raw_wav: True # use raw_wav or kaldi feature
spectrum_type: fbank #linear, mfcc, fbank spectrum_type: fbank #linear, mfcc, fbank
......
...@@ -69,7 +69,6 @@ fi ...@@ -69,7 +69,6 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size # format manifest with tokenids, vocab size
python3 ${MAIN_ROOT}/utils/format_data.py \ python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.json" \ --cmvn_path "data/mean_std.json" \
--unit_type "spm" \ --unit_type "spm" \
--spm_model_prefix ${bpeprefix} \ --spm_model_prefix ${bpeprefix} \
......
...@@ -27,7 +27,9 @@ from paddle import distributed as dist ...@@ -27,7 +27,9 @@ from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataset import ManifestDataset from paddlespeech.s2t.io.dataset import ManifestDataset
from paddlespeech.s2t.io.sampler import SortagradBatchSampler from paddlespeech.s2t.io.sampler import SortagradBatchSampler
from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler
...@@ -247,92 +249,103 @@ class U2Trainer(Trainer): ...@@ -247,92 +249,103 @@ class U2Trainer(Trainer):
def setup_dataloader(self): def setup_dataloader(self):
config = self.config.clone() config = self.config.clone()
config.defrost()
config.collator.keep_transcription_text = False
# train/valid dataset, return token ids if self.train:
config.data.manifest = config.data.train_manifest # train/valid dataset, return token ids
train_dataset = ManifestDataset.from_config(config) self.train_loader = BatchDataLoader(
json_file=config.data.train_manifest,
config.data.manifest = config.data.dev_manifest train_mode=True,
dev_dataset = ManifestDataset.from_config(config) sortagrad=False,
collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
batch_size=config.collator.batch_size, batch_size=config.collator.batch_size,
num_replicas=None, maxlen_in=float('inf'),
rank=None, maxlen_out=float('inf'),
shuffle=True, minibatches=0,
drop_last=True, mini_batch_size=self.args.nprocs,
sortagrad=config.collator.sortagrad, batch_count='auto',
shuffle_method=config.collator.shuffle_method) batch_bins=0,
else: batch_frames_in=0,
batch_sampler = SortagradBatchSampler( batch_frames_out=0,
train_dataset, batch_frames_inout=0,
shuffle=True, preprocess_conf=config.collator.
augmentation_config, # aug will be off when train_mode=False
n_iter_processes=config.collator.num_workers,
subsampling_factor=1,
num_encs=1)
self.valid_loader = BatchDataLoader(
json_file=config.data.dev_manifest,
train_mode=False,
sortagrad=False,
batch_size=config.collator.batch_size, batch_size=config.collator.batch_size,
drop_last=True, maxlen_in=float('inf'),
sortagrad=config.collator.sortagrad, maxlen_out=float('inf'),
shuffle_method=config.collator.shuffle_method) minibatches=0,
self.train_loader = DataLoader( mini_batch_size=self.args.nprocs,
train_dataset, batch_count='auto',
batch_sampler=batch_sampler, batch_bins=0,
collate_fn=collate_fn_train, batch_frames_in=0,
num_workers=config.collator.num_workers, ) batch_frames_out=0,
self.valid_loader = DataLoader( batch_frames_inout=0,
dev_dataset, preprocess_conf=config.collator.
batch_size=config.collator.batch_size, augmentation_config, # aug will be off when train_mode=False
shuffle=False, n_iter_processes=config.collator.num_workers,
drop_last=False, subsampling_factor=1,
collate_fn=collate_fn_dev, num_encs=1)
num_workers=config.collator.num_workers, ) logger.info("Setup train/valid Dataloader!")
else:
# test dataset, return raw text # test dataset, return raw text
config.data.manifest = config.data.test_manifest self.test_loader = BatchDataLoader(
# filter test examples, will cause less examples, but no mismatch with training json_file=config.data.test_manifest,
# and can use large batch size , save training time, so filter test egs now. train_mode=False,
config.data.min_input_len = 0.0 # second sortagrad=False,
config.data.max_input_len = float('inf') # second batch_size=config.decoding.batch_size,
config.data.min_output_len = 0.0 # tokens maxlen_in=float('inf'),
config.data.max_output_len = float('inf') # tokens maxlen_out=float('inf'),
config.data.min_output_input_ratio = 0.00 minibatches=0,
config.data.max_output_input_ratio = float('inf') mini_batch_size=1,
batch_count='auto',
test_dataset = ManifestDataset.from_config(config) batch_bins=0,
# return text ord id batch_frames_in=0,
config.collator.keep_transcription_text = True batch_frames_out=0,
config.collator.augmentation_config = "" batch_frames_inout=0,
self.test_loader = DataLoader( preprocess_conf=config.collator.
test_dataset, augmentation_config, # aug will be off when train_mode=False
batch_size=config.decoding.batch_size, n_iter_processes=1,
shuffle=False, subsampling_factor=1,
drop_last=False, num_encs=1)
collate_fn=SpeechCollator.from_config(config),
num_workers=config.collator.num_workers, ) self.align_loader = BatchDataLoader(
# return text token id json_file=config.data.test_manifest,
config.collator.keep_transcription_text = False train_mode=False,
self.align_loader = DataLoader( sortagrad=False,
test_dataset, batch_size=config.decoding.batch_size,
batch_size=config.decoding.batch_size, maxlen_in=float('inf'),
shuffle=False, maxlen_out=float('inf'),
drop_last=False, minibatches=0,
collate_fn=SpeechCollator.from_config(config), mini_batch_size=1,
num_workers=config.collator.num_workers, ) batch_count='auto',
logger.info("Setup train/valid/test/align Dataloader!") batch_bins=0,
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.collator.
augmentation_config, # aug will be off when train_mode=False
n_iter_processes=1,
subsampling_factor=1,
num_encs=1)
logger.info("Setup test/align Dataloader!")
def setup_model(self): def setup_model(self):
config = self.config config = self.config
model_conf = config.model model_conf = config.model
with UpdateConfig(model_conf): with UpdateConfig(model_conf):
model_conf.input_dim = self.train_loader.collate_fn.feature_size if self.train:
model_conf.output_dim = self.train_loader.collate_fn.vocab_size model_conf.input_dim = self.train_loader.feat_dim
model_conf.output_dim = self.train_loader.vocab_size
else:
model_conf.input_dim = self.test_loader.feat_dim
model_conf.output_dim = self.test_loader.vocab_size
model = U2Model.from_config(model_conf) model = U2Model.from_config(model_conf)
...@@ -341,6 +354,11 @@ class U2Trainer(Trainer): ...@@ -341,6 +354,11 @@ class U2Trainer(Trainer):
logger.info(f"{model}") logger.info(f"{model}")
layer_tools.print_params(model, logger.info) layer_tools.print_params(model, logger.info)
self.model = model
logger.info("Setup model!")
if not self.train:
return
train_config = config.training train_config = config.training
optim_type = train_config.optim optim_type = train_config.optim
...@@ -381,10 +399,9 @@ class U2Trainer(Trainer): ...@@ -381,10 +399,9 @@ class U2Trainer(Trainer):
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler) optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
optimizer = OptimizerFactory.from_args(optim_type, optimzer_args) optimizer = OptimizerFactory.from_args(optim_type, optimzer_args)
self.model = model
self.optimizer = optimizer self.optimizer = optimizer
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
logger.info("Setup model/optimizer/lr_scheduler!") logger.info("Setup optimizer/lr_scheduler!")
class U2Tester(U2Trainer): class U2Tester(U2Trainer):
...@@ -419,14 +436,19 @@ class U2Tester(U2Trainer): ...@@ -419,14 +436,19 @@ class U2Tester(U2Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list
def ordid2token(self, texts, texts_len): def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """ """ ord() id to chr() chr """
trans = [] trans = []
for text, n in zip(texts, texts_len): for text, n in zip(texts, texts_len):
n = n.numpy().item() n = n.numpy().item()
ids = text[:n] ids = text[:n]
trans.append(''.join([chr(i) for i in ids])) trans.append(text_feature.defeaturize(ids.numpy().tolist()))
return trans return trans
def compute_metrics(self, def compute_metrics(self,
...@@ -442,12 +464,11 @@ class U2Tester(U2Trainer): ...@@ -442,12 +464,11 @@ class U2Tester(U2Trainer):
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time() start_time = time.time()
text_feature = self.test_loader.collate_fn.text_feature target_transcripts = self.id2token(texts, texts_len, self.text_feature)
target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts, result_tokenids = self.model.decode( result_transcripts, result_tokenids = self.model.decode(
audio, audio,
audio_len, audio_len,
text_feature=text_feature, text_feature=self.text_feature,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path, lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha, beam_alpha=cfg.alpha,
...@@ -497,7 +518,7 @@ class U2Tester(U2Trainer): ...@@ -497,7 +518,7 @@ class U2Tester(U2Trainer):
self.model.eval() self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.test_loader.collate_fn.stride_ms stride_ms = self.config.collator.stride_ms
error_rate_type = None error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0 num_frames = 0.0
...@@ -556,8 +577,8 @@ class U2Tester(U2Trainer): ...@@ -556,8 +577,8 @@ class U2Tester(U2Trainer):
def align(self): def align(self):
ctc_utils.ctc_align( ctc_utils.ctc_align(
self.model, self.align_loader, self.config.decoding.batch_size, self.model, self.align_loader, self.config.decoding.batch_size,
self.align_loader.collate_fn.stride_ms, self.config.collator.stride_ms,
self.align_loader.collate_fn.vocab_list, self.args.result_file) self.vocab_list, self.args.result_file)
def load_inferspec(self): def load_inferspec(self):
"""infer model and input spec. """infer model and input spec.
......
...@@ -392,6 +392,7 @@ class U2Tester(U2Trainer): ...@@ -392,6 +392,7 @@ class U2Tester(U2Trainer):
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath, vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list
def id2token(self, texts, texts_len, text_feature): def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """ """ ord() id to chr() chr """
...@@ -529,8 +530,8 @@ class U2Tester(U2Trainer): ...@@ -529,8 +530,8 @@ class U2Tester(U2Trainer):
def align(self): def align(self):
ctc_utils.ctc_align( ctc_utils.ctc_align(
self.model, self.align_loader, self.config.decoding.batch_size, self.model, self.align_loader, self.config.decoding.batch_size,
self.align_loader.collate_fn.stride_ms, self.config.collator.stride_ms,
self.align_loader.collate_fn.vocab_list, self.args.result_file) self.vocab_list, self.args.result_file)
def load_inferspec(self): def load_inferspec(self):
"""infer model and input spec. """infer model and input spec.
......
...@@ -207,34 +207,16 @@ class AudioDataset(Dataset): ...@@ -207,34 +207,16 @@ class AudioDataset(Dataset):
if sort: if sort:
data = sorted(data, key=lambda x: x["feat_shape"][0]) data = sorted(data, key=lambda x: x["feat_shape"][0])
if raw_wav: if raw_wav:
assert data[0]['feat'].split(':')[0].splitext()[-1] not in ('.ark', path_suffix = data[0]['feat'].split(':')[0].splitext()[-1]
'.scp') assert path_suffix not in ('.ark', '.scp')
data = map(lambda x: (float(x['feat_shape'][0]) * 1000 / stride_ms)) # m second to n frame
data = list(
map(lambda x: (float(x['feat_shape'][0]) * 1000 / stride_ms),
data))
self.input_dim = data[0]['feat_shape'][1] self.input_dim = data[0]['feat_shape'][1]
self.output_dim = data[0]['token_shape'][1] self.output_dim = data[0]['token_shape'][1]
# with open(data_file, 'r') as f:
# for line in f:
# arr = line.strip().split('\t')
# if len(arr) != 7:
# continue
# key = arr[0].split(':')[1]
# tokenid = arr[5].split(':')[1]
# output_dim = int(arr[6].split(':')[1].split(',')[1])
# if raw_wav:
# wav_path = ':'.join(arr[1].split(':')[1:])
# duration = int(float(arr[2].split(':')[1]) * 1000 / 10)
# data.append((key, wav_path, duration, tokenid))
# else:
# feat_ark = ':'.join(arr[1].split(':')[1:])
# feat_info = arr[2].split(':')[1].split(',')
# feat_dim = int(feat_info[1].strip())
# num_frames = int(feat_info[0].strip())
# data.append((key, feat_ark, num_frames, tokenid))
# self.input_dim = feat_dim
# self.output_dim = output_dim
valid_data = [] valid_data = []
for i in range(len(data)): for i in range(len(data)):
length = data[i]['feat_shape'][0] length = data[i]['feat_shape'][0]
...@@ -242,17 +224,17 @@ class AudioDataset(Dataset): ...@@ -242,17 +224,17 @@ class AudioDataset(Dataset):
# remove too lang or too short utt for both input and output # remove too lang or too short utt for both input and output
# to prevent from out of memory # to prevent from out of memory
if length > max_length or length < min_length: if length > max_length or length < min_length:
# logging.warn('ignore utterance {} feature {}'.format(
# data[i][0], length))
pass pass
elif token_length > token_max_length or token_length < token_min_length: elif token_length > token_max_length or token_length < token_min_length:
pass pass
else: else:
valid_data.append(data[i]) valid_data.append(data[i])
logger.info(f"raw dataset len: {len(data)}")
data = valid_data data = valid_data
num_data = len(data)
logger.info(f"dataset len after filter: {num_data}")
self.minibatch = [] self.minibatch = []
num_data = len(data)
# Dynamic batch size # Dynamic batch size
if batch_type == 'dynamic': if batch_type == 'dynamic':
assert (max_frames_in_batch > 0) assert (max_frames_in_batch > 0)
...@@ -277,7 +259,9 @@ class AudioDataset(Dataset): ...@@ -277,7 +259,9 @@ class AudioDataset(Dataset):
cur = end cur = end
def __len__(self): def __len__(self):
"""number of example(batch)"""
return len(self.minibatch) return len(self.minibatch)
def __getitem__(self, idx): def __getitem__(self, idx):
"""batch example of idx"""
return self.minibatch[idx] return self.minibatch[idx]
...@@ -18,8 +18,10 @@ import kaldiio ...@@ -18,8 +18,10 @@ import kaldiio
import numpy as np import numpy as np
import soundfile import soundfile
from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline as Transformation from .utility import feat_type
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
# from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline as Transformation
__all__ = ["LoadInputsAndTargets"] __all__ = ["LoadInputsAndTargets"]
...@@ -322,20 +324,7 @@ class LoadInputsAndTargets(): ...@@ -322,20 +324,7 @@ class LoadInputsAndTargets():
"Not supported: loader_type={}".format(filetype)) "Not supported: loader_type={}".format(filetype))
def file_type(self, filepath): def file_type(self, filepath):
suffix = filepath.split(":")[0].split('.')[-1].lower() return feat_type(filepath)
if suffix == 'ark':
return 'mat'
elif suffix == 'scp':
return 'scp'
elif suffix == 'npy':
return 'npy'
elif suffix == 'npz':
return 'npz'
elif suffix in ['wav', 'flac']:
# PCM16
return 'sound'
else:
raise ValueError(f"Not support filetype: {suffix}")
class SoundHDF5File(): class SoundHDF5File():
......
...@@ -17,7 +17,7 @@ import numpy as np ...@@ -17,7 +17,7 @@ import numpy as np
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
__all__ = ["pad_list", "pad_sequence"] __all__ = ["pad_list", "pad_sequence", "feat_type"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
...@@ -85,3 +85,20 @@ def pad_sequence(sequences: List[np.ndarray], ...@@ -85,3 +85,20 @@ def pad_sequence(sequences: List[np.ndarray],
out_tensor[:length, i, ...] = tensor out_tensor[:length, i, ...] = tensor
return out_tensor return out_tensor
def feat_type(filepath):
suffix = filepath.split(":")[0].split('.')[-1].lower()
if suffix == 'ark':
return 'mat'
elif suffix == 'scp':
return 'scp'
elif suffix == 'npy':
return 'npy'
elif suffix == 'npz':
return 'npz'
elif suffix in ['wav', 'flac']:
# PCM16
return 'sound'
else:
raise ValueError(f"Not support filetype: {suffix}")
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# Modified from espnet(https://github.com/espnet/espnet) # Modified from espnet(https://github.com/espnet/espnet)
import librosa import librosa
import numpy as np import numpy as np
from python_speech_features import logfbank
def stft(x, def stft(x,
...@@ -304,3 +305,85 @@ class IStft(): ...@@ -304,3 +305,85 @@ class IStft():
win_length=self.win_length, win_length=self.win_length,
window=self.window, window=self.window,
center=self.center, ) center=self.center, )
class LogMelSpectrogramKaldi():
def __init__(
self,
fs=16000,
n_mels=80,
n_fft=512, # fft point
n_shift=160, # unit:sample, 10ms
win_length=400, # unit:sample, 25ms
window="povey",
fmin=20,
fmax=None,
eps=1e-10,
dither=False):
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
if n_shift > win_length:
raise ValueError("Stride size must not be greater than "
"window size.")
self.n_shift = n_shift / fs # unit: ms
self.win_length = win_length / fs # unit: ms
self.window = window
self.fmin = fmin
if fmax is None:
fmax_ = fmax if fmax else self.fs / 2
elif fmax > int(self.fs / 2):
raise ValueError("fmax must not be greater than half of "
"sample rate.")
self.fmax = fmax_
self.eps = eps
self.remove_dc_offset = True
self.preemph = 0.97
self.dither = dither
def __repr__(self):
return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, "
"n_shift={n_shift}, win_length={win_length}, window={window}, "
"fmin={fmin}, fmax={fmax}, eps={eps}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_fft=self.n_fft,
n_shift=self.n_shift,
win_length=self.win_length,
window=self.window,
fmin=self.fmin,
fmax=self.fmax,
eps=self.eps, ))
def __call__(self, x):
"""
Args:
x (np.ndarray): shape (Ti,)
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
if x.ndim != 1:
raise ValueError("Not support x: [Time, Channel]")
if x.dtype == np.int16:
x = x / 2**(16 - 1)
return logfbank(
signal=x,
samplerate=self.fs,
winlen=self.win_length, # unit ms
winstep=self.n_shift, # unit ms
nfilt=self.n_mels,
nfft=self.n_fft,
lowfreq=self.fmin,
highfreq=self.fmax,
dither=self.dither,
remove_dc_offset=self.remove_dc_offset,
preemph=self.preemph,
wintype=self.window)
...@@ -45,6 +45,7 @@ import_alias = dict( ...@@ -45,6 +45,7 @@ import_alias = dict(
stft2fbank="paddlespeech.s2t.transform.spectrogram:Stft2LogMelSpectrogram", stft2fbank="paddlespeech.s2t.transform.spectrogram:Stft2LogMelSpectrogram",
wpe="paddlespeech.s2t.transform.wpe:WPE", wpe="paddlespeech.s2t.transform.wpe:WPE",
channel_selector="paddlespeech.s2t.transform.channel_selector:ChannelSelector", channel_selector="paddlespeech.s2t.transform.channel_selector:ChannelSelector",
fbank_kaldi="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogramKaldi",
) )
......
...@@ -20,13 +20,13 @@ import json ...@@ -20,13 +20,13 @@ import json
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.frontend.utility import read_manifest from paddlespeech.s2t.frontend.utility import read_manifest
from paddlespeech.s2t.io.utility import feat_type
from paddlespeech.s2t.utils.utility import add_arguments from paddlespeech.s2t.utils.utility import add_arguments
from paddlespeech.s2t.utils.utility import print_arguments from paddlespeech.s2t.utils.utility import print_arguments
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), mat(ark), scp")
add_arg('cmvn_path', str, add_arg('cmvn_path', str,
'examples/librispeech/data/mean_std.json', 'examples/librispeech/data/mean_std.json',
"Filepath of cmvn.") "Filepath of cmvn.")
...@@ -62,24 +62,64 @@ def main(): ...@@ -62,24 +62,64 @@ def main():
vocab_size = text_feature.vocab_size vocab_size = text_feature.vocab_size
print(f"Vocab size: {vocab_size}") print(f"Vocab size: {vocab_size}")
# josnline like this
# {
# "input": [{"name": "input1", "shape": (100, 83), "feat": "xxx.ark:123"}],
# "output": [{"name":"target1", "shape": (40, 5002), "text": "a b c de"}],
# "utt2spk": "111-2222",
# "utt": "111-2222-333"
# }
count = 0 count = 0
for manifest_path in args.manifest_paths: for manifest_path in args.manifest_paths:
manifest_jsons = read_manifest(manifest_path) manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons: for line_json in manifest_jsons:
output_json = {
"input": [],
"output": [],
'utt': line_json['utt'],
'utt2spk': line_json.get('utt2spk', 'global'),
}
# output
line = line_json['text'] line = line_json['text']
tokens = text_feature.tokenize(line) if isinstance(line, str):
tokenids = text_feature.featurize(line) # only one target
line_json['token'] = tokens tokens = text_feature.tokenize(line)
line_json['token_id'] = tokenids tokenids = text_feature.featurize(line)
line_json['token_shape'] = (len(tokenids), vocab_size) output_json['output'].append({
feat_shape = line_json['feat_shape'] 'name': 'traget1',
assert isinstance(feat_shape, (list, tuple)), type(feat_shape) 'shape': (len(tokenids), vocab_size),
if args.feat_type == 'raw': 'text': line,
feat_shape.append(feat_dim) 'token': ' '.join(tokens),
line_json['filetype'] = 'sound' 'tokenid': ' '.join(map(str, tokenids)),
else: # kaldi })
raise NotImplementedError('no support kaldi feat now!') else:
fout.write(json.dumps(line_json) + '\n') # isinstance(line, list), multi target
raise NotImplementedError("not support multi output now!")
# input
line = line_json['feat']
if isinstance(line, str):
# only one input
feat_shape = line_json['feat_shape']
assert isinstance(feat_shape, (list, tuple)), type(feat_shape)
filetype = feat_type(line)
if filetype == 'sound':
feat_shape.append(feat_dim)
else: # kaldi
raise NotImplementedError('no support kaldi feat now!')
output_json['input'].append({
"name": "input1",
"shape": feat_shape,
"feat": line,
"filetype": filetype,
})
else:
# isinstance(line, list), multi input
raise NotImplementedError("not support multi input now!")
fout.write(json.dumps(output_json) + '\n')
count += 1 count += 1
print(f"Examples number: {count}") print(f"Examples number: {count}")
......
...@@ -20,13 +20,13 @@ import json ...@@ -20,13 +20,13 @@ import json
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.frontend.utility import read_manifest from paddlespeech.s2t.frontend.utility import read_manifest
from paddlespeech.s2t.io.utility import feat_type
from paddlespeech.s2t.utils.utility import add_arguments from paddlespeech.s2t.utils.utility import add_arguments
from paddlespeech.s2t.utils.utility import print_arguments from paddlespeech.s2t.utils.utility import print_arguments
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), kaldi")
add_arg('cmvn_path', str, add_arg('cmvn_path', str,
'examples/librispeech/data/mean_std.json', 'examples/librispeech/data/mean_std.json',
"Filepath of cmvn.") "Filepath of cmvn.")
...@@ -79,9 +79,11 @@ def main(): ...@@ -79,9 +79,11 @@ def main():
line_json['token1'] = tokens line_json['token1'] = tokens
line_json['token_id1'] = tokenids line_json['token_id1'] = tokenids
line_json['token_shape1'] = (len(tokenids), vocab_size) line_json['token_shape1'] = (len(tokenids), vocab_size)
feat_shape = line_json['feat_shape'] feat_shape = line_json['feat_shape']
assert isinstance(feat_shape, (list, tuple)), type(feat_shape) assert isinstance(feat_shape, (list, tuple)), type(feat_shape)
if args.feat_type == 'raw': filetype = feat_type(line_json['feat'])
if filetype == 'sound':
feat_shape.append(feat_dim) feat_shape.append(feat_dim)
else: # kaldi else: # kaldi
raise NotImplementedError('no support kaldi feat now!') raise NotImplementedError('no support kaldi feat now!')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册