提交 7e136d08 编写于 作者: H Hui Zhang

support no_sync for backward; ds support accum grad

上级 41e58631
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
...@@ -65,29 +66,51 @@ class DeepSpeech2Trainer(Trainer): ...@@ -65,29 +66,51 @@ class DeepSpeech2Trainer(Trainer):
super().__init__(config, args) super().__init__(config, args)
def train_batch(self, batch_index, batch_data, msg): def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training
start = time.time() start = time.time()
# forward
utt, audio, audio_len, text, text_len = batch_data utt, audio, audio_len, text, text_len = batch_data
loss = self.model(audio, audio_len, text, text_len) loss = self.model(audio, audio_len, text, text_len)
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
self.optimizer.step()
self.optimizer.clear_grad()
iteration_time = time.time() - start
losses_np = { losses_np = {
'train_loss': float(loss), 'train_loss': float(loss),
} }
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
self.optimizer.clear_grad()
self.iteration += 1
iteration_time = time.time() - start
msg += "train time: {:>.3f}s, ".format(iteration_time) msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.collator.batch_size) msg += "batch size: {}, ".format(self.config.collator.batch_size)
msg += "accum: {}, ".format(train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items()) for k, v in losses_np.items())
logger.info(msg) logger.info(msg)
if dist.get_rank() == 0 and self.visualizer: if dist.get_rank() == 0 and self.visualizer:
for k, v in losses_np.items(): for k, v in losses_np.items():
# `step -1` since we update `step` after optimizer.step().
self.visualizer.add_scalar("train/{}".format(k), v, self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration) self.iteration - 1)
self.iteration += 1
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
......
...@@ -17,6 +17,7 @@ import os ...@@ -17,6 +17,7 @@ import os
import sys import sys
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
...@@ -79,21 +80,35 @@ class U2Trainer(Trainer): ...@@ -79,21 +80,35 @@ class U2Trainer(Trainer):
def train_batch(self, batch_index, batch_data, msg): def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training train_conf = self.config.training
start = time.time() start = time.time()
utt, audio, audio_len, text, text_len = batch_data
# forward
utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len) text_len)
# loss div by `batch_size * accum_grad` # loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad loss /= train_conf.accum_grad
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
losses_np = {'loss': float(loss) * train_conf.accum_grad} losses_np = {'loss': float(loss) * train_conf.accum_grad}
if attention_loss: if attention_loss:
losses_np['att_loss'] = float(attention_loss) losses_np['att_loss'] = float(attention_loss)
if ctc_loss: if ctc_loss:
losses_np['ctc_loss'] = float(ctc_loss) losses_np['ctc_loss'] = float(ctc_loss)
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0: if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step() self.optimizer.step()
self.optimizer.clear_grad() self.optimizer.clear_grad()
......
...@@ -17,6 +17,7 @@ import os ...@@ -17,6 +17,7 @@ import os
import sys import sys
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
...@@ -83,20 +84,34 @@ class U2Trainer(Trainer): ...@@ -83,20 +84,34 @@ class U2Trainer(Trainer):
train_conf = self.config.training train_conf = self.config.training
start = time.time() start = time.time()
# forward
utt, audio, audio_len, text, text_len = batch_data utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len) text_len)
# loss div by `batch_size * accum_grad` # loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad loss /= train_conf.accum_grad
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
losses_np = {'loss': float(loss) * train_conf.accum_grad} losses_np = {'loss': float(loss) * train_conf.accum_grad}
if attention_loss: if attention_loss:
losses_np['att_loss'] = float(attention_loss) losses_np['att_loss'] = float(attention_loss)
if ctc_loss: if ctc_loss:
losses_np['ctc_loss'] = float(ctc_loss) losses_np['ctc_loss'] = float(ctc_loss)
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0: if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step() self.optimizer.step()
self.optimizer.clear_grad() self.optimizer.clear_grad()
......
...@@ -17,6 +17,7 @@ import os ...@@ -17,6 +17,7 @@ import os
import sys import sys
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
...@@ -83,6 +84,7 @@ class U2STTrainer(Trainer): ...@@ -83,6 +84,7 @@ class U2STTrainer(Trainer):
def train_batch(self, batch_index, batch_data, msg): def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training train_conf = self.config.training
start = time.time() start = time.time()
# forward
utt, audio, audio_len, text, text_len = batch_data utt, audio, audio_len, text, text_len = batch_data
if isinstance(text, list) and isinstance(text_len, list): if isinstance(text, list) and isinstance(text_len, list):
# joint training with ASR. Two decoding texts [translation, transcription] # joint training with ASR. Two decoding texts [translation, transcription]
...@@ -94,18 +96,30 @@ class U2STTrainer(Trainer): ...@@ -94,18 +96,30 @@ class U2STTrainer(Trainer):
else: else:
loss, st_loss, attention_loss, ctc_loss = self.model( loss, st_loss, attention_loss, ctc_loss = self.model(
audio, audio_len, text, text_len) audio, audio_len, text, text_len)
# loss div by `batch_size * accum_grad` # loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad loss /= train_conf.accum_grad
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
losses_np = {'loss': float(loss) * train_conf.accum_grad} losses_np = {'loss': float(loss) * train_conf.accum_grad}
losses_np['st_loss'] = float(st_loss)
if attention_loss: if attention_loss:
losses_np['att_loss'] = float(attention_loss) losses_np['att_loss'] = float(attention_loss)
if ctc_loss: if ctc_loss:
losses_np['ctc_loss'] = float(ctc_loss) losses_np['ctc_loss'] = float(ctc_loss)
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0: if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step() self.optimizer.step()
self.optimizer.clear_grad() self.optimizer.clear_grad()
......
...@@ -44,6 +44,7 @@ model: ...@@ -44,6 +44,7 @@ model:
training: training:
n_epoch: 80 n_epoch: 80
accum_grad: 1
lr: 2e-3 lr: 2e-3
lr_decay: 0.83 lr_decay: 0.83
weight_decay: 1e-06 weight_decay: 1e-06
......
...@@ -46,6 +46,7 @@ model: ...@@ -46,6 +46,7 @@ model:
training: training:
n_epoch: 50 n_epoch: 50
accum_grad: 1
lr: 2e-3 lr: 2e-3
lr_decay: 0.9 # 0.83 lr_decay: 0.9 # 0.83
weight_decay: 1e-06 weight_decay: 1e-06
......
...@@ -11,7 +11,7 @@ data: ...@@ -11,7 +11,7 @@ data:
max_output_input_ratio: .inf max_output_input_ratio: .inf
collator: collator:
batch_size: 20 batch_size: 15
mean_std_filepath: data/mean_std.json mean_std_filepath: data/mean_std.json
unit_type: char unit_type: char
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
...@@ -44,6 +44,7 @@ model: ...@@ -44,6 +44,7 @@ model:
training: training:
n_epoch: 50 n_epoch: 50
accum_grad: 4
lr: 1e-3 lr: 1e-3
lr_decay: 0.83 lr_decay: 0.83
weight_decay: 1e-06 weight_decay: 1e-06
......
...@@ -11,7 +11,7 @@ data: ...@@ -11,7 +11,7 @@ data:
max_output_input_ratio: .inf max_output_input_ratio: .inf
collator: collator:
batch_size: 20 batch_size: 15
mean_std_filepath: data/mean_std.json mean_std_filepath: data/mean_std.json
unit_type: char unit_type: char
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
...@@ -46,6 +46,7 @@ model: ...@@ -46,6 +46,7 @@ model:
training: training:
n_epoch: 50 n_epoch: 50
accum_grad: 4
lr: 1e-3 lr: 1e-3
lr_decay: 0.83 lr_decay: 0.83
weight_decay: 1e-06 weight_decay: 1e-06
......
...@@ -45,6 +45,7 @@ model: ...@@ -45,6 +45,7 @@ model:
training: training:
n_epoch: 10 n_epoch: 10
accum_grad: 1
lr: 1e-5 lr: 1e-5
lr_decay: 1.0 lr_decay: 1.0
weight_decay: 1e-06 weight_decay: 1e-06
......
...@@ -4,7 +4,7 @@ data: ...@@ -4,7 +4,7 @@ data:
dev_manifest: data/manifest.tiny dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny test_manifest: data/manifest.tiny
min_input_len: 0.0 min_input_len: 0.0
max_input_len: 27.0 max_input_len: 30.0
min_output_len: 0.0 min_output_len: 0.0
max_output_len: 400.0 max_output_len: 400.0
min_output_input_ratio: 0.05 min_output_input_ratio: 0.05
...@@ -47,6 +47,7 @@ model: ...@@ -47,6 +47,7 @@ model:
training: training:
n_epoch: 10 n_epoch: 10
accum_grad: 1
lr: 1e-5 lr: 1e-5
lr_decay: 1.0 lr_decay: 1.0
weight_decay: 1e-06 weight_decay: 1e-06
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册