未验证 提交 376818ef 编写于 作者: J Jeff Rasley 提交者: GitHub

Empty grad fix (#291)

* empty grad fix
* add unit tests for empty grad
上级 607814fe
......@@ -979,7 +979,17 @@ class DeepSpeedLight(Module):
def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000):
grads = []
for param_name, param in self.module.named_parameters():
if param.grad is not None:
if param.grad is None:
# In cases where there is an imbalance of empty grads across
# ranks we must create empty grads, this will ensure that every
# rank is reducing the same size. In some cases it may make
# sense in the future to support the ability to average not
# w.r.t. world size but with a different value.
grads.append(
torch.zeros(param.size(),
dtype=param.dtype,
device=param.device))
else:
grad_data = param.grad.data
if self.sparse_gradients_enabled(
) and param_name in self.csr_tensor_module_names:
......
......@@ -5,16 +5,21 @@ import torch
class SimpleModel(torch.nn.Module):
def __init__(self, hidden_dim, empty_grad=False):
def __init__(self, hidden_dim, empty_grad=False, rank=0):
super(SimpleModel, self).__init__()
self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
if empty_grad:
self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)])
self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
self.rank = rank
self.empty_grad = empty_grad
def forward(self, x, y):
hidden_dim = x
hidden_dim = self.linear(hidden_dim)
if self.rank == 0 and self.empty_grad:
hidden_dim = self.linear(hidden_dim) + self.linear2(hidden_dim)
else:
hidden_dim = self.linear(hidden_dim)
return self.cross_entropy_loss(hidden_dim, y)
......
......@@ -33,9 +33,10 @@ def test_lamb_fp32_grad_clip(tmpdir):
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
device=model.device,
dtype=torch.float)
for n, batch in enumerate(data_loader):
loss = model(batch[0].float(), batch[1])
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
......@@ -81,7 +82,7 @@ def test_lamb_fp16_basic(tmpdir):
def test_lamb_fp16_empty_grad(tmpdir):
config_dict = {
"train_batch_size": 1,
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Lamb",
......@@ -97,9 +98,9 @@ def test_lamb_fp16_empty_grad(tmpdir):
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True)
model = SimpleModel(hidden_dim, empty_grad=True, rank=args.local_rank)
@distributed_test(world_size=[1])
@distributed_test(world_size=[2])
def _test_lamb_fp16_empty_grad(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
......@@ -116,6 +117,44 @@ def test_lamb_fp16_empty_grad(tmpdir):
_test_lamb_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim)
def test_adam_fp32_empty_grad(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": False
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True, rank=args.local_rank)
@distributed_test(world_size=[2])
def _test_adam_fp32_empty_grad(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.float)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_adam_fp32_empty_grad(args=args, model=model, hidden_dim=hidden_dim)
def test_adamw_fp16_basic(tmpdir):
config_dict = {
"train_batch_size": 1,
......@@ -495,3 +534,41 @@ def test_adam_amp_o2(tmpdir):
model.step()
_test_adam_amp_o2(args=args, model=model, hidden_dim=hidden_dim)
def test_adam_amp_o2_empty_grad(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"amp": {
"enabled": True,
"opt_level": "O2"
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False, rank=args.local_rank)
@distributed_test(world_size=[2])
def _test_adam_amp_o2_empty_grad(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_adam_amp_o2_empty_grad(args=args, model=model, hidden_dim=hidden_dim)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册