未验证 提交 f5cce75e 编写于 作者: S Shaden Smith 提交者: GitHub

Overflow fix (#416)

* Switches fused_optimizer overflow calculation
上级 7d91be97
......@@ -11,7 +11,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import logger
from deepspeed.utils import logger, log_dist
class FP16_Optimizer(object):
......@@ -204,9 +204,30 @@ class FP16_Optimizer(object):
UPDATE_FP16 = 'update_fp16'
STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP, BASIC_STEP, UPDATE_FP16]
# First compute norm for all group so we know if there is overflow
grads_groups_flat = []
# First determine if there is overflow.
self.start_timers([OVERFLOW_CHECK])
fp16_params = []
for i, group in enumerate(self.fp16_groups):
fp16_params.extend([p for p in group if p.grad is not None])
self.overflow = self.overflow_checker.has_overflow(fp16_params)
self.stop_timers([OVERFLOW_CHECK])
prev_scale = self.cur_scale
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
log_dist(
"Overflow detected. Skipping step. Attempted loss "
f"scale: {prev_scale}, reducing to {self.cur_scale}",
ranks=[0])
# Clear gradients
for i, group in enumerate(self.fp16_groups):
for p in group:
p.grad = None
self.log_timers(OVERFLOW_TIMERS)
return self.overflow
grads_groups_flat = []
for i, group in enumerate(self.fp16_groups):
data_type = self.fp32_groups_flat[i].dtype
......@@ -227,22 +248,6 @@ class FP16_Optimizer(object):
all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu)
self.stop_timers([COMPUTE_NORM])
self.start_timers([OVERFLOW_CHECK])
self.overflow = self.overflow_checker.check_using_norm([all_groups_norm])
self.stop_timers([OVERFLOW_CHECK])
prev_scale = self.cur_scale
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {} ".format(prev_scale,
self.cur_scale))
self.log_timers(OVERFLOW_TIMERS)
grads_groups_flat = None
return self.overflow
self.start_timers([UNSCALE_AND_CLIP])
self.unscale_and_clip_grads(grads_groups_flat, [all_groups_norm])
self.stop_timers([UNSCALE_AND_CLIP])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册