diff --git a/.mergify.yml b/.mergify.yml index 03e57e14b7bdec0e85b981f4cfa7d7d92e936402..6ec28ae81a0bcede7563888f4d4e99782f6a59a8 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -39,12 +39,30 @@ pull_request_rules: actions: label: remove: ["conflicts"] - - name: "auto add label=enhancement" + - name: "auto add label=S2T" conditions: - files~=^deepspeech/ actions: label: - add: ["enhancement"] + add: ["S2T"] + - name: "auto add label=T2S" + conditions: + - files~=^parakeet/ + actions: + label: + add: ["T2S"] + - name: "auto add label=Audio" + conditions: + - files~=^paddleaudio/ + actions: + label: + add: ["Audio"] + - name: "auto add label=TextProcess" + conditions: + - files~=^text_processing/ + actions: + label: + add: ["TextProcess"] - name: "auto add label=Example" conditions: - files~=^examples/ diff --git a/README.md b/README.md index 468f42a61b9c310f96eba5c6a0b13bc12faa7282..c501e0c3723ea9420dbeabb246720017c7806cf9 100644 --- a/README.md +++ b/README.md @@ -10,10 +10,9 @@ English | [简体中文](README_ch.md)

- Quick Start - | Tutorials - | Models List - + Quick Start + | Tutorials + | Models List

