From ba67bd9a141f0277d451b876b63862c34e30bfea Mon Sep 17 00:00:00 2001 From: Nicholas Cilfone <23509131+ncilfone@users.noreply.github.com> Date: Fri, 29 Jul 2022 10:24:25 -0400 Subject: [PATCH] Added retain_graph as a kwarg to the main engine backward function (#1149) Co-authored-by: Jeff Rasley --- deepspeed/runtime/engine.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 9d1b8b6a..5ab14a71 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1711,11 +1711,14 @@ class DeepSpeedEngine(Module): loss, allreduce_gradients=True, release_loss=False, + retain_graph=False, scale_wrt_gas=True): r"""Execute backward pass on the loss Arguments: loss: Torch tensor on which to execute backward propagation allreduce_gradients: is deprecated, ignored, and will soon be removed' + retain_graph: bool, default: false + forward on user defined choice of retain_graph """ see_memory_usage("Engine before backward", force=self.memory_breakdown()) @@ -1751,9 +1754,9 @@ class DeepSpeedEngine(Module): self._start_timers(self.engine_timers.backward_inner_timers) if self.zero_optimization(): - self.optimizer.is_gradient_accumulation_boundary = ( - self.is_gradient_accumulation_boundary()) - self.optimizer.backward(loss) + self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary( + ) + self.optimizer.backward(loss, retain_graph=retain_graph) elif self.amp_enabled(): # AMP requires delaying unscale when inside gradient accumulation boundaries # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations @@ -1761,19 +1764,19 @@ class DeepSpeedEngine(Module): with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss: - scaled_loss.backward() + scaled_loss.backward(retain_graph=retain_graph) elif self.fp16_enabled(): if self.eigenvalue_enabled(): self.optimizer.backward(loss, create_graph=True, retain_graph=True) else: - self.optimizer.backward(loss) + self.optimizer.backward(loss, retain_graph=retain_graph) elif self.bfloat16_enabled(): self.optimizer.backward(loss) else: if self.eigenvalue_enabled(): loss.backward(create_graph=True, retain_graph=True) else: - loss.backward() + loss.backward(retain_graph=retain_graph) self._stop_timers(self.engine_timers.backward_inner_timers) -- GitLab