未验证 提交 a52cbf80 编写于 作者: J Jeff Rasley 提交者: GitHub

[zero-3] add bwd support for list/dict types returned in fwd (#1857)

上级 b4fcd98f
......@@ -73,9 +73,14 @@ def move_to_cpu(tensor_list):
tensor.data = tensor.data.cpu()
def is_builtin_type(obj):
# https://stackoverflow.com/a/17795199
return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins"
#apply torch.autograd.Function that calls a backward_function to tensors in output
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if type(outputs) is tuple:
if isinstance(outputs, (tuple, list)):
touched_outputs = []
for output in outputs:
touched_output = _apply_to_tensors_only(module,
......@@ -83,10 +88,23 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
backward_function,
output)
touched_outputs.append(touched_output)
return tuple(touched_outputs)
return outputs.__class__(touched_outputs)
elif isinstance(outputs, dict):
# apply inplace to avoid recreating dict inherited objects
for key in outputs.keys():
outputs[key] = _apply_to_tensors_only(module,
functional,
backward_function,
outputs[key])
return outputs
elif type(outputs) is torch.Tensor:
return functional.apply(module, backward_function, outputs)
else:
if not is_builtin_type(outputs):
logger.warning(
f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. "
"The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and "
"output tensors and therefore may not get triggered properly.")
return outputs
......
......@@ -1222,3 +1222,67 @@ def test_zero_offload_stage1():
model.step()
_go(model=model, hidden_dim=hidden_dim)
@pytest.mark.parametrize('return_type', [tuple, list, dict])
def test_z3_dict_fwd(return_type):
config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": 3
}
}
hidden_dim = 10
class MyModel(torch.nn.Module):
def __init__(self, hidden_dim):
super(MyModel, self).__init__()
self.l1 = torch.nn.Linear(hidden_dim, hidden_dim)
self.cel = torch.nn.CrossEntropyLoss()
def forward(self, x, y):
x = self.l1(x)
loss = self.cel(x, y)
if return_type == dict:
val = {'a': x, 'loss': loss, 'b': 1, 'c': None}
elif return_type == list:
val = [x, loss]
elif return_type == tuple:
val = (x, loss)
else:
raise NotImplementedError
return val
@distributed_test(world_size=[1])
def _go(hidden_dim):
with deepspeed.zero.Init():
model = MyModel(hidden_dim)
model, _, _, _ = deepspeed.initialize(model=model,
model_parameters=model.parameters(),
config=config_dict)
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
torch.distributed.barrier()
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
if return_type == dict:
loss = loss['loss']
else:
loss = loss[1]
model.backward(loss)
model.step()
_go(hidden_dim)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册