------------------------------------------------------------------------------------ @@ -27,37 +26,31 @@ how they can install it, how they can use it --> -**PaddleSpeech** is an open-source toolkit on [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform for two critical tasks in Speech - **Automatic Speech Recognition (ASR)** and **Text-To-Speech Synthesis (TTS)**, with modules involving state-of-art and influential models. +**PaddleSpeech** is an open-source toolkit on [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform for a variety of critical tasks in speech, with state-of-art and influential models. -Via the easy-to-use, efficient, flexible and scalable implementation, our vision is to empower both industrial application and academic research, including training, inference & testing module, and deployment. Besides, this toolkit also features at: -- **Fast and Light-weight**: we provide a high-speed and ultra-lightweight model that is convenient for industrial deployment. +Via the easy-to-use, efficient, flexible and scalable implementation, our vision is to empower both industrial application and academic research, including training, inference & testing modules, and deployment process. To be more specific, this toolkit features at: +- **Fast and Light-weight**: we provide high-speed and ultra-lightweight models that are convenient for industrial deployment. - **Rule-based Chinese frontend**: our frontend contains Text Normalization (TN) and Grapheme-to-Phoneme (G2P, including Polyphone and Tone Sandhi). Moreover, we use self-defined linguistic rules to adapt Chinese context. -- **Varieties of Functions that Vitalize Research**: - - *Integration of mainstream models and datasets*: the toolkit implements modules that participate in the whole pipeline of both ASR and TTS, and uses datasets like LibriSpeech, LJSpeech, AIShell, etc. See also [model lists](#models-list) for more details. - - *Support of ASR streaming and non-streaming data*: This toolkit contains non-streaming/streaming models like [DeepSpeech2](http://proceedings.mlr.press/v48/amodei16.pdf), [Transformer](https://arxiv.org/abs/1706.03762), [Conformer](https://arxiv.org/abs/2005.08100) and [U2](https://arxiv.org/pdf/2012.05481.pdf). +- **Varieties of Functions that Vitalize both Industrial and Academia**: + - *Implementation of critical audio tasks*: this toolkit contains audio functions like Speech Translation (ST), Automatic Speech Recognition (ASR), Text-To-Speech Synthesis (TTS), Voice Cloning(VC), Punctuation Restoration, etc. + - *Integration of mainstream models and datasets*: the toolkit implements modules that participate in the whole pipeline of the speech tasks, and uses mainstream datasets like LibriSpeech, LJSpeech, AIShell, CSMSC, etc. See also [model lists](#models-list) for more details. + - *Cross-domain application*: as an extension of the application of traditional audio tasks, we combine the aforementioned tasks with other fields like NLP. Let's install PaddleSpeech with only a few lines of code! >Note: The official name is still deepspeech. 2021/10/26 -``` shell -# 1. Install essential libraries and paddlepaddle first. -# install prerequisites -sudo apt-get install -y sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev libsndfile1 -# `pip install paddlepaddle-gpu` instead if you are using GPU. -pip install paddlepaddle - -# 2.Then install PaddleSpeech. +If you are using Ubuntu, PaddleSpeech can be set up with pip installation (with root privilege). +```shell git clone https://github.com/PaddlePaddle/DeepSpeech.git cd DeepSpeech pip install -e . ``` - ## Table of Contents The contents of this README is as follow: -- [Alternative Installation](#installation) +- [Alternative Installation](#alternative-installation) - [Quick Start](#quick-start) - [Models List](#models-list) - [Tutorials](#tutorials) @@ -75,12 +68,15 @@ The base environment in this page is If you want to set up PaddleSpeech in other environment, please see the [ASR installation](docs/source/asr/install.md) and [TTS installation](docs/source/tts/install.md) documents for all the alternatives. ## Quick Start +> Note: the current links to `English ASR` and `English TTS` are not valid. -> Note: `ckptfile` should be replaced by real path that represents files or folders later. Similarly, `exp/default` is the folder that contains the pretrained models. +Just a quick test of our functions: [English ASR](link/hubdetail?name=deepspeech2_aishell&en_category=AutomaticSpeechRecognition) and [English TTS](link/hubdetail?name=fastspeech2_baker&en_category=TextToSpeech) by typing message or upload your own audio file. -Try a tiny ASR DeepSpeech2 model training on toy set of LibriSpeech: +Developers can have a try of our model with only a few lines of code. -```shell +A tiny **ASR** DeepSpeech2 model training on toy set of LibriSpeech: + +```bash cd examples/tiny/s0/ # source the environment source path.sh @@ -90,28 +86,50 @@ bash local/data.sh bash local/test.sh conf/deepspeech2.yaml ckptfile offline ``` -For TTS, try FastSpeech2 on LJSpeech: -- Download LJSpeech-1.1 from the [ljspeech official website](https://keithito.com/LJ-Speech-Dataset/) and our prepared durations for fastspeech2 [ljspeech_alignment](https://paddlespeech.bj.bcebos.com/MFA/LJSpeech-1.1/ljspeech_alignment.tar.gz). -- Assume your path to the dataset is `~/datasets/LJSpeech-1.1` and `./ljspeech_alignment` accordingly, preprocess your data and then use our pretrained model to synthesize: -```shell -bash ./local/preprocess.sh conf/default.yaml -bash ./local/synthesize_e2e.sh conf/default.yaml exp/default ckptfile -``` +For **TTS**, try pretrained FastSpeech2 + Parallel WaveGAN on CSMSC: +```bash +cd examples/csmsc/tts3 +# download the pretrained models and unaip them +wget https://paddlespeech.bj.bcebos.com/Parakeet/pwg_baker_ckpt_0.4.zip +unzip pwg_baker_ckpt_0.4.zip +wget https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_baker_ckpt_0.4.zip +unzip fastspeech2_nosil_baker_ckpt_0.4.zip +# source the environment +source path.sh +# run end-to-end synthesize +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/synthesize_e2e.py \ + --fastspeech2-config=fastspeech2_nosil_baker_ckpt_0.4/default.yaml \ + --fastspeech2-checkpoint=fastspeech2_nosil_baker_ckpt_0.4/snapshot_iter_76000.pdz \ + --fastspeech2-stat=fastspeech2_nosil_baker_ckpt_0.4/speech_stats.npy \ + --pwg-config=pwg_baker_ckpt_0.4/pwg_default.yaml \ + --pwg-checkpoint=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \ + --pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \ + --text=${BIN_DIR}/../sentences.txt \ + --output-dir=exp/default/test_e2e \ + --inference-dir=exp/default/inference \ + --device="gpu" \ + --phones-dict=fastspeech2_nosil_baker_ckpt_0.4/phone_id_map.txt +``` If you want to try more functions like training and tuning, please see [ASR getting started](docs/source/asr/getting_started.md) and [TTS Basic Use](/docs/source/tts/basic_usage.md). ## Models List +PaddleSpeech supports a series of most popular models, summarized in [released models](./docs/source/released_model.md) with available pretrained models. - -PaddleSpeech ASR supports a lot of mainstream models, which are summarized as follow. For more information, please refer to [ASR Models](./docs/source/asr/released_model.md). +ASR module contains *Acoustic Model* and *Language Model*, with the following details: +> Note: The `Link` should be code path rather than download links. + + @@ -125,7 +143,7 @@ The current hyperlinks redirect to [Previous Parakeet](https://github.com/Paddle - + @@ -200,7 +218,7 @@ PaddleSpeech TTS mainly contains three modules: *Text Frontend*, *Acoustic Model @@ -208,41 +226,41 @@ PaddleSpeech TTS mainly contains three modules: *Text Frontend*, *Acoustic Model - + - + @@ -250,26 +268,26 @@ PaddleSpeech TTS mainly contains three modules: *Text Frontend*, *Acoustic Model @@ -277,14 +295,14 @@ PaddleSpeech TTS mainly contains three modules: *Text Frontend*, *Acoustic Model diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 7b929f8b701b57ce399013b36fcda867d1287fff..56743629b6b1cfac1e0ca1e57ab3748290789cad 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -21,11 +21,6 @@ from typing import Optional import jsonlines import numpy as np import paddle -from paddle import distributed as dist -from paddle import inference -from paddle.io import DataLoader -from yacs.config import CfgNode - from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset @@ -44,6 +39,10 @@ from deepspeech.utils import mp_tools from deepspeech.utils.log import Autolog from deepspeech.utils.log import Log from deepspeech.utils.utility import UpdateConfig +from paddle import distributed as dist +from paddle import inference +from paddle.io import DataLoader +from yacs.config import CfgNode logger = Log(__name__).getlog() @@ -153,8 +152,12 @@ class DeepSpeech2Trainer(Trainer): def setup_model(self): config = self.config.clone() with UpdateConfig(config): - config.model.feat_size = self.train_loader.collate_fn.feature_size - config.model.dict_size = self.train_loader.collate_fn.vocab_size + if self.train: + config.model.feat_size = self.train_loader.collate_fn.feature_size + config.model.dict_size = self.train_loader.collate_fn.vocab_size + else: + config.model.feat_size = self.test_loader.collate_fn.feature_size + config.model.dict_size = self.test_loader.collate_fn.vocab_size if self.args.model_type == 'offline': model = DeepSpeech2Model.from_config(config.model) @@ -167,6 +170,11 @@ class DeepSpeech2Trainer(Trainer): logger.info(f"{model}") layer_tools.print_params(model, logger.info) + self.model = model + logger.info("Setup model!") + + if not self.train: + return grad_clip = ClipGradByGlobalNormWithLog( config.training.global_grad_clip) @@ -180,74 +188,76 @@ class DeepSpeech2Trainer(Trainer): weight_decay=paddle.regularizer.L2Decay( config.training.weight_decay), grad_clip=grad_clip) - - self.model = model self.optimizer = optimizer self.lr_scheduler = lr_scheduler - logger.info("Setup model/optimizer/lr_scheduler!") + logger.info("Setup optimizer/lr_scheduler!") def setup_dataloader(self): config = self.config.clone() config.defrost() - config.collator.keep_transcription_text = False - - config.data.manifest = config.data.train_manifest - train_dataset = ManifestDataset.from_config(config) - - config.data.manifest = config.data.dev_manifest - dev_dataset = ManifestDataset.from_config(config) - - config.data.manifest = config.data.test_manifest - test_dataset = ManifestDataset.from_config(config) - - if self.parallel: - batch_sampler = SortagradDistributedBatchSampler( + if self.train: + # train + config.data.manifest = config.data.train_manifest + train_dataset = ManifestDataset.from_config(config) + if self.parallel: + batch_sampler = SortagradDistributedBatchSampler( + train_dataset, + batch_size=config.collator.batch_size, + num_replicas=None, + rank=None, + shuffle=True, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + else: + batch_sampler = SortagradBatchSampler( + train_dataset, + shuffle=True, + batch_size=config.collator.batch_size, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + + config.collator.keep_transcription_text = False + collate_fn_train = SpeechCollator.from_config(config) + self.train_loader = DataLoader( train_dataset, - batch_size=config.collator.batch_size, - num_replicas=None, - rank=None, - shuffle=True, - drop_last=True, - sortagrad=config.collator.sortagrad, - shuffle_method=config.collator.shuffle_method) + batch_sampler=batch_sampler, + collate_fn=collate_fn_train, + num_workers=config.collator.num_workers) + + # dev + config.data.manifest = config.data.dev_manifest + dev_dataset = ManifestDataset.from_config(config) + + config.collator.augmentation_config = "" + config.collator.keep_transcription_text = False + collate_fn_dev = SpeechCollator.from_config(config) + self.valid_loader = DataLoader( + dev_dataset, + batch_size=int(config.collator.batch_size), + shuffle=False, + drop_last=False, + collate_fn=collate_fn_dev, + num_workers=config.collator.num_workers) + logger.info("Setup train/valid Dataloader!") else: - batch_sampler = SortagradBatchSampler( - train_dataset, - shuffle=True, - batch_size=config.collator.batch_size, - drop_last=True, - sortagrad=config.collator.sortagrad, - shuffle_method=config.collator.shuffle_method) - - collate_fn_train = SpeechCollator.from_config(config) - - config.collator.augmentation_config = "" - collate_fn_dev = SpeechCollator.from_config(config) - - config.collator.keep_transcription_text = True - config.collator.augmentation_config = "" - collate_fn_test = SpeechCollator.from_config(config) - - self.train_loader = DataLoader( - train_dataset, - batch_sampler=batch_sampler, - collate_fn=collate_fn_train, - num_workers=config.collator.num_workers) - self.valid_loader = DataLoader( - dev_dataset, - batch_size=int(config.collator.batch_size), - shuffle=False, - drop_last=False, - collate_fn=collate_fn_dev, - num_workers=config.collator.num_workers) - self.test_loader = DataLoader( - test_dataset, - batch_size=config.decoding.batch_size, - shuffle=False, - drop_last=False, - collate_fn=collate_fn_test, - num_workers=config.collator.num_workers) - logger.info("Setup train/valid/test Dataloader!") + # test + config.data.manifest = config.data.test_manifest + test_dataset = ManifestDataset.from_config(config) + + config.collator.augmentation_config = "" + config.collator.keep_transcription_text = True + collate_fn_test = SpeechCollator.from_config(config) + + self.test_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_test, + num_workers=config.collator.num_workers) + logger.info("Setup test Dataloader!") class DeepSpeech2Tester(DeepSpeech2Trainer): @@ -401,6 +411,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): class DeepSpeech2ExportTester(DeepSpeech2Tester): def __init__(self, config, args): super().__init__(config, args) + self.apply_static = True def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): if self.args.model_type == "online": diff --git a/deepspeech/exps/lm/transformer/__init__.py b/deepspeech/exps/lm/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..185a92b8d94d3426d616c0624f0f2ee04339349e --- /dev/null +++ b/deepspeech/exps/lm/transformer/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/deepspeech/exps/lm/transformer/bin/cacu_perplexity.py b/deepspeech/exps/lm/transformer/bin/cacu_perplexity.py new file mode 100644 index 0000000000000000000000000000000000000000..29a880f88fb308888a406d09623b65d6298ec68f --- /dev/null +++ b/deepspeech/exps/lm/transformer/bin/cacu_perplexity.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +import configargparse + + +def get_parser(): + """Get default arguments.""" + parser = configargparse.ArgumentParser( + description="The parser for caculating the perplexity of transformer language model ", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, ) + + parser.add_argument( + "--rnnlm", type=str, default=None, help="RNNLM model file to read") + + parser.add_argument( + "--rnnlm-conf", + type=str, + default=None, + help="RNNLM model config file to read") + + parser.add_argument( + "--vocab_path", + type=str, + default=None, + help="vocab path to for token2id") + + parser.add_argument( + "--bpeprefix", + type=str, + default=None, + help="The path of bpeprefix for loading") + + parser.add_argument( + "--text_path", + type=str, + default=None, + help="The path of text file for testing ") + + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpu to use, 0 for using cpu instead") + + parser.add_argument( + "--dtype", + choices=("float16", "float32", "float64"), + default="float32", + help="Float precision (only available in --api v2)", ) + + parser.add_argument( + "--output_dir", + type=str, + default=".", + help="The output directory to store the sentence PPL") + + return parser + + +def main(args): + parser = get_parser() + args = parser.parse_args(args) + from deepspeech.exps.lm.transformer.lm_cacu_perplexity import run_get_perplexity + run_get_perplexity(args) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/deepspeech/exps/lm/transformer/lm_cacu_perplexity.py b/deepspeech/exps/lm/transformer/lm_cacu_perplexity.py new file mode 100644 index 0000000000000000000000000000000000000000..b63bcd08b9aff1944d759b6afc099abfbc532b8e --- /dev/null +++ b/deepspeech/exps/lm/transformer/lm_cacu_perplexity.py @@ -0,0 +1,132 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Caculating the PPL of LM model +import os + +import numpy as np +import paddle +from paddle.io import DataLoader +from yacs.config import CfgNode + +from deepspeech.io.collator import TextCollatorSpm +from deepspeech.io.dataset import TextDataset +from deepspeech.models.lm_interface import dynamic_import_lm +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + + +def get_config(config_path): + confs = CfgNode(new_allowed=True) + confs.merge_from_file(config_path) + return confs + + +def load_trained_lm(args): + lm_config = get_config(args.rnnlm_conf) + lm_model_module = lm_config.model_module + lm_class = dynamic_import_lm(lm_model_module) + lm = lm_class(**lm_config.model) + model_dict = paddle.load(args.rnnlm) + lm.set_state_dict(model_dict) + return lm, lm_config + + +def write_dict_into_file(ppl_dict, name): + with open(name, "w") as f: + for key in ppl_dict.keys(): + f.write(key + " " + ppl_dict[key] + "\n") + return + + +def cacu_perplexity( + lm_model, + lm_config, + args, + log_base=None, ): + unit_type = lm_config.data.unit_type + batch_size = lm_config.decoding.batch_size + num_workers = lm_config.decoding.num_workers + text_file_path = args.text_path + + total_nll = 0.0 + total_ntokens = 0 + ppl_dict = {} + len_dict = {} + text_dataset = TextDataset.from_file(text_file_path) + collate_fn_text = TextCollatorSpm( + unit_type=unit_type, + vocab_filepath=args.vocab_path, + spm_model_prefix=args.bpeprefix) + train_loader = DataLoader( + text_dataset, + batch_size=batch_size, + collate_fn=collate_fn_text, + num_workers=num_workers) + + logger.info("start caculating PPL......") + for i, (keys, ys_input_pad, ys_output_pad, + y_lens) in enumerate(train_loader()): + + ys_input_pad = paddle.to_tensor(ys_input_pad) + ys_output_pad = paddle.to_tensor(ys_output_pad) + _, unused_logp, unused_count, nll, nll_count = lm_model.forward( + ys_input_pad, ys_output_pad) + nll = nll.numpy() + nll_count = nll_count.numpy() + for key, _nll, ntoken in zip(keys, nll, nll_count): + if log_base is None: + utt_ppl = np.exp(_nll / ntoken) + else: + utt_ppl = log_base**(_nll / ntoken / np.log(log_base)) + + # Write PPL of each utts for debugging or analysis + ppl_dict[key] = str(utt_ppl) + len_dict[key] = str(ntoken) + + total_nll += nll.sum() + total_ntokens += nll_count.sum() + logger.info("Current total nll: " + str(total_nll)) + logger.info("Current total tokens: " + str(total_ntokens)) + write_dict_into_file(ppl_dict, os.path.join(args.output_dir, "uttPPL")) + write_dict_into_file(len_dict, os.path.join(args.output_dir, "uttLEN")) + if log_base is None: + ppl = np.exp(total_nll / total_ntokens) + else: + ppl = log_base**(total_nll / total_ntokens / np.log(log_base)) + + if log_base is None: + log_base = np.e + else: + log_base = log_base + + return ppl, log_base + + +def run_get_perplexity(args): + if args.ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + if args.ngpu == 1: + device = "gpu:0" + else: + device = "cpu" + paddle.set_device(device) + dtype = getattr(paddle, args.dtype) + logger.info(f"Decoding device={device}, dtype={dtype}") + lm_model, lm_config = load_trained_lm(args) + lm_model.to(device=device, dtype=dtype) + lm_model.eval() + PPL, log_base = cacu_perplexity(lm_model, lm_config, args, None) + logger.info("Final PPL: " + str(PPL)) + logger.info("The log base is:" + str("%.2f" % log_base)) diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 7806aaa491bcde1c26969ead6e4c8032e6aec665..e47a59edaf0435578b57edfc37222acca7df2de2 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -172,7 +172,7 @@ class U2Trainer(Trainer): dist.get_rank(), total_loss / num_seen_utts)) return total_loss, num_seen_utts - def train(self): + def do_train(self): """The training process control by step.""" # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index f86243269ef5eec4955e82a155897c5882c45938..663c36d8b41f01d73cac5f9cabfee3fe99021144 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -173,7 +173,7 @@ class U2Trainer(Trainer): dist.get_rank(), total_loss / num_seen_utts)) return total_loss, num_seen_utts - def train(self): + def do_train(self): """The training process control by step.""" # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index c5df44c6704678c0642a995b44c324693bd4e4b4..1f638e64c082f8e8bb7bd9fc8c4be7a2b53f529d 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -184,7 +184,7 @@ class U2STTrainer(Trainer): dist.get_rank(), total_loss / num_seen_utts)) return total_loss, num_seen_utts - def train(self): + def do_train(self): """The training process control by step.""" # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index a6834ebc6f4349590d0df737ee84b1aa0c7c5dec..c596bd43b1fe003e00ba04ca9604771a4dd572c9 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -53,7 +53,7 @@ class TextFeaturizer(): self.maskctc = maskctc if vocab_filepath: - self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id = self._load_vocabulary_from_file( + self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file( vocab_filepath, maskctc) self.vocab_size = len(self.vocab_list) @@ -227,4 +227,4 @@ class TextFeaturizer(): logger.info(f"SOS id: {sos_id}") logger.info(f"SPACE id: {space_id}") logger.info(f"MASKCTC id: {maskctc_id}") - return token2id, id2token, vocab_list, unk_id, eos_id + return token2id, id2token, vocab_list, unk_id, eos_id, blank_id diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index b523dfc8e52e7bf0f785f4a5a02e535820154228..5391260eeab8deb6eae2124bf3f30cc993df849e 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -19,6 +19,7 @@ from yacs.config import CfgNode from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer +from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer from deepspeech.frontend.normalizer import FeatureNormalizer from deepspeech.frontend.speech import SpeechSegment from deepspeech.frontend.utility import IGNORE_ID @@ -33,7 +34,7 @@ logger = Log(__name__).getlog() def _tokenids(text, keep_transcription_text): - # for training text is token ids + # for training text is token ids tokens = text # token ids if keep_transcription_text: @@ -45,6 +46,43 @@ def _tokenids(text, keep_transcription_text): return tokens +class TextCollatorSpm(): + def __init__(self, unit_type, vocab_filepath, spm_model_prefix): + assert (vocab_filepath is not None) + self.text_featurizer = TextFeaturizer( + unit_type=unit_type, + vocab_filepath=vocab_filepath, + spm_model_prefix=spm_model_prefix) + self.eos_id = self.text_featurizer.eos_id + self.blank_id = self.text_featurizer.blank_id + + def __call__(self, batch): + """ + return type [List, np.array [B, T], np.array [B, T], np.array[B]] + """ + keys = [] + texts = [] + texts_input = [] + texts_output = [] + text_lens = [] + + for idx, item in enumerate(batch): + key = item.split(" ")[0].strip() + text = " ".join(item.split(" ")[1:]) + keys.append(key) + token_ids = self.text_featurizer.featurize(text) + texts_input.append( + np.array([self.eos_id] + token_ids).astype(np.int64)) + texts_output.append( + np.array(token_ids + [self.eos_id]).astype(np.int64)) + text_lens.append(len(token_ids) + 1) + + ys_input_pad = pad_list(texts_input, self.blank_id).astype(np.int64) + ys_output_pad = pad_list(texts_output, self.blank_id).astype(np.int64) + y_lens = np.array(text_lens).astype(np.int64) + return keys, ys_input_pad, ys_output_pad, y_lens + + class SpeechCollatorBase(): def __init__( self, diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index 1945c5f7259cf429d9e343e0c4cf909497cfb165..7c1010025551e76a072f68447ecc59006cc6e310 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -24,6 +24,25 @@ __all__ = ["ManifestDataset", "TransformDataset"] logger = Log(__name__).getlog() +class TextDataset(Dataset): + @classmethod + def from_file(cls, file_path): + dataset = cls(file_path) + return dataset + + def __init__(self, file_path): + self._manifest = [] + with open(file_path) as f: + for line in f: + self._manifest.append(line.strip()) + + def __len__(self): + return len(self._manifest) + + def __getitem__(self, idx): + return self._manifest[idx] + + class ManifestDataset(Dataset): @classmethod def params(cls, config: Optional[CfgNode]=None) -> CfgNode: diff --git a/deepspeech/models/lm/transformer.py b/deepspeech/models/lm/transformer.py index 35ecf6785d7882f00651958dd8b7f3e6ae7df24b..19e2b758a5a702f052d2fc6e3ab1b41258e5c16e 100644 --- a/deepspeech/models/lm/transformer.py +++ b/deepspeech/models/lm/transformer.py @@ -111,6 +111,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): in perplexity: p(t)^{-n} = exp(-log p(t) / n) """ + batch_size = x.size(0) xm = x != 0 xlen = xm.sum(axis=1) if self.embed_drop is not None: @@ -121,11 +122,13 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): y = self.decoder(h) loss = F.cross_entropy( y.view(-1, y.shape[-1]), t.view(-1), reduction="none") - mask = xm.to(dtype=loss.dtype) + mask = xm.to(loss.dtype) logp = loss * mask.view(-1) + nll = logp.view(batch_size, -1).sum(-1) + nll_count = mask.sum(-1) logp = logp.sum() count = mask.sum() - return logp / count, logp, count + return logp / count, logp, count, nll, nll_count # beam search API (see ScorerInterface) def score(self, y: paddle.Tensor, state: Any, diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 2c2389203a8989331d43d21edae80f24f55d6c8f..ddde1e885c2cf33f6dc2af13e90a96315780340c 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -18,9 +18,6 @@ from contextlib import contextmanager from pathlib import Path import paddle -from paddle import distributed as dist -from tensorboardX import SummaryWriter - from deepspeech.training.reporter import ObsScope from deepspeech.training.reporter import report from deepspeech.training.timer import Timer @@ -31,6 +28,8 @@ from deepspeech.utils.log import Log from deepspeech.utils.utility import all_version from deepspeech.utils.utility import seed_all from deepspeech.utils.utility import UpdateConfig +from paddle import distributed as dist +from tensorboardX import SummaryWriter __all__ = ["Trainer"] @@ -134,6 +133,10 @@ class Trainer(): logger.info( f"Benchmark reset batch-size: {self.args.benchmark_batch_size}") + @property + def train(self): + return self._train + @contextmanager def eval(self): self._train = False @@ -248,7 +251,7 @@ class Trainer(): sys.exit( f"Reach benchmark-max-step: {self.args.benchmark_max_step}") - def train(self): + def do_train(self): """The training process control by epoch.""" self.before_train() @@ -321,7 +324,7 @@ class Trainer(): """ try: with Timer("Training Done: {}"): - self.train() + self.do_train() except KeyboardInterrupt: exit(-1) finally: @@ -344,8 +347,12 @@ class Trainer(): try: with Timer("Test/Decode Done: {}"): with self.eval(): - self.restore() - self.test() + if hasattr(self, + "apply_static") and self.apply_static is True: + self.test() + else: + self.restore() + self.test() except KeyboardInterrupt: exit(-1) @@ -377,6 +384,8 @@ class Trainer(): elif self.args.checkpoint_path: output_dir = Path( self.args.checkpoint_path).expanduser().parent.parent + elif self.args.export_path: + output_dir = Path(self.args.export_path).expanduser().parent.parent self.output_dir = output_dir self.output_dir.mkdir(parents=True, exist_ok=True) @@ -432,7 +441,7 @@ class Trainer(): beginning of the experiment. """ config_file = self.config_dir / "config.yaml" - if self._train and config_file.exists(): + if self.train and config_file.exists(): time_stamp = time.strftime("%Y_%m_%d_%H_%M_%s", time.gmtime()) target_path = self.config_dir / ".".join( [time_stamp, "config.yaml"]) diff --git a/examples/aishell/s0/local/test.sh b/examples/aishell/s0/local/test.sh index 2ae0740b3e8d44ab03e45f4c1b5dbb945657705e..d539ac4943039fe6c33eb1373985aa98617a587f 100755 --- a/examples/aishell/s0/local/test.sh +++ b/examples/aishell/s0/local/test.sh @@ -13,7 +13,7 @@ ckpt_prefix=$2 model_type=$3 # download language model -bash local/download_lm_ch.sh +bash local/download_lm_ch.sh > /dev/null 2>&1 if [ $? -ne 0 ]; then exit 1 fi diff --git a/examples/aishell/s0/local/test_export.sh b/examples/aishell/s0/local/test_export.sh index a9a6b122df8055f872f9f0a68717b57241d99359..f0a30ce56fbc4cc43b559295fba7ef3ac3b3be26 100755 --- a/examples/aishell/s0/local/test_export.sh +++ b/examples/aishell/s0/local/test_export.sh @@ -13,7 +13,7 @@ jit_model_export_path=$2 model_type=$3 # download language model -bash local/download_lm_ch.sh +bash local/download_lm_ch.sh > /dev/null 2>&1 if [ $? -ne 0 ]; then exit 1 fi diff --git a/examples/csmsc/tts2/README.md b/examples/csmsc/tts2/README.md index e73f81fa95189d0184f471d5be12d4a2b6adc6d0..de9e488c8d4096deb53a1601b171ee0ce5b5fd55 100644 --- a/examples/csmsc/tts2/README.md +++ b/examples/csmsc/tts2/README.md @@ -19,7 +19,7 @@ Run the command below to 4. synthesize wavs. - synthesize waveform from `metadata.jsonl`. - synthesize waveform from text file. -6. inference using static model. +5. inference using static model. ```bash ./run.sh ``` diff --git a/examples/csmsc/tts3/README.md b/examples/csmsc/tts3/README.md index 42f33faac1de5ce0c19843764db2e86c7938c1ba..7eeb14fc533fcf2df2603a7d943216350505c6b7 100644 --- a/examples/csmsc/tts3/README.md +++ b/examples/csmsc/tts3/README.md @@ -19,6 +19,7 @@ Run the command below to 4. synthesize wavs. - synthesize waveform from `metadata.jsonl`. - synthesize waveform from text file. +5. inference using static model. ```bash ./run.sh ``` @@ -189,6 +190,13 @@ optional arguments: 5. `--output-dir` is the directory to save synthesized audio files. 6. `--device is` the type of device to run synthesis, 'cpu' and 'gpu' are supported. 'gpu' is recommended for faster synthesis. +### Inference +After Synthesize, we will get static models of fastspeech2 and pwgan in `${train_output_path}/inference`. +`./local/inference.sh` calls `${BIN_DIR}/inference.py`, which provides a paddle static model inference example for fastspeech2 + pwgan synthesize. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} +``` + ## Pretrained Model Pretrained FastSpeech2 model with no silence in the edge of audios. [fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_baker_ckpt_0.4.zip) @@ -215,6 +223,7 @@ python3 ${BIN_DIR}/synthesize_e2e.py \ --pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \ --text=${BIN_DIR}/../sentences.txt \ --output-dir=exp/default/test_e2e \ + --inference-dir=exp/default/inference \ --device="gpu" \ --phones-dict=fastspeech2_nosil_baker_ckpt_0.4/phone_id_map.txt ``` diff --git a/examples/csmsc/tts3/local/inference.sh b/examples/csmsc/tts3/local/inference.sh new file mode 100755 index 0000000000000000000000000000000000000000..cab72547c76a52e7b610eea6ad340b16a72216fa --- /dev/null +++ b/examples/csmsc/tts3/local/inference.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +train_output_path=$1 + +python3 ${BIN_DIR}/inference.py \ + --inference-dir=${train_output_path}/inference \ + --text=${BIN_DIR}/../sentences.txt \ + --output-dir=${train_output_path}/pd_infer_out \ + --phones-dict=dump/phone_id_map.txt diff --git a/examples/csmsc/tts3/local/synthesize_e2e.sh b/examples/csmsc/tts3/local/synthesize_e2e.sh index 8c9755dd0d1115b788ad724871da22c6d791005f..b654274319b6abb55e736f0b2c4aa75e7b07c3bf 100755 --- a/examples/csmsc/tts3/local/synthesize_e2e.sh +++ b/examples/csmsc/tts3/local/synthesize_e2e.sh @@ -15,5 +15,6 @@ python3 ${BIN_DIR}/synthesize_e2e.py \ --pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \ --text=${BIN_DIR}/../sentences.txt \ --output-dir=${train_output_path}/test_e2e \ + --inference-dir=${train_output_path}/inference \ --device="gpu" \ --phones-dict=dump/phone_id_map.txt diff --git a/examples/csmsc/tts3/run.sh b/examples/csmsc/tts3/run.sh index f45ddab060710e7a8545d3eb90d5628f3d76c4bb..718d6076041f8e8855bad5801383e691f175eee3 100755 --- a/examples/csmsc/tts3/run.sh +++ b/examples/csmsc/tts3/run.sh @@ -35,3 +35,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # synthesize_e2e, vocoder is pwgan CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # inference with static model + CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1 +fi diff --git a/examples/csmsc/voc3/README.md b/examples/csmsc/voc3/README.md new file mode 100644 index 0000000000000000000000000000000000000000..780a8ccdb0095865840c48f7d96d44465a23c534 --- /dev/null +++ b/examples/csmsc/voc3/README.md @@ -0,0 +1,127 @@ +# Multi Band MelGAN with CSMSC +This example contains code used to train a [Multi Band MelGAN](https://arxiv.org/abs/2005.05106) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). +## Dataset +### Download and Extract the datasaet +Download CSMSC from the [official website](https://www.data-baker.com/data/index/source) and extract it to `~/datasets`. Then the dataset is in directory `~/datasets/BZNSYP`. + +### Get MFA results for silence trim +We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence in the edge of audio. +You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/use_mfa) of our repo. + +## Get Started +Assume the path to the dataset is `~/datasets/BZNSYP`. +Assume the path to the MFA result of CSMSC is `./baker_alignment_tone`. +Run the command below to +1. **source path**. +2. preprocess the dataset, +3. train the model. +4. synthesize wavs. + - synthesize waveform from `metadata.jsonl`. +```bash +./run.sh +``` +### Preprocess the dataset +```bash +./local/preprocess.sh ${conf_path} +``` +When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below. + +```text +dump +├── dev +│ ├── norm +│ └── raw +├── test +│ ├── norm +│ └── raw +└── train + ├── norm + ├── raw + └── feats_stats.npy +``` +The dataset is split into 3 parts, namely `train`, `dev` and `test`, each of which contains a `norm` and `raw` subfolder. The `raw` folder contains log magnitude of mel spectrogram of each utterances, while the norm folder contains normalized spectrogram. The statistics used to normalize the spectrogram is computed from the training set, which is located in `dump/train/feats_stats.npy`. + +Also there is a `metadata.jsonl` in each subfolder. It is a table-like file which contains id and paths to spectrogam of each utterance. + +### Train the model +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} +``` +`./local/train.sh` calls `${BIN_DIR}/train.py`. +Here's the complete help message. + +```text +usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] + [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] + [--device DEVICE] [--nprocs NPROCS] [--verbose VERBOSE] + [--batch-size BATCH_SIZE] [--max-iter MAX_ITER] + [--run-benchmark RUN_BENCHMARK] + [--profiler_options PROFILER_OPTIONS] + +Train a ParallelWaveGAN model. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG config file to overwrite default config. + --train-metadata TRAIN_METADATA + training data. + --dev-metadata DEV_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --device DEVICE device type to use. + --nprocs NPROCS number of processes. + --verbose VERBOSE verbose. + +benchmark: + arguments related to benchmark. + + --batch-size BATCH_SIZE + batch size. + --max-iter MAX_ITER train max steps. + --run-benchmark RUN_BENCHMARK + runing benchmark or not, if True, use the --batch-size + and --max-iter. + --profiler_options PROFILER_OPTIONS + The option of profiler, which should be in format + "key1=value1;key2=value2;key3=value3". +``` + +1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`. +2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder. +3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are save in `checkpoints/` inside this directory. +4. `--device` is the type of the device to run the experiment, 'cpu' or 'gpu' are supported. +5. `--nprocs` is the number of processes to run in parallel, note that nprocs > 1 is only supported when `--device` is 'gpu'. + +### Synthesize +`./local/synthesize.sh` calls `${BIN_DIR}/synthesize.py`, which can synthesize waveform from `metadata.jsonl`. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} +``` +```text +usage: synthesize.py [-h] [--config CONFIG] [--checkpoint CHECKPOINT] + [--test-metadata TEST_METADATA] [--output-dir OUTPUT_DIR] + [--device DEVICE] [--verbose VERBOSE] + +Synthesize with parallel wavegan. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG parallel wavegan config file. + --checkpoint CHECKPOINT + snapshot to load. + --test-metadata TEST_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --device DEVICE device to run. + --verbose VERBOSE verbose. +``` + +1. `--config` parallel wavegan config file. You should use the same config with which the model is trained. +2. `--checkpoint` is the checkpoint to load. Pick one of the checkpoints from `checkpoints` inside the training output directory. +3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory. +4. `--output-dir` is the directory to save the synthesized audio files. +5. `--device` is the type of device to run synthesis, 'cpu' and 'gpu' are supported. + +## Pretrained Models diff --git a/examples/csmsc/voc3/conf/default.yaml b/examples/csmsc/voc3/conf/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cc27220fca0a4d6dcad4aff463df4a6bfcc0c1eb --- /dev/null +++ b/examples/csmsc/voc3/conf/default.yaml @@ -0,0 +1,139 @@ +# This is the hyperparameter configuration file for MelGAN. +# Please make sure this is adjusted for the CSMSC dataset. If you want to +# apply to the other dataset, you might need to carefully change some parameters. +# This configuration requires ~ 8GB memory and will finish within 7 days on Titan V. + +# This configuration is based on full-band MelGAN but the hop size and sampling +# rate is different from the paper (16kHz vs 24kHz). The number of iteraions +# is not shown in the paper so currently we train 1M iterations (not sure enough +# to converge). The optimizer setting is based on @dathudeptrai advice. +# https://github.com/kan-bayashi/ParallelWaveGAN/issues/143#issuecomment-632539906 + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +fs: 24000 # Sampling rate. +n_fft: 2048 # FFT size. (in samples) +n_shift: 300 # Hop size. (in samples) +win_length: 1200 # Window length. (in samples) + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. (Hz) +fmax: 7600 # Maximum frequency in mel basis calculation. (Hz) + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + in_channels: 80 # Number of input channels. + out_channels: 4 # Number of output channels. + kernel_size: 7 # Kernel size of initial and final conv layers. + channels: 384 # Initial number of channels for conv layers. + upsample_scales: [5, 5, 3] # List of Upsampling scales. + stack_kernel_size: 3 # Kernel size of dilated conv layers in residual stack. + stacks: 4 # Number of stacks in a single residual stack module. + use_weight_norm: True # Whether to use weight normalization. + use_causal_conv: False # Whether to use causal convolution. + use_final_nonlinear_activation: True + + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + scales: 3 # Number of multi-scales. + downsample_pooling: "AvgPool1D" # Pooling type for the input downsampling. + downsample_pooling_params: # Parameters of the above pooling function. + kernel_size: 4 + stride: 2 + padding: 1 + exclusive: True + kernel_sizes: [5, 3] # List of kernel size. + channels: 16 # Number of channels of the initial conv layer. + max_downsample_channels: 512 # Maximum number of channels of downsampling layers. + downsample_scales: [4, 4, 4] # List of downsampling scales. + nonlinear_activation: "LeakyReLU" # Nonlinear activation function. + nonlinear_activation_params: # Parameters of nonlinear activation function. + negative_slope: 0.2 + use_weight_norm: True # Whether to use weight norm. + + +########################################################### +# STFT LOSS SETTING # +########################################################### +use_stft_loss: true +stft_loss_params: + fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. + hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss + win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. + window: "hann" # Window function for STFT-based loss +use_subband_stft_loss: true +subband_stft_loss_params: + fft_sizes: [384, 683, 171] # List of FFT size for STFT-based loss. + hop_sizes: [30, 60, 10] # List of hop size for STFT-based loss + win_lengths: [150, 300, 60] # List of window length for STFT-based loss. + window: "hann" # Window function for STFT-based loss + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +use_feat_match_loss: false # Whether to use feature matching loss. +lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss. + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 64 # Batch size. +batch_max_steps: 16200 # Length of each audio in batch. Make sure dividable by hop_size. +num_workers: 2 # Number of workers in DataLoader. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + epsilon: 1.0e-7 # Generator's epsilon. + weight_decay: 0.0 # Generator's weight decay coefficient. + +generator_grad_norm: -1 # Generator's gradient norm. +generator_scheduler_params: + learning_rate: 1.0e-3 # Generator's learning rate. + gamma: 0.5 # Generator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 100000 + - 200000 + - 300000 + - 400000 + - 500000 + - 600000 +discriminator_optimizer_params: + epsilon: 1.0e-7 # Discriminator's epsilon. + weight_decay: 0.0 # Discriminator's weight decay coefficient. + +discriminator_grad_norm: -1 # Discriminator's gradient norm. +discriminator_scheduler_params: + learning_rate: 1.0e-3 # Discriminator's learning rate. + gamma: 0.5 # Discriminator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 100000 + - 200000 + - 300000 + - 400000 + - 500000 + - 600000 + +########################################################### +# INTERVAL SETTING # +########################################################### +discriminator_train_start_steps: 200000 # Number of steps to start to train discriminator. +train_max_steps: 1000000 # Number of training steps. +save_interval_steps: 5000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. + +########################################################### +# OTHER SETTING # +########################################################### +num_snapshots: 10 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random \ No newline at end of file diff --git a/examples/csmsc/voc3/local/preprocess.sh b/examples/csmsc/voc3/local/preprocess.sh new file mode 100755 index 0000000000000000000000000000000000000000..61d6d62bef566d385c4d3d2407ce437ec6d8e9ad --- /dev/null +++ b/examples/csmsc/voc3/local/preprocess.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +stage=0 +stop_stage=100 + +config_path=$1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./baker_alignment_tone \ + --output=durations.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/../preprocess.py \ + --rootdir=~/datasets/BZNSYP/ \ + --dataset=baker \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --cut-sil=True \ + --num-cpu=20 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="feats" +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # normalize, dev and test should use train's stats + echo "Normalize ..." + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --stats=dump/train/feats_stats.npy + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --stats=dump/train/feats_stats.npy + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --stats=dump/train/feats_stats.npy +fi diff --git a/examples/csmsc/voc3/local/synthesize.sh b/examples/csmsc/voc3/local/synthesize.sh new file mode 100755 index 0000000000000000000000000000000000000000..9f904ac0c6e7006ab40c3d8aaa7c457ad1495b36 --- /dev/null +++ b/examples/csmsc/voc3/local/synthesize.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/synthesize.py \ + --config=${config_path} \ + --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --test-metadata=dump/test/norm/metadata.jsonl \ + --output-dir=${train_output_path}/test diff --git a/examples/csmsc/voc3/local/train.sh b/examples/csmsc/voc3/local/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..1ef860c36a527aa361dc3a081febae20557e34c7 --- /dev/null +++ b/examples/csmsc/voc3/local/train.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +FLAGS_cudnn_exhaustive_search=true \ +FLAGS_conv_workspace_size_limit=4000 \ +python ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --nprocs=1 diff --git a/examples/csmsc/voc3/path.sh b/examples/csmsc/voc3/path.sh new file mode 100755 index 0000000000000000000000000000000000000000..f6b9fe61a918a75ed2438719c236a7c2be4645c8 --- /dev/null +++ b/examples/csmsc/voc3/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=multi_band_melgan +export BIN_DIR=${MAIN_ROOT}/parakeet/exps/gan_vocoder/${MODEL} \ No newline at end of file diff --git a/examples/csmsc/voc3/run.sh b/examples/csmsc/voc3/run.sh new file mode 100755 index 0000000000000000000000000000000000000000..360f6ec2a23379dc3477042c878617537e092081 --- /dev/null +++ b/examples/csmsc/voc3/run.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0,1 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_50000.pdz + +# with the following command, you can choice the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi diff --git a/examples/librispeech/s0/local/test.sh b/examples/librispeech/s0/local/test.sh index 4d00f30b852da5a370f5d4934f3caadd2b833c00..25dd04374acb02256fc6efc5a6b4d572569efb3a 100755 --- a/examples/librispeech/s0/local/test.sh +++ b/examples/librispeech/s0/local/test.sh @@ -13,7 +13,7 @@ ckpt_prefix=$2 model_type=$3 # download language model -bash local/download_lm_en.sh +bash local/download_lm_en.sh > /dev/null 2>&1 if [ $? -ne 0 ]; then exit 1 fi diff --git a/examples/librispeech/s2/conf/lm/transformer.yaml b/examples/librispeech/s2/conf/lm/transformer.yaml index 4349f795ba56ea35eb256b62ab167df9ba6cb006..826f08020d54b7a79218cf548dca1dd8d901e0df 100644 --- a/examples/librispeech/s2/conf/lm/transformer.yaml +++ b/examples/librispeech/s2/conf/lm/transformer.yaml @@ -1,4 +1,8 @@ model_module: transformer + +data: + unit_type: spm + model: n_vocab: 5002 pos_enc: null @@ -11,3 +15,7 @@ model: emb_dropout_rate: 0.0 att_dropout_rate: 0.0 tie_weights: False + +decoding: + batch_size: 30 + num_workers: 2 diff --git a/examples/librispeech/s2/local/cacu_perplexity.sh b/examples/librispeech/s2/local/cacu_perplexity.sh new file mode 100755 index 0000000000000000000000000000000000000000..a77a6de3a3ac4a6205f00eeab2652b592716d6b6 --- /dev/null +++ b/examples/librispeech/s2/local/cacu_perplexity.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +set -e + +stage=-1 +stop_stage=100 + +expdir=exp +datadir=data + +ngpu=0 + +# lm params +rnnlm_config_path=conf/lm/transformer.yaml +lmexpdir=exp/lm/transformer +lang_model=transformerLM.pdparams + +#data path +test_set=${datadir}/test_clean/text +test_set_lower=${datadir}/test_clean/text_lower +train_set=train_960 + +# bpemode (unigram or bpe) +nbpe=5000 +bpemode=unigram +bpeprefix=${datadir}/lang_char/${train_set}_${bpemode}${nbpe} +bpemodel=${bpeprefix}.model + +vocabfile=${bpeprefix}_units.txt +vocabfile_lower=${bpeprefix}_units_lower.txt + +output_dir=${expdir}/lm/transformer/perplexity + +mkdir -p ${output_dir} + +# Transform the data upper case to lower +if [ -f ${vocabfile} ]; then + tr A-Z a-z < ${vocabfile} > ${vocabfile_lower} +fi + +if [ -f ${test_set} ]; then + tr A-Z a-z < ${test_set} > ${test_set_lower} +fi + +python ${LM_BIN_DIR}/cacu_perplexity.py \ + --rnnlm ${lmexpdir}/${lang_model} \ + --rnnlm-conf ${rnnlm_config_path} \ + --vocab_path ${vocabfile_lower} \ + --bpeprefix ${bpeprefix} \ + --text_path ${test_set_lower} \ + --output_dir ${output_dir} \ + --ngpu ${ngpu} + diff --git a/examples/librispeech/s2/run.sh b/examples/librispeech/s2/run.sh index 146f133d8c310f3ff5a05aaed54f8cce369e0886..e014c2a93d6c0a0f4d37d85bb86ea409aaacffc9 100755 --- a/examples/librispeech/s2/run.sh +++ b/examples/librispeech/s2/run.sh @@ -51,3 +51,7 @@ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then # export ckpt avg_n CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi + +if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + CUDA_VISIBLE_DEVICES= ./local/cacu_perplexity.sh || exit -1 +fi diff --git a/examples/other/1xt2x/README.md b/examples/other/1xt2x/README.md index 1f5fe8e3b17af649d361be4250cdd94c1795e00a..a278d366510b4205a0aa4180947dae6845f334a9 100644 --- a/examples/other/1xt2x/README.md +++ b/examples/other/1xt2x/README.md @@ -2,10 +2,18 @@ Convert Deepspeech 1.8 released model to 2.x. -## Model +## Model source directory * Deepspeech2x -## Exp -* baidu_en8k +## Expriment directory * aishell * librispeech +* baidu_en8k + +# The released model + +Acoustic Model | Training Data | Hours of Speech | Token-based | CER | WER +:-------------:| :------------:| :---------------: | :---------: | :---: | :----: +Ds2 Offline Aishell 1xt2x model| Aishell Dataset | 151 h | Char-based | 0.080447 | +Ds2 Offline Librispeech 1xt2x model | Librispeech Dataset | 960 h | Word-based | | 0.068548 +Ds2 Offline Baidu en8k 1x2x model | Baidu Internal English Dataset | 8628 h |Word-based | | 0.054112 diff --git a/examples/other/1xt2x/aishell/local/test.sh b/examples/other/1xt2x/aishell/local/test.sh index 2ae0740b3e8d44ab03e45f4c1b5dbb945657705e..d539ac4943039fe6c33eb1373985aa98617a587f 100755 --- a/examples/other/1xt2x/aishell/local/test.sh +++ b/examples/other/1xt2x/aishell/local/test.sh @@ -13,7 +13,7 @@ ckpt_prefix=$2 model_type=$3 # download language model -bash local/download_lm_ch.sh +bash local/download_lm_ch.sh > /dev/null 2>&1 if [ $? -ne 0 ]; then exit 1 fi diff --git a/examples/other/1xt2x/baidu_en8k/local/test.sh b/examples/other/1xt2x/baidu_en8k/local/test.sh index 4d00f30b852da5a370f5d4934f3caadd2b833c00..25dd04374acb02256fc6efc5a6b4d572569efb3a 100755 --- a/examples/other/1xt2x/baidu_en8k/local/test.sh +++ b/examples/other/1xt2x/baidu_en8k/local/test.sh @@ -13,7 +13,7 @@ ckpt_prefix=$2 model_type=$3 # download language model -bash local/download_lm_en.sh +bash local/download_lm_en.sh > /dev/null 2>&1 if [ $? -ne 0 ]; then exit 1 fi diff --git a/examples/other/1xt2x/librispeech/local/test.sh b/examples/other/1xt2x/librispeech/local/test.sh index 4d00f30b852da5a370f5d4934f3caadd2b833c00..25dd04374acb02256fc6efc5a6b4d572569efb3a 100755 --- a/examples/other/1xt2x/librispeech/local/test.sh +++ b/examples/other/1xt2x/librispeech/local/test.sh @@ -13,7 +13,7 @@ ckpt_prefix=$2 model_type=$3 # download language model -bash local/download_lm_en.sh +bash local/download_lm_en.sh > /dev/null 2>&1 if [ $? -ne 0 ]; then exit 1 fi diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index 621b372cbb932a732c63b109ec4ed57c47791b8d..58899a1568e3fd61ba23aaf1cb83347428a7f40d 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -45,7 +45,7 @@ model: ctc_grad_norm_type: null training: - n_epoch: 10 + n_epoch: 5 accum_grad: 1 lr: 1e-5 lr_decay: 0.8 diff --git a/examples/tiny/s0/conf/deepspeech2_online.yaml b/examples/tiny/s0/conf/deepspeech2_online.yaml index 5a8294adb780b32503bb46cfbc80c43b1700b1eb..334b1d31ce21ab95c3099c76caf9cdd36c61cd92 100644 --- a/examples/tiny/s0/conf/deepspeech2_online.yaml +++ b/examples/tiny/s0/conf/deepspeech2_online.yaml @@ -47,7 +47,7 @@ model: ctc_grad_norm_type: null training: - n_epoch: 10 + n_epoch: 5 accum_grad: 1 lr: 1e-5 lr_decay: 1.0 diff --git a/examples/tiny/s0/local/test.sh b/examples/tiny/s0/local/test.sh index 4d00f30b852da5a370f5d4934f3caadd2b833c00..25dd04374acb02256fc6efc5a6b4d572569efb3a 100755 --- a/examples/tiny/s0/local/test.sh +++ b/examples/tiny/s0/local/test.sh @@ -13,7 +13,7 @@ ckpt_prefix=$2 model_type=$3 # download language model -bash local/download_lm_en.sh +bash local/download_lm_en.sh > /dev/null 2>&1 if [ $? -ne 0 ]; then exit 1 fi diff --git a/examples/tiny/s1/conf/chunk_confermer.yaml b/examples/tiny/s1/conf/chunk_confermer.yaml index b14b4b21218012f56a9df73b7dc31da8c271ee6e..c518666977faef8c0862be3e7c7f4d5b5244a5fc 100644 --- a/examples/tiny/s1/conf/chunk_confermer.yaml +++ b/examples/tiny/s1/conf/chunk_confermer.yaml @@ -83,7 +83,7 @@ model: training: - n_epoch: 20 + n_epoch: 5 accum_grad: 1 global_grad_clip: 5.0 optim: adam diff --git a/examples/tiny/s1/conf/chunk_transformer.yaml b/examples/tiny/s1/conf/chunk_transformer.yaml index 38edbf35816a6bd73af9eeea8051c4b580ebb5b1..29c30b262048b46bf08d132aebbb24bd7186bf71 100644 --- a/examples/tiny/s1/conf/chunk_transformer.yaml +++ b/examples/tiny/s1/conf/chunk_transformer.yaml @@ -76,7 +76,7 @@ model: training: - n_epoch: 20 + n_epoch: 5 accum_grad: 1 global_grad_clip: 5.0 optim: adam diff --git a/examples/tiny/s1/conf/conformer.yaml b/examples/tiny/s1/conf/conformer.yaml index 0b06b2b72feb890d886aade48d3449785fa4b375..8487da771930e6f615ac9fe0e718bab310f66970 100644 --- a/examples/tiny/s1/conf/conformer.yaml +++ b/examples/tiny/s1/conf/conformer.yaml @@ -79,7 +79,7 @@ model: training: - n_epoch: 20 + n_epoch: 5 accum_grad: 4 global_grad_clip: 5.0 optim: adam diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index 1c6f9e022a44e108c5f6d1d6d81cd743a8448863..cc9b5c5158adf2ca74ccf715e6edaf61cb320953 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -73,7 +73,7 @@ model: training: - n_epoch: 21 + n_epoch: 5 accum_grad: 1 global_grad_clip: 5.0 optim: adam diff --git a/parakeet/data/batch.py b/parakeet/data/batch.py index 515074d14d560e8131c45cc937bbd3d7bd931ea8..5e7ac3996ccf94d6ba6f57b0b8cd91f8025ec883 100644 --- a/parakeet/data/batch.py +++ b/parakeet/data/batch.py @@ -53,8 +53,8 @@ def batch_text_id(minibatch, pad_id=0, dtype=np.int64): peek_example = minibatch[0] assert len(peek_example.shape) == 1, "text example is an 1D tensor" - lengths = [example.shape[0] for example in minibatch - ] # assume (channel, n_samples) or (n_samples, ) + lengths = [example.shape[0] for example in + minibatch] # assume (channel, n_samples) or (n_samples, ) max_len = np.max(lengths) batch = [] diff --git a/parakeet/datasets/vocoder_batch_fn.py b/parakeet/datasets/vocoder_batch_fn.py index 30adb142d0c6db69247301cd7d1177d7736cc698..2de4fb124e1e50a7c5481366c8cec675922d8a98 100644 --- a/parakeet/datasets/vocoder_batch_fn.py +++ b/parakeet/datasets/vocoder_batch_fn.py @@ -107,8 +107,13 @@ class Clip(object): features, this process will be needed. """ - if len(x) < c.shape[1] * self.hop_size: - x = np.pad(x, (0, c.shape[1] * self.hop_size - len(x)), mode="edge") + if len(x) < c.shape[0] * self.hop_size: + x = np.pad(x, (0, c.shape[0] * self.hop_size - len(x)), mode="edge") + elif len(x) > c.shape[0] * self.hop_size: + print( + f"wave length: ({len(x)}), mel length: ({c.shape[0]}), hop size: ({self.hop_size })" + ) + x = x[:c.shape[1] * self.hop_size] # check the legnth is valid assert len(x) == c.shape[ diff --git a/parakeet/exps/fastspeech2/inference.py b/parakeet/exps/fastspeech2/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..4367608874ad3ca5fbacb7f6c473a775acb6eed6 --- /dev/null +++ b/parakeet/exps/fastspeech2/inference.py @@ -0,0 +1,133 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from pathlib import Path + +import soundfile as sf +from paddle import inference + +from parakeet.frontend.zh_frontend import Frontend + + +def main(): + parser = argparse.ArgumentParser( + description="Paddle Infernce with speedyspeech & parallel wavegan.") + parser.add_argument( + "--inference-dir", type=str, help="dir to save inference models") + parser.add_argument( + "--text", + type=str, + help="text to synthesize, a 'utt_id sentence' pair per line") + parser.add_argument("--output-dir", type=str, help="output dir") + parser.add_argument( + "--enable-auto-log", action="store_true", help="use auto log") + parser.add_argument( + "--phones-dict", + type=str, + default="phones.txt", + help="phone vocabulary file.") + + args, _ = parser.parse_known_args() + + frontend = Frontend(phone_vocab_path=args.phones_dict) + print("frontend done!") + + fastspeech2_config = inference.Config( + str(Path(args.inference_dir) / "fastspeech2.pdmodel"), + str(Path(args.inference_dir) / "fastspeech2.pdiparams")) + fastspeech2_config.enable_use_gpu(50, 0) + # This line must be commented, if not, it will OOM + # fastspeech2_config.enable_memory_optim() + fastspeech2_predictor = inference.create_predictor(fastspeech2_config) + + pwg_config = inference.Config( + str(Path(args.inference_dir) / "pwg.pdmodel"), + str(Path(args.inference_dir) / "pwg.pdiparams")) + pwg_config.enable_use_gpu(100, 0) + pwg_config.enable_memory_optim() + pwg_predictor = inference.create_predictor(pwg_config) + + if args.enable_auto_log: + import auto_log + os.makedirs("output", exist_ok=True) + pid = os.getpid() + logger = auto_log.AutoLogger( + model_name="fastspeech2", + model_precision='float32', + batch_size=1, + data_shape="dynamic", + save_path="./output/auto_log.log", + inference_config=fastspeech2_config, + pids=pid, + process_name=None, + gpu_ids=0, + time_keys=['preprocess_time', 'inference_time', 'postprocess_time'], + warmup=0) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + sentences = [] + + with open(args.text, 'rt') as f: + for line in f: + utt_id, sentence = line.strip().split() + sentences.append((utt_id, sentence)) + + for utt_id, sentence in sentences: + if args.enable_auto_log: + logger.times.start() + input_ids = frontend.get_input_ids(sentence, merge_sentences=True) + phone_ids = input_ids["phone_ids"] + phones = phone_ids[0].numpy() + + if args.enable_auto_log: + logger.times.stamp() + + input_names = fastspeech2_predictor.get_input_names() + phones_handle = fastspeech2_predictor.get_input_handle(input_names[0]) + + phones_handle.reshape(phones.shape) + phones_handle.copy_from_cpu(phones) + + fastspeech2_predictor.run() + output_names = fastspeech2_predictor.get_output_names() + output_handle = fastspeech2_predictor.get_output_handle(output_names[0]) + output_data = output_handle.copy_to_cpu() + + input_names = pwg_predictor.get_input_names() + mel_handle = pwg_predictor.get_input_handle(input_names[0]) + mel_handle.reshape(output_data.shape) + mel_handle.copy_from_cpu(output_data) + + pwg_predictor.run() + output_names = pwg_predictor.get_output_names() + output_handle = pwg_predictor.get_output_handle(output_names[0]) + wav = output_data = output_handle.copy_to_cpu() + + if args.enable_auto_log: + logger.times.stamp() + + sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=24000) + + if args.enable_auto_log: + logger.times.end(stamp=True) + print(f"{utt_id} done!") + + if args.enable_auto_log: + logger.report() + + +if __name__ == "__main__": + main() diff --git a/parakeet/exps/fastspeech2/synthesize_e2e.py b/parakeet/exps/fastspeech2/synthesize_e2e.py index dd1b57c8a9400f869c14268798a85c190bb599db..9c036e9fc91f5335e74c841d04813dd1ef0f3187 100644 --- a/parakeet/exps/fastspeech2/synthesize_e2e.py +++ b/parakeet/exps/fastspeech2/synthesize_e2e.py @@ -13,12 +13,15 @@ # limitations under the License. import argparse import logging +import os from pathlib import Path import numpy as np import paddle import soundfile as sf import yaml +from paddle import jit +from paddle.static import InputSpec from yacs.config import CfgNode from parakeet.frontend.zh_frontend import Frontend @@ -74,7 +77,21 @@ def evaluate(args, fastspeech2_config, pwg_config): pwg_normalizer = ZScore(mu, std) fastspeech2_inference = FastSpeech2Inference(fastspeech2_normalizer, model) + fastspeech2_inference.eval() + fastspeech2_inference = jit.to_static( + fastspeech2_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) + paddle.jit.save(fastspeech2_inference, + os.path.join(args.inference_dir, "fastspeech2")) + fastspeech2_inference = paddle.jit.load( + os.path.join(args.inference_dir, "fastspeech2")) pwg_inference = PWGInference(pwg_normalizer, vocoder) + pwg_inference.eval() + pwg_inference = jit.to_static( + pwg_inference, input_spec=[ + InputSpec([-1, 80], dtype=paddle.float32), + ]) + paddle.jit.save(pwg_inference, os.path.join(args.inference_dir, "pwg")) + pwg_inference = paddle.jit.load(os.path.join(args.inference_dir, "pwg")) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) @@ -135,6 +152,8 @@ def main(): type=str, help="text to synthesize, a 'utt_id sentence' pair per line.") parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--inference-dir", type=str, help="dir to save inference models") parser.add_argument( "--device", type=str, default="gpu", help="device type to use.") parser.add_argument("--verbose", type=int, default=1, help="verbose.") diff --git a/parakeet/exps/fastspeech2/train.py b/parakeet/exps/fastspeech2/train.py index 59b1ea3af5eb017dce1f20749097a3181359f876..47ad1b4dac0ca66bea057a67de26d28659d75f81 100644 --- a/parakeet/exps/fastspeech2/train.py +++ b/parakeet/exps/fastspeech2/train.py @@ -25,7 +25,6 @@ from paddle import DataParallel from paddle import distributed as dist from paddle.io import DataLoader from paddle.io import DistributedBatchSampler -from visualdl import LogWriter from yacs.config import CfgNode from parakeet.datasets.am_batch_fn import fastspeech2_multi_spk_batch_fn @@ -160,8 +159,7 @@ def train_sp(args, config): if dist.get_rank() == 0: trainer.extend(evaluator, trigger=(1, "epoch")) - writer = LogWriter(str(output_dir)) - trainer.extend(VisualDL(writer), trigger=(1, "iteration")) + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) trainer.extend( Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) # print(trainer.extensions) diff --git a/parakeet/exps/gan_vocoder/multi_band_melgan/__init__.py b/parakeet/exps/gan_vocoder/multi_band_melgan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..185a92b8d94d3426d616c0624f0f2ee04339349e --- /dev/null +++ b/parakeet/exps/gan_vocoder/multi_band_melgan/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/parakeet/exps/gan_vocoder/multi_band_melgan/synthesize.py b/parakeet/exps/gan_vocoder/multi_band_melgan/synthesize.py new file mode 100644 index 0000000000000000000000000000000000000000..00b1b96c8befe556c3faff800ca4fc56e60b7dab --- /dev/null +++ b/parakeet/exps/gan_vocoder/multi_band_melgan/synthesize.py @@ -0,0 +1,98 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from pathlib import Path + +import jsonlines +import numpy as np +import paddle +import soundfile as sf +import yaml +from paddle import distributed as dist +from timer import timer +from yacs.config import CfgNode + +from parakeet.datasets.data_table import DataTable +from parakeet.models.melgan import MelGANGenerator + + +def main(): + parser = argparse.ArgumentParser( + description="Synthesize with parallel wavegan.") + parser.add_argument( + "--config", type=str, help="parallel wavegan config file.") + parser.add_argument("--checkpoint", type=str, help="snapshot to load.") + parser.add_argument("--test-metadata", type=str, help="dev data.") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--device", type=str, default="gpu", help="device to run.") + parser.add_argument("--verbose", type=int, default=1, help="verbose.") + + args = parser.parse_args() + + with open(args.config) as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + paddle.set_device(args.device) + generator = MelGANGenerator(**config["generator_params"]) + state_dict = paddle.load(args.checkpoint) + generator.set_state_dict(state_dict["generator_params"]) + + generator.remove_weight_norm() + generator.eval() + with jsonlines.open(args.test_metadata, 'r') as reader: + metadata = list(reader) + + test_dataset = DataTable( + metadata, + fields=['utt_id', 'feats'], + converters={ + 'utt_id': None, + 'feats': np.load, + }) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + N = 0 + T = 0 + for example in test_dataset: + utt_id = example['utt_id'] + mel = example['feats'] + mel = paddle.to_tensor(mel) # (T, C) + with timer() as t: + with paddle.no_grad(): + wav = generator.inference(c=mel) + wav = wav.numpy() + N += wav.size + T += t.elapse + speed = wav.size / t.elapse + rtf = config.fs / speed + print( + f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs) + print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }") + + +if __name__ == "__main__": + main() diff --git a/parakeet/exps/gan_vocoder/multi_band_melgan/train.py b/parakeet/exps/gan_vocoder/multi_band_melgan/train.py new file mode 100644 index 0000000000000000000000000000000000000000..c03fb354b9ac0845b31c293569a877afb8591ce8 --- /dev/null +++ b/parakeet/exps/gan_vocoder/multi_band_melgan/train.py @@ -0,0 +1,269 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +import os +import shutil +from pathlib import Path + +import jsonlines +import numpy as np +import paddle +import yaml +from paddle import DataParallel +from paddle import distributed as dist +from paddle import nn +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from paddle.optimizer import Adam +from paddle.optimizer.lr import MultiStepDecay +from yacs.config import CfgNode + +from parakeet.datasets.data_table import DataTable +from parakeet.datasets.vocoder_batch_fn import Clip +from parakeet.models.melgan import MBMelGANEvaluator +from parakeet.models.melgan import MBMelGANUpdater +from parakeet.models.melgan import MelGANGenerator +from parakeet.models.melgan import MelGANMultiScaleDiscriminator +from parakeet.modules.adversarial_loss import DiscriminatorAdversarialLoss +from parakeet.modules.adversarial_loss import GeneratorAdversarialLoss +from parakeet.modules.pqmf import PQMF +from parakeet.modules.stft_loss import MultiResolutionSTFTLoss +from parakeet.training.extensions.snapshot import Snapshot +from parakeet.training.extensions.visualizer import VisualDL +from parakeet.training.seeding import seed_everything +from parakeet.training.trainer import Trainer + + +def train_sp(args, config): + # decides device type and whether to run in parallel + # setup running environment correctly + world_size = paddle.distributed.get_world_size() + if not paddle.is_compiled_with_cuda(): + paddle.set_device("cpu") + else: + paddle.set_device("gpu") + if world_size > 1: + paddle.distributed.init_parallel_env() + + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + print( + f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", + ) + + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + + # construct dataset for training and validation + with jsonlines.open(args.train_metadata, 'r') as reader: + train_metadata = list(reader) + train_dataset = DataTable( + data=train_metadata, + fields=["wave", "feats"], + converters={ + "wave": np.load, + "feats": np.load, + }, ) + with jsonlines.open(args.dev_metadata, 'r') as reader: + dev_metadata = list(reader) + dev_dataset = DataTable( + data=dev_metadata, + fields=["wave", "feats"], + converters={ + "wave": np.load, + "feats": np.load, + }, ) + + # collate function and dataloader + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True) + dev_sampler = DistributedBatchSampler( + dev_dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=False) + print("samplers done!") + + if "aux_context_window" in config.generator_params: + aux_context_window = config.generator_params.aux_context_window + else: + aux_context_window = 0 + train_batch_fn = Clip( + batch_max_steps=config.batch_max_steps, + hop_size=config.n_shift, + aux_context_window=aux_context_window) + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=train_batch_fn, + num_workers=config.num_workers) + + dev_dataloader = DataLoader( + dev_dataset, + batch_sampler=dev_sampler, + collate_fn=train_batch_fn, + num_workers=config.num_workers) + print("dataloaders done!") + + generator = MelGANGenerator(**config["generator_params"]) + discriminator = MelGANMultiScaleDiscriminator( + **config["discriminator_params"]) + if world_size > 1: + generator = DataParallel(generator) + discriminator = DataParallel(discriminator) + print("models done!") + criterion_stft = MultiResolutionSTFTLoss(**config["stft_loss_params"]) + criterion_sub_stft = MultiResolutionSTFTLoss( + **config["subband_stft_loss_params"]) + criterion_gen_adv = GeneratorAdversarialLoss() + criterion_dis_adv = DiscriminatorAdversarialLoss() + # define special module for subband processing + criterion_pqmf = PQMF(subbands=config["generator_params"]["out_channels"]) + print("criterions done!") + + lr_schedule_g = MultiStepDecay(**config["generator_scheduler_params"]) + # Compared to multi_band_melgan.v1 config, Adam optimizer without gradient norm is used + generator_grad_norm = config["generator_grad_norm"] + gradient_clip_g = nn.ClipGradByGlobalNorm( + generator_grad_norm) if generator_grad_norm > 0 else None + print("gradient_clip_g:", gradient_clip_g) + + optimizer_g = Adam( + learning_rate=lr_schedule_g, + grad_clip=gradient_clip_g, + parameters=generator.parameters(), + **config["generator_optimizer_params"]) + lr_schedule_d = MultiStepDecay(**config["discriminator_scheduler_params"]) + discriminator_grad_norm = config["discriminator_grad_norm"] + gradient_clip_d = nn.ClipGradByGlobalNorm( + discriminator_grad_norm) if discriminator_grad_norm > 0 else None + print("gradient_clip_d:", gradient_clip_d) + optimizer_d = Adam( + learning_rate=lr_schedule_d, + grad_clip=gradient_clip_d, + parameters=discriminator.parameters(), + **config["discriminator_optimizer_params"]) + print("optimizers done!") + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + if dist.get_rank() == 0: + config_name = args.config.split("/")[-1] + # copy conf to output_dir + shutil.copyfile(args.config, output_dir / config_name) + + updater = MBMelGANUpdater( + models={ + "generator": generator, + "discriminator": discriminator, + }, + optimizers={ + "generator": optimizer_g, + "discriminator": optimizer_d, + }, + criterions={ + "stft": criterion_stft, + "sub_stft": criterion_sub_stft, + "gen_adv": criterion_gen_adv, + "dis_adv": criterion_dis_adv, + "pqmf": criterion_pqmf + }, + schedulers={ + "generator": lr_schedule_g, + "discriminator": lr_schedule_d, + }, + dataloader=train_dataloader, + discriminator_train_start_steps=config.discriminator_train_start_steps, + lambda_adv=config.lambda_adv, + output_dir=output_dir) + + evaluator = MBMelGANEvaluator( + models={ + "generator": generator, + "discriminator": discriminator, + }, + criterions={ + "stft": criterion_stft, + "sub_stft": criterion_sub_stft, + "gen_adv": criterion_gen_adv, + "dis_adv": criterion_dis_adv, + "pqmf": criterion_pqmf + }, + dataloader=dev_dataloader, + lambda_adv=config.lambda_adv, + output_dir=output_dir) + + trainer = Trainer( + updater, + stop_trigger=(config.train_max_steps, "iteration"), + out=output_dir) + + if dist.get_rank() == 0: + trainer.extend( + evaluator, trigger=(config.eval_interval_steps, 'iteration')) + trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration')) + trainer.extend( + Snapshot(max_size=config.num_snapshots), + trigger=(config.save_interval_steps, 'iteration')) + + print("Trainer Done!") + trainer.run() + + +def main(): + # parse args and config and redirect to train_sp + + parser = argparse.ArgumentParser( + description="Train a Multi-Band MelGAN model.") + parser.add_argument( + "--config", type=str, help="config file to overwrite default config.") + parser.add_argument("--train-metadata", type=str, help="training data.") + parser.add_argument("--dev-metadata", type=str, help="dev data.") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--device", type=str, default="gpu", help="device type to use.") + parser.add_argument( + "--nprocs", type=int, default=1, help="number of processes.") + parser.add_argument("--verbose", type=int, default=1, help="verbose.") + + args = parser.parse_args() + if args.device == "cpu" and args.nprocs > 1: + raise RuntimeError("Multiprocess training on CPU is not supported.") + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + # dispatch + if args.nprocs > 1: + dist.spawn(train_sp, (args, config), nprocs=args.nprocs) + else: + train_sp(args, config) + + +if __name__ == "__main__": + main() diff --git a/parakeet/exps/gan_vocoder/parallelwave_gan/synthesize.py b/parakeet/exps/gan_vocoder/parallelwave_gan/synthesize.py index 9129caa54ebab33e726d5ea215c9ff222d5f22a6..2400e00b4eb131b1b029d3e2c4b3a45f01a81a23 100644 --- a/parakeet/exps/gan_vocoder/parallelwave_gan/synthesize.py +++ b/parakeet/exps/gan_vocoder/parallelwave_gan/synthesize.py @@ -86,8 +86,9 @@ def main(): N += wav.size T += t.elapse speed = wav.size / t.elapse + rtf = config.fs / speed print( - f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {config.fs / speed}." + f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." ) sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs) print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }") diff --git a/parakeet/exps/gan_vocoder/parallelwave_gan/train.py b/parakeet/exps/gan_vocoder/parallelwave_gan/train.py index 7a16ca597ed49014011648a04d298a7a5906ef43..ad50b65c71b324079d85d58f2d85d894ca145708 100644 --- a/parakeet/exps/gan_vocoder/parallelwave_gan/train.py +++ b/parakeet/exps/gan_vocoder/parallelwave_gan/train.py @@ -28,7 +28,6 @@ from paddle.io import DataLoader from paddle.io import DistributedBatchSampler from paddle.optimizer import Adam # No RAdaom from paddle.optimizer.lr import StepDecay -from visualdl import LogWriter from yacs.config import CfgNode from parakeet.datasets.data_table import DataTable @@ -193,8 +192,7 @@ def train_sp(args, config): if dist.get_rank() == 0: trainer.extend( evaluator, trigger=(config.eval_interval_steps, 'iteration')) - writer = LogWriter(str(trainer.out)) - trainer.extend(VisualDL(writer), trigger=(1, 'iteration')) + trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration')) trainer.extend( Snapshot(max_size=config.num_snapshots), trigger=(config.save_interval_steps, 'iteration')) diff --git a/parakeet/exps/speedyspeech/inference.py b/parakeet/exps/speedyspeech/inference.py index bf144d760f597cb72479beb87ea4f752eb97500b..77a90915b8186bcc915151c8b86ce672773f6cce 100644 --- a/parakeet/exps/speedyspeech/inference.py +++ b/parakeet/exps/speedyspeech/inference.py @@ -96,8 +96,8 @@ def main(): input_ids = frontend.get_input_ids( sentence, merge_sentences=True, get_tone_ids=True) - phone_ids = input_ids["phone_ids"] - tone_ids = input_ids["tone_ids"] + phone_ids = input_ids["phone_ids"].numpy() + tone_ids = input_ids["tone_ids"].numpy() phones = phone_ids[0] tones = tone_ids[0] diff --git a/parakeet/exps/speedyspeech/train.py b/parakeet/exps/speedyspeech/train.py index ea9fe20d7ffa0d903d52245740e8ae1c4e4a46b0..6a4bf59e1aeeb0918202bca0bcc3d9cf8bd63211 100644 --- a/parakeet/exps/speedyspeech/train.py +++ b/parakeet/exps/speedyspeech/train.py @@ -25,7 +25,6 @@ from paddle import DataParallel from paddle import distributed as dist from paddle.io import DataLoader from paddle.io import DistributedBatchSampler -from visualdl import LogWriter from yacs.config import CfgNode from parakeet.datasets.am_batch_fn import speedyspeech_batch_fn @@ -153,8 +152,7 @@ def train_sp(args, config): if dist.get_rank() == 0: trainer.extend(evaluator, trigger=(1, "epoch")) - writer = LogWriter(str(output_dir)) - trainer.extend(VisualDL(writer), trigger=(1, "iteration")) + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) trainer.extend( Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) trainer.run() diff --git a/parakeet/exps/tacotron2/ljspeech.py b/parakeet/exps/tacotron2/ljspeech.py index 20dc29d37c8d13ec12623e58e7883fefa17e3e78..59c855eb6ca172b8c46ca43b5f0bb741ada74089 100644 --- a/parakeet/exps/tacotron2/ljspeech.py +++ b/parakeet/exps/tacotron2/ljspeech.py @@ -67,19 +67,16 @@ class LJSpeechCollector(object): # Sort by text_len in descending order texts = [ - i - for i, _ in sorted( + i for i, _ in sorted( zip(texts, text_lens), key=lambda x: x[1], reverse=True) ] mels = [ - i - for i, _ in sorted( + i for i, _ in sorted( zip(mels, text_lens), key=lambda x: x[1], reverse=True) ] mel_lens = [ - i - for i, _ in sorted( + i for i, _ in sorted( zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True) ] diff --git a/parakeet/exps/transformer_tts/train.py b/parakeet/exps/transformer_tts/train.py index fdaff347521e073af5f51c93d0810242021743c8..bf0663908e9d5a2c83c0ed2bce96cb5a697c3c57 100644 --- a/parakeet/exps/transformer_tts/train.py +++ b/parakeet/exps/transformer_tts/train.py @@ -25,7 +25,6 @@ from paddle import DataParallel from paddle import distributed as dist from paddle.io import DataLoader from paddle.io import DistributedBatchSampler -from visualdl import LogWriter from yacs.config import CfgNode from parakeet.datasets.am_batch_fn import transformer_single_spk_batch_fn @@ -148,8 +147,7 @@ def train_sp(args, config): if dist.get_rank() == 0: trainer.extend(evaluator, trigger=(1, "epoch")) - writer = LogWriter(str(output_dir)) - trainer.extend(VisualDL(writer), trigger=(1, "iteration")) + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) trainer.extend( Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) # print(trainer.extensions) diff --git a/parakeet/models/fastspeech2/fastspeech2.py b/parakeet/models/fastspeech2/fastspeech2.py index bde3a82ba95d76c192b8b07ad492926ff8b72210..192517b168a6311764ab768cfc78cfb03829aeb2 100644 --- a/parakeet/models/fastspeech2/fastspeech2.py +++ b/parakeet/models/fastspeech2/fastspeech2.py @@ -341,6 +341,7 @@ class FastSpeech2(nn.Layer): Tensor speech_lengths, modified if reduction_factor > 1 """ + # input of embedding must be int64 xs = paddle.cast(text, 'int64') ilens = paddle.cast(text_lengths, 'int64') @@ -388,7 +389,6 @@ class FastSpeech2(nn.Layer): tone_id=None) -> Sequence[paddle.Tensor]: # forward encoder x_masks = self._source_mask(ilens) - # (B, Tmax, adim) hs, _ = self.encoder(xs, x_masks) @@ -405,7 +405,6 @@ class FastSpeech2(nn.Layer): if tone_id is not None: tone_embs = self.tone_embedding_table(tone_id) hs = self._integrate_with_tone_embed(hs, tone_embs) - # forward duration predictor and variance predictors d_masks = make_pad_mask(ilens) @@ -437,6 +436,7 @@ class FastSpeech2(nn.Layer): e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose( (0, 2, 1)) hs = hs + e_embs + p_embs + # (B, Lmax, adim) hs = self.length_regulator(hs, d_outs, alpha) else: @@ -447,6 +447,7 @@ class FastSpeech2(nn.Layer): e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose( (0, 2, 1)) hs = hs + e_embs + p_embs + # (B, Lmax, adim) hs = self.length_regulator(hs, ds) @@ -461,9 +462,11 @@ class FastSpeech2(nn.Layer): else: h_masks = None # (B, Lmax, adim) + zs, _ = self.decoder(hs, h_masks) # (B, Lmax, odim) - before_outs = self.feat_out(zs).reshape((zs.shape[0], -1, self.odim)) + before_outs = self.feat_out(zs).reshape( + (paddle.shape(zs)[0], -1, self.odim)) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: @@ -526,8 +529,8 @@ class FastSpeech2(nn.Layer): d = paddle.cast(durations, 'int64') p, e = pitch, energy # setup batch axis - ilens = paddle.to_tensor( - [x.shape[0]], dtype=paddle.int64, place=x.place) + ilens = paddle.shape(x)[0] + xs, ys = x.unsqueeze(0), None if y is not None: diff --git a/parakeet/models/melgan/__init__.py b/parakeet/models/melgan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4f557db6d155df5f12ae1de5a350f9d523f5ed9 --- /dev/null +++ b/parakeet/models/melgan/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .melgan import * +from .multi_band_melgan_updater import * diff --git a/parakeet/models/melgan/melgan.py b/parakeet/models/melgan/melgan.py new file mode 100644 index 0000000000000000000000000000000000000000..3f002b80c6dfe17b973bc16b4d2831975222fdc2 --- /dev/null +++ b/parakeet/models/melgan/melgan.py @@ -0,0 +1,553 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MelGAN Modules.""" +from typing import Any +from typing import Dict +from typing import List + +import numpy as np +import paddle +from paddle import nn + +from parakeet.modules.causal_conv import CausalConv1D +from parakeet.modules.causal_conv import CausalConv1DTranspose +from parakeet.modules.nets_utils import initialize +from parakeet.modules.pqmf import PQMF +from parakeet.modules.residual_stack import ResidualStack + + +class MelGANGenerator(nn.Layer): + """MelGAN generator module.""" + + def __init__( + self, + in_channels: int=80, + out_channels: int=1, + kernel_size: int=7, + channels: int=512, + bias: bool=True, + upsample_scales: List[int]=[8, 8, 2, 2], + stack_kernel_size: int=3, + stacks: int=3, + nonlinear_activation: str="LeakyReLU", + nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2}, + pad: str="Pad1D", + pad_params: Dict[str, Any]={"mode": "reflect"}, + use_final_nonlinear_activation: bool=True, + use_weight_norm: bool=True, + use_causal_conv: bool=False, + init_type: str="xavier_uniform", ): + """Initialize MelGANGenerator module. + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels, + the number of sub-band is out_channels in multi-band melgan. + kernel_size : int + Kernel size of initial and final conv layer. + channels : int + Initial number of channels for conv layer. + bias : bool + Whether to add bias parameter in convolution layers. + upsample_scales : List[int] + List of upsampling scales. + stack_kernel_size : int + Kernel size of dilated conv layers in residual stack. + stacks : int + Number of stacks in a single residual stack. + nonlinear_activation : Optional[str], optional + Non linear activation in upsample network, by default None + nonlinear_activation_params : Dict[str, Any], optional + Parameters passed to the linear activation in the upsample network, + by default {} + pad : str + Padding function module name before dilated convolution layer. + pad_params : dict + Hyperparameters for padding function. + use_final_nonlinear_activation : paddle.nn.Layer + Activation function for the final layer. + use_weight_norm : bool + Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + use_causal_conv : bool + Whether to use causal convolution. + """ + super().__init__() + + # check hyper parameters is valid + assert channels >= np.prod(upsample_scales) + assert channels % (2**len(upsample_scales)) == 0 + if not use_causal_conv: + assert (kernel_size - 1 + ) % 2 == 0, "Not support even number kernel size." + + # initialize parameters + initialize(self, init_type) + + layers = [] + if not use_causal_conv: + layers += [ + getattr(paddle.nn, pad)((kernel_size - 1) // 2, **pad_params), + nn.Conv1D(in_channels, channels, kernel_size, bias_attr=bias), + ] + else: + layers += [ + CausalConv1D( + in_channels, + channels, + kernel_size, + bias=bias, + pad=pad, + pad_params=pad_params, ), + ] + + for i, upsample_scale in enumerate(upsample_scales): + # add upsampling layer + layers += [ + getattr(nn, nonlinear_activation)(**nonlinear_activation_params) + ] + if not use_causal_conv: + layers += [ + nn.Conv1DTranspose( + channels // (2**i), + channels // (2**(i + 1)), + upsample_scale * 2, + stride=upsample_scale, + padding=upsample_scale // 2 + upsample_scale % 2, + output_padding=upsample_scale % 2, + bias_attr=bias, ) + ] + else: + layers += [ + CausalConv1DTranspose( + channels // (2**i), + channels // (2**(i + 1)), + upsample_scale * 2, + stride=upsample_scale, + bias=bias, ) + ] + + # add residual stack + for j in range(stacks): + layers += [ + ResidualStack( + kernel_size=stack_kernel_size, + channels=channels // (2**(i + 1)), + dilation=stack_kernel_size**j, + bias=bias, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + pad=pad, + pad_params=pad_params, + use_causal_conv=use_causal_conv, ) + ] + + # add final layer + layers += [ + getattr(nn, nonlinear_activation)(**nonlinear_activation_params) + ] + if not use_causal_conv: + layers += [ + getattr(nn, pad)((kernel_size - 1) // 2, **pad_params), + nn.Conv1D( + channels // (2**(i + 1)), + out_channels, + kernel_size, + bias_attr=bias), + ] + else: + layers += [ + CausalConv1D( + channels // (2**(i + 1)), + out_channels, + kernel_size, + bias=bias, + pad=pad, + pad_params=pad_params, ), + ] + if use_final_nonlinear_activation: + layers += [nn.Tanh()] + + # define the model as a single function + self.melgan = nn.Sequential(*layers) + nn.initializer.set_global_initializer(None) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + # initialize pqmf for multi-band melgan inference + if out_channels > 1: + self.pqmf = PQMF(subbands=out_channels) + else: + self.pqmf = None + + def forward(self, c): + """Calculate forward propagation. + Parameters + ---------- + c : Tensor + Input tensor (B, in_channels, T). + Returns + ---------- + Tensor + Output tensor (B, out_channels, T ** prod(upsample_scales)). + """ + out = self.melgan(c) + return out + + def apply_weight_norm(self): + """Recursively apply weight normalization to all the Convolution layers + in the sublayers. + """ + + def _apply_weight_norm(layer): + if isinstance(layer, (nn.Conv1D, nn.Conv2D, nn.Conv1DTranspose)): + nn.utils.weight_norm(layer) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + """Recursively remove weight normalization from all the Convolution + layers in the sublayers. + """ + + def _remove_weight_norm(layer): + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + pass + + self.apply(_remove_weight_norm) + + def reset_parameters(self): + """Reset parameters. + This initialization follows official implementation manner. + https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py + """ + + # 定义参数为float的正态分布。 + dist = paddle.distribution.Normal(loc=0.0, scale=0.02) + + def _reset_parameters(m): + if isinstance(m, nn.Conv1D) or isinstance(m, nn.Conv1DTranspose): + w = dist.sample(m.weight.shape) + m.weight.set_value(w) + + self.apply(_reset_parameters) + + def inference(self, c): + """Perform inference. + Parameters + ---------- + c : Union[Tensor, ndarray] + Input tensor (T, in_channels). + Returns + ---------- + Tensor + Output tensor (out_channels*T ** prod(upsample_scales), 1). + """ + if not isinstance(c, paddle.Tensor): + c = paddle.to_tensor(c, dtype="float32") + # pseudo batch + c = c.transpose([1, 0]).unsqueeze(0) + # (B, out_channels, T ** prod(upsample_scales) + out = self.melgan(c) + if self.pqmf is not None: + # (B, 1, out_channels * T ** prod(upsample_scales) + out = self.pqmf.synthesis(out) + out = out.squeeze(0).transpose([1, 0]) + return out + + +class MelGANDiscriminator(nn.Layer): + """MelGAN discriminator module.""" + + def __init__( + self, + in_channels: int=1, + out_channels: int=1, + kernel_sizes: List[int]=[5, 3], + channels: int=16, + max_downsample_channels: int=1024, + bias: bool=True, + downsample_scales: List[int]=[4, 4, 4, 4], + nonlinear_activation: str="LeakyReLU", + nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2}, + pad: str="Pad1D", + pad_params: Dict[str, Any]={"mode": "reflect"}, ): + """Initilize MelGAN discriminator module. + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_sizes : List[int] + List of two kernel sizes. The prod will be used for the first conv layer, + and the first and the second kernel sizes will be used for the last two layers. + For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15, + the last two layers' kernel size will be 5 and 3, respectively. + channels : int + Initial number of channels for conv layer. + max_downsample_channels : int + Maximum number of channels for downsampling layers. + bias : bool + Whether to add bias parameter in convolution layers. + downsample_scales : List[int] + List of downsampling scales. + nonlinear_activation : str + Activation function module name. + nonlinear_activation_params : dict + Hyperparameters for activation function. + pad : str + Padding function module name before dilated convolution layer. + pad_params : dict + Hyperparameters for padding function. + """ + super().__init__() + self.layers = nn.LayerList() + + # check kernel size is valid + assert len(kernel_sizes) == 2 + assert kernel_sizes[0] % 2 == 1 + assert kernel_sizes[1] % 2 == 1 + + # add first layer + self.layers.append( + nn.Sequential( + getattr(nn, pad)((np.prod(kernel_sizes) - 1) // 2, ** + pad_params), + nn.Conv1D( + in_channels, + channels, + int(np.prod(kernel_sizes)), + bias_attr=bias), + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), )) + + # add downsample layers + in_chs = channels + for downsample_scale in downsample_scales: + out_chs = min(in_chs * downsample_scale, max_downsample_channels) + self.layers.append( + nn.Sequential( + nn.Conv1D( + in_chs, + out_chs, + kernel_size=downsample_scale * 10 + 1, + stride=downsample_scale, + padding=downsample_scale * 5, + groups=in_chs // 4, + bias_attr=bias, ), + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), )) + in_chs = out_chs + + # add final layers + out_chs = min(in_chs * 2, max_downsample_channels) + self.layers.append( + nn.Sequential( + nn.Conv1D( + in_chs, + out_chs, + kernel_sizes[0], + padding=(kernel_sizes[0] - 1) // 2, + bias_attr=bias, ), + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), )) + self.layers.append( + nn.Conv1D( + out_chs, + out_channels, + kernel_sizes[1], + padding=(kernel_sizes[1] - 1) // 2, + bias_attr=bias, ), ) + + def forward(self, x): + """Calculate forward propagation. + Parameters + ---------- + x : Tensor + Input noise signal (B, 1, T). + Returns + ---------- + List + List of output tensors of each layer (for feat_match_loss). + """ + outs = [] + for f in self.layers: + x = f(x) + outs += [x] + + return outs + + +class MelGANMultiScaleDiscriminator(nn.Layer): + """MelGAN multi-scale discriminator module.""" + + def __init__( + self, + in_channels: int=1, + out_channels: int=1, + scales: int=3, + downsample_pooling: str="AvgPool1D", + # follow the official implementation setting + downsample_pooling_params: Dict[str, Any]={ + "kernel_size": 4, + "stride": 2, + "padding": 1, + "exclusive": True, + }, + kernel_sizes: List[int]=[5, 3], + channels: int=16, + max_downsample_channels: int=1024, + bias: bool=True, + downsample_scales: List[int]=[4, 4, 4, 4], + nonlinear_activation: str="LeakyReLU", + nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2}, + pad: str="Pad1D", + pad_params: Dict[str, Any]={"mode": "reflect"}, + use_weight_norm: bool=True, + init_type: str="xavier_uniform", ): + """Initilize MelGAN multi-scale discriminator module. + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + scales : int + Number of multi-scales. + downsample_pooling : str + Pooling module name for downsampling of the inputs. + downsample_pooling_params : dict + Parameters for the above pooling module. + kernel_sizes : List[int] + List of two kernel sizes. The sum will be used for the first conv layer, + and the first and the second kernel sizes will be used for the last two layers. + channels : int + Initial number of channels for conv layer. + max_downsample_channels : int + Maximum number of channels for downsampling layers. + bias : bool + Whether to add bias parameter in convolution layers. + downsample_scales : List[int] + List of downsampling scales. + nonlinear_activation : str + Activation function module name. + nonlinear_activation_params : dict + Hyperparameters for activation function. + pad : str + Padding function module name before dilated convolution layer. + pad_params : dict + Hyperparameters for padding function. + use_causal_conv : bool + Whether to use causal convolution. + """ + super().__init__() + # initialize parameters + initialize(self, init_type) + + self.discriminators = nn.LayerList() + + # add discriminators + for _ in range(scales): + self.discriminators.append( + MelGANDiscriminator( + in_channels=in_channels, + out_channels=out_channels, + kernel_sizes=kernel_sizes, + channels=channels, + max_downsample_channels=max_downsample_channels, + bias=bias, + downsample_scales=downsample_scales, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + pad=pad, + pad_params=pad_params, )) + self.pooling = getattr(nn, downsample_pooling)( + **downsample_pooling_params) + + nn.initializer.set_global_initializer(None) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + def forward(self, x): + """Calculate forward propagation. + Parameters + ---------- + x : Tensor + Input noise signal (B, 1, T). + Returns + ---------- + List + List of list of each discriminator outputs, which consists of each layer output tensors. + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + x = self.pooling(x) + + return outs + + def apply_weight_norm(self): + """Recursively apply weight normalization to all the Convolution layers + in the sublayers. + """ + + def _apply_weight_norm(layer): + if isinstance(layer, (nn.Conv1D, nn.Conv2D, nn.Conv1DTranspose)): + nn.utils.weight_norm(layer) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + """Recursively remove weight normalization from all the Convolution + layers in the sublayers. + """ + + def _remove_weight_norm(layer): + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + pass + + self.apply(_remove_weight_norm) + + def reset_parameters(self): + """Reset parameters. + This initialization follows official implementation manner. + https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py + """ + + # 定义参数为float的正态分布。 + dist = paddle.distribution.Normal(loc=0.0, scale=0.02) + + def _reset_parameters(m): + if isinstance(m, nn.Conv1D) or isinstance(m, nn.Conv1DTranspose): + w = dist.sample(m.weight.shape) + m.weight.set_value(w) + + self.apply(_reset_parameters) diff --git a/parakeet/models/melgan/multi_band_melgan_updater.py b/parakeet/models/melgan/multi_band_melgan_updater.py new file mode 100644 index 0000000000000000000000000000000000000000..0783cb9749ab70434a5395ad07d5a71956195e64 --- /dev/null +++ b/parakeet/models/melgan/multi_band_melgan_updater.py @@ -0,0 +1,245 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Dict + +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from paddle.optimizer.lr import LRScheduler + +from parakeet.training.extensions.evaluator import StandardEvaluator +from parakeet.training.reporter import report +from parakeet.training.updaters.standard_updater import StandardUpdater +from parakeet.training.updaters.standard_updater import UpdaterState +logging.basicConfig( + format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', + datefmt='[%Y-%m-%d %H:%M:%S]') +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class MBMelGANUpdater(StandardUpdater): + def __init__(self, + models: Dict[str, Layer], + optimizers: Dict[str, Optimizer], + criterions: Dict[str, Layer], + schedulers: Dict[str, LRScheduler], + dataloader: DataLoader, + discriminator_train_start_steps: int, + lambda_adv: float, + output_dir=None): + self.models = models + self.generator: Layer = models['generator'] + self.discriminator: Layer = models['discriminator'] + + self.optimizers = optimizers + self.optimizer_g: Optimizer = optimizers['generator'] + self.optimizer_d: Optimizer = optimizers['discriminator'] + + self.criterions = criterions + self.criterion_stft = criterions['stft'] + self.criterion_sub_stft = criterions['sub_stft'] + self.criterion_pqmf = criterions['pqmf'] + self.criterion_gen_adv = criterions["gen_adv"] + self.criterion_dis_adv = criterions["dis_adv"] + + self.schedulers = schedulers + self.scheduler_g = schedulers['generator'] + self.scheduler_d = schedulers['discriminator'] + + self.dataloader = dataloader + + self.discriminator_train_start_steps = discriminator_train_start_steps + self.lambda_adv = lambda_adv + self.state = UpdaterState(iteration=0, epoch=0) + + self.train_iterator = iter(self.dataloader) + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def update_core(self, batch): + self.msg = "Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + + # parse batch + wav, mel = batch + # Generator + # (B, out_channels, T ** prod(upsample_scales) + wav_ = self.generator(mel) + wav_mb_ = wav_ + # (B, 1, out_channels*T ** prod(upsample_scales) + wav_ = self.criterion_pqmf.synthesis(wav_mb_) + + # initialize + gen_loss = 0.0 + + # full band Multi-resolution stft loss + sc_loss, mag_loss = self.criterion_stft(wav_, wav) + # for balancing with subband stft loss + # Eq.(9) in paper + gen_loss += 0.5 * (sc_loss + mag_loss) + report("train/spectral_convergence_loss", float(sc_loss)) + report("train/log_stft_magnitude_loss", float(mag_loss)) + losses_dict["spectral_convergence_loss"] = float(sc_loss) + losses_dict["log_stft_magnitude_loss"] = float(mag_loss) + + # sub band Multi-resolution stft loss + # (B, subbands, T // subbands) + wav_mb = self.criterion_pqmf.analysis(wav) + sub_sc_loss, sub_mag_loss = self.criterion_sub_stft(wav_mb_, wav_mb) + # Eq.(9) in paper + gen_loss += 0.5 * (sub_sc_loss + sub_mag_loss) + report("train/sub_spectral_convergence_loss", float(sub_sc_loss)) + report("train/sub_log_stft_magnitude_loss", float(sub_mag_loss)) + losses_dict["sub_spectral_convergence_loss"] = float(sub_sc_loss) + losses_dict["sub_log_stft_magnitude_loss"] = float(sub_mag_loss) + + ## Adversarial loss + if self.state.iteration > self.discriminator_train_start_steps: + p_ = self.discriminator(wav_) + adv_loss = self.criterion_gen_adv(p_) + + report("train/adversarial_loss", float(adv_loss)) + losses_dict["adversarial_loss"] = float(adv_loss) + gen_loss += self.lambda_adv * adv_loss + + report("train/generator_loss", float(gen_loss)) + losses_dict["generator_loss"] = float(gen_loss) + + self.optimizer_g.clear_grad() + gen_loss.backward() + + self.optimizer_g.step() + self.scheduler_g.step() + + # Disctiminator + if self.state.iteration > self.discriminator_train_start_steps: + # re-compute wav_ which leads better quality + with paddle.no_grad(): + wav_ = self.generator(mel) + wav_ = self.criterion_pqmf.synthesis(wav_) + p = self.discriminator(wav) + p_ = self.discriminator(wav_.detach()) + real_loss, fake_loss = self.criterion_dis_adv(p_, p) + dis_loss = real_loss + fake_loss + report("train/real_loss", float(real_loss)) + report("train/fake_loss", float(fake_loss)) + report("train/discriminator_loss", float(dis_loss)) + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + self.optimizer_d.clear_grad() + dis_loss.backward() + + self.optimizer_d.step() + self.scheduler_d.step() + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + + +class MBMelGANEvaluator(StandardEvaluator): + def __init__(self, + models, + criterions, + dataloader, + lambda_adv, + output_dir=None): + self.models = models + self.generator = models['generator'] + self.discriminator = models['discriminator'] + + self.criterions = criterions + self.criterion_stft = criterions['stft'] + self.criterion_sub_stft = criterions['sub_stft'] + self.criterion_pqmf = criterions['pqmf'] + self.criterion_gen_adv = criterions["gen_adv"] + self.criterion_dis_adv = criterions["dis_adv"] + + self.dataloader = dataloader + self.lambda_adv = lambda_adv + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def evaluate_core(self, batch): + # logging.debug("Evaluate: ") + self.msg = "Evaluate: " + losses_dict = {} + + wav, mel = batch + # Generator + # (B, out_channels, T ** prod(upsample_scales) + wav_ = self.generator(mel) + wav_mb_ = wav_ + # (B, 1, out_channels*T ** prod(upsample_scales) + wav_ = self.criterion_pqmf.synthesis(wav_mb_) + + ## Adversarial loss + p_ = self.discriminator(wav_) + adv_loss = self.criterion_gen_adv(p_) + + report("eval/adversarial_loss", float(adv_loss)) + losses_dict["adversarial_loss"] = float(adv_loss) + gen_loss = self.lambda_adv * adv_loss + + # Multi-resolution stft loss + sc_loss, mag_loss = self.criterion_stft(wav_, wav) + # Eq.(9) in paper + gen_loss += 0.5 * (sc_loss + mag_loss) + report("eval/spectral_convergence_loss", float(sc_loss)) + report("eval/log_stft_magnitude_loss", float(mag_loss)) + losses_dict["spectral_convergence_loss"] = float(sc_loss) + losses_dict["log_stft_magnitude_loss"] = float(mag_loss) + + # sub band Multi-resolution stft loss + # (B, subbands, T // subbands) + wav_mb = self.criterion_pqmf.analysis(wav) + sub_sc_loss, sub_mag_loss = self.criterion_sub_stft(wav_mb_, wav_mb) + # Eq.(9) in paper + gen_loss += 0.5 * (sub_sc_loss + sub_mag_loss) + report("eval/sub_spectral_convergence_loss", float(sub_sc_loss)) + report("eval/sub_log_stft_magnitude_loss", float(sub_mag_loss)) + losses_dict["sub_spectral_convergence_loss"] = float(sub_sc_loss) + losses_dict["sub_log_stft_magnitude_loss"] = float(sub_mag_loss) + + report("eval/generator_loss", float(gen_loss)) + losses_dict["generator_loss"] = float(gen_loss) + + # Disctiminator + p = self.discriminator(wav) + real_loss, fake_loss = self.criterion_dis_adv(p_, p) + dis_loss = real_loss + fake_loss + report("eval/real_loss", float(real_loss)) + report("eval/fake_loss", float(fake_loss)) + report("eval/discriminator_loss", float(dis_loss)) + + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + self.logger.info(self.msg) diff --git a/parakeet/models/parallel_wavegan/parallel_wavegan.py b/parakeet/models/parallel_wavegan/parallel_wavegan.py index bb21465304fc062e7de6b9840e3dc2fdc2be1b3a..fe4ec355139d7048e6336f8678e751d489269b67 100644 --- a/parakeet/models/parallel_wavegan/parallel_wavegan.py +++ b/parakeet/models/parallel_wavegan/parallel_wavegan.py @@ -498,7 +498,6 @@ class PWGGenerator(nn.Layer): def inference(self, c=None): """Waveform generation. This function is used for single instance inference. - Parameters ---------- c : Tensor, optional @@ -506,12 +505,12 @@ class PWGGenerator(nn.Layer): x : Tensor, optional Shape (T, C_in), the noise waveform, by default None If not provided, a sample is drawn from a gaussian distribution. - Returns ------- Tensor Shape (T, C_out), the generated waveform """ + # when to static, can not input x, see https://github.com/PaddlePaddle/Parakeet/pull/132/files x = paddle.randn( [1, self.in_channels, paddle.shape(c)[0] * self.upsample_factor]) c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch diff --git a/parakeet/modules/adversarial_loss.py b/parakeet/modules/adversarial_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..02e8c807dcecf13a417bb4a7a33d5ec3037c52fe --- /dev/null +++ b/parakeet/modules/adversarial_loss.py @@ -0,0 +1,124 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Adversarial loss modules.""" +import paddle +import paddle.nn.functional as F +from paddle import nn + + +class GeneratorAdversarialLoss(nn.Layer): + """Generator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators=True, + loss_type="mse", ): + """Initialize GeneratorAversarialLoss module.""" + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.criterion = self._mse_loss + else: + self.criterion = self._hinge_loss + + def forward(self, outputs): + """Calcualate generator adversarial loss. + Parameters + ---------- + outputs: Tensor or List + Discriminator outputs or list of discriminator outputs. + Returns + ---------- + Tensor + Generator adversarial loss value. + """ + if isinstance(outputs, (tuple, list)): + adv_loss = 0.0 + for i, outputs_ in enumerate(outputs): + if isinstance(outputs_, (tuple, list)): + # case including feature maps + outputs_ = outputs_[-1] + adv_loss += self.criterion(outputs_) + if self.average_by_discriminators: + adv_loss /= i + 1 + else: + adv_loss = self.criterion(outputs) + + return adv_loss + + def _mse_loss(self, x): + return F.mse_loss(x, paddle.ones_like(x)) + + def _hinge_loss(self, x): + return -x.mean() + + +class DiscriminatorAdversarialLoss(nn.Layer): + """Discriminator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators=True, + loss_type="mse", ): + """Initialize DiscriminatorAversarialLoss module.""" + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse"], f"{loss_type} is not supported." + if loss_type == "mse": + self.fake_criterion = self._mse_fake_loss + self.real_criterion = self._mse_real_loss + + def forward(self, outputs_hat, outputs): + """Calcualate discriminator adversarial loss. + Parameters + ---------- + outputs_hat : Tensor or list + Discriminator outputs or list of + discriminator outputs calculated from generator outputs. + outputs : Tensor or list + Discriminator outputs or list of + discriminator outputs calculated from groundtruth. + Returns + ---------- + Tensor + Discriminator real loss value. + Tensor + Discriminator fake loss value. + """ + if isinstance(outputs, (tuple, list)): + real_loss = 0.0 + fake_loss = 0.0 + for i, (outputs_hat_, + outputs_) in enumerate(zip(outputs_hat, outputs)): + if isinstance(outputs_hat_, (tuple, list)): + # case including feature maps + outputs_hat_ = outputs_hat_[-1] + outputs_ = outputs_[-1] + real_loss += self.real_criterion(outputs_) + fake_loss += self.fake_criterion(outputs_hat_) + if self.average_by_discriminators: + fake_loss /= i + 1 + real_loss /= i + 1 + else: + real_loss = self.real_criterion(outputs) + fake_loss = self.fake_criterion(outputs_hat) + + return real_loss, fake_loss + + def _mse_real_loss(self, x): + return F.mse_loss(x, paddle.ones_like(x)) + + def _mse_fake_loss(self, x): + return F.mse_loss(x, paddle.zeros_like(x)) diff --git a/parakeet/modules/causal_conv.py b/parakeet/modules/causal_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..c0dd5b28c96ce8bfd4edd52a27d687ff5901f27f --- /dev/null +++ b/parakeet/modules/causal_conv.py @@ -0,0 +1,81 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Causal convolusion layer modules.""" +import paddle + + +class CausalConv1D(paddle.nn.Layer): + """CausalConv1D module with customized initialization.""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation=1, + bias=True, + pad="Pad1D", + pad_params={"value": 0.0}, ): + """Initialize CausalConv1d module.""" + super().__init__() + self.pad = getattr(paddle.nn, pad)((kernel_size - 1) * dilation, + **pad_params) + self.conv = paddle.nn.Conv1D( + in_channels, + out_channels, + kernel_size, + dilation=dilation, + bias_attr=bias) + + def forward(self, x): + """Calculate forward propagation. + Parameters + ---------- + x : Tensor + Input tensor (B, in_channels, T). + Returns + ---------- + Tensor + Output tensor (B, out_channels, T). + """ + return self.conv(self.pad(x))[:, :, :x.shape[2]] + + +class CausalConv1DTranspose(paddle.nn.Layer): + """CausalConv1DTranspose module with customized initialization.""" + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + bias=True): + """Initialize CausalConvTranspose1d module.""" + super().__init__() + self.deconv = paddle.nn.Conv1DTranspose( + in_channels, out_channels, kernel_size, stride, bias_attr=bias) + self.stride = stride + + def forward(self, x): + """Calculate forward propagation. + Parameters + ---------- + x : Tensor + Input tensor (B, in_channels, T_in). + Returns + ---------- + Tensor + Output tensor (B, out_channels, T_out). + """ + return self.deconv(x)[:, :, :-self.stride] diff --git a/parakeet/modules/fastspeech2_predictor/length_regulator.py b/parakeet/modules/fastspeech2_predictor/length_regulator.py index e5195e53637a8a6596ff8217271784e9b5e2c84d..a4d508add1aa1661fb127ee16806e3ca108e2736 100644 --- a/parakeet/modules/fastspeech2_predictor/length_regulator.py +++ b/parakeet/modules/fastspeech2_predictor/length_regulator.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Length regulator related modules.""" -import numpy as np import paddle from paddle import nn @@ -49,11 +48,10 @@ class LengthRegulator(nn.Layer): encodings: (B, T, C) durations: (B, T) """ - batch_size, t_enc = durations.shape - durations = durations.numpy() - slens = np.sum(durations, -1) - t_dec = np.max(slens) - M = np.zeros([batch_size, t_dec, t_enc]) + batch_size, t_enc = paddle.shape(durations) + slens = durations.sum(-1) + t_dec = slens.max() + M = paddle.zeros([batch_size, t_dec, t_enc]) for i in range(batch_size): k = 0 for j in range(t_enc): @@ -61,7 +59,6 @@ class LengthRegulator(nn.Layer): if d >= 1: M[i, k:k + d, j] = 1 k += d - M = paddle.to_tensor(M, dtype=encodings.dtype) encodings = paddle.matmul(M, encodings) return encodings @@ -82,6 +79,7 @@ class LengthRegulator(nn.Layer): Tensor replicated input tensor based on durations (B, T*, D). """ + if alpha != 1.0: assert alpha > 0 ds = paddle.round(ds.cast(dtype=paddle.float32) * alpha) diff --git a/parakeet/modules/fastspeech2_transformer/attention.py b/parakeet/modules/fastspeech2_transformer/attention.py index ae941a79aa445dc5f6eff50c0f4d4be0c1eddf5b..0bac47426d38ca002b65b6dd88e17033d8a88fa3 100644 --- a/parakeet/modules/fastspeech2_transformer/attention.py +++ b/parakeet/modules/fastspeech2_transformer/attention.py @@ -106,13 +106,11 @@ class MultiHeadedAttention(nn.Layer): n_batch = value.shape[0] softmax = paddle.nn.Softmax(axis=-1) if mask is not None: - mask = mask.unsqueeze(1) mask = paddle.logical_not(mask) - min_value = float( - numpy.finfo( - paddle.to_tensor(0, dtype=scores.dtype).numpy().dtype).min) - + # assume scores.dtype==paddle.float32, we only use "float32" here + dtype = str(scores.dtype).split(".")[-1] + min_value = numpy.finfo(dtype).min scores = masked_fill(scores, mask, min_value) # (batch, head, time1, time2) self.attn = softmax(scores) diff --git a/parakeet/modules/fastspeech2_transformer/embedding.py b/parakeet/modules/fastspeech2_transformer/embedding.py index 6c1c7245f52f6ffe9d8b5489f0daa4c9036bc56d..1dfd6dfdc000ed899d448a4fb17a6fd3686e013d 100644 --- a/parakeet/modules/fastspeech2_transformer/embedding.py +++ b/parakeet/modules/fastspeech2_transformer/embedding.py @@ -31,9 +31,16 @@ class PositionalEncoding(nn.Layer): Maximum input length. reverse : bool Whether to reverse the input position. + type : str + dtype of param """ - def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): + def __init__(self, + d_model, + dropout_rate, + max_len=5000, + dtype="float32", + reverse=False): """Construct an PositionalEncoding object.""" super(PositionalEncoding, self).__init__() self.d_model = d_model @@ -41,20 +48,21 @@ class PositionalEncoding(nn.Layer): self.xscale = math.sqrt(self.d_model) self.dropout = nn.Dropout(p=dropout_rate) self.pe = None - self.extend_pe(paddle.expand(paddle.to_tensor(0.0), (1, max_len))) + self.dtype = dtype + self.extend_pe(paddle.expand(paddle.zeros([1]), (1, max_len))) def extend_pe(self, x): """Reset the positional encodings.""" - - pe = paddle.zeros([x.shape[1], self.d_model]) + x_shape = paddle.shape(x) + pe = paddle.zeros([x_shape[1], self.d_model]) if self.reverse: position = paddle.arange( - x.shape[1] - 1, -1, -1.0, dtype=paddle.float32).unsqueeze(1) + x_shape[1] - 1, -1, -1.0, dtype=self.dtype).unsqueeze(1) else: position = paddle.arange( - 0, x.shape[1], dtype=paddle.float32).unsqueeze(1) + 0, x_shape[1], dtype=self.dtype).unsqueeze(1) div_term = paddle.exp( - paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * + paddle.arange(0, self.d_model, 2, dtype=self.dtype) * -(math.log(10000.0) / self.d_model)) pe[:, 0::2] = paddle.sin(position * div_term) pe[:, 1::2] = paddle.cos(position * div_term) @@ -75,7 +83,8 @@ class PositionalEncoding(nn.Layer): Encoded tensor (batch, time, `*`). """ self.extend_pe(x) - x = x * self.xscale + self.pe[:, :x.shape[1]] + T = paddle.shape(x)[1] + x = x * self.xscale + self.pe[:, :T] return self.dropout(x) @@ -92,21 +101,26 @@ class ScaledPositionalEncoding(PositionalEncoding): Dropout rate. max_len : int Maximum input length. + dtype : str + dtype of param """ - def __init__(self, d_model, dropout_rate, max_len=5000): + def __init__(self, d_model, dropout_rate, max_len=5000, dtype="float32"): """Initialize class.""" super().__init__( - d_model=d_model, dropout_rate=dropout_rate, max_len=max_len) - x = paddle.ones([1], dtype="float32") + d_model=d_model, + dropout_rate=dropout_rate, + max_len=max_len, + dtype=dtype) + x = paddle.ones([1], dtype=self.dtype) self.alpha = paddle.create_parameter( shape=x.shape, - dtype=str(x.numpy().dtype), + dtype=self.dtype, default_initializer=paddle.nn.initializer.Assign(x)) def reset_parameters(self): """Reset parameters.""" - self.alpha = paddle.to_tensor(1.0) + self.alpha = paddle.ones([1]) def forward(self, x): """Add positional encoding. @@ -115,12 +129,12 @@ class ScaledPositionalEncoding(PositionalEncoding): ---------- x : paddle.Tensor Input tensor (batch, time, `*`). - Returns ---------- paddle.Tensor Encoded tensor (batch, time, `*`). """ self.extend_pe(x) - x = x + self.alpha * self.pe[:, :x.shape[1]] + T = paddle.shape(x)[1] + x = x + self.alpha * self.pe[:, :T] return self.dropout(x) diff --git a/parakeet/modules/fastspeech2_transformer/encoder.py b/parakeet/modules/fastspeech2_transformer/encoder.py index 630b50ff5e96dd94acb3f1267398933b547859dd..996e9dee08aab5069f22d0e144f886687cc077dd 100644 --- a/parakeet/modules/fastspeech2_transformer/encoder.py +++ b/parakeet/modules/fastspeech2_transformer/encoder.py @@ -185,6 +185,7 @@ class Encoder(nn.Layer): paddle.Tensor Mask tensor (#batch, time). """ + xs = self.embed(xs) xs, masks = self.encoders(xs, masks) if self.normalize_before: diff --git a/parakeet/modules/layer_norm.py b/parakeet/modules/layer_norm.py index 3bab823f299917ec507f423607e844f00ecfdac4..a1c775fc8bfa6d1b720e58cf4db223f01c558c5d 100644 --- a/parakeet/modules/layer_norm.py +++ b/parakeet/modules/layer_norm.py @@ -44,6 +44,7 @@ class LayerNorm(paddle.nn.LayerNorm): paddle.Tensor Normalized tensor. """ + if self.dim == -1: return super(LayerNorm, self).forward(x) else: @@ -54,9 +55,12 @@ class LayerNorm(paddle.nn.LayerNorm): orig_perm = list(range(len_dim)) new_perm = orig_perm[:] - new_perm[self.dim], new_perm[len_dim - - 1] = new_perm[len_dim - - 1], new_perm[self.dim] + # Python style item change is not able when converting dygraph to static graph. + # new_perm[self.dim], new_perm[len_dim -1] = new_perm[len_dim -1], new_perm[self.dim] + # use C++ style item change here + temp = new_perm[self.dim] + new_perm[self.dim] = new_perm[len_dim - 1] + new_perm[len_dim - 1] = temp return paddle.transpose( super(LayerNorm, self).forward(paddle.transpose(x, new_perm)), diff --git a/parakeet/modules/masked_fill.py b/parakeet/modules/masked_fill.py index 34230f1c43999f64ee1c33f049c58fd557e722e4..b322225479c843672073ec5567fa4403137bdb26 100644 --- a/parakeet/modules/masked_fill.py +++ b/parakeet/modules/masked_fill.py @@ -25,12 +25,24 @@ def is_broadcastable(shp1, shp2): return True +# assume that len(shp1) == len(shp2) +def broadcast_shape(shp1, shp2): + result = [] + for a, b in zip(shp1[::-1], shp2[::-1]): + result.append(max(a, b)) + return result[::-1] + + def masked_fill(xs: paddle.Tensor, mask: paddle.Tensor, value: Union[float, int]): - assert is_broadcastable(xs.shape, mask.shape) is True - bshape = paddle.broadcast_shape(xs.shape, mask.shape) + # comment following line for converting dygraph to static graph. + # assert is_broadcastable(xs.shape, mask.shape) is True + # bshape = paddle.broadcast_shape(xs.shape, mask.shape) + bshape = broadcast_shape(xs.shape, mask.shape) + mask.stop_gradient = True mask = mask.broadcast_to(bshape) + trues = paddle.ones_like(xs) * value mask = mask.cast(dtype=paddle.bool) xs = paddle.where(mask, trues, xs) diff --git a/parakeet/modules/nets_utils.py b/parakeet/modules/nets_utils.py index 47eae65d68b00be7d3e8ab5aa4661c5fb028213a..0696335a5c5a4a9ae424a2c6e8f72ef480b5eaa5 100644 --- a/parakeet/modules/nets_utils.py +++ b/parakeet/modules/nets_utils.py @@ -56,7 +56,7 @@ def make_pad_mask(lengths, length_dim=-1): Parameters ---------- - lengths : LongTensor or List + lengths : LongTensor Batch of lengths (B,). Returns @@ -77,17 +77,11 @@ def make_pad_mask(lengths, length_dim=-1): if length_dim == 0: raise ValueError("length_dim cannot be 0: {}".format(length_dim)) - if not isinstance(lengths, list): - lengths = lengths.tolist() - bs = int(len(lengths)) - - maxlen = int(max(lengths)) - + bs = paddle.shape(lengths)[0] + maxlen = lengths.max() seq_range = paddle.arange(0, maxlen, dtype=paddle.int64) seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen]) - - seq_length_expand = paddle.to_tensor( - lengths, dtype=seq_range_expand.dtype).unsqueeze(-1) + seq_length_expand = lengths.unsqueeze(-1) mask = seq_range_expand >= seq_length_expand return mask diff --git a/parakeet/modules/pqmf.py b/parakeet/modules/pqmf.py new file mode 100644 index 0000000000000000000000000000000000000000..275addd2fe6524aa652a515ace6038e51606fc49 --- /dev/null +++ b/parakeet/modules/pqmf.py @@ -0,0 +1,140 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pseudo QMF modules.""" +import numpy as np +import paddle +import paddle.nn.functional as F +from scipy.signal import kaiser + + +def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0): + """Design prototype filter for PQMF. + This method is based on `A Kaiser window approach for the design of prototype + filters of cosine modulated filterbanks`_. + Parameters + ---------- + taps : int + The number of filter taps. + cutoff_ratio : float + Cut-off frequency ratio. + beta : float + Beta coefficient for kaiser window. + Returns + ---------- + ndarray + Impluse response of prototype filter (taps + 1,). + .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`: + https://ieeexplore.ieee.org/abstract/document/681427 + """ + # check the arguments are valid + assert taps % 2 == 0, "The number of taps mush be even number." + assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0." + + # make initial filter + omega_c = np.pi * cutoff_ratio + with np.errstate(invalid="ignore"): + h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / ( + np.pi * (np.arange(taps + 1) - 0.5 * taps)) + h_i[taps // + 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form + + # apply kaiser window + w = kaiser(taps + 1, beta) + h = h_i * w + + return h + + +class PQMF(paddle.nn.Layer): + """PQMF module. + This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_. + .. _`Near-perfect-reconstruction pseudo-QMF banks`: + https://ieeexplore.ieee.org/document/258122 + """ + + def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0): + """Initilize PQMF module. + The cutoff_ratio and beta parameters are optimized for #subbands = 4. + See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195. + Parameters + ---------- + subbands : int + The number of subbands. + taps : int + The number of filter taps. + cutoff_ratio : float + Cut-off frequency ratio. + beta : float + Beta coefficient for kaiser window. + """ + super(PQMF, self).__init__() + + # build analysis & synthesis filter coefficients + h_proto = design_prototype_filter(taps, cutoff_ratio, beta) + h_analysis = np.zeros((subbands, len(h_proto))) + h_synthesis = np.zeros((subbands, len(h_proto))) + for k in range(subbands): + h_analysis[k] = ( + 2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * ( + np.arange(taps + 1) - (taps / 2)) + (-1)**k * np.pi / 4)) + h_synthesis[k] = ( + 2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * ( + np.arange(taps + 1) - (taps / 2)) - (-1)**k * np.pi / 4)) + + # convert to tensor + self.analysis_filter = paddle.to_tensor( + h_analysis, dtype="float32").unsqueeze(1) + self.synthesis_filter = paddle.to_tensor( + h_synthesis, dtype="float32").unsqueeze(0) + + # filter for downsampling & upsampling + updown_filter = paddle.zeros( + (subbands, subbands, subbands), dtype="float32") + for k in range(subbands): + updown_filter[k, k, 0] = 1.0 + self.updown_filter = updown_filter + self.subbands = subbands + + # keep padding info + self.pad_fn = paddle.nn.Pad1D(taps // 2, mode='constant', value=0.0) + + def analysis(self, x): + """Analysis with PQMF. + Parameters + ---------- + x : Tensor + Input tensor (B, 1, T). + Returns + ---------- + Tensor + Output tensor (B, subbands, T // subbands). + """ + x = F.conv1d(self.pad_fn(x), self.analysis_filter) + return F.conv1d(x, self.updown_filter, stride=self.subbands) + + def synthesis(self, x): + """Synthesis with PQMF. + Parameters + ---------- + x : Tensor + Input tensor (B, subbands, T // subbands). + Returns + ---------- + Tensor + Output tensor (B, 1, T). + """ + + x = F.conv1d_transpose( + x, self.updown_filter * self.subbands, stride=self.subbands) + return F.conv1d(self.pad_fn(x), self.synthesis_filter) diff --git a/parakeet/modules/residual_stack.py b/parakeet/modules/residual_stack.py new file mode 100644 index 0000000000000000000000000000000000000000..135c32e5770994b63bc0353c74c6089fc2992e83 --- /dev/null +++ b/parakeet/modules/residual_stack.py @@ -0,0 +1,109 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Residual stack module in MelGAN.""" +from typing import Any +from typing import Dict + +from paddle import nn + +from parakeet.modules.causal_conv import CausalConv1D + + +class ResidualStack(nn.Layer): + """Residual stack module introduced in MelGAN.""" + + def __init__( + self, + kernel_size: int=3, + channels: int=32, + dilation: int=1, + bias: bool=True, + nonlinear_activation: str="LeakyReLU", + nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2}, + pad: str="Pad1D", + pad_params: Dict[str, Any]={"mode": "reflect"}, + use_causal_conv: bool=False, ): + """Initialize ResidualStack module. + Parameters + ---------- + kernel_size : int + Kernel size of dilation convolution layer. + channels : int + Number of channels of convolution layers. + dilation : int + Dilation factor. + bias : bool + Whether to add bias parameter in convolution layers. + nonlinear_activation : str + Activation function module name. + nonlinear_activation_params : Dict[str,Any] + Hyperparameters for activation function. + pad : str + Padding function module name before dilated convolution layer. + pad_params : Dict[str, Any] + Hyperparameters for padding function. + use_causal_conv : bool + Whether to use causal convolution. + """ + super().__init__() + + # defile residual stack part + if not use_causal_conv: + assert (kernel_size - 1 + ) % 2 == 0, "Not support even number kernel size." + self.stack = nn.Sequential( + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), + getattr(nn, pad)((kernel_size - 1) // 2 * dilation, + **pad_params), + nn.Conv1D( + channels, + channels, + kernel_size, + dilation=dilation, + bias_attr=bias), + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), + nn.Conv1D(channels, channels, 1, bias_attr=bias), ) + else: + self.stack = nn.Sequential( + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), + CausalConv1D( + channels, + channels, + kernel_size, + dilation=dilation, + bias=bias, + pad=pad, + pad_params=pad_params, ), + getattr(nn, nonlinear_activation)( + **nonlinear_activation_params), + nn.Conv1D(channels, channels, 1, bias_attr=bias), ) + + # defile extra layer for skip connection + self.skip_layer = nn.Conv1D(channels, channels, 1, bias_attr=bias) + + def forward(self, c): + """Calculate forward propagation. + Parameters + ---------- + c : Tensor + Input tensor (B, channels, T). + Returns + ---------- + Tensor + Output tensor (B, chennels, T). + """ + return self.stack(c) + self.skip_layer(c) diff --git a/parakeet/modules/stft_loss.py b/parakeet/modules/stft_loss.py index 1f400b461aeafc21d0426463c633a3ebf65c93a2..8af55ab1439a61b4ef60b5ede567ef6bc1304c23 100644 --- a/parakeet/modules/stft_loss.py +++ b/parakeet/modules/stft_loss.py @@ -51,7 +51,7 @@ def stft(x, # calculate window window = signal.get_window(window, win_length, fftbins=True) window = paddle.to_tensor(window) - x_stft = paddle.tensor.signal.stft( + x_stft = paddle.signal.stft( x, fft_size, hop_length, diff --git a/parakeet/training/extensions/visualizer.py b/parakeet/training/extensions/visualizer.py index 1c66ad8dd2990cacf36824b2268d38d86954452d..bc62c97603a937578de58c2e5a5d90b65c572802 100644 --- a/parakeet/training/extensions/visualizer.py +++ b/parakeet/training/extensions/visualizer.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from visualdl import LogWriter + from parakeet.training import extension from parakeet.training.trainer import Trainer @@ -26,8 +28,8 @@ class VisualDL(extension.Extension): default_name = 'visualdl' priority = extension.PRIORITY_READER - def __init__(self, writer): - self.writer = writer + def __init__(self, logdir): + self.writer = LogWriter(str(logdir)) def __call__(self, trainer: Trainer): for k, v in trainer.observation.items(): diff --git a/text_processing/.gitignore b/text_processing/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e400141b24bb752d016fae6d650f59504f0bc6b1 --- /dev/null +++ b/text_processing/.gitignore @@ -0,0 +1,7 @@ +data +glove +.pyc +checkpoints +epoch +__pycache__ +glove.840B.300d.zip diff --git a/text_processing/README.md b/text_processing/README.md new file mode 100644 index 0000000000000000000000000000000000000000..294af01d17554b0e035cbb3b0cce09aea9f39625 --- /dev/null +++ b/text_processing/README.md @@ -0,0 +1,25 @@ +# PaddleSpeechTask +A speech library to deal with a series of related front-end and back-end tasks + +## 环境 +- python==3.6.13 +- paddle==2.1.1 + +## 中/英文文本加标点任务 punctuation restoration: + +### 数据集: data +- 中文数据来源:data/chinese +1.iwlst2012zh +2.平凡的世界 + +- 英文数据来源: data/english +1.iwlst2012en + +- iwlst2012数据获取过程见data/README.md + +### 模型:speechtask/punctuation_restoration/model +1.BLSTM模型 + +2.BertLinear模型 + +3.BertBLSTM模型 diff --git a/text_processing/examples/punctuation_restoration/chinese/README.md b/text_processing/examples/punctuation_restoration/chinese/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1fcd954ca9bee8db56843a1adde526ebd6bd8fb8 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/README.md @@ -0,0 +1,35 @@ +# 中文实验例程 +## 测试数据: +- IWLST2012中文:test2012 + +## 运行代码 +- 运行 `run.sh 0 0 conf/train_conf/bertBLSTM_zh.yaml 1 conf/data_conf/chinese.yaml ` + +## 实验结果: +- BertLinear + - 实验配置:conf/train_conf/bertLinear_zh.yaml + - 测试结果 + + | | COMMA | PERIOD | QUESTION | OVERALL | + |-----------|-----------|-----------|-----------|--------- | + |Precision | 0.425665 | 0.335190 | 0.698113 | 0.486323 | + |Recall | 0.511278 | 0.572108 | 0.787234 | 0.623540 | + |F1 | 0.464560 | 0.422717 | 0.740000 | 0.542426 | + +- BertBLSTM + - 实验配置:conf/train_conf/bertBLSTM_zh.yaml + - 测试结果 avg_1 + + | | COMMA | PERIOD | QUESTION | OVERALL | + |-----------|-----------|-----------|-----------|--------- | + |Precision | 0.469484 | 0.550604 | 0.801887 | 0.607325 | + |Recall | 0.580271 | 0.592408 | 0.817308 | 0.663329 | + |F1 | 0.519031 | 0.570741 | 0.809524 | 0.633099 | + + - BertBLSTM/avg_1测试标贝合成数据 + + | | COMMA | PERIOD | QUESTION | OVERALL | + |-----------|-----------|-----------|-----------|--------- | + |Precision | 0.217192 | 0.196339 | 0.820717 | 0.411416 | + |Recall | 0.205922 | 0.892531 | 0.416162 | 0.504872 | + |F1 | 0.211407 | 0.321873 | 0.552279 | 0.361853 | diff --git a/text_processing/examples/punctuation_restoration/chinese/conf/blstm.yaml b/text_processing/examples/punctuation_restoration/chinese/conf/blstm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9b1a2e010b147310057098545516ca38b70f8e07 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/conf/blstm.yaml @@ -0,0 +1,34 @@ +data: + language: chinese + raw_path: /data4/mahaoxin/PaddleSpeechTask/data/chinese/PFDSJ #path to raw dataset + raw_train_file: train + raw_dev_file: dev + raw_test_file: test + vocab_file: vocab + punc_file: punc_vocab + save_path: data/PFDSJ #path to save dataset + seq_len: 100 + batch_size: 10 + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + +model_type: blstm +model_params: + vocab_size: 3751 + embedding_size: 200 + hidden_size: 100 + num_layers: 3 + num_class: 5 + init_scale: 0.1 + +training: + n_epoch: 32 + lr: !!float 1e-4 + lr_decay: 1.0 + weight_decay: !!float 1e-06 + global_grad_clip: 5.0 + log_interval: 10 + + + diff --git a/text_processing/examples/punctuation_restoration/chinese/conf/data_conf/chinese.yaml b/text_processing/examples/punctuation_restoration/chinese/conf/data_conf/chinese.yaml new file mode 100644 index 0000000000000000000000000000000000000000..191bfd3e6085a9b49fc74bda66e2e1bb67102100 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/conf/data_conf/chinese.yaml @@ -0,0 +1,7 @@ +type: chinese +raw_path: /data4/mahaoxin/PaddleSpeechTask/data/chinese/iwslt2012_zh #path to raw dataset +raw_train_file: iwslt2012_train_zh +raw_dev_file: iwslt2010_dev_zh +raw_test_file: biaobei_asr +punc_file: punc_vocab +save_path: data/iwslt2012_zh #path to save dataset \ No newline at end of file diff --git a/text_processing/examples/punctuation_restoration/chinese/conf/train_conf/bertBLSTM_zh.yaml b/text_processing/examples/punctuation_restoration/chinese/conf/train_conf/bertBLSTM_zh.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d1f58aac1f35e681d7ac4ea017105e2ad4ad817e --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/conf/train_conf/bertBLSTM_zh.yaml @@ -0,0 +1,49 @@ +data: + dataset_type: Bert + train_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/train + dev_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/dev + test_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/test2012_revise + data_params: + pretrained_token: bert-base-chinese + punc_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/punc_vocab + seq_len: 100 + batch_size: 64 + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + +checkpoint: + kbest_n: 5 + latest_n: 10 + metric_type: F1 + + +model_type: BertBLSTM +model_params: + pretrained_token: bert-base-chinese + output_size: 4 + dropout: 0.0 + bert_size: 768 + blstm_size: 128 + num_blstm_layers: 2 + init_scale: 0.1 + +# model_type: BertChLinear +# model_params: bert-base-chinese +# pretrained_token: +# output_size: 4 +# dropout: 0.0 +# bert_size: 768 + +training: + n_epoch: 100 + lr: !!float 1e-5 + lr_decay: 1.0 + weight_decay: !!float 1e-06 + global_grad_clip: 5.0 + log_interval: 10 + log_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/log/bertBLSTM_zh0812.log + +testing: + log_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/log/test_bertBLSTM_zh0812.log + diff --git a/text_processing/examples/punctuation_restoration/chinese/conf/train_conf/bertLinear_zh.yaml b/text_processing/examples/punctuation_restoration/chinese/conf/train_conf/bertLinear_zh.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c422e840e64f856f24ee6f760b1945e9ae02d252 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/conf/train_conf/bertLinear_zh.yaml @@ -0,0 +1,42 @@ +data: + dataset_type: Bert + train_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/train + dev_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/dev + test_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/test2012 + data_params: + pretrained_token: bert-base-chinese + punc_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/data/iwslt2012_zh/punc_vocab + seq_len: 100 + batch_size: 32 + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + +checkpoint: + kbest_n: 10 + latest_n: 10 + metric_type: F1 + + +model_type: BertLinear +model_params: + pretrained_token: bert-base-uncased + output_size: 4 + dropout: 0.2 + bert_size: 768 + hiddensize: 1568 + + +training: + n_epoch: 50 + lr: !!float 1e-5 + lr_decay: 1.0 + weight_decay: !!float 1e-06 + global_grad_clip: 5.0 + log_interval: 10 + log_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/log/train_linear0812.log + +testing: + log_interval: 10 + log_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/chinese/log/test_linear0812.log + diff --git a/text_processing/examples/punctuation_restoration/chinese/local/avg.sh b/text_processing/examples/punctuation_restoration/chinese/local/avg.sh new file mode 100644 index 0000000000000000000000000000000000000000..b8c14c6623cc30ca167cfb26fcc9ade349cad288 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/local/avg.sh @@ -0,0 +1,23 @@ +#! /usr/bin/env bash + +if [ $# != 2 ]; then + echo "usage: ${0} ckpt_dir avg_num" + exit -1 +fi + +ckpt_dir=${1} +average_num=${2} +decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams + +python3 -u ${BIN_DIR}/avg_model.py \ +--dst_model ${decode_checkpoint} \ +--ckpt_dir ${ckpt_dir} \ +--num ${average_num} \ +--val_best + +if [ $? -ne 0 ]; then + echo "Failed in avg ckpt!" + exit 1 +fi + +exit 0 \ No newline at end of file diff --git a/text_processing/examples/punctuation_restoration/chinese/local/data.sh b/text_processing/examples/punctuation_restoration/chinese/local/data.sh new file mode 100644 index 0000000000000000000000000000000000000000..aff7203ccb698d26c7ff22bc1ff91f3dc3a76735 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/local/data.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +if [ $# != 1 ];then + echo "usage: ${0} data_pre_conf" + echo $1 + exit -1 +fi + +data_pre_conf=$1 + +python3 -u ${BIN_DIR}/pre_data.py \ +--config ${data_pre_conf} + +if [ $? -ne 0 ]; then + echo "Failed in training!" + exit 1 +fi + +exit 0 diff --git a/text_processing/examples/punctuation_restoration/chinese/local/test.sh b/text_processing/examples/punctuation_restoration/chinese/local/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..6db75ca2ad347f292443b1baf980430df00dd568 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/local/test.sh @@ -0,0 +1,32 @@ + +#!/bin/bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 + + +python3 -u ${BIN_DIR}/test.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${ckpt_prefix}.rsl \ +--checkpoint_path ${ckpt_prefix} + +if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 +fi + +exit 0 diff --git a/text_processing/examples/punctuation_restoration/chinese/local/train.sh b/text_processing/examples/punctuation_restoration/chinese/local/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..f6bd2c98359b67292f9654c69ed11fc1a6720046 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/local/train.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +if [ $# != 2 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_name=$2 + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi + +mkdir -p exp + +python3 -u ${BIN_DIR}/train.py \ +--device ${device} \ +--nproc ${ngpu} \ +--config ${config_path} \ +--output exp/${ckpt_name} + +if [ $? -ne 0 ]; then + echo "Failed in training!" + exit 1 +fi + +exit 0 diff --git a/text_processing/examples/punctuation_restoration/chinese/path.sh b/text_processing/examples/punctuation_restoration/chinese/path.sh new file mode 100644 index 0000000000000000000000000000000000000000..8154cc78f89edeacc184d26fac1eaf579b43764b --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/path.sh @@ -0,0 +1,13 @@ +export MAIN_ROOT=${PWD}/../../../ + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + + +export BIN_DIR=${MAIN_ROOT}/speechtask/punctuation_restoration/bin diff --git a/text_processing/examples/punctuation_restoration/chinese/run.sh b/text_processing/examples/punctuation_restoration/chinese/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..bb3d25d4b73ca3d023054f6f0598494a69e488e6 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/chinese/run.sh @@ -0,0 +1,47 @@ +#!/bin/bash +set -e +source path.sh + + +## stage, gpu, data_pre_config, train_config, avg_num +if [ $# -lt 4 ]; then + echo "usage: bash ./run.sh stage gpu train_config avg_num data_config" + echo "eg: bash ./run.sh 0 0 train_config 1 data_config " + exit -1 +fi + +stage=$1 +stop_stage=100 +gpus=$2 +conf_path=$3 +avg_num=$4 +avg_ckpt=avg_${avg_num} +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + +if [ $stage -le 0 ]; then + if [ $# -eq 5 ]; then + data_pre_conf=$5 + # prepare data + bash ./local/data.sh ${data_pre_conf} || exit -1 + else + echo "data_pre_conf is not exist!" + exit -1 + fi +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `exp` dir + CUDA_VISIBLE_DEVICES=${gpus} bash ./local/train.sh ${conf_path} ${ckpt} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # avg n best model + bash ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=${gpus} bash ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi diff --git a/text_processing/examples/punctuation_restoration/english/README.md b/text_processing/examples/punctuation_restoration/english/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7955bb7d597a4c3492dc078a38b9e05eba45001b --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/README.md @@ -0,0 +1,23 @@ +# 英文实验例程 +## 测试数据: +- IWLST2012英文:test2011 + +## 运行代码 +- 运行 `run.sh 0 0 conf/train_conf/bertBLSTM_base_en.yaml 1 conf/data_conf/english.yaml ` + + +## 相关论文实验结果: +> * Nagy, Attila, Bence Bial, and Judit Ács. "Automatic punctuation restoration with BERT models." arXiv preprint arXiv:2101.07343 (2021)* +> + + +## 实验结果: +- BertBLSTM + - 实验配置:conf/train_conf/bertLinear_en.yaml + - 测试结果:exp/bertLinear_enRe/checkpoints/3.pdparams + + | | COMMA | PERIOD | QUESTION | OVERALL | + |-----------|-----------|-----------|-----------|--------- | + |Precision |0.667910 |0.715778 |0.822222 |0.735304 | + |Recall |0.755274 |0.868188 |0.804348 |0.809270 | + |F1 |0.708911 |0.784651 |0.813187 |0.768916 | diff --git a/text_processing/examples/punctuation_restoration/english/conf/data_conf/english.yaml b/text_processing/examples/punctuation_restoration/english/conf/data_conf/english.yaml new file mode 100644 index 0000000000000000000000000000000000000000..44834f28c3e78e6d4ade7c9b203d78a45d85019b --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/conf/data_conf/english.yaml @@ -0,0 +1,7 @@ +type: english +raw_path: /data4/mahaoxin/PaddleSpeechTask/data/english/iwslt2012_en #path to raw dataset +raw_train_file: iwslt2012_train_en +raw_dev_file: iwslt2010_dev_en +raw_test_file: iwslt2011_test_en +punc_file: punc_vocab +save_path: data/iwslt2012_en #path to save dataset \ No newline at end of file diff --git a/text_processing/examples/punctuation_restoration/english/conf/train_conf/bertBLSTM_base_en.yaml b/text_processing/examples/punctuation_restoration/english/conf/train_conf/bertBLSTM_base_en.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7f4383d481eca71c4e37acf4fcd9b684e923d3d8 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/conf/train_conf/bertBLSTM_base_en.yaml @@ -0,0 +1,47 @@ +data: + dataset_type: Bert + train_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/train + dev_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/dev + test_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/test2011 + data_params: + pretrained_token: bert-base-uncased #english + punc_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/punc_vocab + seq_len: 50 + batch_size: 32 + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + +checkpoint: + kbest_n: 10 + latest_n: 10 + +model_type: BertBLSTM +model_params: + pretrained_token: bert-base-uncased + output_size: 4 + dropout: 0.0 + bert_size: 768 + blstm_size: 128 + num_blstm_layers: 2 + init_scale: 0.2 +# model_type: BertChLinear +# model_params: +# pretrained_token: bert-large-uncased +# output_size: 4 +# dropout: 0.0 +# bert_size: 768 + +training: + n_epoch: 100 + lr: !!float 1e-5 + lr_decay: 1.0 + weight_decay: !!float 1e-06 + global_grad_clip: 5.0 + log_interval: 10 + log_path: log/bertBLSTM_base0812.log + +testing: + log_path: log/testbertBLSTM_base0812.log + + diff --git a/text_processing/examples/punctuation_restoration/english/conf/train_conf/bertLinear_en.yaml b/text_processing/examples/punctuation_restoration/english/conf/train_conf/bertLinear_en.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8cac9889477f2e45c559cc18517b2b5773984500 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/conf/train_conf/bertLinear_en.yaml @@ -0,0 +1,39 @@ +data: + dataset_type: Bert + train_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/train + dev_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/dev + test_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/test2011 + data_params: + pretrained_token: bert-base-uncased #english + punc_path: /data4/mahaoxin/PaddleSpeechTask/examples/punctuation_restoration/english/data/iwslt2012_en/punc_vocab + seq_len: 100 + batch_size: 32 + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + +checkpoint: + kbest_n: 10 + latest_n: 10 + +model_type: BertLinear +model_params: + pretrained_token: bert-base-uncased + output_size: 4 + dropout: 0.2 + bert_size: 768 + hiddensize: 1568 + +training: + n_epoch: 20 + lr: !!float 1e-5 + lr_decay: 1.0 + weight_decay: !!float 1e-06 + global_grad_clip: 3.0 + log_interval: 10 + log_path: log/train_linear0820.log + +testing: + log_path: log/test2011_linear0820.log + + diff --git a/text_processing/examples/punctuation_restoration/english/local/avg.sh b/text_processing/examples/punctuation_restoration/english/local/avg.sh new file mode 100644 index 0000000000000000000000000000000000000000..b8c14c6623cc30ca167cfb26fcc9ade349cad288 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/local/avg.sh @@ -0,0 +1,23 @@ +#! /usr/bin/env bash + +if [ $# != 2 ]; then + echo "usage: ${0} ckpt_dir avg_num" + exit -1 +fi + +ckpt_dir=${1} +average_num=${2} +decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams + +python3 -u ${BIN_DIR}/avg_model.py \ +--dst_model ${decode_checkpoint} \ +--ckpt_dir ${ckpt_dir} \ +--num ${average_num} \ +--val_best + +if [ $? -ne 0 ]; then + echo "Failed in avg ckpt!" + exit 1 +fi + +exit 0 \ No newline at end of file diff --git a/text_processing/examples/punctuation_restoration/english/local/data.sh b/text_processing/examples/punctuation_restoration/english/local/data.sh new file mode 100644 index 0000000000000000000000000000000000000000..1b0c62b175aa46d3bc1544ef1e7fbec4db4d717f --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/local/data.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +if [ $# != 1 ];then + echo "usage: ${0} config_path" + exit -1 +fi + +config_path=$1 + +python3 -u ${BIN_DIR}/pre_data.py \ +--config ${config_path} + +if [ $? -ne 0 ]; then + echo "Failed in training!" + exit 1 +fi + +exit 0 diff --git a/text_processing/examples/punctuation_restoration/english/local/test.sh b/text_processing/examples/punctuation_restoration/english/local/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..6db75ca2ad347f292443b1baf980430df00dd568 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/local/test.sh @@ -0,0 +1,32 @@ + +#!/bin/bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 + + +python3 -u ${BIN_DIR}/test.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${ckpt_prefix}.rsl \ +--checkpoint_path ${ckpt_prefix} + +if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 +fi + +exit 0 diff --git a/text_processing/examples/punctuation_restoration/english/local/train.sh b/text_processing/examples/punctuation_restoration/english/local/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..f6bd2c98359b67292f9654c69ed11fc1a6720046 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/local/train.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +if [ $# != 2 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_name=$2 + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi + +mkdir -p exp + +python3 -u ${BIN_DIR}/train.py \ +--device ${device} \ +--nproc ${ngpu} \ +--config ${config_path} \ +--output exp/${ckpt_name} + +if [ $? -ne 0 ]; then + echo "Failed in training!" + exit 1 +fi + +exit 0 diff --git a/text_processing/examples/punctuation_restoration/english/path.sh b/text_processing/examples/punctuation_restoration/english/path.sh new file mode 100644 index 0000000000000000000000000000000000000000..8154cc78f89edeacc184d26fac1eaf579b43764b --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/path.sh @@ -0,0 +1,13 @@ +export MAIN_ROOT=${PWD}/../../../ + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + + +export BIN_DIR=${MAIN_ROOT}/speechtask/punctuation_restoration/bin diff --git a/text_processing/examples/punctuation_restoration/english/run.sh b/text_processing/examples/punctuation_restoration/english/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..bb3d25d4b73ca3d023054f6f0598494a69e488e6 --- /dev/null +++ b/text_processing/examples/punctuation_restoration/english/run.sh @@ -0,0 +1,47 @@ +#!/bin/bash +set -e +source path.sh + + +## stage, gpu, data_pre_config, train_config, avg_num +if [ $# -lt 4 ]; then + echo "usage: bash ./run.sh stage gpu train_config avg_num data_config" + echo "eg: bash ./run.sh 0 0 train_config 1 data_config " + exit -1 +fi + +stage=$1 +stop_stage=100 +gpus=$2 +conf_path=$3 +avg_num=$4 +avg_ckpt=avg_${avg_num} +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + +if [ $stage -le 0 ]; then + if [ $# -eq 5 ]; then + data_pre_conf=$5 + # prepare data + bash ./local/data.sh ${data_pre_conf} || exit -1 + else + echo "data_pre_conf is not exist!" + exit -1 + fi +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `exp` dir + CUDA_VISIBLE_DEVICES=${gpus} bash ./local/train.sh ${conf_path} ${ckpt} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # avg n best model + bash ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=${gpus} bash ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi diff --git a/text_processing/requirements.txt b/text_processing/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..685ab029eda6011f567837387a0768202e5a8b0c --- /dev/null +++ b/text_processing/requirements.txt @@ -0,0 +1,6 @@ +numpy +pyyaml +tensorboardX +tqdm +ujson +yacs diff --git a/text_processing/speechtask/punctuation_restoration/bin/avg_model.py b/text_processing/speechtask/punctuation_restoration/bin/avg_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a012e25816b0bdbec2c64d1bbe2aa67a6ce3ce3d --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/bin/avg_model.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import glob +import json +import os + +import numpy as np +import paddle + + +def main(args): + paddle.set_device('cpu') + + val_scores = [] + beat_val_scores = [] + selected_epochs = [] + if args.val_best: + jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json') + for y in jsons: + with open(y, 'r') as f: + dic_json = json.load(f) + loss = dic_json['F1'] + epoch = dic_json['epoch'] + if epoch >= args.min_epoch and epoch <= args.max_epoch: + val_scores.append((epoch, loss)) + + val_scores = np.array(val_scores) + sort_idx = np.argsort(val_scores[:, 1]) + sorted_val_scores = val_scores[sort_idx] + path_list = [ + args.ckpt_dir + '/{}.pdparams'.format(int(epoch)) + for epoch in sorted_val_scores[:args.num, 0] + ] + + beat_val_scores = sorted_val_scores[:args.num, 1] + selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64) + print("best val scores = " + str(beat_val_scores)) + print("selected epochs = " + str(selected_epochs)) + else: + path_list = glob.glob(f'{args.ckpt_dir}/[!avg][!final]*.pdparams') + path_list = sorted(path_list, key=os.path.getmtime) + path_list = path_list[-args.num:] + + print(path_list) + + avg = None + num = args.num + assert num == len(path_list) + for path in path_list: + print(f'Processing {path}') + states = paddle.load(path) + if avg is None: + avg = states + else: + for k in avg.keys(): + avg[k] += states[k] + # average + for k in avg.keys(): + if avg[k] is not None: + avg[k] /= num + + paddle.save(avg, args.dst_model) + print(f'Saving to {args.dst_model}') + + meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json' + with open(meta_path, 'w') as f: + data = json.dumps({ + "avg_ckpt": args.dst_model, + "ckpt": path_list, + "epoch": selected_epochs.tolist(), + "val_loss": beat_val_scores.tolist(), + }) + f.write(data + "\n") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='average model') + parser.add_argument('--dst_model', required=True, help='averaged model') + parser.add_argument( + '--ckpt_dir', required=True, help='ckpt model dir for average') + parser.add_argument( + '--val_best', action="store_true", help='averaged model') + parser.add_argument( + '--num', default=5, type=int, help='nums for averaged model') + parser.add_argument( + '--min_epoch', + default=0, + type=int, + help='min epoch used for averaging model') + parser.add_argument( + '--max_epoch', + default=65536, # Big enough + type=int, + help='max epoch used for averaging model') + + args = parser.parse_args() + print(args) + + main(args) diff --git a/text_processing/speechtask/punctuation_restoration/bin/pre_data.py b/text_processing/speechtask/punctuation_restoration/bin/pre_data.py new file mode 100644 index 0000000000000000000000000000000000000000..a074d7e3da85483c246b6e996c788444ec6d49a7 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/bin/pre_data.py @@ -0,0 +1,48 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data preparation for punctuation_restoration task.""" +import yaml +from speechtask.punctuation_restoration.utils.default_parser import default_argument_parser +from speechtask.punctuation_restoration.utils.punct_pre import process_chinese_pure_senetence +from speechtask.punctuation_restoration.utils.punct_pre import process_english_pure_senetence +from speechtask.punctuation_restoration.utils.utility import print_arguments + + +# create dataset from raw data files +def main(config, args): + print("Start preparing data from raw data.") + if (config['type'] == 'chinese'): + process_chinese_pure_senetence(config) + elif (config['type'] == 'english'): + print('english!!!!') + process_english_pure_senetence(config) + else: + print('Error: Type should be chinese or english!!!!') + raise ValueError('Type should be chinese or english') + + print("Finish preparing data.") + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + print_arguments(args, globals()) + + # https://yaml.org/type/float.html + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + # config.freeze() + print(config) + main(config, args) diff --git a/text_processing/speechtask/punctuation_restoration/bin/test.py b/text_processing/speechtask/punctuation_restoration/bin/test.py new file mode 100644 index 0000000000000000000000000000000000000000..17892fdb76e9a53df320a172368350ce6e826fb5 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/bin/test.py @@ -0,0 +1,45 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation for model.""" +import yaml +from speechtask.punctuation_restoration.training.trainer import Tester +from speechtask.punctuation_restoration.utils.default_parser import default_argument_parser +from speechtask.punctuation_restoration.utils.utility import print_arguments + + +def main_sp(config, args): + exp = Tester(config, args) + exp.setup() + exp.run_test() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + print_arguments(args, globals()) + + # https://yaml.org/type/float.html + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/text_processing/speechtask/punctuation_restoration/bin/train.py b/text_processing/speechtask/punctuation_restoration/bin/train.py new file mode 100644 index 0000000000000000000000000000000000000000..1ffd79b7b192ce358d66104c1198212205a53040 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/bin/train.py @@ -0,0 +1,49 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Trainer for punctuation_restoration task.""" +import yaml +from paddle import distributed as dist +from speechtask.punctuation_restoration.training.trainer import Trainer +from speechtask.punctuation_restoration.utils.default_parser import default_argument_parser +from speechtask.punctuation_restoration.utils.utility import print_arguments + + +def main_sp(config, args): + exp = Trainer(config, args) + exp.setup() + exp.run() + + +def main(config, args): + if args.device == "gpu" and args.nprocs > 1: + dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) + else: + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + print_arguments(args, globals()) + + # https://yaml.org/type/float.html + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/text_processing/speechtask/punctuation_restoration/io/__init__.py b/text_processing/speechtask/punctuation_restoration/io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..185a92b8d94d3426d616c0624f0f2ee04339349e --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/io/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/text_processing/speechtask/punctuation_restoration/io/collator.py b/text_processing/speechtask/punctuation_restoration/io/collator.py new file mode 100644 index 0000000000000000000000000000000000000000..5b63b5847f67ec7726ab909148b9a28eac7afea0 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/io/collator.py @@ -0,0 +1,64 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +__all__ = ["TextCollator"] + + +class TextCollator(): + def __init__(self, padding_value): + self.padding_value = padding_value + + def __call__(self, batch): + """batch examples + Args: + batch ([List]): batch is (text, punctuation) + text (List[int] ) shape (batch, L) + punctuation (List[int] or str): shape (batch, L) + Returns: + tuple(text, punctuation): batched data. + text : (B, Lmax) + punctuation : (B, Lmax) + """ + texts = [] + punctuations = [] + for text, punctuation in batch: + + texts.append(text) + punctuations.append(punctuation) + + #[B, T, D] + x_pad = self.pad_sequence(texts).astype(np.int64) + # print(x_pad.shape) + # pad_list(audios, 0.0).astype(np.float32) + # ilens = np.array(audio_lens).astype(np.int64) + y_pad = self.pad_sequence(punctuations).astype(np.int64) + # print(y_pad.shape) + # olens = np.array(text_lens).astype(np.int64) + return x_pad, y_pad + + def pad_sequence(self, sequences): + # assuming trailing dimensions and type of all the Tensors + # in sequences are same and fetching those from sequences[0] + max_len = max([len(s) for s in sequences]) + out_dims = (len(sequences), max_len) + + out_tensor = np.full(out_dims, + self.padding_value) #, dtype=sequences[0].dtype) + for i, tensor in enumerate(sequences): + length = len(tensor) + # use index notation to prevent duplicate references to the tensor + out_tensor[i, :length] = tensor + + return out_tensor diff --git a/text_processing/speechtask/punctuation_restoration/io/common.py b/text_processing/speechtask/punctuation_restoration/io/common.py new file mode 100644 index 0000000000000000000000000000000000000000..3ed4a6041ba7e532614911eb22c49c1e4e216768 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/io/common.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import codecs +import re +import unicodedata + +import ujson + +PAD = "" +UNK = "" +NUM = "" +END = "" +SPACE = "_SPACE" + + +def write_json(filename, dataset): + with codecs.open(filename, mode="w", encoding="utf-8") as f: + ujson.dump(dataset, f) + + +def word_convert(word, keep_number=True, lowercase=True): + if not keep_number: + if is_digit(word): + word = NUM + if lowercase: + word = word.lower() + return word + + +def is_digit(word): + try: + float(word) + return True + except ValueError: + pass + try: + unicodedata.numeric(word) + return True + except (TypeError, ValueError): + pass + result = re.compile(r'^[-+]?[0-9]+,[0-9]+$').match(word) + if result: + return True + return False diff --git a/text_processing/speechtask/punctuation_restoration/io/dataset.py b/text_processing/speechtask/punctuation_restoration/io/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..17c13c3877b7aefa9ca75053aa7edbec01f4ab7a --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/io/dataset.py @@ -0,0 +1,310 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import random + +import numpy as np +import paddle +from paddle.io import Dataset +from paddlenlp.transformers import BertTokenizer +# from speechtask.punctuation_restoration.utils.punct_prepro import load_dataset + +__all__ = ["PuncDataset", "PuncDatasetFromBertTokenizer"] + + +class PuncDataset(Dataset): + """Representing a Dataset + superclass + ---------- + data.Dataset : + Dataset is a abstract class, representing the real data. + """ + + def __init__(self, train_path, vocab_path, punc_path, seq_len=100): + # 检查文件是否存在 + print(train_path) + print(vocab_path) + assert os.path.exists(train_path), "train文件不存在" + assert os.path.exists(vocab_path), "词典文件不存在" + assert os.path.exists(punc_path), "标点文件不存在" + self.seq_len = seq_len + + self.word2id = self.load_vocab( + vocab_path, extra_word_list=['', '']) + self.id2word = {v: k for k, v in self.word2id.items()} + self.punc2id = self.load_vocab(punc_path, extra_word_list=[" "]) + self.id2punc = {k: v for (v, k) in self.punc2id.items()} + + tmp_seqs = open(train_path, encoding='utf-8').readlines() + self.txt_seqs = [i for seq in tmp_seqs for i in seq.split()] + # print(self.txt_seqs[:10]) + # with open('./txt_seq', 'w', encoding='utf-8') as w: + # print(self.txt_seqs, file=w) + self.preprocess(self.txt_seqs) + print('---punc-') + print(self.punc2id) + + def __len__(self): + """return the sentence nums in .txt + """ + return self.in_len + + def __getitem__(self, index): + """返回指定索引的张量对 (输入文本id的序列 , 其对应的标点id序列) + Parameters + ---------- + index : int 索引 + """ + return self.input_data[index], self.label[index] + + def load_vocab(self, vocab_path, extra_word_list=[], encoding='utf-8'): + n = len(extra_word_list) + with open(vocab_path, encoding='utf-8') as vf: + vocab = {word.strip(): i + n for i, word in enumerate(vf)} + for i, word in enumerate(extra_word_list): + vocab[word] = i + return vocab + + def preprocess(self, txt_seqs: list): + """将文本转为单词和应预测标点的id pair + Parameters + ---------- + txt : 文本 + 文本每个单词跟随一个空格,符号也跟一个空格 + """ + input_data = [] + label = [] + input_r = [] + label_r = [] + # txt_seqs is a list like: ['char', 'char', 'char', '*,*', 'char', ......] + count = 0 + length = len(txt_seqs) + for token in txt_seqs: + count += 1 + if count == length: + break + if token in self.punc2id: + continue + punc = txt_seqs[count] + if punc not in self.punc2id: + # print('标点{}:'.format(count), self.punc2id[" "]) + label.append(self.punc2id[" "]) + input_data.append( + self.word2id.get(token, self.word2id[""])) + input_r.append(token) + label_r.append(' ') + else: + # print('标点{}:'.format(count), self.punc2id[punc]) + label.append(self.punc2id[punc]) + input_data.append( + self.word2id.get(token, self.word2id[""])) + input_r.append(token) + label_r.append(punc) + if len(input_data) != len(label): + assert 'error: length input_data != label' + # code below is for using 100 as a hidden size + print(len(input_data)) + self.in_len = len(input_data) // self.seq_len + len_tmp = self.in_len * self.seq_len + input_data = input_data[:len_tmp] + label = label[:len_tmp] + + self.input_data = paddle.to_tensor( + np.array(input_data, dtype='int64').reshape(-1, self.seq_len)) + self.label = paddle.to_tensor( + np.array(label, dtype='int64').reshape(-1, self.seq_len)) + + +# unk_token='[UNK]' +# sep_token='[SEP]' +# pad_token='[PAD]' +# cls_token='[CLS]' +# mask_token='[MASK]' + + +class PuncDatasetFromBertTokenizer(Dataset): + """Representing a Dataset + superclass + ---------- + data.Dataset : + Dataset is a abstract class, representing the real data. + """ + + def __init__(self, + train_path, + is_eval, + pretrained_token, + punc_path, + seq_len=100): + # 检查文件是否存在 + print(train_path) + self.tokenizer = BertTokenizer.from_pretrained( + pretrained_token, do_lower_case=True) + self.paddingID = self.tokenizer.pad_token_id + assert os.path.exists(train_path), "train文件不存在" + assert os.path.exists(punc_path), "标点文件不存在" + self.seq_len = seq_len + + self.punc2id = self.load_vocab(punc_path, extra_word_list=[" "]) + self.id2punc = {k: v for (v, k) in self.punc2id.items()} + + tmp_seqs = open(train_path, encoding='utf-8').readlines() + self.txt_seqs = [i for seq in tmp_seqs for i in seq.split()] + # print(self.txt_seqs[:10]) + # with open('./txt_seq', 'w', encoding='utf-8') as w: + # print(self.txt_seqs, file=w) + if (is_eval): + self.preprocess(self.txt_seqs) + else: + self.preprocess_shift(self.txt_seqs) + print("data len: %d" % (len(self.input_data))) + print('---punc-') + print(self.punc2id) + + def __len__(self): + """return the sentence nums in .txt + """ + return self.in_len + + def __getitem__(self, index): + """返回指定索引的张量对 (输入文本id的序列 , 其对应的标点id序列) + Parameters + ---------- + index : int 索引 + """ + return self.input_data[index], self.label[index] + + def load_vocab(self, vocab_path, extra_word_list=[], encoding='utf-8'): + n = len(extra_word_list) + with open(vocab_path, encoding='utf-8') as vf: + vocab = {word.strip(): i + n for i, word in enumerate(vf)} + for i, word in enumerate(extra_word_list): + vocab[word] = i + return vocab + + def preprocess(self, txt_seqs: list): + """将文本转为单词和应预测标点的id pair + Parameters + ---------- + txt : 文本 + 文本每个单词跟随一个空格,符号也跟一个空格 + """ + input_data = [] + label = [] + # txt_seqs is a list like: ['char', 'char', 'char', '*,*', 'char', ......] + count = 0 + for i in range(len(txt_seqs) - 1): + word = txt_seqs[i] + punc = txt_seqs[i + 1] + if word in self.punc2id: + continue + + token = self.tokenizer(word) + x = token["input_ids"][1:-1] + input_data.extend(x) + + for i in range(len(x) - 1): + label.append(self.punc2id[" "]) + + if punc not in self.punc2id: + # print('标点{}:'.format(count), self.punc2id[" "]) + label.append(self.punc2id[" "]) + else: + label.append(self.punc2id[punc]) + + if len(input_data) != len(label): + assert 'error: length input_data != label' + # code below is for using 100 as a hidden size + + # print(len(input_data[0])) + # print(len(label)) + self.in_len = len(input_data) // self.seq_len + len_tmp = self.in_len * self.seq_len + input_data = input_data[:len_tmp] + label = label[:len_tmp] + # # print(input_data) + # print(type(input_data)) + # tmp=np.array(input_data) + # print('--~~~~~~~~~~~~~') + # print(type(tmp)) + # print(tmp.shape) + self.input_data = paddle.to_tensor( + np.array(input_data, dtype='int64').reshape( + -1, self.seq_len)) #, dtype='int64' + self.label = paddle.to_tensor( + np.array(label, dtype='int64').reshape( + -1, self.seq_len)) #, dtype='int64' + + def preprocess_shift(self, txt_seqs: list): + """将文本转为单词和应预测标点的id pair + Parameters + ---------- + txt : 文本 + 文本每个单词跟随一个空格,符号也跟一个空格 + """ + input_data = [] + label = [] + # txt_seqs is a list like: ['char', 'char', 'char', '*,*', 'char', ......] + count = 0 + for i in range(len(txt_seqs) - 1): + word = txt_seqs[i] + punc = txt_seqs[i + 1] + if word in self.punc2id: + continue + + token = self.tokenizer(word) + x = token["input_ids"][1:-1] + input_data.extend(x) + + for i in range(len(x) - 1): + label.append(self.punc2id[" "]) + + if punc not in self.punc2id: + # print('标点{}:'.format(count), self.punc2id[" "]) + label.append(self.punc2id[" "]) + else: + label.append(self.punc2id[punc]) + + if len(input_data) != len(label): + assert 'error: length input_data != label' + + # print(len(input_data[0])) + # print(len(label)) + start = 0 + processed_data = [] + processed_label = [] + while (start < len(input_data) - self.seq_len): + # end=start+self.seq_len + end = random.randint(start + self.seq_len // 2, + start + self.seq_len) + processed_data.append(input_data[start:end]) + processed_label.append(label[start:end]) + + start = start + random.randint(1, self.seq_len // 2) + + self.in_len = len(processed_data) + # # print(input_data) + # print(type(input_data)) + # tmp=np.array(input_data) + # print('--~~~~~~~~~~~~~') + # print(type(tmp)) + # print(tmp.shape) + self.input_data = processed_data + #paddle.to_tensor(np.array(processed_data, dtype='int64')) #, dtype='int64' + self.label = processed_label + #paddle.to_tensor(np.array(processed_label, dtype='int64')) #, dtype='int64' + + +if __name__ == '__main__': + dataset = PuncDataset() diff --git a/text_processing/speechtask/punctuation_restoration/model/BertBLSTM.py b/text_processing/speechtask/punctuation_restoration/model/BertBLSTM.py new file mode 100644 index 0000000000000000000000000000000000000000..bc953adfdcc2e62e439a0ce8880d77adf9a6106f --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/model/BertBLSTM.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +import paddle.nn.initializer as I +from paddlenlp.transformers import BertForTokenClassification + + +class BertBLSTMPunc(nn.Layer): + def __init__(self, + pretrained_token="bert-large-uncased", + output_size=4, + dropout=0.0, + bert_size=768, + blstm_size=128, + num_blstm_layers=2, + init_scale=0.1): + super(BertBLSTMPunc, self).__init__() + self.output_size = output_size + self.bert = BertForTokenClassification.from_pretrained( + pretrained_token, num_classes=bert_size) + # self.bert_vocab_size = vocab_size + # self.bn = nn.BatchNorm1d(segment_size*self.bert_vocab_size) + # self.fc = nn.Linear(segment_size*self.bert_vocab_size, output_size) + + self.lstm = nn.LSTM( + input_size=bert_size, + hidden_size=blstm_size, + num_layers=num_blstm_layers, + direction="bidirect", + weight_ih_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale)), + weight_hh_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale))) + + # NOTE dense*2 使用bert中间层 dense hidden_state self.bert_size + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(blstm_size * 2, output_size) + self.softmax = nn.Softmax() + + def forward(self, x): + # print('input :', x.shape) + x = self.bert(x) #[0] + # print('after bert :', x.shape) + + y, (_, _) = self.lstm(x) + # print('after lstm :', y.shape) + y = self.fc(self.dropout(y)) + y = paddle.reshape(y, shape=[-1, self.output_size]) + # print('after fc :', y.shape) + + logit = self.softmax(y) + # print('after softmax :', logit.shape) + + return y, logit + + +if __name__ == '__main__': + print('start model') + model = BertBLSTMPunc() + x = paddle.randint(low=0, high=40, shape=[2, 5]) + print(x) + y, logit = model(x) diff --git a/text_processing/speechtask/punctuation_restoration/model/BertLinear.py b/text_processing/speechtask/punctuation_restoration/model/BertLinear.py new file mode 100644 index 0000000000000000000000000000000000000000..854f522cff0de39eee13d71077de07d0e76f5bd0 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/model/BertLinear.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +from paddlenlp.transformers import BertForTokenClassification + + +class BertLinearPunc(nn.Layer): + def __init__(self, + pretrained_token="bert-base-uncased", + output_size=4, + dropout=0.2, + bert_size=768, + hiddensize=1568): + super(BertLinearPunc, self).__init__() + self.output_size = output_size + self.bert = BertForTokenClassification.from_pretrained( + pretrained_token, num_classes=bert_size) + # self.bert_vocab_size = vocab_size + # self.bn = nn.BatchNorm1d(segment_size*self.bert_vocab_size) + # self.fc = nn.Linear(segment_size*self.bert_vocab_size, output_size) + + # NOTE dense*2 使用bert中间层 dense hidden_state self.bert_size + self.dropout1 = nn.Dropout(dropout) + self.fc1 = nn.Linear(bert_size, hiddensize) + self.dropout2 = nn.Dropout(dropout) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hiddensize, output_size) + self.softmax = nn.Softmax() + + def forward(self, x): + # print('input :', x.shape) + x = self.bert(x) #[0] + # print('after bert :', x.shape) + + x = self.fc1(self.dropout1(x)) + x = self.fc2(self.relu(self.dropout2(x))) + x = paddle.reshape(x, shape=[-1, self.output_size]) + # print('after fc :', x.shape) + + logit = self.softmax(x) + # print('after softmax :', logit.shape) + + return x, logit + + +if __name__ == '__main__': + print('start model') + model = BertLinearPunc() + x = paddle.randint(low=0, high=40, shape=[2, 5]) + print(x) + y, logit = model(x) diff --git a/text_processing/speechtask/punctuation_restoration/model/blstm.py b/text_processing/speechtask/punctuation_restoration/model/blstm.py new file mode 100644 index 0000000000000000000000000000000000000000..fcfd31a3e8191526b81610287e9bc95d7ccfc60c --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/model/blstm.py @@ -0,0 +1,89 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +import paddle.nn.initializer as I + + +class BiLSTM(nn.Layer): + """LSTM for Punctuation Restoration + """ + + def __init__(self, + vocab_size, + embedding_size, + hidden_size, + num_layers, + num_class, + init_scale=0.1): + super(BiLSTM, self).__init__() + # hyper parameters + self.vocab_size = vocab_size + self.embedding_size = embedding_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_class = num_class + + # 网络中的层 + self.embedding = nn.Embedding( + vocab_size, + embedding_size, + weight_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale))) + # print(hidden_size) + # print(embedding_size) + self.lstm = nn.LSTM( + input_size=embedding_size, + hidden_size=hidden_size, + num_layers=num_layers, + direction="bidirect", + weight_ih_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale)), + weight_hh_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale))) + # Here is a one direction LSTM. If bidirection LSTM, (hidden_size*2(,)) + self.fc = nn.Linear( + in_features=hidden_size * 2, + out_features=num_class, + weight_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale)), + bias_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale))) + # self.fc = nn.Linear(hidden_size, num_class) + + self.softmax = nn.Softmax() + + def forward(self, input): + """The forward process of Net + Parameters + ---------- + inputs : tensor + Training data, batch first + """ + # Inherit the knowledge of context + + # hidden = self.init_hidden(inputs.size(0)) + # print('input_size',inputs.size()) + embedding = self.embedding(input) + # print('embedding_size', embedding.size()) + # packed = pack_sequence(embedding, inputs_lengths, batch_first=True) + # embedding本身是同样长度的,用这个函数主要是为了用pack + # ***************************************************************************** + y, (_, _) = self.lstm(embedding) + + # print(y.size()) + y = self.fc(y) + y = paddle.reshape(y, shape=[-1, self.num_class]) + logit = self.softmax(y) + return y, logit diff --git a/text_processing/speechtask/punctuation_restoration/model/lstm.py b/text_processing/speechtask/punctuation_restoration/model/lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec6853372a2b5e21b95ce1c9271c4a2d1694982 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/model/lstm.py @@ -0,0 +1,85 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +import paddle.nn.initializer as I + + +class RnnLm(nn.Layer): + def __init__(self, + vocab_size, + punc_size, + hidden_size, + num_layers=1, + init_scale=0.1, + dropout=0.0): + super(RnnLm, self).__init__() + self.hidden_size = hidden_size + self.num_layers = num_layers + self.init_scale = init_scale + self.punc_size = punc_size + + self.embedder = nn.Embedding( + vocab_size, + hidden_size, + weight_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale))) + + self.lstm = nn.LSTM( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + weight_ih_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale)), + weight_hh_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale))) + + self.fc = nn.Linear( + hidden_size, + punc_size, + weight_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale)), + bias_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale))) + + self.dropout = nn.Dropout(p=dropout) + self.softmax = nn.Softmax() + + def forward(self, inputs): + x = inputs + x_emb = self.embedder(x) + x_emb = self.dropout(x_emb) + + y, (_, _) = self.lstm(x_emb) + + y = self.dropout(y) + y = self.fc(y) + y = paddle.reshape(y, shape=[-1, self.punc_size]) + logit = self.softmax(y) + return y, logit + + +class CrossEntropyLossForLm(nn.Layer): + def __init__(self): + super(CrossEntropyLossForLm, self).__init__() + + def forward(self, y, label): + label = paddle.unsqueeze(label, axis=2) + loss = paddle.nn.functional.cross_entropy( + input=y, label=label, reduction='none') + loss = paddle.squeeze(loss, axis=[2]) + loss = paddle.mean(loss, axis=[0]) + loss = paddle.sum(loss) + return loss diff --git a/text_processing/speechtask/punctuation_restoration/modules/__init__.py b/text_processing/speechtask/punctuation_restoration/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..185a92b8d94d3426d616c0624f0f2ee04339349e --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/modules/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/text_processing/speechtask/punctuation_restoration/modules/activation.py b/text_processing/speechtask/punctuation_restoration/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..6a13e4aabff5ec120aab9e82d6b9b07f782abc0d --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/modules/activation.py @@ -0,0 +1,141 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import OrderedDict + +import paddle +from paddle import nn + +__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock"] + + +def brelu(x, t_min=0.0, t_max=24.0, name=None): + # paddle.to_tensor is dygraph_only can not work under JIT + t_min = paddle.full(shape=[1], fill_value=t_min, dtype='float32') + t_max = paddle.full(shape=[1], fill_value=t_max, dtype='float32') + return x.maximum(t_min).minimum(t_max) + + +class LinearGLUBlock(nn.Layer): + """A linear Gated Linear Units (GLU) block.""" + + def __init__(self, idim: int): + """ GLU. + Args: + idim (int): input and output dimension + """ + super().__init__() + self.fc = nn.Linear(idim, idim * 2) + + def forward(self, xs): + return glu(self.fc(xs), dim=-1) + + +class ConvGLUBlock(nn.Layer): + def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0, + dropout=0.): + """A convolutional Gated Linear Units (GLU) block. + + Args: + kernel_size (int): kernel size + in_ch (int): number of input channels + out_ch (int): number of output channels + bottlececk_dim (int): dimension of the bottleneck layers for computational efficiency. Defaults to 0. + dropout (float): dropout probability. Defaults to 0.. + """ + + super().__init__() + + self.conv_residual = None + if in_ch != out_ch: + self.conv_residual = nn.utils.weight_norm( + nn.Conv2D( + in_channels=in_ch, out_channels=out_ch, kernel_size=(1, 1)), + name='weight', + dim=0) + self.dropout_residual = nn.Dropout(p=dropout) + + self.pad_left = ConstantPad2d((0, 0, kernel_size - 1, 0), 0) + + layers = OrderedDict() + if bottlececk_dim == 0: + layers['conv'] = nn.utils.weight_norm( + nn.Conv2D( + in_channels=in_ch, + out_channels=out_ch * 2, + kernel_size=(kernel_size, 1)), + name='weight', + dim=0) + # TODO(hirofumi0810): padding? + layers['dropout'] = nn.Dropout(p=dropout) + layers['glu'] = GLU() + + elif bottlececk_dim > 0: + layers['conv_in'] = nn.utils.weight_norm( + nn.Conv2D( + in_channels=in_ch, + out_channels=bottlececk_dim, + kernel_size=(1, 1)), + name='weight', + dim=0) + layers['dropout_in'] = nn.Dropout(p=dropout) + layers['conv_bottleneck'] = nn.utils.weight_norm( + nn.Conv2D( + in_channels=bottlececk_dim, + out_channels=bottlececk_dim, + kernel_size=(kernel_size, 1)), + name='weight', + dim=0) + layers['dropout'] = nn.Dropout(p=dropout) + layers['glu'] = GLU() + layers['conv_out'] = nn.utils.weight_norm( + nn.Conv2D( + in_channels=bottlececk_dim, + out_channels=out_ch * 2, + kernel_size=(1, 1)), + name='weight', + dim=0) + layers['dropout_out'] = nn.Dropout(p=dropout) + + self.layers = nn.Sequential(layers) + + def forward(self, xs): + """Forward pass. + Args: + xs (FloatTensor): `[B, in_ch, T, feat_dim]` + Returns: + out (FloatTensor): `[B, out_ch, T, feat_dim]` + """ + residual = xs + if self.conv_residual is not None: + residual = self.dropout_residual(self.conv_residual(residual)) + xs = self.pad_left(xs) # `[B, embed_dim, T+kernel-1, 1]` + xs = self.layers(xs) # `[B, out_ch * 2, T ,1]` + xs = xs + residual + return xs + + +def get_activation(act): + """Return activation function.""" + # Lazy load to avoid unused import + activation_funcs = { + "hardtanh": paddle.nn.Hardtanh, + "tanh": paddle.nn.Tanh, + "relu": paddle.nn.ReLU, + "selu": paddle.nn.SELU, + "swish": paddle.nn.Swish, + "gelu": paddle.nn.GELU, + "brelu": brelu, + } + + return activation_funcs[act]() diff --git a/text_processing/speechtask/punctuation_restoration/modules/attention.py b/text_processing/speechtask/punctuation_restoration/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7363c4d87336e16c026c53c132ac00ca3edcce --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/modules/attention.py @@ -0,0 +1,229 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Multi-Head Attention layer definition.""" +import math +from typing import Optional +from typing import Tuple + +import paddle +from paddle import nn +from paddle.nn import initializer as I + +__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"] + +# Relative Positional Encodings +# https://www.jianshu.com/p/c0608efcc26f +# https://zhuanlan.zhihu.com/p/344604604 + + +class MultiHeadedAttention(nn.Layer): + """Multi-Head Attention layer.""" + + def __init__(self, n_head: int, n_feat: int, dropout_rate: float): + """Construct an MultiHeadedAttention object. + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + super().__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv(self, + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Transform query, key and value. + Args: + query (paddle.Tensor): Query tensor (#batch, time1, size). + key (paddle.Tensor): Key tensor (#batch, time2, size). + value (paddle.Tensor): Value tensor (#batch, time2, size). + Returns: + paddle.Tensor: Transformed query tensor, size + (#batch, n_head, time1, d_k). + paddle.Tensor: Transformed key tensor, size + (#batch, n_head, time2, d_k). + paddle.Tensor: Transformed value tensor, size + (#batch, n_head, time2, d_k). + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) + k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k) + v = v.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention(self, + value: paddle.Tensor, + scores: paddle.Tensor, + mask: Optional[paddle.Tensor]) -> paddle.Tensor: + """Compute attention context vector. + Args: + value (paddle.Tensor): Transformed value, size + (#batch, n_head, time2, d_k). + scores (paddle.Tensor): Attention score, size + (#batch, n_head, time1, time2). + mask (paddle.Tensor): Mask, size (#batch, 1, time2) or + (#batch, time1, time2). + Returns: + paddle.Tensor: Transformed value weighted + by the attention score, (#batch, time1, d_model). + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + scores = scores.masked_fill(mask, -float('inf')) + attn = paddle.softmax( + scores, axis=-1).masked_fill(mask, + 0.0) # (batch, head, time1, time2) + else: + attn = paddle.softmax( + scores, axis=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(attn) + x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k) + x = x.transpose([0, 2, 1, 3]).contiguous().view( + n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + mask: Optional[paddle.Tensor]) -> paddle.Tensor: + """Compute scaled dot product attention. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + """ + q, k, v = self.forward_qkv(query, key, value) + scores = paddle.matmul(q, + k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask) + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding.""" + + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an RelPositionMultiHeadedAttention object. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + super().__init__(n_head, n_feat, dropout_rate) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + #self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + #self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + #torch.nn.init.xavier_uniform_(self.pos_bias_u) + #torch.nn.init.xavier_uniform_(self.pos_bias_v) + pos_bias_u = self.create_parameter( + [self.h, self.d_k], default_initializer=I.XavierUniform()) + self.add_parameter('pos_bias_u', pos_bias_u) + pos_bias_v = self.create_parameter( + (self.h, self.d_k), default_initializer=I.XavierUniform()) + self.add_parameter('pos_bias_v', pos_bias_v) + + def rel_shift(self, x, zero_triu: bool=False): + """Compute relative positinal encoding. + Args: + x (paddle.Tensor): Input tensor (batch, head, time1, time1). + zero_triu (bool): If true, return the lower triangular part of + the matrix. + Returns: + paddle.Tensor: Output tensor. (batch, head, time1, time1) + """ + zero_pad = paddle.zeros( + (x.size(0), x.size(1), x.size(2), 1), dtype=x.dtype) + x_padded = paddle.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] + + if zero_triu: + ones = paddle.ones((x.size(2), x.size(3))) + x = x * paddle.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + pos_emb: paddle.Tensor, + mask: Optional[paddle.Tensor]): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (paddle.Tensor): Query tensor (#batch, time1, size). + key (paddle.Tensor): Key tensor (#batch, time2, size). + value (paddle.Tensor): Value tensor (#batch, time2, size). + pos_emb (paddle.Tensor): Positional embedding tensor + (#batch, time1, size). + mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + Returns: + paddle.Tensor: Output tensor (#batch, time1, d_model). + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3]) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3]) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2])) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2])) + # Remove rel_shift since it is useless in speech recognition, + # and it requires special attention for streaming. + # matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask) diff --git a/text_processing/speechtask/punctuation_restoration/modules/crf.py b/text_processing/speechtask/punctuation_restoration/modules/crf.py new file mode 100644 index 0000000000000000000000000000000000000000..0a53ae6f8c94ace0d066cfc62c4ef7c2ac412718 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/modules/crf.py @@ -0,0 +1,366 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +from paddle import nn + +__all__ = ['CRF'] + + +class CRF(nn.Layer): + """ + Linear-chain Conditional Random Field (CRF). + + Args: + nb_labels (int): number of labels in your tagset, including special symbols. + bos_tag_id (int): integer representing the beginning of sentence symbol in + your tagset. + eos_tag_id (int): integer representing the end of sentence symbol in your tagset. + pad_tag_id (int, optional): integer representing the pad symbol in your tagset. + If None, the model will treat the PAD as a normal tag. Otherwise, the model + will apply constraints for PAD transitions. + batch_first (bool): Whether the first dimension represents the batch dimension. + """ + + def __init__(self, + nb_labels: int, + bos_tag_id: int, + eos_tag_id: int, + pad_tag_id: int=None, + batch_first: bool=True): + super().__init__() + + self.nb_labels = nb_labels + self.BOS_TAG_ID = bos_tag_id + self.EOS_TAG_ID = eos_tag_id + self.PAD_TAG_ID = pad_tag_id + self.batch_first = batch_first + + # initialize transitions from a random uniform distribution between -0.1 and 0.1 + self.transitions = self.create_parameter( + [self.nb_labels, self.nb_labels], + default_initializer=nn.initializer.Uniform(-0.1, 0.1)) + self.init_weights() + + def init_weights(self): + # enforce contraints (rows=from, columns=to) with a big negative number + # so exp(-10000) will tend to zero + + # no transitions allowed to the beginning of sentence + self.transitions[:, self.BOS_TAG_ID] = -10000.0 + # no transition alloed from the end of sentence + self.transitions[self.EOS_TAG_ID, :] = -10000.0 + + if self.PAD_TAG_ID is not None: + # no transitions from padding + self.transitions[self.PAD_TAG_ID, :] = -10000.0 + # no transitions to padding + self.transitions[:, self.PAD_TAG_ID] = -10000.0 + # except if the end of sentence is reached + # or we are already in a pad position + self.transitions[self.PAD_TAG_ID, self.EOS_TAG_ID] = 0.0 + self.transitions[self.PAD_TAG_ID, self.PAD_TAG_ID] = 0.0 + + def forward(self, + emissions: paddle.Tensor, + tags: paddle.Tensor, + mask: paddle.Tensor=None) -> paddle.Tensor: + """Compute the negative log-likelihood. See `log_likelihood` method.""" + nll = -self.log_likelihood(emissions, tags, mask=mask) + return nll + + def log_likelihood(self, emissions, tags, mask=None): + """Compute the probability of a sequence of tags given a sequence of + emissions scores. + + Args: + emissions (paddle.Tensor): Sequence of emissions for each label. + Shape of (batch_size, seq_len, nb_labels) if batch_first is True, + (seq_len, batch_size, nb_labels) otherwise. + tags (paddle.LongTensor): Sequence of labels. + Shape of (batch_size, seq_len) if batch_first is True, + (seq_len, batch_size) otherwise. + mask (paddle.FloatTensor, optional): Tensor representing valid positions. + If None, all positions are considered valid. + Shape of (batch_size, seq_len) if batch_first is True, + (seq_len, batch_size) otherwise. + + Returns: + paddle.Tensor: sum of the log-likelihoods for each sequence in the batch. + Shape of () + """ + # fix tensors order by setting batch as the first dimension + if not self.batch_first: + emissions = emissions.transpose(0, 1) + tags = tags.transpose(0, 1) + + if mask is None: + mask = paddle.ones(emissions.shape[:2], dtype=paddle.float) + + scores = self._compute_scores(emissions, tags, mask=mask) + partition = self._compute_log_partition(emissions, mask=mask) + return paddle.sum(scores - partition) + + def decode(self, emissions, mask=None): + """Find the most probable sequence of labels given the emissions using + the Viterbi algorithm. + + Args: + emissions (paddle.Tensor): Sequence of emissions for each label. + Shape (batch_size, seq_len, nb_labels) if batch_first is True, + (seq_len, batch_size, nb_labels) otherwise. + mask (paddle.FloatTensor, optional): Tensor representing valid positions. + If None, all positions are considered valid. + Shape (batch_size, seq_len) if batch_first is True, + (seq_len, batch_size) otherwise. + + Returns: + paddle.Tensor: the viterbi score for the for each batch. + Shape of (batch_size,) + list of lists: the best viterbi sequence of labels for each batch. [B, T] + """ + # fix tensors order by setting batch as the first dimension + if not self.batch_first: + emissions = emissions.transpose(0, 1) + tags = tags.transpose(0, 1) + + if mask is None: + mask = paddle.ones(emissions.shape[:2], dtype=paddle.float) + + scores, sequences = self._viterbi_decode(emissions, mask) + return scores, sequences + + def _compute_scores(self, emissions, tags, mask): + """Compute the scores for a given batch of emissions with their tags. + + Args: + emissions (paddle.Tensor): (batch_size, seq_len, nb_labels) + tags (Paddle.LongTensor): (batch_size, seq_len) + mask (Paddle.FloatTensor): (batch_size, seq_len) + + Returns: + paddle.Tensor: Scores for each batch. + Shape of (batch_size,) + """ + batch_size, seq_length = tags.shape + scores = paddle.zeros([batch_size]) + + # save first and last tags to be used later + first_tags = tags[:, 0] + last_valid_idx = mask.int().sum(1) - 1 + + # TODO(Hui Zhang): not support fancy index. + # last_tags = tags.gather(last_valid_idx.unsqueeze(1), axis=1).squeeze() + batch_idx = paddle.arange(batch_size, dtype=last_valid_idx.dtype) + gather_last_valid_idx = paddle.stack( + [batch_idx, last_valid_idx], axis=-1) + last_tags = tags.gather_nd(gather_last_valid_idx) + + # add the transition from BOS to the first tags for each batch + # t_scores = self.transitions[self.BOS_TAG_ID, first_tags] + t_scores = self.transitions[self.BOS_TAG_ID].gather(first_tags) + + # add the [unary] emission scores for the first tags for each batch + # for all batches, the first word, see the correspondent emissions + # for the first tags (which is a list of ids): + # emissions[:, 0, [tag_1, tag_2, ..., tag_nblabels]] + # e_scores = emissions[:, 0].gather(1, first_tags.unsqueeze(1)).squeeze() + gather_first_tags_idx = paddle.stack([batch_idx, first_tags], axis=-1) + e_scores = emissions[:, 0].gather_nd(gather_first_tags_idx) + + # the scores for a word is just the sum of both scores + scores += e_scores + t_scores + + # now lets do this for each remaining word + for i in range(1, seq_length): + + # we could: iterate over batches, check if we reached a mask symbol + # and stop the iteration, but vecotrizing is faster due to gpu, + # so instead we perform an element-wise multiplication + is_valid = mask[:, i] + + previous_tags = tags[:, i - 1] + current_tags = tags[:, i] + + # calculate emission and transition scores as we did before + # e_scores = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze() + gather_current_tags_idx = paddle.stack( + [batch_idx, current_tags], axis=-1) + e_scores = emissions[:, i].gather_nd(gather_current_tags_idx) + # t_scores = self.transitions[previous_tags, current_tags] + gather_transitions_idx = paddle.stack( + [previous_tags, current_tags], axis=-1) + t_scores = self.transitions.gather_nd(gather_transitions_idx) + + # apply the mask + e_scores = e_scores * is_valid + t_scores = t_scores * is_valid + + scores += e_scores + t_scores + + # add the transition from the end tag to the EOS tag for each batch + # scores += self.transitions[last_tags, self.EOS_TAG_ID] + scores += self.transitions.gather(last_tags)[:, self.EOS_TAG_ID] + + return scores + + def _compute_log_partition(self, emissions, mask): + """Compute the partition function in log-space using the forward-algorithm. + + Args: + emissions (paddle.Tensor): (batch_size, seq_len, nb_labels) + mask (Paddle.FloatTensor): (batch_size, seq_len) + + Returns: + paddle.Tensor: the partition scores for each batch. + Shape of (batch_size,) + """ + batch_size, seq_length, nb_labels = emissions.shape + + # in the first iteration, BOS will have all the scores + alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze( + 0) + emissions[:, 0] + + for i in range(1, seq_length): + # (bs, nb_labels) -> (bs, 1, nb_labels) + e_scores = emissions[:, i].unsqueeze(1) + + # (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels) + t_scores = self.transitions.unsqueeze(0) + + # (bs, nb_labels) -> (bs, nb_labels, 1) + a_scores = alphas.unsqueeze(2) + + scores = e_scores + t_scores + a_scores + new_alphas = paddle.logsumexp(scores, axis=1) + + # set alphas if the mask is valid, otherwise keep the current values + is_valid = mask[:, i].unsqueeze(-1) + alphas = is_valid * new_alphas + (1 - is_valid) * alphas + + # add the scores for the final transition + last_transition = self.transitions[:, self.EOS_TAG_ID] + end_scores = alphas + last_transition.unsqueeze(0) + + # return a *log* of sums of exps + return paddle.logsumexp(end_scores, axis=1) + + def _viterbi_decode(self, emissions, mask): + """Compute the viterbi algorithm to find the most probable sequence of labels + given a sequence of emissions. + + Args: + emissions (paddle.Tensor): (batch_size, seq_len, nb_labels) + mask (Paddle.FloatTensor): (batch_size, seq_len) + + Returns: + paddle.Tensor: the viterbi score for the for each batch. + Shape of (batch_size,) + list of lists of ints: the best viterbi sequence of labels for each batch + """ + batch_size, seq_length, nb_labels = emissions.shape + + # in the first iteration, BOS will have all the scores and then, the max + alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze( + 0) + emissions[:, 0] + + backpointers = [] + + for i in range(1, seq_length): + # (bs, nb_labels) -> (bs, 1, nb_labels) + e_scores = emissions[:, i].unsqueeze(1) + + # (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels) + t_scores = self.transitions.unsqueeze(0) + + # (bs, nb_labels) -> (bs, nb_labels, 1) + a_scores = alphas.unsqueeze(2) + + # combine current scores with previous alphas + scores = e_scores + t_scores + a_scores + + # so far is exactly like the forward algorithm, + # but now, instead of calculating the logsumexp, + # we will find the highest score and the tag associated with it + # max_scores, max_score_tags = paddle.max(scores, axis=1) + max_scores = paddle.max(scores, axis=1) + max_score_tags = paddle.argmax(scores, axis=1) + + # set alphas if the mask is valid, otherwise keep the current values + is_valid = mask[:, i].unsqueeze(-1) + alphas = is_valid * max_scores + (1 - is_valid) * alphas + + # add the max_score_tags for our list of backpointers + # max_scores has shape (batch_size, nb_labels) so we transpose it to + # be compatible with our previous loopy version of viterbi + backpointers.append(max_score_tags.t()) + + # add the scores for the final transition + last_transition = self.transitions[:, self.EOS_TAG_ID] + end_scores = alphas + last_transition.unsqueeze(0) + + # get the final most probable score and the final most probable tag + # max_final_scores, max_final_tags = paddle.max(end_scores, axis=1) + max_final_scores = paddle.max(end_scores, axis=1) + max_final_tags = paddle.argmax(end_scores, axis=1) + + # find the best sequence of labels for each sample in the batch + best_sequences = [] + emission_lengths = mask.int().sum(axis=1) + for i in range(batch_size): + + # recover the original sentence length for the i-th sample in the batch + sample_length = emission_lengths[i].item() + + # recover the max tag for the last timestep + sample_final_tag = max_final_tags[i].item() + + # limit the backpointers until the last but one + # since the last corresponds to the sample_final_tag + sample_backpointers = backpointers[:sample_length - 1] + + # follow the backpointers to build the sequence of labels + sample_path = self._find_best_path(i, sample_final_tag, + sample_backpointers) + + # add this path to the list of best sequences + best_sequences.append(sample_path) + + return max_final_scores, best_sequences + + def _find_best_path(self, sample_id, best_tag, backpointers): + """Auxiliary function to find the best path sequence for a specific sample. + + Args: + sample_id (int): sample index in the range [0, batch_size) + best_tag (int): tag which maximizes the final score + backpointers (list of lists of tensors): list of pointers with + shape (seq_len_i-1, nb_labels, batch_size) where seq_len_i + represents the length of the ith sample in the batch + + Returns: + list of ints: a list of tag indexes representing the bast path + """ + # add the final best_tag to our best path + best_path = [best_tag] + + # traverse the backpointers in backwards + for backpointers_t in reversed(backpointers): + + # recover the best_tag at this timestep + best_tag = backpointers_t[best_tag][sample_id].item() + + # append to the beginning of the list so we don't need to reverse it later + best_path.insert(0, best_tag) + + return best_path diff --git a/text_processing/speechtask/punctuation_restoration/training/__init__.py b/text_processing/speechtask/punctuation_restoration/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..185a92b8d94d3426d616c0624f0f2ee04339349e --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/training/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/text_processing/speechtask/punctuation_restoration/training/loss.py b/text_processing/speechtask/punctuation_restoration/training/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..356dfcab183ed023b9e96eb1afa4fdd7f6a6fac3 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/training/loss.py @@ -0,0 +1,98 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class FocalLossHX(nn.Layer): + def __init__(self, gamma=0, size_average=True): + super(FocalLoss, self).__init__() + self.gamma = gamma + self.size_average = size_average + + def forward(self, input, target): + # print('input') + # print(input.shape) + # print(target.shape) + + if input.dim() > 2: + input = paddle.reshape( + input, + shape=[input.size(0), input.size(1), -1]) # N,C,H,W => N,C,H*W + input = input.transpose(1, 2) # N,C,H*W => N,H*W,C + input = paddle.reshape( + input, shape=[-1, input.size(2)]) # N,H*W,C => N*H*W,C + target = paddle.reshape(target, shape=[-1]) + + logpt = F.log_softmax(input) + # print('logpt') + # print(logpt.shape) + # print(logpt) + + # get true class column from each row + all_rows = paddle.arange(len(input)) + # print(target) + log_pt = logpt.numpy()[all_rows.numpy(), target.numpy()] + + pt = paddle.to_tensor(log_pt, dtype='float64').exp() + ce = F.cross_entropy(input, target, reduction='none') + # print('ce') + # print(ce.shape) + + loss = (1 - pt)**self.gamma * ce + # print('ce:%f'%ce.mean()) + # print('fl:%f'%loss.mean()) + if self.size_average: + return loss.mean() + else: + return loss.sum() + + +class FocalLoss(nn.Layer): + """ + Focal Loss. + Code referenced from: + https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py + Args: + gamma (float): the coefficient of Focal Loss. + ignore_index (int64): Specifies a target value that is ignored + and does not contribute to the input gradient. Default ``255``. + """ + + def __init__(self, gamma=2.0): + super(FocalLoss, self).__init__() + self.gamma = gamma + + def forward(self, logit, label): + #####logit = F.softmax(logit) + # logit = paddle.reshape( + # logit, [logit.shape[0], logit.shape[1], -1]) # N,C,H,W => N,C,H*W + # logit = paddle.transpose(logit, [0, 2, 1]) # N,C,H*W => N,H*W,C + # logit = paddle.reshape(logit, + # [-1, logit.shape[2]]) # N,H*W,C => N*H*W,C + label = paddle.reshape(label, [-1, 1]) + range_ = paddle.arange(0, label.shape[0]) + range_ = paddle.unsqueeze(range_, axis=-1) + label = paddle.cast(label, dtype='int64') + label = paddle.concat([range_, label], axis=-1) + logpt = F.log_softmax(logit) + logpt = paddle.gather_nd(logpt, label) + + pt = paddle.exp(logpt.detach()) + loss = -1 * (1 - pt)**self.gamma * logpt + loss = paddle.mean(loss) + # print(loss) + # print(logpt) + return loss diff --git a/text_processing/speechtask/punctuation_restoration/training/trainer.py b/text_processing/speechtask/punctuation_restoration/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..2dce88a3f5d4f5916835654f5642cf86b8999eaf --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/training/trainer.py @@ -0,0 +1,651 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import time +from collections import defaultdict +from pathlib import Path + +import numpy as np +import paddle +import paddle.nn as nn +import pandas as pd +from paddle import distributed as dist +from paddle.io import DataLoader +from sklearn.metrics import classification_report +from sklearn.metrics import f1_score +from sklearn.metrics import precision_recall_fscore_support +from speechtask.punctuation_restoration.io.dataset import PuncDataset +from speechtask.punctuation_restoration.io.dataset import PuncDatasetFromBertTokenizer +from speechtask.punctuation_restoration.model.BertBLSTM import BertBLSTMPunc +from speechtask.punctuation_restoration.model.BertLinear import BertLinearPunc +from speechtask.punctuation_restoration.model.blstm import BiLSTM +from speechtask.punctuation_restoration.model.lstm import RnnLm +from speechtask.punctuation_restoration.utils import layer_tools +from speechtask.punctuation_restoration.utils import mp_tools +from speechtask.punctuation_restoration.utils.checkpoint import Checkpoint +from tensorboardX import SummaryWriter + +__all__ = ["Trainer", "Tester"] + +DefinedClassifier = { + "lstm": RnnLm, + "blstm": BiLSTM, + "BertLinear": BertLinearPunc, + "BertBLSTM": BertBLSTMPunc +} + +DefinedLoss = { + "ce": nn.CrossEntropyLoss, +} + +DefinedDataset = { + 'PuncCh': PuncDataset, + 'Bert': PuncDatasetFromBertTokenizer, +} + + +class Trainer(): + """ + An experiment template in order to structure the training code and take + care of saving, loading, logging, visualization stuffs. It"s intended to + be flexible and simple. + + So it only handles output directory (create directory for the output, + create a checkpoint directory, dump the config in use and create + visualizer and logger) in a standard way without enforcing any + input-output protocols to the model and dataloader. It leaves the main + part for the user to implement their own (setup the model, criterion, + optimizer, define a training step, define a validation function and + customize all the text and visual logs). + It does not save too much boilerplate code. The users still have to write + the forward/backward/update mannually, but they are free to add + non-standard behaviors if needed. + We have some conventions to follow. + 1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and + ``valid_loader``, ``config`` and ``args`` attributes. + 2. The config should have a ``training`` field, which has + ``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is + used as the trigger to invoke validation, checkpointing and stop of the + experiment. + 3. There are four methods, namely ``train_batch``, ``valid``, + ``setup_model`` and ``setup_dataloader`` that should be implemented. + Feel free to add/overwrite other methods and standalone functions if you + need. + + Parameters + ---------- + config: yacs.config.CfgNode + The configuration used for the experiment. + + args: argparse.Namespace + The parsed command line arguments. + Examples + -------- + >>> def main_sp(config, args): + >>> exp = Trainer(config, args) + >>> exp.setup() + >>> exp.run() + >>> + >>> config = get_cfg_defaults() + >>> parser = default_argument_parser() + >>> args = parser.parse_args() + >>> if args.config: + >>> config.merge_from_file(args.config) + >>> if args.opts: + >>> config.merge_from_list(args.opts) + >>> config.freeze() + >>> + >>> if args.nprocs > 1 and args.device == "gpu": + >>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) + >>> else: + >>> main_sp(config, args) + """ + + def __init__(self, config, args): + self.config = config + self.args = args + self.optimizer = None + self.visualizer = None + self.output_dir = None + self.checkpoint_dir = None + self.iteration = 0 + self.epoch = 0 + + def setup(self): + """Setup the experiment. + """ + self.setup_logger() + paddle.set_device(self.args.device) + if self.parallel: + self.init_parallel() + + self.setup_output_dir() + self.dump_config() + self.setup_visualizer() + self.setup_checkpointer() + + self.setup_model() + + self.setup_dataloader() + + self.iteration = 0 + self.epoch = 0 + + @property + def parallel(self): + """A flag indicating whether the experiment should run with + multiprocessing. + """ + return self.args.device == "gpu" and self.args.nprocs > 1 + + def init_parallel(self): + """Init environment for multiprocess training. + """ + dist.init_parallel_env() + + @mp_tools.rank_zero_only + def save(self, tag=None, infos: dict=None): + """Save checkpoint (model parameters and optimizer states). + + Args: + tag (int or str, optional): None for step, else using tag, e.g epoch. Defaults to None. + infos (dict, optional): meta data to save. Defaults to None. + """ + + infos = infos if infos else dict() + infos.update({ + "step": self.iteration, + "epoch": self.epoch, + "lr": self.optimizer.get_lr() + }) + self.checkpointer.add_checkpoint(self.checkpoint_dir, self.iteration + if tag is None else tag, self.model, + self.optimizer, infos) + + def resume_or_scratch(self): + """Resume from latest checkpoint at checkpoints in the output + directory or load a specified checkpoint. + + If ``args.checkpoint_path`` is not None, load the checkpoint, else + resume training. + """ + scratch = None + infos = self.checkpointer.load_parameters( + self.model, + self.optimizer, + checkpoint_dir=self.checkpoint_dir, + checkpoint_path=self.args.checkpoint_path) + if infos: + # restore from ckpt + self.iteration = infos["step"] + self.epoch = infos["epoch"] + scratch = False + else: + self.iteration = 0 + self.epoch = 0 + scratch = True + + return scratch + + def new_epoch(self): + """Reset the train loader seed and increment `epoch`. + """ + self.epoch += 1 + if self.parallel: + self.train_loader.batch_sampler.set_epoch(self.epoch) + + def train(self): + """The training process control by epoch.""" + from_scratch = self.resume_or_scratch() + + if from_scratch: + # save init model, i.e. 0 epoch + self.save(tag="init") + + self.lr_scheduler.step(self.iteration) + if self.parallel: + self.train_loader.batch_sampler.set_epoch(self.epoch) + + self.logger.info( + f"Train Total Examples: {len(self.train_loader.dataset)}") + self.punc_list = [] + for i in range(len(self.train_loader.dataset.id2punc)): + self.punc_list.append(self.train_loader.dataset.id2punc[i]) + while self.epoch < self.config["training"]["n_epoch"]: + self.model.train() + self.total_label_train = [] + self.total_predict_train = [] + try: + data_start_time = time.time() + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(batch_index + 1, + len(self.train_loader)) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) + msg += "data time: {:>.3f}s, ".format(dataload_time) + self.train_batch(batch_index, batch, msg) + data_start_time = time.time() + t = classification_report( + self.total_label_train, + self.total_predict_train, + target_names=self.punc_list) + self.logger.info(t) + except Exception as e: + self.logger.error(e) + raise e + + total_loss, F1_score = self.valid() + self.logger.info("Epoch {} Val info val_loss {}, F1_score {}". + format(self.epoch, total_loss, F1_score)) + if self.visualizer: + self.visualizer.add_scalars("epoch", { + "total_loss": total_loss, + "lr": self.lr_scheduler() + }, self.epoch) + + self.save( + tag=self.epoch, infos={"val_loss": total_loss, + "F1": F1_score}) + # step lr every epoch + self.lr_scheduler.step() + self.new_epoch() + + def run(self): + """The routine of the experiment after setup. This method is intended + to be used by the user. + """ + try: + self.train() + except KeyboardInterrupt: + self.save() + exit(-1) + finally: + self.destory() + self.logger.info("Training Done.") + + def setup_output_dir(self): + """Create a directory used for output. + """ + # output dir + output_dir = Path(self.args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + + self.output_dir = output_dir + + def setup_checkpointer(self): + """Create a directory used to save checkpoints into. + + It is "checkpoints" inside the output directory. + """ + # checkpoint dir + self.checkpointer = Checkpoint(self.logger, + self.config["checkpoint"]["kbest_n"], + self.config["checkpoint"]["latest_n"]) + + checkpoint_dir = self.output_dir / "checkpoints" + checkpoint_dir.mkdir(exist_ok=True) + + self.checkpoint_dir = checkpoint_dir + + def setup_logger(self): + LOG_FORMAT = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s" + format_str = logging.Formatter( + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' + ) + logging.basicConfig( + filename=self.config["training"]["log_path"], + level=logging.INFO, + format=LOG_FORMAT) + self.logger = logging.getLogger(__name__) + # self.logger = logging.getLogger(self.config["training"]["log_path"].strip().split('/')[-1].split('.')[0]) + + self.logger.setLevel(logging.INFO) #设置日志级别 + sh = logging.StreamHandler() #往屏幕上输出 + sh.setFormatter(format_str) #设置屏幕上显示的格式 + self.logger.addHandler(sh) #把对象加到logger里 + + self.logger.info('info') + print("setup logger!!!") + + @mp_tools.rank_zero_only + def destory(self): + """Close visualizer to avoid hanging after training""" + # https://github.com/pytorch/fairseq/issues/2357 + if self.visualizer: + self.visualizer.close() + + @mp_tools.rank_zero_only + def setup_visualizer(self): + """Initialize a visualizer to log the experiment. + + The visual log is saved in the output directory. + + Notes + ------ + Only the main process has a visualizer with it. Use multiple + visualizers in multiprocess to write to a same log file may cause + unexpected behaviors. + """ + # visualizer + visualizer = SummaryWriter(logdir=str(self.output_dir)) + self.visualizer = visualizer + + @mp_tools.rank_zero_only + def dump_config(self): + """Save the configuration used for this experiment. + + It is saved in to ``config.yaml`` in the output directory at the + beginning of the experiment. + """ + with open(self.output_dir / "config.yaml", "wt") as f: + print(self.config, file=f) + + def train_batch(self, batch_index, batch_data, msg): + start = time.time() + + input, label = batch_data + label = paddle.reshape(label, shape=[-1]) + y, logit = self.model(input) + pred = paddle.argmax(logit, axis=1) + self.total_label_train.extend(label.numpy().tolist()) + self.total_predict_train.extend(pred.numpy().tolist()) + # self.total_predict.append(logit.numpy().tolist()) + # print('--after model----') + # # print(label.shape) + # # print(pred.shape) + # # print('--!!!!!!!!!!!!!----') + # print("self.total_label") + # print(self.total_label) + # print("self.total_predict") + # print(self.total_predict) + loss = self.crit(y, label) + + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + self.optimizer.step() + self.optimizer.clear_grad() + iteration_time = time.time() - start + + losses_np = { + "train_loss": float(loss), + } + msg += "train time: {:>.3f}s, ".format(iteration_time) + msg += "batch size: {}, ".format(self.config["data"]["batch_size"]) + msg += ", ".join("{}: {:>.6f}".format(k, v) + for k, v in losses_np.items()) + self.logger.info(msg) + # print(msg) + + if dist.get_rank() == 0 and self.visualizer: + for k, v in losses_np.items(): + self.visualizer.add_scalar("train/{}".format(k), v, + self.iteration) + self.iteration += 1 + + @paddle.no_grad() + def valid(self): + self.logger.info( + f"Valid Total Examples: {len(self.valid_loader.dataset)}") + self.model.eval() + valid_losses = defaultdict(list) + num_seen_utts = 1 + total_loss = 0.0 + valid_total_label = [] + valid_total_predict = [] + for i, batch in enumerate(self.valid_loader): + input, label = batch + label = paddle.reshape(label, shape=[-1]) + y, logit = self.model(input) + pred = paddle.argmax(logit, axis=1) + valid_total_label.extend(label.numpy().tolist()) + valid_total_predict.extend(pred.numpy().tolist()) + loss = self.crit(y, label) + + if paddle.isfinite(loss): + num_utts = batch[1].shape[0] + num_seen_utts += num_utts + total_loss += float(loss) * num_utts + valid_losses["val_loss"].append(float(loss)) + + if (i + 1) % self.config["training"]["log_interval"] == 0: + valid_dump = {k: np.mean(v) for k, v in valid_losses.items()} + valid_dump["val_history_loss"] = total_loss / num_seen_utts + + # logging + msg = f"Valid: Rank: {dist.get_rank()}, " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader)) + msg += ", ".join("{}: {:>.6f}".format(k, v) + for k, v in valid_dump.items()) + self.logger.info(msg) + # print(msg) + + self.logger.info("Rank {} Val info val_loss {}".format( + dist.get_rank(), total_loss / num_seen_utts)) + # print("Rank {} Val info val_loss {} acc: {}".format( + # dist.get_rank(), total_loss / num_seen_utts, acc)) + F1_score = f1_score( + valid_total_label, valid_total_predict, average="macro") + return total_loss / num_seen_utts, F1_score + + def setup_model(self): + config = self.config + + model = DefinedClassifier[self.config["model_type"]]( + **self.config["model_params"]) + self.crit = DefinedLoss[self.config["loss_type"]](**self.config[ + "loss"]) if "loss_type" in self.config else DefinedLoss["ce"]() + + if self.parallel: + model = paddle.DataParallel(model) + + self.logger.info(f"{model}") + layer_tools.print_params(model, self.logger.info) + + lr_scheduler = paddle.optimizer.lr.ExponentialDecay( + learning_rate=config["training"]["lr"], + gamma=config["training"]["lr_decay"], + verbose=True) + optimizer = paddle.optimizer.Adam( + learning_rate=lr_scheduler, + parameters=model.parameters(), + weight_decay=paddle.regularizer.L2Decay( + config["training"]["weight_decay"])) + + self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.logger.info("Setup model/criterion/optimizer/lr_scheduler!") + + def setup_dataloader(self): + print("setup_dataloader!!!") + config = self.config["data"].copy() + + print(config["batch_size"]) + + train_dataset = DefinedDataset[config["dataset_type"]]( + train_path=config["train_path"], **config["data_params"]) + dev_dataset = DefinedDataset[config["dataset_type"]]( + train_path=config["dev_path"], **config["data_params"]) + + # train_dataset = config["dataset_type"](os.path.join(config["save_path"], "train"), + # os.path.join(config["save_path"], config["vocab_file"]), + # os.path.join(config["save_path"], config["punc_file"]), + # config["seq_len"]) + + # dev_dataset = PuncDataset(os.path.join(config["save_path"], "dev"), + # os.path.join(config["save_path"], config["vocab_file"]), + # os.path.join(config["save_path"], config["punc_file"]), + # config["seq_len"]) + + # if self.parallel: + # batch_sampler = SortagradDistributedBatchSampler( + # train_dataset, + # batch_size=config["batch_size"], + # num_replicas=None, + # rank=None, + # shuffle=True, + # drop_last=True, + # sortagrad=config["sortagrad"], + # shuffle_method=config["shuffle_method"]) + # else: + # batch_sampler = SortagradBatchSampler( + # train_dataset, + # shuffle=True, + # batch_size=config["batch_size"], + # drop_last=True, + # sortagrad=config["sortagrad"], + # shuffle_method=config["shuffle_method"]) + + self.train_loader = DataLoader( + train_dataset, + num_workers=config["num_workers"], + batch_size=config["batch_size"]) + self.valid_loader = DataLoader( + dev_dataset, + batch_size=config["batch_size"], + shuffle=False, + drop_last=False, + num_workers=config["num_workers"]) + self.logger.info("Setup train/valid Dataloader!") + + +class Tester(Trainer): + def __init__(self, config, args): + super().__init__(config, args) + + @mp_tools.rank_zero_only + @paddle.no_grad() + def test(self): + self.logger.info( + f"Test Total Examples: {len(self.test_loader.dataset)}") + self.punc_list = [] + for i in range(len(self.test_loader.dataset.id2punc)): + self.punc_list.append(self.test_loader.dataset.id2punc[i]) + self.model.eval() + test_total_label = [] + test_total_predict = [] + with open(self.args.result_file, 'w') as fout: + for i, batch in enumerate(self.test_loader): + input, label = batch + label = paddle.reshape(label, shape=[-1]) + y, logit = self.model(input) + pred = paddle.argmax(logit, axis=1) + test_total_label.extend(label.numpy().tolist()) + test_total_predict.extend(pred.numpy().tolist()) + # print(type(logit)) + + # logging + msg = "Test: " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + self.logger.info(msg) + # print(msg) + t = classification_report( + test_total_label, test_total_predict, target_names=self.punc_list) + print(t) + t2 = self.evaluation(test_total_label, test_total_predict) + print(t2) + + def evaluation(self, y_pred, y_test): + precision, recall, f1, _ = precision_recall_fscore_support( + y_test, y_pred, average=None, labels=[1, 2, 3]) + overall = precision_recall_fscore_support( + y_test, y_pred, average='macro', labels=[1, 2, 3]) + result = pd.DataFrame( + np.array([precision, recall, f1]), + columns=list(['O', 'COMMA', 'PERIOD', 'QUESTION'])[1:], + index=['Precision', 'Recall', 'F1']) + result['OVERALL'] = overall[:3] + return result + + def run_test(self): + self.resume_or_scratch() + try: + self.test() + except KeyboardInterrupt: + exit(-1) + + def setup(self): + """Setup the experiment. + """ + paddle.set_device(self.args.device) + self.setup_logger() + self.setup_output_dir() + self.setup_checkpointer() + + self.setup_dataloader() + self.setup_model() + + self.iteration = 0 + self.epoch = 0 + + def setup_model(self): + config = self.config + model = DefinedClassifier[self.config["model_type"]]( + **self.config["model_params"]) + + self.model = model + self.logger.info("Setup model!") + + def setup_dataloader(self): + config = self.config["data"].copy() + + test_dataset = DefinedDataset[config["dataset_type"]]( + train_path=config["test_path"], **config["data_params"]) + + self.test_loader = DataLoader( + test_dataset, + batch_size=config["batch_size"], + shuffle=False, + drop_last=False) + self.logger.info("Setup test Dataloader!") + + def setup_output_dir(self): + """Create a directory used for output. + """ + # output dir + if self.args.output: + output_dir = Path(self.args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + else: + output_dir = Path( + self.args.checkpoint_path).expanduser().parent.parent + output_dir.mkdir(parents=True, exist_ok=True) + + self.output_dir = output_dir + + def setup_logger(self): + LOG_FORMAT = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s" + format_str = logging.Formatter( + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' + ) + logging.basicConfig( + filename=self.config["testing"]["log_path"], + level=logging.INFO, + format=LOG_FORMAT) + self.logger = logging.getLogger(__name__) + # self.logger = logging.getLogger(self.config["training"]["log_path"].strip().split('/')[-1].split('.')[0]) + + self.logger.setLevel(logging.INFO) #设置日志级别 + sh = logging.StreamHandler() #往屏幕上输出 + sh.setFormatter(format_str) #设置屏幕上显示的格式 + self.logger.addHandler(sh) #把对象加到logger里 + + self.logger.info('info') + print("setup test logger!!!") diff --git a/text_processing/speechtask/punctuation_restoration/utils/__init__.py b/text_processing/speechtask/punctuation_restoration/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..185a92b8d94d3426d616c0624f0f2ee04339349e --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/text_processing/speechtask/punctuation_restoration/utils/checkpoint.py b/text_processing/speechtask/punctuation_restoration/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad4b5b363cef2c2e86a6e2e9aa4bc5931793a74 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/utils/checkpoint.py @@ -0,0 +1,304 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import glob +import json +import os +import re +from pathlib import Path +from typing import Text +from typing import Union + +import paddle +from paddle import distributed as dist +from paddle.optimizer import Optimizer +from speechtask.punctuation_restoration.utils import mp_tools +# from speechtask.punctuation_restoration.utils.log import Log + +# logger = Log(__name__).getlog() + +__all__ = ["Checkpoint"] + + +class Checkpoint(): + def __init__(self, + logger, + kbest_n: int=5, + latest_n: int=1, + metric_type='val_loss'): + self.best_records: Mapping[Path, float] = {} + self.latest_records = [] + self.kbest_n = kbest_n + self.latest_n = latest_n + self._save_all = (kbest_n == -1) + self.logger = logger + self.metric_type = metric_type + + def add_checkpoint(self, + checkpoint_dir, + tag_or_iteration: Union[int, Text], + model: paddle.nn.Layer, + optimizer: Optimizer=None, + infos: dict=None): + """Save checkpoint in best_n and latest_n. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + tag_or_iteration (int or str): the latest iteration(step or epoch) number or tag. + model (Layer): model to be checkpointed. + optimizer (Optimizer, optional): optimizer to be checkpointed. + infos (dict or None)): any info you want to save. + metric_type (str, optional): metric type. Defaults to 'val_loss'. + """ + metric_type = self.metric_type + if (metric_type not in infos.keys()): + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + return + + #save best + if self._should_save_best(infos[metric_type]): + self._save_best_checkpoint_and_update( + infos[metric_type], checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + #save latest + self._save_latest_checkpoint_and_update( + checkpoint_dir, tag_or_iteration, model, optimizer, infos) + + if isinstance(tag_or_iteration, int): + self._save_checkpoint_record(checkpoint_dir, tag_or_iteration) + + def load_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None, + record_file="checkpoint_latest"): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + record_file "checkpoint_latest" or "checkpoint_best" + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + configs = {} + + if checkpoint_path is not None: + pass + elif checkpoint_dir is not None and record_file is not None: + # load checkpint from record file + checkpoint_record = os.path.join(checkpoint_dir, record_file) + iteration = self._load_checkpoint_idx(checkpoint_record) + if iteration == -1: + return configs + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(iteration)) + else: + raise ValueError( + "At least one of 'checkpoint_path' or 'checkpoint_dir' should be specified!" + ) + + rank = dist.get_rank() + + params_path = checkpoint_path + ".pdparams" + model_dict = paddle.load(params_path) + model.set_state_dict(model_dict) + self.logger.info( + "Rank {}: loaded model from {}".format(rank, params_path)) + + optimizer_path = checkpoint_path + ".pdopt" + if optimizer and os.path.isfile(optimizer_path): + optimizer_dict = paddle.load(optimizer_path) + optimizer.set_state_dict(optimizer_dict) + self.logger.info("Rank {}: loaded optimizer state from {}".format( + rank, optimizer_path)) + + info_path = re.sub('.pdparams$', '.json', params_path) + if os.path.exists(info_path): + with open(info_path, 'r') as fin: + configs = json.load(fin) + return configs + + def load_latest_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + return self.load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_latest") + + def load_best_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + return self.load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_best") + + def _should_save_best(self, metric: float) -> bool: + if not self._best_full(): + return True + + # already full + worst_record_path = max(self.best_records, key=self.best_records.get) + # worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0] + worst_metric = self.best_records[worst_record_path] + return metric < worst_metric + + def _best_full(self): + return (not self._save_all) and len(self.best_records) == self.kbest_n + + def _latest_full(self): + return len(self.latest_records) == self.latest_n + + def _save_best_checkpoint_and_update(self, metric, checkpoint_dir, + tag_or_iteration, model, optimizer, + infos): + # remove the worst + if self._best_full(): + worst_record_path = max(self.best_records, + key=self.best_records.get) + self.best_records.pop(worst_record_path) + if (worst_record_path not in self.latest_records): + self.logger.info( + "remove the worst checkpoint: {}".format(worst_record_path)) + self._del_checkpoint(checkpoint_dir, worst_record_path) + + # add the new one + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + self.best_records[tag_or_iteration] = metric + + def _save_latest_checkpoint_and_update( + self, checkpoint_dir, tag_or_iteration, model, optimizer, infos): + # remove the old + if self._latest_full(): + to_del_fn = self.latest_records.pop(0) + if (to_del_fn not in self.best_records.keys()): + self.logger.info( + "remove the latest checkpoint: {}".format(to_del_fn)) + self._del_checkpoint(checkpoint_dir, to_del_fn) + self.latest_records.append(tag_or_iteration) + + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + + def _del_checkpoint(self, checkpoint_dir, tag_or_iteration): + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(tag_or_iteration)) + for filename in glob.glob(checkpoint_path + ".*"): + os.remove(filename) + self.logger.info("delete file: {}".format(filename)) + + def _load_checkpoint_idx(self, checkpoint_record: str) -> int: + """Get the iteration number corresponding to the latest saved checkpoint. + Args: + checkpoint_path (str): the saved path of checkpoint. + Returns: + int: the latest iteration number. -1 for no checkpoint to load. + """ + if not os.path.isfile(checkpoint_record): + return -1 + + # Fetch the latest checkpoint index. + with open(checkpoint_record, "rt") as handle: + latest_checkpoint = handle.readlines()[-1].strip() + iteration = int(latest_checkpoint.split(":")[-1]) + return iteration + + def _save_checkpoint_record(self, checkpoint_dir: str, iteration: int): + """Save the iteration number of the latest model to be checkpoint record. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + iteration (int): the latest iteration number. + Returns: + None + """ + checkpoint_record_latest = os.path.join(checkpoint_dir, + "checkpoint_latest") + checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best") + + with open(checkpoint_record_best, "w") as handle: + for i in self.best_records.keys(): + handle.write("model_checkpoint_path:{}\n".format(i)) + with open(checkpoint_record_latest, "w") as handle: + for i in self.latest_records: + handle.write("model_checkpoint_path:{}\n".format(i)) + + @mp_tools.rank_zero_only + def _save_parameters(self, + checkpoint_dir: str, + tag_or_iteration: Union[int, str], + model: paddle.nn.Layer, + optimizer: Optimizer=None, + infos: dict=None): + """Checkpoint the latest trained model parameters. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + tag_or_iteration (int or str): the latest iteration(step or epoch) number. + model (Layer): model to be checkpointed. + optimizer (Optimizer, optional): optimizer to be checkpointed. + Defaults to None. + infos (dict or None): any info you want to save. + Returns: + None + """ + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(tag_or_iteration)) + + model_dict = model.state_dict() + params_path = checkpoint_path + ".pdparams" + paddle.save(model_dict, params_path) + self.logger.info("Saved model to {}".format(params_path)) + + if optimizer: + opt_dict = optimizer.state_dict() + optimizer_path = checkpoint_path + ".pdopt" + paddle.save(opt_dict, optimizer_path) + self.logger.info( + "Saved optimzier state to {}".format(optimizer_path)) + + info_path = re.sub('.pdparams$', '.json', params_path) + infos = {} if infos is None else infos + with open(info_path, 'w') as fout: + data = json.dumps(infos) + fout.write(data) diff --git a/text_processing/speechtask/punctuation_restoration/utils/default_parser.py b/text_processing/speechtask/punctuation_restoration/utils/default_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..b83d989d62645d7f21df3f5b5c18006666a69df1 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/utils/default_parser.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + + +def default_argument_parser(): + r"""A simple yet genral argument parser for experiments with parakeet. + + This is used in examples with parakeet. And it is intended to be used by + other experiments with parakeet. It requires a minimal set of command line + arguments to start a training script. + + The ``--config`` and ``--opts`` are used for overwrite the deault + configuration. + + The ``--data`` and ``--output`` specifies the data path and output path. + Resuming training from existing progress at the output directory is the + intended default behavior. + + The ``--checkpoint_path`` specifies the checkpoint to load from. + + The ``--device`` and ``--nprocs`` specifies how to run the training. + + + See Also + -------- + parakeet.training.experiment + Returns + ------- + argparse.ArgumentParser + the parser + """ + parser = argparse.ArgumentParser() + + # yapf: disable + # data and output + parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.") + parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.") + # parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.") + parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.") + + # load from saved checkpoint + parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load") + + # save jit model to + parser.add_argument("--export_path", type=str, help="path of the jit model to save") + + # save asr result to + parser.add_argument("--result_file", type=str, help="path of save the asr result") + + # running + parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"], + help="device type to use, cpu and gpu are supported.") + parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.") + + # overwrite extra config and default config + # parser.add_argument("--opts", nargs=argparse.REMAINDER, + # help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") + parser.add_argument("--opts", type=str, default=[], nargs='+', + help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") + # yapd: enable + + return parser diff --git a/text_processing/speechtask/punctuation_restoration/utils/layer_tools.py b/text_processing/speechtask/punctuation_restoration/utils/layer_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..fb076c0c716938b85e0b52a4268b71993e51a475 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/utils/layer_tools.py @@ -0,0 +1,88 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from paddle import nn + +__all__ = [ + "summary", "gradient_norm", "freeze", "unfreeze", "print_grads", + "print_params" +] + + +def summary(layer: nn.Layer, print_func=print): + if print_func is None: + return + num_params = num_elements = 0 + for name, param in layer.state_dict().items(): + if print_func: + print_func( + "{} | {} | {}".format(name, param.shape, np.prod(param.shape))) + num_elements += np.prod(param.shape) + num_params += 1 + if print_func: + num_elements = num_elements / 1024**2 + print_func( + f"Total parameters: {num_params}, {num_elements:.2f}M elements.") + + +def print_grads(model, print_func=print): + if print_func is None: + return + for n, p in model.named_parameters(): + msg = f"param grad: {n}: shape: {p.shape} grad: {p.grad}" + print_func(msg) + + +def print_params(model, print_func=print): + if print_func is None: + return + total = 0.0 + num_params = 0.0 + for n, p in model.named_parameters(): + msg = f"{n} | {p.shape} | {np.prod(p.shape)} | {not p.stop_gradient}" + total += np.prod(p.shape) + num_params += 1 + if print_func: + print_func(msg) + if print_func: + total = total / 1024**2 + print_func(f"Total parameters: {num_params}, {total:.2f}M elements.") + + +def gradient_norm(layer: nn.Layer): + grad_norm_dict = {} + for name, param in layer.state_dict().items(): + if param.trainable: + grad = param.gradient() # return numpy.ndarray + grad_norm_dict[name] = np.linalg.norm(grad) / grad.size + return grad_norm_dict + + +def recursively_remove_weight_norm(layer: nn.Layer): + for layer in layer.sublayers(): + try: + nn.utils.remove_weight_norm(layer) + except ValueError as e: + # ther is not weight norm hoom in this layer + pass + + +def freeze(layer: nn.Layer): + for param in layer.parameters(): + param.trainable = False + + +def unfreeze(layer: nn.Layer): + for param in layer.parameters(): + param.trainable = True diff --git a/text_processing/speechtask/punctuation_restoration/utils/mp_tools.py b/text_processing/speechtask/punctuation_restoration/utils/mp_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..d3e25aab68ad597df14f168095db9080e48ee997 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/utils/mp_tools.py @@ -0,0 +1,30 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import wraps + +from paddle import distributed as dist + +__all__ = ["rank_zero_only"] + + +def rank_zero_only(func): + @wraps(func) + def wrapper(*args, **kwargs): + rank = dist.get_rank() + if rank != 0: + return + result = func(*args, **kwargs) + return result + + return wrapper diff --git a/text_processing/speechtask/punctuation_restoration/utils/punct_pre.py b/text_processing/speechtask/punctuation_restoration/utils/punct_pre.py new file mode 100644 index 0000000000000000000000000000000000000000..7f1431829df340348f834fa8fe4d9249859659e8 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/utils/punct_pre.py @@ -0,0 +1,163 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shutil + +CHINESE_PUNCTUATION_MAPPING = { + 'O': '', + ',': ",", + '。': '。', + '?': '?', +} + + +def process_one_file_chinese(raw_path, save_path): + f = open(raw_path, 'r', encoding='utf-8') + save_file = open(save_path, 'w', encoding='utf-8') + for line in f.readlines(): + line = line.strip().replace(' ', '').replace(' ', '') + for i in line: + save_file.write(i + ' ') + save_file.write('\n') + save_file.close() + + +def process_chinese_pure_senetence(config): + ####need raw_path, raw_train_file, raw_dev_file, raw_test_file, punc_file, save_path + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_train_file"])), "train file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_dev_file"])), "dev file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_test_file"])), "test file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "punc_file"])), "punc file doesn't exist." + + train_file = os.path.join(config["raw_path"], config["raw_train_file"]) + dev_file = os.path.join(config["raw_path"], config["raw_dev_file"]) + test_file = os.path.join(config["raw_path"], config["raw_test_file"]) + if not os.path.exists(config["save_path"]): + os.makedirs(config["save_path"]) + + shutil.copy( + os.path.join(config["raw_path"], config["punc_file"]), + os.path.join(config["save_path"], config["punc_file"])) + + process_one_file_chinese(train_file, + os.path.join(config["save_path"], "train")) + process_one_file_chinese(dev_file, os.path.join(config["save_path"], "dev")) + process_one_file_chinese(test_file, + os.path.join(config["save_path"], "test")) + + +def process_one_chinese_pair(raw_path, save_path): + + f = open(raw_path, 'r', encoding='utf-8') + save_file = open(save_path, 'w', encoding='utf-8') + for line in f.readlines(): + if (len(line.strip().split()) == 2): + word, punc = line.strip().split() + save_file.write(word + ' ' + CHINESE_PUNCTUATION_MAPPING[punc]) + if (punc == "。"): + save_file.write("\n") + else: + save_file.write(" ") + save_file.close() + + +def process_chinese_pair(config): + ### need raw_path, raw_train_file, raw_dev_file, raw_test_file, punc_file, save_path + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_train_file"])), "train file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_dev_file"])), "dev file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_test_file"])), "test file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "punc_file"])), "punc file doesn't exist." + + train_file = os.path.join(config["raw_path"], config["raw_train_file"]) + dev_file = os.path.join(config["raw_path"], config["raw_dev_file"]) + test_file = os.path.join(config["raw_path"], config["raw_test_file"]) + + process_one_chinese_pair(train_file, + os.path.join(config["save_path"], "train")) + process_one_chinese_pair(dev_file, os.path.join(config["save_path"], "dev")) + process_one_chinese_pair(test_file, + os.path.join(config["save_path"], "test")) + + shutil.copy( + os.path.join(config["raw_path"], config["punc_file"]), + os.path.join(config["save_path"], config["punc_file"])) + + +english_punc = [',', '.', '?'] +ignore_english_punc = ['\"', '/'] + + +def process_one_file_english(raw_path, save_path): + f = open(raw_path, 'r', encoding='utf-8') + save_file = open(save_path, 'w', encoding='utf-8') + for line in f.readlines(): + for i in ignore_english_punc: + line = line.replace(i, '') + for i in english_punc: + line = line.replace(i, ' ' + i) + wordlist = line.strip().split(' ') + # print(type(wordlist)) + # print(wordlist) + for i in wordlist: + save_file.write(i + ' ') + save_file.write('\n') + save_file.close() + + +def process_english_pure_senetence(config): + ####need raw_path, raw_train_file, raw_dev_file, raw_test_file, punc_file, save_path + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_train_file"])), "train file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_dev_file"])), "dev file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "raw_test_file"])), "test file doesn't exist." + assert os.path.exists( + os.path.join(config["raw_path"], config[ + "punc_file"])), "punc file doesn't exist." + + train_file = os.path.join(config["raw_path"], config["raw_train_file"]) + dev_file = os.path.join(config["raw_path"], config["raw_dev_file"]) + test_file = os.path.join(config["raw_path"], config["raw_test_file"]) + if not os.path.exists(config["save_path"]): + os.makedirs(config["save_path"]) + + shutil.copy( + os.path.join(config["raw_path"], config["punc_file"]), + os.path.join(config["save_path"], config["punc_file"])) + + process_one_file_english(train_file, + os.path.join(config["save_path"], "train")) + process_one_file_english(dev_file, os.path.join(config["save_path"], "dev")) + process_one_file_english(test_file, + os.path.join(config["save_path"], "test")) diff --git a/text_processing/speechtask/punctuation_restoration/utils/utility.py b/text_processing/speechtask/punctuation_restoration/utils/utility.py new file mode 100644 index 0000000000000000000000000000000000000000..64570026bce8cd766a75e459aeabd7cae4b33a18 --- /dev/null +++ b/text_processing/speechtask/punctuation_restoration/utils/utility.py @@ -0,0 +1,81 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains common utility functions.""" +import distutils.util +import math +import os +from typing import List + +__all__ = ['print_arguments', 'add_arguments', "log_add"] + + +def print_arguments(args, info=None): + """Print argparse's arguments. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + parser.add_argument("name", default="Jonh", type=str, help="User name.") + args = parser.parse_args() + print_arguments(args) + + :param args: Input argparse.Namespace for printing. + :type args: argparse.Namespace + """ + filename = "" + if info: + filename = info["__file__"] + filename = os.path.basename(filename) + print(f"----------- {filename} Configuration Arguments -----------") + for arg, value in sorted(vars(args).items()): + print("%s: %s" % (arg, value)) + print("-----------------------------------------------------------") + + +def add_arguments(argname, type, default, help, argparser, **kwargs): + """Add argparse's argument. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + add_argument("name", str, "Jonh", "User name.", parser) + args = parser.parse_args() + """ + type = distutils.util.strtobool if type == bool else type + argparser.add_argument( + "--" + argname, + default=default, + type=type, + help=help + ' Default: %(default)s.', + **kwargs) + + +def log_add(args: List[int]) -> float: + """Stable log add + + Args: + args (List[int]): log scores + + Returns: + float: sum of log scores + """ + if all(a == -float('inf') for a in args): + return -float('inf') + a_max = max(args) + lsp = math.log(sum(math.exp(a - a_max) for a in args)) + return a_max + lsp diff --git a/tools/Makefile b/tools/Makefile index 87107a534baa3738fd94a96126ef957a4c268b41..e2aba8feba75ebc6da5dedf3f9f857b1db1d7171 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -24,7 +24,7 @@ clean: apt.done: apt update -y - apt install -y bc flac jq vim tig tree pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev + apt install -y bc flac jq vim tig tree pkg-config libsndfile1 libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev echo "check_certificate = off" >> ~/.wgetrc touch apt.done diff --git a/tools/extras/README.md b/tools/extras/README.md index 19c06a1342bb2949f6fd15e31aae8add924f7c8d..7d03c4bee519970d712c98d902a7a4f47913138c 100644 --- a/tools/extras/README.md +++ b/tools/extras/README.md @@ -1,3 +1,7 @@ +# install scripts +call from `tools` dir. + +## Details 1. kaldi deps gcc, mkl or openblas diff --git a/tools/extras/install_soundfile.sh b/tools/extras/install_soundfile.sh new file mode 100755 index 0000000000000000000000000000000000000000..cbc4e00d167a941ca3adb59371de3d93014b67de --- /dev/null +++ b/tools/extras/install_soundfile.sh @@ -0,0 +1,18 @@ +# install package libsndfile + +WGET=wget --no-check-certificate + +SOUNDFILE=libsndfile-1.0.28 +SOUNDFILE_LIB=${SOUNDFILE}tar.gz + +echo "Install package libsndfile into default system path." +test -e ${SOUNDFILE_LIB} || ${WGET} -c "http://www.mega-nerd.com/libsndfile/files/${SOUNDFILE_LIB}" +if [ $? != 0 ]; then + echo "Download ${SOUNDFILE_LIB} failed !!!" + exit 1 +fi + +tar -zxvf ${SOUNDFILE_LIB} +pushd ${SOUNDFILE} +./configure > /dev/null && make > /dev/null && make install > /dev/null +popd \ No newline at end of file
Acoustic Model Aishell2 Conv + 5 LSTM layers with only forward direction 2 Conv + 5 LSTM layers with only forward direction Ds2 Online Aishell Model Text Frontend - chinese-fronted + chinese-fronted
Tacotron2 LJSpeech - tacotron2-vctk + tacotron2-vctk
TransformerTTS - transformer-ljspeech + transformer-ljspeech
SpeedySpeech CSMSC - speedyspeech-csmsc + speedyspeech-csmsc
FastSpeech2 AISHELL-3 - fastspeech2-aishell3 + fastspeech2-aishell3
VCTK fastspeech2-vctk fastspeech2-vctk
LJSpeech fastspeech2-ljspeech fastspeech2-ljspeech
CSMSC - fastspeech2-csmsc + fastspeech2-csmsc
WaveFlow LJSpeech - waveflow-ljspeech + waveflow-ljspeech
Parallel WaveGAN LJSpeech - PWGAN-ljspeech + PWGAN-ljspeech
VCTK - PWGAN-vctk + PWGAN-vctk
CSMSC - PWGAN-csmsc + PWGAN-csmsc
GE2E AISHELL-3, etc. - ge2e + ge2e
GE2E + Tactron2 AISHELL-3 - ge2e-tactron2-aishell3 + ge2e-tactron2-aishell3