From 8e73d1841b8950e9b9141aa0aef0130b1376d936 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 5 Oct 2021 01:37:28 +0000 Subject: [PATCH] tiny/s0/s1 can run all --- .../exps/deepspeech2/bin/deploy/runtime.py | 25 +- .../exps/deepspeech2/bin/deploy/server.py | 29 +- deepspeech/exps/deepspeech2/model.py | 432 +++++++++++++----- deepspeech/exps/u2/bin/alignment.py | 3 + deepspeech/exps/u2/bin/export.py | 3 + deepspeech/exps/u2/bin/test.py | 3 + deepspeech/exps/u2/bin/train.py | 4 +- deepspeech/exps/u2/model.py | 90 ++-- .../frontend/featurizer/text_featurizer.py | 112 ++--- deepspeech/frontend/utility.py | 159 +++++-- deepspeech/io/collator.py | 12 +- deepspeech/io/dataset.py | 3 +- deepspeech/models/ds2/conv.py | 14 +- deepspeech/models/ds2/deepspeech2.py | 28 +- deepspeech/models/ds2/rnn.py | 12 +- deepspeech/models/ds2_online/deepspeech2.py | 96 ++-- deepspeech/training/trainer.py | 176 ++++--- deepspeech/utils/log.py | 6 +- examples/dataset/mini_librispeech/.gitignore | 1 + .../mini_librispeech/mini_librispeech.py | 21 + examples/librispeech/s1/local/align.sh | 32 ++ examples/librispeech/s1/local/data.sh | 2 +- .../librispeech/s1/local/download_lm_en.sh | 2 +- examples/librispeech/s1/local/export.sh | 8 +- examples/librispeech/s1/local/test.sh | 106 +++-- examples/librispeech/s1/local/train.sh | 25 +- examples/tiny/s0/conf/deepspeech2.yaml | 3 + examples/tiny/s0/conf/deepspeech2_online.yaml | 72 +++ examples/tiny/s0/local/download_lm_en.sh | 7 +- examples/tiny/s0/local/export.sh | 17 +- examples/tiny/s0/local/test.sh | 17 +- examples/tiny/s0/local/train.sh | 38 +- examples/tiny/s0/path.sh | 2 +- examples/tiny/s0/run.sh | 11 +- examples/tiny/s1/conf/transformer.yaml | 2 + examples/tiny/s1/local/align.sh | 32 ++ examples/tiny/s1/local/data.sh | 2 +- examples/tiny/s1/local/export.sh | 8 +- examples/tiny/s1/local/test.sh | 57 ++- examples/tiny/s1/local/train.sh | 39 +- examples/tiny/s1/run.sh | 12 +- 41 files changed, 1178 insertions(+), 545 deletions(-) create mode 100755 examples/librispeech/s1/local/align.sh create mode 100644 examples/tiny/s0/conf/deepspeech2_online.yaml create mode 100755 examples/tiny/s1/local/align.sh diff --git a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py index 5677d4cf..21ffa6bf 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py @@ -18,8 +18,10 @@ import numpy as np import paddle from paddle.inference import Config from paddle.inference import create_predictor +from paddle.io import DataLoader from deepspeech.exps.deepspeech2.config import get_cfg_defaults +from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.training.cli import default_argument_parser @@ -78,26 +80,31 @@ def inference(config, args): def start_server(config, args): """Start the ASR server""" config.defrost() - config.data.manfiest = config.data.test_manifest - config.data.augmentation_config = "" - config.data.keep_transcription_text = True + config.data.manifest = config.data.test_manifest dataset = ManifestDataset.from_config(config) - model = DeepSpeech2Model.from_pretrained(dataset, config, + config.collator.augmentation_config = "" + config.collator.keep_transcription_text = True + config.collator.batch_size = 1 + config.collator.num_workers = 0 + collate_fn = SpeechCollator.from_config(config) + test_loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=0) + + model = DeepSpeech2Model.from_pretrained(test_loader, config, args.checkpoint_path) model.eval() # prepare ASR inference handler def file_to_transcript(filename): - feature = dataset.process_utterance(filename, "") - audio = np.array([feature[0]]).astype('float32') #[1, D, T] - audio_len = feature[0].shape[1] + feature = test_loader.collate_fn.process_utterance(filename, "") + audio = np.array([feature[0]]).astype('float32') #[1, T, D] + audio_len = feature[0].shape[0] audio_len = np.array([audio_len]).astype('int64') # [1] result_transcript = model.decode( paddle.to_tensor(audio), paddle.to_tensor(audio_len), - vocab_list=dataset.vocab_list, + vocab_list=test_loader.collate_fn.vocab_list, decoding_method=config.decoding.decoding_method, lang_model_path=config.decoding.lang_model_path, beam_alpha=config.decoding.alpha, @@ -138,7 +145,7 @@ if __name__ == "__main__": add_arg('host_ip', str, 'localhost', "Server's IP address.") - add_arg('host_port', int, 8086, "Server's IP port.") + add_arg('host_port', int, 8089, "Server's IP port.") add_arg('speech_save_dir', str, 'demo_cache', "Directory to save demo audios.") diff --git a/deepspeech/exps/deepspeech2/bin/deploy/server.py b/deepspeech/exps/deepspeech2/bin/deploy/server.py index 0e1211b0..583e9095 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/server.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/server.py @@ -16,8 +16,10 @@ import functools import numpy as np import paddle +from paddle.io import DataLoader from deepspeech.exps.deepspeech2.config import get_cfg_defaults +from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.training.cli import default_argument_parser @@ -31,26 +33,35 @@ from deepspeech.utils.utility import print_arguments def start_server(config, args): """Start the ASR server""" config.defrost() - config.data.manfiest = config.data.test_manifest - config.data.augmentation_config = "" - config.data.keep_transcription_text = True + config.data.manifest = config.data.test_manifest dataset = ManifestDataset.from_config(config) - model = DeepSpeech2Model.from_pretrained(dataset, config, + config.collator.augmentation_config = "" + config.collator.keep_transcription_text = True + config.collator.batch_size = 1 + config.collator.num_workers = 0 + collate_fn = SpeechCollator.from_config(config) + test_loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=0) + + model = DeepSpeech2Model.from_pretrained(test_loader, config, args.checkpoint_path) model.eval() # prepare ASR inference handler def file_to_transcript(filename): - feature = dataset.process_utterance(filename, "") - audio = np.array([feature[0]]).astype('float32') #[1, D, T] - audio_len = feature[0].shape[1] + feature = test_loader.collate_fn.process_utterance(filename, "") + audio = np.array([feature[0]]).astype('float32') #[1, T, D] + # audio = audio.swapaxes(1,2) + print('---file_to_transcript feature----') + print(audio.shape) + audio_len = feature[0].shape[0] + print(audio_len) audio_len = np.array([audio_len]).astype('int64') # [1] result_transcript = model.decode( paddle.to_tensor(audio), paddle.to_tensor(audio_len), - vocab_list=dataset.vocab_list, + vocab_list=test_loader.collate_fn.vocab_list, decoding_method=config.decoding.decoding_method, lang_model_path=config.decoding.lang_model_path, beam_alpha=config.decoding.alpha, @@ -91,7 +102,7 @@ if __name__ == "__main__": add_arg('host_ip', str, 'localhost', "Server's IP address.") - add_arg('host_port', int, 8086, "Server's IP port.") + add_arg('host_port', int, 8088, "Server's IP port.") add_arg('speech_save_dir', str, 'demo_cache', "Directory to save demo audios.") diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 05add5bc..d3b321d7 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -12,26 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. """Contains DeepSpeech2 model.""" +import os import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path +from typing import Optional +import jsonlines import numpy as np import paddle from paddle import distributed as dist +from paddle import inference from paddle.io import DataLoader +from yacs.config import CfgNode +from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.models.ds2 import DeepSpeech2InferModel from deepspeech.models.ds2 import DeepSpeech2Model +from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline +from deepspeech.models.ds2_online import DeepSpeech2ModelOnline from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog +from deepspeech.training.reporter import report from deepspeech.training.trainer import Trainer from deepspeech.utils import error_rate from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools +from deepspeech.utils.log import Autolog from deepspeech.utils.log import Log from deepspeech.utils.utility import UpdateConfig @@ -42,9 +53,9 @@ class DeepSpeech2Trainer(Trainer): def __init__(self, config, args): super().__init__(config, args) - def train_batch(self, batch_index, batch_data, msg): + def train_batch(self, batch_index, batch, msg): start = time.time() - loss = self.model(*batch_data) + loss = self.model(*batch) loss.backward() layer_tools.print_grads(self.model, print_func=None) self.optimizer.step() @@ -176,7 +187,7 @@ class DeepSpeech2Trainer(Trainer): sortagrad=config.data.sortagrad, shuffle_method=config.data.shuffle_method) - collate_fn = SpeechCollator(keep_transcription_text=False) + collate_fn = SpeechCollator(keep_transcription_text=False, return_utts=False) self.train_loader = DataLoader( train_dataset, batch_sampler=batch_sampler, @@ -190,10 +201,55 @@ class DeepSpeech2Trainer(Trainer): collate_fn=collate_fn) logger.info("Setup train/valid Dataloader!") + config.data.manifest = config.data.test_manifest + config.data.keep_transcription_text = True + config.data.augmentation_config = "" + # filter test examples, will cause less examples, but no mismatch with training + # and can use large batch size , save training time, so filter test egs now. + # config.data.min_input_len = 0.0 # second + # config.data.max_input_len = float('inf') # second + # config.data.min_output_len = 0.0 # tokens + # config.data.max_output_len = float('inf') # tokens + # config.data.min_output_input_ratio = 0.00 + # config.data.max_output_input_ratio = float('inf') + test_dataset = ManifestDataset.from_config(config) + + # return text ord id + self.test_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=SpeechCollator(keep_transcription_text=True, return_utts=True)) + logger.info("Setup test Dataloader!") + class DeepSpeech2Tester(DeepSpeech2Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # testing config + default = CfgNode( + dict( + alpha=2.5, # Coef of LM for beam search. + beta=0.3, # Coef of WC for beam search. + cutoff_prob=1.0, # Cutoff probability for pruning. + cutoff_top_n=40, # Cutoff number for pruning. + lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. + decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy + error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' + num_proc_bsearch=8, # # of CPUs for beam search. + beam_size=500, # Beam search width. + batch_size=128, # decoding batch size + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + def __init__(self, config, args): super().__init__(config, args) + self._text_featurizer = TextFeaturizer( + unit_type=config.data.unit_type, vocab_filepath=None) def ordid2token(self, texts, texts_len): """ ord() id to chr() chr """ @@ -204,15 +260,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, audio, audio_len, texts, texts_len): - cfg = self.config.decoding - errors_sum, len_refs, num_ins = 0.0, 0, 0 - errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors - error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer - - vocab_list = self.test_loader.dataset.vocab_list + def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): + self.autolog.times.start() + self.autolog.times.stamp() - target_transcripts = self.ordid2token(texts, texts_len) result_transcripts = self.model.decode( audio, audio_len, @@ -225,14 +276,48 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): cutoff_prob=cfg.cutoff_prob, cutoff_top_n=cfg.cutoff_top_n, num_processes=cfg.num_proc_bsearch) + #replace the with ' ' + result_transcripts = [ + self._text_featurizer.detokenize(sentence) + for sentence in result_transcripts + ] + + self.autolog.times.stamp() + self.autolog.times.stamp() + self.autolog.times.end() + return result_transcripts + + def compute_metrics(self, + utts, + audio, + audio_len, + texts, + texts_len, + fout=None): + cfg = self.config.decoding + errors_sum, len_refs, num_ins = 0.0, 0, 0 + errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors + error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer - for target, result in zip(target_transcripts, result_transcripts): + #vocab_list = self.test_loader.collate_fn.vocab_list + vocab_list = self.test_loader.dataset.vocab_list + + target_transcripts = self.ordid2token(texts, texts_len) + + result_transcripts = self.compute_result_transcripts(audio, audio_len, + vocab_list, cfg) + + for utt, target, result in zip(utts, target_transcripts, + result_transcripts): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref num_ins += 1 - logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target, result)) + if fout: + fout.write({"utt": utt, "ref": target, "hyp": result}) + logger.info(f"Utt: {utt}") + logger.info(f"Ref: {target}") + logger.info(f"Hyp: {result}") logger.info("Current error rate [%s] = %f" % (cfg.error_rate_type, error_rate_func(target, result))) @@ -247,19 +332,25 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): @paddle.no_grad() def test(self): logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + self.autolog = Autolog( + batch_size=self.config.decoding.batch_size, + model_name="deepspeech2", + model_precision="fp32").getlog() self.model.eval() cfg = self.config error_rate_type = None errors_sum, len_refs, num_ins = 0.0, 0, 0 - - for i, batch in enumerate(self.test_loader): - metrics = self.compute_metrics(*batch) - errors_sum += metrics['errors_sum'] - len_refs += metrics['len_refs'] - num_ins += metrics['num_ins'] - error_rate_type = metrics['error_rate_type'] - logger.info("Error rate [%s] (%d/?) = %f" % - (error_rate_type, num_ins, errors_sum / len_refs)) + with jsonlines.open(self.args.result_file, 'w') as fout: + for i, batch in enumerate(self.test_loader): + audio, audio_len, texts, texts_len, utts = batch + metrics = self.compute_metrics(utts, audio, audio_len, texts, + texts_len, fout) + errors_sum += metrics['errors_sum'] + len_refs += metrics['len_refs'] + num_ins += metrics['num_ins'] + error_rate_type = metrics['error_rate_type'] + logger.info("Error rate [%s] (%d/?) = %f" % + (error_rate_type, num_ins, errors_sum / len_refs)) # logging msg = "Test: " @@ -268,101 +359,234 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): msg += "Final error rate [%s] (%d/%d) = %f" % ( error_rate_type, num_ins, num_ins, errors_sum / len_refs) logger.info(msg) - - def run_test(self): - self.resume_or_scratch() - try: - self.test() - except KeyboardInterrupt: - exit(-1) + self.autolog.report() def export(self): - infer_model = DeepSpeech2InferModel.from_pretrained( - self.test_loader.dataset, self.config, self.args.checkpoint_path) + if self.args.model_type == 'offline': + infer_model = DeepSpeech2InferModel.from_pretrained( + self.test_loader, self.config, self.args.checkpoint_path) + elif self.args.model_type == 'online': + infer_model = DeepSpeech2InferModelOnline.from_pretrained( + self.test_loader, self.config, self.args.checkpoint_path) + else: + raise Exception("wrong model type") + infer_model.eval() + #feat_dim = self.test_loader.collate_fn.feature_size feat_dim = self.test_loader.dataset.feature_size - static_model = paddle.jit.to_static( - infer_model, - input_spec=[ - paddle.static.InputSpec( - shape=[None, None, feat_dim], - dtype='float32'), # audio, [B,T,D] - paddle.static.InputSpec(shape=[None], - dtype='int64'), # audio_length, [B] - ]) + static_model = infer_model.export() logger.info(f"Export code: {static_model.forward.code}") paddle.jit.save(static_model, self.args.export_path) - def run_export(self): - try: - self.export() - except KeyboardInterrupt: - exit(-1) - - def setup(self): - """Setup the experiment. - """ - paddle.set_device(self.args.device) - self.setup_output_dir() - self.setup_checkpointer() +class DeepSpeech2ExportTester(DeepSpeech2Tester): + def __init__(self, config, args): + super().__init__(config, args) - self.setup_dataloader() - self.setup_model() + def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): + if self.args.model_type == "online": + output_probs, output_lens = self.static_forward_online(audio, + audio_len) + elif self.args.model_type == "offline": + output_probs, output_lens = self.static_forward_offline(audio, + audio_len) + else: + raise Exception("wrong model type") - self.iteration = 0 - self.epoch = 0 + self.predictor.clear_intermediate_tensor() + self.predictor.try_shrink_memory() - def setup_model(self): - config = self.config - model = DeepSpeech2Model( - feat_size=self.test_loader.dataset.feature_size, - dict_size=self.test_loader.dataset.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) - self.model = model - logger.info("Setup model!") + self.model.decoder.init_decode(cfg.alpha, cfg.beta, cfg.lang_model_path, + vocab_list, cfg.decoding_method) - def setup_dataloader(self): - config = self.config.clone() - config.defrost() - # return raw text - - config.data.manifest = config.data.test_manifest - config.data.keep_transcription_text = True - config.data.augmentation_config = "" - # filter test examples, will cause less examples, but no mismatch with training - # and can use large batch size , save training time, so filter test egs now. - # config.data.min_input_len = 0.0 # second - # config.data.max_input_len = float('inf') # second - # config.data.min_output_len = 0.0 # tokens - # config.data.max_output_len = float('inf') # tokens - # config.data.min_output_input_ratio = 0.00 - # config.data.max_output_input_ratio = float('inf') - test_dataset = ManifestDataset.from_config(config) + result_transcripts = self.model.decoder.decode_probs( + output_probs, output_lens, vocab_list, cfg.decoding_method, + cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, + cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) + #replace the with ' ' + result_transcripts = [ + self._text_featurizer.detokenize(sentence) + for sentence in result_transcripts + ] - # return text ord id - self.test_loader = DataLoader( - test_dataset, - batch_size=config.decoding.batch_size, - shuffle=False, - drop_last=False, - collate_fn=SpeechCollator(keep_transcription_text=True)) - logger.info("Setup test Dataloader!") + return result_transcripts - def setup_output_dir(self): - """Create a directory used for output. + def static_forward_online(self, audio, audio_len, + decoder_chunk_size: int=1): + """ + Parameters + ---------- + audio (Tensor): shape[B, T, D] + audio_len (Tensor): shape[B] + decoder_chunk_size(int) + Returns + ------- + output_probs(numpy.array): shape[B, T, vocab_size] + output_lens(numpy.array): shape[B] """ - # output dir - if self.args.output: - output_dir = Path(self.args.output).expanduser() - output_dir.mkdir(parents=True, exist_ok=True) + output_probs_list = [] + output_lens_list = [] + subsampling_rate = self.model.encoder.conv.subsampling_rate + receptive_field_length = self.model.encoder.conv.receptive_field_length + chunk_stride = subsampling_rate * decoder_chunk_size + chunk_size = (decoder_chunk_size - 1 + ) * subsampling_rate + receptive_field_length + + x_batch = audio.numpy() + batch_size, Tmax, x_dim = x_batch.shape + x_len_batch = audio_len.numpy().astype(np.int64) + if (Tmax - chunk_size) % chunk_stride != 0: + padding_len_batch = chunk_stride - ( + Tmax - chunk_size + ) % chunk_stride # The length of padding for the batch else: - output_dir = Path( - self.args.checkpoint_path).expanduser().parent.parent - output_dir.mkdir(parents=True, exist_ok=True) + padding_len_batch = 0 + x_list = np.split(x_batch, batch_size, axis=0) + x_len_list = np.split(x_len_batch, batch_size, axis=0) + + for x, x_len in zip(x_list, x_len_list): + self.autolog.times.start() + self.autolog.times.stamp() + x_len = x_len[0] + assert (chunk_size <= x_len) + + if (x_len - chunk_size) % chunk_stride != 0: + padding_len_x = chunk_stride - (x_len - chunk_size + ) % chunk_stride + else: + padding_len_x = 0 + + padding = np.zeros( + (x.shape[0], padding_len_x, x.shape[2]), dtype=x.dtype) + padded_x = np.concatenate([x, padding], axis=1) + + num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1 + num_chunk = int(num_chunk) + + chunk_state_h_box = np.zeros( + (self.config.model.num_rnn_layers, 1, + self.config.model.rnn_layer_size), + dtype=x.dtype) + chunk_state_c_box = np.zeros( + (self.config.model.num_rnn_layers, 1, + self.config.model.rnn_layer_size), + dtype=x.dtype) + + input_names = self.predictor.get_input_names() + audio_handle = self.predictor.get_input_handle(input_names[0]) + audio_len_handle = self.predictor.get_input_handle(input_names[1]) + h_box_handle = self.predictor.get_input_handle(input_names[2]) + c_box_handle = self.predictor.get_input_handle(input_names[3]) + + probs_chunk_list = [] + probs_chunk_lens_list = [] + for i in range(0, num_chunk): + start = i * chunk_stride + end = start + chunk_size + x_chunk = padded_x[:, start:end, :] + if x_len < i * chunk_stride: + x_chunk_lens = 0 + else: + x_chunk_lens = min(x_len - i * chunk_stride, chunk_size) + + if (x_chunk_lens < + receptive_field_length): #means the number of input frames in the chunk is not enough for predicting one prob + break + x_chunk_lens = np.array([x_chunk_lens]) + audio_handle.reshape(x_chunk.shape) + audio_handle.copy_from_cpu(x_chunk) + + audio_len_handle.reshape(x_chunk_lens.shape) + audio_len_handle.copy_from_cpu(x_chunk_lens) + + h_box_handle.reshape(chunk_state_h_box.shape) + h_box_handle.copy_from_cpu(chunk_state_h_box) + + c_box_handle.reshape(chunk_state_c_box.shape) + c_box_handle.copy_from_cpu(chunk_state_c_box) + + output_names = self.predictor.get_output_names() + output_handle = self.predictor.get_output_handle( + output_names[0]) + output_lens_handle = self.predictor.get_output_handle( + output_names[1]) + output_state_h_handle = self.predictor.get_output_handle( + output_names[2]) + output_state_c_handle = self.predictor.get_output_handle( + output_names[3]) + self.predictor.run() + output_chunk_probs = output_handle.copy_to_cpu() + output_chunk_lens = output_lens_handle.copy_to_cpu() + chunk_state_h_box = output_state_h_handle.copy_to_cpu() + chunk_state_c_box = output_state_c_handle.copy_to_cpu() + + probs_chunk_list.append(output_chunk_probs) + probs_chunk_lens_list.append(output_chunk_lens) + output_probs = np.concatenate(probs_chunk_list, axis=1) + output_lens = np.sum(probs_chunk_lens_list, axis=0) + vocab_size = output_probs.shape[2] + output_probs_padding_len = Tmax + padding_len_batch - output_probs.shape[ + 1] + output_probs_padding = np.zeros( + (1, output_probs_padding_len, vocab_size), + dtype=output_probs. + dtype) # The prob padding for a piece of utterance + output_probs = np.concatenate( + [output_probs, output_probs_padding], axis=1) + output_probs_list.append(output_probs) + output_lens_list.append(output_lens) + self.autolog.times.stamp() + self.autolog.times.stamp() + self.autolog.times.end() + output_probs = np.concatenate(output_probs_list, axis=0) + output_lens = np.concatenate(output_lens_list, axis=0) + return output_probs, output_lens + + def static_forward_offline(self, audio, audio_len): + """ + Parameters + ---------- + audio (Tensor): shape[B, T, D] + audio_len (Tensor): shape[B] + + Returns + ------- + output_probs(numpy.array): shape[B, T, vocab_size] + output_lens(numpy.array): shape[B] + """ + x = audio.numpy() + x_len = audio_len.numpy().astype(np.int64) + + input_names = self.predictor.get_input_names() + audio_handle = self.predictor.get_input_handle(input_names[0]) + audio_len_handle = self.predictor.get_input_handle(input_names[1]) - self.output_dir = output_dir + audio_handle.reshape(x.shape) + audio_handle.copy_from_cpu(x) + + audio_len_handle.reshape(x_len.shape) + audio_len_handle.copy_from_cpu(x_len) + + self.autolog.times.start() + self.autolog.times.stamp() + self.predictor.run() + self.autolog.times.stamp() + self.autolog.times.stamp() + self.autolog.times.end() + + output_names = self.predictor.get_output_names() + output_handle = self.predictor.get_output_handle(output_names[0]) + output_lens_handle = self.predictor.get_output_handle(output_names[1]) + output_probs = output_handle.copy_to_cpu() + output_lens = output_lens_handle.copy_to_cpu() + return output_probs, output_lens + + def setup_model(self): + super().setup_model() + infer_config = inference.Config(self.args.export_path + ".pdmodel", + self.args.export_path + ".pdiparams") + if (os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''): + infer_config.enable_use_gpu(100, 0) + infer_config.enable_memory_optim() + infer_predictor = inference.create_predictor(infer_config) + self.predictor = infer_predictor diff --git a/deepspeech/exps/u2/bin/alignment.py b/deepspeech/exps/u2/bin/alignment.py index c1c9582f..cef9d1ab 100644 --- a/deepspeech/exps/u2/bin/alignment.py +++ b/deepspeech/exps/u2/bin/alignment.py @@ -30,6 +30,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") args = parser.parse_args() print_arguments(args, globals()) diff --git a/deepspeech/exps/u2/bin/export.py b/deepspeech/exps/u2/bin/export.py index 292c7838..3dc41b70 100644 --- a/deepspeech/exps/u2/bin/export.py +++ b/deepspeech/exps/u2/bin/export.py @@ -30,6 +30,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + # save jit model to + parser.add_argument( + "--export_path", type=str, help="path of the jit model to save") args = parser.parse_args() print_arguments(args, globals()) diff --git a/deepspeech/exps/u2/bin/test.py b/deepspeech/exps/u2/bin/test.py index c47f932c..f6127675 100644 --- a/deepspeech/exps/u2/bin/test.py +++ b/deepspeech/exps/u2/bin/test.py @@ -34,6 +34,9 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") args = parser.parse_args() print_arguments(args, globals()) diff --git a/deepspeech/exps/u2/bin/train.py b/deepspeech/exps/u2/bin/train.py index 9dd0041d..17fb08a6 100644 --- a/deepspeech/exps/u2/bin/train.py +++ b/deepspeech/exps/u2/bin/train.py @@ -22,6 +22,8 @@ from deepspeech.exps.u2.model import U2Trainer as Trainer from deepspeech.training.cli import default_argument_parser from deepspeech.utils.utility import print_arguments +# from deepspeech.exps.u2.trainer import U2Trainer as Trainer + def main_sp(config, args): exp = Trainer(config, args) @@ -30,7 +32,7 @@ def main_sp(config, args): def main(config, args): - if args.device == "gpu" and args.nprocs > 1: + if args.nprocs > 0: dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) else: main_sp(config, args) diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index f166a071..af84d9cf 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -73,11 +73,11 @@ class U2Trainer(Trainer): def __init__(self, config, args): super().__init__(config, args) - def train_batch(self, batch_index, batch_data, msg): + def train_batch(self, batch_index, batch, msg): train_conf = self.config.training start = time.time() - loss, attention_loss, ctc_loss = self.model(*batch_data) + loss, attention_loss, ctc_loss = self.model(*batch) # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad loss.backward() @@ -219,7 +219,7 @@ class U2Trainer(Trainer): config.data.augmentation_config = "" dev_dataset = ManifestDataset.from_config(config) - collate_fn = SpeechCollator(keep_transcription_text=False) + collate_fn = SpeechCollator(keep_transcription_text=False, return_utts=False) if self.parallel: batch_sampler = SortagradDistributedBatchSampler( train_dataset, @@ -269,7 +269,7 @@ class U2Trainer(Trainer): batch_size=config.decoding.batch_size, shuffle=False, drop_last=False, - collate_fn=SpeechCollator(keep_transcription_text=True)) + collate_fn=SpeechCollator(keep_transcription_text=True, return_utts=True)) logger.info("Setup train/valid/test Dataloader!") def setup_model(self): @@ -345,7 +345,7 @@ class U2Tester(U2Trainer): decoding_chunk_size=-1, # decoding chunk size. Defaults to -1. # <0: for decoding, use full chunk. # >0: for decoding, use fixed chunk size as set. - # 0: used for training, it's prohibited here. + # 0: used for training, it's prohibited here. num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1. simulate_streaming=False, # simulate streaming inference. Defaults to False. )) @@ -428,7 +428,7 @@ class U2Tester(U2Trainer): num_time = 0.0 with open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): - metrics = self.compute_metrics(*batch, fout=fout) + metrics = self.compute_metrics(*batch[:-1], fout=fout) num_frames += metrics['num_frames'] num_time += metrics["decode_time"] errors_sum += metrics['errors_sum'] @@ -476,12 +476,12 @@ class U2Tester(U2Trainer): }) f.write(data + '\n') - def run_test(self): - self.resume_or_scratch() - try: - self.test() - except KeyboardInterrupt: - sys.exit(-1) + # def run_test(self): + # self.resume_or_scratch() + # try: + # self.test() + # except KeyboardInterrupt: + # sys.exit(-1) def load_inferspec(self): """infer model and input spec. @@ -512,36 +512,36 @@ class U2Tester(U2Trainer): logger.info(f"Export code: {static_model.forward.code}") paddle.jit.save(static_model, self.args.export_path) - def run_export(self): - try: - self.export() - except KeyboardInterrupt: - sys.exit(-1) - - def setup(self): - """Setup the experiment. - """ - paddle.set_device(self.args.device) - - self.setup_output_dir() - self.setup_checkpointer() - - self.setup_dataloader() - self.setup_model() - - self.iteration = 0 - self.epoch = 0 - - def setup_output_dir(self): - """Create a directory used for output. - """ - # output dir - if self.args.output: - output_dir = Path(self.args.output).expanduser() - output_dir.mkdir(parents=True, exist_ok=True) - else: - output_dir = Path( - self.args.checkpoint_path).expanduser().parent.parent - output_dir.mkdir(parents=True, exist_ok=True) - - self.output_dir = output_dir + # def run_export(self): + # try: + # self.export() + # except KeyboardInterrupt: + # sys.exit(-1) + + # def setup(self): + # """Setup the experiment. + # """ + # paddle.set_device(self.args.device) + + # self.setup_output_dir() + # self.setup_checkpointer() + + # self.setup_dataloader() + # self.setup_model() + + # self.iteration = 0 + # self.epoch = 0 + + # def setup_output_dir(self): + # """Create a directory used for output. + # """ + # # output dir + # if self.args.output: + # output_dir = Path(self.args.output).expanduser() + # output_dir.mkdir(parents=True, exist_ok=True) + # else: + # output_dir = Path( + # self.args.checkpoint_path).expanduser().parent.parent + # output_dir.mkdir(parents=True, exist_ok=True) + + # self.output_dir = output_dir diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index 1ba6ac7f..dcbc363f 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -14,12 +14,27 @@ """Contains the text featurizer class.""" import sentencepiece as spm -from deepspeech.frontend.utility import EOS -from deepspeech.frontend.utility import UNK +from ..utility import EOS +from ..utility import SPACE +from ..utility import UNK +from ..utility import SOS +from ..utility import BLANK +from ..utility import MASKCTC +from ..utility import load_dict +from deepspeech.utils.log import Log -class TextFeaturizer(object): - def __init__(self, unit_type, vocab_filepath, spm_model_prefix=None): +logger = Log(__name__).getlog() + +__all__ = ["TextFeaturizer"] + + +class TextFeaturizer(): + def __init__(self, + unit_type, + vocab_filepath, + spm_model_prefix=None, + maskctc=False): """Text featurizer, for processing or extracting features from text. Currently, it supports char/word/sentence-piece level tokenizing and conversion into @@ -34,20 +49,21 @@ class TextFeaturizer(object): assert unit_type in ('char', 'spm', 'word') self.unit_type = unit_type self.unk = UNK + self.maskctc = maskctc + if vocab_filepath: - self._vocab_dict, self._id2token, self._vocab_list = self._load_vocabulary_from_file( - vocab_filepath) - self.unk_id = self._vocab_list.index(self.unk) - self.eos_id = self._vocab_list.index(EOS) + self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id = self._load_vocabulary_from_file( + vocab_filepath, maskctc) + self.vocab_size = len(self.vocab_list) if unit_type == 'spm': spm_model = spm_model_prefix + '.model' self.sp = spm.SentencePieceProcessor() self.sp.Load(spm_model) - def tokenize(self, text): + def tokenize(self, text, replace_space=True): if self.unit_type == 'char': - tokens = self.char_tokenize(text) + tokens = self.char_tokenize(text, replace_space) elif self.unit_type == 'word': tokens = self.word_tokenize(text) else: # spm @@ -67,27 +83,27 @@ class TextFeaturizer(object): """Convert text string to a list of token indices. Args: - text (str): Text to process. - + text (str): Text. + Returns: List[int]: List of token indices. """ tokens = self.tokenize(text) ids = [] for token in tokens: - token = token if token in self._vocab_dict else self.unk - ids.append(self._vocab_dict[token]) + token = token if token in self.vocab_dict else self.unk + ids.append(self.vocab_dict[token]) return ids def defeaturize(self, idxs): """Convert a list of token indices to text string, - ignore index after eos_id. + ignore index after eos_id. Args: idxs (List[int]): List of token indices. Returns: - str: Text to process. + str: Text. """ tokens = [] for idx in idxs: @@ -97,43 +113,22 @@ class TextFeaturizer(object): text = self.detokenize(tokens) return text - @property - def vocab_size(self): - """Return the vocabulary size. - - :return: Vocabulary size. - :rtype: int - """ - return len(self._vocab_list) - - @property - def vocab_list(self): - """Return the vocabulary in list. - - Returns: - List[str]: tokens. - """ - return self._vocab_list - - @property - def vocab_dict(self): - """Return the vocabulary in dict. - - Returns: - Dict[str, int]: token str -> int - """ - return self._vocab_dict - - def char_tokenize(self, text): + def char_tokenize(self, text, replace_space=True): """Character tokenizer. Args: text (str): text string. + replace_space (bool): False only used by build_vocab.py. Returns: List[str]: tokens. """ - return list(text.strip()) + text = text.strip() + if replace_space: + text_list = [SPACE if item == " " else item for item in list(text)] + else: + text_list = list(text) + return text_list def char_detokenize(self, tokens): """Character detokenizer. @@ -144,6 +139,7 @@ class TextFeaturizer(object): Returns: str: text string. """ + tokens = tokens.replace(SPACE, " ") return "".join(tokens) def word_tokenize(self, text): @@ -206,14 +202,28 @@ class TextFeaturizer(object): return decode(tokens) - def _load_vocabulary_from_file(self, vocab_filepath): + def _load_vocabulary_from_file(self, vocab_filepath: str, maskctc: bool): """Load vocabulary from file.""" - vocab_lines = [] - with open(vocab_filepath, 'r', encoding='utf-8') as file: - vocab_lines.extend(file.readlines()) - vocab_list = [line[:-1] for line in vocab_lines] + vocab_list = load_dict(vocab_filepath, maskctc) + assert vocab_list is not None + logger.info(f"Vocab: {vocab_list}") + id2token = dict( [(idx, token) for (idx, token) in enumerate(vocab_list)]) token2id = dict( [(token, idx) for (idx, token) in enumerate(vocab_list)]) - return token2id, id2token, vocab_list + + blank_id = vocab_list.index(BLANK) if BLANK in vocab_list else -1 + maskctc_id = vocab_list.index(MASKCTC) if MASKCTC in vocab_list else -1 + unk_id = vocab_list.index(UNK) if UNK in vocab_list else -1 + eos_id = vocab_list.index(EOS) if EOS in vocab_list else -1 + sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1 + space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1 + + logger.info(f"UNK id: {unk_id}") + logger.info(f"EOS id: {eos_id}") + logger.info(f"SOS id: {sos_id}") + logger.info(f"SPACE id: {space_id}") + logger.info(f"BLANK id: {blank_id}") + logger.info(f"MASKCTC id: {maskctc_id}") + return token2id, id2token, vocab_list, unk_id, eos_id diff --git a/deepspeech/frontend/utility.py b/deepspeech/frontend/utility.py index b2dd9601..f83f1d4e 100644 --- a/deepspeech/frontend/utility.py +++ b/deepspeech/frontend/utility.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """Contains data helper functions.""" -import codecs import json import math +import tarfile +from collections import namedtuple +from typing import List +from typing import Optional +from typing import Text +import jsonlines import numpy as np from deepspeech.utils.log import Log @@ -23,16 +28,40 @@ from deepspeech.utils.log import Log logger = Log(__name__).getlog() __all__ = [ - "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", "max_dbfs", - "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS", "EOS", "UNK", - "BLANK" + "load_dict", "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs", + "max_dbfs", "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS", + "EOS", "UNK", "BLANK", "MASKCTC", "SPACE" ] IGNORE_ID = -1 -SOS = "" +# `sos` and `eos` using same token +SOS = "" EOS = SOS UNK = "" BLANK = "" +MASKCTC = "" +SPACE = "" + + +def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]: + if dict_path is None: + return None + + with open(dict_path, "r") as f: + dictionary = f.readlines() + # first token is `` + # multi line: ` 0\n` + # one line: `` + # space is relpace with + char_list = [entry[:-1].split(" ")[0] for entry in dictionary] + if BLANK not in char_list: + char_list.insert(0, BLANK) + if EOS not in char_list: + char_list.append(EOS) + # for non-autoregressive maskctc model + if maskctc and MASKCTC not in char_list: + char_list.append(MASKCTC) + return char_list def read_manifest( @@ -47,12 +76,20 @@ def read_manifest( Args: manifest_path ([type]): Manifest file to load and parse. - max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf'). - min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0. - max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0. - min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0. - max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0. - min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05. + max_input_len ([type], optional): maximum output seq length, + in seconds for raw wav, in frame numbers for feature data. + Defaults to float('inf'). + min_input_len (float, optional): minimum input seq length, + in seconds for raw wav, in frame numbers for feature data. + Defaults to 0.0. + max_output_len (float, optional): maximum input seq length, + in modeling units. Defaults to 500.0. + min_output_len (float, optional): minimum input seq length, + in modeling units. Defaults to 0.0. + max_output_input_ratio (float, optional): + maximum output seq length/output seq length ratio. Defaults to 10.0. + min_output_input_ratio (float, optional): + minimum output seq length/output seq length ratio. Defaults to 0.05. Raises: IOError: If failed to parse the manifest. @@ -62,29 +99,70 @@ def read_manifest( """ manifest = [] - for json_line in codecs.open(manifest_path, 'r', 'utf-8'): - try: - json_data = json.loads(json_line) - except Exception as e: - raise IOError("Error reading manifest: %s" % str(e)) - - feat_len = json_data["feat_shape"][ - 0] if 'feat_shape' in json_data else 1.0 - token_len = json_data["token_shape"][ - 0] if 'token_shape' in json_data else 1.0 - conditions = [ - feat_len >= min_input_len, - feat_len <= max_input_len, - token_len >= min_output_len, - token_len <= max_output_len, - token_len / feat_len >= min_output_input_ratio, - token_len / feat_len <= max_output_input_ratio, - ] - if all(conditions): - manifest.append(json_data) + with jsonlines.open(manifest_path, 'r') as reader: + for json_data in reader: + feat_len = json_data["feat_shape"][ + 0] if 'feat_shape' in json_data else 1.0 + token_len = json_data["token_shape"][ + 0] if 'token_shape' in json_data else 1.0 + conditions = [ + feat_len >= min_input_len, + feat_len <= max_input_len, + token_len >= min_output_len, + token_len <= max_output_len, + token_len / feat_len >= min_output_input_ratio, + token_len / feat_len <= max_output_input_ratio, + ] + if all(conditions): + manifest.append(json_data) return manifest +# Tar File read +TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) + + +def parse_tar(file): + """Parse a tar file to get a tarfile object + and a map containing tarinfoes + """ + result = {} + f = tarfile.open(file) + for tarinfo in f.getmembers(): + result[tarinfo.name] = tarinfo + return f, result + + +def subfile_from_tar(file, local_data=None): + """Get subfile object from tar. + + tar:tarpath#filename + + It will return a subfile object from tar file + and cached tar file info for next reading request. + """ + tarpath, filename = file.split(':', 1)[1].split('#', 1) + + if local_data is None: + local_data = TarLocalData(tar2info={}, tar2object={}) + + assert isinstance(local_data, TarLocalData) + + if 'tar2info' not in local_data.__dict__: + local_data.tar2info = {} + if 'tar2object' not in local_data.__dict__: + local_data.tar2object = {} + + if tarpath not in local_data.tar2info: + fobj, infos = parse_tar(tarpath) + local_data.tar2info[tarpath] = infos + local_data.tar2object[tarpath] = fobj + else: + fobj = local_data.tar2object[tarpath] + infos = local_data.tar2info[tarpath] + return fobj.extractfile(infos[filename]) + + def rms_to_db(rms: float): """Root Mean Square to dB. @@ -101,7 +179,7 @@ def rms_to_dbfs(rms: float): """Root Mean Square to dBFS. https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/ Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB. - + dB = dBFS + 3.0103 dBFS = db - 3.0103 e.g. 0 dB = -3.0103 dBFS @@ -116,26 +194,26 @@ def rms_to_dbfs(rms: float): def max_dbfs(sample_data: np.ndarray): - """Peak dBFS based on the maximum energy sample. + """Peak dBFS based on the maximum energy sample. Args: sample_data ([np.ndarray]): float array, [-1, 1]. Returns: - float: dBFS + float: dBFS """ # Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization. return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data)))) def mean_dbfs(sample_data): - """Peak dBFS based on the RMS energy. + """Peak dBFS based on the RMS energy. Args: sample_data ([np.ndarray]): float array, [-1, 1]. Returns: - float: dBFS + float: dBFS """ return rms_to_dbfs( math.sqrt(np.mean(np.square(sample_data, dtype=np.float64)))) @@ -155,7 +233,7 @@ def gain_db_to_ratio(gain_db: float): def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103): """Nomalize audio to dBFS. - + Args: sample_data (np.ndarray): input wave samples, [-1, 1]. dbfs (float, optional): target dBFS. Defaults to -3.0103. @@ -254,6 +332,13 @@ def load_cmvn(cmvn_file: str, filetype: str): cmvn = _load_json_cmvn(cmvn_file) elif filetype == "kaldi": cmvn = _load_kaldi_cmvn(cmvn_file) + elif filetype == "npz": + eps = 1e-14 + npzfile = np.load(cmvn_file) + mean = np.squeeze(npzfile["mean"]) + std = np.squeeze(npzfile["std"]) + istd = 1 / (std + eps) + cmvn = [mean, istd] else: raise ValueError(f"cmvn file type no support: {filetype}") return cmvn[0], cmvn[1] diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 7f019039..280a4073 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -23,7 +23,7 @@ logger = Log(__name__).getlog() class SpeechCollator(): - def __init__(self, keep_transcription_text=True): + def __init__(self, keep_transcription_text=True, return_utts=False): """ Padding audio features with zeros to make them have the same shape (or a user-defined shape) within one bach. @@ -31,6 +31,7 @@ class SpeechCollator(): if ``keep_transcription_text`` is False, text is token ids else is raw string. """ self._keep_transcription_text = keep_transcription_text + self.return_utts = return_utts def __call__(self, batch): """batch examples @@ -51,7 +52,9 @@ class SpeechCollator(): audio_lens = [] texts = [] text_lens = [] - for audio, text in batch: + utts = [] + for utt, audio, text in batch: + utts.append(utt) # audio audios.append(audio.T) # [T, D] audio_lens.append(audio.shape[1]) @@ -75,4 +78,7 @@ class SpeechCollator(): padded_texts = pad_sequence( texts, padding_value=IGNORE_ID).astype(np.int64) text_lens = np.array(text_lens).astype(np.int64) - return padded_audios, audio_lens, padded_texts, text_lens + if self.return_utts: + return padded_audios, audio_lens, padded_texts, text_lens, utts + else: + return padded_audios, audio_lens, padded_texts, text_lens \ No newline at end of file diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index fe53d8e3..c11047f1 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -347,4 +347,5 @@ class ManifestDataset(Dataset): def __getitem__(self, idx): instance = self._manifest[idx] - return self.process_utterance(instance["feat"], instance["text"]) + feat, text = self.process_utterance(instance["feat"], instance["text"]) + return instance["utt"], feat, text diff --git a/deepspeech/models/ds2/conv.py b/deepspeech/models/ds2/conv.py index 111f5d3b..365c4a68 100644 --- a/deepspeech/models/ds2/conv.py +++ b/deepspeech/models/ds2/conv.py @@ -26,9 +26,9 @@ __all__ = ['ConvStack', "conv_output_size"] def conv_output_size(I, F, P, S): # https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters # Output size after Conv: - # By noting I the length of the input volume size, - # F the length of the filter, - # P the amount of zero padding, + # By noting I the length of the input volume size, + # F the length of the filter, + # P the amount of zero padding, # S the stride, # then the output size O of the feature map along that dimension is given by: # O = (I - F + Pstart + Pend) // S + 1 @@ -45,7 +45,7 @@ def conv_output_size(I, F, P, S): # https://fomoro.com/research/article/receptive-field-calculator # https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters # https://distill.pub/2019/computing-receptive-fields/ -# Rl-1 = Sl * Rl + (Kl - Sl) +# Rl-1 = Sl * Rl + (Kl - Sl) class ConvBn(nn.Layer): @@ -58,8 +58,8 @@ class ConvBn(nn.Layer): :type num_channels_in: int :param num_channels_out: Number of output channels. :type num_channels_out: int - :param stride: The x dimension of the stride. Or input a tuple for two - image dimension. + :param stride: The x dimension of the stride. Or input a tuple for two + image dimension. :type stride: int|tuple|list :param padding: The x dimension of the padding. Or input a tuple for two image dimension. @@ -114,7 +114,7 @@ class ConvBn(nn.Layer): masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] # TODO(Hui Zhang): not support bool multiply - masks = masks.type_as(x) + masks = masks.astype(x.dtype) x = x.multiply(masks) return x, x_len diff --git a/deepspeech/models/ds2/deepspeech2.py b/deepspeech/models/ds2/deepspeech2.py index 96730f80..a2aa31f7 100644 --- a/deepspeech/models/ds2/deepspeech2.py +++ b/deepspeech/models/ds2/deepspeech2.py @@ -219,15 +219,17 @@ class DeepSpeech2Model(nn.Layer): The model built from pretrained result. """ model = cls( - feat_size=dataloader.collate_fn.feature_size, - dict_size=dataloader.collate_fn.vocab_size, + #feat_size=dataloader.collate_fn.feature_size, + feat_size=dataloader.dataset.feature_size, + #dict_size=dataloader.collate_fn.vocab_size, + dict_size=dataloader.dataset.vocab_size, num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, use_gru=config.model.use_gru, share_rnn_weights=config.model.share_rnn_weights, blank_id=config.model.blank_id, - ctc_grad_norm_type=config.ctc_grad_norm_type, ) + ctc_grad_norm_type=config.model.ctc_grad_norm_type, ) infos = Checkpoint().load_parameters( model, checkpoint_path=checkpoint_path) logger.info(f"checkpoint info: {infos}") @@ -260,24 +262,8 @@ class DeepSpeech2Model(nn.Layer): class DeepSpeech2InferModel(DeepSpeech2Model): - def __init__(self, - feat_size, - dict_size, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=1024, - use_gru=False, - share_rnn_weights=True, - blank_id=0): - super().__init__( - feat_size=feat_size, - dict_size=dict_size, - num_conv_layers=num_conv_layers, - num_rnn_layers=num_rnn_layers, - rnn_size=rnn_size, - use_gru=use_gru, - share_rnn_weights=share_rnn_weights, - blank_id=blank_id) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def forward(self, audio, audio_len): """export model function diff --git a/deepspeech/models/ds2/rnn.py b/deepspeech/models/ds2/rnn.py index 29bd2883..2c189962 100644 --- a/deepspeech/models/ds2/rnn.py +++ b/deepspeech/models/ds2/rnn.py @@ -29,13 +29,13 @@ __all__ = ['RNNStack'] class RNNCell(nn.RNNCellBase): r""" - Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it + Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it computes the outputs and updates states. The formula used is as follows: .. math:: h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) y_{t} & = h_{t} - + where :math:`act` is for :attr:`activation`. """ @@ -92,7 +92,7 @@ class RNNCell(nn.RNNCellBase): class GRUCell(nn.RNNCellBase): r""" - Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, + Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, it computes the outputs and updates states. The formula for GRU used is as follows: .. math:: @@ -101,8 +101,8 @@ class GRUCell(nn.RNNCellBase): \widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc})) h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t} y_{t} & = h_{t} - - where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise + + where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise multiplication operator. """ @@ -309,6 +309,6 @@ class RNNStack(nn.Layer): masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(-1) # [B, T, 1] # TODO(Hui Zhang): not support bool multiply - masks = masks.type_as(x) + masks = masks.astype(x.dtype) x = x.multiply(masks) return x, x_len diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index 29d207c4..52e0c7b1 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -255,22 +255,24 @@ class DeepSpeech2ModelOnline(nn.Layer): fc_layers_size_list=[512, 256], use_gru=True, #Use gru if set True. Use simple rnn if set False. blank_id=0, # index of blank in vocob.txt - )) + ctc_grad_norm_type='instance', )) if config is not None: config.merge_from_other_cfg(default) return default - def __init__(self, - feat_size, - dict_size, - num_conv_layers=2, - num_rnn_layers=4, - rnn_size=1024, - rnn_direction='forward', - num_fc_layers=2, - fc_layers_size_list=[512, 256], - use_gru=False, - blank_id=0): + def __init__( + self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=4, + rnn_size=1024, + rnn_direction='forward', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False, + blank_id=0, + ctc_grad_norm_type='instance', ): super().__init__() self.encoder = CRNNEncoder( feat_size=feat_size, @@ -290,7 +292,7 @@ class DeepSpeech2ModelOnline(nn.Layer): dropout_rate=0.0, reduction=True, # sum batch_average=True, # sum / batch_size - grad_norm_type='instance') + grad_norm_type=ctc_grad_norm_type) def forward(self, audio, audio_len, text, text_len): """Compute Model loss @@ -348,16 +350,18 @@ class DeepSpeech2ModelOnline(nn.Layer): DeepSpeech2ModelOnline The model built from pretrained result. """ - model = cls(feat_size=dataloader.collate_fn.feature_size, - dict_size=dataloader.collate_fn.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - rnn_direction=config.model.rnn_direction, - num_fc_layers=config.model.num_fc_layers, - fc_layers_size_list=config.model.fc_layers_size_list, - use_gru=config.model.use_gru, - blank_id=config.model.blank_id) + model = cls( + feat_size=dataloader.collate_fn.feature_size, + dict_size=dataloader.collate_fn.vocab_size, + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + rnn_direction=config.model.rnn_direction, + num_fc_layers=config.model.num_fc_layers, + fc_layers_size_list=config.model.fc_layers_size_list, + use_gru=config.model.use_gru, + blank_id=config.model.blank_id, + ctc_grad_norm_type=config.model.ctc_grad_norm_type, ) infos = Checkpoint().load_parameters( model, checkpoint_path=checkpoint_path) logger.info(f"checkpoint info: {infos}") @@ -376,42 +380,24 @@ class DeepSpeech2ModelOnline(nn.Layer): DeepSpeech2ModelOnline The model built from config. """ - model = cls(feat_size=config.feat_size, - dict_size=config.dict_size, - num_conv_layers=config.num_conv_layers, - num_rnn_layers=config.num_rnn_layers, - rnn_size=config.rnn_layer_size, - rnn_direction=config.rnn_direction, - num_fc_layers=config.num_fc_layers, - fc_layers_size_list=config.fc_layers_size_list, - use_gru=config.use_gru, - blank_id=config.blank_id) + model = cls( + feat_size=config.feat_size, + dict_size=config.dict_size, + num_conv_layers=config.num_conv_layers, + num_rnn_layers=config.num_rnn_layers, + rnn_size=config.rnn_layer_size, + rnn_direction=config.rnn_direction, + num_fc_layers=config.num_fc_layers, + fc_layers_size_list=config.fc_layers_size_list, + use_gru=config.use_gru, + blank_id=config.blank_id, + ctc_grad_norm_type=config.ctc_grad_norm_type, ) return model class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): - def __init__(self, - feat_size, - dict_size, - num_conv_layers=2, - num_rnn_layers=4, - rnn_size=1024, - rnn_direction='forward', - num_fc_layers=2, - fc_layers_size_list=[512, 256], - use_gru=False, - blank_id=0): - super().__init__( - feat_size=feat_size, - dict_size=dict_size, - num_conv_layers=num_conv_layers, - num_rnn_layers=num_rnn_layers, - rnn_size=rnn_size, - rnn_direction=rnn_direction, - num_fc_layers=num_fc_layers, - fc_layers_size_list=fc_layers_size_list, - use_gru=use_gru, - blank_id=blank_id) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box): diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index a40eb365..6ae5cf60 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import time +from contextlib import contextmanager from pathlib import Path import paddle @@ -29,37 +30,37 @@ logger = Log(__name__).getlog() class Trainer(): """ - An experiment template in order to structure the training code and take - care of saving, loading, logging, visualization stuffs. It's intended to - be flexible and simple. - - So it only handles output directory (create directory for the output, - create a checkpoint directory, dump the config in use and create + An experiment template in order to structure the training code and take + care of saving, loading, logging, visualization stuffs. It's intended to + be flexible and simple. + + So it only handles output directory (create directory for the output, + create a checkpoint directory, dump the config in use and create visualizer and logger) in a standard way without enforcing any - input-output protocols to the model and dataloader. It leaves the main - part for the user to implement their own (setup the model, criterion, - optimizer, define a training step, define a validation function and + input-output protocols to the model and dataloader. It leaves the main + part for the user to implement their own (setup the model, criterion, + optimizer, define a training step, define a validation function and customize all the text and visual logs). - It does not save too much boilerplate code. The users still have to write - the forward/backward/update mannually, but they are free to add + It does not save too much boilerplate code. The users still have to write + the forward/backward/update mannually, but they are free to add non-standard behaviors if needed. We have some conventions to follow. - 1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and + 1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and ``valid_loader``, ``config`` and ``args`` attributes. - 2. The config should have a ``training`` field, which has - ``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is - used as the trigger to invoke validation, checkpointing and stop of the + 2. The config should have a ``training`` field, which has + ``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is + used as the trigger to invoke validation, checkpointing and stop of the experiment. - 3. There are four methods, namely ``train_batch``, ``valid``, + 3. There are four methods, namely ``train_batch``, ``valid``, ``setup_model`` and ``setup_dataloader`` that should be implemented. - Feel free to add/overwrite other methods and standalone functions if you + Feel free to add/overwrite other methods and standalone functions if you need. - + Parameters ---------- config: yacs.config.CfgNode The configuration used for the experiment. - + args: argparse.Namespace The parsed command line arguments. Examples @@ -68,17 +69,17 @@ class Trainer(): >>> exp = Trainer(config, args) >>> exp.setup() >>> exp.run() - >>> + >>> >>> config = get_cfg_defaults() >>> parser = default_argument_parser() >>> args = parser.parse_args() - >>> if args.config: + >>> if args.config: >>> config.merge_from_file(args.config) >>> if args.opts: >>> config.merge_from_list(args.opts) >>> config.freeze() - >>> - >>> if args.nprocs > 1 and args.device == "gpu": + >>> + >>> if args.nprocs > 0: >>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) >>> else: >>> main_sp(config, args) @@ -93,18 +94,24 @@ class Trainer(): self.checkpoint_dir = None self.iteration = 0 self.epoch = 0 + self._train = True - def setup(self): - """Setup the experiment. - """ - paddle.set_device(self.args.device) + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') if self.parallel: self.init_parallel() + @contextmanager + def eval(self): + self._train = False + yield + self._train = True + + def setup(self): + """Setup the experiment. + """ self.setup_output_dir() self.dump_config() self.setup_visualizer() - self.setup_checkpointer() self.setup_dataloader() self.setup_model() @@ -114,10 +121,10 @@ class Trainer(): @property def parallel(self): - """A flag indicating whether the experiment should run with + """A flag indicating whether the experiment should run with multiprocessing. """ - return self.args.device == "gpu" and self.args.nprocs > 1 + return self.args.nprocs > 1 def init_parallel(self): """Init environment for multiprocess training. @@ -144,9 +151,9 @@ class Trainer(): self.optimizer, infos) def resume_or_scratch(self): - """Resume from latest checkpoint at checkpoints in the output + """Resume from latest checkpoint at checkpoints in the output directory or load a specified checkpoint. - + If ``args.checkpoint_path`` is not None, load the checkpoint, else resume training. """ @@ -158,8 +165,8 @@ class Trainer(): checkpoint_path=self.args.checkpoint_path) if infos: # restore from ckpt - self.iteration = infos["step"] - self.epoch = infos["epoch"] + self.iteration = infos["step"] + 1 + self.epoch = infos["epoch"] + 1 scratch = False else: self.iteration = 0 @@ -237,31 +244,61 @@ class Trainer(): try: self.train() except KeyboardInterrupt: - self.save() exit(-1) finally: self.destory() - logger.info("Training Done.") + logger.info("Train Done.") + + def run_test(self): + """Do Test/Decode""" + with self.eval(): + self.resume_or_scratch() + try: + self.test() + except KeyboardInterrupt: + exit(-1) + logger.info("Test/Decode Done.") + + def run_export(self): + """Do Model Export""" + with self.eval(): + try: + self.export() + except KeyboardInterrupt: + exit(-1) + logger.info("Export Done.") def setup_output_dir(self): """Create a directory used for output. """ - # output dir - output_dir = Path(self.args.output).expanduser() - output_dir.mkdir(parents=True, exist_ok=True) - + if self.args.output: + output_dir = Path(self.args.output).expanduser() + elif self.args.checkpoint_path: + output_dir = Path( + self.args.checkpoint_path).expanduser().parent.parent self.output_dir = output_dir + self.output_dir.mkdir(parents=True, exist_ok=True) - def setup_checkpointer(self): - """Create a directory used to save checkpoints into. - - It is "checkpoints" inside the output directory. - """ - # checkpoint dir - checkpoint_dir = self.output_dir / "checkpoints" - checkpoint_dir.mkdir(exist_ok=True) + self.checkpoint_dir = self.output_dir / "checkpoints" + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + + self.log_dir = output_dir / "log" + self.log_dir.mkdir(parents=True, exist_ok=True) + + self.test_dir = output_dir / "test" + self.test_dir.mkdir(parents=True, exist_ok=True) + + self.decode_dir = output_dir / "decode" + self.decode_dir.mkdir(parents=True, exist_ok=True) - self.checkpoint_dir = checkpoint_dir + self.export_dir = output_dir / "export" + self.export_dir.mkdir(parents=True, exist_ok=True) + + self.visual_dir = output_dir / "visual" + self.visual_dir.mkdir(parents=True, exist_ok=True) + + self.config_dir = output_dir / "conf" + self.config_dir.mkdir(parents=True, exist_ok=True) @mp_tools.rank_zero_only def destory(self): @@ -273,27 +310,34 @@ class Trainer(): @mp_tools.rank_zero_only def setup_visualizer(self): """Initialize a visualizer to log the experiment. - + The visual log is saved in the output directory. - + Notes ------ - Only the main process has a visualizer with it. Use multiple - visualizers in multiprocess to write to a same log file may cause + Only the main process has a visualizer with it. Use multiple + visualizers in multiprocess to write to a same log file may cause unexpected behaviors. """ # visualizer - visualizer = SummaryWriter(logdir=str(self.output_dir)) + visualizer = SummaryWriter(logdir=str(self.visual_dir)) self.visualizer = visualizer @mp_tools.rank_zero_only def dump_config(self): - """Save the configuration used for this experiment. - - It is saved in to ``config.yaml`` in the output directory at the + """Save the configuration used for this experiment. + + It is saved in to ``config.yaml`` in the output directory at the beginning of the experiment. """ - with open(self.output_dir / "config.yaml", 'wt') as f: + config_file = self.config_dir / "config.yaml" + if self._train and config_file.exists(): + time_stamp = time.strftime("%Y_%m_%d_%H_%M_%s", time.gmtime()) + target_path = self.config_dir / ".".join( + [time_stamp, "config.yaml"]) + config_file.rename(target_path) + + with open(config_file, 'wt') as f: print(self.config, file=f) def train_batch(self): @@ -307,14 +351,26 @@ class Trainer(): """ raise NotImplementedError("valid should be implemented.") + @paddle.no_grad() + def test(self): + """The test. A subclass should implement this method in Tester. + """ + raise NotImplementedError("test should be implemented.") + + @paddle.no_grad() + def export(self): + """The test. A subclass should implement this method in Tester. + """ + raise NotImplementedError("export should be implemented.") + def setup_model(self): - """Setup model, criterion and optimizer, etc. A subclass should + """Setup model, criterion and optimizer, etc. A subclass should implement this method. """ raise NotImplementedError("setup_model should be implemented.") def setup_dataloader(self): - """Setup training dataloader and validation dataloader. A subclass + """Setup training dataloader and validation dataloader. A subclass should implement this method. """ raise NotImplementedError("setup_dataloader should be implemented.") diff --git a/deepspeech/utils/log.py b/deepspeech/utils/log.py index 7e8de600..1790efdb 100644 --- a/deepspeech/utils/log.py +++ b/deepspeech/utils/log.py @@ -120,14 +120,15 @@ class Autolog: model_precision="fp32"): import auto_log pid = os.getpid() - if (os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''): + if os.environ.get('CUDA_VISIBLE_DEVICES', None): gpu_id = int(os.environ['CUDA_VISIBLE_DEVICES'].split(',')[0]) infer_config = inference.Config() infer_config.enable_use_gpu(100, gpu_id) else: gpu_id = None infer_config = inference.Config() - autolog = auto_log.AutoLogger( + + self.autolog = auto_log.AutoLogger( model_name=model_name, model_precision=model_precision, batch_size=batch_size, @@ -139,7 +140,6 @@ class Autolog: gpu_ids=gpu_id, time_keys=['preprocess_time', 'inference_time', 'postprocess_time'], warmup=0) - self.autolog = autolog def getlog(self): return self.autolog diff --git a/examples/dataset/mini_librispeech/.gitignore b/examples/dataset/mini_librispeech/.gitignore index 61f54c96..7fbcfd65 100644 --- a/examples/dataset/mini_librispeech/.gitignore +++ b/examples/dataset/mini_librispeech/.gitignore @@ -2,3 +2,4 @@ dev-clean/ manifest.dev-clean manifest.train-clean train-clean/ +*.meta diff --git a/examples/dataset/mini_librispeech/mini_librispeech.py b/examples/dataset/mini_librispeech/mini_librispeech.py index f5bc1393..65fee81a 100644 --- a/examples/dataset/mini_librispeech/mini_librispeech.py +++ b/examples/dataset/mini_librispeech/mini_librispeech.py @@ -58,6 +58,10 @@ def create_manifest(data_dir, manifest_path): """ print("Creating manifest %s ..." % manifest_path) json_lines = [] + total_sec = 0.0 + total_text = 0.0 + total_num = 0 + for subfolder, _, filelist in sorted(os.walk(data_dir)): text_filelist = [ filename for filename in filelist if filename.endswith('trans.txt') @@ -80,10 +84,27 @@ def create_manifest(data_dir, manifest_path): 'text': text })) + + total_sec += duration + total_text += len(text) + total_num += 1 + with codecs.open(manifest_path, 'w', 'utf-8') as out_file: for line in json_lines: out_file.write(line + '\n') + subset = os.path.splitext(manifest_path)[1][1:] + manifest_dir = os.path.dirname(manifest_path) + data_dir_name = os.path.split(data_dir)[-1] + meta_path = os.path.join(manifest_dir, data_dir_name) + '.meta' + with open(meta_path, 'w') as f: + print(f"{subset}:", file=f) + print(f"{total_num} utts", file=f) + print(f"{total_sec / (60*60)} h", file=f) + print(f"{total_text} text", file=f) + print(f"{total_text / total_sec} text/sec", file=f) + print(f"{total_sec / total_num} sec/utt", file=f) + def prepare_dataset(url, md5sum, target_dir, manifest_path): """Download, unpack and create summmary manifest file. diff --git a/examples/librispeech/s1/local/align.sh b/examples/librispeech/s1/local/align.sh new file mode 100755 index 00000000..279461aa --- /dev/null +++ b/examples/librispeech/s1/local/align.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_prefix=$2 + +batch_size=1 +output_dir=${ckpt_prefix} +mkdir -p ${output_dir} + +# align dump in `result_file` +# .tier, .TextGrid dump in `dir of result_file` +python3 -u ${BIN_DIR}/alignment.py \ +--nproc ${ngpu} \ +--config ${config_path} \ +--result_file ${output_dir}/${type}.align \ +--checkpoint_path ${ckpt_prefix} \ +--opts decoding.batch_size ${batch_size} + +if [ $? -ne 0 ]; then + echo "Failed in ctc alignment!" + exit 1 +fi + +exit 0 diff --git a/examples/librispeech/s1/local/data.sh b/examples/librispeech/s1/local/data.sh index 96924e35..2b6af229 100755 --- a/examples/librispeech/s1/local/data.sh +++ b/examples/librispeech/s1/local/data.sh @@ -1,4 +1,4 @@ -#! /usr/bin/env bash +#!/bin/bash stage=-1 stop_stage=100 diff --git a/examples/librispeech/s1/local/download_lm_en.sh b/examples/librispeech/s1/local/download_lm_en.sh index 05ea793f..dc1bdf66 100755 --- a/examples/librispeech/s1/local/download_lm_en.sh +++ b/examples/librispeech/s1/local/download_lm_en.sh @@ -1,4 +1,4 @@ -#! /usr/bin/env bash +#!/bin/bash . ${MAIN_ROOT}/utils/utility.sh diff --git a/examples/librispeech/s1/local/export.sh b/examples/librispeech/s1/local/export.sh index 1b19d572..b562218e 100755 --- a/examples/librispeech/s1/local/export.sh +++ b/examples/librispeech/s1/local/export.sh @@ -1,4 +1,4 @@ -#! /usr/bin/env bash +#!/bin/bash if [ $# != 3 ];then echo "usage: $0 config_path ckpt_prefix jit_model_path" @@ -12,13 +12,7 @@ config_path=$1 ckpt_path_prefix=$2 jit_model_export_path=$3 -device=gpu -if [ ngpu == 0 ];then - device=cpu -fi - python3 -u ${BIN_DIR}/export.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ diff --git a/examples/librispeech/s1/local/test.sh b/examples/librispeech/s1/local/test.sh index 8c323e00..7f48d3d5 100755 --- a/examples/librispeech/s1/local/test.sh +++ b/examples/librispeech/s1/local/test.sh @@ -1,19 +1,40 @@ -#! /usr/bin/env bash +#!/bin/bash -if [ $# != 2 ];then - echo "usage: ${0} config_path ckpt_path_prefix" +set -e + +expdir=exp +datadir=data +nj=32 + +lmtag= + +recog_set="test-clean test-other dev-clean dev-other" +recog_set="test-clean" + +# bpemode (unigram or bpe) +nbpe=5000 +bpemode=unigram +bpeprefix="data/bpe_${bpemode}_${nbpe}" +bpemodel=${bpeprefix}.model + +if [ $# != 3 ];then + echo "usage: ${0} config_path dict_path ckpt_path_prefix" exit -1 fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -device=gpu -if [ ngpu == 0 ];then - device=cpu -fi config_path=$1 -ckpt_prefix=$2 +dict=$2 +ckpt_prefix=$3 + +chunk_mode=false +if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then + chunk_mode=true +fi +echo "chunk mode ${chunk_mode}" + # download language model #bash local/download_lm_en.sh @@ -21,39 +42,46 @@ ckpt_prefix=$2 # exit 1 #fi -for type in attention ctc_greedy_search; do - echo "decoding ${type}" - batch_size=64 - python3 -u ${BIN_DIR}/test.py \ - --device ${device} \ - --nproc 1 \ - --config ${config_path} \ - --result_file ${ckpt_prefix}.${type}.rsl \ - --checkpoint_path ${ckpt_prefix} \ - --opts decoding.decoding_method ${type} decoding.batch_size ${batch_size} - - if [ $? -ne 0 ]; then - echo "Failed in evaluation!" - exit 1 - fi -done +pids=() # initialize pids + +for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring; do +( + for rtask in ${recog_set}; do + ( + decode_dir=decode_${rtask}_${dmethd}_$(basename ${config_path%.*})_${lmtag} + feat_recog_dir=${datadir} + mkdir -p ${expdir}/${decode_dir} + mkdir -p ${feat_recog_dir} + + # split data + split_json.sh ${feat_recog_dir}/manifest.${rtask} ${nj} + + #### use CPU for decoding + ngpu=0 + + # set batchsize 0 to disable batch decoding + batch_size=1 + ${decode_cmd} JOB=1:${nj} ${expdir}/${decode_dir}/log/decode.JOB.log \ + python3 -u ${BIN_DIR}/test.py \ + --nproc ${ngpu} \ + --config ${config_path} \ + --result_file ${expdir}/${decode_dir}/data.JOB.json \ + --checkpoint_path ${ckpt_prefix} \ + --opts decoding.decoding_method ${dmethd} \ + --opts decoding.batch_size ${batch_size} \ + --opts data.test_manifest ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} + + score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel}.model --wer true ${expdir}/${decode_dir} ${dict} -for type in ctc_prefix_beam_search attention_rescoring; do - echo "decoding ${type}" - batch_size=1 - python3 -u ${BIN_DIR}/test.py \ - --device ${device} \ - --nproc 1 \ - --config ${config_path} \ - --result_file ${ckpt_prefix}.${type}.rsl \ - --checkpoint_path ${ckpt_prefix} \ - --opts decoding.decoding_method ${type} decoding.batch_size ${batch_size} - - if [ $? -ne 0 ]; then - echo "Failed in evaluation!" - exit 1 - fi + ) & + pids+=($!) # store background pids + done +) & +pids+=($!) # store background pids done +i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done +[ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." && false +echo "Finished" exit 0 diff --git a/examples/librispeech/s1/local/train.sh b/examples/librispeech/s1/local/train.sh index a4218aa8..906a329d 100755 --- a/examples/librispeech/s1/local/train.sh +++ b/examples/librispeech/s1/local/train.sh @@ -1,4 +1,4 @@ -#! /usr/bin/env bash +#!/bin/bash if [ $# != 2 ];then echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" @@ -11,19 +11,28 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 -device=gpu -if [ ngpu == 0 ];then - device=cpu +mkdir -p exp + +# seed may break model convergence +seed=0 +if [ ${seed} != 0 ]; then + #export FLAGS_cudnn_deterministic=True + echo "None" fi -echo "using ${device}..." -mkdir -p exp +# export FLAGS_cudnn_exhaustive_search=true +# export FLAGS_conv_workspace_size_limit=4000 python3 -u ${BIN_DIR}/train.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--seed ${seed} + +if [ ${seed} != 0 ]; then + #unset FLAGS_cudnn_deterministic + echo "None" +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index 3f52da7f..594ec579 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -4,6 +4,7 @@ data: dev_manifest: data/manifest.tiny test_manifest: data/manifest.tiny mean_std_filepath: data/mean_std.json + unit_type: char vocab_filepath: data/vocab.txt augmentation_config: conf/augmentation.json batch_size: 4 @@ -35,6 +36,8 @@ model: rnn_layer_size: 2048 use_gru: False share_rnn_weights: True + blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 20 diff --git a/examples/tiny/s0/conf/deepspeech2_online.yaml b/examples/tiny/s0/conf/deepspeech2_online.yaml new file mode 100644 index 00000000..7e30409f --- /dev/null +++ b/examples/tiny/s0/conf/deepspeech2_online.yaml @@ -0,0 +1,72 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.tiny + dev_manifest: data/manifest.tiny + test_manifest: data/manifest.tiny + min_input_len: 0.0 + max_input_len: 30.0 + min_output_len: 0.0 + max_output_len: 400.0 + min_output_input_ratio: 0.05 + max_output_input_ratio: 10.0 + + +collator: + mean_std_filepath: data/mean_std.json + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: + spectrum_type: linear + feat_dim: + delta_delta: False + stride_ms: 10.0 + window_ms: 20.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 + use_dB_normalization: True + target_dB: -20 + dither: 1.0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + batch_size: 4 + +model: + num_conv_layers: 2 + num_rnn_layers: 4 + rnn_layer_size: 2048 + rnn_direction: forward + num_fc_layers: 2 + fc_layers_size_list: 512, 256 + use_gru: True + blank_id: 0 + ctc_grad_norm_type: instance + +training: + n_epoch: 10 + accum_grad: 1 + lr: 1e-5 + lr_decay: 1.0 + weight_decay: 1e-06 + global_grad_clip: 5.0 + log_interval: 1 + checkpoint: + kbest_n: 3 + latest_n: 2 + + +decoding: + batch_size: 128 + error_rate_type: wer + decoding_method: ctc_beam_search + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 2.5 + beta: 0.3 + beam_size: 500 + cutoff_prob: 1.0 + cutoff_top_n: 40 + num_proc_bsearch: 8 diff --git a/examples/tiny/s0/local/download_lm_en.sh b/examples/tiny/s0/local/download_lm_en.sh index 05ea793f..a647d3bc 100755 --- a/examples/tiny/s0/local/download_lm_en.sh +++ b/examples/tiny/s0/local/download_lm_en.sh @@ -1,4 +1,4 @@ -#! /usr/bin/env bash +#!/bin/bash . ${MAIN_ROOT}/utils/utility.sh @@ -9,6 +9,11 @@ URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm MD5="099a601759d467cd0a8523ff939819c5" TARGET=${DIR}/common_crawl_00.prune01111.trie.klm +if [ -e $TARGET ];then + echo "$TARGET exists." + exit 0 +fi + echo "Download language model ..." download $URL $MD5 $TARGET if [ $? -ne 0 ]; then diff --git a/examples/tiny/s0/local/export.sh b/examples/tiny/s0/local/export.sh index 1b19d572..a5e62c28 100755 --- a/examples/tiny/s0/local/export.sh +++ b/examples/tiny/s0/local/export.sh @@ -1,7 +1,7 @@ -#! /usr/bin/env bash +#!/bin/bash -if [ $# != 3 ];then - echo "usage: $0 config_path ckpt_prefix jit_model_path" +if [ $# != 4 ];then + echo "usage: $0 config_path ckpt_prefix jit_model_path model_type" exit -1 fi @@ -11,19 +11,14 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_path_prefix=$2 jit_model_export_path=$3 - -device=gpu -if [ ngpu == 0 ];then - device=cpu -fi +model_type=$4 python3 -u ${BIN_DIR}/export.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ ---export_path ${jit_model_export_path} - +--export_path ${jit_model_export_path} \ +--model_type ${model_type} if [ $? -ne 0 ]; then echo "Failed in export!" diff --git a/examples/tiny/s0/local/test.sh b/examples/tiny/s0/local/test.sh index 79e05838..4d00f30b 100755 --- a/examples/tiny/s0/local/test.sh +++ b/examples/tiny/s0/local/test.sh @@ -1,19 +1,16 @@ -#! /usr/bin/env bash +#!/bin/bash -if [ $# != 2 ];then - echo "usage: ${0} config_path ckpt_path_prefix" +if [ $# != 3 ];then + echo "usage: ${0} config_path ckpt_path_prefix model_type" exit -1 fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -device=gpu -if [ ngpu == 0 ];then - device=cpu -fi config_path=$1 ckpt_prefix=$2 +model_type=$3 # download language model bash local/download_lm_en.sh @@ -22,11 +19,11 @@ if [ $? -ne 0 ]; then fi python3 -u ${BIN_DIR}/test.py \ ---device ${device} \ ---nproc 1 \ +--nproc ${ngpu} \ --config ${config_path} \ --result_file ${ckpt_prefix}.rsl \ ---checkpoint_path ${ckpt_prefix} +--checkpoint_path ${ckpt_prefix} \ +--model_type ${model_type} if [ $? -ne 0 ]; then echo "Failed in evaluation!" diff --git a/examples/tiny/s0/local/train.sh b/examples/tiny/s0/local/train.sh index f8c9dbc0..5b87780a 100755 --- a/examples/tiny/s0/local/train.sh +++ b/examples/tiny/s0/local/train.sh @@ -1,28 +1,42 @@ -#! /usr/bin/env bash +#!/bin/bash -if [ $# != 2 ];then - echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" - exit -1 -fi +profiler_options= + +# seed may break model convergence +seed=0 + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -config_path=$1 -ckpt_name=$2 +if [ ${seed} != 0 ]; then + export FLAGS_cudnn_deterministic=True + echo "using seed $seed & FLAGS_cudnn_deterministic=True ..." +fi -device=gpu -if [ ngpu == 0 ];then - device=cpu +if [ $# != 3 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name model_type" + exit -1 fi +config_path=$1 +ckpt_name=$2 +model_type=$3 + mkdir -p exp python3 -u ${BIN_DIR}/train.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--model_type ${model_type} \ +--profiler-options "${profiler_options}" \ +--seed ${seed} + +if [ ${seed} != 0 ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/tiny/s0/path.sh b/examples/tiny/s0/path.sh index 777da29e..8a9345f2 100644 --- a/examples/tiny/s0/path.sh +++ b/examples/tiny/s0/path.sh @@ -1,4 +1,4 @@ -export MAIN_ROOT=${PWD}/../../../ +export MAIN_ROOT=`realpath ${PWD}/../../../` export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} export LC_ALL=C diff --git a/examples/tiny/s0/run.sh b/examples/tiny/s0/run.sh index d4961adb..f39fb3fa 100755 --- a/examples/tiny/s0/run.sh +++ b/examples/tiny/s0/run.sh @@ -7,11 +7,12 @@ stage=0 stop_stage=100 conf_path=conf/deepspeech2.yaml avg_num=1 +model_type=offline source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; avg_ckpt=avg_${avg_num} -ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') ###ckpt = deepspeech2 echo "checkpoint name ${ckpt}" if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then @@ -21,20 +22,20 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # export ckpt avg_n - CUDA_VISIBLE_DEVICES=${gpus} ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit + CUDA_VISIBLE_DEVICES=${gpus} ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type} fi diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index 39f5e99b..e060a90f 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -65,6 +65,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/tiny/s1/local/align.sh b/examples/tiny/s1/local/align.sh new file mode 100755 index 00000000..279461aa --- /dev/null +++ b/examples/tiny/s1/local/align.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_prefix=$2 + +batch_size=1 +output_dir=${ckpt_prefix} +mkdir -p ${output_dir} + +# align dump in `result_file` +# .tier, .TextGrid dump in `dir of result_file` +python3 -u ${BIN_DIR}/alignment.py \ +--nproc ${ngpu} \ +--config ${config_path} \ +--result_file ${output_dir}/${type}.align \ +--checkpoint_path ${ckpt_prefix} \ +--opts decoding.batch_size ${batch_size} + +if [ $? -ne 0 ]; then + echo "Failed in ctc alignment!" + exit 1 +fi + +exit 0 diff --git a/examples/tiny/s1/local/data.sh b/examples/tiny/s1/local/data.sh index 5822dc92..b5dbd581 100755 --- a/examples/tiny/s1/local/data.sh +++ b/examples/tiny/s1/local/data.sh @@ -1,4 +1,4 @@ -#! /usr/bin/env bash +#!/bin/bash stage=-1 stop_stage=100 diff --git a/examples/tiny/s1/local/export.sh b/examples/tiny/s1/local/export.sh index 1b19d572..b562218e 100755 --- a/examples/tiny/s1/local/export.sh +++ b/examples/tiny/s1/local/export.sh @@ -1,4 +1,4 @@ -#! /usr/bin/env bash +#!/bin/bash if [ $# != 3 ];then echo "usage: $0 config_path ckpt_prefix jit_model_path" @@ -12,13 +12,7 @@ config_path=$1 ckpt_path_prefix=$2 jit_model_export_path=$3 -device=gpu -if [ ngpu == 0 ];then - device=cpu -fi - python3 -u ${BIN_DIR}/export.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ diff --git a/examples/tiny/s1/local/test.sh b/examples/tiny/s1/local/test.sh index 240a63b0..34088ce9 100755 --- a/examples/tiny/s1/local/test.sh +++ b/examples/tiny/s1/local/test.sh @@ -1,4 +1,4 @@ -#! /usr/bin/env bash +#!/bin/bash if [ $# != 2 ];then echo "usage: ${0} config_path ckpt_path_prefix" @@ -8,30 +8,57 @@ fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -device=gpu -if [ ngpu == 0 ];then - device=cpu -fi config_path=$1 ckpt_prefix=$2 +chunk_mode=false +if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then + chunk_mode=true +fi + # download language model #bash local/download_lm_en.sh #if [ $? -ne 0 ]; then # exit 1 #fi -python3 -u ${BIN_DIR}/test.py \ ---device ${device} \ ---nproc 1 \ ---config ${config_path} \ ---result_file ${ckpt_prefix}.rsl \ ---checkpoint_path ${ckpt_prefix} +for type in attention ctc_greedy_search; do + echo "decoding ${type}" + if [ ${chunk_mode} == true ];then + # stream decoding only support batchsize=1 + batch_size=1 + else + batch_size=64 + fi + python3 -u ${BIN_DIR}/test.py \ + --nproc ${ngpu} \ + --config ${config_path} \ + --result_file ${ckpt_prefix}.${type}.rsl \ + --checkpoint_path ${ckpt_prefix} \ + --opts decoding.decoding_method ${type} \ + --opts decoding.batch_size ${batch_size} -if [ $? -ne 0 ]; then - echo "Failed in evaluation!" - exit 1 -fi + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + fi +done + +for type in ctc_prefix_beam_search attention_rescoring; do + echo "decoding ${type}" + batch_size=1 + python3 -u ${BIN_DIR}/test.py \ + --nproc ${ngpu} \ + --config ${config_path} \ + --result_file ${ckpt_prefix}.${type}.rsl \ + --checkpoint_path ${ckpt_prefix} \ + --opts decoding.decoding_method ${type} \ + --opts decoding.batch_size ${batch_size} + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + fi +done exit 0 diff --git a/examples/tiny/s1/local/train.sh b/examples/tiny/s1/local/train.sh index f8c9dbc0..71af3a00 100755 --- a/examples/tiny/s1/local/train.sh +++ b/examples/tiny/s1/local/train.sh @@ -1,28 +1,45 @@ -#! /usr/bin/env bash +#!/bin/bash + +profiler_options= +benchmark_batch_size=0 +benchmark_max_step=0 + +# seed may break model convergence +seed=0 + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +if [ ${seed} != 0 ]; then + export FLAGS_cudnn_deterministic=True + echo "using seed $seed & FLAGS_cudnn_deterministic=True ..." +fi if [ $# != 2 ];then echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" exit -1 fi -ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') -echo "using $ngpu gpus..." - config_path=$1 ckpt_name=$2 -device=gpu -if [ ngpu == 0 ];then - device=cpu -fi - mkdir -p exp python3 -u ${BIN_DIR}/train.py \ ---device ${device} \ +--seed ${seed} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--profiler-options "${profiler_options}" \ +--benchmark-batch-size ${benchmark_batch_size} \ +--benchmark-max-step ${benchmark_max_step} + + +if [ ${seed} != 0 ]; then + unset FLAGS_cudnn_deterministic +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/tiny/s1/run.sh b/examples/tiny/s1/run.sh index f7e41a33..6580afed 100755 --- a/examples/tiny/s1/run.sh +++ b/examples/tiny/s1/run.sh @@ -20,20 +20,26 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=4,5,6,7 ./local/train.sh ${conf_path} ${ckpt} + ./local/train.sh ${conf_path} ${ckpt} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES= ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # ctc alignment of test data + CUDA_VISIBLE_DEVICES= ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # export ckpt avg_n CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi + -- GitLab