diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index f3e3fcadf99daacea13e39d0f6273e2124c0d01a..fbc357ca05c33155de91a629dbe658d2ba916f33 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -15,6 +15,7 @@ import os import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional @@ -65,29 +66,51 @@ class DeepSpeech2Trainer(Trainer): super().__init__(config, args) def train_batch(self, batch_index, batch_data, msg): + train_conf = self.config.training start = time.time() + + # forward utt, audio, audio_len, text, text_len = batch_data loss = self.model(audio, audio_len, text, text_len) - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - self.optimizer.step() - self.optimizer.clear_grad() - iteration_time = time.time() - start - losses_np = { 'train_loss': float(loss), } + + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step + if (batch_index + 1) % train_conf.accum_grad == 0: + self.optimizer.step() + self.optimizer.clear_grad() + self.iteration += 1 + + iteration_time = time.time() - start + msg += "train time: {:>.3f}s, ".format(iteration_time) msg += "batch size: {}, ".format(self.config.collator.batch_size) + msg += "accum: {}, ".format(train_conf.accum_grad) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_np.items()) logger.info(msg) if dist.get_rank() == 0 and self.visualizer: for k, v in losses_np.items(): + # `step -1` since we update `step` after optimizer.step(). self.visualizer.add_scalar("train/{}".format(k), v, - self.iteration) - self.iteration += 1 + self.iteration - 1) @paddle.no_grad() def valid(self): diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 0662e38d9fcdbf60ab764f8a2936f4b2006790f1..8ab9a26e83590648695ade9563d498943360d602 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -17,6 +17,7 @@ import os import sys import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional @@ -79,21 +80,35 @@ class U2Trainer(Trainer): def train_batch(self, batch_index, batch_data, msg): train_conf = self.config.training start = time.time() - utt, audio, audio_len, text, text_len = batch_data + # forward + utt, audio, audio_len, text, text_len = batch_data loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - losses_np = {'loss': float(loss) * train_conf.accum_grad} if attention_loss: losses_np['att_loss'] = float(attention_loss) if ctc_loss: losses_np['ctc_loss'] = float(ctc_loss) + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step if (batch_index + 1) % train_conf.accum_grad == 0: self.optimizer.step() self.optimizer.clear_grad() diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 6a932d75137b302f98eb9f8e66c402dbacc6d787..140ee947fae5eaa13e31f5002c3b6af5e3577c1a 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -17,6 +17,7 @@ import os import sys import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional @@ -83,20 +84,34 @@ class U2Trainer(Trainer): train_conf = self.config.training start = time.time() + # forward utt, audio, audio_len, text, text_len = batch_data loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - losses_np = {'loss': float(loss) * train_conf.accum_grad} if attention_loss: losses_np['att_loss'] = float(attention_loss) if ctc_loss: losses_np['ctc_loss'] = float(ctc_loss) + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step if (batch_index + 1) % train_conf.accum_grad == 0: self.optimizer.step() self.optimizer.clear_grad() diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index 5734e15f58c5fdcc843602b69475fbf60ecd006c..ef5938b7712f459a73ab7944bbb80c5ddf1fc81c 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -17,6 +17,7 @@ import os import sys import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional @@ -83,6 +84,7 @@ class U2STTrainer(Trainer): def train_batch(self, batch_index, batch_data, msg): train_conf = self.config.training start = time.time() + # forward utt, audio, audio_len, text, text_len = batch_data if isinstance(text, list) and isinstance(text_len, list): # joint training with ASR. Two decoding texts [translation, transcription] @@ -94,18 +96,30 @@ class U2STTrainer(Trainer): else: loss, st_loss, attention_loss, ctc_loss = self.model( audio, audio_len, text, text_len) + # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - losses_np = {'loss': float(loss) * train_conf.accum_grad} - losses_np['st_loss'] = float(st_loss) if attention_loss: losses_np['att_loss'] = float(attention_loss) if ctc_loss: losses_np['ctc_loss'] = float(ctc_loss) + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step if (batch_index + 1) % train_conf.accum_grad == 0: self.optimizer.step() self.optimizer.clear_grad() diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 0f465a8f7bca3af736bd07390a43cbc23179c1a3..4bf03ec63e67a59c0d75bb41540097b798651c3d 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -44,6 +44,7 @@ model: training: n_epoch: 80 + accum_grad: 1 lr: 2e-3 lr_decay: 0.83 weight_decay: 1e-06 diff --git a/examples/aishell/s0/conf/deepspeech2_online.yaml b/examples/aishell/s0/conf/deepspeech2_online.yaml index 9f05d8dd8613aa9687ae405f6c90bb53b7d654a6..9946852d0325bd6221917208aac42488dee4ac19 100644 --- a/examples/aishell/s0/conf/deepspeech2_online.yaml +++ b/examples/aishell/s0/conf/deepspeech2_online.yaml @@ -46,6 +46,7 @@ model: training: n_epoch: 50 + accum_grad: 1 lr: 2e-3 lr_decay: 0.9 # 0.83 weight_decay: 1e-06 diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index 2c31e66e12207f0310d8dc1a214a02bd308f1a2f..0e6ed5bab1b7bc4497356a6f0f7e4c2a3c64a8eb 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -11,7 +11,7 @@ data: max_output_input_ratio: .inf collator: - batch_size: 20 + batch_size: 15 mean_std_filepath: data/mean_std.json unit_type: char vocab_filepath: data/vocab.txt @@ -44,6 +44,7 @@ model: training: n_epoch: 50 + accum_grad: 4 lr: 1e-3 lr_decay: 0.83 weight_decay: 1e-06 diff --git a/examples/librispeech/s0/conf/deepspeech2_online.yaml b/examples/librispeech/s0/conf/deepspeech2_online.yaml index 87445c0b47638251a9b112c0756e76979d2579d5..6e74f704260b1c9bbd43c02c5f251e59f4074a6c 100644 --- a/examples/librispeech/s0/conf/deepspeech2_online.yaml +++ b/examples/librispeech/s0/conf/deepspeech2_online.yaml @@ -11,7 +11,7 @@ data: max_output_input_ratio: .inf collator: - batch_size: 20 + batch_size: 15 mean_std_filepath: data/mean_std.json unit_type: char vocab_filepath: data/vocab.txt @@ -46,6 +46,7 @@ model: training: n_epoch: 50 + accum_grad: 4 lr: 1e-3 lr_decay: 0.83 weight_decay: 1e-06 diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index c93217d32f93e735ed84369fa691de86f29270bc..5c9436e3941ef021a0e4609376e15eea26399d66 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -45,6 +45,7 @@ model: training: n_epoch: 10 + accum_grad: 1 lr: 1e-5 lr_decay: 1.0 weight_decay: 1e-06 diff --git a/examples/tiny/s0/conf/deepspeech2_online.yaml b/examples/tiny/s0/conf/deepspeech2_online.yaml index 4205a04acdc831d76be2593167d679408bfe3834..e435ff9691b4905a3f136cfc122dd4f1876d4605 100644 --- a/examples/tiny/s0/conf/deepspeech2_online.yaml +++ b/examples/tiny/s0/conf/deepspeech2_online.yaml @@ -4,7 +4,7 @@ data: dev_manifest: data/manifest.tiny test_manifest: data/manifest.tiny min_input_len: 0.0 - max_input_len: 27.0 + max_input_len: 30.0 min_output_len: 0.0 max_output_len: 400.0 min_output_input_ratio: 0.05 @@ -47,6 +47,7 @@ model: training: n_epoch: 10 + accum_grad: 1 lr: 1e-5 lr_decay: 1.0 weight_decay: 1e-06