未验证 提交 adc21a4d 编写于 作者: J Jeff Rasley 提交者: GitHub

ZeRO-1 empty grads fix + tests (#1273)

* fix empty grad zero tests

* dont clear grads in stage 1 code path

* prevent none grads from being reduced
上级 1ff25748
......@@ -473,7 +473,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
group.grad = None
group.grad = None #class init
return
......@@ -497,7 +497,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
if not self.overlap_comm:
for i, group in enumerate(self.fp16_groups):
for param in group:
self.reduce_ready_partitions_and_remove_grads(param, i)
if param.grad is not None:
self.reduce_ready_partitions_and_remove_grads(param, i)
# reduce any pending grads in either hook/non-hook case
self.overlapping_partition_gradients_reduce_epilogue()
......@@ -974,7 +975,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements).float()
dest_tensor.copy_(src_tensor, non_blocking=True)
param.grad = None
param.grad = None #offload only
def complete_grad_norm_calculation_for_cpu_offload(self, params):
total_norm = 0.0
......@@ -1095,17 +1096,19 @@ class FP16_DeepSpeedZeroOptimizer(object):
Multiple gradient reduction is currently not supported"
self.params_already_reduced[param_id] = True
if not self.is_param_in_current_partition[param_id]:
if self.overlap_comm and self.contiguous_gradients is False:
# Clear grads of other partitions during the next reduction
# to avoid clearing them before the reduction is complete.
if self.previous_reduced_grads is None:
self.previous_reduced_grads = []
self.previous_reduced_grads.append(param)
else:
param.grad = None
elif self.contiguous_gradients:
self.copy_grads_in_partition(param)
if self.partition_gradients:
if not self.is_param_in_current_partition[param_id]:
if self.overlap_comm and self.contiguous_gradients is False:
# Clear grads of other partitions during the next reduction
# to avoid clearing them before the reduction is complete.
if self.previous_reduced_grads is None:
self.previous_reduced_grads = []
self.previous_reduced_grads.append(param)
else:
param.grad = None #only if self.partition_gradients
elif self.contiguous_gradients:
self.copy_grads_in_partition(param)
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
......@@ -1125,7 +1128,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
for params_id in self.is_grad_computed[i][partition_id]:
if are_all_related_partitions_reduced(params_id):
self.param_dict[params_id].grad = None
self.param_dict[params_id].grad = None # dead code
def flatten_and_print(self, message, tensors, start=0, n=5):
flatten_tensor = self.flatten(tensors)
......@@ -1214,7 +1217,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
def _clear_previous_reduced_grads(self):
if self.previous_reduced_grads is not None:
for param in self.previous_reduced_grads:
param.grad = None
param.grad = None # overlap enabled
self.previous_reduced_grads = None
# if rank is specified do a reduction instead of an allreduce
......@@ -1331,7 +1334,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
p.grad = None
p.grad = None # epilogue and in step
else:
if p.grad is not None:
p.grad.detach_()
......@@ -1457,7 +1460,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
def free_grad_in_param_list(self, param_list):
for p in param_list:
p.grad = None
p.grad = None # in step
def reset_cpu_buffers(self):
self.norm_for_param_grads = {}
......@@ -1583,7 +1586,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
# get rid of the fp32 gradients. Not needed anymore
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
group.grad = None
group.grad = None # in step
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)
......
......@@ -17,10 +17,7 @@ class SimpleModel(torch.nn.Module):
def forward(self, x, y):
hidden_dim = x
if self.empty_grad and torch.distributed.get_rank() == 0:
hidden_dim = self.linear(hidden_dim) + self.linear2(hidden_dim)
else:
hidden_dim = self.linear(hidden_dim)
hidden_dim = self.linear(hidden_dim)
return self.cross_entropy_loss(hidden_dim, y)
......
......@@ -856,3 +856,38 @@ def test_zero3_lazyscatter(tmpdir):
model.step()
_go(args=args)
@pytest.mark.parametrize('stage', [1, 2, 3])
def test_zero_empty_grad(tmpdir, stage):
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": stage
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim)
@distributed_test(world_size=[1])
def _go(args, model, hidden_dim):
optimizer = torch.optim.Adam(model.parameters())
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
optimizer=optimizer)
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_go(args=args, model=model, hidden_dim=hidden_dim)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册