未验证 提交 f5453124 编写于 作者: J Jeff Rasley 提交者: GitHub

Support amp deepspeed backend (#286)

* add amp support for deepspeed (non-ZeRO)
* tests for amp mode
上级 4a3234e0
......@@ -5,6 +5,7 @@ Licensed under the MIT license.
import torch
import json
import copy
from deepspeed.pt.deepspeed_constants import *
from deepspeed.pt.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
from deepspeed.pt.deepspeed_config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys
......@@ -18,6 +19,22 @@ LAMB_OPTIMIZER = 'lamb'
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER]
def get_amp_enabled(param_dict):
if AMP in param_dict.keys():
return get_scalar_param(param_dict[AMP], AMP_ENABLED, AMP_ENABLED_DEFAULT)
else:
return False
def get_amp_params(param_dict):
if AMP in param_dict.keys():
amp_params = copy.copy(param_dict[AMP])
amp_params.pop(AMP_ENABLED)
return amp_params
else:
return False
def get_fp16_enabled(param_dict):
if FP16 in param_dict.keys():
return get_scalar_param(param_dict[FP16], FP16_ENABLED, FP16_ENABLED_DEFAULT)
......@@ -315,6 +332,8 @@ class DeepSpeedConfig(object):
self.gradient_clipping = get_gradient_clipping(param_dict)
self.fp16_enabled = get_fp16_enabled(param_dict)
self.amp_enabled = get_amp_enabled(param_dict)
self.amp_params = get_amp_params(param_dict)
self.loss_scale = get_loss_scale(param_dict)
self.initial_dynamic_scale = get_initial_dynamic_scale(param_dict)
self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict)
......
......@@ -117,6 +117,24 @@ FP16_HYSTERESIS_DEFAULT = 2
FP16_MIN_LOSS_SCALE = "min_loss_scale"
FP16_MIN_LOSS_SCALE_DEFAULT = 1
#########################################
# Apex AMP support
#########################################
# Use Apex AMP for mixed precision support, all parameters (other than 'enabled') will be passed to
# amp.initialize(model, optimizer, **amp_params)
# See apex documentation for supported parameters/features: https://nvidia.github.io/apex/amp.html#apex.amp.initialize
AMP_FORMAT = '''
"amp" {
"enabled: true,
"opt_level": "O1",
...
}
'''
AMP = "amp"
AMP_ENABLED = "enabled"
AMP_ENABLED_DEFAULT = False
#########################################
# Gradient clipping
#########################################
......
......@@ -8,6 +8,7 @@ import warnings
import torch.distributed as dist
from torch.nn.modules import Module
from torch.distributed.distributed_c10d import _get_global_rank
from apex import amp
from tensorboardX import SummaryWriter
......@@ -312,6 +313,12 @@ class DeepSpeedLight(Module):
def fp16_enabled(self):
return self._config.fp16_enabled
def amp_enabled(self):
return self._config.amp_enabled
def amp_params(self):
return self._config.amp_params
def loss_scale(self):
return self._config.loss_scale
......@@ -449,28 +456,33 @@ class DeepSpeedLight(Module):
assert self.dynamic_loss_scale(), \
'DeepSpeed {} optimizer requires dynamic loss scaling'.format(self.optimizer_name())
def _broadcast_model(self):
for p in self.module.parameters():
if torch.is_tensor(p):
dist.broadcast(p,
self.broadcast_src_rank,
group=self.data_parallel_group)
def _configure_distributed_model(self, model):
self.module = model
if self.fp16_enabled():
self.module.half()
self.module.to(self.device)
if self.mpu is None:
self.data_parallel_group = _initialize_parameter_parallel_groups()
self.dp_world_size = dist.get_world_size()
src_rank = 0
self.broadcast_src_rank = 0
else:
self.data_parallel_group = self.mpu.get_data_parallel_group()
self.dp_world_size = self.mpu.get_data_parallel_world_size()
src_rank = _get_global_rank(self.mpu.get_data_parallel_group(), 0)
logger.info(f"global src_rank={src_rank}")
for p in self.module.parameters():
if torch.is_tensor(p):
dist.broadcast(p, src_rank, group=self.data_parallel_group)
self.broadcast_src_rank = _get_global_rank(
self.mpu.get_data_parallel_group(),
0)
logger.info(f"global src_rank={self.broadcast_src_rank}")
# TODO: support new AMP optimizer
# self.module.half()
# self.module.to(self.local_rank)
#self.module, self.optimizer = amp.initialize(self.module, self.optimizer, opt_level="O2")
if not self.amp_enabled():
self._broadcast_model()
# Configure optimizer
def _configure_optimizer(self, client_optimizer, model_parameters):
......@@ -486,6 +498,7 @@ class DeepSpeedLight(Module):
logger.info('DeepSpeed Basic Optimizer = {}'.format(basic_optimizer))
if self.zero_optimization():
assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
if self.optimizer_name() != ADAM_OPTIMIZER:
assert self.zero_allow_untested_optimizer(), \
'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'
......@@ -494,6 +507,12 @@ class DeepSpeedLight(Module):
"**** You are using ZeRO with an untested optimizer, proceed with caution *****"
)
self.optimizer = self._configure_zero_optimizer(basic_optimizer)
elif self.amp_enabled():
assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode"
amp_params = self.amp_params()
logger.info(f"Initializing AMP with these params: {amp_params}")
self.module, self.optimizer = amp.initialize(self.module, basic_optimizer, **amp_params)
self._broadcast_model()
elif self.fp16_enabled():
self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
else:
......@@ -748,12 +767,11 @@ class DeepSpeedLight(Module):
if self.zero_optimization():
self.optimizer.backward(loss)
elif self.amp_enabled():
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
elif self.fp16_enabled():
self.optimizer.backward(loss)
# TODO: Use new AMP semantics as below
# with amp.scale_loss(loss, self.optimizer) as scaled_loss:
# scaled_loss.backward()
else:
loss.backward()
......
......@@ -395,3 +395,103 @@ def test_zero_empty_partition(tmpdir, zero_stage):
model.step()
_test_zero_empty_partition(args)
def test_adam_amp_basic(tmpdir):
config_dict = {"train_batch_size": 1, "steps_per_print": 1, "amp": {"enabled": True}}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1])
def _test_adam_amp_basic(args, model, hidden_dim):
optimizer = torch.optim.Adam(params=model.parameters())
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
optimizer=optimizer)
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_adam_amp_basic(args=args, model=model, hidden_dim=hidden_dim)
def test_lamb_amp_basic(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Lamb",
"params": {
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"amp": {
"enabled": True,
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1, 2])
def _test_lamb_amp_basic(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_lamb_amp_basic(args=args, model=model, hidden_dim=hidden_dim)
def test_adam_amp_o2(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"amp": {
"enabled": True,
"opt_level": "O2"
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1, 2])
def _test_adam_amp_o2(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_adam_amp_o2(args=args, model=model, hidden_dim=hidden_dim)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册