未验证 提交 ba67bd9a 编写于 作者: N Nicholas Cilfone 提交者: GitHub

Added retain_graph as a kwarg to the main engine backward function (#1149)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 556f0051
......@@ -1711,11 +1711,14 @@ class DeepSpeedEngine(Module):
loss,
allreduce_gradients=True,
release_loss=False,
retain_graph=False,
scale_wrt_gas=True):
r"""Execute backward pass on the loss
Arguments:
loss: Torch tensor on which to execute backward propagation
allreduce_gradients: is deprecated, ignored, and will soon be removed'
retain_graph: bool, default: false
forward on user defined choice of retain_graph
"""
see_memory_usage("Engine before backward", force=self.memory_breakdown())
......@@ -1751,9 +1754,9 @@ class DeepSpeedEngine(Module):
self._start_timers(self.engine_timers.backward_inner_timers)
if self.zero_optimization():
self.optimizer.is_gradient_accumulation_boundary = (
self.is_gradient_accumulation_boundary())
self.optimizer.backward(loss)
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary(
)
self.optimizer.backward(loss, retain_graph=retain_graph)
elif self.amp_enabled():
# AMP requires delaying unscale when inside gradient accumulation boundaries
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
......@@ -1761,19 +1764,19 @@ class DeepSpeedEngine(Module):
with amp.scale_loss(loss,
self.optimizer,
delay_unscale=delay_unscale) as scaled_loss:
scaled_loss.backward()
scaled_loss.backward(retain_graph=retain_graph)
elif self.fp16_enabled():
if self.eigenvalue_enabled():
self.optimizer.backward(loss, create_graph=True, retain_graph=True)
else:
self.optimizer.backward(loss)
self.optimizer.backward(loss, retain_graph=retain_graph)
elif self.bfloat16_enabled():
self.optimizer.backward(loss)
else:
if self.eigenvalue_enabled():
loss.backward(create_graph=True, retain_graph=True)
else:
loss.backward()
loss.backward(retain_graph=retain_graph)
self._stop_timers(self.engine_timers.backward_inner_timers)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册