diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 7431b2c892c44167bc340f797f0cd08d252c17a4..76ba6af78b76b10f940661620f494798ad20bd51 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -979,7 +979,7 @@ class DeepSpeedEngine(Module): torch.nn.utils.clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping()) - def _take_model_step(self): + def _take_model_step(self, lr_kwargs): if self.gradient_clipping() > 0.0: if not self.fp16_enabled() and not self.amp_enabled(): self.clip_fp32_gradients() @@ -1010,14 +1010,14 @@ class DeepSpeedEngine(Module): self.skipped_steps += 1 else: if self.lr_scheduler is not None: - self.lr_scheduler.step() + self.lr_scheduler.step(**(lr_kwargs or {})) if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0: self._report_progress(self.global_steps + 1) self.global_steps += 1 self.global_samples += self.train_batch_size() - def step(self): + def step(self, lr_kwargs=None): r"""Execute the weight update step after forward and backward propagation on effective_train_batch. """ @@ -1034,7 +1034,7 @@ class DeepSpeedEngine(Module): if self.progressive_layer_drop: self.progressive_layer_drop.update_state(self.global_steps) - self._take_model_step() + self._take_model_step(lr_kwargs) self.tput_timer.stop(report_progress) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 954774e589123766ade0a1efa814e3c6bb76e18e..5c5d896dfc0d14f83cb5d4d5a1c55b98b0c06e49 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -940,14 +940,14 @@ class PipelineEngine(DeepSpeedEngine): if self.wall_clock_breakdown(): self.timers('pipe_recv_grad').stop() - def _exec_optimizer_step(self): + def _exec_optimizer_step(self, lr_kwargs=None): if self.wall_clock_breakdown(): self.timers('step_microstep').start() self.timers('step').start() self.mem_status('BEFORE STEP', reset_max=True) self._force_grad_boundary = True - self._take_model_step() + self._take_model_step(lr_kwargs) self._force_grad_boundary = False self.mem_status('AFTER STEP')