未验证 提交 53c73fe3 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Support multi-output models (#170)

* Push to remote

* Correctly handle multi output models by doing loss scaling in backward()
Unit tests for multi output models

* Fix formatting issues

* Formatting issues fix

* Fix formatting

* Update DeepSpeedExamples submodule
Enable Megatron model tests
上级 43f27332
Subproject commit e0d2d7f4a86f03612bc0a210a5e4dbcc798b48a6
Subproject commit 32ae819bdb514a348d47a071ceaca509e352fb5e
......@@ -576,20 +576,25 @@ class DeepSpeedLight(Module):
self.warn_unscaled_loss = True
self.module.train(False)
def _scale_loss(self, loss):
if isinstance(loss, torch.Tensor):
loss = loss / self.gradient_accumulation_steps()
elif isinstance(loss, tuple) and isinstance(loss[0], torch.Tensor):
loss = (l / self.gradient_accumulation_steps() for l in loss)
elif isinstance(loss, list) and isinstance(loss[0], torch.Tensor):
loss = [l / self.gradient_accumulation_steps() for l in loss]
def _scale_loss(self, prescaled_loss):
if isinstance(prescaled_loss, torch.Tensor):
scaled_loss = prescaled_loss / self.gradient_accumulation_steps()
elif isinstance(prescaled_loss, tuple) or isinstance(prescaled_loss, list):
scaled_loss = []
for l in prescaled_loss:
if isinstance(l, torch.Tensor):
scaled_loss.append(l / self.gradient_accumulation_steps())
else:
scaled_loss.append(l)
else:
scaled_loss = prescaled_loss
if self.warn_unscaled_loss:
logging.warning(
f'DeepSpeed unable to scale loss because of type: {type(loss)}')
f'DeepSpeed unable to scale loss because of type: {type(prescaled_loss)}'
)
self.warn_unscaled_loss = False
return loss
return scaled_loss
def forward(self, *inputs, **kwargs):
r"""Execute forward propagation
......@@ -607,10 +612,6 @@ class DeepSpeedLight(Module):
self.tput_timer.start()
loss = self.module(*inputs, **kwargs)
# scale loss w.r.t. gradient accumulation if needed
if self.gradient_accumulation_steps() > 1:
loss = self._scale_loss(loss)
if self.wall_clock_breakdown():
self.timers('forward').stop()
self.timers('forward_microstep').stop()
......@@ -629,6 +630,10 @@ class DeepSpeedLight(Module):
allreduce_gradients: If this is False, then gradient averaging will be skipped. Default is True.
"""
# scale loss w.r.t. gradient accumulation if needed
if self.gradient_accumulation_steps() > 1:
loss = self._scale_loss(loss)
# Log training Loss
if self.tensorboard_enabled():
if self.is_gradient_accumulation_boundary():
......@@ -684,6 +689,8 @@ class DeepSpeedLight(Module):
self.timers('backward').stop()
self.timers('backward_microstep').stop()
return loss
def is_gradient_accumulation_boundary(self):
return (self.micro_steps + 1) % \
self.gradient_accumulation_steps() == 0
......
......@@ -33,8 +33,8 @@ def test_run():
runner = unittest.TextTestRunner(failfast=True)
# Add test suites here.
#pytest_hack(runner.run(Megatron_GPT2.suite()))
#pytest_hack(runner.run(Megatron_GPT2.checkpoint_suite()))
pytest_hack(runner.run(Megatron_GPT2.suite()))
pytest_hack(runner.run(Megatron_GPT2.checkpoint_suite()))
pytest_hack(runner.run(BingBertSquad.suite()))
......
import os
import json
import argparse
import torch
class MultiOutputModel(torch.nn.Module):
def __init__(self, hidden_dim, weight_value):
super(MultiOutputModel, self).__init__()
self.linear = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.linear.weight.data.fill_(weight_value)
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
def forward(self, inputs, targets):
losses = []
for x, y in zip(inputs, targets):
hidden_dim = self.linear(x)
loss = self.cross_entropy_loss(hidden_dim, y)
losses.append(loss)
return tuple(losses)
def multi_output_dataloader(model, total_samples, hidden_dim, device, inputs, targets):
assert len(inputs) == len(targets)
batch_size = model.train_micro_batch_size_per_gpu()
train_data = [
torch.full(size=(total_samples,
hidden_dim),
fill_value=x,
device=device,
dtype=torch.half,
requires_grad=True) for x in inputs
]
train_label = [
torch.empty(total_samples,
device=device,
dtype=torch.long).fill_(y) for y in targets
]
train_dataset = torch.utils.data.TensorDataset(*train_data, *train_label)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
return train_loader
import torch
import deepspeed
import argparse
import pytest
from pytest import approx
import json
import os
from common import distributed_test
from simple_model import args_from_dict
from multi_output_model import MultiOutputModel, multi_output_dataloader
def create_config_dict(micro_batch_size, grad_accumulation_steps, world_size):
return {
"train_micro_batch_size_per_gpu": micro_batch_size,
"gradient_accumulation_steps": grad_accumulation_steps,
"train_batch_size": micro_batch_size * grad_accumulation_steps * world_size,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True
}
}
def test_two_output_model(tmpdir):
gradient_accumulation_steps = 2
micro_batch_size = 1
world_size = 1
config_dict = create_config_dict(micro_batch_size,
gradient_accumulation_steps,
world_size)
hidden_dim = 10
weight_value = 0.1
args = args_from_dict(tmpdir, config_dict)
model = MultiOutputModel(hidden_dim, weight_value)
@distributed_test(world_size=[1])
def _test_two_output_model(args, model, hidden_dim):
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
total_samples = 4
data_loader = multi_output_dataloader(model=model,
total_samples=total_samples,
hidden_dim=hidden_dim,
device=model.device,
inputs=[1.0,
2.0],
targets=[1,
2])
for n, batch in enumerate(data_loader):
assert len(batch) % 2 == 0, \
f"multi_output_dataloader failed to return even number of data samples (input+target)"
midpoint = len(batch) // 2
inputs, targets = batch[:midpoint], batch[midpoint:]
loss_tuple = model(inputs, targets)
expected_loss = torch.tensor(2.302734375,
dtype=torch.half,
device=model.device)
for loss in loss_tuple:
assert loss.shape == torch.Size([])
assert loss.item() == approx(expected_loss.item())
summed_loss = sum(loss_tuple)
scaled_loss = model.backward(summed_loss)
expected_scaled_loss = summed_loss / gradient_accumulation_steps
assert scaled_loss.item() == approx(expected_scaled_loss.item())
model.step()
_test_two_output_model(args=args, model=model, hidden_dim=hidden_dim)
def test_three_output_model(tmpdir):
gradient_accumulation_steps = 3
micro_batch_size = 1
world_size = 1
config_dict = create_config_dict(micro_batch_size,
gradient_accumulation_steps,
world_size)
hidden_dim = 10
weight_value = 0.1
args = args_from_dict(tmpdir, config_dict)
model = MultiOutputModel(hidden_dim, weight_value)
@distributed_test(world_size=[1])
def _test_three_output_model(args, model, hidden_dim):
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
total_samples = gradient_accumulation_steps * micro_batch_size * 2
data_loader = multi_output_dataloader(model=model,
total_samples=total_samples,
hidden_dim=hidden_dim,
device=model.device,
inputs=[1.0,
2.0,
3.0],
targets=[1,
2,
3])
for n, batch in enumerate(data_loader):
assert len(batch) % 2 == 0, \
f"multi_output_dataloader failed to return even number of data samples (input+target)"
midpoint = len(batch) // 2
inputs, targets = batch[:midpoint], batch[midpoint:]
loss_tuple = model(inputs, targets)
assert len(loss_tuple) == 3
expected_loss = torch.tensor(2.302734375,
dtype=torch.half,
device=model.device)
for loss in loss_tuple:
assert loss.shape == torch.Size([])
assert loss.item() == approx(expected_loss.item())
summed_loss = sum(loss_tuple)
scaled_loss = model.backward(summed_loss)
expected_scaled_loss = summed_loss / gradient_accumulation_steps
assert scaled_loss.item() == approx(expected_scaled_loss.item())
model.step()
_test_three_output_model(args=args, model=model, hidden_dim=hidden_dim)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册