提交 466672e1 编写于 作者: H Hui Zhang

no_sync if paddle support else nullcontext

上级 b4e16eb8
...@@ -87,7 +87,8 @@ class DeepSpeech2Trainer(Trainer): ...@@ -87,7 +87,8 @@ class DeepSpeech2Trainer(Trainer):
# Disable gradient synchronizations across DDP processes. # Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module # Within this context, gradients will be accumulated on module
# variables, which will later be synchronized. # variables, which will later be synchronized.
context = self.model.no_sync context = self.model.no_sync if (hasattr(self.model, "no_sync") and
self.parallel) else nullcontext
else: else:
# Used for single gpu training and DDP gradient synchronization # Used for single gpu training and DDP gradient synchronization
# processes. # processes.
......
...@@ -106,7 +106,8 @@ class U2Trainer(Trainer): ...@@ -106,7 +106,8 @@ class U2Trainer(Trainer):
# Within this context, gradients will be accumulated on module # Within this context, gradients will be accumulated on module
# variables, which will later be synchronized. # variables, which will later be synchronized.
# When using cpu w/o DDP, model does not have `no_sync` # When using cpu w/o DDP, model does not have `no_sync`
context = self.model.no_sync if self.parallel else nullcontext context = self.model.no_sync if (hasattr(self.model, "no_sync") and
self.parallel) else nullcontext
else: else:
# Used for single gpu training and DDP gradient synchronization # Used for single gpu training and DDP gradient synchronization
# processes. # processes.
......
...@@ -105,7 +105,8 @@ class U2Trainer(Trainer): ...@@ -105,7 +105,8 @@ class U2Trainer(Trainer):
# Disable gradient synchronizations across DDP processes. # Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module # Within this context, gradients will be accumulated on module
# variables, which will later be synchronized. # variables, which will later be synchronized.
context = self.model.no_sync context = self.model.no_sync if (hasattr(self.model, "no_sync") and
self.parallel) else nullcontext
else: else:
# Used for single gpu training and DDP gradient synchronization # Used for single gpu training and DDP gradient synchronization
# processes. # processes.
......
...@@ -110,7 +110,8 @@ class U2STTrainer(Trainer): ...@@ -110,7 +110,8 @@ class U2STTrainer(Trainer):
# Disable gradient synchronizations across DDP processes. # Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module # Within this context, gradients will be accumulated on module
# variables, which will later be synchronized. # variables, which will later be synchronized.
context = self.model.no_sync context = self.model.no_sync if (hasattr(self.model, "no_sync") and
self.parallel) else nullcontext
else: else:
# Used for single gpu training and DDP gradient synchronization # Used for single gpu training and DDP gradient synchronization
# processes. # processes.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册