diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000000000000000000000000000000000000..b31d9863116691dee0fe8ba7a9c1e002c8b068a6 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,42 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: '' +assignees: '' + +--- + +For support and discussions, please use our [Discourse forums](https://github.com/PaddlePaddle/DeepSpeech/discussions). + +If you've found a bug then please create an issue with the following information: + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +Steps to reproduce the behavior: +1. Go to '...' +2. Click on '....' +3. Scroll down to '....' +4. See error + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Screenshots** +If applicable, add screenshots to help explain your problem. + +** Environment (please complete the following information):** + - OS: [e.g. Ubuntu] + - GCC/G++ Version [e.g. 8.3] + - Python Version [e.g. 3.7] + - PaddlePaddle Version [e.g. 2.0.0] + - Model Version [e.g. 2.0.0] + - GPU/DRIVER Informationo [e.g. Tesla V100-SXM2-32GB/440.64.00] + - CUDA/CUDNN Version [e.g. cuda-10.2] + - MKL Version +- TensorRT Version + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000000000000000000000000000000000000..94d507035946018f860f989aa0b1d537f1245d4f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,24 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: "[Feature request]" +labels: feature request +assignees: '' + +--- + +For support and discussions, please use our [Discourse forums](https://github.com/PaddlePaddle/DeepSpeech/discussions). + +If you've found a feature request then please create an issue with the following information: + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.travis/unittest.sh b/.travis/unittest.sh index c152a1bc57e156f4a1856d6dc9aa0d787d929bc9..416042c8cbe2715b1f8fff9aab690d98a0fd9c35 100755 --- a/.travis/unittest.sh +++ b/.travis/unittest.sh @@ -11,6 +11,13 @@ abort(){ unittest(){ cd $1 > /dev/null + if [ -f "setup.sh" ]; then + bash setup.sh + export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH + fi + if [ $? != 0 ]; then + exit 1 + fi find . -path ./tools/venv -prune -false -o -name 'tests' -type d -print0 | \ xargs -0 -I{} -n1 bash -c \ 'python3 -m unittest discover -v -s {}' @@ -19,6 +26,15 @@ unittest(){ coverage(){ cd $1 > /dev/null + + if [ -f "setup.sh" ]; then + bash setup.sh + export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH + fi + if [ $? != 0 ]; then + exit 1 + fi + find . -path ./tools/venv -prune -false -o -name 'tests' -type d -print0 | \ xargs -0 -I{} -n1 bash -c \ 'python3 -m coverage run --branch {}' diff --git a/README.md b/README.md index f188814996d8149876011ea8234aa2e5bfb53125..9330c005f15c385d2559cfc0bb295e6abf26a98e 100644 --- a/README.md +++ b/README.md @@ -21,26 +21,13 @@ * python>=3.7 * paddlepaddle>=2.0.0 -- Run the setup script for the remaining dependencies - -```bash -git clone https://github.com/PaddlePaddle/DeepSpeech.git -cd DeepSpeech -pushd tools; make; popd -source tools/venv/bin/activate -bash setup.sh -``` - -- Source venv before do experiment. - -```bash -source tools/venv/bin/activate -``` +Please see [install](docs/install.md). ## Getting Started Please see [Getting Started](docs/src/geting_started.md) and [tiny egs](examples/tiny/README.md). + ## More Information * [Install](docs/src/install.md) @@ -56,7 +43,7 @@ Please see [Getting Started](docs/src/geting_started.md) and [tiny egs](examples ## Questions and Help -You are welcome to submit questions and bug reports in [Github Issues](https://github.com/PaddlePaddle/DeepSpeech/issues). You are also welcome to contribute to this project. +You are welcome to submit questions in [Github Discussions](https://github.com/PaddlePaddle/DeepSpeech/discussions) and bug reports in [Github Issues](https://github.com/PaddlePaddle/DeepSpeech/issues). You are also welcome to contribute to this project. ## License diff --git a/README_cn.md b/README_cn.md index 769130472f20700a0ae451ae52c63ddbe7182c1f..b50d205e904eaee416402826f284ad5ef15f5a34 100644 --- a/README_cn.md +++ b/README_cn.md @@ -18,24 +18,11 @@ ## 安装 + * python>=3.7 * paddlepaddle>=2.0.0 -- 安装依赖 - -```bash -git clone https://github.com/PaddlePaddle/DeepSpeech.git -cd DeepSpeech -pushd tools; make; popd -source tools/venv/bin/activate -bash setup.sh -``` - -- 开始实验前要source环境. - -```bash -source tools/venv/bin/activate -``` +参看 [安装](docs/install.md)。 ## 开始 @@ -55,7 +42,7 @@ source tools/venv/bin/activate ## 问题和帮助 -欢迎您在[Github问题](https://github.com/PaddlePaddle/models/issues)中提交问题和bug。也欢迎您为这个项目做出贡献。 +欢迎您在[Github讨论](https://github.com/PaddlePaddle/DeepSpeech/discussions)提交问题,[Github问题](https://github.com/PaddlePaddle/models/issues)中反馈bug。也欢迎您为这个项目做出贡献。 ## License diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 75335d318f7891e8c37bd98fe11173e0d894e607..e3d6369bb4731447497015f419f58103661c8e55 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -58,8 +58,6 @@ class DeepSpeech2Trainer(Trainer): losses_np = { 'train_loss': float(loss), - 'train_loss_div_batchsize': - float(loss) / self.config.data.batch_size } msg = "Train: Rank: {}, ".format(dist.get_rank()) msg += "epoch: {}, ".format(self.epoch) @@ -85,8 +83,6 @@ class DeepSpeech2Trainer(Trainer): loss = self.model(*batch) valid_losses['val_loss'].append(float(loss)) - valid_losses['val_loss_div_batchsize'].append( - float(loss) / self.config.data.batch_size) # write visual log valid_losses = {k: np.mean(v) for k, v in valid_losses.items()} @@ -265,7 +261,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): self.logger.info(msg) def run_test(self): - self.resume_or_load() + self.resume_or_scratch() try: self.test() except KeyboardInterrupt: diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index b5f90f04665f8a596d2cc71589c21bd6c59348aa..38e24cef58c477ed5c9867fdbe35b3ee7960f2da 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -80,12 +80,15 @@ class U2Trainer(Trainer): self.model.train() start = time.time() + loss, attention_loss, ctc_loss = self.model(*batch_data) loss.backward() layer_tools.print_grads(self.model, print_func=None) + if self.iteration % train_conf.accum_grad == 0: self.optimizer.step() self.optimizer.clear_grad() + self.lr_scheduler.step() iteration_time = time.time() - start @@ -102,11 +105,49 @@ class U2Trainer(Trainer): if self.iteration % train_conf.log_interval == 0: self.logger.info(msg) + # display if dist.get_rank() == 0 and self.visualizer: for k, v in losses_np.items(): self.visualizer.add_scalar("train/{}".format(k), v, self.iteration) + def train(self): + """The training process. + It includes forward/backward/update and periodical validation and + saving. + """ + # !!!IMPORTANT!!! + # Try to export the model by script, if fails, we should refine + # the code to satisfy the script export requirements + # script_model = paddle.jit.to_static(self.model) + # script_model_path = str(self.checkpoint_dir / 'init') + # paddle.jit.save(script_model, script_model_path) + + from_scratch = self.resume_or_scratch() + self.logger.info( + f"Train Total Examples: {len(self.train_loader.dataset)}") + + while self.epoch <= self.config.training.n_epoch: + try: + data_start_time = time.time() + for batch in self.train_loader: + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "lr: {}, ".foramt(self.lr_scheduler()) + msg += "dataloader time: {:>.3f}s, ".format(dataload_time) + self.iteration += 1 + self.train_batch(batch, msg) + data_start_time = time.time() + except Exception as e: + self.logger.error(e) + raise e + + self.valid() + self.save() + self.new_epoch() + @mp_tools.rank_zero_only @paddle.no_grad() def valid(self): @@ -365,7 +406,7 @@ class U2Tester(U2Trainer): self.logger.info(msg) def run_test(self): - self.resume_or_load() + self.resume_or_scratch() try: self.test() except KeyboardInterrupt: diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index 7cba2f2cb8c23d3d48fa71ffd2cd79041745b017..884fa4b1f1428a9285eb70336020e5337f5f1000 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -159,7 +159,8 @@ class DeepSpeech2Model(nn.Layer): enc_n_units=self.encoder.output_size, blank_id=dict_size, # last token is dropout_rate=0.0, - reduction=True) + reduction=True, # sum + batch_average=True) # sum / batch_size def forward(self, audio, audio_len, text, text_len): """Compute Model loss diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 16573a38fdb0cc0fa3bd5061040739a875e67a68..90f3e32272929b91275788769a6f2064c5f86620 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -834,7 +834,14 @@ class U2Model(U2BaseModel): decoder = TransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) - ctc = CTCDecoder(vocab_size, encoder.output_size()) + ctc = CTCDecoder( + odim=vocab_size, + enc_n_units=encoder.output_size(), + blank_id=0, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True) # sum / batch_size + return vocab_size, encoder, decoder, ctc @classmethod diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py index 64508a74d18183008fe44d88a61f657191b65cd2..1cd7a3c850a12989ac70acade0d4c91c69ebbb67 100644 --- a/deepspeech/modules/ctc.py +++ b/deepspeech/modules/ctc.py @@ -37,14 +37,16 @@ class CTCDecoder(nn.Layer): enc_n_units, blank_id=0, dropout_rate: float=0.0, - reduction: bool=True): + reduction: bool=True, + batch_average: bool=True): """CTC decoder Args: odim ([int]): text vocabulary size enc_n_units ([int]): encoder output dimention dropout_rate (float): dropout rate (0.0 ~ 1.0) - reduction (bool): reduce the CTC loss into a scalar + reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none' + batch_average (bool): do batch dim wise average. """ assert check_argument_types() super().__init__() @@ -54,7 +56,10 @@ class CTCDecoder(nn.Layer): self.dropout_rate = dropout_rate self.ctc_lo = nn.Linear(enc_n_units, self.odim) reduction_type = "sum" if reduction else "none" - self.criterion = CTCLoss(blank=self.blank_id, reduction=reduction_type) + self.criterion = CTCLoss( + blank=self.blank_id, + reduction=reduction_type, + batch_average=batch_average) # CTCDecoder LM Score handle self._ext_scorer = None diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index cb65ba1403d1894fd59ec807169fc7288680c29b..95ca644ad5f4ae4ced28b7b4cc499a250fcf63bc 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -24,32 +24,33 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"] class CTCLoss(nn.Layer): - def __init__(self, blank=0, reduction='sum'): + def __init__(self, blank=0, reduction='sum', batch_average=False): super().__init__() # last token id as blank id self.loss = nn.CTCLoss(blank=blank, reduction=reduction) + self.batch_average = batch_average def forward(self, logits, ys_pad, hlens, ys_lens): """Compute CTC loss. Args: - logits ([paddle.Tensor]): [description] - ys_pad ([paddle.Tensor]): [description] - hlens ([paddle.Tensor]): [description] - ys_lens ([paddle.Tensor]): [description] + logits ([paddle.Tensor]): [B, Tmax, D] + ys_pad ([paddle.Tensor]): [B, Tmax] + hlens ([paddle.Tensor]): [B] + ys_lens ([paddle.Tensor]): [B] Returns: [paddle.Tensor]: scalar. If reduction is 'none', then (N), where N = \text{batch size}. """ + B = paddle.shape(logits)[0] # warp-ctc need logits, and do softmax on logits by itself # warp-ctc need activation with shape [T, B, V + 1] # logits: (B, L, D) -> (L, B, D) logits = logits.transpose([1, 0, 2]) loss = self.loss(logits, ys_pad, hlens, ys_lens) - - # wenet do batch-size average, deepspeech2 not do this - # Batch-size average - # loss = loss / paddle.shape(logits)[1] + if self.batch_average: + # Batch-size average + loss = loss / B return loss diff --git a/deepspeech/training/scheduler.py b/deepspeech/training/scheduler.py index 54103e06105f69498d1cde1c4c1329fc088424bf..08e9d4121cc5abe95954d1009106b78d718046d0 100644 --- a/deepspeech/training/scheduler.py +++ b/deepspeech/training/scheduler.py @@ -54,4 +54,4 @@ class WarmupLR(LRScheduler): step_num**-0.5, step_num * self.warmup_steps**-1.5) def set_step(self, step: int): - self.last_epoch = step + self.step(step) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 982faa989dfc0956f709670aea9b5760c15ea698..474f8d728444a47a62e8870a0a4c3126620d2597 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -139,7 +139,7 @@ class Trainer(): checkpoint.save_parameters(self.checkpoint_dir, self.iteration, self.model, self.optimizer, infos) - def resume_or_load(self): + def resume_or_scratch(self): """Resume from latest checkpoint at checkpoints in the output directory or load a specified checkpoint. @@ -152,8 +152,20 @@ class Trainer(): checkpoint_dir=self.checkpoint_dir, checkpoint_path=self.args.checkpoint_path) if infos: + # restore from ckpt self.iteration = infos["step"] self.epoch = infos["epoch"] + self.lr_scheduler.step(self.iteration) + if self.parallel: + self.train_loader.batch_sampler.set_epoch(self.epoch) + return False + else: + # from scratch, epoch and iteration init with zero + # save init model, i.e. 0 epoch + self.save() + # self.epoch start from 1. + self.new_epoch() + return True def new_epoch(self): """Reset the train loader and increment ``epoch``. @@ -166,22 +178,22 @@ class Trainer(): def train(self): """The training process. - It includes forward/backward/update and periodical validation and - saving. """ + from_scratch = self.resume_or_scratch() + self.logger.info( f"Train Total Examples: {len(self.train_loader.dataset)}") - self.new_epoch() while self.epoch <= self.config.training.n_epoch: try: data_start_time = time.time() for batch in self.train_loader: dataload_time = time.time() - data_start_time + # iteration start from 1. + self.iteration += 1 msg = "Train: Rank: {}, ".format(dist.get_rank()) msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) msg += "dataloader time: {:>.3f}s, ".format(dataload_time) - self.iteration += 1 self.train_batch(batch, msg) data_start_time = time.time() except Exception as e: @@ -190,6 +202,7 @@ class Trainer(): self.valid() self.save() + # lr control by epoch self.lr_scheduler.step() self.new_epoch() @@ -197,7 +210,6 @@ class Trainer(): """The routine of the experiment after setup. This method is intended to be used by the user. """ - self.resume_or_load() try: self.train() except KeyboardInterrupt: @@ -298,7 +310,7 @@ class Trainer(): # global logger stdout = False - save_path = log_file + save_path = str(log_file) logging.basicConfig( level=logging.DEBUG if stdout else logging.INFO, format=format, diff --git a/docs/src/geting_started.md b/docs/src/getting_started.md similarity index 100% rename from docs/src/geting_started.md rename to docs/src/getting_started.md diff --git a/docs/src/install.md b/docs/src/install.md index 72b7b6988f6093066010d59cb37b651c42a4d22f..01049a2fc5352c8b692e9c607e4a064a562e3623 100644 --- a/docs/src/install.md +++ b/docs/src/install.md @@ -45,7 +45,7 @@ source tools/venv/bin/activate ## Running in Docker Container (optional) -Docker is an open source tool to build, ship, and run distributed applications in an isolated environment. A Docker image for this project has been provided in [hub.docker.com](https://hub.docker.com) with all the dependencies installed, including the pre-built PaddlePaddle, CTC decoders, and other necessary Python and third-party packages. This Docker image requires the support of NVIDIA GPU, so please make sure its availiability and the [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) has been installed. +Docker is an open source tool to build, ship, and run distributed applications in an isolated environment. A Docker image for this project has been provided in [hub.docker.com](https://hub.docker.com) with all the dependencies installed. This Docker image requires the support of NVIDIA GPU, so please make sure its availiability and the [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) has been installed. Take several steps to launch the Docker image: @@ -79,3 +79,7 @@ For example, for CUDA 10.1, CuDNN7.5 install paddle 2.0.0: ```bash python3 -m pip install paddlepaddle-gpu==2.0.0 ``` + +- Install Deepspeech + +Please see [Setup](#setup) section. diff --git a/docs/src/ngram_lm.md b/docs/src/ngram_lm.md index 48c557ce93515a15439d4e8572e251763b0da450..1417d329e8650b806b858f1f0e733cb57fd64c7b 100644 --- a/docs/src/ngram_lm.md +++ b/docs/src/ngram_lm.md @@ -1,6 +1,8 @@ # Prepare Language Model -A language model is required to improve the decoder's performance. We have prepared two language models (with lossy compression) for users to download and try. One is for English and the other is for Mandarin. Users can simply run this to download the preprared language models: +A language model is required to improve the decoder's performance. We have prepared two language models (with lossy compression) for users to download and try. One is for English and the other is for Mandarin. The bash script to download LM is example's `local/download_lm_*.sh`. + +For example, users can simply run this to download the preprared mandarin language models: ```bash cd examples/aishell @@ -8,7 +10,9 @@ source path.sh bash local/download_lm_ch.sh ``` -If you wish to train your own better language model, please refer to [KenLM](https://github.com/kpu/kenlm) for tutorials. Here we provide some tips to show how we preparing our English and Mandarin language models. You can take it as a reference when you train your own. +If you wish to train your own better language model, please refer to [KenLM](https://github.com/kpu/kenlm) for tutorials. +Here we provide some tips to show how we preparing our English and Mandarin language models. +You can take it as a reference when you train your own. ## English LM diff --git a/examples/aishell/.gitignore b/examples/aishell/.gitignore index 389676a70b38c845d7cc16d93166a7400084fb37..3c13afe8afdd8128054fac130d4d810c14b7bc33 100644 --- a/examples/aishell/.gitignore +++ b/examples/aishell/.gitignore @@ -2,3 +2,4 @@ data ckpt* demo_cache *.log +log diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index e06ae0239e8a3a746576ce0863d135dc0400dc0c..5a386b985cf47f273d140ab959773a3729fe0238 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -29,8 +29,8 @@ model: use_gru: True share_rnn_weights: False training: - n_epoch: 30 - lr: 5e-4 + n_epoch: 50 + lr: 2e-3 lr_decay: 0.83 weight_decay: 1e-06 global_grad_clip: 5.0 @@ -39,7 +39,7 @@ decoding: error_rate_type: cer decoding_method: ctc_beam_search lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm - alpha: 2.6 + alpha: 1.9 beta: 5.0 beam_size: 300 cutoff_prob: 0.99 diff --git a/examples/aishell/s0/local/infer.sh b/examples/aishell/s0/local/infer.sh index 41ccabf803f55975866f41a47dc410abb5edca9a..8c6a4dca28a40dada801caa846a4fbb39448fdfb 100644 --- a/examples/aishell/s0/local/infer.sh +++ b/examples/aishell/s0/local/infer.sh @@ -1,6 +1,6 @@ #! /usr/bin/env bash -if [[ $# != 1 ]]; +if [[ $# != 1 ]]; then echo "usage: $0 ckpt-path" exit -1 fi diff --git a/examples/aishell/s0/local/train.sh b/examples/aishell/s0/local/train.sh index c286566a8d1e5a68279916d8d182ef72531afddc..245ed2172d0e1ec54f99ab4f675fe8298ed062d9 100644 --- a/examples/aishell/s0/local/train.sh +++ b/examples/aishell/s0/local/train.sh @@ -2,7 +2,7 @@ # train model # if you wish to resume from an exists model, uncomment --init_from_pretrained_model -export FLAGS_sync_nccl_allreduce=0 +#export FLAGS_sync_nccl_allreduce=0 ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));') echo "using $ngpu gpus..." diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index 8beb6bf0f5459d1f8a0a0eef1bde180e1564a06c..2e215a999211201548ffe182235054d42ddf733c 100644 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -7,7 +7,7 @@ source path.sh bash ./local/data.sh # train model -CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./local/train.sh +CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./local/train.sh baseline # test model CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh @@ -16,4 +16,4 @@ CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh CUDA_VISIBLE_DEVICES=0 bash ./local/infer.sh ckpt/checkpoints/step-3284 # export model -bash ./local/export.sh ckpt/checkpoints/step-3284 jit.model \ No newline at end of file +bash ./local/export.sh ckpt/checkpoints/step-3284 jit.model diff --git a/examples/librispeech/README.md b/examples/librispeech/README.md index e109e1ae449a5da457298e5f635ad85632283eca..697cb91d46d03ba7a62707d0eb68f4a1048665ff 100644 --- a/examples/librispeech/README.md +++ b/examples/librispeech/README.md @@ -1 +1 @@ -* s0 for deepspeech2 +* s0 is for deepspeech diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index 81313e611caefd73f6123dbe487b659ac2063214..2be8f78a9cf5e9c8c1ffc803e22f599b69ede4cd 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -29,8 +29,8 @@ model: use_gru: False share_rnn_weights: True training: - n_epoch: 20 - lr: 5e-4 + n_epoch: 50 + lr: 1e-3 lr_decay: 0.83 weight_decay: 1e-06 global_grad_clip: 5.0 @@ -39,7 +39,7 @@ decoding: error_rate_type: wer decoding_method: ctc_beam_search lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm - alpha: 2.5 + alpha: 1.9 beta: 0.3 beam_size: 500 cutoff_prob: 1.0 diff --git a/examples/librispeech/s0/local/infer.sh b/examples/librispeech/s0/local/infer.sh index 6fc8d39fc82c63b0fdcf4a5ae71a64d67ef13139..98b3b016a3038140f53db1e2f59b025e2c86f5df 100644 --- a/examples/librispeech/s0/local/infer.sh +++ b/examples/librispeech/s0/local/infer.sh @@ -1,6 +1,6 @@ #! /usr/bin/env bash -if [[ $# != 1 ]]; +if [[ $# != 1 ]];then echo "usage: $0 ckpt-path" exit -1 fi diff --git a/examples/librispeech/s0/local/train.sh b/examples/librispeech/s0/local/train.sh index 507947e9ebd4a93a4e1a4d4a6f536da5e977639f..cbccb1896b0cbc49fac5a4ee4606e58c0931b003 100644 --- a/examples/librispeech/s0/local/train.sh +++ b/examples/librispeech/s0/local/train.sh @@ -1,8 +1,9 @@ #! /usr/bin/env bash -export FLAGS_sync_nccl_allreduce=0 +#export FLAGS_sync_nccl_allreduce=0 + # https://github.com/PaddlePaddle/Paddle/pull/28484 -export NCCL_SHM_DISABLE=1 +#export NCCL_SHM_DISABLE=1 ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));') echo "using $ngpu gpus..." @@ -11,7 +12,7 @@ python3 -u ${BIN_DIR}/train.py \ --device 'gpu' \ --nproc ${ngpu} \ --config conf/deepspeech2.yaml \ ---output ckpt +--output ckpt-${1} if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/tiny/s0/local/infer.sh b/examples/tiny/s0/local/infer.sh index 1243c0d082d23362c2a6146cf6185a70dcc8ac6c..b36f9000ac2b3fbb5c2e74a45b97f2b6c84d5182 100644 --- a/examples/tiny/s0/local/infer.sh +++ b/examples/tiny/s0/local/infer.sh @@ -1,6 +1,6 @@ #! /usr/bin/env bash -if [[ $# != 1 ]]; +if [[ $# != 1 ]];then echo "usage: $0 ckpt-path" exit -1 fi diff --git a/examples/tiny/s0/local/test.sh b/examples/tiny/s0/local/test.sh index a0f200799adc0803b5a59b4422500cb25b2f43b6..8c8c278c6b8f3f334668370c4327dade5c2c6f0a 100644 --- a/examples/tiny/s0/local/test.sh +++ b/examples/tiny/s0/local/test.sh @@ -6,7 +6,6 @@ if [ $? -ne 0 ]; then exit 1 fi -CUDA_VISIBLE_DEVICES=0 \ python3 -u ${BIN_DIR}/test.py \ --device 'gpu' \ --nproc 1 \ diff --git a/examples/tiny/s0/local/train.sh b/examples/tiny/s0/local/train.sh index 369ccc924bb104b7c638be25059b3915ae79d0e0..af62ae55f7a8b833ae42fa3b64dd2c1da8786990 100644 --- a/examples/tiny/s0/local/train.sh +++ b/examples/tiny/s0/local/train.sh @@ -2,7 +2,6 @@ export FLAGS_sync_nccl_allreduce=0 -CUDA_VISIBLE_DEVICES=0 \ python3 -u ${BIN_DIR}/train.py \ --device 'gpu' \ --nproc 1 \ diff --git a/setup.sh b/setup.sh index a58bd796740bc722b69fc7ef1558645b1213af1d..5141fd904992d9510d2819d50cd5e83efe789365 100644 --- a/setup.sh +++ b/setup.sh @@ -1,5 +1,8 @@ #! /usr/bin/env bash +source utils/log.sh + + SUDO='sudo' if [ $(id -u) -eq 0 ]; then SUDO='' @@ -8,6 +11,8 @@ fi if [ -e /etc/lsb-release ];then #${SUDO} apt-get update ${SUDO} apt-get install -y sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev + error_msg "Please using Ubuntu or install `pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev` by user." + exit -1 fi # install python dependencies @@ -15,17 +20,17 @@ if [ -f "requirements.txt" ]; then pip3 install -r requirements.txt fi if [ $? != 0 ]; then - echo "Install python dependencies failed !!!" + error_msg "Install python dependencies failed !!!" exit 1 fi # install package libsndfile python3 -c "import soundfile" if [ $? != 0 ]; then - echo "Install package libsndfile into default system path." + info_msg "Install package libsndfile into default system path." wget "http://www.mega-nerd.com/libsndfile/files/libsndfile-1.0.28.tar.gz" if [ $? != 0 ]; then - echo "Download libsndfile-1.0.28.tar.gz failed !!!" + error_msg "Download libsndfile-1.0.28.tar.gz failed !!!" exit 1 fi tar -zxvf libsndfile-1.0.28.tar.gz @@ -43,6 +48,10 @@ if [ $? != 0 ]; then sh setup.sh cd - > /dev/null fi +python3 -c "import pkg_resources; pkg_resources.require(\"swig_decoders==1.1\")" +if [ $? != 0 ]; then + error_msg "Please check why decoder install error!" + exit -1 +fi - -echo "Install all dependencies successfully." +info_msg "Install all dependencies successfully." diff --git a/utils/log.sh b/utils/log.sh new file mode 100644 index 0000000000000000000000000000000000000000..84591b076f37515c05ec7f7dd71f434c625c7eee --- /dev/null +++ b/utils/log.sh @@ -0,0 +1,11 @@ +_HDR_FMT="%.23s %s[%s]: " +_ERR_MSG_FMT="ERROR: ${_HDR_FMT}%s\n" +_INFO_MSG_FMT="INFO: ${_HDR_FMT}%s\n" + +error_msg() { + printf "$_ERR_MSG_FMT" $(date +%F.%T.%N) ${BASH_SOURCE[1]##*/} ${BASH_LINENO[0]} "${@}" +} + +info_msg() { + printf "$_INFO_MSG_FMT" $(date +%F.%T.%N) ${BASH_SOURCE[1]##*/} ${BASH_LINENO[0]} "${@}" +}