diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index d2f7870008e18af5e43c91ce3735055f8131f253..8c1d2003cb1ba9bb6bf1e1236e943d3fdf77d393 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -11,7 +11,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE -from deepspeed.utils import logger +from deepspeed.utils import logger, log_dist class FP16_Optimizer(object): @@ -204,9 +204,30 @@ class FP16_Optimizer(object): UPDATE_FP16 = 'update_fp16' STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP, BASIC_STEP, UPDATE_FP16] - # First compute norm for all group so we know if there is overflow - grads_groups_flat = [] + # First determine if there is overflow. + self.start_timers([OVERFLOW_CHECK]) + fp16_params = [] + for i, group in enumerate(self.fp16_groups): + fp16_params.extend([p for p in group if p.grad is not None]) + self.overflow = self.overflow_checker.has_overflow(fp16_params) + self.stop_timers([OVERFLOW_CHECK]) + prev_scale = self.cur_scale + self._update_scale(self.overflow) + if self.overflow: + if self.verbose: + log_dist( + "Overflow detected. Skipping step. Attempted loss " + f"scale: {prev_scale}, reducing to {self.cur_scale}", + ranks=[0]) + # Clear gradients + for i, group in enumerate(self.fp16_groups): + for p in group: + p.grad = None + + self.log_timers(OVERFLOW_TIMERS) + return self.overflow + grads_groups_flat = [] for i, group in enumerate(self.fp16_groups): data_type = self.fp32_groups_flat[i].dtype @@ -227,22 +248,6 @@ class FP16_Optimizer(object): all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu) self.stop_timers([COMPUTE_NORM]) - self.start_timers([OVERFLOW_CHECK]) - self.overflow = self.overflow_checker.check_using_norm([all_groups_norm]) - self.stop_timers([OVERFLOW_CHECK]) - - prev_scale = self.cur_scale - self._update_scale(self.overflow) - - if self.overflow: - if self.verbose: - print("[deepspeed] OVERFLOW! Skipping step. Attempted loss " - "scale: {}, reducing to {} ".format(prev_scale, - self.cur_scale)) - self.log_timers(OVERFLOW_TIMERS) - grads_groups_flat = None - return self.overflow - self.start_timers([UNSCALE_AND_CLIP]) self.unscale_and_clip_grads(grads_groups_flat, [all_groups_norm]) self.stop_timers([UNSCALE_AND_CLIP])