From fa70a024a78ca3bcf30fb10d1cceb0c572e9949e Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 5 Oct 2021 10:41:20 +0000 Subject: [PATCH] model;new updater;benchmark;chians; can run libri/s1 --- .../exps/deepspeech2/bin/deploy/runtime.py | 2 +- .../exps/deepspeech2/bin/deploy/server.py | 2 +- deepspeech/exps/deepspeech2/bin/export.py | 8 +- deepspeech/exps/deepspeech2/bin/test.py | 8 +- .../exps/deepspeech2/bin/test_export.py | 57 ++ deepspeech/exps/deepspeech2/bin/test_hub.py | 206 ++++ deepspeech/exps/deepspeech2/bin/train.py | 7 +- deepspeech/exps/deepspeech2/bin/tune.py | 2 +- deepspeech/exps/deepspeech2/config.py | 17 +- deepspeech/exps/deepspeech2/model.py | 35 +- deepspeech/exps/u2/bin/train.py | 2 +- .../frontend/featurizer/audio_featurizer.py | 32 +- .../frontend/featurizer/speech_featurizer.py | 14 +- .../frontend/featurizer/text_featurizer.py | 6 +- deepspeech/frontend/utility.py | 13 +- deepspeech/io/__init__.py | 4 +- deepspeech/io/dataset.py | 10 +- deepspeech/models/ds2/__init__.py | 17 + deepspeech/models/ds2/conv.py | 171 ++++ deepspeech/models/ds2/deepspeech2.py | 306 ++++++ deepspeech/models/ds2/rnn.py | 314 ++++++ deepspeech/models/ds2_online/__init__.py | 17 + deepspeech/models/ds2_online/conv.py | 33 + deepspeech/models/ds2_online/deepspeech2.py | 438 ++++++++ deepspeech/models/u2/__init__.py | 19 + deepspeech/models/u2/u2.py | 951 ++++++++++++++++++ deepspeech/models/u2/updater.py | 149 +++ deepspeech/models/u2_st.py | 733 ++++++++++++++ deepspeech/modules/ctc.py | 29 +- deepspeech/modules/loss.py | 8 +- deepspeech/training/cli.py | 112 ++- deepspeech/training/extensions/__init__.py | 41 + deepspeech/training/extensions/evaluator.py | 101 ++ deepspeech/training/extensions/extension.py | 52 + deepspeech/training/extensions/snapshot.py | 133 +++ deepspeech/training/extensions/visualizer.py | 39 + deepspeech/training/gradclip.py | 7 +- deepspeech/training/optimizer.py | 121 +++ deepspeech/training/reporter.py | 144 +++ deepspeech/training/scheduler.py | 66 +- deepspeech/training/timer.py | 50 + deepspeech/training/trainer.py | 178 ++-- deepspeech/training/triggers/__init__.py | 28 + .../training/triggers/interval_trigger.py | 38 + deepspeech/training/triggers/limit_trigger.py | 31 + deepspeech/training/triggers/time_trigger.py | 32 + deepspeech/training/updaters/__init__.py | 13 + .../training/updaters/standard_updater.py | 195 ++++ deepspeech/training/updaters/trainer.py | 184 ++++ deepspeech/training/updaters/updater.py | 84 ++ deepspeech/utils/checkpoint.py | 14 +- env.sh | 2 +- examples/aishell/s0/conf/deepspeech2.yaml | 2 +- examples/aishell/s0/local/data.sh | 2 +- examples/aishell/s1/conf/chunk_conformer.yaml | 2 +- examples/aishell/s1/conf/conformer.yaml | 2 +- examples/aishell/s1/local/data.sh | 2 +- examples/librispeech/s0/conf/deepspeech2.yaml | 2 +- examples/librispeech/s0/local/data.sh | 2 +- .../librispeech/s1/conf/chunk_confermer.yaml | 2 +- .../s1/conf/chunk_transformer.yaml | 2 +- examples/librispeech/s1/conf/conformer.yaml | 2 +- examples/librispeech/s1/conf/transformer.yaml | 8 +- examples/librispeech/s1/local/data.sh | 2 +- examples/librispeech/s1/local/train.sh | 22 +- examples/tiny/s0/conf/deepspeech2.yaml | 2 +- examples/tiny/s0/local/data.sh | 2 +- examples/tiny/s1/conf/chunk_confermer.yaml | 2 +- examples/tiny/s1/conf/chunk_transformer.yaml | 2 +- examples/tiny/s1/conf/conformer.yaml | 2 +- examples/tiny/s1/conf/transformer.yaml | 2 +- examples/tiny/s1/local/data.sh | 2 +- examples/tiny/s1/test.profile | Bin 0 -> 130998 bytes tests/benchmark/.gitignore | 2 + tests/benchmark/README.md | 11 + tests/benchmark/run_all.sh | 49 + tests/benchmark/run_benchmark.sh | 57 ++ tests/chains/README.md | 9 + tests/chains/ds2_params_lite_train_infer.txt | 51 + tests/chains/ds2_params_whole_train_infer.txt | 51 + tests/chains/lite_train_infer.sh | 5 + tests/chains/prepare.sh | 84 ++ tests/chains/speedyspeech_params_lite.txt | 51 + tests/chains/test.sh | 371 +++++++ tests/chains/whole_train_infer.sh | 5 + tests/deepspeech2_model_test.py | 2 +- tests/deepspeech2_online_model_test.py | 186 ++++ 87 files changed, 6047 insertions(+), 228 deletions(-) create mode 100644 deepspeech/exps/deepspeech2/bin/test_export.py create mode 100644 deepspeech/exps/deepspeech2/bin/test_hub.py create mode 100644 deepspeech/models/ds2/__init__.py create mode 100644 deepspeech/models/ds2/conv.py create mode 100644 deepspeech/models/ds2/deepspeech2.py create mode 100644 deepspeech/models/ds2/rnn.py create mode 100644 deepspeech/models/ds2_online/__init__.py create mode 100644 deepspeech/models/ds2_online/conv.py create mode 100644 deepspeech/models/ds2_online/deepspeech2.py create mode 100644 deepspeech/models/u2/__init__.py create mode 100644 deepspeech/models/u2/u2.py create mode 100644 deepspeech/models/u2/updater.py create mode 100644 deepspeech/models/u2_st.py create mode 100644 deepspeech/training/extensions/__init__.py create mode 100644 deepspeech/training/extensions/evaluator.py create mode 100644 deepspeech/training/extensions/extension.py create mode 100644 deepspeech/training/extensions/snapshot.py create mode 100644 deepspeech/training/extensions/visualizer.py create mode 100644 deepspeech/training/optimizer.py create mode 100644 deepspeech/training/reporter.py create mode 100644 deepspeech/training/timer.py create mode 100644 deepspeech/training/triggers/__init__.py create mode 100644 deepspeech/training/triggers/interval_trigger.py create mode 100644 deepspeech/training/triggers/limit_trigger.py create mode 100644 deepspeech/training/triggers/time_trigger.py create mode 100644 deepspeech/training/updaters/__init__.py create mode 100644 deepspeech/training/updaters/standard_updater.py create mode 100644 deepspeech/training/updaters/trainer.py create mode 100644 deepspeech/training/updaters/updater.py create mode 100644 examples/tiny/s1/test.profile create mode 100644 tests/benchmark/.gitignore create mode 100644 tests/benchmark/README.md create mode 100755 tests/benchmark/run_all.sh create mode 100755 tests/benchmark/run_benchmark.sh create mode 100644 tests/chains/README.md create mode 100644 tests/chains/ds2_params_lite_train_infer.txt create mode 100644 tests/chains/ds2_params_whole_train_infer.txt create mode 100644 tests/chains/lite_train_infer.sh create mode 100644 tests/chains/prepare.sh create mode 100644 tests/chains/speedyspeech_params_lite.txt create mode 100644 tests/chains/test.sh create mode 100644 tests/chains/whole_train_infer.sh create mode 100644 tests/deepspeech2_online_model_test.py diff --git a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py index f3125e04..5677d4cf 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py @@ -21,7 +21,7 @@ from paddle.inference import create_predictor from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.io.dataset import ManifestDataset -from deepspeech.models.deepspeech2 import DeepSpeech2Model +from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.training.cli import default_argument_parser from deepspeech.utils.socket_server import AsrRequestHandler from deepspeech.utils.socket_server import AsrTCPServer diff --git a/deepspeech/exps/deepspeech2/bin/deploy/server.py b/deepspeech/exps/deepspeech2/bin/deploy/server.py index b2ff37e0..0e1211b0 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/server.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/server.py @@ -19,7 +19,7 @@ import paddle from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.io.dataset import ManifestDataset -from deepspeech.models.deepspeech2 import DeepSpeech2Model +from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.training.cli import default_argument_parser from deepspeech.utils.socket_server import AsrRequestHandler from deepspeech.utils.socket_server import AsrTCPServer diff --git a/deepspeech/exps/deepspeech2/bin/export.py b/deepspeech/exps/deepspeech2/bin/export.py index a1607d58..ab5251d5 100644 --- a/deepspeech/exps/deepspeech2/bin/export.py +++ b/deepspeech/exps/deepspeech2/bin/export.py @@ -30,11 +30,17 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + # save jit model to + parser.add_argument( + "--export_path", type=str, help="path of the jit model to save") + parser.add_argument( + "--model_type", type=str, default='offline', help="offline/online") args = parser.parse_args() + print("model_type:{}".format(args.model_type)) print_arguments(args) # https://yaml.org/type/float.html - config = get_cfg_defaults() + config = get_cfg_defaults(args.model_type) if args.config: config.merge_from_file(args.config) if args.opts: diff --git a/deepspeech/exps/deepspeech2/bin/test.py b/deepspeech/exps/deepspeech2/bin/test.py index f4edf08a..7fbdab45 100644 --- a/deepspeech/exps/deepspeech2/bin/test.py +++ b/deepspeech/exps/deepspeech2/bin/test.py @@ -30,11 +30,17 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + parser.add_argument( + "--model_type", type=str, default='offline', help='offline/online') + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") args = parser.parse_args() print_arguments(args, globals()) + print("model_type:{}".format(args.model_type)) # https://yaml.org/type/float.html - config = get_cfg_defaults() + config = get_cfg_defaults(args.model_type) if args.config: config.merge_from_file(args.config) if args.opts: diff --git a/deepspeech/exps/deepspeech2/bin/test_export.py b/deepspeech/exps/deepspeech2/bin/test_export.py new file mode 100644 index 00000000..be1a8479 --- /dev/null +++ b/deepspeech/exps/deepspeech2/bin/test_export.py @@ -0,0 +1,57 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation for DeepSpeech2 model.""" +from deepspeech.exps.deepspeech2.config import get_cfg_defaults +from deepspeech.exps.deepspeech2.model import DeepSpeech2ExportTester as ExportTester +from deepspeech.training.cli import default_argument_parser +from deepspeech.utils.utility import print_arguments + + +def main_sp(config, args): + exp = ExportTester(config, args) + exp.setup() + exp.run_test() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") + #load jit model from + parser.add_argument( + "--export_path", type=str, help="path of the jit model to save") + parser.add_argument( + "--model_type", type=str, default='offline', help='offline/online') + args = parser.parse_args() + print_arguments(args, globals()) + print("model_type:{}".format(args.model_type)) + + # https://yaml.org/type/float.html + config = get_cfg_defaults(args.model_type) + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/deepspeech/exps/deepspeech2/bin/test_hub.py b/deepspeech/exps/deepspeech2/bin/test_hub.py new file mode 100644 index 00000000..1cf24bb0 --- /dev/null +++ b/deepspeech/exps/deepspeech2/bin/test_hub.py @@ -0,0 +1,206 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation for DeepSpeech2 model.""" +import os +import sys +from pathlib import Path + +import paddle +import soundfile + +from deepspeech.exps.deepspeech2.config import get_cfg_defaults +from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer +from deepspeech.io.collator import SpeechCollator +from deepspeech.models.ds2 import DeepSpeech2Model +from deepspeech.models.ds2_online import DeepSpeech2ModelOnline +from deepspeech.training.cli import default_argument_parser +from deepspeech.utils import mp_tools +from deepspeech.utils.checkpoint import Checkpoint +from deepspeech.utils.log import Log +from deepspeech.utils.utility import print_arguments +from deepspeech.utils.utility import UpdateConfig + +logger = Log(__name__).getlog() + + +class DeepSpeech2Tester_hub(): + def __init__(self, config, args): + self.args = args + self.config = config + self.audio_file = args.audio_file + self.collate_fn_test = SpeechCollator.from_config(config) + self._text_featurizer = TextFeaturizer( + unit_type=config.collator.unit_type, vocab_filepath=None) + + def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): + result_transcripts = self.model.decode( + audio, + audio_len, + vocab_list, + decoding_method=cfg.decoding_method, + lang_model_path=cfg.lang_model_path, + beam_alpha=cfg.alpha, + beam_beta=cfg.beta, + beam_size=cfg.beam_size, + cutoff_prob=cfg.cutoff_prob, + cutoff_top_n=cfg.cutoff_top_n, + num_processes=cfg.num_proc_bsearch) + #replace the '' with ' ' + result_transcripts = [ + self._text_featurizer.detokenize(sentence) + for sentence in result_transcripts + ] + + return result_transcripts + + @mp_tools.rank_zero_only + @paddle.no_grad() + def test(self): + self.model.eval() + cfg = self.config + audio_file = self.audio_file + collate_fn_test = self.collate_fn_test + audio, _ = collate_fn_test.process_utterance( + audio_file=audio_file, transcript=" ") + audio_len = audio.shape[0] + audio = paddle.to_tensor(audio, dtype='float32') + audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) + vocab_list = collate_fn_test.vocab_list + result_transcripts = self.compute_result_transcripts( + audio, audio_len, vocab_list, cfg.decoding) + logger.info("result_transcripts: " + result_transcripts[0]) + + def run_test(self): + self.resume() + try: + self.test() + except KeyboardInterrupt: + exit(-1) + + def setup(self): + """Setup the experiment. + """ + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') + + self.setup_output_dir() + self.setup_checkpointer() + + self.setup_model() + + def setup_output_dir(self): + """Create a directory used for output. + """ + # output dir + if self.args.output: + output_dir = Path(self.args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + else: + output_dir = Path( + self.args.checkpoint_path).expanduser().parent.parent + output_dir.mkdir(parents=True, exist_ok=True) + self.output_dir = output_dir + + def setup_model(self): + config = self.config.clone() + with UpdateConfig(config): + config.model.feat_size = self.collate_fn_test.feature_size + config.model.dict_size = self.collate_fn_test.vocab_size + + if self.args.model_type == 'offline': + model = DeepSpeech2Model.from_config(config.model) + elif self.args.model_type == 'online': + model = DeepSpeech2ModelOnline.from_config(config.model) + else: + raise Exception("wrong model type") + + self.model = model + + def setup_checkpointer(self): + """Create a directory used to save checkpoints into. + + It is "checkpoints" inside the output directory. + """ + # checkpoint dir + checkpoint_dir = self.output_dir / "checkpoints" + checkpoint_dir.mkdir(exist_ok=True) + + self.checkpoint_dir = checkpoint_dir + + self.checkpoint = Checkpoint( + kbest_n=self.config.training.checkpoint.kbest_n, + latest_n=self.config.training.checkpoint.latest_n) + + def resume(self): + """Resume from the checkpoint at checkpoints in the output + directory or load a specified checkpoint. + """ + params_path = self.args.checkpoint_path + ".pdparams" + model_dict = paddle.load(params_path) + self.model.set_state_dict(model_dict) + + +def check(audio_file): + logger.info("checking the audio file format......") + try: + sig, sample_rate = soundfile.read(audio_file) + except Exception as e: + logger.error(str(e)) + logger.error( + "can not open the wav file, please check the audio file format") + sys.exit(-1) + logger.info("The sample rate is %d" % sample_rate) + assert (sample_rate == 16000) + logger.info("The audio file format is right") + + +def main_sp(config, args): + exp = DeepSpeech2Tester_hub(config, args) + exp.setup() + exp.run_test() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + parser.add_argument( + "--model_type", type=str, default='offline', help='offline/online') + parser.add_argument("--audio_file", type=str, help='audio file path') + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") + args = parser.parse_args() + print_arguments(args, globals()) + if not os.path.isfile(args.audio_file): + print("Please input the audio file path") + sys.exit(-1) + check(args.audio_file) + print("model_type:{}".format(args.model_type)) + + # https://yaml.org/type/float.html + config = get_cfg_defaults(args.model_type) + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/deepspeech/exps/deepspeech2/bin/train.py b/deepspeech/exps/deepspeech2/bin/train.py index 5e5c1e2a..02aefe3d 100644 --- a/deepspeech/exps/deepspeech2/bin/train.py +++ b/deepspeech/exps/deepspeech2/bin/train.py @@ -27,7 +27,7 @@ def main_sp(config, args): def main(config, args): - if args.device == "gpu" and args.nprocs > 1: + if args.nprocs > 0: dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) else: main_sp(config, args) @@ -35,11 +35,14 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() + parser.add_argument( + "--model_type", type=str, default='offline', help='offline/online') args = parser.parse_args() + print("model_type:{}".format(args.model_type)) print_arguments(args, globals()) # https://yaml.org/type/float.html - config = get_cfg_defaults() + config = get_cfg_defaults(args.model_type) if args.config: config.merge_from_file(args.config) if args.opts: diff --git a/deepspeech/exps/deepspeech2/bin/tune.py b/deepspeech/exps/deepspeech2/bin/tune.py index 02e329a1..c933e501 100644 --- a/deepspeech/exps/deepspeech2/bin/tune.py +++ b/deepspeech/exps/deepspeech2/bin/tune.py @@ -21,7 +21,7 @@ from paddle.io import DataLoader from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset -from deepspeech.models.deepspeech2 import DeepSpeech2Model +from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.training.cli import default_argument_parser from deepspeech.utils import error_rate from deepspeech.utils.utility import add_arguments diff --git a/deepspeech/exps/deepspeech2/config.py b/deepspeech/exps/deepspeech2/config.py index a8d452a9..633e38ff 100644 --- a/deepspeech/exps/deepspeech2/config.py +++ b/deepspeech/exps/deepspeech2/config.py @@ -13,7 +13,7 @@ # limitations under the License. from yacs.config import CfgNode as CN -from deepspeech.models.deepspeech2 import DeepSpeech2Model +from deepspeech.models.ds2 import DeepSpeech2Model _C = CN() _C.data = CN( @@ -32,7 +32,7 @@ _C.data = CN( window_ms=20.0, # ms n_fft=None, # fft points max_freq=None, # None for samplerate/2 - specgram_type='linear', # 'linear', 'mfcc', 'fbank' + spectrum_type='linear', # 'linear', 'mfcc', 'fbank' feat_dim=0, # 'mfcc', 'fbank' delat_delta=False, # 'mfcc', 'fbank' target_sample_rate=16000, # target sample rate @@ -46,16 +46,7 @@ _C.data = CN( shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle' )) -_C.model = CN( - dict( - num_conv_layers=2, #Number of stacking convolution layers. - num_rnn_layers=3, #Number of stacking RNN layers. - rnn_layer_size=1024, #RNN layer size (number of RNN cells). - use_gru=True, #Use gru if set True. Use simple rnn if set False. - share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. - )) - -DeepSpeech2Model.params(_C.model) +_C.model = DeepSpeech2Model.params() _C.training = CN( dict( @@ -81,7 +72,7 @@ _C.decoding = CN( )) -def get_cfg_defaults(): +def get_cfg_defaults(model_type): """Get a yacs CfgNode object with default values for my_project.""" # Return a clone so that the defaults will not be altered # This is for the "local variable" use pattern diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 8e8a1824..05add5bc 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -25,14 +25,15 @@ from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradDistributedBatchSampler -from deepspeech.models.deepspeech2 import DeepSpeech2InferModel -from deepspeech.models.deepspeech2 import DeepSpeech2Model +from deepspeech.models.ds2 import DeepSpeech2InferModel +from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.trainer import Trainer from deepspeech.utils import error_rate from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools from deepspeech.utils.log import Log +from deepspeech.utils.utility import UpdateConfig logger = Log(__name__).getlog() @@ -98,15 +99,27 @@ class DeepSpeech2Trainer(Trainer): return total_loss, num_seen_utts def setup_model(self): - config = self.config - model = DeepSpeech2Model( - feat_size=self.train_loader.dataset.feature_size, - dict_size=self.train_loader.dataset.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) + #config = self.config + #model = DeepSpeech2Model( + # feat_size=self.train_loader.dataset.feature_size, + # dict_size=self.train_loader.dataset.vocab_size, + # num_conv_layers=config.model.num_conv_layers, + # num_rnn_layers=config.model.num_rnn_layers, + # rnn_size=config.model.rnn_layer_size, + # use_gru=config.model.use_gru, + # share_rnn_weights=config.model.share_rnn_weights) + + config = self.config.clone() + with UpdateConfig(config): + config.model.feat_size = self.train_loader.dataset.feature_size + config.model.dict_size = self.train_loader.dataset.vocab_size + + if self.args.model_type == 'offline': + model = DeepSpeech2Model.from_config(config.model) + elif self.args.model_type == 'online': + model = DeepSpeech2ModelOnline.from_config(config.model) + else: + raise Exception("wrong model type") if self.parallel: model = paddle.DataParallel(model) diff --git a/deepspeech/exps/u2/bin/train.py b/deepspeech/exps/u2/bin/train.py index 9dd0041d..f07fc2ea 100644 --- a/deepspeech/exps/u2/bin/train.py +++ b/deepspeech/exps/u2/bin/train.py @@ -30,7 +30,7 @@ def main_sp(config, args): def main(config, args): - if args.device == "gpu" and args.nprocs > 1: + if args.nprocs > 0: dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) else: main_sp(config, args) diff --git a/deepspeech/frontend/featurizer/audio_featurizer.py b/deepspeech/frontend/featurizer/audio_featurizer.py index 11c1fa2d..2f3163fa 100644 --- a/deepspeech/frontend/featurizer/audio_featurizer.py +++ b/deepspeech/frontend/featurizer/audio_featurizer.py @@ -24,15 +24,15 @@ class AudioFeaturizer(object): Currently, it supports feature types of linear spectrogram and mfcc. - :param specgram_type: Specgram feature type. Options: 'linear'. - :type specgram_type: str + :param spectrum_type: Specgram feature type. Options: 'linear'. + :type spectrum_type: str :param stride_ms: Striding size (in milliseconds) for generating frames. :type stride_ms: float :param window_ms: Window size (in milliseconds) for generating frames. :type window_ms: float - :param max_freq: When specgram_type is 'linear', only FFT bins + :param max_freq: When spectrum_type is 'linear', only FFT bins corresponding to frequencies between [0, max_freq] are - returned; when specgram_type is 'mfcc', max_feq is the + returned; when spectrum_type is 'mfcc', max_feq is the highest band edge of mel filters. :types max_freq: None|float :param target_sample_rate: Audio are resampled (if upsampling or @@ -47,7 +47,7 @@ class AudioFeaturizer(object): """ def __init__(self, - specgram_type: str='linear', + spectrum_type: str='linear', feat_dim: int=None, delta_delta: bool=False, stride_ms=10.0, @@ -58,7 +58,7 @@ class AudioFeaturizer(object): use_dB_normalization=True, target_dB=-20, dither=1.0): - self._specgram_type = specgram_type + self._spectrum_type = spectrum_type # mfcc and fbank using `feat_dim` self._feat_dim = feat_dim # mfcc and fbank using `delta-delta` @@ -113,27 +113,27 @@ class AudioFeaturizer(object): def feature_size(self): """audio feature size""" feat_dim = 0 - if self._specgram_type == 'linear': + if self._spectrum_type == 'linear': fft_point = self._window_ms if self._fft_point is None else self._fft_point feat_dim = int(fft_point * (self._target_sample_rate / 1000) / 2 + 1) - elif self._specgram_type == 'mfcc': + elif self._spectrum_type == 'mfcc': # mfcc, delta, delta-delta feat_dim = int(self._feat_dim * 3) if self._delta_delta else int(self._feat_dim) - elif self._specgram_type == 'fbank': + elif self._spectrum_type == 'fbank': # fbank, delta, delta-delta feat_dim = int(self._feat_dim * 3) if self._delta_delta else int(self._feat_dim) else: - raise ValueError("Unknown specgram_type %s. " - "Supported values: linear." % self._specgram_type) + raise ValueError("Unknown spectrum_type %s. " + "Supported values: linear." % self._spectrum_type) return feat_dim def _compute_specgram(self, audio_segment): """Extract various audio features.""" sample_rate = audio_segment.sample_rate - if self._specgram_type == 'linear': + if self._spectrum_type == 'linear': samples = audio_segment.samples return self._compute_linear_specgram( samples, @@ -141,7 +141,7 @@ class AudioFeaturizer(object): stride_ms=self._stride_ms, window_ms=self._window_ms, max_freq=self._max_freq) - elif self._specgram_type == 'mfcc': + elif self._spectrum_type == 'mfcc': samples = audio_segment.to('int16') return self._compute_mfcc( samples, @@ -152,7 +152,7 @@ class AudioFeaturizer(object): max_freq=self._max_freq, dither=self._dither, delta_delta=self._delta_delta) - elif self._specgram_type == 'fbank': + elif self._spectrum_type == 'fbank': samples = audio_segment.to('int16') return self._compute_fbank( samples, @@ -164,8 +164,8 @@ class AudioFeaturizer(object): dither=self._dither, delta_delta=self._delta_delta) else: - raise ValueError("Unknown specgram_type %s. " - "Supported values: linear." % self._specgram_type) + raise ValueError("Unknown spectrum_type %s. " + "Supported values: linear." % self._spectrum_type) def _compute_linear_specgram(self, samples, diff --git a/deepspeech/frontend/featurizer/speech_featurizer.py b/deepspeech/frontend/featurizer/speech_featurizer.py index e6761cb5..50856e16 100644 --- a/deepspeech/frontend/featurizer/speech_featurizer.py +++ b/deepspeech/frontend/featurizer/speech_featurizer.py @@ -27,16 +27,16 @@ class SpeechFeaturizer(object): :param vocab_filepath: Filepath to load vocabulary for token indices conversion. - :type specgram_type: str - :param specgram_type: Specgram feature type. Options: 'linear', 'mfcc'. - :type specgram_type: str + :type spectrum_type: str + :param spectrum_type: Specgram feature type. Options: 'linear', 'mfcc'. + :type spectrum_type: str :param stride_ms: Striding size (in milliseconds) for generating frames. :type stride_ms: float :param window_ms: Window size (in milliseconds) for generating frames. :type window_ms: float - :param max_freq: When specgram_type is 'linear', only FFT bins + :param max_freq: When spectrum_type is 'linear', only FFT bins corresponding to frequencies between [0, max_freq] are - returned; when specgram_type is 'mfcc', max_freq is the + returned; when spectrum_type is 'mfcc', max_freq is the highest band edge of mel filters. :types max_freq: None|float :param target_sample_rate: Speech are resampled (if upsampling or @@ -54,7 +54,7 @@ class SpeechFeaturizer(object): unit_type, vocab_filepath, spm_model_prefix=None, - specgram_type='linear', + spectrum_type='linear', feat_dim=None, delta_delta=False, stride_ms=10.0, @@ -66,7 +66,7 @@ class SpeechFeaturizer(object): target_dB=-20, dither=1.0): self._audio_featurizer = AudioFeaturizer( - specgram_type=specgram_type, + spectrum_type=spectrum_type, feat_dim=feat_dim, delta_delta=delta_delta, stride_ms=stride_ms, diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index 1ba6ac7f..7e16480f 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -45,7 +45,7 @@ class TextFeaturizer(object): self.sp = spm.SentencePieceProcessor() self.sp.Load(spm_model) - def tokenize(self, text): + def tokenize(self, text, replace_space=True): if self.unit_type == 'char': tokens = self.char_tokenize(text) elif self.unit_type == 'word': @@ -68,7 +68,7 @@ class TextFeaturizer(object): Args: text (str): Text to process. - + Returns: List[int]: List of token indices. """ @@ -81,7 +81,7 @@ class TextFeaturizer(object): def defeaturize(self, idxs): """Convert a list of token indices to text string, - ignore index after eos_id. + ignore index after eos_id. Args: idxs (List[int]): List of token indices. diff --git a/deepspeech/frontend/utility.py b/deepspeech/frontend/utility.py index b2dd9601..8fab236c 100644 --- a/deepspeech/frontend/utility.py +++ b/deepspeech/frontend/utility.py @@ -32,6 +32,7 @@ IGNORE_ID = -1 SOS = "" EOS = SOS UNK = "" +SPACE = " " BLANK = "" @@ -101,7 +102,7 @@ def rms_to_dbfs(rms: float): """Root Mean Square to dBFS. https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/ Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB. - + dB = dBFS + 3.0103 dBFS = db - 3.0103 e.g. 0 dB = -3.0103 dBFS @@ -116,26 +117,26 @@ def rms_to_dbfs(rms: float): def max_dbfs(sample_data: np.ndarray): - """Peak dBFS based on the maximum energy sample. + """Peak dBFS based on the maximum energy sample. Args: sample_data ([np.ndarray]): float array, [-1, 1]. Returns: - float: dBFS + float: dBFS """ # Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization. return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data)))) def mean_dbfs(sample_data): - """Peak dBFS based on the RMS energy. + """Peak dBFS based on the RMS energy. Args: sample_data ([np.ndarray]): float array, [-1, 1]. Returns: - float: dBFS + float: dBFS """ return rms_to_dbfs( math.sqrt(np.mean(np.square(sample_data, dtype=np.float64)))) @@ -155,7 +156,7 @@ def gain_db_to_ratio(gain_db: float): def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103): """Nomalize audio to dBFS. - + Args: sample_data (np.ndarray): input wave samples, [-1, 1]. dbfs (float, optional): target dBFS. Defaults to -3.0103. diff --git a/deepspeech/io/__init__.py b/deepspeech/io/__init__.py index e180f18e..884e76e5 100644 --- a/deepspeech/io/__init__.py +++ b/deepspeech/io/__init__.py @@ -35,7 +35,7 @@ def create_dataloader(manifest_path, stride_ms=10.0, window_ms=20.0, max_freq=None, - specgram_type='linear', + spectrum_type='linear', feat_dim=None, delta_delta=False, use_dB_normalization=True, @@ -64,7 +64,7 @@ def create_dataloader(manifest_path, stride_ms=stride_ms, window_ms=window_ms, max_freq=max_freq, - specgram_type=specgram_type, + spectrum_type=spectrum_type, feat_dim=feat_dim, delta_delta=delta_delta, use_dB_normalization=use_dB_normalization, diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index fba5f7c6..fe53d8e3 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -63,7 +63,7 @@ class ManifestDataset(Dataset): n_fft=None, # fft points max_freq=None, # None for samplerate/2 raw_wav=True, # use raw_wav or kaldi feature - specgram_type='linear', # 'linear', 'mfcc', 'fbank' + spectrum_type='linear', # 'linear', 'mfcc', 'fbank' feat_dim=0, # 'mfcc', 'fbank' delta_delta=False, # 'mfcc', 'fbank' dither=1.0, # feature dither @@ -124,7 +124,7 @@ class ManifestDataset(Dataset): n_fft=config.data.n_fft, max_freq=config.data.max_freq, target_sample_rate=config.data.target_sample_rate, - specgram_type=config.data.specgram_type, + spectrum_type=config.data.spectrum_type, feat_dim=config.data.feat_dim, delta_delta=config.data.delta_delta, dither=config.data.dither, @@ -152,7 +152,7 @@ class ManifestDataset(Dataset): n_fft=None, max_freq=None, target_sample_rate=16000, - specgram_type='linear', + spectrum_type='linear', feat_dim=None, delta_delta=False, dither=1.0, @@ -180,7 +180,7 @@ class ManifestDataset(Dataset): n_fft (int, optional): fft points for rfft. Defaults to None. max_freq (int, optional): max cut freq. Defaults to None. target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000. - specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'. + spectrum_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'. feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None. delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False. use_dB_normalization (bool, optional): do dB normalization. Defaults to True. @@ -200,7 +200,7 @@ class ManifestDataset(Dataset): unit_type=unit_type, vocab_filepath=vocab_filepath, spm_model_prefix=spm_model_prefix, - specgram_type=specgram_type, + spectrum_type=spectrum_type, feat_dim=feat_dim, delta_delta=delta_delta, stride_ms=stride_ms, diff --git a/deepspeech/models/ds2/__init__.py b/deepspeech/models/ds2/__init__.py new file mode 100644 index 00000000..39bea5bf --- /dev/null +++ b/deepspeech/models/ds2/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .deepspeech2 import DeepSpeech2InferModel +from .deepspeech2 import DeepSpeech2Model + +__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel'] diff --git a/deepspeech/models/ds2/conv.py b/deepspeech/models/ds2/conv.py new file mode 100644 index 00000000..111f5d3b --- /dev/null +++ b/deepspeech/models/ds2/conv.py @@ -0,0 +1,171 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddle import nn +from paddle.nn import functional as F + +from deepspeech.modules.activation import brelu +from deepspeech.modules.mask import make_non_pad_mask +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + +__all__ = ['ConvStack', "conv_output_size"] + + +def conv_output_size(I, F, P, S): + # https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters + # Output size after Conv: + # By noting I the length of the input volume size, + # F the length of the filter, + # P the amount of zero padding, + # S the stride, + # then the output size O of the feature map along that dimension is given by: + # O = (I - F + Pstart + Pend) // S + 1 + # When Pstart == Pend == P, we can replace Pstart + Pend by 2P. + # When Pstart == Pend == 0 + # O = (I - F - S) // S + # https://iq.opengenus.org/output-size-of-convolution/ + # Output height = (Input height + padding height top + padding height bottom - kernel height) / (stride height) + 1 + # Output width = (Output width + padding width right + padding width left - kernel width) / (stride width) + 1 + return (I - F + 2 * P - S) // S + + +# receptive field calculator +# https://fomoro.com/research/article/receptive-field-calculator +# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters +# https://distill.pub/2019/computing-receptive-fields/ +# Rl-1 = Sl * Rl + (Kl - Sl) + + +class ConvBn(nn.Layer): + """Convolution layer with batch normalization. + + :param kernel_size: The x dimension of a filter kernel. Or input a tuple for + two image dimension. + :type kernel_size: int|tuple|list + :param num_channels_in: Number of input channels. + :type num_channels_in: int + :param num_channels_out: Number of output channels. + :type num_channels_out: int + :param stride: The x dimension of the stride. Or input a tuple for two + image dimension. + :type stride: int|tuple|list + :param padding: The x dimension of the padding. Or input a tuple for two + image dimension. + :type padding: int|tuple|list + :param act: Activation type, relu|brelu + :type act: string + :return: Batch norm layer after convolution layer. + :rtype: Variable + + """ + + def __init__(self, num_channels_in, num_channels_out, kernel_size, stride, + padding, act): + + super().__init__() + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + + self.conv = nn.Conv2D( + num_channels_in, + num_channels_out, + kernel_size=kernel_size, + stride=stride, + padding=padding, + weight_attr=None, + bias_attr=False, + data_format='NCHW') + + self.bn = nn.BatchNorm2D( + num_channels_out, + weight_attr=None, + bias_attr=None, + data_format='NCHW') + self.act = F.relu if act == 'relu' else brelu + + def forward(self, x, x_len): + """ + x(Tensor): audio, shape [B, C, D, T] + """ + x = self.conv(x) + x = self.bn(x) + x = self.act(x) + + x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1] + ) // self.stride[1] + 1 + + # reset padding part to 0 + masks = make_non_pad_mask(x_len) #[B, T] + masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] + # TODO(Hui Zhang): not support bool multiply + masks = masks.type_as(x) + x = x.multiply(masks) + + return x, x_len + + +class ConvStack(nn.Layer): + """Convolution group with stacked convolution layers. + + :param feat_size: audio feature dim. + :type feat_size: int + :param num_stacks: Number of stacked convolution layers. + :type num_stacks: int + """ + + def __init__(self, feat_size, num_stacks): + super().__init__() + self.feat_size = feat_size # D + self.num_stacks = num_stacks + + self.conv_in = ConvBn( + num_channels_in=1, + num_channels_out=32, + kernel_size=(41, 11), #[D, T] + stride=(2, 3), + padding=(20, 5), + act='brelu') + + out_channel = 32 + convs = [ + ConvBn( + num_channels_in=32, + num_channels_out=out_channel, + kernel_size=(21, 11), + stride=(2, 1), + padding=(10, 5), + act='brelu') for i in range(num_stacks - 1) + ] + self.conv_stack = nn.LayerList(convs) + + # conv output feat_dim + output_height = (feat_size - 1) // 2 + 1 + for i in range(self.num_stacks - 1): + output_height = (output_height - 1) // 2 + 1 + self.output_height = out_channel * output_height + + def forward(self, x, x_len): + """ + x: shape [B, C, D, T] + x_len : shape [B] + """ + x, x_len = self.conv_in(x, x_len) + for i, conv in enumerate(self.conv_stack): + x, x_len = conv(x, x_len) + return x, x_len diff --git a/deepspeech/models/ds2/deepspeech2.py b/deepspeech/models/ds2/deepspeech2.py new file mode 100644 index 00000000..96730f80 --- /dev/null +++ b/deepspeech/models/ds2/deepspeech2.py @@ -0,0 +1,306 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Deepspeech2 ASR Model""" +from typing import Optional + +import paddle +from paddle import nn +from yacs.config import CfgNode + +from deepspeech.models.ds2.conv import ConvStack +from deepspeech.models.ds2.rnn import RNNStack +from deepspeech.modules.ctc import CTCDecoder +from deepspeech.utils import layer_tools +from deepspeech.utils.checkpoint import Checkpoint +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + +__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel'] + + +class CRNNEncoder(nn.Layer): + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=True): + super().__init__() + self.rnn_size = rnn_size + self.feat_size = feat_size # 161 for linear + self.dict_size = dict_size + + self.conv = ConvStack(feat_size, num_conv_layers) + + i_size = self.conv.output_height # H after conv stack + self.rnn = RNNStack( + i_size=i_size, + h_size=rnn_size, + num_stacks=num_rnn_layers, + use_gru=use_gru, + share_rnn_weights=share_rnn_weights) + + @property + def output_size(self): + return self.rnn_size * 2 + + def forward(self, audio, audio_len): + """Compute Encoder outputs + + Args: + audio (Tensor): [B, Tmax, D] + text (Tensor): [B, Umax] + audio_len (Tensor): [B] + text_len (Tensor): [B] + Returns: + x (Tensor): encoder outputs, [B, T, D] + x_lens (Tensor): encoder length, [B] + """ + # [B, T, D] -> [B, D, T] + audio = audio.transpose([0, 2, 1]) + # [B, D, T] -> [B, C=1, D, T] + x = audio.unsqueeze(1) + x_lens = audio_len + + # convolution group + x, x_lens = self.conv(x, x_lens) + + # convert data from convolution feature map to sequence of vectors + #B, C, D, T = paddle.shape(x) # not work under jit + x = x.transpose([0, 3, 1, 2]) #[B, T, C, D] + #x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit + x = x.reshape([0, 0, -1]) #[B, T, C*D] + + # remove padding part + x, x_lens = self.rnn(x, x_lens) #[B, T, D] + return x, x_lens + + +class DeepSpeech2Model(nn.Layer): + """The DeepSpeech2 network structure. + + :param audio_data: Audio spectrogram data layer. + :type audio_data: Variable + :param text_data: Transcription text data layer. + :type text_data: Variable + :param audio_len: Valid sequence length data layer. + :type audio_len: Variable + :param masks: Masks data layer to reset padding. + :type masks: Variable + :param dict_size: Dictionary size for tokenized transcription. + :type dict_size: int + :param num_conv_layers: Number of stacking convolution layers. + :type num_conv_layers: int + :param num_rnn_layers: Number of stacking RNN layers. + :type num_rnn_layers: int + :param rnn_size: RNN layer size (dimension of RNN cells). + :type rnn_size: int + :param use_gru: Use gru if set True. Use simple rnn if set False. + :type use_gru: bool + :param share_rnn_weights: Whether to share input-hidden weights between + forward and backward direction RNNs. + It is only available when use_gru=False. + :type share_weights: bool + :return: A tuple of an output unnormalized log probability layer ( + before softmax) and a ctc cost layer. + :rtype: tuple of LayerOutput + """ + + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + default = CfgNode( + dict( + num_conv_layers=2, #Number of stacking convolution layers. + num_rnn_layers=3, #Number of stacking RNN layers. + rnn_layer_size=1024, #RNN layer size (number of RNN cells). + use_gru=True, #Use gru if set True. Use simple rnn if set False. + share_rnn_weights=True, #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. + ctc_grad_norm_type='instance', )) + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=True, + blank_id=0, + ctc_grad_norm_type='instance'): + super().__init__() + self.encoder = CRNNEncoder( + feat_size=feat_size, + dict_size=dict_size, + num_conv_layers=num_conv_layers, + num_rnn_layers=num_rnn_layers, + rnn_size=rnn_size, + use_gru=use_gru, + share_rnn_weights=share_rnn_weights) + assert (self.encoder.output_size == rnn_size * 2) + + self.decoder = CTCDecoder( + odim=dict_size, # is in vocab + enc_n_units=self.encoder.output_size, + blank_id=blank_id, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type=ctc_grad_norm_type) + + def forward(self, audio, audio_len, text, text_len): + """Compute Model loss + + Args: + audio (Tenosr): [B, T, D] + audio_len (Tensor): [B] + text (Tensor): [B, U] + text_len (Tensor): [B] + + Returns: + loss (Tenosr): [1] + """ + eouts, eouts_len = self.encoder(audio, audio_len) + loss = self.decoder(eouts, eouts_len, text, text_len) + return loss + + @paddle.no_grad() + def decode(self, audio, audio_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes): + # init once + # decoders only accept string encoded in utf-8 + self.decoder.init_decode( + beam_alpha=beam_alpha, + beam_beta=beam_beta, + lang_model_path=lang_model_path, + vocab_list=vocab_list, + decoding_method=decoding_method) + + eouts, eouts_len = self.encoder(audio, audio_len) + probs = self.decoder.softmax(eouts) + return self.decoder.decode_probs( + probs.numpy(), eouts_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes) + + @classmethod + def from_pretrained(cls, dataloader, config, checkpoint_path): + """Build a DeepSpeech2Model model from a pretrained model. + Parameters + ---------- + dataloader: paddle.io.DataLoader + + config: yacs.config.CfgNode + model configs + + checkpoint_path: Path or str + the path of pretrained model checkpoint, without extension name + + Returns + ------- + DeepSpeech2Model + The model built from pretrained result. + """ + model = cls( + feat_size=dataloader.collate_fn.feature_size, + dict_size=dataloader.collate_fn.vocab_size, + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + use_gru=config.model.use_gru, + share_rnn_weights=config.model.share_rnn_weights, + blank_id=config.model.blank_id, + ctc_grad_norm_type=config.ctc_grad_norm_type, ) + infos = Checkpoint().load_parameters( + model, checkpoint_path=checkpoint_path) + logger.info(f"checkpoint info: {infos}") + layer_tools.summary(model) + return model + + @classmethod + def from_config(cls, config): + """Build a DeepSpeec2Model from config + Parameters + + config: yacs.config.CfgNode + config.model + Returns + ------- + DeepSpeech2Model + The model built from config. + """ + model = cls( + feat_size=config.feat_size, + dict_size=config.dict_size, + num_conv_layers=config.num_conv_layers, + num_rnn_layers=config.num_rnn_layers, + rnn_size=config.rnn_layer_size, + use_gru=config.use_gru, + share_rnn_weights=config.share_rnn_weights, + blank_id=config.blank_id, + ctc_grad_norm_type=config.ctc_grad_norm_type, ) + return model + + +class DeepSpeech2InferModel(DeepSpeech2Model): + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=True, + blank_id=0): + super().__init__( + feat_size=feat_size, + dict_size=dict_size, + num_conv_layers=num_conv_layers, + num_rnn_layers=num_rnn_layers, + rnn_size=rnn_size, + use_gru=use_gru, + share_rnn_weights=share_rnn_weights, + blank_id=blank_id) + + def forward(self, audio, audio_len): + """export model function + + Args: + audio (Tensor): [B, T, D] + audio_len (Tensor): [B] + + Returns: + probs: probs after softmax + """ + eouts, eouts_len = self.encoder(audio, audio_len) + probs = self.decoder.softmax(eouts) + return probs, eouts_len + + def export(self): + static_model = paddle.jit.to_static( + self, + input_spec=[ + paddle.static.InputSpec( + shape=[None, None, self.encoder.feat_size], + dtype='float32'), # audio, [B,T,D] + paddle.static.InputSpec(shape=[None], + dtype='int64'), # audio_length, [B] + ]) + return static_model diff --git a/deepspeech/models/ds2/rnn.py b/deepspeech/models/ds2/rnn.py new file mode 100644 index 00000000..29bd2883 --- /dev/null +++ b/deepspeech/models/ds2/rnn.py @@ -0,0 +1,314 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +from deepspeech.modules.activation import brelu +from deepspeech.modules.mask import make_non_pad_mask +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + +__all__ = ['RNNStack'] + + +class RNNCell(nn.RNNCellBase): + r""" + Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it + computes the outputs and updates states. + The formula used is as follows: + .. math:: + h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) + y_{t} & = h_{t} + + where :math:`act` is for :attr:`activation`. + """ + + def __init__(self, + hidden_size: int, + activation="tanh", + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super().__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_hh = self.create_parameter( + (hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = None + self.bias_hh = self.create_parameter( + (hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.hidden_size = hidden_size + if activation not in ["tanh", "relu", "brelu"]: + raise ValueError( + "activation for SimpleRNNCell should be tanh or relu, " + "but get {}".format(activation)) + self.activation = activation + self._activation_fn = paddle.tanh \ + if activation == "tanh" \ + else F.relu + if activation == 'brelu': + self._activation_fn = brelu + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + pre_h = states + i2h = inputs + if self.bias_ih is not None: + i2h += self.bias_ih + h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + h2h += self.bias_hh + h = self._activation_fn(i2h + h2h) + return h, h + + @property + def state_shape(self): + return (self.hidden_size, ) + + +class GRUCell(nn.RNNCellBase): + r""" + Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, + it computes the outputs and updates states. + The formula for GRU used is as follows: + .. math:: + r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr}) + z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz}) + \widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc})) + h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t} + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise + multiplication operator. + """ + + def __init__(self, + input_size: int, + hidden_size: int, + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super().__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_hh = self.create_parameter( + (3 * hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = None + self.bias_hh = self.create_parameter( + (3 * hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.hidden_size = hidden_size + self.input_size = input_size + self._gate_activation = F.sigmoid + self._activation = paddle.tanh + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + + pre_hidden = states + x_gates = inputs + if self.bias_ih is not None: + x_gates = x_gates + self.bias_ih + h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + h_gates = h_gates + self.bias_hh + + x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1) + h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1) + + r = self._gate_activation(x_r + h_r) + z = self._gate_activation(x_z + h_z) + c = self._activation(x_c + r * h_c) # apply reset gate after mm + h = (pre_hidden - c) * z + c + # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru + + return h, h + + @property + def state_shape(self): + r""" + The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch + size would be automatically inserted into shape). The shape corresponds + to the shape of :math:`h_{t-1}`. + """ + return (self.hidden_size, ) + + +class BiRNNWithBN(nn.Layer): + """Bidirectonal simple rnn layer with sequence-wise batch normalization. + The batch normalization is only performed on input-state weights. + + :param size: Dimension of RNN cells. + :type size: int + :param share_weights: Whether to share input-hidden weights between + forward and backward directional RNNs. + :type share_weights: bool + :return: Bidirectional simple rnn layer. + :rtype: Variable + """ + + def __init__(self, i_size: int, h_size: int, share_weights: bool): + super().__init__() + self.share_weights = share_weights + if self.share_weights: + #input-hidden weights shared between bi-directional rnn. + self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) + # batch norm is only performed on input-state projection + self.fw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + self.bw_fc = self.fw_fc + self.bw_bn = self.fw_bn + else: + self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) + self.fw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False) + self.bw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + + self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu') + self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu') + self.fw_rnn = nn.RNN( + self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] + self.bw_rnn = nn.RNN( + self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] + + def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): + # x, shape [B, T, D] + fw_x = self.fw_bn(self.fw_fc(x)) + bw_x = self.bw_bn(self.bw_fc(x)) + fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) + bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) + x = paddle.concat([fw_x, bw_x], axis=-1) + return x, x_len + + +class BiGRUWithBN(nn.Layer): + """Bidirectonal gru layer with sequence-wise batch normalization. + The batch normalization is only performed on input-state weights. + + :param name: Name of the layer. + :type name: string + :param input: Input layer. + :type input: Variable + :param size: Dimension of GRU cells. + :type size: int + :param act: Activation type. + :type act: string + :return: Bidirectional GRU layer. + :rtype: Variable + """ + + def __init__(self, i_size: int, h_size: int): + super().__init__() + hidden_size = h_size * 3 + + self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) + self.fw_bn = nn.BatchNorm1D( + hidden_size, bias_attr=None, data_format='NLC') + self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) + self.bw_bn = nn.BatchNorm1D( + hidden_size, bias_attr=None, data_format='NLC') + + self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) + self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) + self.fw_rnn = nn.RNN( + self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] + self.bw_rnn = nn.RNN( + self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] + + def forward(self, x, x_len): + # x, shape [B, T, D] + fw_x = self.fw_bn(self.fw_fc(x)) + bw_x = self.bw_bn(self.bw_fc(x)) + fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) + bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) + x = paddle.concat([fw_x, bw_x], axis=-1) + return x, x_len + + +class RNNStack(nn.Layer): + """RNN group with stacked bidirectional simple RNN or GRU layers. + + :param input: Input layer. + :type input: Variable + :param size: Dimension of RNN cells in each layer. + :type size: int + :param num_stacks: Number of stacked rnn layers. + :type num_stacks: int + :param use_gru: Use gru if set True. Use simple rnn if set False. + :type use_gru: bool + :param share_rnn_weights: Whether to share input-hidden weights between + forward and backward directional RNNs. + It is only available when use_gru=False. + :type share_weights: bool + :return: Output layer of the RNN group. + :rtype: Variable + """ + + def __init__(self, + i_size: int, + h_size: int, + num_stacks: int, + use_gru: bool, + share_rnn_weights: bool): + super().__init__() + rnn_stacks = [] + for i in range(num_stacks): + if use_gru: + #default:GRU using tanh + rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size)) + else: + rnn_stacks.append( + BiRNNWithBN( + i_size=i_size, + h_size=h_size, + share_weights=share_rnn_weights)) + i_size = h_size * 2 + + self.rnn_stacks = nn.ModuleList(rnn_stacks) + + def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): + """ + x: shape [B, T, D] + x_len: shpae [B] + """ + for i, rnn in enumerate(self.rnn_stacks): + x, x_len = rnn(x, x_len) + masks = make_non_pad_mask(x_len) #[B, T] + masks = masks.unsqueeze(-1) # [B, T, 1] + # TODO(Hui Zhang): not support bool multiply + masks = masks.type_as(x) + x = x.multiply(masks) + return x, x_len diff --git a/deepspeech/models/ds2_online/__init__.py b/deepspeech/models/ds2_online/__init__.py new file mode 100644 index 00000000..255000ee --- /dev/null +++ b/deepspeech/models/ds2_online/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .deepspeech2 import DeepSpeech2InferModelOnline +from .deepspeech2 import DeepSpeech2ModelOnline + +__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline'] diff --git a/deepspeech/models/ds2_online/conv.py b/deepspeech/models/ds2_online/conv.py new file mode 100644 index 00000000..4a6fd5ab --- /dev/null +++ b/deepspeech/models/ds2_online/conv.py @@ -0,0 +1,33 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle + +from deepspeech.modules.subsampling import Conv2dSubsampling4 + + +class Conv2dSubsampling4Online(Conv2dSubsampling4): + def __init__(self, idim: int, odim: int, dropout_rate: float): + super().__init__(idim, odim, dropout_rate, None) + self.output_dim = ((idim - 1) // 2 - 1) // 2 * odim + self.receptive_field_length = 2 * ( + 3 - 1) + 3 # stride_1 * (kernel_size_2 - 1) + kerel_size_1 + + def forward(self, x: paddle.Tensor, + x_len: paddle.Tensor) -> [paddle.Tensor, paddle.Tensor]: + x = x.unsqueeze(1) # (b, c=1, t, f) + x = self.conv(x) + #b, c, t, f = paddle.shape(x) #not work under jit + x = x.transpose([0, 2, 1, 3]).reshape([0, 0, -1]) + x_len = ((x_len - 1) // 2 - 1) // 2 + return x, x_len diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py new file mode 100644 index 00000000..29d207c4 --- /dev/null +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -0,0 +1,438 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Deepspeech2 ASR Online Model""" +from typing import Optional + +import paddle +import paddle.nn.functional as F +from paddle import nn +from yacs.config import CfgNode + +from deepspeech.models.ds2_online.conv import Conv2dSubsampling4Online +from deepspeech.modules.ctc import CTCDecoder +from deepspeech.utils import layer_tools +from deepspeech.utils.checkpoint import Checkpoint +from deepspeech.utils.log import Log +logger = Log(__name__).getlog() + +__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline'] + + +class CRNNEncoder(nn.Layer): + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=4, + rnn_size=1024, + rnn_direction='forward', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False): + super().__init__() + self.rnn_size = rnn_size + self.feat_size = feat_size # 161 for linear + self.dict_size = dict_size + self.num_rnn_layers = num_rnn_layers + self.num_fc_layers = num_fc_layers + self.rnn_direction = rnn_direction + self.fc_layers_size_list = fc_layers_size_list + self.use_gru = use_gru + self.conv = Conv2dSubsampling4Online(feat_size, 32, dropout_rate=0.0) + + self.output_dim = self.conv.output_dim + + i_size = self.conv.output_dim + self.rnn = nn.LayerList() + self.layernorm_list = nn.LayerList() + self.fc_layers_list = nn.LayerList() + if rnn_direction == 'bidirect' or rnn_direction == 'bidirectional': + layernorm_size = 2 * rnn_size + elif rnn_direction == 'forward': + layernorm_size = rnn_size + else: + raise Exception("Wrong rnn direction") + for i in range(0, num_rnn_layers): + if i == 0: + rnn_input_size = i_size + else: + rnn_input_size = layernorm_size + if use_gru is True: + self.rnn.append( + nn.GRU( + input_size=rnn_input_size, + hidden_size=rnn_size, + num_layers=1, + direction=rnn_direction)) + else: + self.rnn.append( + nn.LSTM( + input_size=rnn_input_size, + hidden_size=rnn_size, + num_layers=1, + direction=rnn_direction)) + self.layernorm_list.append(nn.LayerNorm(layernorm_size)) + self.output_dim = layernorm_size + + fc_input_size = layernorm_size + for i in range(self.num_fc_layers): + self.fc_layers_list.append( + nn.Linear(fc_input_size, fc_layers_size_list[i])) + fc_input_size = fc_layers_size_list[i] + self.output_dim = fc_layers_size_list[i] + + @property + def output_size(self): + return self.output_dim + + def forward(self, x, x_lens, init_state_h_box=None, init_state_c_box=None): + """Compute Encoder outputs + + Args: + x (Tensor): [B, T, D] + x_lens (Tensor): [B] + init_state_h_box(Tensor): init_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + init_state_c_box(Tensor): init_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + Return: + x (Tensor): encoder outputs, [B, T, D] + x_lens (Tensor): encoder length, [B] + final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + """ + if init_state_h_box is not None: + init_state_list = None + + if self.use_gru is True: + init_state_h_list = paddle.split( + init_state_h_box, self.num_rnn_layers, axis=0) + init_state_list = init_state_h_list + else: + init_state_h_list = paddle.split( + init_state_h_box, self.num_rnn_layers, axis=0) + init_state_c_list = paddle.split( + init_state_c_box, self.num_rnn_layers, axis=0) + init_state_list = [(init_state_h_list[i], init_state_c_list[i]) + for i in range(self.num_rnn_layers)] + else: + init_state_list = [None] * self.num_rnn_layers + + x, x_lens = self.conv(x, x_lens) + final_chunk_state_list = [] + for i in range(0, self.num_rnn_layers): + x, final_state = self.rnn[i](x, init_state_list[i], + x_lens) #[B, T, D] + final_chunk_state_list.append(final_state) + x = self.layernorm_list[i](x) + + for i in range(self.num_fc_layers): + x = self.fc_layers_list[i](x) + x = F.relu(x) + + if self.use_gru is True: + final_chunk_state_h_box = paddle.concat( + final_chunk_state_list, axis=0) + final_chunk_state_c_box = init_state_c_box + else: + final_chunk_state_h_list = [ + final_chunk_state_list[i][0] for i in range(self.num_rnn_layers) + ] + final_chunk_state_c_list = [ + final_chunk_state_list[i][1] for i in range(self.num_rnn_layers) + ] + final_chunk_state_h_box = paddle.concat( + final_chunk_state_h_list, axis=0) + final_chunk_state_c_box = paddle.concat( + final_chunk_state_c_list, axis=0) + + return x, x_lens, final_chunk_state_h_box, final_chunk_state_c_box + + def forward_chunk_by_chunk(self, x, x_lens, decoder_chunk_size=8): + """Compute Encoder outputs + + Args: + x (Tensor): [B, T, D] + x_lens (Tensor): [B] + decoder_chunk_size: The chunk size of decoder + Returns: + eouts_list (List of Tensor): The list of encoder outputs in chunk_size: [B, chunk_size, D] * num_chunks + eouts_lens_list (List of Tensor): The list of encoder length in chunk_size: [B] * num_chunks + final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size] + """ + subsampling_rate = self.conv.subsampling_rate + receptive_field_length = self.conv.receptive_field_length + chunk_size = (decoder_chunk_size - 1 + ) * subsampling_rate + receptive_field_length + chunk_stride = subsampling_rate * decoder_chunk_size + max_len = x.shape[1] + assert (chunk_size <= max_len) + + eouts_chunk_list = [] + eouts_chunk_lens_list = [] + if (max_len - chunk_size) % chunk_stride != 0: + padding_len = chunk_stride - (max_len - chunk_size) % chunk_stride + else: + padding_len = 0 + padding = paddle.zeros((x.shape[0], padding_len, x.shape[2])) + padded_x = paddle.concat([x, padding], axis=1) + num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1 + num_chunk = int(num_chunk) + chunk_state_h_box = None + chunk_state_c_box = None + final_state_h_box = None + final_state_c_box = None + for i in range(0, num_chunk): + start = i * chunk_stride + end = start + chunk_size + x_chunk = padded_x[:, start:end, :] + + x_len_left = paddle.where(x_lens - i * chunk_stride < 0, + paddle.zeros_like(x_lens), + x_lens - i * chunk_stride) + x_chunk_len_tmp = paddle.ones_like(x_lens) * chunk_size + x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp, + x_len_left, x_chunk_len_tmp) + + eouts_chunk, eouts_chunk_lens, chunk_state_h_box, chunk_state_c_box = self.forward( + x_chunk, x_chunk_lens, chunk_state_h_box, chunk_state_c_box) + + eouts_chunk_list.append(eouts_chunk) + eouts_chunk_lens_list.append(eouts_chunk_lens) + final_state_h_box = chunk_state_h_box + final_state_c_box = chunk_state_c_box + return eouts_chunk_list, eouts_chunk_lens_list, final_state_h_box, final_state_c_box + + +class DeepSpeech2ModelOnline(nn.Layer): + """The DeepSpeech2 network structure for online. + + :param audio: Audio spectrogram data layer. + :type audio: Variable + :param text: Transcription text data layer. + :type text: Variable + :param audio_len: Valid sequence length data layer. + :type audio_len: Variable + :param feat_size: feature size for audio. + :type feat_size: int + :param dict_size: Dictionary size for tokenized transcription. + :type dict_size: int + :param num_conv_layers: Number of stacking convolution layers. + :type num_conv_layers: int + :param num_rnn_layers: Number of stacking RNN layers. + :type num_rnn_layers: int + :param rnn_size: RNN layer size (dimension of RNN cells). + :type rnn_size: int + :param num_fc_layers: Number of stacking FC layers. + :type num_fc_layers: int + :param fc_layers_size_list: The list of FC layer sizes. + :type fc_layers_size_list: [int,] + :param use_gru: Use gru if set True. Use simple rnn if set False. + :type use_gru: bool + :return: A tuple of an output unnormalized log probability layer ( + before softmax) and a ctc cost layer. + :rtype: tuple of LayerOutput + """ + + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + default = CfgNode( + dict( + num_conv_layers=2, #Number of stacking convolution layers. + num_rnn_layers=4, #Number of stacking RNN layers. + rnn_layer_size=1024, #RNN layer size (number of RNN cells). + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=True, #Use gru if set True. Use simple rnn if set False. + blank_id=0, # index of blank in vocob.txt + )) + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=4, + rnn_size=1024, + rnn_direction='forward', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False, + blank_id=0): + super().__init__() + self.encoder = CRNNEncoder( + feat_size=feat_size, + dict_size=dict_size, + num_conv_layers=num_conv_layers, + num_rnn_layers=num_rnn_layers, + rnn_direction=rnn_direction, + num_fc_layers=num_fc_layers, + fc_layers_size_list=fc_layers_size_list, + rnn_size=rnn_size, + use_gru=use_gru) + + self.decoder = CTCDecoder( + odim=dict_size, # is in vocab + enc_n_units=self.encoder.output_size, + blank_id=blank_id, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type='instance') + + def forward(self, audio, audio_len, text, text_len): + """Compute Model loss + + Args: + audio (Tenosr): [B, T, D] + audio_len (Tensor): [B] + text (Tensor): [B, U] + text_len (Tensor): [B] + + Returns: + loss (Tenosr): [1] + """ + eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( + audio, audio_len, None, None) + loss = self.decoder(eouts, eouts_len, text, text_len) + return loss + + @paddle.no_grad() + def decode(self, audio, audio_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes): + # init once + # decoders only accept string encoded in utf-8 + self.decoder.init_decode( + beam_alpha=beam_alpha, + beam_beta=beam_beta, + lang_model_path=lang_model_path, + vocab_list=vocab_list, + decoding_method=decoding_method) + + eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( + audio, audio_len, None, None) + probs = self.decoder.softmax(eouts) + return self.decoder.decode_probs( + probs.numpy(), eouts_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes) + + @classmethod + def from_pretrained(cls, dataloader, config, checkpoint_path): + """Build a DeepSpeech2Model model from a pretrained model. + Parameters + ---------- + dataloader: paddle.io.DataLoader + + config: yacs.config.CfgNode + model configs + + checkpoint_path: Path or str + the path of pretrained model checkpoint, without extension name + + Returns + ------- + DeepSpeech2ModelOnline + The model built from pretrained result. + """ + model = cls(feat_size=dataloader.collate_fn.feature_size, + dict_size=dataloader.collate_fn.vocab_size, + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + rnn_direction=config.model.rnn_direction, + num_fc_layers=config.model.num_fc_layers, + fc_layers_size_list=config.model.fc_layers_size_list, + use_gru=config.model.use_gru, + blank_id=config.model.blank_id) + infos = Checkpoint().load_parameters( + model, checkpoint_path=checkpoint_path) + logger.info(f"checkpoint info: {infos}") + layer_tools.summary(model) + return model + + @classmethod + def from_config(cls, config): + """Build a DeepSpeec2ModelOnline from config + Parameters + + config: yacs.config.CfgNode + config.model + Returns + ------- + DeepSpeech2ModelOnline + The model built from config. + """ + model = cls(feat_size=config.feat_size, + dict_size=config.dict_size, + num_conv_layers=config.num_conv_layers, + num_rnn_layers=config.num_rnn_layers, + rnn_size=config.rnn_layer_size, + rnn_direction=config.rnn_direction, + num_fc_layers=config.num_fc_layers, + fc_layers_size_list=config.fc_layers_size_list, + use_gru=config.use_gru, + blank_id=config.blank_id) + return model + + +class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=4, + rnn_size=1024, + rnn_direction='forward', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False, + blank_id=0): + super().__init__( + feat_size=feat_size, + dict_size=dict_size, + num_conv_layers=num_conv_layers, + num_rnn_layers=num_rnn_layers, + rnn_size=rnn_size, + rnn_direction=rnn_direction, + num_fc_layers=num_fc_layers, + fc_layers_size_list=fc_layers_size_list, + use_gru=use_gru, + blank_id=blank_id) + + def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box, + chunk_state_c_box): + eouts_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box = self.encoder( + audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box) + probs_chunk = self.decoder.softmax(eouts_chunk) + return probs_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box + + def export(self): + static_model = paddle.jit.to_static( + self, + input_spec=[ + paddle.static.InputSpec( + shape=[None, None, + self.encoder.feat_size], #[B, chunk_size, feat_dim] + dtype='float32'), + paddle.static.InputSpec(shape=[None], + dtype='int64'), # audio_length, [B] + paddle.static.InputSpec( + shape=[None, None, None], dtype='float32'), + paddle.static.InputSpec( + shape=[None, None, None], dtype='float32') + ]) + return static_model diff --git a/deepspeech/models/u2/__init__.py b/deepspeech/models/u2/__init__.py new file mode 100644 index 00000000..a9010f1d --- /dev/null +++ b/deepspeech/models/u2/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .u2 import U2InferModel +from .u2 import U2Model +from .updater import U2Evaluator +from .updater import U2Updater + +__all__ = ["U2Model", "U2InferModel", "U2Evaluator", "U2Updater"] diff --git a/deepspeech/models/u2/u2.py b/deepspeech/models/u2/u2.py new file mode 100644 index 00000000..e6cd7b5c --- /dev/null +++ b/deepspeech/models/u2/u2.py @@ -0,0 +1,951 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""U2 ASR Model +Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition +(https://arxiv.org/pdf/2012.05481.pdf) +""" +import sys +import time +from collections import defaultdict +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import paddle +from paddle import jit +from paddle import nn +from yacs.config import CfgNode + +from deepspeech.frontend.utility import IGNORE_ID +from deepspeech.frontend.utility import load_cmvn +from deepspeech.modules.cmvn import GlobalCMVN +from deepspeech.modules.ctc import CTCDecoder +from deepspeech.modules.decoder import TransformerDecoder +from deepspeech.modules.encoder import ConformerEncoder +from deepspeech.modules.encoder import TransformerEncoder +from deepspeech.modules.loss import LabelSmoothingLoss +from deepspeech.modules.mask import make_pad_mask +from deepspeech.modules.mask import mask_finished_preds +from deepspeech.modules.mask import mask_finished_scores +from deepspeech.modules.mask import subsequent_mask +from deepspeech.utils import checkpoint +from deepspeech.utils import layer_tools +from deepspeech.utils.ctc_utils import remove_duplicates_and_blank +from deepspeech.utils.log import Log +from deepspeech.utils.tensor_utils import add_sos_eos +from deepspeech.utils.tensor_utils import pad_sequence +from deepspeech.utils.tensor_utils import th_accuracy +from deepspeech.utils.utility import log_add +from deepspeech.utils.utility import UpdateConfig + +__all__ = ["U2Model", "U2InferModel"] + +logger = Log(__name__).getlog() + + +class U2BaseModel(nn.Layer): + """CTC-Attention hybrid Encoder-Decoder model""" + + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # network architecture + default = CfgNode() + # allow add new item when merge_with_file + default.cmvn_file = "" + default.cmvn_file_type = "json" + default.input_dim = 0 + default.output_dim = 0 + # encoder related + default.encoder = 'transformer' + default.encoder_conf = CfgNode( + dict( + output_size=256, # dimension of attention + attention_heads=4, + linear_units=2048, # the number of units of position-wise feed forward + num_blocks=12, # the number of encoder blocks + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer='conv2d', # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before=True, + # use_cnn_module=True, + # cnn_module_kernel=15, + # activation_type='swish', + # pos_enc_layer_type='rel_pos', + # selfattention_layer_type='rel_selfattn', + )) + # decoder related + default.decoder = 'transformer' + default.decoder_conf = CfgNode( + dict( + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + self_attention_dropout_rate=0.0, + src_attention_dropout_rate=0.0, )) + # hybrid CTC/attention + default.model_conf = CfgNode( + dict( + ctc_weight=0.3, + lsm_weight=0.1, # label smoothing option + length_normalized_loss=False, )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, + vocab_size: int, + encoder: TransformerEncoder, + decoder: TransformerDecoder, + ctc: CTCDecoder, + ctc_weight: float=0.5, + ignore_id: int=IGNORE_ID, + lsm_weight: float=0.0, + length_normalized_loss: bool=False, + **kwargs): + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + + super().__init__() + # note that eos is the same as sos (equivalent ID) + self.sos = vocab_size - 1 + self.eos = vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + + self.encoder = encoder + self.decoder = decoder + self.ctc = ctc + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, ) + + def forward( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + text: paddle.Tensor, + text_lengths: paddle.Tensor, + ) -> Tuple[Optional[paddle.Tensor], Optional[paddle.Tensor], Optional[ + paddle.Tensor]]: + """Frontend + Encoder + Decoder + Calc loss + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + Returns: + total_loss, attention_loss, ctc_loss + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == + text_lengths.shape[0]), (speech.shape, speech_lengths.shape, + text.shape, text_lengths.shape) + # 1. Encoder + start = time.time() + encoder_out, encoder_mask = self.encoder(speech, speech_lengths) + encoder_time = time.time() - start + #logger.debug(f"encoder time: {encoder_time}") + #TODO(Hui Zhang): sum not support bool type + #encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] + encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum( + 1) #[B, 1, T] -> [B] + + # 2a. Attention-decoder branch + loss_att = None + if self.ctc_weight != 1.0: + start = time.time() + loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, + text, text_lengths) + decoder_time = time.time() - start + #logger.debug(f"decoder time: {decoder_time}") + + # 2b. CTC branch + loss_ctc = None + if self.ctc_weight != 0.0: + start = time.time() + loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, + text_lengths) + ctc_time = time.time() - start + #logger.debug(f"ctc time: {ctc_time}") + + if loss_ctc is None: + loss = loss_att + elif loss_att is None: + loss = loss_ctc + else: + loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + return loss, loss_att, loss_ctc + + def _calc_att_loss( + self, + encoder_out: paddle.Tensor, + encoder_mask: paddle.Tensor, + ys_pad: paddle.Tensor, + ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]: + """Calc attention loss. + + Args: + encoder_out (paddle.Tensor): [B, Tmax, D] + encoder_mask (paddle.Tensor): [B, 1, Tmax] + ys_pad (paddle.Tensor): [B, Umax] + ys_pad_lens (paddle.Tensor): [B] + + Returns: + Tuple[paddle.Tensor, float]: attention_loss, accuracy rate + """ + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad, + ys_in_lens) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, ) + return loss_att, acc_att + + def _forward_encoder( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Encoder pass. + + Args: + speech (paddle.Tensor): [B, Tmax, D] + speech_lengths (paddle.Tensor): [B] + decoding_chunk_size (int, optional): chuck size. Defaults to -1. + num_decoding_left_chunks (int, optional): nums chunks. Defaults to -1. + simulate_streaming (bool, optional): streaming or not. Defaults to False. + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: + encoder hiddens (B, Tmax, D), + encoder hiddens mask (B, 1, Tmax). + """ + # Let's assume B = batch_size + # 1. Encoder + if simulate_streaming and decoding_chunk_size > 0: + encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk( + speech, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + else: + encoder_out, encoder_mask = self.encoder( + speech, + speech_lengths, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + return encoder_out, encoder_mask + + def recognize( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + beam_size: int=10, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False, ) -> paddle.Tensor: + """ Apply beam search on attention decoder + Args: + speech (paddle.Tensor): (batch, max_len, feat_dim) + speech_length (paddle.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + paddle.Tensor: decoding result, (batch, max_result_len) + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + device = speech.place + batch_size = speech.shape[0] + + # Let's assume B = batch_size and N = beam_size + # 1. Encoder + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.shape[1] + encoder_dim = encoder_out.shape[2] + running_size = batch_size * beam_size + encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( + running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) + encoder_mask = encoder_mask.unsqueeze(1).repeat( + 1, beam_size, 1, 1).view(running_size, 1, + maxlen) # (B*N, 1, max_len) + + hyps = paddle.ones( + [running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1) + # log scale score + scores = paddle.to_tensor( + [0.0] + [-float('inf')] * (beam_size - 1), dtype=paddle.float) + scores = scores.to(device).repeat(batch_size).unsqueeze(1).to( + device) # (B*N, 1) + end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1) + cache: Optional[List[paddle.Tensor]] = None + # 2. Decoder forward step by step + for i in range(1, maxlen + 1): + # Stop if all batch and all beam produce eos + # TODO(Hui Zhang): if end_flag.sum() == running_size: + if end_flag.cast(paddle.int64).sum() == running_size: + break + + # 2.1 Forward decoder step + hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( + running_size, 1, 1).to(device) # (B*N, i, i) + # logp: (B*N, vocab) + logp, cache = self.decoder.forward_one_step( + encoder_out, encoder_mask, hyps, hyps_mask, cache) + + # 2.2 First beam prune: select topk best prob at current time + top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) + top_k_logp = mask_finished_scores(top_k_logp, end_flag) + top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos) + + # 2.3 Seconde beam prune: select topk score with history + scores = scores + top_k_logp # (B*N, N), broadcast add + scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) + scores, offset_k_index = scores.topk(k=beam_size) # (B, N) + scores = scores.view(-1, 1) # (B*N, 1) + + # 2.4. Compute base index in top_k_index, + # regard top_k_index as (B*N*N),regard offset_k_index as (B*N), + # then find offset_k_index in top_k_index + base_k_index = paddle.arange(batch_size).view(-1, 1).repeat( + 1, beam_size) # (B, N) + base_k_index = base_k_index * beam_size * beam_size + best_k_index = base_k_index.view(-1) + offset_k_index.view( + -1) # (B*N) + + # 2.5 Update best hyps + best_k_pred = paddle.index_select( + top_k_index.view(-1), index=best_k_index, axis=0) # (B*N) + best_hyps_index = best_k_index // beam_size + last_best_k_hyps = paddle.index_select( + hyps, index=best_hyps_index, axis=0) # (B*N, i) + hyps = paddle.cat( + (last_best_k_hyps, best_k_pred.view(-1, 1)), + dim=1) # (B*N, i+1) + + # 2.6 Update end flag + end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1) + + # 3. Select best of best + scores = scores.view(batch_size, beam_size) + # TODO: length normalization + best_index = paddle.argmax(scores, axis=-1).long() # (B) + best_hyps_index = best_index + paddle.arange( + batch_size, dtype=paddle.long) * beam_size + best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0) + best_hyps = best_hyps[:, 1:] + return best_hyps + + def ctc_greedy_search( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False, ) -> List[List[int]]: + """ Apply CTC greedy search + Args: + speech (paddle.Tensor): (batch, max_len, feat_dim) + speech_length (paddle.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + List[List[int]]: best path result + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + batch_size = speech.shape[0] + + # Let's assume B = batch_size + # encoder_out: (B, maxlen, encoder_dim) + # encoder_mask: (B, 1, Tmax) + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, simulate_streaming) + maxlen = encoder_out.shape[1] + # (TODO Hui Zhang): bool no support reduce_sum + # encoder_out_lens = encoder_mask.squeeze(1).sum(1) + encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1) + ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) + + topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) + topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) + pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen) + topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen) + + hyps = [hyp.tolist() for hyp in topk_index] + hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] + return hyps + + def _ctc_prefix_beam_search( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + beam_size: int, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False, + blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]: + """ CTC prefix beam search inner implementation + Args: + speech (paddle.Tensor): (batch, max_len, feat_dim) + speech_length (paddle.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + List[Tuple[int, float]]: nbest results, (N,1), (text, likelihood) + paddle.Tensor: encoder output, (1, max_len, encoder_dim), + it will be used for rescoring in attention rescoring mode + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + batch_size = speech.shape[0] + # For CTC prefix beam search, we only support batch_size=1 + assert batch_size == 1 + + # Let's assume B = batch_size and N = beam_size + # 1. Encoder forward and get CTC score + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.shape[1] + ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + + # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) + # blank_ending_score and none_blank_ending_score in ln domain + cur_hyps = [(tuple(), (0.0, -float('inf')))] + # 2. CTC beam search step by step + for t in range(0, maxlen): + logp = ctc_probs[t] # (vocab_size,) + # key: prefix, value (pb, pnb), default value(-inf, -inf) + next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) + # 2.1 First beam prune: select topk best + top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) + for s in top_k_index: + s = s.item() + ps = logp[s].item() + for prefix, (pb, pnb) in cur_hyps: + last = prefix[-1] if len(prefix) > 0 else None + if s == blank_id: # blank + n_pb, n_pnb = next_hyps[prefix] + n_pb = log_add([n_pb, pb + ps, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + elif s == last: + # Update *ss -> *s; + n_pb, n_pnb = next_hyps[prefix] + n_pnb = log_add([n_pnb, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + # Update *s-s -> *ss, - is for blank + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + else: + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + + # 2.2 Second beam prune + next_hyps = sorted( + next_hyps.items(), + key=lambda x: log_add(list(x[1])), + reverse=True) + cur_hyps = next_hyps[:beam_size] + + hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps] + return hyps, encoder_out + + def ctc_prefix_beam_search( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + beam_size: int, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False, ) -> List[int]: + """ Apply CTC prefix beam search + Args: + speech (paddle.Tensor): (batch, max_len, feat_dim) + speech_length (paddle.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + List[int]: CTC prefix beam search nbest results + """ + hyps, _ = self._ctc_prefix_beam_search( + speech, speech_lengths, beam_size, decoding_chunk_size, + num_decoding_left_chunks, simulate_streaming) + return hyps[0][0] + + def attention_rescoring( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + beam_size: int, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + ctc_weight: float=0.0, + simulate_streaming: bool=False, ) -> List[int]: + """ Apply attention rescoring decoding, CTC prefix beam search + is applied first to get nbest, then we resoring the nbest on + attention decoder with corresponding encoder out + Args: + speech (paddle.Tensor): (batch, max_len, feat_dim) + speech_length (paddle.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + List[int]: Attention rescoring result + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + device = speech.place + batch_size = speech.shape[0] + # For attention rescoring we only support batch_size=1 + assert batch_size == 1 + + # len(hyps) = beam_size, encoder_out: (1, maxlen, encoder_dim) + hyps, encoder_out = self._ctc_prefix_beam_search( + speech, speech_lengths, beam_size, decoding_chunk_size, + num_decoding_left_chunks, simulate_streaming) + assert len(hyps) == beam_size + + hyps_pad = pad_sequence([ + paddle.to_tensor(hyp[0], place=device, dtype=paddle.long) + for hyp in hyps + ], True, self.ignore_id) # (beam_size, max_hyps_len) + hyps_lens = paddle.to_tensor( + [len(hyp[0]) for hyp in hyps], place=device, + dtype=paddle.long) # (beam_size,) + hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + + encoder_out = encoder_out.repeat(beam_size, 1, 1) + encoder_mask = paddle.ones( + (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) + decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps_pad, + hyps_lens) # (beam_size, max_hyps_len, vocab_size) + # ctc score in ln domain + decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) + decoder_out = decoder_out.numpy() + + # Only use decoder score for rescoring + best_score = -float('inf') + best_index = 0 + # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size + for i, hyp in enumerate(hyps): + score = 0.0 + for j, w in enumerate(hyp[0]): + score += decoder_out[i][j][w] + # last decoder output token is `eos`, for laste decoder input token. + score += decoder_out[i][len(hyp[0])][self.eos] + # add ctc score (which in ln domain) + score += hyp[1] * ctc_weight + if score > best_score: + best_score = score + best_index = i + return hyps[best_index][0] + + #@jit.to_static + def subsampling_rate(self) -> int: + """ Export interface for c++ call, return subsampling_rate of the + model + """ + return self.encoder.embed.subsampling_rate + + #@jit.to_static + def right_context(self) -> int: + """ Export interface for c++ call, return right_context of the model + """ + return self.encoder.embed.right_context + + #@jit.to_static + def sos_symbol(self) -> int: + """ Export interface for c++ call, return sos symbol id of the model + """ + return self.sos + + #@jit.to_static + def eos_symbol(self) -> int: + """ Export interface for c++ call, return eos symbol id of the model + """ + return self.eos + + @jit.to_static + def forward_encoder_chunk( + self, + xs: paddle.Tensor, + offset: int, + required_cache_size: int, + subsampling_cache: Optional[paddle.Tensor]=None, + elayers_output_cache: Optional[List[paddle.Tensor]]=None, + conformer_cnn_cache: Optional[List[paddle.Tensor]]=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[ + paddle.Tensor]]: + """ Export interface for c++ call, give input chunk xs, and return + output from time 0 to current chunk. + Args: + xs (paddle.Tensor): chunk input + subsampling_cache (Optional[paddle.Tensor]): subsampling cache + elayers_output_cache (Optional[List[paddle.Tensor]]): + transformer/conformer encoder layers output cache + conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer + cnn cache + Returns: + paddle.Tensor: output, it ranges from time 0 to current chunk. + paddle.Tensor: subsampling cache + List[paddle.Tensor]: attention cache + List[paddle.Tensor]: conformer cnn cache + """ + return self.encoder.forward_chunk( + xs, offset, required_cache_size, subsampling_cache, + elayers_output_cache, conformer_cnn_cache) + + # @jit.to_static + def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: + """ Export interface for c++ call, apply linear transform and log + softmax before ctc + Args: + xs (paddle.Tensor): encoder output, (B, T, D) + Returns: + paddle.Tensor: activation before ctc + """ + return self.ctc.log_softmax(xs) + + @jit.to_static + def forward_attention_decoder( + self, + hyps: paddle.Tensor, + hyps_lens: paddle.Tensor, + encoder_out: paddle.Tensor, ) -> paddle.Tensor: + """ Export interface for c++ call, forward decoder with multiple + hypothesis from ctc prefix beam search and one encoder output + Args: + hyps (paddle.Tensor): hyps from ctc prefix beam search, already + pad sos at the begining, (B, T) + hyps_lens (paddle.Tensor): length of each hyp in hyps, (B) + encoder_out (paddle.Tensor): corresponding encoder output, (B=1, T, D) + Returns: + paddle.Tensor: decoder output, (B, L) + """ + assert encoder_out.shape[0] == 1 + num_hyps = hyps.shape[0] + assert hyps_lens.shape[0] == num_hyps + encoder_out = encoder_out.repeat(num_hyps, 1, 1) + # (B, 1, T) + encoder_mask = paddle.ones( + [num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool) + # (num_hyps, max_hyps_len, vocab_size) + decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, + hyps_lens) + decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1) + return decoder_out + + @paddle.no_grad() + def decode(self, + feats: paddle.Tensor, + feats_lengths: paddle.Tensor, + text_feature: Dict[str, int], + decoding_method: str, + lang_model_path: str, + beam_alpha: float, + beam_beta: float, + beam_size: int, + cutoff_prob: float, + cutoff_top_n: int, + num_processes: int, + ctc_weight: float=0.0, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False): + """u2 decoding. + + Args: + feats (Tenosr): audio features, (B, T, D) + feats_lengths (Tenosr): (B) + text_feature (TextFeaturizer): text feature object. + decoding_method (str): decoding mode, e.g. + 'attention', 'ctc_greedy_search', + 'ctc_prefix_beam_search', 'attention_rescoring' + lang_model_path (str): lm path. + beam_alpha (float): lm weight. + beam_beta (float): length penalty. + beam_size (int): beam size for search + cutoff_prob (float): for prune. + cutoff_top_n (int): for prune. + num_processes (int): + ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0. + decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here. + num_decoding_left_chunks (int, optional): + number of left chunks for decoding. Defaults to -1. + simulate_streaming (bool, optional): simulate streaming inference. Defaults to False. + + Raises: + ValueError: when not support decoding_method. + + Returns: + List[List[int]]: transcripts. + """ + batch_size = feats.shape[0] + if decoding_method in ['ctc_prefix_beam_search', + 'attention_rescoring'] and batch_size > 1: + logger.fatal( + f'decoding mode {decoding_method} must be running with batch_size == 1' + ) + sys.exit(1) + + if decoding_method == 'attention': + hyps = self.recognize( + feats, + feats_lengths, + beam_size=beam_size, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + simulate_streaming=simulate_streaming) + hyps = [hyp.tolist() for hyp in hyps] + elif decoding_method == 'ctc_greedy_search': + hyps = self.ctc_greedy_search( + feats, + feats_lengths, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + simulate_streaming=simulate_streaming) + # ctc_prefix_beam_search and attention_rescoring only return one + # result in List[int], change it to List[List[int]] for compatible + # with other batch decoding mode + elif decoding_method == 'ctc_prefix_beam_search': + assert feats.shape[0] == 1 + hyp = self.ctc_prefix_beam_search( + feats, + feats_lengths, + beam_size, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + simulate_streaming=simulate_streaming) + hyps = [hyp] + elif decoding_method == 'attention_rescoring': + assert feats.shape[0] == 1 + hyp = self.attention_rescoring( + feats, + feats_lengths, + beam_size, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + ctc_weight=ctc_weight, + simulate_streaming=simulate_streaming) + hyps = [hyp] + else: + raise ValueError(f"Not support decoding method: {decoding_method}") + + res = [text_feature.defeaturize(hyp) for hyp in hyps] + return res + + +class U2Model(U2BaseModel): + def __init__(self, configs: dict): + vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs) + + super().__init__( + vocab_size=vocab_size, + encoder=encoder, + decoder=decoder, + ctc=ctc, + **configs['model_conf']) + + @classmethod + def _init_from_config(cls, configs: dict): + """init sub module for model. + + Args: + configs (dict): config dict. + + Raises: + ValueError: raise when using not support encoder type. + + Returns: + int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc + """ + # cmvn + if configs['cmvn_file'] is not None: + mean, istd = load_cmvn(configs['cmvn_file'], + configs['cmvn_file_type']) + global_cmvn = GlobalCMVN( + paddle.to_tensor(mean, dtype=paddle.float), + paddle.to_tensor(istd, dtype=paddle.float)) + else: + global_cmvn = None + + # input & output dim + input_dim = configs['input_dim'] + vocab_size = configs['output_dim'] + assert input_dim != 0, input_dim + assert vocab_size != 0, vocab_size + + # encoder + encoder_type = configs.get('encoder', 'transformer') + logger.info(f"U2 Encoder type: {encoder_type}") + if encoder_type == 'transformer': + encoder = TransformerEncoder( + input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + elif encoder_type == 'conformer': + encoder = ConformerEncoder( + input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + else: + raise ValueError(f"not support encoder type:{encoder_type}") + + # decoder + decoder = TransformerDecoder(vocab_size, + encoder.output_size(), + **configs['decoder_conf']) + + # ctc decoder and ctc loss + model_conf = configs['model_conf'] + ctc = CTCDecoder( + odim=vocab_size, + enc_n_units=encoder.output_size(), + blank_id=0, + dropout_rate=model_conf['ctc_dropoutrate'], + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type=model_conf['ctc_grad_norm_type']) + + return vocab_size, encoder, decoder, ctc + + @classmethod + def from_config(cls, configs: dict): + """init model. + + Args: + configs (dict): config dict. + + Raises: + ValueError: raise when using not support encoder type. + + Returns: + nn.Layer: U2Model + """ + model = cls(configs) + return model + + @classmethod + def from_pretrained(cls, dataloader, config, checkpoint_path): + """Build a DeepSpeech2Model model from a pretrained model. + + Args: + dataloader (paddle.io.DataLoader): not used. + config (yacs.config.CfgNode): model configs + checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name + + Returns: + DeepSpeech2Model: The model built from pretrained result. + """ + with UpdateConfig(config): + config.input_dim = dataloader.collate_fn.feature_size + config.output_dim = dataloader.collate_fn.vocab_size + + model = cls.from_config(config) + + if checkpoint_path: + infos = checkpoint.Checkpoint().load_parameters( + model, checkpoint_path=checkpoint_path) + logger.info(f"checkpoint info: {infos}") + layer_tools.summary(model) + return model + + +class U2InferModel(U2Model): + def __init__(self, configs: dict): + super().__init__(configs) + + def forward(self, + feats, + feats_lengths, + decoding_chunk_size=-1, + num_decoding_left_chunks=-1, + simulate_streaming=False): + """export model function + + Args: + feats (Tensor): [B, T, D] + feats_lengths (Tensor): [B] + + Returns: + List[List[int]]: best path result + """ + return self.ctc_greedy_search( + feats, + feats_lengths, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + simulate_streaming=simulate_streaming) diff --git a/deepspeech/models/u2/updater.py b/deepspeech/models/u2/updater.py new file mode 100644 index 00000000..7b70ca04 --- /dev/null +++ b/deepspeech/models/u2/updater.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import nullcontext + +import paddle +from paddle import distributed as dist + +from deepspeech.training.extensions.evaluator import StandardEvaluator +from deepspeech.training.reporter import report +from deepspeech.training.timer import Timer +from deepspeech.training.updaters.standard_updater import StandardUpdater +from deepspeech.utils import layer_tools +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + + +class U2Evaluator(StandardEvaluator): + def __init__(self, model, dataloader): + super().__init__(model, dataloader) + self.msg = "" + self.num_seen_utts = 0 + self.total_loss = 0.0 + + def evaluate_core(self, batch): + self.msg = "Valid: Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + + loss, attention_loss, ctc_loss = self.model(*batch[1:]) + if paddle.isfinite(loss): + num_utts = batch[1].shape[0] + self.num_seen_utts += num_utts + self.total_loss += float(loss) * num_utts + + losses_dict['loss'] = float(loss) + if attention_loss: + losses_dict['att_loss'] = float(attention_loss) + if ctc_loss: + losses_dict['ctc_loss'] = float(ctc_loss) + + for k, v in losses_dict.items(): + report("eval/" + k, v) + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + logger.info(self.msg) + return self.total_loss, self.num_seen_utts + + +class U2Updater(StandardUpdater): + def __init__(self, + model, + optimizer, + scheduler, + dataloader, + init_state=None, + accum_grad=1, + **kwargs): + super().__init__( + model, optimizer, scheduler, dataloader, init_state=init_state) + self.accum_grad = accum_grad + self.forward_count = 0 + self.msg = "" + + def update_core(self, batch): + """One Step + + Args: + batch (List[Object]): utts, xs, xlens, ys, ylens + """ + losses_dict = {} + self.msg = "Rank: {}, ".format(dist.get_rank()) + + # forward + batch_size = batch[1].shape[0] + loss, attention_loss, ctc_loss = self.model(*batch[1:]) + # loss div by `batch_size * accum_grad` + loss /= self.accum_grad + + # loss backward + if (self.forward_count + 1) != self.accum_grad: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # loss info + losses_dict['loss'] = float(loss) * self.accum_grad + if attention_loss: + losses_dict['att_loss'] = float(attention_loss) + if ctc_loss: + losses_dict['ctc_loss'] = float(ctc_loss) + # report loss + for k, v in losses_dict.items(): + report("train/" + k, v) + # loss msg + self.msg += "batch size: {}, ".format(batch_size) + self.msg += "accum: {}, ".format(self.accum_grad) + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + + # Truncate the graph + loss.detach() + + # update parameters + self.forward_count += 1 + if self.forward_count != self.accum_grad: + return + self.forward_count = 0 + + self.optimizer.step() + self.optimizer.clear_grad() + self.scheduler.step() + + def update(self): + # model is default in train mode + + # training for a step is implemented here + with Timer("data time cost:{}"): + batch = self.read_batch() + with Timer("step time cost:{}"): + self.update_core(batch) + + # #iterations with accum_grad > 1 + # Ref.: https://github.com/espnet/espnet/issues/777 + if self.forward_count == 0: + self.state.iteration += 1 + if self.updates_per_epoch is not None: + if self.state.iteration % self.updates_per_epoch == 0: + self.state.epoch += 1 diff --git a/deepspeech/models/u2_st.py b/deepspeech/models/u2_st.py new file mode 100644 index 00000000..bf98423d --- /dev/null +++ b/deepspeech/models/u2_st.py @@ -0,0 +1,733 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""U2 ASR Model +Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition +(https://arxiv.org/pdf/2012.05481.pdf) +""" +import time +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import paddle +from paddle import jit +from paddle import nn +from yacs.config import CfgNode + +from deepspeech.frontend.utility import IGNORE_ID +from deepspeech.frontend.utility import load_cmvn +from deepspeech.modules.cmvn import GlobalCMVN +from deepspeech.modules.ctc import CTCDecoder +from deepspeech.modules.decoder import TransformerDecoder +from deepspeech.modules.encoder import ConformerEncoder +from deepspeech.modules.encoder import TransformerEncoder +from deepspeech.modules.loss import LabelSmoothingLoss +from deepspeech.modules.mask import mask_finished_preds +from deepspeech.modules.mask import mask_finished_scores +from deepspeech.modules.mask import subsequent_mask +from deepspeech.utils import checkpoint +from deepspeech.utils import layer_tools +from deepspeech.utils.log import Log +from deepspeech.utils.tensor_utils import add_sos_eos +from deepspeech.utils.tensor_utils import th_accuracy +from deepspeech.utils.utility import UpdateConfig + +__all__ = ["U2STModel", "U2STInferModel"] + +logger = Log(__name__).getlog() + + +class U2STBaseModel(nn.Layer): + """CTC-Attention hybrid Encoder-Decoder model""" + + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # network architecture + default = CfgNode() + # allow add new item when merge_with_file + default.cmvn_file = "" + default.cmvn_file_type = "json" + default.input_dim = 0 + default.output_dim = 0 + # encoder related + default.encoder = 'transformer' + default.encoder_conf = CfgNode( + dict( + output_size=256, # dimension of attention + attention_heads=4, + linear_units=2048, # the number of units of position-wise feed forward + num_blocks=12, # the number of encoder blocks + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer='conv2d', # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before=True, + # use_cnn_module=True, + # cnn_module_kernel=15, + # activation_type='swish', + # pos_enc_layer_type='rel_pos', + # selfattention_layer_type='rel_selfattn', + )) + # decoder related + default.decoder = 'transformer' + default.decoder_conf = CfgNode( + dict( + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + self_attention_dropout_rate=0.0, + src_attention_dropout_rate=0.0, )) + # hybrid CTC/attention + default.model_conf = CfgNode( + dict( + asr_weight=0.0, + ctc_weight=0.0, + lsm_weight=0.1, # label smoothing option + length_normalized_loss=False, )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, + vocab_size: int, + encoder: TransformerEncoder, + st_decoder: TransformerDecoder, + decoder: TransformerDecoder=None, + ctc: CTCDecoder=None, + ctc_weight: float=0.0, + asr_weight: float=0.0, + ignore_id: int=IGNORE_ID, + lsm_weight: float=0.0, + length_normalized_loss: bool=False, + **kwargs): + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + + super().__init__() + # note that eos is the same as sos (equivalent ID) + self.sos = vocab_size - 1 + self.eos = vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + self.asr_weight = asr_weight + + self.encoder = encoder + self.st_decoder = st_decoder + self.decoder = decoder + self.ctc = ctc + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, ) + + def forward( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + text: paddle.Tensor, + text_lengths: paddle.Tensor, + asr_text: paddle.Tensor=None, + asr_text_lengths: paddle.Tensor=None, + ) -> Tuple[Optional[paddle.Tensor], Optional[paddle.Tensor], Optional[ + paddle.Tensor]]: + """Frontend + Encoder + Decoder + Calc loss + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + Returns: + total_loss, attention_loss, ctc_loss + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == + text_lengths.shape[0]), (speech.shape, speech_lengths.shape, + text.shape, text_lengths.shape) + # 1. Encoder + start = time.time() + encoder_out, encoder_mask = self.encoder(speech, speech_lengths) + encoder_time = time.time() - start + #logger.debug(f"encoder time: {encoder_time}") + #TODO(Hui Zhang): sum not support bool type + #encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] + encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum( + 1) #[B, 1, T] -> [B] + + # 2a. ST-decoder branch + start = time.time() + loss_st, acc_st = self._calc_st_loss(encoder_out, encoder_mask, text, + text_lengths) + decoder_time = time.time() - start + + loss_asr_att = None + loss_asr_ctc = None + # 2b. ASR Attention-decoder branch + if self.asr_weight > 0.: + if self.ctc_weight != 1.0: + start = time.time() + loss_asr_att, acc_att = self._calc_att_loss( + encoder_out, encoder_mask, asr_text, asr_text_lengths) + decoder_time = time.time() - start + + # 2c. CTC branch + if self.ctc_weight != 0.0: + start = time.time() + loss_asr_ctc = self.ctc(encoder_out, encoder_out_lens, asr_text, + asr_text_lengths) + ctc_time = time.time() - start + + if loss_asr_ctc is None: + loss_asr = loss_asr_att + elif loss_asr_att is None: + loss_asr = loss_asr_ctc + else: + loss_asr = self.ctc_weight * loss_asr_ctc + (1 - self.ctc_weight + ) * loss_asr_att + loss = self.asr_weight * loss_asr + (1 - self.asr_weight) * loss_st + else: + loss = loss_st + return loss, loss_st, loss_asr_att, loss_asr_ctc + + def _calc_st_loss( + self, + encoder_out: paddle.Tensor, + encoder_mask: paddle.Tensor, + ys_pad: paddle.Tensor, + ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]: + """Calc attention loss. + + Args: + encoder_out (paddle.Tensor): [B, Tmax, D] + encoder_mask (paddle.Tensor): [B, 1, Tmax] + ys_pad (paddle.Tensor): [B, Umax] + ys_pad_lens (paddle.Tensor): [B] + + Returns: + Tuple[paddle.Tensor, float]: attention_loss, accuracy rate + """ + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad, + ys_in_lens) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, ) + return loss_att, acc_att + + def _calc_att_loss( + self, + encoder_out: paddle.Tensor, + encoder_mask: paddle.Tensor, + ys_pad: paddle.Tensor, + ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]: + """Calc attention loss. + + Args: + encoder_out (paddle.Tensor): [B, Tmax, D] + encoder_mask (paddle.Tensor): [B, 1, Tmax] + ys_pad (paddle.Tensor): [B, Umax] + ys_pad_lens (paddle.Tensor): [B] + + Returns: + Tuple[paddle.Tensor, float]: attention_loss, accuracy rate + """ + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad, + ys_in_lens) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, ) + return loss_att, acc_att + + def _forward_encoder( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Encoder pass. + + Args: + speech (paddle.Tensor): [B, Tmax, D] + speech_lengths (paddle.Tensor): [B] + decoding_chunk_size (int, optional): chuck size. Defaults to -1. + num_decoding_left_chunks (int, optional): nums chunks. Defaults to -1. + simulate_streaming (bool, optional): streaming or not. Defaults to False. + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: + encoder hiddens (B, Tmax, D), + encoder hiddens mask (B, 1, Tmax). + """ + # Let's assume B = batch_size + # 1. Encoder + if simulate_streaming and decoding_chunk_size > 0: + encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk( + speech, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + else: + encoder_out, encoder_mask = self.encoder( + speech, + speech_lengths, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + return encoder_out, encoder_mask + + def translate( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + beam_size: int=10, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False, ) -> paddle.Tensor: + """ Apply beam search on attention decoder + Args: + speech (paddle.Tensor): (batch, max_len, feat_dim) + speech_length (paddle.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + paddle.Tensor: decoding result, (batch, max_result_len) + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + device = speech.place + batch_size = speech.shape[0] + + # Let's assume B = batch_size and N = beam_size + # 1. Encoder + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.shape[1] + encoder_dim = encoder_out.shape[2] + running_size = batch_size * beam_size + encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( + running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) + encoder_mask = encoder_mask.unsqueeze(1).repeat( + 1, beam_size, 1, 1).view(running_size, 1, + maxlen) # (B*N, 1, max_len) + + hyps = paddle.ones( + [running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1) + # log scale score + scores = paddle.to_tensor( + [0.0] + [-float('inf')] * (beam_size - 1), dtype=paddle.float) + scores = scores.to(device).repeat(batch_size).unsqueeze(1).to( + device) # (B*N, 1) + end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1) + cache: Optional[List[paddle.Tensor]] = None + # 2. Decoder forward step by step + for i in range(1, maxlen + 1): + # Stop if all batch and all beam produce eos + # TODO(Hui Zhang): if end_flag.sum() == running_size: + if end_flag.cast(paddle.int64).sum() == running_size: + break + + # 2.1 Forward decoder step + hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( + running_size, 1, 1).to(device) # (B*N, i, i) + # logp: (B*N, vocab) + logp, cache = self.st_decoder.forward_one_step( + encoder_out, encoder_mask, hyps, hyps_mask, cache) + + # 2.2 First beam prune: select topk best prob at current time + top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) + top_k_logp = mask_finished_scores(top_k_logp, end_flag) + top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos) + + # 2.3 Seconde beam prune: select topk score with history + scores = scores + top_k_logp # (B*N, N), broadcast add + scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) + scores, offset_k_index = scores.topk(k=beam_size) # (B, N) + scores = scores.view(-1, 1) # (B*N, 1) + + # 2.4. Compute base index in top_k_index, + # regard top_k_index as (B*N*N),regard offset_k_index as (B*N), + # then find offset_k_index in top_k_index + base_k_index = paddle.arange(batch_size).view(-1, 1).repeat( + 1, beam_size) # (B, N) + base_k_index = base_k_index * beam_size * beam_size + best_k_index = base_k_index.view(-1) + offset_k_index.view( + -1) # (B*N) + + # 2.5 Update best hyps + best_k_pred = paddle.index_select( + top_k_index.view(-1), index=best_k_index, axis=0) # (B*N) + best_hyps_index = best_k_index // beam_size + last_best_k_hyps = paddle.index_select( + hyps, index=best_hyps_index, axis=0) # (B*N, i) + hyps = paddle.cat( + (last_best_k_hyps, best_k_pred.view(-1, 1)), + dim=1) # (B*N, i+1) + + # 2.6 Update end flag + end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1) + + # 3. Select best of best + scores = scores.view(batch_size, beam_size) + # TODO: length normalization + best_index = paddle.argmax(scores, axis=-1).long() # (B) + best_hyps_index = best_index + paddle.arange( + batch_size, dtype=paddle.long) * beam_size + best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0) + best_hyps = best_hyps[:, 1:] + return best_hyps + + # @jit.to_static + def subsampling_rate(self) -> int: + """ Export interface for c++ call, return subsampling_rate of the + model + """ + return self.encoder.embed.subsampling_rate + + # @jit.to_static + def right_context(self) -> int: + """ Export interface for c++ call, return right_context of the model + """ + return self.encoder.embed.right_context + + # @jit.to_static + def sos_symbol(self) -> int: + """ Export interface for c++ call, return sos symbol id of the model + """ + return self.sos + + # @jit.to_static + def eos_symbol(self) -> int: + """ Export interface for c++ call, return eos symbol id of the model + """ + return self.eos + + @jit.to_static + def forward_encoder_chunk( + self, + xs: paddle.Tensor, + offset: int, + required_cache_size: int, + subsampling_cache: Optional[paddle.Tensor]=None, + elayers_output_cache: Optional[List[paddle.Tensor]]=None, + conformer_cnn_cache: Optional[List[paddle.Tensor]]=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[ + paddle.Tensor]]: + """ Export interface for c++ call, give input chunk xs, and return + output from time 0 to current chunk. + Args: + xs (paddle.Tensor): chunk input + subsampling_cache (Optional[paddle.Tensor]): subsampling cache + elayers_output_cache (Optional[List[paddle.Tensor]]): + transformer/conformer encoder layers output cache + conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer + cnn cache + Returns: + paddle.Tensor: output, it ranges from time 0 to current chunk. + paddle.Tensor: subsampling cache + List[paddle.Tensor]: attention cache + List[paddle.Tensor]: conformer cnn cache + """ + return self.encoder.forward_chunk( + xs, offset, required_cache_size, subsampling_cache, + elayers_output_cache, conformer_cnn_cache) + + # @jit.to_static + def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: + """ Export interface for c++ call, apply linear transform and log + softmax before ctc + Args: + xs (paddle.Tensor): encoder output + Returns: + paddle.Tensor: activation before ctc + """ + return self.ctc.log_softmax(xs) + + @jit.to_static + def forward_attention_decoder( + self, + hyps: paddle.Tensor, + hyps_lens: paddle.Tensor, + encoder_out: paddle.Tensor, ) -> paddle.Tensor: + """ Export interface for c++ call, forward decoder with multiple + hypothesis from ctc prefix beam search and one encoder output + Args: + hyps (paddle.Tensor): hyps from ctc prefix beam search, already + pad sos at the begining, (B, T) + hyps_lens (paddle.Tensor): length of each hyp in hyps, (B) + encoder_out (paddle.Tensor): corresponding encoder output, (B=1, T, D) + Returns: + paddle.Tensor: decoder output, (B, L) + """ + assert encoder_out.shape[0] == 1 + num_hyps = hyps.shape[0] + assert hyps_lens.shape[0] == num_hyps + encoder_out = encoder_out.repeat(num_hyps, 1, 1) + # (B, 1, T) + encoder_mask = paddle.ones( + [num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool) + # (num_hyps, max_hyps_len, vocab_size) + decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, + hyps_lens) + decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1) + return decoder_out + + @paddle.no_grad() + def decode(self, + feats: paddle.Tensor, + feats_lengths: paddle.Tensor, + text_feature: Dict[str, int], + decoding_method: str, + lang_model_path: str, + beam_alpha: float, + beam_beta: float, + beam_size: int, + cutoff_prob: float, + cutoff_top_n: int, + num_processes: int, + ctc_weight: float=0.0, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False): + """u2 decoding. + + Args: + feats (Tenosr): audio features, (B, T, D) + feats_lengths (Tenosr): (B) + text_feature (TextFeaturizer): text feature object. + decoding_method (str): decoding mode, e.g. + 'fullsentence', + 'simultaneous' + lang_model_path (str): lm path. + beam_alpha (float): lm weight. + beam_beta (float): length penalty. + beam_size (int): beam size for search + cutoff_prob (float): for prune. + cutoff_top_n (int): for prune. + num_processes (int): + ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0. + decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here. + num_decoding_left_chunks (int, optional): + number of left chunks for decoding. Defaults to -1. + simulate_streaming (bool, optional): simulate streaming inference. Defaults to False. + + Raises: + ValueError: when not support decoding_method. + + Returns: + List[List[int]]: transcripts. + """ + batch_size = feats.shape[0] + + if decoding_method == 'fullsentence': + hyps = self.translate( + feats, + feats_lengths, + beam_size=beam_size, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + simulate_streaming=simulate_streaming) + hyps = [hyp.tolist() for hyp in hyps] + else: + raise ValueError(f"Not support decoding method: {decoding_method}") + + res = [text_feature.defeaturize(hyp) for hyp in hyps] + return res + + +class U2STModel(U2STBaseModel): + def __init__(self, configs: dict): + vocab_size, encoder, decoder = U2STModel._init_from_config(configs) + + if isinstance(decoder, Tuple): + st_decoder, asr_decoder, ctc = decoder + super().__init__( + vocab_size=vocab_size, + encoder=encoder, + st_decoder=st_decoder, + decoder=asr_decoder, + ctc=ctc, + **configs['model_conf']) + else: + super().__init__( + vocab_size=vocab_size, + encoder=encoder, + st_decoder=decoder, + **configs['model_conf']) + + @classmethod + def _init_from_config(cls, configs: dict): + """init sub module for model. + + Args: + configs (dict): config dict. + + Raises: + ValueError: raise when using not support encoder type. + + Returns: + int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc + """ + if configs['cmvn_file'] is not None: + mean, istd = load_cmvn(configs['cmvn_file'], + configs['cmvn_file_type']) + global_cmvn = GlobalCMVN( + paddle.to_tensor(mean, dtype=paddle.float), + paddle.to_tensor(istd, dtype=paddle.float)) + else: + global_cmvn = None + + input_dim = configs['input_dim'] + vocab_size = configs['output_dim'] + assert input_dim != 0, input_dim + assert vocab_size != 0, vocab_size + + encoder_type = configs.get('encoder', 'transformer') + logger.info(f"U2 Encoder type: {encoder_type}") + if encoder_type == 'transformer': + encoder = TransformerEncoder( + input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + elif encoder_type == 'conformer': + encoder = ConformerEncoder( + input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + else: + raise ValueError(f"not support encoder type:{encoder_type}") + + st_decoder = TransformerDecoder(vocab_size, + encoder.output_size(), + **configs['decoder_conf']) + + asr_weight = configs['model_conf']['asr_weight'] + logger.info(f"ASR Joint Training Weight: {asr_weight}") + + if asr_weight > 0.: + decoder = TransformerDecoder(vocab_size, + encoder.output_size(), + **configs['decoder_conf']) + # ctc decoder and ctc loss + model_conf = configs['model_conf'] + ctc = CTCDecoder( + odim=vocab_size, + enc_n_units=encoder.output_size(), + blank_id=0, + dropout_rate=model_conf['ctc_dropoutrate'], + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type=model_conf['ctc_grad_norm_type']) + + return vocab_size, encoder, (st_decoder, decoder, ctc) + else: + return vocab_size, encoder, st_decoder + + @classmethod + def from_config(cls, configs: dict): + """init model. + + Args: + configs (dict): config dict. + + Raises: + ValueError: raise when using not support encoder type. + + Returns: + nn.Layer: U2STModel + """ + model = cls(configs) + return model + + @classmethod + def from_pretrained(cls, dataloader, config, checkpoint_path): + """Build a DeepSpeech2Model model from a pretrained model. + + Args: + dataloader (paddle.io.DataLoader): not used. + config (yacs.config.CfgNode): model configs + checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name + + Returns: + DeepSpeech2Model: The model built from pretrained result. + """ + with UpdateConfig(config): + config.input_dim = dataloader.collate_fn.feature_size + config.output_dim = dataloader.collate_fn.vocab_size + + model = cls.from_config(config) + + if checkpoint_path: + infos = checkpoint.load_parameters( + model, checkpoint_path=checkpoint_path) + logger.info(f"checkpoint info: {infos}") + layer_tools.summary(model) + return model + + +class U2STInferModel(U2STModel): + def __init__(self, configs: dict): + super().__init__(configs) + + def forward(self, + feats, + feats_lengths, + decoding_chunk_size=-1, + num_decoding_left_chunks=-1, + simulate_streaming=False): + """export model function + + Args: + feats (Tensor): [B, T, D] + feats_lengths (Tensor): [B] + + Returns: + List[List[int]]: best path result + """ + return self.translate( + feats, + feats_lengths, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + simulate_streaming=simulate_streaming) diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py index 31e489a3..551bbf67 100644 --- a/deepspeech/modules/ctc.py +++ b/deepspeech/modules/ctc.py @@ -16,15 +16,19 @@ from paddle import nn from paddle.nn import functional as F from typeguard import check_argument_types -from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch -from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder -from deepspeech.decoders.swig_wrapper import Scorer from deepspeech.modules.loss import CTCLoss from deepspeech.utils import ctc_utils from deepspeech.utils.log import Log logger = Log(__name__).getlog() +try: + from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401 + from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder # noqa: F401 + from deepspeech.decoders.swig_wrapper import Scorer # noqa: F401 +except Exception as e: + logger.info("ctcdecoder not installed!") + __all__ = ['CTCDecoder'] @@ -35,7 +39,8 @@ class CTCDecoder(nn.Layer): blank_id=0, dropout_rate: float=0.0, reduction: bool=True, - batch_average: bool=True): + batch_average: bool=True, + grad_norm_type: str="instance"): """CTC decoder Args: @@ -44,19 +49,21 @@ class CTCDecoder(nn.Layer): dropout_rate (float): dropout rate (0.0 ~ 1.0) reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none' batch_average (bool): do batch dim wise average. + grad_norm_type (str): one of 'instance', 'batch', 'frame', None. """ assert check_argument_types() super().__init__() self.blank_id = blank_id self.odim = odim - self.dropout_rate = dropout_rate + self.dropout = nn.Dropout(dropout_rate) self.ctc_lo = nn.Linear(enc_n_units, self.odim) reduction_type = "sum" if reduction else "none" self.criterion = CTCLoss( blank=self.blank_id, reduction=reduction_type, - batch_average=batch_average) + batch_average=batch_average, + grad_norm_type=grad_norm_type) # CTCDecoder LM Score handle self._ext_scorer = None @@ -72,7 +79,7 @@ class CTCDecoder(nn.Layer): Returns: loss (Tenosr): ctc loss value, scalar. """ - logits = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) + logits = self.ctc_lo(self.dropout(hs_pad)) loss = self.criterion(logits, ys_pad, hlens, ys_lens) return loss @@ -132,7 +139,7 @@ class CTCDecoder(nn.Layer): results = [] for i, probs in enumerate(probs_split): output_transcription = ctc_greedy_decoder( - probs_seq=probs, vocabulary=vocab_list) + probs_seq=probs, vocabulary=vocab_list, blank_id=self.blank_id) results.append(output_transcription) return results @@ -212,13 +219,15 @@ class CTCDecoder(nn.Layer): num_processes=num_processes, ext_scoring_func=self._ext_scorer, cutoff_prob=cutoff_prob, - cutoff_top_n=cutoff_top_n) + cutoff_top_n=cutoff_top_n, + blank_id=self.blank_id) results = [result[0][1] for result in beam_search_results] return results def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list, decoding_method): + if decoding_method == "ctc_beam_search": self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, vocab_list) @@ -229,7 +238,7 @@ class CTCDecoder(nn.Layer): """ctc decoding with probs. Args: - probs (Tenosr): activation after softmax + probs (Tenosr): activation after softmax logits_lens (Tenosr): audio output lens vocab_list ([type]): [description] decoding_method ([type]): [description] diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index 3e441bbb..517e1d44 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -23,7 +23,7 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"] class CTCLoss(nn.Layer): - def __init__(self, blank=0, reduction='sum', batch_average=False): + def __init__(self, blank=0, reduction='sum', batch_average=False, grad_norm_type=None): super().__init__() # last token id as blank id self.loss = nn.CTCLoss(blank=blank, reduction=reduction) @@ -89,8 +89,8 @@ class LabelSmoothingLoss(nn.Layer): size (int): the number of class padding_idx (int): padding class id which will be ignored for loss smoothing (float): smoothing rate (0.0 means the conventional CE) - normalize_length (bool): - True, normalize loss by sequence length; + normalize_length (bool): + True, normalize loss by sequence length; False, normalize loss by batch size. Defaults to False. """ @@ -107,7 +107,7 @@ class LabelSmoothingLoss(nn.Layer): The model outputs and data labels tensors are flatten to (batch*seqlen, class) shape and a mask is applied to the padding part which should not be calculated for loss. - + Args: x (paddle.Tensor): prediction (batch, seqlen, class) target (paddle.Tensor): diff --git a/deepspeech/training/cli.py b/deepspeech/training/cli.py index b83d989d..e079293c 100644 --- a/deepspeech/training/cli.py +++ b/deepspeech/training/cli.py @@ -14,25 +14,39 @@ import argparse +class ExtendAction(argparse.Action): + """ + [Since Python 3.8, the "extend" is available directly in stdlib] + (https://docs.python.org/3.8/library/argparse.html#action). + If you only have to support 3.8+ then defining it yourself is no longer required. + Usage of stdlib "extend" action is exactly the same way as this answer originally described: + """ + + def __call__(self, parser, namespace, values, option_string=None): + items = getattr(namespace, self.dest) or [] + items.extend(values) + setattr(namespace, self.dest, items) + + def default_argument_parser(): r"""A simple yet genral argument parser for experiments with parakeet. - - This is used in examples with parakeet. And it is intended to be used by - other experiments with parakeet. It requires a minimal set of command line + + This is used in examples with parakeet. And it is intended to be used by + other experiments with parakeet. It requires a minimal set of command line arguments to start a training script. - - The ``--config`` and ``--opts`` are used for overwrite the deault + + The ``--config`` and ``--opts`` are used for overwrite the deault configuration. - - The ``--data`` and ``--output`` specifies the data path and output path. - Resuming training from existing progress at the output directory is the + + The ``--data`` and ``--output`` specifies the data path and output path. + Resuming training from existing progress at the output directory is the intended default behavior. - + The ``--checkpoint_path`` specifies the checkpoint to load from. - - The ``--device`` and ``--nprocs`` specifies how to run the training. - - + + The ``--nprocs`` specifies how to run the training. + + See Also -------- parakeet.training.experiment @@ -42,33 +56,53 @@ def default_argument_parser(): the parser """ parser = argparse.ArgumentParser() + parser.register('action', 'extend', ExtendAction) - # yapf: disable - # data and output - parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.") - parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.") - # parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.") - parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.") - - # load from saved checkpoint - parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load") - - # save jit model to - parser.add_argument("--export_path", type=str, help="path of the jit model to save") - - # save asr result to - parser.add_argument("--result_file", type=str, help="path of save the asr result") - - # running - parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"], - help="device type to use, cpu and gpu are supported.") - parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.") + train_group = parser.add_argument_group( + title='Train Options', description=None) + train_group.add_argument( + "--seed", + type=int, + default=None, + help="seed to use for paddle, np and random. None or 0 for random, else set seed." + ) + train_group.add_argument( + "--nprocs", + type=int, + default=1, + help="number of parallel processes. 0 for cpu.") + train_group.add_argument( + "--config", metavar="CONFIG_FILE", help="config file.") + train_group.add_argument( + "--output", metavar="CKPT_DIR", help="path to save checkpoint.") + train_group.add_argument( + "--checkpoint_path", type=str, help="path to load checkpoint") + train_group.add_argument( + "--opts", + action='extend', + nargs=2, + metavar=('key', 'val'), + help="overwrite --config field, passing (KEY VALUE) pairs") + train_group.add_argument( + "--dump-config", metavar="FILE", help="dump config to `this` file.") - # overwrite extra config and default config - # parser.add_argument("--opts", nargs=argparse.REMAINDER, - # help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") - parser.add_argument("--opts", type=str, default=[], nargs='+', - help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") - # yapd: enable + profile_group = parser.add_argument_group( + title='Benchmark Options', description=None) + profile_group.add_argument( + '--profiler-options', + type=str, + default=None, + help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".' + ) + profile_group.add_argument( + '--benchmark-batch-size', + type=int, + default=None, + help='batch size for benchmark.') + profile_group.add_argument( + '--benchmark-max-step', + type=int, + default=None, + help='max iteration for benchmark.') return parser diff --git a/deepspeech/training/extensions/__init__.py b/deepspeech/training/extensions/__init__.py new file mode 100644 index 00000000..6ad04155 --- /dev/null +++ b/deepspeech/training/extensions/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable + +from .extension import Extension + + +def make_extension(trigger: Callable=None, + default_name: str=None, + priority: int=None, + finalizer: Callable=None, + initializer: Callable=None, + on_error: Callable=None): + """Make an Extension-like object by injecting required attributes to it. + """ + if trigger is None: + trigger = Extension.trigger + if priority is None: + priority = Extension.priority + + def decorator(ext): + ext.trigger = trigger + ext.default_name = default_name or ext.__name__ + ext.priority = priority + ext.finalize = finalizer + ext.on_error = on_error + ext.initialize = initializer + return ext + + return decorator diff --git a/deepspeech/training/extensions/evaluator.py b/deepspeech/training/extensions/evaluator.py new file mode 100644 index 00000000..1026a4ec --- /dev/null +++ b/deepspeech/training/extensions/evaluator.py @@ -0,0 +1,101 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict + +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.nn import Layer + +from . import extension +from ..reporter import DictSummary +from ..reporter import ObsScope +from ..reporter import report +from ..timer import Timer +from deepspeech.utils.log import Log +logger = Log(__name__).getlog() + + +class StandardEvaluator(extension.Extension): + + trigger = (1, 'epoch') + default_name = 'validation' + priority = extension.PRIORITY_WRITER + + name = None + + def __init__(self, model: Layer, dataloader: DataLoader): + # it is designed to hold multiple models + models = {"main": model} + self.models: Dict[str, Layer] = models + self.model = model + + # dataloaders + self.dataloader = dataloader + + def evaluate_core(self, batch): + # compute + self.model(batch) # you may report here + return + + def evaluate_sync(self, data): + # dist sync `evaluate_core` outputs + if data is None: + return + + numerator, denominator = data + if dist.get_world_size() > 1: + numerator = paddle.to_tensor(numerator) + denominator = paddle.to_tensor(denominator) + # the default operator in all_reduce function is sum. + dist.all_reduce(numerator) + dist.all_reduce(denominator) + value = numerator / denominator + value = float(value) + else: + value = numerator / denominator + # used for `snapshort` to do kbest save. + report("VALID/LOSS", value) + logger.info(f"Valid: all-reduce loss {value}") + + def evaluate(self): + # switch to eval mode + for model in self.models.values(): + model.eval() + + # to average evaluation metrics + summary = DictSummary() + for batch in self.dataloader: + observation = {} + with ObsScope(observation): + # main evaluation computation here. + with paddle.no_grad(): + self.evaluate_sync(self.evaluate_core(batch)) + summary.add(observation) + summary = summary.compute_mean() + + # switch to train mode + for model in self.models.values(): + model.train() + return summary + + def __call__(self, trainer=None): + # evaluate and report the averaged metric to current observation + # if it is used to extend a trainer, the metrics is reported to + # to observation of the trainer + # or otherwise, you can use your own observation + with Timer("Eval Time Cost: {}"): + summary = self.evaluate() + for k, v in summary.items(): + report(k, v) diff --git a/deepspeech/training/extensions/extension.py b/deepspeech/training/extensions/extension.py new file mode 100644 index 00000000..02f92495 --- /dev/null +++ b/deepspeech/training/extensions/extension.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +PRIORITY_WRITER = 300 +PRIORITY_EDITOR = 200 +PRIORITY_READER = 100 + + +class Extension(): + """Extension to customize the behavior of Trainer.""" + trigger = (1, 'iteration') + priority = PRIORITY_READER + name = None + + @property + def default_name(self): + """Default name of the extension, class name by default.""" + return type(self).__name__ + + def __call__(self, trainer): + """Main action of the extention. After each update, it is executed + when the trigger fires.""" + raise NotImplementedError( + 'Extension implementation must override __call__.') + + def initialize(self, trainer): + """Action that is executed once to get the corect trainer state. + It is called before training normally, but if the trainer restores + states with an Snapshot extension, this method should also be called. + """ + pass + + def on_error(self, trainer, exc, tb): + """Handles the error raised during training before finalization. + """ + pass + + def finalize(self, trainer): + """Action that is executed when training is done. + For example, visualizers would need to be closed. + """ + pass diff --git a/deepspeech/training/extensions/snapshot.py b/deepspeech/training/extensions/snapshot.py new file mode 100644 index 00000000..e81eb97f --- /dev/null +++ b/deepspeech/training/extensions/snapshot.py @@ -0,0 +1,133 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from datetime import datetime +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List + +import jsonlines + +from . import extension +from ..reporter import get_observations +from ..updaters.trainer import Trainer +from deepspeech.utils.log import Log +from deepspeech.utils.mp_tools import rank_zero_only + +logger = Log(__name__).getlog() + + +def load_records(records_fp): + """Load record files (json lines.)""" + with jsonlines.open(records_fp, 'r') as reader: + records = list(reader) + return records + + +class Snapshot(extension.Extension): + """An extension to make snapshot of the updater object inside + the trainer. It is done by calling the updater's `save` method. + An Updater save its state_dict by default, which contains the + updater state, (i.e. epoch and iteration) and all the model + parameters and optimizer states. If the updater inside the trainer + subclasses StandardUpdater, everything is good to go. + Parameters + ---------- + checkpoint_dir : Union[str, Path] + The directory to save checkpoints into. + """ + + trigger = (1, 'epoch') + priority = -100 + default_name = "snapshot" + + def __init__(self, + mode='latest', + max_size: int=5, + indicator=None, + less_better=True, + snapshot_on_error: bool=False): + self.records: List[Dict[str, Any]] = [] + assert mode in ('latest', 'kbest'), mode + if mode == 'kbest': + assert indicator is not None + self.mode = mode + self.indicator = indicator + self.less_is_better = less_better + self.max_size = max_size + self._snapshot_on_error = snapshot_on_error + self._save_all = (max_size == -1) + self.checkpoint_dir = None + + def initialize(self, trainer: Trainer): + """Setting up this extention.""" + self.checkpoint_dir = trainer.out / "checkpoints" + + # load existing records + record_path: Path = self.checkpoint_dir / "records.jsonl" + if record_path.exists(): + self.records = load_records(record_path) + ckpt_path = self.records[-1]['path'] + logger.info(f"Loading from an existing checkpoint {ckpt_path}") + trainer.updater.load(ckpt_path) + + def on_error(self, trainer, exc, tb): + if self._snapshot_on_error: + self.save_checkpoint_and_update(trainer, 'latest') + + def __call__(self, trainer: Trainer): + self.save_checkpoint_and_update(trainer, self.mode) + + def full(self): + """Whether the number of snapshots it keeps track of is greater + than the max_size.""" + return (not self._save_all) and len(self.records) > self.max_size + + @rank_zero_only + def save_checkpoint_and_update(self, trainer: Trainer, mode: str): + """Saving new snapshot and remove the oldest snapshot if needed.""" + iteration = trainer.updater.state.iteration + epoch = trainer.updater.state.epoch + num = epoch if self.trigger[1] == 'epoch' else iteration + path = self.checkpoint_dir / f"{num}.np" + + # add the new one + trainer.updater.save(path) + record = { + "time": str(datetime.now()), + 'path': str(path.resolve()), # use absolute path + 'iteration': iteration, + 'epoch': epoch, + 'indicator': get_observations()[self.indicator] + } + self.records.append(record) + + # remove the earist + if self.full(): + if mode == 'kbest': + self.records = sorted( + self.records, + key=lambda record: record['indicator'], + reverse=not self.less_is_better) + eariest_record = self.records[0] + os.remove(eariest_record["path"]) + self.records.pop(0) + + # update the record file + record_path = self.checkpoint_dir / "records.jsonl" + with jsonlines.open(record_path, 'w') as writer: + for record in self.records: + # jsonlines.open may return a Writer or a Reader + writer.write(record) # pylint: disable=no-member diff --git a/deepspeech/training/extensions/visualizer.py b/deepspeech/training/extensions/visualizer.py new file mode 100644 index 00000000..e5f456ca --- /dev/null +++ b/deepspeech/training/extensions/visualizer.py @@ -0,0 +1,39 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from visualdl import LogWriter + +from . import extension +from ..updaters.trainer import Trainer + + +class VisualDL(extension.Extension): + """A wrapper of visualdl log writer. It assumes that the metrics to be visualized + are all scalars which are recorded into the `.observation` dictionary of the + trainer object. The dictionary is created for each step, thus the visualdl log + writer uses the iteration from the updater's `iteration` as the global step to + add records. + """ + trigger = (1, 'iteration') + default_name = 'visualdl' + priority = extension.PRIORITY_READER + + def __init__(self, output_dir): + self.writer = LogWriter(str(output_dir)) + + def __call__(self, trainer: Trainer): + for k, v in trainer.observation.items(): + self.writer.add_scalar(k, v, step=trainer.updater.state.iteration) + + def finalize(self, trainer): + self.writer.close() diff --git a/deepspeech/training/gradclip.py b/deepspeech/training/gradclip.py index d0f9803d..87b36aca 100644 --- a/deepspeech/training/gradclip.py +++ b/deepspeech/training/gradclip.py @@ -27,6 +27,9 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): def __init__(self, clip_norm): super().__init__(clip_norm) + def __repr__(self): + return f"{self.__class__.__name__}(global_clip_norm={self.clip_norm})" + @imperative_base.no_grad def _dygraph_clip(self, params_grads): params_and_grads = [] @@ -44,7 +47,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): sum_square = layers.reduce_sum(square) sum_square_list.append(sum_square) - # debug log + # debug log, not dump all since slow down train process if i < 10: logger.debug( f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }") @@ -73,7 +76,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): new_grad = layers.elementwise_mul(x=g, y=clip_var) params_and_grads.append((p, new_grad)) - # debug log + # debug log, not dump all since slow down train process if i < 10: logger.debug( f"Grad After Clip: {p.name}: {float(new_grad.square().sum().sqrt())}" diff --git a/deepspeech/training/optimizer.py b/deepspeech/training/optimizer.py new file mode 100644 index 00000000..db7069c9 --- /dev/null +++ b/deepspeech/training/optimizer.py @@ -0,0 +1,121 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any +from typing import Dict +from typing import Text + +import paddle +from paddle.optimizer import Optimizer +from paddle.regularizer import L2Decay + +from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog +from deepspeech.utils.dynamic_import import dynamic_import +from deepspeech.utils.dynamic_import import instance_class +from deepspeech.utils.log import Log + +__all__ = ["OptimizerFactory"] + +logger = Log(__name__).getlog() + +OPTIMIZER_DICT = { + "sgd": "paddle.optimizer:SGD", + "momentum": "paddle.optimizer:Momentum", + "adadelta": "paddle.optimizer:Adadelta", + "adam": "paddle.optimizer:Adam", + "adamw": "paddle.optimizer:AdamW", +} + + +def register_optimizer(cls): + """Register optimizer.""" + alias = cls.__name__.lower() + OPTIMIZER_DICT[cls.__name__.lower()] = cls.__module__ + ":" + cls.__name__ + return cls + + +@register_optimizer +class Noam(paddle.optimizer.Adam): + """Seem to: espnet/nets/pytorch_backend/transformer/optimizer.py """ + + def __init__(self, + learning_rate=0, + beta1=0.9, + beta2=0.98, + epsilon=1e-9, + parameters=None, + weight_decay=None, + grad_clip=None, + lazy_mode=False, + multi_precision=False, + name=None): + super().__init__( + learning_rate=learning_rate, + beta1=beta1, + beta2=beta2, + epsilon=epsilon, + parameters=parameters, + weight_decay=weight_decay, + grad_clip=grad_clip, + lazy_mode=lazy_mode, + multi_precision=multi_precision, + name=name) + + def __repr__(self): + echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> " + echo += f"learning_rate: {self._learning_rate}, " + echo += f"(beta1: {self._beta1} beta2: {self._beta2}), " + echo += f"epsilon: {self._epsilon}" + + +def dynamic_import_optimizer(module): + """Import Optimizer class dynamically. + + Args: + module (str): module_name:class_name or alias in `OPTIMIZER_DICT` + + Returns: + type: Optimizer class + + """ + module_class = dynamic_import(module, OPTIMIZER_DICT) + assert issubclass(module_class, + Optimizer), f"{module} does not implement Optimizer" + return module_class + + +class OptimizerFactory(): + @classmethod + def from_args(cls, name: str, args: Dict[Text, Any]): + assert "parameters" in args, "parameters not in args." + assert "learning_rate" in args, "learning_rate not in args." + + grad_clip = ClipGradByGlobalNormWithLog( + args['grad_clip']) if "grad_clip" in args else None + weight_decay = L2Decay( + args['weight_decay']) if "weight_decay" in args else None + if weight_decay: + logger.info(f'') + if grad_clip: + logger.info(f'') + + module_class = dynamic_import_optimizer(name.lower()) + args.update({"grad_clip": grad_clip, "weight_decay": weight_decay}) + opt = instance_class(module_class, args) + if "__repr__" in vars(opt): + logger.info(f"{opt}") + else: + logger.info( + f" LR: {args['learning_rate']}" + ) + return opt diff --git a/deepspeech/training/reporter.py b/deepspeech/training/reporter.py new file mode 100644 index 00000000..7afc33f3 --- /dev/null +++ b/deepspeech/training/reporter.py @@ -0,0 +1,144 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import math +from collections import defaultdict + +OBSERVATIONS = None + + +@contextlib.contextmanager +def ObsScope(observations): + # make `observation` the target to report to. + # it is basically a dictionary that stores temporary observations + global OBSERVATIONS + old = OBSERVATIONS + OBSERVATIONS = observations + + try: + yield + finally: + OBSERVATIONS = old + + +def get_observations(): + global OBSERVATIONS + return OBSERVATIONS + + +def report(name, value): + # a simple function to report named value + # you can use it everywhere, it will get the default target and writ to it + # you can think of it as std.out + observations = get_observations() + if observations is None: + return + else: + observations[name] = value + + +class Summary(): + """Online summarization of a sequence of scalars. + Summary computes the statistics of given scalars online. + """ + + def __init__(self): + self._x = 0.0 + self._x2 = 0.0 + self._n = 0 + + def add(self, value, weight=1): + """Adds a scalar value. + Args: + value: Scalar value to accumulate. It is either a NumPy scalar or + a zero-dimensional array (on CPU or GPU). + weight: An optional weight for the value. It is a NumPy scalar or + a zero-dimensional array (on CPU or GPU). + Default is 1 (integer). + """ + self._x += weight * value + self._x2 += weight * value * value + self._n += weight + + def compute_mean(self): + """Computes the mean.""" + x, n = self._x, self._n + return x / n + + def make_statistics(self): + """Computes and returns the mean and standard deviation values. + Returns: + tuple: Mean and standard deviation values. + """ + x, n = self._x, self._n + mean = x / n + var = self._x2 / n - mean * mean + std = math.sqrt(var) + return mean, std + + +class DictSummary(): + """Online summarization of a sequence of dictionaries. + ``DictSummary`` computes the statistics of a given set of scalars online. + It only computes the statistics for scalar values and variables of scalar + values in the dictionaries. + """ + + def __init__(self): + self._summaries = defaultdict(Summary) + + def add(self, d): + """Adds a dictionary of scalars. + Args: + d (dict): Dictionary of scalars to accumulate. Only elements of + scalars, zero-dimensional arrays, and variables of + zero-dimensional arrays are accumulated. When the value + is a tuple, the second element is interpreted as a weight. + """ + summaries = self._summaries + for k, v in d.items(): + w = 1 + if isinstance(v, tuple): + v = v[0] + w = v[1] + summaries[k].add(v, weight=w) + + def compute_mean(self): + """Creates a dictionary of mean values. + It returns a single dictionary that holds a mean value for each entry + added to the summary. + Returns: + dict: Dictionary of mean values. + """ + return { + name: summary.compute_mean() + for name, summary in self._summaries.items() + } + + def make_statistics(self): + """Creates a dictionary of statistics. + It returns a single dictionary that holds mean and standard deviation + values for every entry added to the summary. For an entry of name + ``'key'``, these values are added to the dictionary by names ``'key'`` + and ``'key.std'``, respectively. + Returns: + dict: Dictionary of statistics of all entries. + """ + stats = {} + for name, summary in self._summaries.items(): + mean, std = summary.make_statistics() + stats[name] = mean + stats[name + '.std'] = std + + return stats diff --git a/deepspeech/training/scheduler.py b/deepspeech/training/scheduler.py index d3613028..bb53281a 100644 --- a/deepspeech/training/scheduler.py +++ b/deepspeech/training/scheduler.py @@ -11,18 +11,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any +from typing import Dict +from typing import Text from typing import Union from paddle.optimizer.lr import LRScheduler from typeguard import check_argument_types +from deepspeech.utils.dynamic_import import dynamic_import +from deepspeech.utils.dynamic_import import instance_class from deepspeech.utils.log import Log -__all__ = ["WarmupLR"] +__all__ = ["WarmupLR", "LRSchedulerFactory"] logger = Log(__name__).getlog() +SCHEDULER_DICT = { + "noam": "paddle.optimizer.lr:NoamDecay", + "expdecaylr": "paddle.optimizer.lr:ExponentialDecay", + "piecewisedecay": "paddle.optimizer.lr:PiecewiseDecay", +} + +def register_scheduler(cls): + """Register scheduler.""" + alias = cls.__name__.lower() + SCHEDULER_DICT[cls.__name__.lower()] = cls.__module__ + ":" + cls.__name__ + return cls + + +@register_scheduler class WarmupLR(LRScheduler): """The WarmupLR scheduler This scheduler is almost same as NoamLR Scheduler except for following @@ -40,7 +59,8 @@ class WarmupLR(LRScheduler): warmup_steps: Union[int, float]=25000, learning_rate=1.0, last_epoch=-1, - verbose=False): + verbose=False, + **kwargs): assert check_argument_types() self.warmup_steps = warmup_steps super().__init__(learning_rate, last_epoch, verbose) @@ -64,3 +84,45 @@ class WarmupLR(LRScheduler): None ''' self.step(epoch=step) + + +@register_scheduler +class ConstantLR(LRScheduler): + """ + Args: + learning_rate (float): The initial learning rate. It is a python float number. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``ConstantLR`` instance to schedule learning rate. + """ + + def __init__(self, learning_rate, last_epoch=-1, verbose=False): + super().__init__(learning_rate, last_epoch, verbose) + + def get_lr(self): + return self.base_lr + + +def dynamic_import_scheduler(module): + """Import Scheduler class dynamically. + + Args: + module (str): module_name:class_name or alias in `SCHEDULER_DICT` + + Returns: + type: Scheduler class + + """ + module_class = dynamic_import(module, SCHEDULER_DICT) + assert issubclass(module_class, + LRScheduler), f"{module} does not implement LRScheduler" + return module_class + + +class LRSchedulerFactory(): + @classmethod + def from_args(cls, name: str, args: Dict[Text, Any]): + module_class = dynamic_import_scheduler(name.lower()) + return instance_class(module_class, args) diff --git a/deepspeech/training/timer.py b/deepspeech/training/timer.py new file mode 100644 index 00000000..2ca9d638 --- /dev/null +++ b/deepspeech/training/timer.py @@ -0,0 +1,50 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import time + +from deepspeech.utils.log import Log + +__all__ = ["Timer"] + +logger = Log(__name__).getlog() + + +class Timer(): + """To be used like this: + with Timer("Message") as value: + do some thing + """ + + def __init__(self, message=None): + self.message = message + + def duration(self) -> str: + elapsed_time = time.time() - self.start + time_str = str(datetime.timedelta(seconds=elapsed_time)) + return time_str + + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, type, value, traceback): + if self.message: + logger.info(self.message.format(self.duration())) + + def __call__(self) -> float: + return time.time() - self.start + + def __str__(self): + return self.duration() diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 11e5f214..70d7ec1f 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import time +from contextlib import contextmanager from pathlib import Path import paddle @@ -29,37 +30,37 @@ logger = Log(__name__).getlog() class Trainer(): """ - An experiment template in order to structure the training code and take - care of saving, loading, logging, visualization stuffs. It's intended to - be flexible and simple. - - So it only handles output directory (create directory for the output, - create a checkpoint directory, dump the config in use and create + An experiment template in order to structure the training code and take + care of saving, loading, logging, visualization stuffs. It's intended to + be flexible and simple. + + So it only handles output directory (create directory for the output, + create a checkpoint directory, dump the config in use and create visualizer and logger) in a standard way without enforcing any - input-output protocols to the model and dataloader. It leaves the main - part for the user to implement their own (setup the model, criterion, - optimizer, define a training step, define a validation function and + input-output protocols to the model and dataloader. It leaves the main + part for the user to implement their own (setup the model, criterion, + optimizer, define a training step, define a validation function and customize all the text and visual logs). - It does not save too much boilerplate code. The users still have to write - the forward/backward/update mannually, but they are free to add + It does not save too much boilerplate code. The users still have to write + the forward/backward/update mannually, but they are free to add non-standard behaviors if needed. We have some conventions to follow. - 1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and + 1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and ``valid_loader``, ``config`` and ``args`` attributes. - 2. The config should have a ``training`` field, which has - ``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is - used as the trigger to invoke validation, checkpointing and stop of the + 2. The config should have a ``training`` field, which has + ``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is + used as the trigger to invoke validation, checkpointing and stop of the experiment. - 3. There are four methods, namely ``train_batch``, ``valid``, + 3. There are four methods, namely ``train_batch``, ``valid``, ``setup_model`` and ``setup_dataloader`` that should be implemented. - Feel free to add/overwrite other methods and standalone functions if you + Feel free to add/overwrite other methods and standalone functions if you need. - + Parameters ---------- config: yacs.config.CfgNode The configuration used for the experiment. - + args: argparse.Namespace The parsed command line arguments. Examples @@ -68,17 +69,17 @@ class Trainer(): >>> exp = Trainer(config, args) >>> exp.setup() >>> exp.run() - >>> + >>> >>> config = get_cfg_defaults() >>> parser = default_argument_parser() >>> args = parser.parse_args() - >>> if args.config: + >>> if args.config: >>> config.merge_from_file(args.config) >>> if args.opts: >>> config.merge_from_list(args.opts) >>> config.freeze() - >>> - >>> if args.nprocs > 1 and args.device == "gpu": + >>> + >>> if args.nprocs > 0: >>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) >>> else: >>> main_sp(config, args) @@ -93,18 +94,24 @@ class Trainer(): self.checkpoint_dir = None self.iteration = 0 self.epoch = 0 + self._train = True - def setup(self): - """Setup the experiment. - """ - paddle.set_device(self.args.device) + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') if self.parallel: self.init_parallel() + @contextmanager + def eval(self): + self._train = False + yield + self._train = True + + def setup(self): + """Setup the experiment. + """ self.setup_output_dir() self.dump_config() self.setup_visualizer() - self.setup_checkpointer() self.setup_dataloader() self.setup_model() @@ -114,10 +121,10 @@ class Trainer(): @property def parallel(self): - """A flag indicating whether the experiment should run with + """A flag indicating whether the experiment should run with multiprocessing. """ - return self.args.device == "gpu" and self.args.nprocs > 1 + return self.args.nprocs > 1 def init_parallel(self): """Init environment for multiprocess training. @@ -139,14 +146,14 @@ class Trainer(): "epoch": self.epoch, "lr": self.optimizer.get_lr() }) - checkpoint.save_parameters(self.checkpoint_dir, self.iteration + Checkpoint().save_parameters(self.checkpoint_dir, self.iteration if tag is None else tag, self.model, self.optimizer, infos) def resume_or_scratch(self): - """Resume from latest checkpoint at checkpoints in the output + """Resume from latest checkpoint at checkpoints in the output directory or load a specified checkpoint. - + If ``args.checkpoint_path`` is not None, load the checkpoint, else resume training. """ @@ -158,8 +165,8 @@ class Trainer(): checkpoint_path=self.args.checkpoint_path) if infos: # restore from ckpt - self.iteration = infos["step"] - self.epoch = infos["epoch"] + self.iteration = infos["step"] + 1 + self.epoch = infos["epoch"] + 1 scratch = False else: self.iteration = 0 @@ -237,31 +244,61 @@ class Trainer(): try: self.train() except KeyboardInterrupt: - self.save() exit(-1) finally: self.destory() - logger.info("Training Done.") + logger.info("Train Done.") + + def run_test(self): + """Do Test/Decode""" + with self.eval(): + self.resume_or_scratch() + try: + self.test() + except KeyboardInterrupt: + exit(-1) + logger.info("Test/Decode Done.") + + def run_export(self): + """Do Model Export""" + with self.eval(): + try: + self.export() + except KeyboardInterrupt: + exit(-1) + logger.info("Export Done.") def setup_output_dir(self): """Create a directory used for output. """ - # output dir - output_dir = Path(self.args.output).expanduser() - output_dir.mkdir(parents=True, exist_ok=True) - + if self.args.output: + output_dir = Path(self.args.output).expanduser() + elif self.args.checkpoint_path: + output_dir = Path( + self.args.checkpoint_path).expanduser().parent.parent self.output_dir = output_dir + self.output_dir.mkdir(parents=True, exist_ok=True) - def setup_checkpointer(self): - """Create a directory used to save checkpoints into. - - It is "checkpoints" inside the output directory. - """ - # checkpoint dir - checkpoint_dir = self.output_dir / "checkpoints" - checkpoint_dir.mkdir(exist_ok=True) + self.checkpoint_dir = self.output_dir / "checkpoints" + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + + self.log_dir = output_dir / "log" + self.log_dir.mkdir(parents=True, exist_ok=True) + + self.test_dir = output_dir / "test" + self.test_dir.mkdir(parents=True, exist_ok=True) + + self.decode_dir = output_dir / "decode" + self.decode_dir.mkdir(parents=True, exist_ok=True) - self.checkpoint_dir = checkpoint_dir + self.export_dir = output_dir / "export" + self.export_dir.mkdir(parents=True, exist_ok=True) + + self.visual_dir = output_dir / "visual" + self.visual_dir.mkdir(parents=True, exist_ok=True) + + self.config_dir = output_dir / "conf" + self.config_dir.mkdir(parents=True, exist_ok=True) @mp_tools.rank_zero_only def destory(self): @@ -273,27 +310,34 @@ class Trainer(): @mp_tools.rank_zero_only def setup_visualizer(self): """Initialize a visualizer to log the experiment. - + The visual log is saved in the output directory. - + Notes ------ - Only the main process has a visualizer with it. Use multiple - visualizers in multiprocess to write to a same log file may cause + Only the main process has a visualizer with it. Use multiple + visualizers in multiprocess to write to a same log file may cause unexpected behaviors. """ # visualizer - visualizer = SummaryWriter(logdir=str(self.output_dir)) + visualizer = SummaryWriter(logdir=str(self.visual_dir)) self.visualizer = visualizer @mp_tools.rank_zero_only def dump_config(self): - """Save the configuration used for this experiment. - - It is saved in to ``config.yaml`` in the output directory at the + """Save the configuration used for this experiment. + + It is saved in to ``config.yaml`` in the output directory at the beginning of the experiment. """ - with open(self.output_dir / "config.yaml", 'wt') as f: + config_file = self.config_dir / "config.yaml" + if self._train and config_file.exists(): + time_stamp = time.strftime("%Y_%m_%d_%H_%M_%s", time.gmtime()) + target_path = self.config_dir / ".".join( + [time_stamp, "config.yaml"]) + config_file.rename(target_path) + + with open(config_file, 'wt') as f: print(self.config, file=f) def train_batch(self): @@ -307,14 +351,26 @@ class Trainer(): """ raise NotImplementedError("valid should be implemented.") + @paddle.no_grad() + def test(self): + """The test. A subclass should implement this method in Tester. + """ + raise NotImplementedError("test should be implemented.") + + @paddle.no_grad() + def export(self): + """The test. A subclass should implement this method in Tester. + """ + raise NotImplementedError("export should be implemented.") + def setup_model(self): - """Setup model, criterion and optimizer, etc. A subclass should + """Setup model, criterion and optimizer, etc. A subclass should implement this method. """ raise NotImplementedError("setup_model should be implemented.") def setup_dataloader(self): - """Setup training dataloader and validation dataloader. A subclass + """Setup training dataloader and validation dataloader. A subclass should implement this method. """ raise NotImplementedError("setup_dataloader should be implemented.") diff --git a/deepspeech/training/triggers/__init__.py b/deepspeech/training/triggers/__init__.py new file mode 100644 index 00000000..1a7c4292 --- /dev/null +++ b/deepspeech/training/triggers/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .interval_trigger import IntervalTrigger + + +def never_fail_trigger(trainer): + return False + + +def get_trigger(trigger): + if trigger is None: + return never_fail_trigger + if callable(trigger): + return trigger + else: + trigger = IntervalTrigger(*trigger) + return trigger diff --git a/deepspeech/training/triggers/interval_trigger.py b/deepspeech/training/triggers/interval_trigger.py new file mode 100644 index 00000000..1e04afad --- /dev/null +++ b/deepspeech/training/triggers/interval_trigger.py @@ -0,0 +1,38 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class IntervalTrigger(): + """A Predicate to do something every N cycle.""" + + def __init__(self, period: int, unit: str): + if unit not in ("iteration", "epoch"): + raise ValueError("unit should be 'iteration' or 'epoch'") + if period <= 0: + raise ValueError("period should be a positive integer.") + self.period = period + self.unit = unit + self.last_index = None + + def __call__(self, trainer): + if self.last_index is None: + last_index = getattr(trainer.updater.state, self.unit) + self.last_index = last_index + + last_index = self.last_index + index = getattr(trainer.updater.state, self.unit) + fire = index // self.period != last_index // self.period + + self.last_index = index + return fire diff --git a/deepspeech/training/triggers/limit_trigger.py b/deepspeech/training/triggers/limit_trigger.py new file mode 100644 index 00000000..ecd527ac --- /dev/null +++ b/deepspeech/training/triggers/limit_trigger.py @@ -0,0 +1,31 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class LimitTrigger(): + """A Predicate to decide whether to stop.""" + + def __init__(self, limit: int, unit: str): + if unit not in ("iteration", "epoch"): + raise ValueError("unit should be 'iteration' or 'epoch'") + if limit <= 0: + raise ValueError("limit should be a positive integer.") + self.limit = limit + self.unit = unit + + def __call__(self, trainer): + state = trainer.updater.state + index = getattr(state, self.unit) + fire = index >= self.limit + return fire diff --git a/deepspeech/training/triggers/time_trigger.py b/deepspeech/training/triggers/time_trigger.py new file mode 100644 index 00000000..ea8fe562 --- /dev/null +++ b/deepspeech/training/triggers/time_trigger.py @@ -0,0 +1,32 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class TimeTrigger(): + """Trigger based on a fixed time interval. + This trigger accepts iterations with a given interval time. + Args: + period (float): Interval time. It is given in seconds. + """ + + def __init__(self, period): + self._period = period + self._next_time = self._period + + def __call__(self, trainer): + if self._next_time < trainer.elapsed_time: + self._next_time += self._period + return True + else: + return False diff --git a/deepspeech/training/updaters/__init__.py b/deepspeech/training/updaters/__init__.py new file mode 100644 index 00000000..185a92b8 --- /dev/null +++ b/deepspeech/training/updaters/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/deepspeech/training/updaters/standard_updater.py b/deepspeech/training/updaters/standard_updater.py new file mode 100644 index 00000000..10c99e7f --- /dev/null +++ b/deepspeech/training/updaters/standard_updater.py @@ -0,0 +1,195 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict +from typing import Optional + +import paddle +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from paddle.optimizer.lr import LRScheduler + +from deepspeech.training.reporter import report +from deepspeech.training.updaters.updater import UpdaterBase +from deepspeech.training.updaters.updater import UpdaterState +from deepspeech.utils.log import Log + +__all__ = ["StandardUpdater"] + +logger = Log(__name__).getlog() + + +class StandardUpdater(UpdaterBase): + """An example of over-simplification. Things may not be that simple, but + you can subclass it to fit your need. + """ + + def __init__(self, + model: Layer, + optimizer: Optimizer, + scheduler: LRScheduler, + dataloader: DataLoader, + init_state: Optional[UpdaterState]=None): + super().__init__(init_state) + # it is designed to hold multiple models + models = {"main": model} + self.models: Dict[str, Layer] = models + self.model = model + + # it is designed to hold multiple optimizers + optimizers = {"main": optimizer} + self.optimizer = optimizer + self.optimizers: Dict[str, Optimizer] = optimizers + + # it is designed to hold multiple scheduler + schedulers = {"main": scheduler} + self.scheduler = scheduler + self.schedulers: Dict[str, LRScheduler] = schedulers + + # dataloaders + self.dataloader = dataloader + + self.train_iterator = iter(dataloader) + + def update(self): + # We increase the iteration index after updating and before extension. + # Here are the reasons. + + # 0. Snapshotting(as well as other extensions, like visualizer) is + # executed after a step of updating; + # 1. We decide to increase the iteration index after updating and + # before any all extension is executed. + # 3. We do not increase the iteration after extension because we + # prefer a consistent resume behavior, when load from a + # `snapshot_iter_100.pdz` then the next step to train is `101`, + # naturally. But if iteration is increased increased after + # extension(including snapshot), then, a `snapshot_iter_99` is + # loaded. You would need a extra increasing of the iteration idex + # before training to avoid another iteration `99`, which has been + # done before snapshotting. + # 4. Thus iteration index represrnts "currently how mant epochs has + # been done." + # NOTE: use report to capture the correctly value. If you want to + # report the learning rate used for a step, you must report it before + # the learning rate scheduler's step() has been called. In paddle's + # convention, we do not use an extension to change the learning rate. + # so if you want to report it, do it in the updater. + + # Then here comes the next question. When is the proper time to + # increase the epoch index? Since all extensions are executed after + # updating, it is the time that after updating is the proper time to + # increase epoch index. + # 1. If we increase the epoch index before updating, then an extension + # based ot epoch would miss the correct timing. It could only be + # triggerd after an extra updating. + # 2. Theoretically, when an epoch is done, the epoch index should be + # increased. So it would be increase after updating. + # 3. Thus, eppoch index represents "currently how many epochs has been + # done." So it starts from 0. + + # switch to training mode + for model in self.models.values(): + model.train() + + # training for a step is implemented here + with Timier("data time cost:{}"): + batch = self.read_batch() + with Timier("step time cost:{}"): + self.update_core(batch) + + self.state.iteration += 1 + if self.updates_per_epoch is not None: + if self.state.iteration % self.updates_per_epoch == 0: + self.state.epoch += 1 + + def update_core(self, batch): + """A simple case for a training step. Basic assumptions are: + Single model; + Single optimizer; + Single scheduler, and update learning rate each step; + A batch from the dataloader is just the input of the model; + The model return a single loss, or a dict containing serval losses. + Parameters updates at every batch, no gradient accumulation. + """ + loss = self.model(*batch) + + if isinstance(loss, paddle.Tensor): + loss_dict = {"main": loss} + else: + # Dict[str, Tensor] + loss_dict = loss + if "main" not in loss_dict: + main_loss = 0 + for loss_item in loss.values(): + main_loss += loss_item + loss_dict["main"] = main_loss + + for name, loss_item in loss_dict.items(): + report(name, float(loss_item)) + + self.optimizer.clear_grad() + loss_dict["main"].backward() + self.optimizer.step() + self.scheduler.step() + + @property + def updates_per_epoch(self): + """Number of steps per epoch, + determined by the length of the dataloader.""" + length_of_dataloader = None + try: + length_of_dataloader = len(self.dataloader) + except TypeError: + logger.debug("This dataloader has no __len__.") + finally: + return length_of_dataloader + + def new_epoch(self): + """Start a new epoch.""" + # NOTE: all batch sampler for distributed training should + # subclass DistributedBatchSampler and implement `set_epoch` method + if hasattr(self.dataloader, "batch_sampler"): + batch_sampler = self.dataloader.batch_sampler + if isinstance(batch_sampler, DistributedBatchSampler): + batch_sampler.set_epoch(self.state.epoch) + self.train_iterator = iter(self.dataloader) + + def read_batch(self): + """Read a batch from the data loader, auto renew when data is exhausted.""" + try: + batch = next(self.train_iterator) + except StopIteration: + self.new_epoch() + batch = next(self.train_iterator) + return batch + + def state_dict(self): + """State dict of a Updater, model, optimizers/schedulers + and updater state are included.""" + state_dict = super().state_dict() + for name, model in self.models.items(): + state_dict[f"{name}_params"] = model.state_dict() + for name, optim in self.optimizers.items(): + state_dict[f"{name}_optimizer"] = optim.state_dict() + return state_dict + + def set_state_dict(self, state_dict): + """Set state dict for a Updater. Parameters of models, states for + optimizers/schedulers and UpdaterState are restored.""" + for name, model in self.models.items(): + model.set_state_dict(state_dict[f"{name}_params"]) + for name, optim in self.optimizers.items(): + optim.set_state_dict(state_dict[f"{name}_optimizer"]) + super().set_state_dict(state_dict) diff --git a/deepspeech/training/updaters/trainer.py b/deepspeech/training/updaters/trainer.py new file mode 100644 index 00000000..07769465 --- /dev/null +++ b/deepspeech/training/updaters/trainer.py @@ -0,0 +1,184 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import traceback +from collections import OrderedDict +from pathlib import Path +from typing import Callable +from typing import List +from typing import Union + +import six +import tqdm + +from deepspeech.training.extensions.extension import Extension +from deepspeech.training.extensions.extension import PRIORITY_READER +from deepspeech.training.reporter import ObsScope +from deepspeech.training.triggers import get_trigger +from deepspeech.training.triggers.limit_trigger import LimitTrigger +from deepspeech.training.updaters.updater import UpdaterBase + + +class _ExtensionEntry(): + def __init__(self, extension, trigger, priority): + self.extension = extension + self.trigger = trigger + self.priority = priority + + +class Trainer(): + def __init__(self, + updater: UpdaterBase, + stop_trigger: Callable=None, + out: Union[str, Path]='result', + extensions: List[Extension]=None): + self.updater = updater + self.extensions = OrderedDict() + self.stop_trigger = LimitTrigger(*stop_trigger) + self.out = Path(out) + self.observation = None + + self._done = False + if extensions: + for ext in extensions: + self.extend(ext) + + @property + def is_before_training(self): + return self.updater.state.iteration == 0 + + def extend(self, extension, name=None, trigger=None, priority=None): + # get name for the extension + # argument \ + # -> extention's name \ + # -> default_name (class name, when it is an object) \ + # -> function name when it is a function \ + # -> error + + if name is None: + name = getattr(extension, 'name', None) + if name is None: + name = getattr(extension, 'default_name', None) + if name is None: + name = getattr(extension, '__name__', None) + if name is None: + raise ValueError("Name is not given for the extension.") + if name == 'training': + raise ValueError("training is a reserved name.") + + if trigger is None: + trigger = getattr(extension, 'trigger', (1, 'iteration')) + trigger = get_trigger(trigger) + + if priority is None: + priority = getattr(extension, 'priority', PRIORITY_READER) + + # add suffix to avoid nameing conflict + ordinal = 0 + modified_name = name + while modified_name in self.extensions: + ordinal += 1 + modified_name = f"{name}_{ordinal}" + extension.name = modified_name + + self.extensions[modified_name] = _ExtensionEntry(extension, trigger, + priority) + + def get_extension(self, name): + """get extension by name.""" + extensions = self.extensions + if name in extensions: + return extensions[name].extension + else: + raise ValueError(f'extension {name} not found') + + def run(self): + if self._done: + raise RuntimeError("Training is already done!.") + + self.out.mkdir(parents=True, exist_ok=True) + + # sort extensions by priorities once + extension_order = sorted( + self.extensions.keys(), + key=lambda name: self.extensions[name].priority, + reverse=True) + extensions = [(name, self.extensions[name]) for name in extension_order] + + # initializing all extensions + for name, entry in extensions: + if hasattr(entry.extension, "initialize"): + entry.extension.initialize(self) + + update = self.updater.update # training step + stop_trigger = self.stop_trigger + + # display only one progress bar + max_iteration = None + if isinstance(stop_trigger, LimitTrigger): + if stop_trigger.unit == 'epoch': + max_epoch = self.stop_trigger.limit + updates_per_epoch = getattr(self.updater, "updates_per_epoch", + None) + max_iteration = max_epoch * updates_per_epoch if updates_per_epoch else None + else: + max_iteration = self.stop_trigger.limit + + p = tqdm.tqdm(initial=self.updater.state.iteration, total=max_iteration) + + try: + while not stop_trigger(self): + self.observation = {} + # set observation as the `report` target + # you can use `report` freely in Updater.update() + + # updating parameters and state + with ObsScope(self.observation): + update() + p.update() + + # execute extension when necessary + for name, entry in extensions: + if entry.trigger(self): + entry.extension(self) + + # print("###", self.observation) + except Exception as e: + f = sys.stderr + f.write(f"Exception in main training loop: {e}\n") + f.write("Traceback (most recent call last):\n") + traceback.print_tb(sys.exc_info()[2]) + f.write( + "Trainer extensions will try to handle the extension. Then all extensions will finalize." + ) + + # capture the exception in the mian training loop + exc_info = sys.exc_info() + + # try to handle it + for name, entry in extensions: + if hasattr(entry.extension, "on_error"): + try: + entry.extension.on_error(self, e, sys.exc_info()[2]) + except Exception as ee: + f.write(f"Exception in error handler: {ee}\n") + f.write('Traceback (most recent call last):\n') + traceback.print_tb(sys.exc_info()[2]) + + # raise exception in main training loop + six.reraise(*exc_info) + finally: + for name, entry in extensions: + if hasattr(entry.extension, "finalize"): + entry.extension.finalize(self) diff --git a/deepspeech/training/updaters/updater.py b/deepspeech/training/updaters/updater.py new file mode 100644 index 00000000..e5dd6556 --- /dev/null +++ b/deepspeech/training/updaters/updater.py @@ -0,0 +1,84 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass + +import paddle + +from deepspeech.utils.log import Log + +__all__ = ["UpdaterBase", "UpdaterState"] + +logger = Log(__name__).getlog() + + +@dataclass +class UpdaterState: + iteration: int = 0 + epoch: int = 0 + + +class UpdaterBase(): + """An updater is the abstraction of how a model is trained given the + dataloader and the optimizer. + The `update_core` method is a step in the training loop with only necessary + operations (get a batch, forward and backward, update the parameters). + Other stuffs are made extensions. Visualization, saving, loading and + periodical validation and evaluation are not considered here. + But even in such simplist case, things are not that simple. There is an + attempt to standardize this process and requires only the model and + dataset and do all the stuffs automatically. But this may hurt flexibility. + If we assume a batch yield from the dataloader is just the input to the + model, we will find that some model requires more arguments, or just some + keyword arguments. But this prevents us from over-simplifying it. + From another perspective, the batch may includes not just the input, but + also the target. But the model's forward method may just need the input. + We can pass a dict or a super-long tuple to the model and let it pick what + it really needs. But this is an abuse of lazy interface. + After all, we care about how a model is trained. But just how the model is + used for inference. We want to control how a model is trained. We just + don't want to be messed up with other auxiliary code. + So the best practice is to define a model and define a updater for it. + """ + + def __init__(self, init_state=None): + # init state + if init_state is None: + self.state = UpdaterState() + else: + self.state = init_state + + def update(self, batch): + raise NotImplementedError( + "Implement your own `update` method for training a step.") + + def state_dict(self): + state_dict = { + "epoch": self.state.epoch, + "iteration": self.state.iteration, + } + return state_dict + + def set_state_dict(self, state_dict): + self.state.epoch = state_dict["epoch"] + self.state.iteration = state_dict["iteration"] + + def save(self, path): + logger.debug(f"Saving to {path}.") + archive = self.state_dict() + paddle.save(archive, str(path)) + + def load(self, path): + logger.debug(f"Loading from {path}.") + archive = paddle.load(str(path)) + self.set_state_dict(archive) diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index 8e31edfa..796cafe0 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -39,13 +39,13 @@ class Checkpoint(): self.latest_n = latest_n self._save_all = (kbest_n == -1) - def add_checkpoint(self, - checkpoint_dir, - tag_or_iteration: Union[int, Text], - model: paddle.nn.Layer, - optimizer: Optimizer=None, - infos: dict=None, - metric_type="val_loss"): + def save_parameters(self, + checkpoint_dir, + tag_or_iteration: Union[int, Text], + model: paddle.nn.Layer, + optimizer: Optimizer=None, + infos: dict=None, + metric_type="val_loss"): """Save checkpoint in best_n and latest_n. Args: diff --git a/env.sh b/env.sh index 461586e7..e782d815 100644 --- a/env.sh +++ b/env.sh @@ -1,6 +1,6 @@ export MAIN_ROOT=${PWD} -export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:/usr/local/bin:${PATH} +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}:/usr/local/bin export LC_ALL=C # Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 8b08ee30..1f61afe7 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -13,7 +13,7 @@ data: max_output_len: .inf min_output_input_ratio: 0.00 max_output_input_ratio: .inf - specgram_type: linear + spectrum_type: linear target_sample_rate: 16000 max_freq: None n_fft: None diff --git a/examples/aishell/s0/local/data.sh b/examples/aishell/s0/local/data.sh index 2f09b14a..3f0ed0dc 100755 --- a/examples/aishell/s0/local/data.sh +++ b/examples/aishell/s0/local/data.sh @@ -46,7 +46,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then num_workers=$(nproc) python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ --manifest_path="data/manifest.train.raw" \ - --specgram_type="linear" \ + --spectrum_type="linear" \ --delta_delta=false \ --stride_ms=10.0 \ --window_ms=20.0 \ diff --git a/examples/aishell/s1/conf/chunk_conformer.yaml b/examples/aishell/s1/conf/chunk_conformer.yaml index 904624c3..3984a7fe 100644 --- a/examples/aishell/s1/conf/chunk_conformer.yaml +++ b/examples/aishell/s1/conf/chunk_conformer.yaml @@ -15,7 +15,7 @@ data: min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 raw_wav: True # use raw_wav or kaldi feature - specgram_type: fbank #linear, mfcc, fbank + spectrum_type: fbank #linear, mfcc, fbank feat_dim: 80 delta_delta: False dither: 1.0 diff --git a/examples/aishell/s1/conf/conformer.yaml b/examples/aishell/s1/conf/conformer.yaml index b880f858..51bd1ad4 100644 --- a/examples/aishell/s1/conf/conformer.yaml +++ b/examples/aishell/s1/conf/conformer.yaml @@ -15,7 +15,7 @@ data: min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 raw_wav: True # use raw_wav or kaldi feature - specgram_type: fbank #linear, mfcc, fbank + spectrum_type: fbank #linear, mfcc, fbank feat_dim: 80 delta_delta: False dither: 1.0 diff --git a/examples/aishell/s1/local/data.sh b/examples/aishell/s1/local/data.sh index c6abce3b..ed58bb6f 100755 --- a/examples/aishell/s1/local/data.sh +++ b/examples/aishell/s1/local/data.sh @@ -46,7 +46,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then num_workers=$(nproc) python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ --manifest_path="data/manifest.train.raw" \ - --specgram_type="fbank" \ + --spectrum_type="fbank" \ --feat_dim=80 \ --delta_delta=false \ --stride_ms=10.0 \ diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index d1746bff..30178d2f 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -13,7 +13,7 @@ data: max_output_len: .inf min_output_input_ratio: 0.00 max_output_input_ratio: .inf - specgram_type: linear + spectrum_type: linear target_sample_rate: 16000 max_freq: None n_fft: None diff --git a/examples/librispeech/s0/local/data.sh b/examples/librispeech/s0/local/data.sh index 921f1f49..8d09baf6 100755 --- a/examples/librispeech/s0/local/data.sh +++ b/examples/librispeech/s0/local/data.sh @@ -62,7 +62,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ --manifest_path="data/manifest.train.raw" \ --num_samples=2000 \ - --specgram_type="linear" \ + --spectrum_type="linear" \ --delta_delta=false \ --sample_rate=16000 \ --stride_ms=10.0 \ diff --git a/examples/librispeech/s1/conf/chunk_confermer.yaml b/examples/librispeech/s1/conf/chunk_confermer.yaml index ec945a18..db0d937c 100644 --- a/examples/librispeech/s1/conf/chunk_confermer.yaml +++ b/examples/librispeech/s1/conf/chunk_confermer.yaml @@ -16,7 +16,7 @@ data: min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 raw_wav: True # use raw_wav or kaldi feature - specgram_type: fbank #linear, mfcc, fbank + spectrum_type: fbank #linear, mfcc, fbank feat_dim: 80 delta_delta: False dither: 1.0 diff --git a/examples/librispeech/s1/conf/chunk_transformer.yaml b/examples/librispeech/s1/conf/chunk_transformer.yaml index 3939ffc6..8441de9c 100644 --- a/examples/librispeech/s1/conf/chunk_transformer.yaml +++ b/examples/librispeech/s1/conf/chunk_transformer.yaml @@ -16,7 +16,7 @@ data: min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 raw_wav: True # use raw_wav or kaldi feature - specgram_type: fbank #linear, mfcc, fbank + spectrum_type: fbank #linear, mfcc, fbank feat_dim: 80 delta_delta: False dither: 1.0 diff --git a/examples/librispeech/s1/conf/conformer.yaml b/examples/librispeech/s1/conf/conformer.yaml index 8f8bf453..3cdde4a4 100644 --- a/examples/librispeech/s1/conf/conformer.yaml +++ b/examples/librispeech/s1/conf/conformer.yaml @@ -16,7 +16,7 @@ data: min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 raw_wav: True # use raw_wav or kaldi feature - specgram_type: fbank #linear, mfcc, fbank + spectrum_type: fbank #linear, mfcc, fbank feat_dim: 80 delta_delta: False dither: 1.0 diff --git a/examples/librispeech/s1/conf/transformer.yaml b/examples/librispeech/s1/conf/transformer.yaml index a094b0fb..49baecf9 100644 --- a/examples/librispeech/s1/conf/transformer.yaml +++ b/examples/librispeech/s1/conf/transformer.yaml @@ -8,7 +8,7 @@ data: spm_model_prefix: 'data/bpe_unigram_5000' mean_std_filepath: "" augmentation_config: conf/augmentation.json - batch_size: 64 + batch_size: 32 min_input_len: 0.5 # second max_input_len: 20.0 # second min_output_len: 0.0 # tokens @@ -16,7 +16,7 @@ data: min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 raw_wav: True # use raw_wav or kaldi feature - specgram_type: fbank #linear, mfcc, fbank + spectrum_type: fbank #linear, mfcc, fbank feat_dim: 80 delta_delta: False dither: 1.0 @@ -65,13 +65,15 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false training: n_epoch: 120 - accum_grad: 2 + accum_grad: 4 global_grad_clip: 5.0 optim: adam optim_conf: diff --git a/examples/librispeech/s1/local/data.sh b/examples/librispeech/s1/local/data.sh index fbdd17d5..96924e35 100755 --- a/examples/librispeech/s1/local/data.sh +++ b/examples/librispeech/s1/local/data.sh @@ -68,7 +68,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ --manifest_path="data/manifest.train.raw" \ --num_samples=-1 \ - --specgram_type="fbank" \ + --spectrum_type="fbank" \ --feat_dim=80 \ --delta_delta=false \ --sample_rate=16000 \ diff --git a/examples/librispeech/s1/local/train.sh b/examples/librispeech/s1/local/train.sh index a4218aa8..5eebfc82 100755 --- a/examples/librispeech/s1/local/train.sh +++ b/examples/librispeech/s1/local/train.sh @@ -1,4 +1,4 @@ -#! /usr/bin/env bash +#!/bin/bash if [ $# != 2 ];then echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" @@ -11,19 +11,23 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 -device=gpu -if [ ngpu == 0 ];then - device=cpu -fi -echo "using ${device}..." - mkdir -p exp +# seed may break model convergence +seed=0 +if [ ${seed} != 0 ]; then + echo "None" +fi + python3 -u ${BIN_DIR}/train.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ ---output exp/${ckpt_name} +--output exp/${ckpt_name} \ +--seed ${seed} + +if [ ${seed} != 0 ]; then + echo "None" +fi if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index dd9ce51f..3f52da7f 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -13,7 +13,7 @@ data: max_output_len: 400.0 min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 - specgram_type: linear + spectrum_type: linear target_sample_rate: 16000 max_freq: None n_fft: None diff --git a/examples/tiny/s0/local/data.sh b/examples/tiny/s0/local/data.sh index 727a3da9..e2bfffc7 100755 --- a/examples/tiny/s0/local/data.sh +++ b/examples/tiny/s0/local/data.sh @@ -46,7 +46,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ --manifest_path="data/manifest.tiny.raw" \ --num_samples=64 \ - --specgram_type="linear" \ + --spectrum_type="linear" \ --delta_delta=false \ --sample_rate=16000 \ --stride_ms=10.0 \ diff --git a/examples/tiny/s1/conf/chunk_confermer.yaml b/examples/tiny/s1/conf/chunk_confermer.yaml index 79006626..cc9a4525 100644 --- a/examples/tiny/s1/conf/chunk_confermer.yaml +++ b/examples/tiny/s1/conf/chunk_confermer.yaml @@ -16,7 +16,7 @@ data: min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 raw_wav: True # use raw_wav or kaldi feature - specgram_type: fbank #linear, mfcc, fbank + spectrum_type: fbank #linear, mfcc, fbank feat_dim: 80 delta_delta: False dither: 1.0 diff --git a/examples/tiny/s1/conf/chunk_transformer.yaml b/examples/tiny/s1/conf/chunk_transformer.yaml index aa2b145a..da7341fe 100644 --- a/examples/tiny/s1/conf/chunk_transformer.yaml +++ b/examples/tiny/s1/conf/chunk_transformer.yaml @@ -16,7 +16,7 @@ data: min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 raw_wav: True # use raw_wav or kaldi feature - specgram_type: fbank #linear, mfcc, fbank + spectrum_type: fbank #linear, mfcc, fbank feat_dim: 80 delta_delta: False dither: 1.0 diff --git a/examples/tiny/s1/conf/conformer.yaml b/examples/tiny/s1/conf/conformer.yaml index 3813daa0..b00da663 100644 --- a/examples/tiny/s1/conf/conformer.yaml +++ b/examples/tiny/s1/conf/conformer.yaml @@ -16,7 +16,7 @@ data: min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 raw_wav: True # use raw_wav or kaldi feature - specgram_type: fbank #linear, mfcc, fbank + spectrum_type: fbank #linear, mfcc, fbank feat_dim: 80 delta_delta: False dither: 1.0 diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index 35c11731..39f5e99b 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -16,7 +16,7 @@ data: min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 raw_wav: True # use raw_wav or kaldi feature - specgram_type: fbank #linear, mfcc, fbank + spectrum_type: fbank #linear, mfcc, fbank feat_dim: 80 delta_delta: False dither: 1.0 diff --git a/examples/tiny/s1/local/data.sh b/examples/tiny/s1/local/data.sh index deff91e0..5822dc92 100755 --- a/examples/tiny/s1/local/data.sh +++ b/examples/tiny/s1/local/data.sh @@ -51,7 +51,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ --manifest_path="data/manifest.tiny.raw" \ --num_samples=64 \ - --specgram_type="fbank" \ + --spectrum_type="fbank" \ --feat_dim=80 \ --delta_delta=false \ --sample_rate=16000 \ diff --git a/examples/tiny/s1/test.profile b/examples/tiny/s1/test.profile new file mode 100644 index 0000000000000000000000000000000000000000..c64affa2343aae4b5f32f60a642cc47761e73123 GIT binary patch literal 130998 zcmcG13!IHr_y1sK7~_7;{Z`~U?&LDhh;j+J7NOIbIcJ`Em^o)WXJ*ip6iMzOT~G+2 z6e1-`H&QCOghLWar&2`7_5WRaKYKsVbDo*=sAu}@Rf`TEqsWikRYfijN*LqUM+Gw7Dc*!`UT@~WOs_X3 zt!c||Eb(27(F7N1UcRF?Tcu5(neK9-+li`tNV5yzX7!u|knRlxTv=&u(sN4-{W3D^Uf%e; zqKc1Py%d_aQ6YYkFm@_#R{0neepO3xr@MUqEb4@<#AyW-s(ag=#JNK4 zR3XlHu3F?Cd0&PiKZtX80dns1C8TEiQnFy<-02CLFEcxf)NV8y55{>jlD!^}=2O8t zk&_QicHaM&9(VgZqu=NH^b}88lwfg zu8eGE0-|6*_1b7ty}QV|^j~JLRpf^p1@v=1Fa<+PAUoMEm$Fx)`%Q21tfGn!NsB2v z@sorp7+Nq%f);IB6HEzIsOLTlygQ+JktF#cXwmi+2MiVJ6!$YNLVta$$PdgK1*yI^ z>*>nQ($WK@z8Jev!OAXoW~SHYaV2Z+fJ^ha3?qwKewQ60?i2%6bMVfZ4gxuP4sl(&`^Y6(3}r z9{JCV6A?h2dx3hUxJeZHGXhO^wr|lmxdQ&wEJs%qtC1D=%+YO%Dn5ws)c=fcH>`7x zn=CZVOD|Gy$75O&wAp%HP$WsSv)}R-6cL^nH-(nb^7xy_Hr08zW|AsCi1XzCjPn$K zMh53PI|f|GrfWX0Tepp0L5QTt&hV``L7ZVe1#7$GF|w#oW4rHsBT(WOMSc*^!wQgR z!}A4m?sSzDyc#KZEGGxdMj+*XXVyD56O*pNY%iD@jS6+~lzU2aoc_5YKZx0aG}zM3 z_i35g=~AnkQ~#DQ-Tf>-7&liRr6&H{*z?wBiYh)}`;y4^O=Wt)nXwIoyVp;MJF+aXRKw<~2sl~6PR*^XRZ_L| zbeG#V&Xul>3R|NM891~;PXFzSDn6+53ertBIxmHmOWYiVUL#$NOU(vRM3Al51R0VZ zJJ+plp$|N|KhvL0%K~j%Y4kf4YPBAf7d%k>gd#tv8`_at9LB;oSk$Nlnj+M()70V? zXcDVMEj1QRQlS=6R=YoHuwIcL)B_Lw2eyJ%C1fh*k@#a6C1CScj0RSyHQw4^^TOTh z6#3y?xP4uH&O>*;l&+}a!!C(iv7A}DmhM84Z678RNB#Nh6Ys-`7a!C=J*j^>Q2#_z zHQO)|aI%7(q}I34IazSlMj7U2Hsj35k3A|)Ss?A+c;H)H{OAk!b<-LRtu=6NS3{U|hn-tKL@R2P&%gL3^5pU`aNSi#Ea7 zCNkbssPCRnTixK?4n=<08M_rTsC!u~m#O`amC*nd>KkM0o%^EJVMTsOdC$eq!J2w3m;c>rmq?%)ei-pz%tim{@_B4Nt{_BG zWM}wR^u%HwEq-pHK9fq`|F%a_#fQ|BZe$*d8ntY5xB!)P>_pvSFQm9>I9kFWKOT#p zyXvXYiYh*cBpOwr8ziV$$O(09ViJ(Dik+fcSqjkZ2+bh%Xa=D|4LsHMuE+LlROAQI z!^}f}Mh_~C%OtlrLKQnb3Af^#;{|Dn_Var{-4Xb|Y>X$fh8nWx)O%|4@%@zs^xq@j)~dN4ID* zXc{bn3bxTIz$Cn#qnV6GGIokH-MTwH-J4FjL++-&gwtcH z45w$Upl5BOy{B$=o>9~N$?kM!>dZyKmY;vyQUBlmiYh*6RI`QChu9hlmbUt7*f_F{ z@bjM*`{}#ws4G4or-F=9@hK`*cV~F%Ln*aOZU$tOPAW@1TvIBYo zY&bTl?dNB>t%PMK$sQ!q*g%M}^l)nLkE^PC3n{AjATBllR69ppGx+FG5vschg}Su< z<;Uigy_hIJAo!N~62REZ)cq+qLjf`p*aX9NA-e(@ju5z0>1;_VR1tid!!s{F8C9Wk zR1ar9yknUea(0Nt3`tfh(QFz$#;uX_YxcOqZc&dD@1cLGm6DhT(!FkMdk5S(UR&RD zUD+na#m+?8R56m;N2qRV$pJ4~Vyod`8J9a{bhd`WRajlZN6e=&j%9aZuKr2B4G0+7 zz(istjs}!}jL+-&KiG(XSx8G@)=XBdy->OM!BY4BaV}9Uni$)-@5WRmSN|}aS_h24 z89KwMb>YAP^um=lwSvbyFiSxT7=hWM&M{1rxjJ(&IgdsuvLi1*XT-c=s*FHu%)4hW zkLfRh=1>O|0^HPN2e>5X3iKoxi(UzMv%uw2s!oN0*fKLoiaKKy24fJ0(d}xCTI`Qj zdhE~8Vr>Wr3C4QdUThTRp^la|Bv~|4xcXZAG-hc7?|`nqC#r{{Z=0|DNPYShTC(AE zz(|}2ffGa4=N;q2v}5biO|HbDEr$P)sEUtVeMdNA@sos+h|Mmt68pwOd-|73ErJ7x zX(y&^yi|d@^s=`gd<`DQO7^>T&yWN!jx1uzgp=kPqYFoMC5U%wL70CUt2jMJt|(^` zRX#L|D3VF)b$u5J#mc0F-(J4U!e`2-F}>YQtur5yn{A6heF|ks}%B@BvdCoisqfC{e}Nvg}U~E(=X5dzP%zpNYH}|LeSX3&(^aO=zy8ZRsdQH zuoaA^f4~|o0tr%VO^D_g&4W(DGSIBiyy^}q?JWM8 z73z0UM+?VIbt&p+of>POxAS1oJH{r*3M$2iQ<$g!N}qK!6KDo(&=Z4jsE}ek52lR*v{!f2#N(W4jjr!JWloqCPBL zgi0jSL#Px&fH+G30R_gY3GZ$>xOt;h0loY7Sk^tSEg6b%GQ%1KK&>$ zl8+JDohTcnt}Kw%iM1I4%R!Wt6>7h>KNtT?$C^oe z*gG+r>cVP7s6fI#29~vKw5Fe+&WW+WKi}||+P|hft*F1Rn0KJpukXqZj07Wbwkp9m z@OiWh8wVB*;k{{@kHa{ix;AxjRFJK;Lu>jki~~ypV)R=qltTp_Fx0yRvOjO>T^GMr zz?J}QK_zTD&TH6)i=^XJ>-Y^FmlY$`wydifb~0`zG(g1$wf(8J8!-IMScFwe!9OLB z^rmECQEnYz4y+_GIlzKN@}>fot@xI?_s`0x!X{W6;2To=kKzFJ*N81NfG&N!U(*CgiwvGL8J~yrmPg>*R;VyKsWh_o!F* zhaIDLP-)}JIa1a=yKDHdzsD--tX6#vj&He{adlVzTv^N#JDMM`Aq zSjB^ML@f}W(AvTyVugC9*PK(A6MHLa%*Q7Ve-JxESye=T6;(oyQ@xr9x@WmO-c&ah z+q^mpA!_b)4ZK1f9t-pELIb?Ex%met!kp6N zXc91FD++p2cV$x?hj#Z;H5S3BO&rEz!Q|Bu4hq52Md@@DU7h`5&%KM&N-1iW_X-#P z=X@PzGOpogZ4HTeJH8OG#Ex~8UWnk^j5+_&@M*D%`oXx>U44_QDX;|$d_940W#9~w zgQwLP+^0X?qy0Ga zDmzEdIB>wzKw~7C9K={i%6&TKVv(Ac1WC2oK&zi`90oGX@{Y|ybcmdz?nHHr=EIH@ zh#-k@J{?FU5vE+MM94uSDw@2GH9Hj`ANvz zXjvakgrLJs(;%9cbl6#HBByT;BEWq97EJYHyy=!1JVrYjhkl{LSQ&3pCSR2g z4?VIyIB_zwIU4~)0Zk|o+92YktqjH^JV$d0(+z81OWoifp}hLq0R!BZ>*Y&7!rp0o zylRz5E#m^2=~|ZV_hUU9Do!eL8~#~3bn=5-N2(BNNj%15)@0MZo<)&;s*+DV;*{!X zD)nnT&6lX?*0CcHW2x7qYxUocx4oQbjjr%|xx5~h7a3HUwP&h@v;%b#*Pi$&DT<~o z*4!dqEyF#Eb^)|>Xj2H?9i`VH_V@rfou~=}DaoYf^qPRN4LTn4Ihm-AxGi+g>G3hl z?y`3F!D|Q>g`?vg!<-zGG)EoT#uTEhoJ+72phXhnEGC47LX>_7ktn7>BZ3ZdqSqp_ zq8SHPHe{UZG=&WrRXIdeERa&1!s?K458o1DmPE}VqAF0#92*C`nz1_W4ge z7H1TP>af|24GZ?cLr`sjGFS%Q2QDP$XaU@YHM~eulk!}aF?kT3dK1PMqm2_N9Safl zqdo$VVg9Q-%sRPsG_X`uJ-dZ>7hE2YUqvAXwC-FTWU3}-Zt zk&4B;*wv8AOoLl*z-cs>8oU{-6oui))!Pw_b@V%6Bu+A8dOeQQR6J>2`RbqqA`}MB z2rv@oQ5vkYa=`vrjKwJx$|SW)%q;UP(7&cHqmRUo%jw;`6sAmd10*vG}h7rO?jcxjpD#)kV;y`;>xCk9E z5-mwYPb7G&Zz9FPKaOF3a^Q5pNSq`yJ&EHqSre+Oa?6SsjgT=3Oz*;JjsP$5FCtMC z=a-bS18kc=g0|$L)|~V(jY9i|4Xx-fG(P_rc+Qq^f18)izPo4CL`4-Jfg?@cEG>f~)A}gDD?DmoacZ&CLlxfTV=PWb)EV`p$$5^H92%JJ1rp z_{hhFPW!U2vT_>k$%+r^w(e*kMDYbSk8-bqA_%5eETCnzNMvCI3%_1a(@cEW$od-T z5;Q8*sH+24cRPgpXyPMR@5MC8Lc_Oevw@2pIEN1y}tPxxMh`)FOVSEzn9&dV<+s3Rp>IMx?t3Cf~1UJP;9+rd?o7BH zhkMG6_3lJHm4kQ*oC>wq+V%_o-H))Q_^@+QHE&M%8E!xij?^*ieC)mH3mRd4w+Dc>$Xs@>u6#lGXb+tvhvRHMS-rTJ#Cv zLyXQmCS?IH+v!B@ROZLpV1#tOVd{z0rnp;0^A~ILNow;ct^#?qDd0t$mSbyHc}ZJV zJb3a<&jtfnZx+az1lw6Y^8nvKlGhRxwk~myA;=B|NL0mXpBYm%g%zyS<2UxFiVqsd zQ%Eq?@h|$caOnsz2``$_zaw}?mKDETg>#>YqCO67a=-xRav3lBAw!e`RKAVk$jWC9E)U;h(bMO z?^^LY;068==|?cyo8zVXXD9ZCG#6Bnt=$tD`4*u2MvPi(4$nR81`A%y&w>7PD?VvC_YGRVg2ar7sPwH zyw@h`GawS1YRC|s_y9)M(=?7n6dLPXrbzve5q+lvc zNwf9B{2^6lJofODidr=3$o@CJ2(od=n8{=k(6fEG^%tPMCR39gm4w*X?F4-Vv1heA za?r-LKDck(Z{3q}^_Rj$;D8Z0StMR>AC6O=+MZve^S*c8txdv7;q({7@jKupe%hze zn}_4)V#Rh?a`L&+87qe{^=zBJdRlalJ0A%Ci(^S4y37?UItE!uac0scN=Gx2EDBfY z9m2^HCnwY=npJICbNSsc;VTKo+C2^!!ucqnTYw^1(UnflKp-v4OBwhO4tHY-fexdX z9xK$vRX!`4bG3n@)+^nul6Jf)6E0y0^&n2-u{_4?PTm!{HxcT|mfxc0-$r3;$2Inj zoq@;iSJVo%i^DuV#7rEcG?(BXBY3O(bm*7u=O^8fh^6d^Df?@=dno$j;UIrUf6{l| zNJPCZ9-nq$Y@I&r_;7(_{IGcdI#v;of$0b|Cc%`(jEpU#_I^6#2GM zQ7GL+0NQXmUBenD{Zz%%2Y-u~SR3y7qOW-d=b*YesU4GDv z!0>M(nD;pT2=KsPFdm6bN4X?u1{_Re7sr3QvNI`MYLkwf1jPqYzm*7IgaRA?hD5FV zjRI5^A0#1GXACt}f|^RD9DHz#Y>ja>u}p@!`Zg{eQ~VAXLQZaQ=m#Kos^`GM$>o}z zY@w(ncJBOa%~LSddNYpGoU$A+0;l!>FN{4Uk-Tc-vT-9XzyP?mY8Oq)-rhtM|c$Gc_01raN#&$h~9QLHP^`yp^u^}S@lrCEfKTu*c_a%kJ+Cyq6#r7`sC( zYkVY{NV*+wNxC?4%Cu$IupfZXR5LW@fFYc#2;FKb-XM7EGEAJ&QV*1UCcXAXT2#EwahfxUJTLXV z?}OdLUw13|at>k|Bx$RE`=Q|U-9LqTe3~=6L424?-ZstwX{s5>5GpVNFxQ~v#IPpQsaRhvZ8loC*LtwulfUK6Z zQDBp}fR`1Cb*iWzqv9tLl{Hf)ub^%JOL8t1<%Uu<^l@9k9 zDs5}eQD3wdVopOmHm;C|IpFg8$>y01zWwX>qaN*to%AhpPQL$P=}i#%aiYztJq{Sc zJpxdn9O_0F<3(#7^I#M!)V&j*8n9vgMMZw(>TZV9vT>&VHT~U2&+buF@d1>8$@i?a z43EH2(C*(S$#+GuLjAYDTAUJmV8jjrrA0q^nNBiEXlSDB{ z*2&iaYEwd7kt6F&h5Ay;xCc8w(pOR6`)d3Tp>I-@RYoluyoe2_WToIvo;Ot+>q_>z zGoX-^)DuD)o?1GJ6>8l+TZ(jQ_^+bQNZRt2t9gYa(ePzLQLIzCJ0qFW+C*z-5=%v~ zLe2Q`wd;?5ejZt>R^R(}Y+5NcdcA8ykC*pm9#flo5U*mYKd#g*c$bQZThWd|j zm`ZFefj(lem?m#f`= zGoD$ZWz#T)*1vf0vElbAYSY3gj~AKCgV=vj-FZ~7{}_$X%9T00F}(M2%1pm=_n)(q zCv>ZaD~CVC^eWzDJ6S5~v(Q7pV7G6UB|aBzWkqdL;<>#C&TV9JyMk(ZRVs|3?ZW`j zs8C}|%t>6c^i@Uedg9sI-<+so4($jqqz<&U(Wx(-Uew)VH;n%5Aq=|{Es8vUHn+9J zX_+e>Fog3PLN|emC0t^bamw<7S)#|e`gZ`dkL7R}!l?l&bO%YE;|$??HNqh@7tDM5 z?YFa6bwN~5*?wf>q)+ndg|%6Uo$i~BHT$mLr*v&Qd%dDx`qXh?yY+)o?5DH^b{ao(m%_-~F8R-U3 z-YGlRLbU^iaGoY~v}qL3;{>EixPC5=4!pSC`GkMCNhstlm`wG!NwthT^ zY@$Q{4j74(oId>q(O`|TaBGPDf$GCM{L(39Cvg7B=`{P@0Ymb?HN{Hz2k3ex3kbk6 zdo*;2BckMp7Si%3j&#R#kZhg_vu2hEB=O4NV~N+|sA=QXiC>Q=n480C(qNSzUwI)) z)&04RlLzFo%lu3)^rA%g64DDReg{0svOpdTq??IqHV(d7k7tGGw571?IBy@Nmj+oz zbo-2mLg35}BwdmS{)HrGP)(ruzza!-a!ve}5ygXPxVy( zirnXTdq!k1gzB2~W(+QOsOZpp6}8f&zY}NvvKGU%0?^uca==LJ6G1qz8xPvKj9gqM zc4@Tv)-5jc>1D#PJ75HMjUwm4c+PUR4{jbP-$bIK!q`?go@iW`A>6ld+!0_T?nVH` z9t5R`bm6LZGCT5h2O8wpj=0|7awXpfJBhPr2J#3n5;^siUWJH>;l3a1+g2k#?3I`U zi&%o?gI7%5Vp{~&!-I3p$&8ES0h8@c%3g}~yN31m96Kc1!)FF+j2MINZTk*JnEcD~usS<{%W7q*hJ zhJ>!ZMz0(^iY$~DPr8>pP(yl47AL_0L-J$pbuQCA<5piF?=32uOKeP-<|>;{3KkfN zd$E+Y)!=jA2YqwzTMGs}-1vRyvqRVp7~s6z=cXJMdE7V^j|pY#>?{(`-sCd>g=W&{ zi1Xq(Nd1}YG8i|z(mB(CfSW$lPsdYm%7+}L-$@2_!!^Z$o@7=BP@&0y50>DRB6%k7 z5R)#sJtYP%g{10HMM|fk*AH*V0WTYxMGeK_9trZixsV&i+C{RK*c}(G=wt}>hWWN4 z;AJb6-!WK%@@E**r!7#^f#Z`og6-l%G!(i>4K?NsMSv&0=>0$Ig*)4$`R!^EYC#Pz zL18n(@{*|dp!uFUjLl2JpqK_(fEmXGPO%)xO$mMs;b#)GO58DsNDvJFom9yoaR)p} zoZ(+i_?aF!@#EM6uZTRT0{D$;D6fwX9aDE6pVKkDZmV`!P#`5$=KN$9 zNTzC`WRQY8%3f(7&|HLivbk2&W@{8So#yH25LC%Nu5T&)T?%W5MU5L{_e>9gMjL(y zyucs&jOdExAxK*jM{7=+@K{_M#qgd+n5?`{RZ0aikuJ@Yf|VObRneN>j<{(b&JFiJ zYbz=pzOMX;OE$%Xt+yn2D}D#Oz#rNN(rhs$tCRE3Q}+k{u3KJF#RrWq);V7jM^*V% z5oWCMx*yIIv0Z37q+nAGQCp&st$z}S(ZX7F_CAsF40 zuA*BTa8!^DWKk>Gq>xsl%-t|JTSI%NC%#p3u*pChVxTQCZ0&Lfyu?3-@Q>yAO@Rtm zQ=QeX)v~F_vzbZ}0~+nX1f%JWhK&J71=#>Q=aCfN1Kju`&$c`*tSg*olL!Zlz{!Zv z?+(Yw4I<9HOn#$C?LU75(N3HW7{V#{i}=X53VO(~#+(0n8Atd&?f%)k9!pi=bc8q@ zFcRl#67KI@xTdvgvhk>YL)XGd>Zp>-FVuawE4)F+2B8B+;=BiJJL&W%_vl;SH&{_e zUPx-(+yiggiPHfiarVp;Cl5Whf``tsakbA2rvpafq)k@+L^$tl>VGzTz(#De-2G6> z?l>2=!C^$0{J$W0hvGWmC4SRVTl=eOz}Rw4hYZU4hs(7#uJXr%8hq zC-2{f6*awZiLYMeG#BXWPxhl6Mvhml2aU9Hm?W?_jYDwen%O$A9kfYan$t|S1m*^Y z!*ovh0XW>A`}L!uZ;uH6Y|ebyfJc;?(i`P*a>!9nwnOKR^sjk8o9)Qe;Q5wQf1dw6 z+7bP1(q#>fIN&9IqD${ieQ4!FWU^4BYfY^9I`QE=*WiEwPHfDQKkZW!xlrnH!cC(i zy(;ctA7W7t=Dj~l(-RY@)p9%FiSUPjRUbwLgSQec1t~I3x zao`j=r_}O$s!r@%tTTksgDbs>(*YxJhECCxR7OOBkeZ$yjk8TR?aS|dddM|wkFropCme2Yh#AJ+RFPz1W%i=1t_z*ZF!AP8!^Tdg>059EJ9ypgj{Z*s% zwq@CA06-dNAnN3QMI_+Yv%E+-XWgxLU_WNn6`Ag#Ye{LR&`V^PZ5G^R`${}PO64Ln*AGvzJ zaP<{Gsp5k);YUJe=5n^;lmRQgvrF4J4xA1c!pTkrFCy*wm}#u_tb%w4h&wwtx2|s? z80$%6DL(;2IHwRg+NpC6UFkP*V`ZMzWe0ur&p3W_7(3u4e)G$+fm0Uz0oo~mGHIqsR)UpjbD|cC0VK1!wPn^)> z$qR!$aTm<^Aou=3h*f9H&8^{72`~vyRMCYg&4(gv!Daf1P2V2wU$6>q+Ods*Y6saI z%BrQIS}-0rIy6s9wve{@H;#I;IjW}*r~bG4H&@;&c74;F5!ivwN-B5ZG!7v<@2NOo zB+k)<^9d9-oF(3R|Fs)GBjUL~y8p#OOSzxNXlJ_s(Z(e$lUYXv*#K40ALc}x17T#j zEp=uTr(-lbyhrxoAY`uIEnKl3Fu-|}xg_M}P*rz(Qqr|d=o~xR;~c`mA{i!=5}`Vd zSUO;clAGJzMFNr(e|oy{wqmF%>HQMYFxxRIVVZ>fkzR!v1QtJ8O}m}};fC4+U6>Y{ z4iVoI!8X*gR?YujU%xHWLaL@OMzvrZJN`|3ozTGuwWvE~E#1qz(}d?PPdPCs1x`}|adqRiUxPSe!|!Tzk#f;G0o0HS&zTXamS~05($m7#Tm~tBzNh}N13l;h zont%P0Yf;+Wz|2VZZdpllT#TCIQ(JvYyT8q0Gx&k8UY443+B!bK2LA)Wvl+vRAI-8qc`2P1DojyHsr1xv zH8H@E)yw#k_)U7O`1$0DpakEXLdr&+w&ZMs#Q`HaF!VFgz+Gp(0)10q98yC_j5{}l z*dh+xNjDcaG1B~=+6~C7zZ?VkF(UGu8j;yEF!h7Vr*pEEVR!Agatv+Z7M$@G~ z8+*J9Efr(~R?ZgVxE_t9oRZV-TJj2tk6Sk@LqQ>$%y@zwiG0Yre2 zblyjFEu_?w6TfV=!A_daM5n=2 znM0I6s8Lf<*YrurZJKr({2G_BBEU$T>^cLZ!HQGrjCkuRkE{_67=cq~i=3t5be28Q z=I!6!tdD5+`Mxnn&MnM~P6rJ2%Ygj*<>u47j&r93=zV1VJPCiwvnaM;bw6{j^RP^O zY9k-VF+)k#`BLacqVPFQiE|JfbrjiwNjej#!)v31gK@}gjJ6+zcVW$-EDebGps~kl z-HED{6BEZga4?!DgRcV-=b5wq2e$oB^|WZfAX%)Ik&&I{PR917F;s$}!q{vtqnHmT z2_iG`v;HIy*i0>7KH+W2Z+?2{v^$}KVo=9g(Vu4EA$s73zgxgpA;oZ1mCpQCDyJCO zLzhSdv8lS(i!16@-4Amo&YOIcB|gZd_7MxFoQ2U>pQ>>q=BcQj=q^VH+5saql{WCj z3}NDw-7B_tX~eP-XpR;u)ed;k*V;CsI^KosKN?>fs5SiSqF^7 zT@#>~Qt)I7!m#}9apbi^zBd((Q1|kC8JsxlNet0Q1Q=?B#ekl)2*@nJc}U38=~)Kpsb0v%JCtUpTS9sU;k-v_WivyV_klrAgl%i${x2&bNf6-PA!Ux zA)!+SZF)zdL5lD(5RUhW*;OdMmZ8H>rZB^Feji(cgbCj|A3EUC(&e@{9eRWf^dV0Q zPJL1oXr!fo%?{vGI*x#f4>Ci~gK>?vDuU7d;K0}ngMgD2h0?$YbeoC67ah%_eREEvtzJLo=ybX`?~Pe8R1srOEfF ze(o>c_#QQKI zu7{A1#X;|=P)9ED=p)rMMeUO|XUKp*ak)vTJXuCc1Eo&T0Cdp42QwWb%sZ7NPTi1t9=<>ognZOUdq7M>Oj!5EJ> zIh*_fqLLK`skWt~V~BE@!ig0h`g^Evqf-1LK}d@0AeYb-fCkI!UKPnqVfBrjr|$-5D@osdif3* zVy!#G$*z*TE1J#NP1NX(v|H0eEkGBZzR@7I_6b-35e1DU$F zSwFhwsQawp%t$CD3HK2ly;0aKSI^R2-a`xDf%83T*Xn#a zU?fh|SW;BinIfUVP+%Q~V@h2q&BL*q+A+tluqWl^ZOCpX4N&pEa?-0O#cv`7bCO z2`I5{plQxaBg8l?WNQD`Wi1}+KDeu*ijQ1lF!`A~d6tzdIN`2G@Xuz26#T{SHs=-%nA??;loV*d1Re z`i~5YCBMppb$)KJKM_X?8jY@TV(PR#inxSo`n%GTTAuPvs|$AwQ^baEW1@=+JDVrao)N z+laLSNvLYDCZD=?BO;I^PS$Dj!I|g`Q){HF#_z3?1AGX87;zC`XwdHf>)OkNa}>eH zXn(-VWrccS_=@@;SJQD#c4VJ(GY9?voIIOABp8X4ZMot>G8}M0-m6T?J>Ha>p~CR& zEL)Rg62HeU+G}W@_B^Qy_NAgDRid_DCEfgoiVIvpO@2juT)lVxu60}9Xa@RUAsDMn z9552+V#2v394C*K;$iXn*S1!u*&0V03>vfy2$I8F@jKuJ{t%I(e?f}pJYR!lAGbH; z`|jk*Y4Rlvnt%~FEi_nhavjEuxunsGJ@>b#V;BzIaKK2M%tz3gYQ@X?mnO8G@8n7BTRr`%Jb<3dGP3sakUW`iPFqja2hbk&4uwAXZnkdZPA zjLaFPLjA{c{kwKr2Sv^LGiPh-dv+7JMv&%eC0Z{=-9&}DwAiY{haN=EyeW6q_fL4~ zV*)=-;1n(q*rQl|6e>@JTMK9L%g3#Fc_@4;$7OsuS4W>Xwx`9Zp?O@!mj@?jJY2G0 z)schPz6O)(6C_n*a=fb!6p9M9!mbyeYV*yziu&9OZPsmj@D6DJ+0XgFh@R5vroAx! z)JZz)sl_YPhYTEwCDjQ^(Ov^@?g--_0Y>0t@5mdyZ_yTiw~KF%+R#K%2NWIG`NaxP z0Ovt82>W14P^L53vH86jv1cAooNtQ>AJ33>0+fPdF*mB@Rot0xUTt|>y<|lXa*#sA zAQeWmtdgR(YujIZJAfse?QS*<4XO9Vu0ETK{MA z80g|Ul#!9-{yeg7$iy67;^S@`akMv6u&}u##xj@OmUG+O>kq%IsO@{V?(y+n%soQE z%=yUyBL%|lFB(EN1;TxmtBI;~T7)f#07E!qNv2rD!(fn42=akI4amF$e~T`EWeAr{%TWUDLJmE7ukL(qv&;vAlfm+bK%|`jb+?5wZ&0Q zG|I#M$N1)yOyIdQg*NZvsOH&Nx+A# zj-n&MOZ;caV$jYKR#RpoyEj*U+`2?9;J@_R%2pR=pJLN6n+2bx5fM#1S|3?-z)Sou zf=Q7h&2+H%G~ZECH~J<9{ik`l&sZX*7!|FF9_xipq)Pm&RWnQs-=@7W=ZnNGit`>X zA1cU`H)>DDv%@pzvvJUlL`Y^^@5@;xIOL+5;h%T5c7#CJh_1_ zGJWI`oN~aBFou$PA0%Ns!qsnA3H$>U>YeXCwc?Hk=Hu%7GauC-wQnC=0yNK`IZNc( z^2TR~!koc)<5d=9tw_hPkI;xDitRl=#EEKLA4E~1w%8SiC;eW6&q$tamBX;<{KhQ% zIUg9p`8mn$7#I!S#;$3q>HcKowiQ7%+`vQnP<<)dV})D`s>08s!N-tLtWDp}RV@NM zHQE#CG}^rl#vt_zXRv0f=6Q;lun)?qs3ch>KQkDMCo6m^9WV2eG^tQCwrf)w zUBw;6MSuT#?~p+o&}|0srUbb&gA;)5H}@KI;Lszq4mt~j{4DF=oL6%KpyOCIFAc#a z&%$Qd<9f&BwFBtaA3(DQPCL%x^s6tx|0g^VU|n&DY(dKckS+)@e5e|353`txwEen`KoNoOuU1>to4E& z?){20VeGD=(H-Wt#VfAKO%ri+`v7CBKefn?Fc0L-n)B@G@fDAK@bxG~ePF?j2W#B5 zp0Pze$b5-lVee=g1rBiHDv8iDDNJJMErGAx*B*HB!vlx|9C&4Dxy37Tq=K+3^T8t6 z+Jpw0z2ble{sNtYQt3SudY3yve$n>UJs><-mIB~7NbAl{$Nw8=9L>=o7ZWio7wzxN zTcgpz7Usb?xOn`<1>2Z%)FO$MBE4(}DI*vK6B6+9Bt864FVxPo%pz!MIJQ?2;Vj8j zEdq>aLYoLsG(oS?>+W&ngc=maXcTZ}g%PwzAJI-~xw1HmWHb>8Mm9m4UwUaG$x1dq zaF30qf8GyHz2?O^M(3`91BP%WL7qGYH346?5ykihAUg+UtJkJX>*QKIf}Ugq6PD zIj#4$w?`{ZYv9UF?K$CPob9^4q5q;%o{ch3yE&ZFh7umw=NDTYu6X$ojah>vkf1Gu z`snxK%O+yE;r4RYe`WHC1S4@yC7fM3*(RIf`h}{s_xe68`#5kKGLHZwaULg}*__TF zD?>I^sPpH2@a@^}mMZFkBGG~KWB&%uD8$Qcyz{#;*!~{ctlFY;1E(ucU^e^h2ex2z zj^-T$>$#Y4QlZXYw`9S!CxJ7OdNt{9)*T4PMrrJn+VkOj-Cn%u@7LG$Q`CGgQEJjt|%I_9KOLSoEJc?&@rmY7YY_r=71Ni zXwQ(2(=EmfG%);1VZ11ULfv_C-#5+MoPgbF*K*d?kFib@K4SzJiS#;<>SsU+2AZjj zU(74rW&f|3rrmva*MH{zscfJTt6B`PiXFXVzn{Gm6{(VJgevL9NIk8M@PHrIxEr!! z4o_*oAC_R}{wEPpj!gsyyu?q}>Vl={)FVBAE9a0Mc+^UxLVc_L{X_d4_*zl7FD`#~ z(GkCz{LsBgeGt=H2AU7NY|4a>vB1xuhd?HlO;T6v-CT0ytO|%9^UeRps|h5F>r#j1QU zuvC(I?1{#Q))gw31Z`p$Cg`TQ{t&pqkRFWtzN}DR=~`uVa-*H%v5c=2eHa6=So~4A4~s6f;)Q3*wivC=poJ3^>a|{{K7A+pTf7I-`Q_hsUfs=x ztMoJUYlJ^WOT=Jfg&H?)YWqc=?-VuI@wGL*>h6br#xTYs!H~Qe&R#S~?qm9C8LkhN zcGaIcd#a+IJwra{7pyJvcoIr#feXca(mKJCKDZkxwjrWT^i74~ z>~0ex+6I5xaix;8X@McK)Bj7dlj%im7kj~vKX+e>glZ5!m58Up5IfFMi$gh>>Ix?N zgR!g1m5eU_^G7&4Tht~38KgEZ5|ultFvvF7Co*2m7caJ^f=)EbG5Ye>^ph(tUPofKFEowNZgv5V2#uW zFk8|XCekM%(y&|ICg)BAp}zol*w_d#Quk>CM)Lj=e$rn?VdGsWw7PleyfSaDrZl9tk$kN3bifFlp-+I0Tsj8Z932g+Q^ zC1KN62XWNL^@$Ivsy7JlW5I(ZtA=9)IAExcQ~pbaQ<0ZTY%O7g;YF1!x|C&vuF0QN zqgIrtzIzWc^S(NJ(WZ5W_~SWu5~FNCsv$9IjdQr}sP#eoyxSFdqyq+;fz$R9ayJ@n zxY$%~t2y}NmCO1j4K!OKk;D_BFouMfpht0dYIm`<`D;^0(pk=3Tf+bu5JZ0?0Z796;2jo~0 z+0r1jz-7$9E9(*ttpxH?_z^OR03(rqM>I?1$TkS%T5oh2E7ZpSG%5D;B^UiGt|UWoj<6k z-4Fi19LHTboA7jC;xUS>V-Rp&-&z_#bW)dT>{svP`j?a^I>mB04twA zdVSSn>rPsVva@1u0YvNwn@35kG*$^9s~Mfd5kO z3*YWVDAN&nJHD!W)g#~I{mBXYpR2d%Nn978e1|pxOlEJGVk^8{+S*m;R(?Trv6u$! zxH$tAPrdp21FjJmyEoqYcHut_U%(oViPWOKcXb`TW5HE4dcDCD7iU%tt>!*Gg(hmkrI5Van;DC|C zitT3lAq>dYRQA)2OYa$Y5~0x*kG8z=>w>kw`7_sn2rv@oGt@_Ps}sY&9t=1u)RWCV zsqpijDT=B;OXrTU86d(7W&~pYI9loBE3yWvPdr&WY4XerMW=mT+mJZ)M-RL}ISCor=Ov`oDfEs4y`ZR()*kHzi_-2C z5@y&M-t_pa%fDA_tEj^}_AggX%VpZiCW5}COC^nIpWAJlG|TEoZ|sOeAe#lK^-bX1 zFxjRw9WYXwV@R4~xip(oWQ+=R%X8ySmX5_E$(KhQe0j~5AX}9YHEjOSDc`tlO5hJq zVcIxx<@kW--(5)zGjN&_^cD=zajyg z{9X7bi3BwMe<%Jz!dyyeU5Q$wLR~!a{KyG=U}xqI*nRqr86OjMmnRgwhcpgpQ!~>N zJpBJ@KQMqCJX%TuvIXO(3>U7S_yQuSsn!Vk7A%j5FOuD3ss|*VrOWVpC;-2HEr_Q= zJ@d?~H)?K~jJy}dNA9X3`aP)1mNLP3IwO(cV$TDD=?L|RY1e84TdgII_)kRo~1}g=`MO?#YiY&S~@5Q`;%+TT>Wl779zk% z^gRiEX(m#nT`G9|gBku&UT-FQs?Q}dEt$0O$;`Sua_&ewU?lxRN&NQ^zt$Oo!&?H8 zdC0^8LpVvzLc1YPymJ#6qfrzMDx8Ar`P2hk>Bz^UEEB9yKb(B*z3B@c!v5n{U)kG! zABKTVcf_EX5nx19ma}Iup=)SX?*UfSKR)`V&w|mw+4;SWqc)9TTTqw#6@(@muT1#q z-6x&yOa+VJDMi>#YB!(8C9!m&UB8!9*N`;}EllfwAb4dEX0XDp@75e(8Jc@?{bQj8 z>S)ygBXP3++YT7uG(!QuX!s&n^oOCgc=ZdH-y{8nM`bcN?&W9J^p%B9X#0uluERH`^n*2rR z_kSN{%PofELx|2;?ni=0n+1AF)BfxTS7&s8s1YO-EXMvDkE?@+6_}%0p$@3uW<#MN zKVyptx?MF5XZB%&R|8%ffSuT2oI4ZAU+GD9sCk51B)z_*SXCjNV4N45t%kX`m_Eli zq_#+2p(SWge;!}Z&v1KT8QZdtDw94Jrv;?QFvIfYutNQ$?6LuWr1b_%^wLkwLFoRS zq{`8eU?gg@>#SXUX4LI*pHzg;5b9ca-=f_;gTm~un^;@}(%OTAwzp|PtkMo;rF>%2|&3ZWHg_TEp$yL+>^of&nd1jtl!p3>r z9g&fEL7?-Bg}ezQ?ru77^GOhqLP`Ic^`u2f_t8x;}liA?B+b*&*?%!UcTdUCyA8ZZ8YyrT6JCQi>V^- zn*cYjLR!Hkz5M8SRl-h}sPyltFm&wzeN&;1X>{zBQ}+);=B#bX@?P_12yn|?CTqls zCx*WlN9dQ)Rj+Xp+`LbdCo%lQT4gL-xU&CCicf{gj}yUE_8QT0=5v3H2FCv)aA|m{ z#&H7{xD3&GNg)w)^(>O`sXJFKa*w>PGqLyx3|Q!Z1>Ed#d6}ix+V%_o-QR>Q0+^{E zOjeJXE~{H2ZY!zjMeALNTb$>@)d3@H5PL)-SjJd%F^|oQ{5Y`NG#pu>=3MSFVp%^# zDSqGnLA{iXM`4iYqBAVaQeuD#RVH|?G5uhJ3CY;=9?qs~Es@Yrp+4EM-|L+=_gB=$ zWjZ&UTlI*dKg>WZ15A2kJf{@y@ff{J;Y)Tz-biqUvu(wOR_cAV|2!!jtpc9V8PDukdt^-Zd*(OLFvKytl$9eU9C>OP4D{LD?4R1xR6*AYhYjJx6 zw4x6R?;Zz?#OWs?mZ4RIhn+S@-bErPMr#iotT3z{0)M~Yh*h;7bW!On(;`c95I(38 zKjeXvOj@gw=Ci)#lCkCz;nNyMI^_$gF>9!)3-gGc(wrVsJ~3Jqbb}UG725d0`uNxb zM@B=BqGoo^Z5A-4YW703Hi8fi_-9Sxi zcimB{(dY~bPL$~y2S0^a}i)<&k^OwYSrNi~|;SyI2drg8Bs=VJKKIAR9P9zyBoG=nM z&Rk(q01U-AacPYT^|E^s*jkH63QC7I zz5Qv^BoS*SGz^zIw~D{<*0wmScl^sk9bY?LHc1?$GJ`?vu=pN?J`z6K>)C$zT}5xj z1>cJz^;D?OEUVP3{_87{eEZU+Nq-!vMhB_bt!HvBMXcGR>-daa^`2-By4XvETdO1) zfi-O%5VJy^HGTgtHwM3eRmOn7RxKRRnlQ2}VTGyZi=cpaXkyIxZ0--cetla}n>_aT zn)_;$PZCFH%|rKCw3L7MRm7eD_~*}A7aN_+BarNV?ad#1_wKH!WvhO?{hz;}_l_UGR6WUgH_S8Vxp&sSj`PrETh3NraMd8vcaa>J$>_sK zw?waN3_9Lndqs>^7H)i;aZ2+Vzno^=+%iwT)%@g=c4Wa5IPoiw z;gT4F7z1==#>n*?quu*91V@GXPMsUOYA538rufK1U*9WS{Gan>5tJ1lPJQJM5A+f* z5cJApE?8~B?!8T58=I=J{EUUEvcQw8D!HMv=}jEe=Ta7rBmdH zuR$ZO47Q^){0P=+sV?tWBuc^5LnkF!p&nj*VOYohh^&YYv4YpCqq07UlO6$HHfF-d zXf>(>KNV_^0ee2({o6c6Em3IOI`!M-@JGG`t2Xyf`lDjWBv#^MNG4RMEpiX#1gkn9iJCYD7dX3lx=)SWoh zGvS``*UEHMggXJT<4!YS;iC;{;l!{L={Za*#6@>ThYnXMkL8-9b!KVQ3vL z0ZHjz1S@b)is4sQs2eheKK*qR9EEj9By82H0 zvahmo8uql0jo((JNU;ytnr<~DuN5bWzJbDK-W~9SpLwuupyq)b>#+lVMJW~uO)`rl zD$|od^C6DYv=TggdOgud+u1_eJy2{|Pp(2p);^^AH&;iNYEGScc~n7c#Mll<8~dj+ zg`z9Xn~oQRX^*BUs18=5{A6?#-Mk-e?3$XRQpTCZM;_}x@slb(h#mib#tw^hxcDQn z7M41Zc$c7nmJJ7ceA^`BZ(HM&eAdXqKrpd>+Mv9m1W+!x@okBW_(f!`w0yAY$0E+1ezY zEN<{j-<#Wg1W>9!zhLF^_wRvt90x`C+rll80d;Zj*LHs1UMwvw3qej99NwH~e_Qk9 znwHE;Eu$8fQ^9-1Vj$%We$I9r8>gR>u-QqfKf-OU1D@Do_-jI%cmpsz9cY+;vJ;cT zX^wO$5ky6}53`dk2ibuDN=X*1Dh?ZXynz(9T}w&$`1ybsZ7i^{LXCamqZgYF%~sTz zY&o0_5za?_Sn&ujM87XyG@)CU)#blmDpyr);>m?=!yuhzL4c>K~;nz07$g@J7etv1r zN5vOoNy&VmJq$m)i$$p+^gmGvGBFcFjQ%2a&*v+i$1{!7OTY6%#!i|f88{_z(9bY% zi;7kIeo&>SM+{QbaXW9E{Oc$?#X{hEX&O&uNP^b&)zJIbe7Lz0;%V+r_cr!cK|Cat z2#(h10SZ>ATd%#J_QlvNBvEtC&l%b5b*8i=KrIPO>1m8XOLhnT+6P>&iTBa;`i{vU z`7-)BPO+D0qC$1RLx7hbDR8sML%#|1`T$G5?gkyf(&9WG&dyKcD z;4s4Zrc*5=y#iSOmh)$tEbe!O?U+7|j?la$4hRdRYYg-w6h+pC`b;Zxp8Pr5FX7c=YxS%Qex1jkP^gLRI*O=$@^9wn}B) zTxak8b9VBCZrB7Z{aA?y7LMGiI727jOxbhvSo}nQ-0{%gV}n2CadEyv;;^Dcb8!lm zr`q@^j{~WE40&<$`mdCd{`y;wr84bd_$xv&IlOgS(E%^<&m#OO48JAC!jbr0#qZgS;AHtb zlgoZuio1>s|9QgC;G-kKOZ?wAhBn7AEr#aA!veEH-7|Re2j5Ia6!-3LQ}>_x4W}QR zn)6A*Kj&7|MF8-ZT{%xD4#85Y1XC`{HP3->R)uMP4fkj%A9?$q%|8FI2?$-u1!vm* z(;M2bu?j@69?g5=(01L*A#307FbtYx?a}+<^i( z@2UfwRH$`79@J`7!`Zk`U1t0JCmy~ICSynt(UD+)vtZtSvn^gUKOKYo9+w`Fo-os$ zGK%G%!S14Et;=Ika;g#^A~B4`;S@g^y94`}uQwg+YlYE?`V1xr%d@b;FI8VVeq-j! zev2RoBOYf?kq#K4L3#rc?wuT`sU6Lm{x0%&K^`?PuRf;aZ$D2Mh_f;18Q> zJ}pbb6NyOIMXN!UFnS52*Rty`9Vv=A zfC_?zE;bAyTn-090-J>56-7``5l|5oL8VwgMT#^Plp>&_s92ELuoDp}@`y&%_dDCq z-J6TQ_wzjYV_ms(W@l$-+w7FeB}o-R;DJYUNpRB*=M<#ClWE(K;M4=r4Q(c&D!LoA zN>bjwtZ~S;2(LF-IAc*WU+PehKeweY^(hdaM}QWzCxFq|ZBKmars%6MK=SMqGC~2! z4QYfmjT(KOvUFwff}8*Ar&wWqO)m&!LwG;&q8_DQ)q$=?1Zmk=3bov#-`$AAROm+( z|DjJeZ2jWFZhTdZ8RKMYjPcMq=l%C^^soO}`$HErS!4_mo4^o}jTC$}d~)E4uaD87 za$O^rZCmDen7IejJpYD|ADqkDj^+C!!RIfgo|U50E%g-(>mo>%u(EMJQiEqrptu=gE6tv2_?LO(KwAX+c<^AQC&nstUTbVFN@5FkkQ+02&V(1nNgKyVy)1y z*&}9iNXKJR6`GZtqx8C!S3iVA2#1Zz7J8+-HcIS;f@C8f6UQzrx98V$pr4Qx$nbyP zl~pLFgpv&ome~>7!5Y!3Aq8p)hKEt1BO~aF<{^LFTSgtQC%^WD6pxOir&%sKX(qn| zqsQ8;s5Z{eD_>Lg=2LjcT$kRs(&{6ffpcjrP6tNgG|QWf^WyD)ANhPe-j!GOIX<&v zi%*UG7&;vo-S;QZJbRM&-I^Cg$tC6-oc?a;_0uH_KFrj8PXwcJRs*|Xy^9Gqc2XMi z$t1~4IgyK~yQi(I6teCvwZ|{J^Y#Hy(8-D(q`!d^E(bCS=%h&C zkIo4#qRDI(=;wvrwsl1jo(M)am7Ij8YN@82)jHW)t%Yu~BFskKxuMJze^%53ng~V@ z=pthMVicGlW`<{txZ;|onGk06(xiz`0Ha5;9qlAb#Yp;Ou1#**?(>M`4dyxt;P7eJ0D!mT!IRkzt4e1+tTE$<|!@mI}E>mx6EfUCF15*DnQm6}GLyMeqQ@=$9c?Da^Lj(wBkCKd6xc zDtN2z)@?a3X)5Wso>y)h#W<5XYAir&>p*O&@V~mI(T>+k!GuG&olMu$%UaaLs~rpl zR$jgM^JD!xdh8jnHnghT|8n~kow`i${KvXio?UC=WI6{wPyFvk#lTqbwC#UQD_LT= z=Zkl}QTzzT*EF*O+Dq@vTh7I+0d|Tb8E<+W?uJb<@N;9VOd15SV_FnD@QNLL&1BCK zJ+`(*I2SB~CpwSDwU^`Px(3qEu-UQnJMe@bl*mWh0XmoHpa@oFx;hQc8iXumnSQed zS)&>>^Ru%24D9$UD<=y!n%Md&^Dg|Yv$NCuNb5_l78L-G7=vu!*(ne)ESd5}Q2S@7 zN|ceJ8t1MAFkuRtIM2$d*|-Zua$3WxZFE+44jIFSTUEMUZN(8~6pA&}ssn>o3(hMv zDK4TCjwgRVx>hbGIW5jIDGlx>>M9l521Z-;(ZFvYnOaT*-+&KHB6#XHho|F0jJ>yM zM%?&y3foGU1)E|~J1`oxA;l&=Ss4(D>2nTnil-inTn8pRO6|2AVN~#JHF89dlPx5O zlG=^T8rOOt|4&Fz_ywH`t+Lk@aIc_Ip~E3jVlYae^Ee0bCL)AOqwU9czw!~hWDP!N z_ZJNw1LbEhR#;8nq_YGS{!d5#cf9H|hsh6~fKcHS}6vZ_uPjZ%^L2MLrcy?rSuU&mj&Z2R}jC z&I^BZ;{#VWU!9Am<|E&9s=Y0$@5I&}Iw`mW_itYKvpdY}+G7z+h7Qi(_CwungJ}QJ z++eQ@aWfs)c@3NHZvR)66T*Cm!;c~El@;7mgJ-^@qE$-x$AaCzaMu5)zr*OTyav$ER7wPzKO_OLIYArJ{*G|p=Yv4gYx zAAX4wcfjbNM62_L=S-oSxr5Vz`QJ>B+{|D?ZiG8tld^^sozloZc$|#FsPL=pZMACV z8;|4J;W3I#3_%3Xc91kbPul{rcxFc6s!unW1RoJo_($9);{Y+yK12v-KD*a*U;SnFB0t4D2Fbqb*t+T2AR zHF*AY&rA-CZi_n2-XOt(4U}dQ++jbRRtl?vh=)%msE&1tP-$J@`On|5e?Wy%gFLs3 z3oB*!`FLUn7&d#IJ4T*O_1JSlgw5*uX}R$|HeP_|E&ixjbvezZ=g|M)T6*PNiW&yO zT#P4Aeh(35aXu6YbO!LX=sD^GcUgU%Pj2gMqSQ~k6uR9b4pmNoBWOAf zn>d+LybP++VT^M2!7Lz`L4S}#_Xhd_!gvw~OkVi?$ho%$lp(4@2eYC0jNMBs(v$J~ zYu#tuwDZEg4JcKGShANEfxijmqlD7VNDcexYK;y7UR8($g3te&Sa<*h#MC71R$v-l z;m==B@KQHp;6MK9`R9vmzyqDC#7S?ie$p@sCwq?w=QP+f|5b6#wMQSr#b;T{TlG3U z5J2=lilx(mA)Hl+1!QW(s-XZMOv?JY>xZ6OaVpMN#UGkk`>yj|2hJxr`xC)voMeft z&Q}I!u|T=dpV&qwmJx-Wh_7dR zT39fI^Hq=-@a$nU?fIb;=U_}Z^Sma^&Wgb#)tmkF6|Xe{y?q64Gv&kq9)8d2k%HY8 z3!?41g^gNyvy$k>{L>MsLEEK07i9*pR(B&bdI_L)FMNoa$E9MpGd5U z1*`!iC9g#K=q!;uEx{+c3HpS67~6#IrnrBU<=d^kC9%x}pJb8bz!Uy8gqnAAEo$h3 z>9QEbLn*$%|0f=`$t&~)`s$OGO>v>bttT^g-1sT2W$F+6(i2>Ow_jfP zKQ*S9J@9!_l_>j}2u!>K$TN6LLPLcQEe2Bh0gbAJ>1(CwqqvNS9)C6hO&F%duRXE) z`qs1IR3R%-R=^45QJO@kCiRns3Af%tw@#7$&(?J3Y6t{Rgi;tCTpymZ1CvSOYh3BgZ1ez;X z$nJ>UBfy8z*`Er(?xUZdzwVQ1IQz>=xbKGZ2v14YUuWsQvjSNgWTB#A!Gv3{=O|8u zO!8zjP?!WuN}(TA$auH&kTit{zMY3PDTA+}7Dd{rMOD;g@JT{Lh2Ny;m@x-$02);Z z6XbeXNrJM|GN&l2ta<0CXfzMZ{B>Z$P1Y}z2bAP0tToarq^juV%%ZP;l4^>h+}=Vj z1?(l79Ld37Bp^i~r~bg(1|I$J#AGC5?pb_o#t)mp*w15Uxab0WjxOJyd8*yHZRrSW zR;(lO!Qp!*vbxp`gC7;(bOqq%7k-DkEB^d@o3@_+!MvG+e)!==cp-9-U?LGoh5yIN zVm~h*G|BUYE$U{EeIMJXJ_9O)lMs#7V`W7-5=o`lHz{njz-R_eA<|B(Hbz`);jxWd zR3^VkS-Yp;Q3Fphs{!Bk^Z*nzEM&a#;vQiPDoPr1<9en<*;kb)D=zwy046M9g|&$7 z-?V78>ufS1nVS zYmg5Sep1*YmhL1(sbi6mx#>AU76rR9;_rMKYO27j_5AbDa+iz1S?QsD@INGgRQTy> zcVicM77o<1LU-yPiK4?)x=AoTz3dSLKlq=% z@KV(X6q&uApYI4h$R&%FZ&AZZ-mF#6dpp_=e`^7hiC<6p zu+_!uJ@kr$d!+&HtbRJG<)1eeJKk^QKn?ym_A6G%+KLE9%BG5#M(9~43h3m?NRKR; zX!|tsTXgh8t^c}`!9G$}qDmkWz>q**|KG4TvIG)UyyB!o8iSrCF_%jP*(%(qwcDTc z{&;oLN7$4^*4{fX)XuR}ttHx<#NF*bL4?kM9{|l0kK8+FO1(SuPzPo#Y)6grdly|c ze0N89)vuaf>u}ysG`kL5#mqLll>?&&Ah-zQi+H?g85yvY%pOl}cRXB(`BuGRTF;qZ zh*^;UM&maEVb3(ZONkAlhN0QaGzUiGJOCjhnlq<_X*Sh9MGvx+CXt80;n&%Q z)EX69K6E9@&J`^rfT0!uBrj0zAq@&EhV7>{*~ijU`x=AK3%>-{+>L;`SXQEp0&$>2 z=JqTOY4BuK6}A)|4=Dsb*>mYvU%U-8ssf4%8W*a6niNSJ%Odpn$(GX|iI5sNQP72k zCyp>ZRqKY|gU8Eg$a!C$*uE50vf7*uR+0XJ3VHa!4d&nvyHDNOq&^+y*3z#$Psw@# zzKl3Uh1~CcRKDk@>G)aD_rI}-WCxRJ)u0%Ec_kS%*g-Z0>^G9?EC^B7Dppb*7#ine zfMT1d_Zz4?$yV>Rq^b%cCkdQ}L+U3j*3{)8|BWt-b<;|A9FV?>WfKLkJHUTL+?x;Y zJlZV%*3o$4O78H?@HsQE2FNQobts_1ueYn~Plq=n1gQwcnN9o5a~QFJluUWG)f-#p zc=5DFW?THpCKZA^o`#nSsqKgB@b%l0GQW5?o3E(Ezu+38tT77k&q$kg|Nmq0tHat{ zU6e~3=Ql|yJ1%&c0A7txz#n~|>kt2OBQE>(0;B;U}j+zxZ<5T67$26x|KMSnMD~>=u;E%b( zI5?wh5Ao18cX42--R6b1J0?3TvyCN_`FvRuid1Vh>ODtry$obhCa5tp)`1}|7LYV? z4J}1tMsEzfn7z1IPRP1#=VdcG0;i({a9}jfmOz)6oRO6_66(%;VgvgNZOJvbmFFLR zyv*hw|H$St8jyIBJ3O1f$3A}>!T&?>+P|sJRC+u@=76XWkD>%lg|_*XI=ecwXvXx* zzO;z_6N*~3IMWE6pLhf)@UKk5brrSSiD*dXJrdn|=(iH7k6s14yst8{k?M}|ZgC=r z5$ueZCd7e4;E8cX!Tew-;Pc9`v`G}^n+Nx{nt5^m7rC9j%kq3xp}m}cX2?oP1phDk z3cE|$*ZhqbhP)EuA40OBWl1OJuHy{_9#kxt&` zcB5BwwH4vM!O>ec8V5$9XSIkXij|2;!f{f-5P5VdxZqRtij&|7jBu@HS+^W<@ zt6?-wI=bOTz=QU~fN~;DY93^f3K^^JOp1)NiR_CCcLErNTL@-tLGMmlX{pf1B44f(Hp;G)@u(Mxnz%C(0s1ZiKf96sy8(f&z-zJRm^nh+@GIPJYZ?gh!IRaw8)X z@!qFnNsxjz_`uWvY-)@Cq})9Da0kpM;`KLOp5O_4uQ`{mlW* zfzdeY;XMZqm}DFw+%r>oh95|Ek20W|3V-(NM@v`vY^vvXIaQw*C# z5QgJD6n^;ucX>M(z3lnPGhfL3&*HKO^dfQ)@Si!L3q!2yOCm^z8stF;CIe?V1k~b% zU*oBBW2qb~Cc=DHTJ-g>k&mVsOsi69S z-oUwni;Z;zP|JYNvz;NIf^ujzQlO9E{~7uB2c3h9XdQhAc!JBR)qEj{6HyPS@Ok?6 z%DBrN`l#fClPau@le?-a4`q+jU(f4*FSy~6dp)*S&}AA;RyMUF1JCc$ujoYEQ3Vhm%s|2Z!8ahl zM5=1oAhl=mtlkJ5Ieh%P_5RvUysHSD!npBb8Zr9@``h~|*B<{2!8PCgZqc~Rt2W^& ziDR@9W?WnG(OFmh(H&j~)n4rV^(}YmkZHf*aXz5JrvSzu{PT+ES1bAcu|J0v?e0d< z(uttvCg0U#VJ|c^N-)O6LFsh;a_q`iBhyCtPPz4~&-2`15M!gSHt9MXR;BN?>HB`_ z;V(S)tl*c?7?y0X;hy<_0!!~bAHVtH`u!ej8VjOnS+-%>yAv{Jc+$7fIwy!jK&t1D z_%!XRg6PsX8K+8TK9%vn`l#nG|L)_iB}VdOX9-kNL1y-NoIvTw^CCtruVj#oBtc~Q zWNkib#gRM4DOfku660Yu9|wje<@uO{{18GaWMwyIiG>IwS_gJ?qrbg-!|&(uW926# zMid>5-OlAA5xj0>FX+L2Ae)|t_B~-%7!Jpxy^5q`ZaGCoCt^a&j5G)mYZJR0~J!g%meTb?{I)=EcQ< z2of1&iiw$%0Egxf^@Gvb2}5giBud?)c@#(I;lL=IJRKp^~jm?OxT->RY zpjykF7!la_h^ytVm;tgKv?kB_4vcPxWPr7arCmp)6{nfYa4v9YcVIM5k_pz%!pXxt zKhu4b|G~x8J-^3eYe!#naUX-z(5YA}EoQ6ITFgbY2%gIF%5(XQD_g?CwyHs{iCIjq zW=m;7Lm)UzT3+~ff7oZ^D~lU>{+X30_Fp}sEr@AN1S57y;V>HKKEf$Y6stZ4dmbwM zUE_0$yfN%)#4M5(x_oi!4`ZVX-)2lJN};DU-%1cS;Ul%VnMr|;@IPNw6nJ~$P;%-B z+`O^9y$wfdUigidwn8{jP^Bum7yIeGspQDWs5ID~;#86}87Cbo{GC(3`{b$59`tlY z&pflJYW;_*C_;SDomT-TwG6xo3FJs`5Om7NibI9}dYR*2`~W4huCOkG^$B{(x^j@< zk-5M^g@40G&0gNwbUJ*CWkvTZkNTA~mKc#7WM`Hw!V7;_-K2V}#%#y%$cpY%6B=?V zq`K{hvXwicmr9zB>_S22ygpA^nA;PFRc}5;lP8KxExCwRmSfh5 zcg@*UkEi&n`fO^#wqu2FE*HQv)JAE2&}Xk3+mAC zii@_oKA2SNVehM6TzOapDA+7I-j{@$g*3kygL?@r6?#YaKdk!j+XD1i~-V596}^)L`8B3@e$=`p`!{?i$y%FC3Pc_yH7#P|-T;=J%LDEGjJ2Y$T^`KtJ(evZ03 zTQNb!eH98jRaq%pDJF9$!hV}-|d z6P32w9KpzIrVk(H`KqFO#bXBa=UfX;cLYi*{EAI??P)f0C$5|AuAA`OhPTlxM-LKt zX8qp}7uoFj`|n;jwf=Hgv5w;%x8B!tC)6o_sMT=K&usr(xoq+Nm7q+cJC?!nJHr;n4mw#3$1>~SASV`1&%KH)6V<(vp(y&{PWMsqU+Ez zA~*?run^!xxNA~Q)PAM0W>ewMDYbk1)9LTvxP8u-FKqww6TbQ5q9x5^TnBi{X_Nzu ztpj*N3A_|tJpo3|-|C=;4HEn{DIUI$N5^~qJf6=RKCEmomqaVAgae~-uBGFj zxI|d-VtzjG_%Gc~+=`v(`N@rcee-tU93bSt#FTMhG)`hS8z|^Be7<7YBLgnjbSLu2 z>|DI$>I$2H^BI9N0gT4kk=Ehzbm^s#OHh@Lp!Xs3m?P4soUihzoDb}e3+xGCG8|l<)t#;eAi; z{sA}_a5~vG5*jM}HeVI2`Etb&*h$|$`iHkJToA_>TIoytM)%y@YwhO`daf^YyM15X z*)s{g%Ds1w`SqXY4tQ)$EV~bF+<1J;;Ki_{y009onI=NDa?yVWi$Pvl;i(UEsR&&- zMpVL-rox{zWAv)g*%_X$z(o{C8=h$m82(Zk!H+Apts7Zm|D zYJ&hO{N}|=uFpw^N>Ekei2Iq^eWs63&cWg#EA9}f+!Q)Vg(ko_$=WiT)%?fVTWi9n zwq43QH*b7=Ci*R-Wtw^5z-XM`!|^t+n!MYl#zb?O&=P&4!XI3_-AnJKgKkxcGyXj% z^;ukfKswayvVy&jLa&Gj537YJngB7Qw$^(4o{15>LduG!=^=WaDlaosi&rWLqj`H0 z$DnrV{Sps=Xo(9kJ#4YYVl7MnuUo1|3bKkSu&K<3jkd-joDnblJG%S(pZV!#J zuM3J^?&Eo<7=>=YF0$q&^M?0CQeo~ zw;K|-k!D(Y|Kc?qDseJX%|H@zaQ!fzzIfr+JXs3*^)xtk$qK17kD*`e3p&>&39`() zhq!-2Wu_|X>2E1@lnUuHf;*$K4Q7u3U+9h`#=zD-)DXjKj8z>FaF-{6?gvQLU(jWC zOrFB17LG8gqIuDYc+p727ZfR6RW9~ewpi*>eSB3FE%-E`c~v>OnD6aNp!zHn?~hL$ z`ktln4cux^Yuklzkh|qq;D3l|4h*`OAQTK-E^#j)lb)Nooo& zKwdC_2px<9M8pn@1y6jAftNSRdok+Ti%EOzD*XBXAt5U%@DnKKh6}IpUrCpfGCYL| zblTj;-6S?r>O`saKgvUbak4g&7<1cIw;K>Z^-QXIK_2-s)tD7QUvTMtwRjxj?rj|X z2!c-+!x@K5ioL4W!F2?`RLqA2@VeIq-o2+EQuGwTPBfCQ#g<5krY#ta^FGj!S0Xhv zGJ#uZFC*qtA(uuyEw`m8vh&H!1Mec_#w!y1KqwM?r<+frDBjq z5XT_$5z&VPFdDHLR$H2t5L-BNV;2?&M&mSNZ)Z=?beflv(SptnG()H`cXkD%aqcDt z;GTm#m;4!0-*c?jNcGfX%vKKg*ISMc3;WKJ>2cxFvF0cb=N{C>noj(R1qozzqh7PxzlcEU*Xg|7hrEu!R>gn^I3tSJ+Jg3b!K}GVtERQl`u)2s*{kvO<*zChpIeB%JoQnASZrG znOTqQ?G7#QKhwKZ{6Bj4D44_xpGW082aT!>`21B+nHGHXNK(h4mL8hK(xmsL4RujL&H6VZO|k`TgdNJIB*eu-WCYvl}sto@|q2}B&iA;Cg!@iYTa;a znQ14=^~IyGtYERDJtiO}rBDjmN<{K&s8SJnTvbuvMu(VW=?8P(d!^UjJ75Z}DxjBR zH|oJ8OFj7aqn|vpb}Q&r6*eH2UP>8LeK$Z<#nr7iJauo6LVo%vLF(fvWyZ||(RC{P z+wW-5edb&kajJ^ieMQ_GTp~%yerSM#KT>PxO?4M6NAl;ollqoy30p_7rU}Rwrb`?? ziKFHp`3}6opZ^t5lex3Juw`Y!EC5Xz(k4GL`}b?!xfiC%vJz#J1)mbY0DVFILM`PP z4`SLG3t?Vy9IJ3v0m9H*&-*Dl<8G8~2Oy?h0vPH|6u$@LQiM}5IY=W^Y6!7@Xi`ad ziDmQ=BSP4%sD&DF;B_NisgV+h#R?Y^cp9xC+Na@j8FRdi=YKij-nrkM?1ys}ot}Y~ zel7$?qowmLEJVN%vTNOp4|(DL`1FYMn|nhAswxU=A{fHTWxou~Ri2*H=CTTxpfg9f#10dVDhYe`e z=-#)=V&+ZnmwmEPRS$zT9ZU?!;4`meDkVU|HOGDk6RSQXN3FEP1O5#~%v8RMh9pOG zK$-^uue*?Bb-{rreiZJ%mkonGDTibt39IKKY$^$vJr05!aPSbDd_KPi&ZK@`3z{Jf z+(F5fedi0^OX3W>E4Br7?Ur1@eI1c7mHlL@DCMR$TvgdpbN#lR|V65q;=Gp96 zzKG_aUQx#@;tMCY7I9vri@_Q~nE|rm5Nds?TOvx?eTBB9nvAkJ1mkudI2DY>Nz|~p zu{ibBS{A z3^*gA*;VMsuHl+Bj+T<8e$!E($%P4KoDYgrN>Qc?hC+#MK=lU5Y=NGwiqwfaXL=*k19L#G3yan7RUwKPr{sR%kr z*@>*t(=1b!IMk}2LR&H|5`5ui55-ce=C65waknf7MpIi3=<+J!or})3ly8M%hdyS0(<}H3F+a%BeT3ak^w<{-2cFsjq||k{)lo4K%!qtJL}bX+ zjBG?O9@V-(3Y>nJ9WTvdpd-go!k!HD`A6|vUI~$_3-=hd9vq{>-!o#-;Eta>V_pYgB6r&ss+{%MiN*KQ}8EGmNXw zrT+ZuH64#Hho)z}tVG$@f?T5Z3-=e?Sd50x!Rf$g zoQzmdo`xpb($%&qzHi{V#SeM@H$@7j{xj;Acq|Z^AMh5i`1dirC(MPfW$D-vE>d&k z-jlgd@w_(b<&2ztuj;(irni(^c<)p=L_Ga+^&Mj_$cvMgdefs%)_H8*3|JctrSQ}M zIUPG3LemS2pks%cRao-XEH3?AJd1%QR$TONNlnZaD*TJ@&Hm_v39z&rxSq9XJN6{{ zm&J(>HbY6o3;)!ZY5V&xpXT{Bwhk(KddrbGl!;I%c%`T}wDGI^+wV8cz71@7m0-~} zf+6w>u7x?=vCQW2);F$uCbiLA&sUWwJ1k_u=Q|4WUvk{G;x_cAd0H_sh(x&nxku5zwvtNsdOxb-#iJ zw4xN{?6h7*MBun)P)oFYyyAOR9zwI<#14@Iqq-7wA)H-faVj@hfAEme4IX~?R^a?S z7N-M4I4N)hqsMv`F{&;-Us2)zFu(fFkBhWW4Z6b9m$*=#^AO1@MWz zGg!wcmP-SX#3oTssL-MAbcM|k#K#_D4ynuSp%k`6fRHg4qe32U5D`?Ib55p&_9_$K zIu+3)5Ee9&8kwA$Dt0W9^~u_&?1P;?a;QoOuNw|n+qL4XPH9BxvDb7&ZA$u>$mE;1 z0nuk$RnY@=J`L2J#6CNKmU_AIYuBuB=g6WI^iuNAjDnyYN+Pj4;>k(!kNozyn~C;L z^~ZJI-udb$>v1_AfuUw&smMTqX3cZI?umLvCuI%6a}yR}X5s4A@cy7abgdv%hqWq; zdd^9pRBw4zpvua~polg6Q5)vl;Uk&&-135rGWqxqs^neCfzeISp){`=gkxOR#7OE$ z+>hzhN1?_;=RS3c)-=Yk=5SziBcq6w+>S=Z%%K-{EwLuPje1-+BdEA0y^HF4F8lP! zM`Ts00?Tk3Gy0pAK8aDhggFfDaRJ zz9xWYpCRMh(S;EaXRfMAi=_9W;RG<@hFj`}F~M?qcag_%9Mnp%6U6mu@)Vl^L3s*c zLWTccpJR1Wqqy6sN*u2+FpB?a&4cTVroPr_!sMAC*%Uui=Ey6Qy5_aus!WA{C^B=- zy4zspqbhOCh}6jGA!OY-^DH;DD_2n#uca_Pkrme`Z|}j2m#^vz2WM4reex;~-1O4w zZjU1Dm#l>Q*jNvNQ09qor&vxQ*3W=@R6vE_{dKxnuJv>UA&KdvX7B$l>vuuiVSH1L zJe#P{A@Oy^%^9I~akHIgh=r*d@G~lO*fw1$ECiu7`Pef8%hUJ|qe9|8uFeg|Nl#xB zePW%d=puS&rnA#KfKxyXR+BJZ%fW-4bi|~>7cYAC;J9A#vKqlR^*lzQTfvNR_=XqK zm9>cSX~n^1s=#8boE_H)M(>B<`jyl&6^GD{0Dz5NU3`s{Hat$N#Mi9078X7XkN$7t6IVDL{&f+J%d1w8uoz6S0J;_J0sve>Np_Hr8fLgYI&hxM(lCaI#szEG$D>S+NlbgE*5|FySW6#?B%WBY2Q4q){V= z&R9m(G*18%YLsOXk@givn4*pF)l9yo=cj3_&ZIZO&f*3Fp?M6aSbp5cpqARDH^ zrH>>{7JR5-wwc5)Ik@8$AfUoO_~QE|_kIe^Pw%~V)F1O1#h0V^<~XN3tPg5ySCPgN zOdHlyP_~R_78ROXy22)Nkl@{ZAchKmX@!=bpL}4Xrz@-_pDvcIs7g3vc2_4I<5i!% zzHa|hgFRnWus~8k>fi<9l6?6wORnIsy(){CgTa0T$>D_#MyxApL%AH}iGUD-HOSha zD#{2fr4O)B;dgD)01nrgo~|gLVlK`>~yz0f`nd$ir?TgYHpz6|*C0EPb>c zBFdau&mxp@*0z{iL~^iw3rMBHFMhN_^AcA;ZK^6d0!adxkT4cl3zs>CI!*S1rEfrm zzx=t<%YG>i4`Ef&fdWh)o%Ngj1is)*VoaacUO70cd@*e8x{Waxz}!Z zYB}s)RR!MpV#b%me|RoS-ECcAAAt;T(WHAMC3n9GK%**!ktJg5iVeac>$lfI78QQv z!eRY#hT$}?Dzwm5Azqr4i*`E2UN3Y6e;aK`KRkUj|@fRL(yTA zsPgGM&i#1JSA+0OAuC~`I%%S8fs?8|Bx}lU-H*X#EHd)4unHnMIJy>tO@+UahX6uO zKUImckHs`y2V2U33jdB`$G-b89WbiGCWtL1_L5|sQXn$J0mK?$oLP>LYV!!EL(!>9 zxMLUUj#-!Df77p>;L08w zD?s{UtO6?h*@td#d2nq*fnC_Z0?|$=vF`tjBJR#UdC{+mq3UYpd&- zwcng&(=gIt&+~u7HFHVaKe*oy=T6N=ECE@GDoPW;gvYL<9y>Er{%=HVDt-FxEkh>3 zJ4sgJ5TSn3FkvFj)kGjkgP5%_RkTl$!BIL>P$4tBWBn94s41(nlmp}Ze`Z|2`#=4q zKK$xsMacmLvks{6%f0m$UYigvLRI9cil4dcgO{}p;XD+R_uGz(vhI$YwZYR35QJG) z0RB5os#!Ft{@^0yte_V5Fi*_(1S5m=G!~yP7Ctn>|MXp-_5mK%^Hod@mJ!>&10QbR zoH4}9K-^)mns-ivQFTW&XAB1>+~h^h;htMxrKUE>8r6Vw?ebdDAT6gs+sU0rbsW`k zWTaEu3nOjm=s%6NJSudwNL`6C^6baa90On2b(N9NGeC@b**VK#Jo^wUIToJJC80_e z;QaJ#(!baWbgjHr6@V`WX7)H0PoMxXW=nQh2Q#)th3XE-^f{e&()^pL$C6e~Hx>FO zp(~+*U~8!LWGgPcPAnkaQHo*C+405sq0NB_;k3@IB6Wn6r<|~vOzi#5njR461bz`& z)Z5?~4NpnIy9B7hl$n_(BgMrh%{-0^)mJO!kJ$>`Xdz}-0+?`{=75qt9){gMR8%rx~M9c{%ZilRYq*xSto!3KTHzy(pkYFuC&Q8 z;ad*AR4Jjr9h#nz$dd_GIb?9TIPaB>D1y6Re48Do)1f+1bLCAD6=%K=2u`)7t>I1kcs|qV4mXid+f|>Imeh$bGk*yHm#`sVy znN>Xb+Kt!gT>=E96hDW?u=tt9A|q=Ow=2Yx80_32h+56YGu!Y&avyU~xM3lm2lIur zBxn&_!9pVgs(g7?aCDi!Z&al*zl!smn?LYEX7mOHBD+IFZc+hA`RZ!=!zi0h{Q2}= z1#pxKpU1-dev z^HCM-sD-;E@i^o`y249toOz{F`Ql9m*Ga)4(nV3>E7xUR36G}HoH7_XBM%|2e(_mQ z5jU`a3cq(fG9F`4Hp&W?l!AKQp!>WflGy;k6SEk?A&zLBIwC782a=ICJ|$Q}L(>ce zh)aMeOq>}SvZIuW`WR7jmR5U{W^~Dlv;0~x;T8=$VqcnLnhodyj2Her8-A}=;^iykGUG<+?@W%iZ5Hh=~Y%(F+u!av@}p5_flxel{%bN zj_sdwOvK!vN1%Y4m-zzRoGR7RxPD^;_^$xsFN=9ew84ay{MDF20&7)?gGA8U^q-3>6R@-Mp0{uEv9KypWg+)S*IXJ-Tm z%73@=nWGsDtux>y8F!*2y0S&pYP*cuq{1aDERe?Z0Rb9EMqr{lDkx;pwIZq~$2#ao z*$mNH6A#^i2{&wRTp~7V+`a~zNy4d+{?3;u$LC8hklSMy?s(G1Gn&pIcGMBzM#Z8X zcU47cVT?JF8{9(XE;I}l>O(6_q&h^H!=E;MQWeaCCyAuJAc?&x=!!kU9a6^o4}#o5RBA=pnK5~DdT>5%=ZqQLt@)vLgboH`m&fd3A9 zgL)SBKd1v#_zJ16uo-{@y?F#eg)hufad1_YkT9^92|n4~i>NDz5ncPaGZ36KF={)W1q>B_ z!A;|K{oNiglYU!JGJoY6k4d{4%IYandx}8prA%V#P_qS%oq^T8R9P*Ev#~b~YjeV5i8FpN>5S0`Ozx8R?F#i;(p>eN&%Q%yZL z+>skSMPrmi{6vS0h74^z(0v%s?S@plm zWlh9gUsl|xGUBq;NL{uXdCZ9_qh_+6(r!+u)AQ~+T&WANsLIOXYy-`N#kBBk=V9r1vT*+$qRo0k9WTTenGMl7QBhnRXC#2Veu^s(Zl`x-PrnZFBZ9Yl1emK>qeA+fgP-6>dG`T6W&9j)5?tOCN z;NZscIiYGgL=J9-eDQ4+WHu04l^W>9IN+<(KB9yQUjz$j>FG+8ofPuDjAV)my|n6z z6keW~Rb*!!a3Jr`!PRk-U%lR2-+NDe?D?$+%|BE=p8|MG zYgf~(bL?g_K@wO1WQE;)z%hO@-$caM5R0@__a?kAi=qSy?1tyR=|f%>DC6|D;566gV)_NRmpGveKcD{eE7Kl9R6$jd3nHIgVv4Cf zTCPn8Q`irJ+ynyS#wI+qb>^YX_hFYgZ$tm~T{nQ-@@Ex0cCQ#Ob2SoozoJTi2<`1d zqSFS%9nw@jt%8eyj2Hf%S2oWnaU2H+Rl#^s$mXCiI>yTIMuV5GEK`IR{+y|A?VizU zAN-3Cyndv^kB2-JvUzh8d_#p_?SZuFStDTvrYfwCAU)B#A|7S}GG60_KlQPNktzji;;%#Byy=|57klcoQo#cjD*VWtTkh^P3D$qA!nSb+ zI0lIJ%%~w6GPBfKdQ`fP#wCeSDEucoS6|t7<8V(`+%(oY`LC)lXEO!g`JY|}F&O-5 z_;HZLn4rtJ6uNXti^DAxb&U$&JLlV-KSv;HKYrEv&)jEcdn{F8F-wEH22BSMSf&i8 zMH1D(voU#k~IY)6U0a=Fg!rwG? zUcr&-y^!zZ-a7CDzR^SPD18%fas}NGQ^;xbybl86oZrK zBrl|i3uIknAh_)RFMy&G-%GvbEC(?G6*4L%cqGD8wkOiJf?$yj;5EUoJ&gfQh0I7l zN~NYB<*|WcV2st4$ju9LD^+1bVnO!pEPd6pxmO}}wXC4g!)Pk#iIHXbqR}BFYgG6z zHtl!))dzs_okrDe?D63U&s7L3r%!ELpx%ef))RA<+m@ov`sREfK;Bns1w6CW5zB{J z%G}wrQp}!Z_>b5$)Trv=Mem-&KL(T^o1Moo2203rh6-PVTeu9?@V$#J8@~H88UZ>k zbCY{y0pfGhDEb@%6gb&Pu@(Zm@kViOfeM^!^ezzCpQBHy(1$HuQCT>R^NPC)IT}@c z)le1IPtbL`BYI7RKXK>A%NLf%7OpDFc-Y)V+^5r#f?Sk^JrKJdw(&-TAK@vIxp0U~ z5RRNj0^{Jj2|QB}U`E4bhP)l#Ni3y8hZENo6{9PWXx)Pp1^Jm7r@~MiXaXIUK&LRT-?ZxD!;X(DXfeDNz^HMXWBp*sgS4h(&mw& zu%+F8s+dS!wPW+vZL)e_y9Vg0s<24{#HhL^=QnLv>IEpwRE1T5P@zX=MR4wzAqE#D za7U&`lRN!yacpV4rcSt>c_}?$_`B%D=W2wM2zK zBdC?q=oz%6Dj}-bWXwoVmVHIVBEduy{=u6re6R8!ufdeytExu^qY`CJW5+h0DBMD0 z#!s~5yaL7E?QTJTEkF|U%uZw5)S11xt z;XgY+Be!c4=z3KJg8mMu3Aim{KBp_7_<+37>7{i=Mp?10K2&=Q*^hWk!-P|sE#&;N9!Q*c&?}2G(x`ExF}6g2 z8&iw}r@wB{@NJlts|s5hi?QF~|1>-CGFCfPQASbbu~P8B@~vmZ#bjItAt-Fmkh+8w z2z(w;;a5FA_Po+N;P;{`ZbE}MN|bFD{0+-7RZ%x@Bbn|=yB0siCekdqsf}FWbuPi$ zoKsaWb``0=t!W!>BN{at0i}Yqb+I)G@9b zysEe%Rvdu=kN9bo-?PUHy3E^?OP3Qa*1@PE`JIZ7=2z#|MFAsKI6B#AulCT-8>u+t|GRP^h_NMIApAFebE# z79@C}%*euPEPr%1nFVg`QMPogVYhj{szlj3K~w^maEl4vxW}bUrnAPX<#=nR!k<&} z#-D$?{%gEP%SxQcd4iLbRHl$ho_exo%YcRof9L*f^Tu2UG^*m-U@OZnZ36c1i4MpT zfiLZ1eHR2-7j`Hb!G z4@4e?c6HnJeIK6x0ck2Ycn*Q%JuVEsHbQd)|;f z9;53R00@)CGbwWGJCmYz$8p5PW#FP;{ja49n;ke0NkyjLH?2(X@iClB5abR1St7_= za&RFuti14>@vt;&JY9*Z7v}^p;htW@xffj66x%Hoevj$j?0NLn_mQY^Ri8~wdRFNk zRYzpwnB3x-{r$QSOvyEGu09tq%HYf>9-MPixY@%rAb9B3%S-vYD-GTE#0k%TbjFh9 zL&}#zvuS+L62XLUTINK_np_MuZ!-S#f#v-APq+P>%%a&v0Yz-1O6l?Fi&ID#93IixA7rk*vau zKU$7_5pz>jqKxbwa4Js#6KayR+0@1Go}KhwdTTUsg3(|=Q%i7-e15S4!}1Sl;qV1lbgc z8GSTgtV?fPY4s6Y2~{QB$M&L+X0_MZSCpM*lwB&QNB|SY*-_v`FXUNLYaor92lq&j1PXxJ=l`zq+jfx<85`%h*X0N%$?;Y9YiT9@Mz^j+6 zVD@r?62XL9bEl?gZG14E8nc!}^^%}1E^q>fWsn0CZp>YpL}R%kojSE=mO{sqJj%#0 z2LqA-Cfr_O<3B2!+*OE|Kaur66*8M56v85{Vjrx8(t5@0mCml<+U-M>`JdMAvdUsa zj2e0JuFv{qHEbAEh5aI!G92FoRLBQ0oXN!Aa9+%rEadQ!fX2CmkP2TdGP)9F zLLI?>E(9jb8sihe(>ds7`%|1)M(#_7IS*08RLCb%sMf@zX$Q!f)9~FN_SyK#Vmu40 z3VT&Z|1ugjD&!Njow)S>7*vvNqG5o(o5D>{;ZN`X<(?1fF2s%D^zKSU>MjRmjt|Uh zRbRTf-1msn^VL&9|Be-xc#LElqq!N_M0&B`-cPyq_-Er0Q9S$B^YlGJ{9(p6bxuz_Z54dCb8|+ec0u}Ea^)sbe}zy|;Xl=(?a6gbVOgUpBy-=P$$%RR zcB0NEC4vu=;A#s}v&JLSR8AJd*#&XY!V?V?@IO8Cv4G$o5^o^k7ES&HJHVtJbz)>> zLN0|y(^6$^;Lhq4Wjh1`W|Mbd!UQyTxt=y6y>%)I9FpOOmYE*mIgn}X{EZwDsqotm zY1?9DRTyZzd&zmX9?L?=FL>ddTc{I0p@a(}cLN(FL?;2Re@~&xrmQGrOVE8`5h?P{ zHB#Am8}S2K8?4N1Dx^ZaD&K+0$v^TP_%Ic&Xdtu35L-Z|H%B?+52CBDSmX)zQ`xHZ@LhquOclCnYTaL~5ZPxC)d%SO%?X_@5H#KK8i=>Tg&3#)bm6ZGk8f-I=_1oQ& zxwK_N=;34~%JxDMFmV#VgqwD+;i~Dh3}SFjBo+S|OYJ^RLPmvrHL$L{|7Iy=8wF{J zE(20jjr8CX{->t}+l>E@voct!w{Qnt2P4%%y7Y2KO>{0eLHi&Te%6WOPZZzV!_yVo zCfB3YU?YBJpYrw-!H3&7i4@Y(C*%f9CH;{DOw0{!SY`>jV*8*s7u*6e#NPiT3(q03RI=AuCYi)1itd(e5d&9bS;)t zt8Lx2ZPfU76d@rNU9;fBgtXRipi&Xo7W<3xPyE-USc{=1WfzgMJ3c8Ee3+D0HYwQh z{tuL(CH|*nT~mVak4k-Q&1?y7439#Tuuin-C0kmTx2u2oeWTaXO|Z{G7b@7?i+4sm zp!^J>56w57qgGYa6TcZ=adU&8{kHivT0g(0D{?R37y&G1?u*!dxKAF;&ANQV^uF*M z6x$SDNnJx4C%^Qa+@fFL9HT!H^t2p|v1&^HR_xCTuJU@3ro$`lT z4Uf6~83M8DjmRAZRiB)!Ah0Lp2k6t!UWP2wi9RPzls-7+M$cCjr2P^5sqsHmKWUgS zd(6R|YiX>OIVF*iLrHlP(rxnvNUTE6RVd1&RgGDx9GEZx9rUbJZ{2IG~XPrsoE1cIxaEfB&UgN!o#Uvtm-`@`}b;pQMKD$$V@RY*9~!#rG7 z)FNT#uCraOEY7aP&)fkkuX1Yk=&VWP+C47YCnxKS(1suumUGk!b54x7n+<`++wiyW z?iu!qvs!C}58_(twwRE)Vo@=)E89#ND(6HBDAr*iW5r7z}xkGZ`z!QG{SUh`+S!z1}$ zOIi;+zwGN=8aVuOU^IK4OJ*W?-7WK~ZYPnfwlA~Q{I6HQ{{Jz0C;bTBdRnl`NV{Ms z4rpp2X5Wi$U4a#YQK450|Ja#jNq^mSGXmjqEWt0Z(dC8zF3-1#mg-k~XxW13N7`qH zydVQAWVkjYYv=qy!$@vABT8uoEt?;7cA-uKe`=ph;R77#uLq2z^!EHnmLo# zExMi)&)2-yKs=w@>ui^452?vGnq^yRZ?g(8H7X}Qm!h;Q?O$eTu`XeOQkBKc*dkpqo|Az5MrSU zcV?x9SQB!L;CU2Bd#3BHP*Fr%cMRe7tasoweo9T4R{{s9Nan;ii^==OVj`-d3C8ln zf+4i@w8qFpkX0gnCtDWytybTZwPe zEGOH=T4;JJzP3Pgv-JMV8gdR=TWvCr5{0LX9ZRlS{?@K8=t?0tL7GZy=w5X693I~_ly{eW}W9c2L@e95D$+>nj^(h zMue2y?ZBD5YV_bN%?}Vb3ojSQ*#)UI3MWChm8&w!E(FD3djc5Z>0{Kd4^iYLjY~s> zG+Wd!NE>(GqB~Ehd;m;f6i1$YD%ddrx*w6t_>U(ZC^cvJTkxs< z`rX=FRv(~%g5t2jqZPbB@OH3y2c9V6_~|_+FL@G<*_6W9v`~9TyS+~&J>&VR5@mz< zLYxSOuwDZwYrVABKc!QZcjVt`Ar|W&t!Jzs$8W8#4m{!Kn-up$>=i4U;`aP&ir+G$ z#O4M!VsBFLw5hSH1b%ZQbl^4qX&jz&-I9Gl4-Aptz3AINzZr9oL!-iJUc$v9W|@}|&HSm2#8$xDaV;Ep z;4f%d=w_+rNOl|k+NUQ{Sxh86u|NgJri4UYuh`<|a4cRpDhEdQ(a@Z1#o`crMn*5e zoFps`g5;I2$PNqtjAB z=)luJ?53H%7O2@CVvoJx>shQuFiuc52S(%McfR*yacUYCo@}_H_SP~OTq)vl=u|LD z#M$>mERV8o)2WU?p&&p*(0jgJvTI}c;J_1p9-(Ln@u3;#vlYvA>sO=qlz$yD7^eT- zb6%}WW`X8m)UTw#816T@zsjo{Hvy*P2RD51Lzh_|+lrYLI_HRH5i3RsFIyQ4B7T7c zs0v#pK+JZav8akBp3DqcS?X-E6;Y{pQue9ew?2+&2V=J6%&)Z##Cre)@g|eWwqlRA zAK!fujBu;;Io1B&t)APzhO0}?|K?j9CQt5LT5d?YyW;du(5w$>aJD~N|KLaC^iTC% Wb15`uO2m#Ogq#0K@NFq=+x|b9$9Y)* literal 0 HcmV?d00001 diff --git a/tests/benchmark/.gitignore b/tests/benchmark/.gitignore new file mode 100644 index 00000000..7d166b06 --- /dev/null +++ b/tests/benchmark/.gitignore @@ -0,0 +1,2 @@ +old-pd_env.txt +pd_env.txt diff --git a/tests/benchmark/README.md b/tests/benchmark/README.md new file mode 100644 index 00000000..d21999ab --- /dev/null +++ b/tests/benchmark/README.md @@ -0,0 +1,11 @@ +# Benchmark Test + +## Data + +* Aishell + +## Docker + +``` +registry.baidubce.com/paddlepaddle/paddle 2.1.1-gpu-cuda10.2-cudnn7 59d5ec1de486 +``` diff --git a/tests/benchmark/run_all.sh b/tests/benchmark/run_all.sh new file mode 100755 index 00000000..6f707cdc --- /dev/null +++ b/tests/benchmark/run_all.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +CUR_DIR=${PWD} +ROOT_DIR=../../ + +# 提供可稳定复现性能的脚本,默认在标准docker环境内py37执行: +# collect env info +bash ${ROOT_DIR}/utils/pd_env_collect.sh +#cat pd_env.txt + + +# 1 安装该模型需要的依赖 (如需开启优化策略请注明) +#pushd ${ROOT_DIR}/tools; make; popd +#source ${ROOT_DIR}/tools/venv/bin/activate +#pushd ${ROOT_DIR}; bash setup.sh; popd + + +# 2 拷贝该模型需要数据、预训练模型 + +# 执行目录:需说明 +#pushd ${ROOT_DIR}/examples/aishell/s1 +pushd ${ROOT_DIR}/examples/tiny/s1 + +mkdir -p exp/log +. path.sh +#bash local/data.sh &> exp/log/data.log + +# 3 批量运行(如不方便批量,1,2需放到单个模型中) + +model_mode_list=(conformer transformer) +fp_item_list=(fp32) +bs_item_list=(32 64 96) +for model_mode in ${model_mode_list[@]}; do + for fp_item in ${fp_item_list[@]}; do + for bs_item in ${bs_item_list[@]} + do + echo "index is speed, 1gpus, begin, ${model_name}" + run_mode=sp + CUDA_VISIBLE_DEVICES=0 bash ${CUR_DIR}/run_benchmark.sh ${run_mode} ${bs_item} ${fp_item} 500 ${model_mode} # (5min) + sleep 60 + echo "index is speed, 8gpus, run_mode is multi_process, begin, ${model_name}" + run_mode=mp + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash ${CUR_DIR}/run_benchmark.sh ${run_mode} ${bs_item} ${fp_item} 500 ${model_mode} + sleep 60 + done + done +done + +popd # aishell/s1 diff --git a/tests/benchmark/run_benchmark.sh b/tests/benchmark/run_benchmark.sh new file mode 100755 index 00000000..bd4655d1 --- /dev/null +++ b/tests/benchmark/run_benchmark.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +set -xe + +# 运行示例:CUDA_VISIBLE_DEVICES=0 bash run_benchmark.sh ${run_mode} ${bs_item} ${fp_item} 500 ${model_mode} +# 参数说明 +function _set_params(){ + run_mode=${1:-"sp"} # 单卡sp|多卡mp + batch_size=${2:-"64"} + fp_item=${3:-"fp32"} # fp32|fp16 + max_iter=${4:-"500"} # 可选,如果需要修改代码提前中断 + model_name=${5:-"model_name"} + run_log_path=${TRAIN_LOG_DIR:-$(pwd)} # TRAIN_LOG_DIR 后续QA设置该参数 + +# 以下不用修改 + device=${CUDA_VISIBLE_DEVICES//,/ } + arr=(${device}) + num_gpu_devices=${#arr[*]} + log_file=${run_log_path}/${model_name}_${run_mode}_bs${batch_size}_${fp_item}_${num_gpu_devices} +} + +function _train(){ + echo "Train on ${num_gpu_devices} GPUs" + echo "current CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES, gpus=$num_gpu_devices, batch_size=$batch_size" + + train_cmd="--benchmark-batch-size ${batch_size} + --benchmark-max-step ${max_iter} + conf/${model_name}.yaml ${model_name}" + + case ${run_mode} in + sp) train_cmd="bash local/train.sh "${train_cmd}"" ;; + mp) + train_cmd="bash local/train.sh "${train_cmd}"" ;; + *) echo "choose run_mode(sp or mp)"; exit 1; + esac + + # 以下不用修改 + timeout 15m ${train_cmd} > ${log_file} 2>&1 + if [ $? -ne 0 ];then + echo -e "${model_name}, FAIL" + export job_fail_flag=1 + else + echo -e "${model_name}, SUCCESS" + export job_fail_flag=0 + fi + + trap 'for pid in $(jobs -pr); do kill -KILL $pid; done' INT QUIT TERM + + if [ $run_mode = "mp" -a -d mylog ]; then + rm ${log_file} + cp mylog/workerlog.0 ${log_file} + fi +} + +_set_params $@ +_train + diff --git a/tests/chains/README.md b/tests/chains/README.md new file mode 100644 index 00000000..1719c40a --- /dev/null +++ b/tests/chains/README.md @@ -0,0 +1,9 @@ +For lite\_train\_infer, Run +``` +bash lite_train_infer.sh +``` + +For whole\_train\_infer, Run +``` +bash whole_train_infer.sh +``` diff --git a/tests/chains/ds2_params_lite_train_infer.txt b/tests/chains/ds2_params_lite_train_infer.txt new file mode 100644 index 00000000..70d54f8b --- /dev/null +++ b/tests/chains/ds2_params_lite_train_infer.txt @@ -0,0 +1,51 @@ +===========================train_params=========================== +model_name:deepspeech2 +python:python3.7 +gpu_list:0|0,1 +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +## +trainer:norm_train +norm_train: ../../../deepspeech/exps/deepspeech2/bin/train.py --config conf/deepspeech2.yaml --model_type offline --profiler-options "" --output exp/deepspeech_tiny --seed 0 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval: ../../../deepspeech/exps/deepspeech2/bin/test.py --nproc 1 --config conf/deepspeech2.yaml --checkpoint_path exp/deepspeech_tiny/checkpoints/9 --result_file tests/9.rsl --model_type offline +null:null +## +===========================infer_params=========================== +null:null +null:null +norm_export: ../../../deepspeech/exps/deepspeech2/bin/export.py --nproc 1 --config conf/deepspeech2.yaml --model_type offline --checkpoint_path exp/deepspeech_tiny/checkpoints/9 --export_path exp/deepspeech_tiny/checkpoints/9.jit +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +infer_model:null +infer_export:null +infer_quant:null +inference:null +--use_gpu:null +--enable_mkldnn:null +--cpu_threads:null +--rec_batch_num:null +--use_tensorrt:null +--precision:null +--det_model_dir:null +--image_dir:null +--save_log_path:null +--benchmark:null +null:null diff --git a/tests/chains/ds2_params_whole_train_infer.txt b/tests/chains/ds2_params_whole_train_infer.txt new file mode 100644 index 00000000..90ce7d89 --- /dev/null +++ b/tests/chains/ds2_params_whole_train_infer.txt @@ -0,0 +1,51 @@ +===========================train_params=========================== +model_name:deepspeech2 +python:python3.7 +gpu_list:0,1|0 +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +## +trainer:norm_train +norm_train: ../../../deepspeech/exps/deepspeech2/bin/train.py --config conf/deepspeech2.yaml --model_type offline --profiler-options "" --output exp/deepspeech_whole --seed 0 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval: ../../../deepspeech/exps/deepspeech2/bin/test.py --nproc 1 --config conf/deepspeech2.yaml --result_file tests/49.rsl --checkpoint_path exp/deepspeech_whole/checkpoints/49 --model_type offline +null:null +## +===========================infer_params=========================== +null:null +null:null +norm_export: ../../../deepspeech/exps/deepspeech2/bin/export.py --nproc 1 --config conf/deepspeech2.yaml --model_type offline --checkpoint_path exp/deepspeech_whole/checkpoints/49 --export_path exp/deepspeech_whole/checkpoints/49.jit +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +infer_model:null +infer_export:null +infer_quant:null +inference:null +--use_gpu:null +--enable_mkldnn:null +--cpu_threads:null +--rec_batch_num:null +--use_tensorrt:null +--precision:null +--det_model_dir:null +--image_dir:null +--save_log_path:null +--benchmark:null +null:null diff --git a/tests/chains/lite_train_infer.sh b/tests/chains/lite_train_infer.sh new file mode 100644 index 00000000..76b22a38 --- /dev/null +++ b/tests/chains/lite_train_infer.sh @@ -0,0 +1,5 @@ +bash prepare.sh ds2_params_lite_train_infer.txt lite_train_infer +cd ../../examples/tiny/s0 +source path.sh +bash ../../../tests/chains/test.sh ../../../tests/chains/ds2_params_lite_train_infer.txt lite_train_infer +cd ../../../tests/chains diff --git a/tests/chains/prepare.sh b/tests/chains/prepare.sh new file mode 100644 index 00000000..73a30283 --- /dev/null +++ b/tests/chains/prepare.sh @@ -0,0 +1,84 @@ +#!/bin/bash +FILENAME=$1 +# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer'] +MODE=$2 + +dataline=$(cat ${FILENAME}) + +# parser params +IFS=$'\n' +lines=(${dataline}) +function func_parser_key(){ + strs=$1 + IFS=":" + array=(${strs}) + tmp=${array[0]} + echo ${tmp} +} +function func_parser_value(){ + strs=$1 + IFS=":" + array=(${strs}) + tmp=${array[1]} + echo ${tmp} +} +IFS=$'\n' +# The training params +model_name=$(func_parser_value "${lines[1]}") + +trainer_list=$(func_parser_value "${lines[14]}") + +# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer'] +MODE=$2 + +if [ ${MODE} = "lite_train_infer" ];then + # pretrain lite train data + curPath=$(readlink -f "$(dirname "$0")") + cd ${curPath}/../../examples/tiny/s0 + source path.sh + # download audio data + bash ./local/data.sh || exit -1 + # download language model + bash local/download_lm_en.sh + if [ $? -ne 0 ]; then + exit 1 + fi + cd ${curPath} + +elif [ ${MODE} = "whole_train_infer" ];then + curPath=$(readlink -f "$(dirname "$0")") + cd ${curPath}/../../examples/aishell/s0 + source path.sh + # download audio data + bash ./local/data.sh || exit -1 + # download language model + bash local/download_lm_ch.sh + if [ $? -ne 0 ]; then + exit 1 + fi + cd ${curPath} +elif [ ${MODE} = "whole_infer" ];then + curPath=$(readlink -f "$(dirname "$0")") + cd ${curPath}/../../examples/aishell/s0 + source path.sh + # download audio data + bash ./local/data.sh || exit -1 + # download language model + bash local/download_lm_ch.sh + if [ $? -ne 0 ]; then + exit 1 + fi + cd ${curPath} +else + curPath=$(readlink -f "$(dirname "$0")") + cd ${curPath}/../../examples/aishell/s0 + source path.sh + # download audio data + bash ./local/data.sh || exit -1 + # download language model + bash local/download_lm_ch.sh + if [ $? -ne 0 ]; then + exit 1 + fi + cd ${curPath} +fi diff --git a/tests/chains/speedyspeech_params_lite.txt b/tests/chains/speedyspeech_params_lite.txt new file mode 100644 index 00000000..c1cfb8f5 --- /dev/null +++ b/tests/chains/speedyspeech_params_lite.txt @@ -0,0 +1,51 @@ +===========================train_params=========================== +model_name:speedyspeech +python:python3.7 +gpu_list:1|0,1 +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +## +trainer:norm_train +norm_train:../examples/speedyspeech/baker/train.py --train-metadata=train_data/mini_BZNSYP/train/norm/metadata.jsonl --dev-metadata=train_data/mini_BZNSYP/dev/norm/metadata.jsonl --config=lite_train_infer.yaml --output-dir=exp/default +null:null +null:null +null:null +null:null +null:null +## +===========================eval_params=========================== +eval:../examples/speedyspeech/baker/synthesize_e2e.py --speedyspeech-config=../examples/speedyspeech/baker/conf/default.yaml --speedyspeech-checkpoint=exp/default/checkpoints/snapshot_iter_90.pdz --speedyspeech-stat=pretrain_models/speedyspeech_baker_ckpt_0.4/speedy_speech_stats.npy --pwg-config=../examples/parallelwave_gan/baker/conf/default.yaml --pwg-checkpoint=pretrain_models/pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz --pwg-stat=pretrain_models/pwg_baker_ckpt_0.4/pwg_stats.npy --text=../examples/speedyspeech/baker/sentences.txt --output-dir=e2e --inference-dir=inference --device="gpu" --phones-dict=../examples/speedyspeech/baker/phones.txt --tones-dict=../examples/speedyspeech/baker/tones.txt +null:null +## +===========================infer_params=========================== +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +## +null:null +null:null +null:null +inference:../examples/speedyspeech/baker/inference.py --inference-dir=pretrain_models/speedyspeech_pwg_inference_0.4 --text=../examples/speedyspeech/baker/sentences.txt --output-dir=inference_out --enable-auto-log --phones-dict=../examples/speedyspeech/baker/phones.txt --tones-dict=../examples/speedyspeech/baker/tones.txt --output-dir=e2e --inference-dir=inference +--use_gpu:True +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null +null:null diff --git a/tests/chains/test.sh b/tests/chains/test.sh new file mode 100644 index 00000000..0b2b4f58 --- /dev/null +++ b/tests/chains/test.sh @@ -0,0 +1,371 @@ +#!/bin/bash +# usage: bash test.sh ***.txt MODE + +FILENAME=$1 +# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer'] +MODE=$2 + +dataline=$(cat ${FILENAME}) + +# parser params +IFS=$'\n' +lines=(${dataline}) + +function func_parser_key(){ + strs=$1 + IFS=":" + array=(${strs}) + tmp=${array[0]} + echo ${tmp} +} +function func_parser_value(){ + strs=$1 + IFS=":" + array=(${strs}) + tmp=${array[1]} + echo ${tmp} +} +function func_set_params(){ + key=$1 + value=$2 + if [ ${key} = "null" ];then + echo " " + elif [[ ${value} = "null" ]] || [[ ${value} = " " ]] || [ ${#value} -le 0 ];then + echo " " + else + echo "${key}=${value}" + fi +} +function func_parser_params(){ + strs=$1 + IFS=":" + array=(${strs}) + key=${array[0]} + tmp=${array[1]} + IFS="|" + res="" + for _params in ${tmp[*]}; do + IFS="=" + array=(${_params}) + mode=${array[0]} + value=${array[1]} + if [[ ${mode} = ${MODE} ]]; then + IFS="|" + #echo $(func_set_params "${mode}" "${value}") + echo $value + break + fi + IFS="|" + done + echo ${res} +} +function status_check(){ + last_status=$1 # the exit code + run_command=$2 + run_log=$3 + if [ $last_status -eq 0 ]; then + echo -e "\033[33m Run successfully with command - ${run_command}! \033[0m" | tee -a ${run_log} + else + echo -e "\033[33m Run failed with command - ${run_command}! \033[0m" | tee -a ${run_log} + fi +} + +IFS=$'\n' +# The training params +model_name=$(func_parser_value "${lines[1]}") +python=$(func_parser_value "${lines[2]}") +gpu_list=$(func_parser_value "${lines[3]}") +train_use_gpu_key=$(func_parser_key "${lines[4]}") +train_use_gpu_value=$(func_parser_value "${lines[4]}") +autocast_list=$(func_parser_value "${lines[5]}") +autocast_key=$(func_parser_key "${lines[5]}") +epoch_key=$(func_parser_key "${lines[6]}") +epoch_num=$(func_parser_params "${lines[6]}") +save_model_key=$(func_parser_key "${lines[7]}") +train_batch_key=$(func_parser_key "${lines[8]}") +train_batch_value=$(func_parser_params "${lines[8]}") +pretrain_model_key=$(func_parser_key "${lines[9]}") +pretrain_model_value=$(func_parser_value "${lines[9]}") +train_model_name=$(func_parser_value "${lines[10]}") +train_infer_img_dir=$(func_parser_value "${lines[11]}") +train_param_key1=$(func_parser_key "${lines[12]}") +train_param_value1=$(func_parser_value "${lines[12]}") + +trainer_list=$(func_parser_value "${lines[14]}") +trainer_norm=$(func_parser_key "${lines[15]}") +norm_trainer=$(func_parser_value "${lines[15]}") +pact_key=$(func_parser_key "${lines[16]}") +pact_trainer=$(func_parser_value "${lines[16]}") +fpgm_key=$(func_parser_key "${lines[17]}") +fpgm_trainer=$(func_parser_value "${lines[17]}") +distill_key=$(func_parser_key "${lines[18]}") +distill_trainer=$(func_parser_value "${lines[18]}") +trainer_key1=$(func_parser_key "${lines[19]}") +trainer_value1=$(func_parser_value "${lines[19]}") +trainer_key2=$(func_parser_key "${lines[20]}") +trainer_value2=$(func_parser_value "${lines[20]}") + +eval_py=$(func_parser_value "${lines[23]}") +eval_key1=$(func_parser_key "${lines[24]}") +eval_value1=$(func_parser_value "${lines[24]}") + +save_infer_key=$(func_parser_key "${lines[27]}") +export_weight=$(func_parser_key "${lines[28]}") +norm_export=$(func_parser_value "${lines[29]}") +pact_export=$(func_parser_value "${lines[30]}") +fpgm_export=$(func_parser_value "${lines[31]}") +distill_export=$(func_parser_value "${lines[32]}") +export_key1=$(func_parser_key "${lines[33]}") +export_value1=$(func_parser_value "${lines[33]}") +export_key2=$(func_parser_key "${lines[34]}") +export_value2=$(func_parser_value "${lines[34]}") + +# parser inference model +infer_model_dir_list=$(func_parser_value "${lines[36]}") +infer_export_list=$(func_parser_value "${lines[37]}") +infer_is_quant=$(func_parser_value "${lines[38]}") +# parser inference +inference_py=$(func_parser_value "${lines[39]}") +use_gpu_key=$(func_parser_key "${lines[40]}") +use_gpu_list=$(func_parser_value "${lines[40]}") +use_mkldnn_key=$(func_parser_key "${lines[41]}") +use_mkldnn_list=$(func_parser_value "${lines[41]}") +cpu_threads_key=$(func_parser_key "${lines[42]}") +cpu_threads_list=$(func_parser_value "${lines[42]}") +batch_size_key=$(func_parser_key "${lines[43]}") +batch_size_list=$(func_parser_value "${lines[43]}") +use_trt_key=$(func_parser_key "${lines[44]}") +use_trt_list=$(func_parser_value "${lines[44]}") +precision_key=$(func_parser_key "${lines[45]}") +precision_list=$(func_parser_value "${lines[45]}") +infer_model_key=$(func_parser_key "${lines[46]}") +image_dir_key=$(func_parser_key "${lines[47]}") +infer_img_dir=$(func_parser_value "${lines[47]}") +save_log_key=$(func_parser_key "${lines[48]}") +benchmark_key=$(func_parser_key "${lines[49]}") +benchmark_value=$(func_parser_value "${lines[49]}") +infer_key1=$(func_parser_key "${lines[50]}") +infer_value1=$(func_parser_value "${lines[50]}") + +LOG_PATH="./tests/output" +mkdir -p ${LOG_PATH} +status_log="${LOG_PATH}/results.log" + + +function func_inference(){ + IFS='|' + _python=$1 + _script=$2 + _model_dir=$3 + _log_path=$4 + _img_dir=$5 + _flag_quant=$6 + # inference + for use_gpu in ${use_gpu_list[*]}; do + if [ ${use_gpu} = "False" ] || [ ${use_gpu} = "cpu" ]; then + for use_mkldnn in ${use_mkldnn_list[*]}; do + if [ ${use_mkldnn} = "False" ] && [ ${_flag_quant} = "True" ]; then + continue + fi + for threads in ${cpu_threads_list[*]}; do + for batch_size in ${batch_size_list[*]}; do + _save_log_path="${_log_path}/infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_${batch_size}.log" + set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}") + set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}") + set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}") + set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}") + set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}") + set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}") + command="${_python} ${_script} > ${_save_log_path} 2>&1 " + eval $command + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check $last_status "${command}" "${status_log}" + done + done + done + elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then + for use_trt in ${use_trt_list[*]}; do + for precision in ${precision_list[*]}; do + if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then + continue + fi + if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then + continue + fi + if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [ ${_flag_quant} = "True" ]; then + continue + fi + for batch_size in ${batch_size_list[*]}; do + _save_log_path="${_log_path}/infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log" + set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}") + set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}") + set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}") + set_tensorrt=$(func_set_params "${use_trt_key}" "${use_trt}") + set_precision=$(func_set_params "${precision_key}" "${precision}") + set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}") + set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}") + command="${_python} ${_script} > ${_save_log_path} 2>&1 " + eval $command + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check $last_status "${command}" "${status_log}" + + done + done + done + else + echo "Does not support hardware other than CPU and GPU Currently!" + fi + done +} + +if [ ${MODE} = "infer" ]; then + GPUID=$3 + if [ ${#GPUID} -le 0 ];then + env=" " + else + env="export CUDA_VISIBLE_DEVICES=${GPUID}" + fi + # set CUDA_VISIBLE_DEVICES + eval $env + export Count=0 + IFS="|" + infer_run_exports=(${infer_export_list}) + infer_quant_flag=(${infer_is_quant}) + for infer_model in ${infer_model_dir_list[*]}; do + # run export + if [ ${infer_run_exports[Count]} != "null" ];then + set_export_weight=$(func_set_params "${export_weight}" "${infer_model}") + set_save_infer_key=$(func_set_params "${save_infer_key}" "${infer_model}") + export_cmd="${python} ${norm_export} ${set_export_weight} ${set_save_infer_key}" + eval $export_cmd + status_export=$? + if [ ${status_export} = 0 ];then + status_check $status_export "${export_cmd}" "${status_log}" + fi + fi + #run inference + is_quant=${infer_quant_flag[Count]} + func_inference "${python}" "${inference_py}" "${infer_model}" "${LOG_PATH}" "${infer_img_dir}" ${is_quant} + Count=$(($Count + 1)) + done + +else + IFS="|" + export Count=0 + USE_GPU_KEY=(${train_use_gpu_value}) + for gpu in ${gpu_list[*]}; do + use_gpu=${USE_GPU_KEY[Count]} + Count=$(($Count + 1)) + if [ ${gpu} = "-1" ];then + env="" + elif [ ${#gpu} -le 1 ];then + env="export CUDA_VISIBLE_DEVICES=${gpu}" + eval ${env} + elif [ ${#gpu} -le 15 ];then + IFS="," + array=(${gpu}) + env="export CUDA_VISIBLE_DEVICES=${array[0]}" + IFS="|" + else + IFS=";" + array=(${gpu}) + ips=${array[0]} + gpu=${array[1]} + IFS="|" + env=" " + fi + for autocast in ${autocast_list[*]}; do + for trainer in ${trainer_list[*]}; do + flag_quant=False + if [ ${trainer} = ${pact_key} ]; then + run_train=${pact_trainer} + run_export=${pact_export} + flag_quant=True + elif [ ${trainer} = "${fpgm_key}" ]; then + run_train=${fpgm_trainer} + run_export=${fpgm_export} + elif [ ${trainer} = "${distill_key}" ]; then + run_train=${distill_trainer} + run_export=${distill_export} + elif [ ${trainer} = ${trainer_key1} ]; then + run_train=${trainer_value1} + run_export=${export_value1} + elif [[ ${trainer} = ${trainer_key2} ]]; then + run_train=${trainer_value2} + run_export=${export_value2} + else + run_train=${norm_trainer} + run_export=${norm_export} + fi + + if [ ${run_train} = "null" ]; then + continue + fi + + set_autocast=$(func_set_params "${autocast_key}" "${autocast}") + set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}") + set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}") + set_batchsize=$(func_set_params "${train_batch_key}" "${train_batch_value}") + set_train_params1=$(func_set_params "${train_param_key1}" "${train_param_value1}") + set_use_gpu=$(func_set_params "${train_use_gpu_key}" "${use_gpu}") + save_log="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}" + + # load pretrain from norm training if current trainer is pact or fpgm trainer + if [ ${trainer} = ${pact_key} ] || [ ${trainer} = ${fpgm_key} ]; then + set_pretrain="${load_norm_train_model}" + fi + + set_save_model=$(func_set_params "${save_model_key}" "${save_log}") + if [ ${#gpu} -le 2 ];then # train with cpu or single gpu + cmd="${python} ${run_train} " + elif [ ${#gpu} -le 15 ];then # train with multi-gpu + gsu=${gpu//,/ } + nump=`echo $gsu | wc -w` + cmd="${python} ${run_train} --nproc=$nump" + else # train with multi-machine + cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_train_params1}" + fi + # run train + # eval "unset CUDA_VISIBLE_DEVICES" + eval $cmd + status_check $? "${cmd}" "${status_log}" + + set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}") + # save norm trained models to set pretrain for pact training and fpgm training + if [ ${trainer} = ${trainer_norm} ]; then + load_norm_train_model=${set_eval_pretrain} + fi + # run eval + if [ ${eval_py} != "null" ]; then + IFS="," + array=(${gpu}) + IFS="|" + env="export CUDA_VISIBLE_DEVICES=${array[0]}" + eval $env + set_eval_params1=$(func_set_params "${eval_key1}" "${eval_value1}") + eval_cmd="${python} ${eval_py} ${set_eval_pretrain} ${set_use_gpu} ${set_eval_params1}" + eval $eval_cmd + status_check $? "${eval_cmd}" "${status_log}" + fi + # run export model + if [ ${run_export} != "null" ]; then + # run export model + save_infer_path="${save_log}" + set_export_weight=$(func_set_params "${export_weight}" "${save_log}/${train_model_name}") + set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_path}") + export_cmd="${python} ${run_export} ${set_export_weight} ${set_save_infer_key}" + eval $export_cmd + status_check $? "${export_cmd}" "${status_log}" + + #run inference + eval $env + save_infer_path="${save_log}" + func_inference "${python}" "${inference_py}" "${save_infer_path}" "${LOG_PATH}" "${train_infer_img_dir}" "${flag_quant}" + #eval "unset CUDA_VISIBLE_DEVICES" + fi + done # done with: for trainer in ${trainer_list[*]}; do + done # done with: for autocast in ${autocast_list[*]}; do + done # done with: for gpu in ${gpu_list[*]}; do +fi # end if [ ${MODE} = "infer" ]; then diff --git a/tests/chains/whole_train_infer.sh b/tests/chains/whole_train_infer.sh new file mode 100644 index 00000000..496041a7 --- /dev/null +++ b/tests/chains/whole_train_infer.sh @@ -0,0 +1,5 @@ +bash prepare.sh ds2_params_whole_train_infer.txt whole_train_infer +cd ../../examples/aishell/s0 +source path.sh +bash ../../../tests/chains/test.sh ../../../tests/chains/ds2_params_whole_train_infer.txt whole_train_infer +cd ../../../tests/chains diff --git a/tests/deepspeech2_model_test.py b/tests/deepspeech2_model_test.py index 1776736f..00df8195 100644 --- a/tests/deepspeech2_model_test.py +++ b/tests/deepspeech2_model_test.py @@ -16,7 +16,7 @@ import unittest import numpy as np import paddle -from deepspeech.models.deepspeech2 import DeepSpeech2Model +from deepspeech.models.ds2 import DeepSpeech2Model class TestDeepSpeech2Model(unittest.TestCase): diff --git a/tests/deepspeech2_online_model_test.py b/tests/deepspeech2_online_model_test.py new file mode 100644 index 00000000..6264070b --- /dev/null +++ b/tests/deepspeech2_online_model_test.py @@ -0,0 +1,186 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import paddle + +from deepspeech.models.ds2_online import DeepSpeech2ModelOnline + + +class TestDeepSpeech2ModelOnline(unittest.TestCase): + def setUp(self): + paddle.set_device('cpu') + + self.batch_size = 2 + self.feat_dim = 161 + max_len = 210 + + # (B, T, D) + audio = np.random.randn(self.batch_size, max_len, self.feat_dim) + audio_len = np.random.randint(max_len, size=self.batch_size) + audio_len[-1] = max_len + # (B, U) + text = np.array([[1, 2], [1, 2]]) + text_len = np.array([2] * self.batch_size) + + self.audio = paddle.to_tensor(audio, dtype='float32') + self.audio_len = paddle.to_tensor(audio_len, dtype='int64') + self.text = paddle.to_tensor(text, dtype='int32') + self.text_len = paddle.to_tensor(text_len, dtype='int64') + + def test_ds2_1(self): + model = DeepSpeech2ModelOnline( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + def test_ds2_2(self): + model = DeepSpeech2ModelOnline( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=True) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + def test_ds2_3(self): + model = DeepSpeech2ModelOnline( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + def test_ds2_4(self): + model = DeepSpeech2ModelOnline( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=True) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + def test_ds2_5(self): + model = DeepSpeech2ModelOnline( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + def test_ds2_6(self): + model = DeepSpeech2ModelOnline( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + rnn_direction='bidirect', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=False) + loss = model(self.audio, self.audio_len, self.text, self.text_len) + self.assertEqual(loss.numel(), 1) + + def test_ds2_7(self): + use_gru = False + model = DeepSpeech2ModelOnline( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=1, + rnn_size=1024, + rnn_direction='forward', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=use_gru) + model.eval() + paddle.device.set_device("cpu") + de_ch_size = 8 + + eouts, eouts_lens, final_state_h_box, final_state_c_box = model.encoder( + self.audio, self.audio_len) + eouts_by_chk_list, eouts_lens_by_chk_list, final_state_h_box_chk, final_state_c_box_chk = model.encoder.forward_chunk_by_chunk( + self.audio, self.audio_len, de_ch_size) + eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1) + eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list) + decode_max_len = eouts.shape[1] + eouts_by_chk = eouts_by_chk[:, :decode_max_len, :] + self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) + self.assertEqual( + paddle.allclose(final_state_h_box, final_state_h_box_chk), True) + if use_gru is False: + self.assertEqual( + paddle.allclose(final_state_c_box, final_state_c_box_chk), True) + + def test_ds2_8(self): + use_gru = True + model = DeepSpeech2ModelOnline( + feat_size=self.feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=1, + rnn_size=1024, + rnn_direction='forward', + num_fc_layers=2, + fc_layers_size_list=[512, 256], + use_gru=use_gru) + model.eval() + paddle.device.set_device("cpu") + de_ch_size = 8 + + eouts, eouts_lens, final_state_h_box, final_state_c_box = model.encoder( + self.audio, self.audio_len) + eouts_by_chk_list, eouts_lens_by_chk_list, final_state_h_box_chk, final_state_c_box_chk = model.encoder.forward_chunk_by_chunk( + self.audio, self.audio_len, de_ch_size) + eouts_by_chk = paddle.concat(eouts_by_chk_list, axis=1) + eouts_lens_by_chk = paddle.add_n(eouts_lens_by_chk_list) + decode_max_len = eouts.shape[1] + eouts_by_chk = eouts_by_chk[:, :decode_max_len, :] + self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True) + self.assertEqual( + paddle.allclose(final_state_h_box, final_state_h_box_chk), True) + if use_gru is False: + self.assertEqual( + paddle.allclose(final_state_c_box, final_state_c_box_chk), True) + + +if __name__ == '__main__': + unittest.main() -- GitLab