diff --git a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py index f3125e04dad524071b5d8bd5efeb4201ebc81557..01f01b6515c518123f10b19f3c2d4272a2b55a67 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py @@ -18,8 +18,10 @@ import numpy as np import paddle from paddle.inference import Config from paddle.inference import create_predictor +from paddle.io import DataLoader from deepspeech.exps.deepspeech2.config import get_cfg_defaults +from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.training.cli import default_argument_parser @@ -78,26 +80,31 @@ def inference(config, args): def start_server(config, args): """Start the ASR server""" config.defrost() - config.data.manfiest = config.data.test_manifest - config.data.augmentation_config = "" - config.data.keep_transcription_text = True + config.data.manifest = config.data.test_manifest dataset = ManifestDataset.from_config(config) - model = DeepSpeech2Model.from_pretrained(dataset, config, + config.collator.augmentation_config = "" + config.collator.keep_transcription_text = True + config.collator.batch_size = 1 + config.collator.num_workers = 0 + collate_fn = SpeechCollator.from_config(config) + test_loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=0) + + model = DeepSpeech2Model.from_pretrained(test_loader, config, args.checkpoint_path) model.eval() # prepare ASR inference handler def file_to_transcript(filename): - feature = dataset.process_utterance(filename, "") - audio = np.array([feature[0]]).astype('float32') #[1, D, T] - audio_len = feature[0].shape[1] + feature = test_loader.collate_fn.process_utterance(filename, "") + audio = np.array([feature[0]]).astype('float32') #[1, T, D] + audio_len = feature[0].shape[0] audio_len = np.array([audio_len]).astype('int64') # [1] result_transcript = model.decode( paddle.to_tensor(audio), paddle.to_tensor(audio_len), - vocab_list=dataset.vocab_list, + vocab_list=test_loader.collate_fn.vocab_list, decoding_method=config.decoding.decoding_method, lang_model_path=config.decoding.lang_model_path, beam_alpha=config.decoding.alpha, @@ -138,7 +145,7 @@ if __name__ == "__main__": add_arg('host_ip', str, 'localhost', "Server's IP address.") - add_arg('host_port', int, 8086, "Server's IP port.") + add_arg('host_port', int, 8089, "Server's IP port.") add_arg('speech_save_dir', str, 'demo_cache', "Directory to save demo audios.") diff --git a/deepspeech/exps/deepspeech2/bin/deploy/server.py b/deepspeech/exps/deepspeech2/bin/deploy/server.py index b2ff37e067340a7c9128614cbc53206a4013fd63..b473a8fd44523cbe4c7a9aa31d9e4f8b5ea28dbd 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/server.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/server.py @@ -16,8 +16,10 @@ import functools import numpy as np import paddle +from paddle.io import DataLoader from deepspeech.exps.deepspeech2.config import get_cfg_defaults +from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.training.cli import default_argument_parser @@ -31,26 +33,35 @@ from deepspeech.utils.utility import print_arguments def start_server(config, args): """Start the ASR server""" config.defrost() - config.data.manfiest = config.data.test_manifest - config.data.augmentation_config = "" - config.data.keep_transcription_text = True + config.data.manifest = config.data.test_manifest dataset = ManifestDataset.from_config(config) - model = DeepSpeech2Model.from_pretrained(dataset, config, + config.collator.augmentation_config = "" + config.collator.keep_transcription_text = True + config.collator.batch_size = 1 + config.collator.num_workers = 0 + collate_fn = SpeechCollator.from_config(config) + test_loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=0) + + model = DeepSpeech2Model.from_pretrained(test_loader, config, args.checkpoint_path) model.eval() # prepare ASR inference handler def file_to_transcript(filename): - feature = dataset.process_utterance(filename, "") - audio = np.array([feature[0]]).astype('float32') #[1, D, T] - audio_len = feature[0].shape[1] + feature = test_loader.collate_fn.process_utterance(filename, "") + audio = np.array([feature[0]]).astype('float32') #[1, T, D] + # audio = audio.swapaxes(1,2) + print('---file_to_transcript feature----') + print(audio.shape) + audio_len = feature[0].shape[0] + print(audio_len) audio_len = np.array([audio_len]).astype('int64') # [1] result_transcript = model.decode( paddle.to_tensor(audio), paddle.to_tensor(audio_len), - vocab_list=dataset.vocab_list, + vocab_list=test_loader.collate_fn.vocab_list, decoding_method=config.decoding.decoding_method, lang_model_path=config.decoding.lang_model_path, beam_alpha=config.decoding.alpha, @@ -91,7 +102,7 @@ if __name__ == "__main__": add_arg('host_ip', str, 'localhost', "Server's IP address.") - add_arg('host_port', int, 8086, "Server's IP port.") + add_arg('host_port', int, 8088, "Server's IP port.") add_arg('speech_save_dir', str, 'demo_cache', "Directory to save demo audios.") diff --git a/deepspeech/exps/deepspeech2/bin/tune.py b/deepspeech/exps/deepspeech2/bin/tune.py index 02e329a11d29fd042532ec9918b4a099875b96b3..f10dc27ced924abcb720e2f99f84bdfa015bf650 100644 --- a/deepspeech/exps/deepspeech2/bin/tune.py +++ b/deepspeech/exps/deepspeech2/bin/tune.py @@ -47,7 +47,7 @@ def tune(config, args): drop_last=False, collate_fn=SpeechCollator(keep_transcription_text=True)) - model = DeepSpeech2Model.from_pretrained(dev_dataset, config, + model = DeepSpeech2Model.from_pretrained(valid_loader, config, args.checkpoint_path) model.eval() diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index deb8752b72deb7efc52bd1f86becf6de7d4d0073..209e8b0233b660e9e079513f457f037d6486f1ce 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -318,7 +318,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): def export(self): infer_model = DeepSpeech2InferModel.from_pretrained( - self.test_loader.dataset, self.config, self.args.checkpoint_path) + self.test_loader, self.config, self.args.checkpoint_path) infer_model.eval() feat_dim = self.test_loader.collate_fn.feature_size static_model = paddle.jit.to_static( diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index ba7bc45c81a3ed526583bd4d01d0e63726366f7a..8802143d68fad58c5205907f5a3123438cd38fe7 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -574,15 +574,14 @@ class U2Tester(U2Trainer): List[paddle.static.InputSpec]: input spec. """ from deepspeech.models.u2 import U2InferModel - infer_model = U2InferModel.from_pretrained(self.test_loader.dataset, + infer_model = U2InferModel.from_pretrained(self.test_loader, self.config.model.clone(), self.args.checkpoint_path) feat_dim = self.test_loader.collate_fn.feature_size input_spec = [ - paddle.static.InputSpec( - shape=[None, feat_dim, None], - dtype='float32'), # audio, [B,D,T] - paddle.static.InputSpec(shape=[None], + paddle.static.InputSpec(shape=[1, None, feat_dim], + dtype='float32'), # audio, [B,T,D] + paddle.static.InputSpec(shape=[1], dtype='int64'), # audio_length, [B] ] return infer_model, input_spec diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 1061f97cfabc8e259ceefc353b264bde3d21a758..2ef1196663edad598b2b927948645d0cc94d7bd4 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -154,8 +154,8 @@ class SpeechCollator(): random_seed (int, optional): for random generator. Defaults to 0. keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False. if ``keep_transcription_text`` is False, text is token ids else is raw string. - - Do augmentations + + Do augmentations Padding audio features with zeros to make them have the same shape (or a user-defined shape) within one batch. """ @@ -242,6 +242,7 @@ class SpeechCollator(): # specgram augment specgram = self._augmentation_pipeline.transform_feature(specgram) + specgram = specgram.transpose([1, 0]) return specgram, transcript_part def __call__(self, batch): @@ -269,8 +270,8 @@ class SpeechCollator(): #utt utts.append(utt) # audio - audios.append(audio.T) # [T, D] - audio_lens.append(audio.shape[1]) + audios.append(audio) # [T, D] + audio_lens.append(audio.shape[0]) # text # for training, text is token ids # else text is string, convert to unicode ord diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index 0ff5514de03248f88a52760938d3540749b9ad31..d2c03a18e013ca722dcd8c99cc23d7719293e78f 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -198,11 +198,11 @@ class DeepSpeech2Model(nn.Layer): cutoff_top_n, num_processes) @classmethod - def from_pretrained(cls, dataset, config, checkpoint_path): + def from_pretrained(cls, dataloader, config, checkpoint_path): """Build a DeepSpeech2Model model from a pretrained model. Parameters ---------- - dataset: paddle.io.Dataset + dataloader: paddle.io.DataLoader config: yacs.config.CfgNode model configs @@ -215,8 +215,8 @@ class DeepSpeech2Model(nn.Layer): DeepSpeech2Model The model built from pretrained result. """ - model = cls(feat_size=dataset.feature_size, - dict_size=dataset.vocab_size, + model = cls(feat_size=dataloader.collate_fn.feature_size, + dict_size=dataloader.collate_fn.vocab_size, num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 238e2d35c5492097868c2ff8ea1ff941bd27dc9e..23ae3423d677b22fc41f03552b95d41db4b3c435 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -876,11 +876,11 @@ class U2Model(U2BaseModel): return model @classmethod - def from_pretrained(cls, dataset, config, checkpoint_path): + def from_pretrained(cls, dataloader, config, checkpoint_path): """Build a DeepSpeech2Model model from a pretrained model. Args: - dataset (paddle.io.Dataset): not used. + dataloader (paddle.io.DataLoader): not used. config (yacs.config.CfgNode): model configs checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name @@ -888,8 +888,8 @@ class U2Model(U2BaseModel): DeepSpeech2Model: The model built from pretrained result. """ config.defrost() - config.input_dim = dataset.feature_size - config.output_dim = dataset.vocab_size + config.input_dim = dataloader.collate_fn.feature_size + config.output_dim = dataloader.collate_fn.vocab_size config.freeze() model = cls.from_config(config) diff --git a/deepspeech/utils/socket_server.py b/deepspeech/utils/socket_server.py index adcbf3bb297af3ce710a36814b19b0ac20babc08..45c659f6021043ca96424138f7f1f29c6649cba6 100644 --- a/deepspeech/utils/socket_server.py +++ b/deepspeech/utils/socket_server.py @@ -48,9 +48,9 @@ def warm_up_test(audio_process_handler, rng = random.Random(random_seed) samples = rng.sample(manifest, num_test_cases) for idx, sample in enumerate(samples): - print("Warm-up Test Case %d: %s", idx, sample['audio_filepath']) + print("Warm-up Test Case %d: %s" % (idx, sample['feat'])) start_time = time.time() - transcript = audio_process_handler(sample['audio_filepath']) + transcript = audio_process_handler(sample['feat']) finish_time = time.time() print("Response Time: %f, Transcript: %s" % (finish_time - start_time, transcript)) diff --git a/examples/aishell/s0/README.md b/examples/aishell/s0/README.md index ae3fb401ad110976671510b352ed59e02909eadb..6ce39b233f043768a06ee6de208336b6b8ffc502 100644 --- a/examples/aishell/s0/README.md +++ b/examples/aishell/s0/README.md @@ -2,10 +2,10 @@ ## Deepspeech2 -| Model | release | Config | Test set | Loss | CER | -| --- | --- | --- | --- | --- | --- | -| DeepSpeech2 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug + new datapipe | test | 6.396368026733398 | 0.068382 | -| DeepSpeech2 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 | -| DeepSpeech2 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 | -| DeepSpeech2 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 | -| DeepSpeech2 58.4M | 1.8.5 | - | test | - | 0.080447 | +| Model | Params | Release | Config | Test set | Loss | CER | +| --- | --- | --- | --- | --- | --- | --- | +| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug + new datapipe | test | 6.396368026733398 | 0.068382,0.073507 | +| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 | +| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 | +| DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 | +| DeepSpeech2 | 58.4M | 1.8.5 | - | test | - | 0.080447 | diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 8cc4c4c9c593b2079c6f2b61f5d2ae0420ace97c..1004fde0e227e4688849dc046d8fe5c839b747f6 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -10,8 +10,8 @@ data: min_output_input_ratio: 0.00 max_output_input_ratio: .inf - collator: + batch_size: 64 # one gpu mean_std_filepath: data/mean_std.json unit_type: char vocab_filepath: data/vocab.txt @@ -33,7 +33,6 @@ collator: sortagrad: True shuffle_method: batch_shuffle num_workers: 0 - batch_size: 64 # one gpu model: num_conv_layers: 2 diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index 05829136ad2d7f325cb00c5426b067604b08655e..c9708dcc90df357568282ace7700ceb13b7883d3 100755 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -31,10 +31,10 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # export ckpt avg_n - CUDA_VISIBLE_DEVICES=${gpus} ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit + CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi diff --git a/examples/aishell/s1/README.md b/examples/aishell/s1/README.md index 601b0a8d0f1f19c3db3b6cb4466b09f6e3304cae..78e759c8f9cb93211dc1192aff779289111d3ec1 100644 --- a/examples/aishell/s1/README.md +++ b/examples/aishell/s1/README.md @@ -2,25 +2,26 @@ ## Conformer -| Model | Config | Augmentation| Test set | Decode method | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | -| conformer | conf/conformer.yaml | spec_aug + shift | test | attention | - | 0.059858 | -| conformer | conf/conformer.yaml | spec_aug + shift | test | ctc_greedy_search | - | 0.062311 | -| conformer | conf/conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | - | 0.062196 | -| conformer | conf/conformer.yaml | spec_aug + shift | test | attention_rescoring | - | 0.054694 | +| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | +| --- | --- | --- | --- | --- | --- | --- | --- | +| conformer | 47.07M | conf/conformer.yaml | spec_aug + shift | test | attention | - | 0.059858 | +| conformer | 47.07M | conf/conformer.yaml | spec_aug + shift | test | ctc_greedy_search | - | 0.062311 | +| conformer | 47.07M | conf/conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | - | 0.062196 | +| conformer | 47.07M | conf/conformer.yaml | spec_aug + shift | test | attention_rescoring | - | 0.054694 | + ## Chunk Conformer -| Model | Config | Augmentation| Test set | Decode method | Chunk | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | --- | -| conformer | conf/chunk_conformer.yaml | spec_aug + shift | test | attention | 16 | - | 0.061939 | -| conformer | conf/chunk_conformer.yaml | spec_aug + shift | test | ctc_greedy_search | 16 | - | 0.070806 | -| conformer | conf/chunk_conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | 16 | - | 0.070739 | -| conformer | conf/chunk_conformer.yaml | spec_aug + shift | test | attention_rescoring | 16 | - | 0.059400 | +| Model | Params | Config | Augmentation| Test set | Decode method | Chunk | Loss | WER | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | +| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug + shift | test | attention | 16 | - | 0.061939 | +| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug + shift | test | ctc_greedy_search | 16 | - | 0.070806 | +| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | 16 | - | 0.070739 | +| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug + shift | test | attention_rescoring | 16 | - | 0.059400 | ## Transformer -| Model | Config | Augmentation| Test set | Decode method | Loss | WER | -| --- | --- | --- | --- | --- | --- | ---| -| transformer | conf/transformer.yaml | spec_aug + shift | test | attention | - | - | +| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | +| --- | --- | --- | --- | --- | --- | --- | ---| +| transformer | - | conf/transformer.yaml | spec_aug + shift | test | attention | - | - | diff --git a/examples/dataset/aishell/aishell.py b/examples/dataset/aishell/aishell.py index a0cabe352d6154567c1e3b446f90041eda7d9e7c..b8aede2fc62815ce616d310886d917ae8c1eb3c0 100644 --- a/examples/dataset/aishell/aishell.py +++ b/examples/dataset/aishell/aishell.py @@ -60,7 +60,7 @@ def create_manifest(data_dir, manifest_path_prefix): if line == '': continue audio_id, text = line.split(' ', 1) - # remove withespace + # remove withespace, charactor text text = ''.join(text.split()) transcript_dict[audio_id] = text @@ -123,6 +123,8 @@ def main(): target_dir=args.target_dir, manifest_path=args.manifest_prefix) + print("Data download and manifest prepare done!") + if __name__ == '__main__': main() diff --git a/examples/dataset/thchs30/.gitignore b/examples/dataset/thchs30/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..47dd6268f424b07fdf4b593c0d43f72b030c3e47 --- /dev/null +++ b/examples/dataset/thchs30/.gitignore @@ -0,0 +1,5 @@ +*.tgz +manifest.* +data_thchs30 +resource +test-noise diff --git a/examples/dataset/thchs30/thchs30.py b/examples/dataset/thchs30/thchs30.py new file mode 100644 index 0000000000000000000000000000000000000000..225adb092686d2cd276297a760eac730995129ff --- /dev/null +++ b/examples/dataset/thchs30/thchs30.py @@ -0,0 +1,169 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Prepare THCHS-30 mandarin dataset + +Download, unpack and create manifest files. +Manifest file is a json-format file with each line containing the +meta data (i.e. audio filepath, transcript and audio duration) +of each audio file in the data set. +""" +import argparse +import codecs +import json +import os +from multiprocessing.pool import Pool +from pathlib import Path + +import soundfile + +from utils.utility import download +from utils.utility import unpack + +DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') + +URL_ROOT = 'http://www.openslr.org/resources/18' +# URL_ROOT = 'https://openslr.magicdatatech.com/resources/18' +DATA_URL = URL_ROOT + '/data_thchs30.tgz' +TEST_NOISE_URL = URL_ROOT + '/test-noise.tgz' +RESOURCE_URL = URL_ROOT + '/resource.tgz' +MD5_DATA = '2d2252bde5c8429929e1841d4cb95e90' +MD5_TEST_NOISE = '7e8a985fb965b84141b68c68556c2030' +MD5_RESOURCE = 'c0b2a565b4970a0c4fe89fefbf2d97e1' + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--target_dir", + default=DATA_HOME + "/THCHS30", + type=str, + help="Directory to save the dataset. (default: %(default)s)") +parser.add_argument( + "--manifest_prefix", + default="manifest", + type=str, + help="Filepath prefix for output manifests. (default: %(default)s)") +args = parser.parse_args() + + +def read_trn(filepath): + """read trn file. + word text in first line. + syllable text in second line. + phoneme text in third line. + + Args: + filepath (str): trn path. + + Returns: + list(str): (word, syllable, phone) + """ + texts = [] + with open(filepath, 'r') as f: + lines = f.read().split('\n') + # last line is `empty` + lines = lines[:3] + assert len(lines) == 3, lines + # charactor text, remove withespace + texts.append(''.join(lines[0].split())) + texts.extend(lines[1:]) + return texts + + +def resolve_symlink(filepath): + """resolve symlink which content is norm file. + + Args: + filepath (str): norm file symlink. + """ + sym_path = Path(filepath) + relative_link = sym_path.read_text().strip() + relative = Path(relative_link) + relpath = sym_path.parent / relative + return relpath.resolve() + + +def create_manifest(data_dir, manifest_path_prefix): + print("Creating manifest %s ..." % manifest_path_prefix) + json_lines = [] + data_types = ['train', 'dev', 'test'] + for dtype in data_types: + del json_lines[:] + audio_dir = os.path.join(data_dir, dtype) + for subfolder, _, filelist in sorted(os.walk(audio_dir)): + for fname in filelist: + file_path = os.path.join(subfolder, fname) + if file_path.endswith('.wav'): + audio_path = os.path.abspath(file_path) + text_path = resolve_symlink(audio_path + '.trn') + else: + continue + + assert os.path.exists(audio_path) and os.path.exists(text_path) + + audio_id = os.path.basename(audio_path)[:-4] + word_text, syllable_text, phone_text = read_trn(text_path) + audio_data, samplerate = soundfile.read(audio_path) + duration = float(len(audio_data) / samplerate) + + json_lines.append( + json.dumps( + { + 'utt': audio_id, + 'feat': audio_path, + 'feat_shape': (duration, ), # second + 'text': word_text, + 'syllable': syllable_text, + 'phone': phone_text, + }, + ensure_ascii=False)) + + manifest_path = manifest_path_prefix + '.' + dtype + with codecs.open(manifest_path, 'w', 'utf-8') as fout: + for line in json_lines: + fout.write(line + '\n') + + +def prepare_dataset(url, md5sum, target_dir, manifest_path, subset): + """Download, unpack and create manifest file.""" + datadir = os.path.join(target_dir, subset) + if not os.path.exists(datadir): + filepath = download(url, md5sum, target_dir) + unpack(filepath, target_dir) + else: + print("Skip downloading and unpacking. Data already exists in %s." % + target_dir) + + if subset == 'data_thchs30': + create_manifest(datadir, manifest_path) + + +def main(): + if args.target_dir.startswith('~'): + args.target_dir = os.path.expanduser(args.target_dir) + + tasks = [ + (DATA_URL, MD5_DATA, args.target_dir, args.manifest_prefix, + "data_thchs30"), + (TEST_NOISE_URL, MD5_TEST_NOISE, args.target_dir, args.manifest_prefix, + "test-noise"), + (RESOURCE_URL, MD5_RESOURCE, args.target_dir, args.manifest_prefix, + "resource"), + ] + with Pool(7) as pool: + pool.starmap(prepare_dataset, tasks) + + print("Data download and manifest prepare done!") + + +if __name__ == '__main__': + main() diff --git a/examples/librispeech/s0/README.md b/examples/librispeech/s0/README.md index 393dd4579c12d7eff271e71c74e83c1f5d0ed41b..76aa5e78a0d35b30a891f15ee983d0482f7560f6 100644 --- a/examples/librispeech/s0/README.md +++ b/examples/librispeech/s0/README.md @@ -2,8 +2,8 @@ ## Deepspeech2 -| Model | release | Config | Test set | Loss | WER | -| --- | --- | --- | --- | --- | --- | -| DeepSpeech2 | 2.1.0 | conf/deepspeech2.yaml | 15.184467315673828 | test-clean | 0.072154 | -| DeepSpeech2 | 2.0.0 | conf/deepspeech2.yaml | - | test-clean | 0.073973 | -| DeepSpeech2 | 1.8.5 | - | test-clean | - | 0.074939 | +| Model | Params | release | Config | Test set | Loss | WER | +| --- | --- | --- | --- | --- | --- | --- | +| DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | 15.184467315673828 | test-clean | 0.072154 | +| DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | - | test-clean | 0.073973 | +| DeepSpeech2 | 42.96M | 1.8.5 | - | test-clean | - | 0.074939 | diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index d1746bff398ad224f28f082ea2dc227ff7624eb7..b419cbe267b11083c0de1b9079de5f650f787271 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -3,16 +3,21 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev-clean test_manifest: data/manifest.test-clean - mean_std_filepath: data/mean_std.json - vocab_filepath: data/vocab.txt - augmentation_config: conf/augmentation.json - batch_size: 20 min_input_len: 0.0 max_input_len: 27.0 # second min_output_len: 0.0 max_output_len: .inf min_output_input_ratio: 0.00 max_output_input_ratio: .inf + +collator: + batch_size: 20 + mean_std_filepath: data/mean_std.json + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: specgram_type: linear target_sample_rate: 16000 max_freq: None diff --git a/examples/librispeech/s1/README.md b/examples/librispeech/s1/README.md index 73f6156d99049083cfe3c6b00417204b547cee7b..5e23c0ab5f582d35cdb4dd41a9c3066d82f65b91 100644 --- a/examples/librispeech/s1/README.md +++ b/examples/librispeech/s1/README.md @@ -2,17 +2,17 @@ ## Conformer -| Model | Config | Augmentation| Test set | Decode method | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | -| conformer | conf/conformer.yaml | spec_aug + shift | test-all | attention | test-all 6.35 | 0.057117 | -| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | attention | test-all 6.35 | 0.030162 | -| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | test-all 6.35 | 0.037910 | -| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | test-all 6.35 | 0.037761 | -| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | test-all 6.35 | 0.032115 | +| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | +| --- | --- | --- | --- | --- | --- | --- | --- | +| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-all | attention | 6.35 | 0.057117 | +| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention | 6.35 | 0.030162 | +| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | 6.35 | 0.037910 | +| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | 6.35 | 0.037761 | +| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | 6.35 | 0.032115 | ## Transformer -| Model | Config | Augmentation| Test set | Decode method | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | -| transformer | conf/transformer.yaml | spec_aug + shift | test-all | attention | test-all 6.98 | 0.066500 | -| transformer | conf/transformer.yaml | spec_aug + shift | test-clean | attention | test-all 6.98 | 0.036 | +| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | +| --- | --- | --- | --- | --- | --- | --- | --- | +| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-all | attention | 6.98 | 0.066500 | +| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean | attention | 6.98 | 0.036 | diff --git a/examples/librispeech/s1/conf/chunk_confermer.yaml b/examples/librispeech/s1/conf/chunk_confermer.yaml index ec945a188bd2f66c12b34dc8499612b38b0912c5..ef08daa84de2c62b1f52882f03fee8edc359e5c5 100644 --- a/examples/librispeech/s1/conf/chunk_confermer.yaml +++ b/examples/librispeech/s1/conf/chunk_confermer.yaml @@ -3,18 +3,20 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test - vocab_filepath: data/vocab.txt - unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_5000' - mean_std_filepath: "" - augmentation_config: conf/augmentation.json - batch_size: 4 min_input_len: 0.5 max_input_len: 20.0 min_output_len: 0.0 max_output_len: 400.0 min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: 'data/bpe_unigram_5000' + mean_std_filepath: "" + augmentation_config: conf/augmentation.json + batch_size: 16 raw_wav: True # use raw_wav or kaldi feature specgram_type: fbank #linear, mfcc, fbank feat_dim: 80 @@ -80,7 +82,7 @@ model: training: n_epoch: 120 - accum_grad: 1 + accum_grad: 8 global_grad_clip: 5.0 optim: adam optim_conf: diff --git a/examples/librispeech/s1/conf/chunk_transformer.yaml b/examples/librispeech/s1/conf/chunk_transformer.yaml index 3939ffc688e1de5dc66606328e48e2d69459b0b6..5ec2ad1263ed8037cb17436203a773b629764cbc 100644 --- a/examples/librispeech/s1/conf/chunk_transformer.yaml +++ b/examples/librispeech/s1/conf/chunk_transformer.yaml @@ -3,18 +3,20 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test - vocab_filepath: data/vocab.txt - unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_5000' - mean_std_filepath: "" - augmentation_config: conf/augmentation.json - batch_size: 64 min_input_len: 0.5 # second max_input_len: 20.0 # second min_output_len: 0.0 # tokens max_output_len: 400.0 # tokens min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: 'data/bpe_unigram_5000' + mean_std_filepath: "" + augmentation_config: conf/augmentation.json + batch_size: 64 raw_wav: True # use raw_wav or kaldi feature specgram_type: fbank #linear, mfcc, fbank feat_dim: 80 @@ -103,6 +105,6 @@ decoding: # >0: for decoding, use fixed chunk size as set. # 0: used for training, it's prohibited here. num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. - simulate_streaming: False # simulate streaming inference. Defaults to False. + simulate_streaming: true # simulate streaming inference. Defaults to False. diff --git a/examples/librispeech/s1/conf/conformer.yaml b/examples/librispeech/s1/conf/conformer.yaml index 8f8bf45398813179db88781dcfc5c71356295934..cce31b163cec5ba977e528f0a05d271bcfaedc1c 100644 --- a/examples/librispeech/s1/conf/conformer.yaml +++ b/examples/librispeech/s1/conf/conformer.yaml @@ -3,18 +3,20 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test-clean - vocab_filepath: data/vocab.txt - unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_5000' - mean_std_filepath: "" - augmentation_config: conf/augmentation.json - batch_size: 16 min_input_len: 0.5 # seconds max_input_len: 20.0 # seconds min_output_len: 0.0 # tokens max_output_len: 400.0 # tokens min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: 'data/bpe_unigram_5000' + mean_std_filepath: "" + augmentation_config: conf/augmentation.json + batch_size: 16 raw_wav: True # use raw_wav or kaldi feature specgram_type: fbank #linear, mfcc, fbank feat_dim: 80 diff --git a/examples/librispeech/s1/conf/transformer.yaml b/examples/librispeech/s1/conf/transformer.yaml index a094b0fba6088ced2252fc71963ed3afb9ca5c0f..8ea49477260e8135b7c9647b0f3aa28748c9e7cf 100644 --- a/examples/librispeech/s1/conf/transformer.yaml +++ b/examples/librispeech/s1/conf/transformer.yaml @@ -3,18 +3,20 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test-clean - vocab_filepath: data/vocab.txt - unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_5000' - mean_std_filepath: "" - augmentation_config: conf/augmentation.json - batch_size: 64 min_input_len: 0.5 # second max_input_len: 20.0 # second min_output_len: 0.0 # tokens max_output_len: 400.0 # tokens min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: 'data/bpe_unigram_5000' + mean_std_filepath: "" + augmentation_config: conf/augmentation.json + batch_size: 64 raw_wav: True # use raw_wav or kaldi feature specgram_type: fbank #linear, mfcc, fbank feat_dim: 80 diff --git a/speechnn/CMakeLists.txt b/speechnn/CMakeLists.txt index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..878374bab0143b8e3346e2f4122b33044a48e726 100644 --- a/speechnn/CMakeLists.txt +++ b/speechnn/CMakeLists.txt @@ -0,0 +1,77 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +project(deepspeech VERSION 0.1) + +set(CMAKE_VERBOSE_MAKEFILE on) +# set std-14 +set(CMAKE_CXX_STANDARD 14) + +# include file +include(FetchContent) +include(ExternalProject) +# fc_patch dir +set(FETCHCONTENT_QUIET off) +get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}") +set(FETCHCONTENT_BASE_DIR ${fc_patch}) + + +############################################################################### +# Option Configurations +############################################################################### +# option configurations +option(TEST_DEBUG "option for debug" OFF) + + +############################################################################### +# Include third party +############################################################################### +# #example for include third party +# FetchContent_Declare() +# # FetchContent_MakeAvailable was not added until CMake 3.14 +# FetchContent_MakeAvailable() +# include_directories() + +# ABSEIL-CPP +include(FetchContent) +FetchContent_Declare( + absl + GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git" + GIT_TAG "20210324.1" +) +FetchContent_MakeAvailable(absl) + +# libsndfile +include(FetchContent) +FetchContent_Declare( + libsndfile + GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git" + GIT_TAG "1.0.31" +) +FetchContent_MakeAvailable(libsndfile) + + +############################################################################### +# Add local library +############################################################################### +# system lib +find_package() +# if dir have CmakeLists.txt +add_subdirectory() +# if dir do not have CmakeLists.txt +add_library(lib_name STATIC file.cc) +target_link_libraries(lib_name item0 item1) +add_dependencies(lib_name depend-target) + + +############################################################################### +# Library installation +############################################################################### +install() + + +############################################################################### +# Build binary file +############################################################################### +add_executable() +target_link_libraries() + diff --git a/speechnn/core/decoder/CMakeLists.txt b/speechnn/core/decoder/CMakeLists.txt index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..259261bdf12272f824eab99288e24a40204db3d7 100644 --- a/speechnn/core/decoder/CMakeLists.txt +++ b/speechnn/core/decoder/CMakeLists.txt @@ -0,0 +1,2 @@ +aux_source_directory(. DIR_LIB_SRCS) +add_library(decoder STATIC ${DIR_LIB_SRCS})