未验证 提交 21c28029 编写于 作者: J Joe Mayer 提交者: GitHub

Adding Gradient Accumulation Data Type Config (#2512)

* Adding gradient accumulation dtype config.

* Switching to new DtypeEnum

* Adding standalone check function, and unit tests

* Variable disambiguation

* Adding checks for unsupported states.

* Updating for PR comments.

* Reorganizing unit test.
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 90ae6884
......@@ -656,6 +656,10 @@ def get_checkpoint_params(param_dict):
return param_dict.get(CHECKPOINT, {})
def get_data_types_params(param_dict):
return param_dict.get(DATA_TYPES, {})
def get_checkpoint_tag_validation_mode(checkpoint_params):
tag_validation_mode = checkpoint_params.get(CHECKPOINT_TAG_VALIDATION,
CHECKPOINT_TAG_VALIDATION_DEFAULT)
......@@ -905,6 +909,10 @@ class DeepSpeedConfig(object):
USE_NODE_LOCAL_STORAGE_CHECKPOINT,
USE_NODE_LOCAL_STORAGE_CHECKPOINT_DEFAULT)
data_types_params = get_data_types_params(param_dict)
self.grad_accum_dtype = data_types_params.get(GRAD_ACCUM_DTYPE,
GRAD_ACCUM_DTYPE_DEFAULT)
par_write_pipe = get_checkpoint_parallel_write_pipeline(checkpoint_params)
self.checkpoint_parallel_write_pipeline = par_write_pipe
......
......@@ -391,6 +391,18 @@ CHECKPOINT_PARALLEL_WRITE = "parallel_write"
CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE = "pipeline_stage"
CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE_DEFAULT = False
#########################################
# Data types config params
#########################################
# "data_types": {
# grad_accum_dtype=["bf16"|"fp16"|"fp32"]
# }
# }
DATA_TYPES = "data_types"
GRAD_ACCUM_DTYPE = "grad_accum_dtype"
GRAD_ACCUM_DTYPE_DEFAULT = None
#########################################
# Drop the last incomplete Batch
# #########################################
......
......@@ -25,6 +25,7 @@ from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
......@@ -37,7 +38,7 @@ from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
PLD_THETA, PLD_GAMMA, BFLOAT16, FP16
PLD_THETA, PLD_GAMMA, BFLOAT16, FP16, AMP
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.compression import compression_scheduler
from deepspeed.compression.constants import \
......@@ -78,6 +79,8 @@ from ..git_version_info import version
from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler
from deepspeed.utils.logging import print_json_dist
from deepspeed.inference.config import DtypeEnum
# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init
dist = None
......@@ -796,6 +799,23 @@ class DeepSpeedEngine(Module):
def aio_config(self):
return self._config.aio_config
def get_data_types(self):
model_dtype = torch.float32
if self.fp16_enabled():
model_dtype = torch.float16
elif self.bfloat16_enabled():
model_dtype = torch.bfloat16
if self._config.grad_accum_dtype == None:
if model_dtype == torch.bfloat16:
grad_accum_dtype = torch.float32
else:
grad_accum_dtype = model_dtype
else:
grad_accum_dtype = DtypeEnum(self._config.grad_accum_dtype).value
return (model_dtype, grad_accum_dtype)
def _configure_lr_scheduler(self, client_lr_scheduler):
# First check for scheduler in json configuration
lr_scheduler = self._scheduler_from_config(self.optimizer)
......@@ -1106,6 +1126,61 @@ class DeepSpeedEngine(Module):
])
assert occurrence <= 1, f"Parameter with name: {name} occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behaviour."
def _do_optimizer_sanity_check(self, basic_optimizer):
model_dtype, grad_accum_dtype = self.get_data_types()
zero_enabled = self.zero_optimization()
amp_enabled = self.amp_enabled()
# config based assertions
assert (
not (amp_enabled and zero_enabled)
), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
if zero_enabled:
if model_dtype != grad_accum_dtype:
raise NotImplementedError(
"Model data type and gradient accumulation data type must be equal to use ZeRO"
)
if not is_zero_supported_optimizer(basic_optimizer):
assert (
self.zero_allow_untested_optimizer()
), 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'
if self.global_rank == 0:
logger.warning(
"**** You are using ZeRO with an untested optimizer, proceed with caution *****"
)
return ZERO_OPTIMIZATION
elif amp_enabled:
if model_dtype != grad_accum_dtype:
raise NotImplementedError(
"Model data type and gradient accumulation data type must be equal to use Amp"
)
if model_dtype == torch.bfloat16 or model_dtype == torch.float16:
raise NotImplementedError(
"Cannot enable both amp with (legacy) fp16 or bfloat16 mode")
try:
logger.info("Initializing Apex amp from: {}".format(amp.__path__))
except NameError:
# If apex/amp is available it will be imported above
raise RuntimeError(
"Unable to import apex/amp, please make sure it is installed")
return AMP
# data type checks
elif model_dtype == grad_accum_dtype:
if model_dtype == torch.bfloat16:
raise NotImplementedError(
"Bfloat16 wrapper must use a gradient accumulation type of fp32, enable ZeRO to use Bfloat16 gradient accumulation"
)
if model_dtype == torch.float16:
return FP16
# else optimizer_wrapper = None
elif model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32:
return BFLOAT16
else:
raise NotImplementedError(
"unsupported mix of model dtype and gradient accummulation type")
return None
# Configure optimizer
def _configure_optimizer(self, client_optimizer, model_parameters):
if client_optimizer is not None:
......@@ -1135,44 +1210,26 @@ class DeepSpeedEngine(Module):
basic_optimizer.__class__.__name__),
ranks=[0])
if self.zero_optimization():
assert (
not self.amp_enabled()
), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
if not is_zero_supported_optimizer(basic_optimizer):
assert (
self.zero_allow_untested_optimizer()
), 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'
optimizer_wrapper = self._do_optimizer_sanity_check(basic_optimizer)
if self.global_rank == 0:
logger.warning(
"**** You are using ZeRO with an untested optimizer, proceed with caution *****"
)
# This optimizer in engine is ZeRO optimizer of stage1_2 or stage3 based on the 'stage' config,
# while ZeRO optimizer itself wraps the original optimizer.
if optimizer_wrapper == ZERO_OPTIMIZATION:
self.optimizer = self._configure_zero_optimizer(basic_optimizer)
elif self.amp_enabled():
assert not (self.fp16_enabled() or self.bfloat16_enabled()), "Cannot enable both amp with (legacy) fp16 or bfloat16 mode"
elif optimizer_wrapper == AMP:
amp_params = self.amp_params()
log_dist(f"Initializing AMP with these params: {amp_params}", ranks=[0])
try:
logger.info("Initializing Apex amp from: {}".format(amp.__path__))
except NameError:
# If apex/amp is available it will be imported above
raise RuntimeError(
"Unable to import apex/amp, please make sure it is installed")
model, self.optimizer = amp.initialize(
self.module, basic_optimizer, **amp_params
)
self._set_client_model(model)
self._broadcast_model()
# TODO: maybe need to broadcast experts differently?
elif self.fp16_enabled():
elif optimizer_wrapper == FP16:
self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
elif self.bfloat16_enabled():
elif optimizer_wrapper == BFLOAT16:
self.optimizer = self._configure_bf16_optimizer(basic_optimizer)
else:
self.optimizer = basic_optimizer
log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer_name()),
ranks=[0])
......
......@@ -1513,3 +1513,18 @@ Different pruning sets, this is used for different pruning parameters. In this e
| Description | Default |
| ------------------------------------------------------------- | ------- |
| Use pipeline stages to parallelize the writing of checkpoints.| `false` |
### Data Type options
```json
"data_types": {
"grad_accum_dtype"=["fp32"|"fp16"|"bf16"]
}
}
```
<i>**grad_accum_dtype**</i>: ["fp32"|"fp16"|"bf16"]
| Description | Default |
| --------------------------------------------------------------------------------------------------------------| ------- |
| Specifies the data type in which to do gradient accumulation. If None the default is to match the model type. | None |
......@@ -6,7 +6,7 @@ from torch.optim.lr_scheduler import _LRScheduler, LambdaLR
from unit.simple_model import SimpleModel, random_dataloader
from unit.common import DistributedTest
from unit.util import required_torch_version
from unit.util import required_torch_version, bf16_required_version_check, required_amp_check
import deepspeed
from deepspeed.ops.adam import FusedAdam
......@@ -116,6 +116,89 @@ class TestConfigOptimizer(DistributedTest):
assert isinstance(ds_optimizer, FusedAdam)
@pytest.mark.parametrize('optimizer_extension', ['zero', 'amp', None])
@pytest.mark.parametrize('model_dtype', ['fp16', 'bf16', 'fp32'])
@pytest.mark.parametrize('grad_accum_dtype', [None, 'fp16', 'bf16', 'fp32'])
class TestOptimizerImplementation(DistributedTest):
world_size = 1
def test(self, optimizer_extension, model_dtype, grad_accum_dtype):
zero_stage = 1 if optimizer_extension == 'zero' else 0
amp = True if optimizer_extension == 'amp' else False
fp16 = True if model_dtype == 'fp16' else False
bf16 = True if model_dtype == 'bf16' else False
# Skip checks
if bf16 and not bf16_required_version_check():
pytest.skip(
"DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
)
if amp and not required_amp_check():
pytest.skip("Amp is not installed can't run amp check")
# Config declaration
ds_config = {
"train_batch_size": 1,
'fp16': {
'enabled': fp16
},
'bf16': {
'enabled': bf16
},
'amp': {
'enabled': amp
},
'zero_optimization': {
"stage": zero_stage
},
"data_types": {
"grad_accum_dtype": grad_accum_dtype
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.001
}
}
}
key = (optimizer_extension, model_dtype, grad_accum_dtype)
# Enumerate supported configurations
is_supported = {}
# Zero Wrapper
is_supported[('zero', 'fp16', None)] = True
is_supported[('zero', 'fp16', 'fp16')] = True
is_supported[('zero', 'bf16', 'bf16')] = True
is_supported[('zero', 'fp32', None)] = True
is_supported[('zero', 'fp32', 'fp32')] = True
# Amp Wrapper
is_supported[('amp', 'fp32', None)] = True
is_supported[('amp', 'fp32', 'fp32')] = True
# FP16 Wrapper
is_supported[(None, 'fp16', None)] = True
is_supported[(None, 'fp16', 'fp16')] = True
# BF16 Wrapper
is_supported[(None, 'bf16', 'fp32')] = True
is_supported[(None, 'bf16', None)] = True
# No Wrapper
is_supported[(None, 'fp32', None)] = True
is_supported[(None, 'fp32', 'fp32')] = True
hidden_dim = 10
model = SimpleModel(hidden_dim)
model_parameters = list(model.parameters())
if key in is_supported:
_, ds_optimizer, _, _ = deepspeed.initialize(config=ds_config,
model=model,
model_parameters=model_parameters)
assert True
else:
with pytest.raises(NotImplementedError):
_, ds_optimizer, _, _ = deepspeed.initialize(config=ds_config,
model=model,
model_parameters=model_parameters)
@pytest.mark.parametrize("scheduler_type", [None, _LRScheduler, Callable])
@pytest.mark.parametrize("optimizer_type", [None, Optimizer, Callable])
class TestClientLrScheduler(DistributedTest):
......
......@@ -45,3 +45,11 @@ def required_maximum_torch_version(major_version, minor_version):
return False
return TORCH_MAJOR < major_version or TORCH_MINOR <= minor_version
def required_amp_check():
from importlib.util import find_spec
if find_spec('apex') is None:
return False
else:
return True
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册