未验证 提交 a4763f55 编写于 作者: C carefree0910 提交者: GitHub

Supported customizing kwargs for lr_scheduler (#584)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 66268bd3
......@@ -979,7 +979,7 @@ class DeepSpeedEngine(Module):
torch.nn.utils.clip_grad_norm_(parameters=self.module.parameters(),
max_norm=self.gradient_clipping())
def _take_model_step(self):
def _take_model_step(self, lr_kwargs):
if self.gradient_clipping() > 0.0:
if not self.fp16_enabled() and not self.amp_enabled():
self.clip_fp32_gradients()
......@@ -1010,14 +1010,14 @@ class DeepSpeedEngine(Module):
self.skipped_steps += 1
else:
if self.lr_scheduler is not None:
self.lr_scheduler.step()
self.lr_scheduler.step(**(lr_kwargs or {}))
if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0:
self._report_progress(self.global_steps + 1)
self.global_steps += 1
self.global_samples += self.train_batch_size()
def step(self):
def step(self, lr_kwargs=None):
r"""Execute the weight update step after forward and backward propagation
on effective_train_batch.
"""
......@@ -1034,7 +1034,7 @@ class DeepSpeedEngine(Module):
if self.progressive_layer_drop:
self.progressive_layer_drop.update_state(self.global_steps)
self._take_model_step()
self._take_model_step(lr_kwargs)
self.tput_timer.stop(report_progress)
......
......@@ -940,14 +940,14 @@ class PipelineEngine(DeepSpeedEngine):
if self.wall_clock_breakdown():
self.timers('pipe_recv_grad').stop()
def _exec_optimizer_step(self):
def _exec_optimizer_step(self, lr_kwargs=None):
if self.wall_clock_breakdown():
self.timers('step_microstep').start()
self.timers('step').start()
self.mem_status('BEFORE STEP', reset_max=True)
self._force_grad_boundary = True
self._take_model_step()
self._take_model_step(lr_kwargs)
self._force_grad_boundary = False
self.mem_status('AFTER STEP')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册