未验证 提交 08c96a1b 编写于 作者: J Jeff Rasley 提交者: GitHub

ZeRO-1 tune max-elems + bug fix (#532)

* zero-1 memory fix

* auto-tune max elems per comm to reduce padding/comm intervals

* clean-up and added previously missing reduction options

* fix testing backing to work with torch1.7
上级 fdd81c30
......@@ -661,7 +661,7 @@ class DeepSpeedEngine(Module):
def _configure_zero_optimizer(self, optimizer):
zero_stage = self.zero_optimization_stage()
logger.info('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage))
assert not self.allreduce_always_fp32(), "ZeRO does not support 'fp32_allreduce': true"
if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode'
optimizer = FP16_DeepSpeedZeroOptimizer_Stage1(
......
......@@ -68,12 +68,6 @@ class CheckOverflow(object):
return bool(overflow)
def check(self, param_groups=None):
#TODO: what's the equivalent here? do we need this?
# for group in self.fp32_from_fp32_groups:
# for param in group:
# params.append(param)
params = []
if param_groups is None:
params = self.params
......
此差异已折叠。
......@@ -41,6 +41,8 @@ def distributed_test(world_size=2, backend='nccl'):
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
if 'args' in func_kwargs:
func_kwargs['args'].local_rank = local_rank
run_func(*func_args, **func_kwargs)
def dist_launcher(num_procs, *func_args, **func_kwargs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册